lm-deluge 0.0.7__tar.gz → 0.0.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge-0.0.9/LICENSE +7 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/PKG-INFO +6 -8
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/README.md +3 -7
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/pyproject.toml +1 -1
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/anthropic.py +23 -7
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/base.py +38 -11
- lm_deluge-0.0.9/src/lm_deluge/api_requests/bedrock.py +283 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/common.py +2 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/mistral.py +2 -2
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/openai.py +37 -6
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/client.py +18 -3
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/models.py +114 -24
- lm_deluge-0.0.9/src/lm_deluge/prompt.py +693 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/tool.py +16 -35
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge.egg-info/PKG-INFO +6 -8
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge.egg-info/SOURCES.txt +4 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_all_models.py +4 -0
- lm_deluge-0.0.9/tests/test_bedrock_models.py +252 -0
- lm_deluge-0.0.9/tests/test_tool_calls.py +401 -0
- lm_deluge-0.0.7/src/lm_deluge/prompt.py +0 -357
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/setup.cfg +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/__init__.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/__init__.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/bedrock.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/cohere.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/deepseek.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/mistral.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/vertex.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/cache.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/embed.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/errors.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/gemini_limits.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/image.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/llm_tools/__init__.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/llm_tools/extract.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/llm_tools/score.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/llm_tools/translate.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/rerank.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/sampling_params.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/tracker.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/util/json.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/util/logprobs.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/util/validation.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/util/xml.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge.egg-info/dependency_links.txt +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge.egg-info/requires.txt +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge.egg-info/top_level.txt +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_cache.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_image_models.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_image_utils.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_json_utils.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_sampling_params.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_translate.py +0 -0
- {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_xml_utils.py +0 -0
lm_deluge-0.0.9/LICENSE
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
Copyright 2025, Taylor AI
|
|
2
|
+
|
|
3
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
|
4
|
+
|
|
5
|
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
|
6
|
+
|
|
7
|
+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: lm_deluge
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.9
|
|
4
4
|
Summary: Python utility for using LLM API models.
|
|
5
5
|
Author-email: Benjamin Anderson <ben@trytaylor.ai>
|
|
6
6
|
Requires-Python: >=3.10
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
8
|
+
License-File: LICENSE
|
|
8
9
|
Requires-Dist: python-dotenv
|
|
9
10
|
Requires-Dist: json5
|
|
10
11
|
Requires-Dist: PyYAML
|
|
@@ -22,10 +23,11 @@ Requires-Dist: pdf2image
|
|
|
22
23
|
Requires-Dist: pillow
|
|
23
24
|
Requires-Dist: fasttext-wheel
|
|
24
25
|
Requires-Dist: fasttext-langdetect
|
|
26
|
+
Dynamic: license-file
|
|
25
27
|
|
|
26
|
-
#
|
|
28
|
+
# lm-deluge
|
|
27
29
|
|
|
28
|
-
`
|
|
30
|
+
`lm-deluge` is a lightweight helper library for maxing out your rate limits with LLM providers. It provides the following:
|
|
29
31
|
|
|
30
32
|
- **Unified client** – Send prompts to all relevant models with a single client.
|
|
31
33
|
- **Massive concurrency with throttling** – Set `max_tokens_per_minute` and `max_requests_per_minute` and let it fly. The client will process as many requests as possible while respecting rate limits and retrying failures.
|
|
@@ -77,11 +79,7 @@ print(resp[0].completion)
|
|
|
77
79
|
|
|
78
80
|
API calls can be customized in a few ways.
|
|
79
81
|
|
|
80
|
-
1. **Sampling Parameters.** This determines things like structured outputs, maximum completion tokens, nucleus sampling, etc. Provide a custom `SamplingParams` to the `LLMClient` to set temperature, top_p, json_mode, max_new_tokens, and/or reasoning_effort.
|
|
81
|
-
|
|
82
|
-
You can pass 1 `SamplingParams` to use for all models, or a list of `SamplingParams` that's the same length as the list of models. You can also pass many of these arguments directly to `LLMClient.basic` so you don't have to construct an entire `SamplingParams` object.
|
|
83
|
-
|
|
84
|
-
|
|
82
|
+
1. **Sampling Parameters.** This determines things like structured outputs, maximum completion tokens, nucleus sampling, etc. Provide a custom `SamplingParams` to the `LLMClient` to set temperature, top_p, json_mode, max_new_tokens, and/or reasoning_effort. You can pass 1 `SamplingParams` to use for all models, or a list of `SamplingParams` that's the same length as the list of models. You can also pass many of these arguments directly to `LLMClient.basic` so you don't have to construct an entire `SamplingParams` object.
|
|
85
83
|
2. **Arguments to LLMClient.** This is where you set request timeout, rate limits, model name(s), model weight(s) for distributing requests across models, retries, and caching.
|
|
86
84
|
3. **Arguments to process_prompts.** Per-call, you can set verbosity, whether to display progress, and whether to return just completions (rather than the full APIResponse object).
|
|
87
85
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
#
|
|
1
|
+
# lm-deluge
|
|
2
2
|
|
|
3
|
-
`
|
|
3
|
+
`lm-deluge` is a lightweight helper library for maxing out your rate limits with LLM providers. It provides the following:
|
|
4
4
|
|
|
5
5
|
- **Unified client** – Send prompts to all relevant models with a single client.
|
|
6
6
|
- **Massive concurrency with throttling** – Set `max_tokens_per_minute` and `max_requests_per_minute` and let it fly. The client will process as many requests as possible while respecting rate limits and retrying failures.
|
|
@@ -52,11 +52,7 @@ print(resp[0].completion)
|
|
|
52
52
|
|
|
53
53
|
API calls can be customized in a few ways.
|
|
54
54
|
|
|
55
|
-
1. **Sampling Parameters.** This determines things like structured outputs, maximum completion tokens, nucleus sampling, etc. Provide a custom `SamplingParams` to the `LLMClient` to set temperature, top_p, json_mode, max_new_tokens, and/or reasoning_effort.
|
|
56
|
-
|
|
57
|
-
You can pass 1 `SamplingParams` to use for all models, or a list of `SamplingParams` that's the same length as the list of models. You can also pass many of these arguments directly to `LLMClient.basic` so you don't have to construct an entire `SamplingParams` object.
|
|
58
|
-
|
|
59
|
-
|
|
55
|
+
1. **Sampling Parameters.** This determines things like structured outputs, maximum completion tokens, nucleus sampling, etc. Provide a custom `SamplingParams` to the `LLMClient` to set temperature, top_p, json_mode, max_new_tokens, and/or reasoning_effort. You can pass 1 `SamplingParams` to use for all models, or a list of `SamplingParams` that's the same length as the list of models. You can also pass many of these arguments directly to `LLMClient.basic` so you don't have to construct an entire `SamplingParams` object.
|
|
60
56
|
2. **Arguments to LLMClient.** This is where you set request timeout, rate limits, model name(s), model weight(s) for distributing requests across models, retries, and caching.
|
|
61
57
|
3. **Arguments to process_prompts.** Per-call, you can set verbosity, whether to display progress, and whether to return just completions (rather than the full APIResponse object).
|
|
62
58
|
|
|
@@ -6,7 +6,7 @@ import warnings
|
|
|
6
6
|
from tqdm import tqdm
|
|
7
7
|
from typing import Callable
|
|
8
8
|
|
|
9
|
-
from lm_deluge.prompt import Conversation
|
|
9
|
+
from lm_deluge.prompt import Conversation, Message, Text, ToolCall, Thinking
|
|
10
10
|
from .base import APIRequestBase, APIResponse
|
|
11
11
|
|
|
12
12
|
from ..tracker import StatusTracker
|
|
@@ -34,6 +34,7 @@ class AnthropicRequest(APIRequestBase):
|
|
|
34
34
|
# for retries
|
|
35
35
|
all_model_names: list[str] | None = None,
|
|
36
36
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
37
|
+
tools: list | None = None,
|
|
37
38
|
):
|
|
38
39
|
super().__init__(
|
|
39
40
|
task_id=task_id,
|
|
@@ -50,6 +51,7 @@ class AnthropicRequest(APIRequestBase):
|
|
|
50
51
|
debug=debug,
|
|
51
52
|
all_model_names=all_model_names,
|
|
52
53
|
all_sampling_params=all_sampling_params,
|
|
54
|
+
tools=tools,
|
|
53
55
|
)
|
|
54
56
|
self.model = APIModel.from_registry(model_name)
|
|
55
57
|
self.url = f"{self.model.api_base}/messages"
|
|
@@ -94,12 +96,14 @@ class AnthropicRequest(APIRequestBase):
|
|
|
94
96
|
)
|
|
95
97
|
if self.system_message is not None:
|
|
96
98
|
self.request_json["system"] = self.system_message
|
|
99
|
+
if tools:
|
|
100
|
+
self.request_json["tools"] = [tool.dump_for("anthropic") for tool in tools]
|
|
97
101
|
|
|
98
102
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
99
103
|
is_error = False
|
|
100
104
|
error_message = None
|
|
101
105
|
thinking = None
|
|
102
|
-
|
|
106
|
+
content = None
|
|
103
107
|
input_tokens = None
|
|
104
108
|
output_tokens = None
|
|
105
109
|
status_code = http_response.status
|
|
@@ -119,14 +123,26 @@ class AnthropicRequest(APIRequestBase):
|
|
|
119
123
|
if status_code >= 200 and status_code < 300:
|
|
120
124
|
try:
|
|
121
125
|
data = await http_response.json()
|
|
122
|
-
|
|
123
|
-
|
|
126
|
+
response_content = data["content"]
|
|
127
|
+
|
|
128
|
+
# Parse response into Message with parts
|
|
129
|
+
parts = []
|
|
130
|
+
for item in response_content:
|
|
124
131
|
if item["type"] == "text":
|
|
125
|
-
|
|
132
|
+
parts.append(Text(item["text"]))
|
|
126
133
|
elif item["type"] == "thinking":
|
|
127
134
|
thinking = item["thinking"]
|
|
135
|
+
parts.append(Thinking(item["thinking"]))
|
|
128
136
|
elif item["type"] == "tool_use":
|
|
129
|
-
|
|
137
|
+
parts.append(
|
|
138
|
+
ToolCall(
|
|
139
|
+
id=item["id"],
|
|
140
|
+
name=item["name"],
|
|
141
|
+
arguments=item["input"],
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
content = Message("assistant", parts)
|
|
130
146
|
input_tokens = data["usage"]["input_tokens"]
|
|
131
147
|
output_tokens = data["usage"]["output_tokens"]
|
|
132
148
|
except Exception as e:
|
|
@@ -162,7 +178,7 @@ class AnthropicRequest(APIRequestBase):
|
|
|
162
178
|
is_error=is_error,
|
|
163
179
|
error_message=error_message,
|
|
164
180
|
prompt=self.prompt,
|
|
165
|
-
|
|
181
|
+
content=content,
|
|
166
182
|
thinking=thinking,
|
|
167
183
|
model_internal=self.model_name,
|
|
168
184
|
sampling_params=self.sampling_params,
|
|
@@ -7,7 +7,7 @@ from dataclasses import dataclass
|
|
|
7
7
|
from abc import ABC, abstractmethod
|
|
8
8
|
from typing import Callable
|
|
9
9
|
|
|
10
|
-
from lm_deluge.prompt import Conversation
|
|
10
|
+
from lm_deluge.prompt import Conversation, Message
|
|
11
11
|
|
|
12
12
|
from ..tracker import StatusTracker
|
|
13
13
|
from ..sampling_params import SamplingParams
|
|
@@ -30,10 +30,12 @@ class APIResponse:
|
|
|
30
30
|
error_message: str | None
|
|
31
31
|
|
|
32
32
|
# completion information
|
|
33
|
-
completion: str | None
|
|
34
33
|
input_tokens: int | None
|
|
35
34
|
output_tokens: int | None
|
|
36
35
|
|
|
36
|
+
# response content - structured format
|
|
37
|
+
content: Message | None = None
|
|
38
|
+
|
|
37
39
|
# optional or calculated automatically
|
|
38
40
|
thinking: str | None = None # if model shows thinking tokens
|
|
39
41
|
model_external: str | None = None # the model tag used by the API
|
|
@@ -47,6 +49,13 @@ class APIResponse:
|
|
|
47
49
|
# set to true if should NOT retry with the same model (unrecoverable error)
|
|
48
50
|
give_up_if_no_other_models: bool | None = False
|
|
49
51
|
|
|
52
|
+
@property
|
|
53
|
+
def completion(self) -> str | None:
|
|
54
|
+
"""Backward compatibility: extract text from content Message."""
|
|
55
|
+
if self.content is not None:
|
|
56
|
+
return self.content.completion
|
|
57
|
+
return None
|
|
58
|
+
|
|
50
59
|
def __post_init__(self):
|
|
51
60
|
# calculate cost & get external model name
|
|
52
61
|
self.id = int(self.id)
|
|
@@ -63,7 +72,7 @@ class APIResponse:
|
|
|
63
72
|
self.input_tokens * api_model.input_cost / 1e6
|
|
64
73
|
+ self.output_tokens * api_model.output_cost / 1e6
|
|
65
74
|
)
|
|
66
|
-
elif self.completion is not None:
|
|
75
|
+
elif self.content is not None and self.completion is not None:
|
|
67
76
|
print(
|
|
68
77
|
f"Warning: Completion provided without token counts for model {self.model_internal}."
|
|
69
78
|
)
|
|
@@ -79,7 +88,8 @@ class APIResponse:
|
|
|
79
88
|
"status_code": self.status_code,
|
|
80
89
|
"is_error": self.is_error,
|
|
81
90
|
"error_message": self.error_message,
|
|
82
|
-
"completion": self.completion,
|
|
91
|
+
"completion": self.completion, # computed property
|
|
92
|
+
"content": self.content.to_log() if self.content else None,
|
|
83
93
|
"input_tokens": self.input_tokens,
|
|
84
94
|
"output_tokens": self.output_tokens,
|
|
85
95
|
"finish_reason": self.finish_reason,
|
|
@@ -88,11 +98,18 @@ class APIResponse:
|
|
|
88
98
|
|
|
89
99
|
@classmethod
|
|
90
100
|
def from_dict(cls, data: dict):
|
|
101
|
+
# Handle backward compatibility for content/completion
|
|
102
|
+
content = None
|
|
103
|
+
if "content" in data and data["content"] is not None:
|
|
104
|
+
# Reconstruct message from log format
|
|
105
|
+
content = Message.from_log(data["content"])
|
|
106
|
+
elif "completion" in data and data["completion"] is not None:
|
|
107
|
+
# Backward compatibility: create a Message with just text
|
|
108
|
+
content = Message.ai(data["completion"])
|
|
109
|
+
|
|
91
110
|
return cls(
|
|
92
111
|
id=data.get("id", random.randint(0, 1_000_000_000)),
|
|
93
112
|
model_internal=data["model_internal"],
|
|
94
|
-
model_external=data["model_external"],
|
|
95
|
-
region=data["region"],
|
|
96
113
|
prompt=Conversation.from_log(data["prompt"]),
|
|
97
114
|
sampling_params=SamplingParams(**data["sampling_params"]),
|
|
98
115
|
status_code=data["status_code"],
|
|
@@ -100,9 +117,14 @@ class APIResponse:
|
|
|
100
117
|
error_message=data["error_message"],
|
|
101
118
|
input_tokens=data["input_tokens"],
|
|
102
119
|
output_tokens=data["output_tokens"],
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
120
|
+
content=content,
|
|
121
|
+
thinking=data.get("thinking"),
|
|
122
|
+
model_external=data.get("model_external"),
|
|
123
|
+
region=data.get("region"),
|
|
124
|
+
logprobs=data.get("logprobs"),
|
|
125
|
+
finish_reason=data.get("finish_reason"),
|
|
126
|
+
cost=data.get("cost"),
|
|
127
|
+
cache_hit=data.get("cache_hit", False),
|
|
106
128
|
)
|
|
107
129
|
|
|
108
130
|
def write_to_file(self, filename):
|
|
@@ -145,6 +167,7 @@ class APIRequestBase(ABC):
|
|
|
145
167
|
debug: bool = False,
|
|
146
168
|
all_model_names: list[str] | None = None,
|
|
147
169
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
170
|
+
tools: list | None = None,
|
|
148
171
|
):
|
|
149
172
|
if all_model_names is None:
|
|
150
173
|
raise ValueError("all_model_names must be provided.")
|
|
@@ -166,6 +189,7 @@ class APIRequestBase(ABC):
|
|
|
166
189
|
self.debug = debug
|
|
167
190
|
self.all_model_names = all_model_names
|
|
168
191
|
self.all_sampling_params = all_sampling_params
|
|
192
|
+
self.tools = tools
|
|
169
193
|
self.result = [] # list of APIResponse objects from each attempt
|
|
170
194
|
|
|
171
195
|
# these should be set in the __init__ of the subclass
|
|
@@ -255,6 +279,7 @@ class APIRequestBase(ABC):
|
|
|
255
279
|
callback=self.callback,
|
|
256
280
|
all_model_names=self.all_model_names,
|
|
257
281
|
all_sampling_params=self.all_sampling_params,
|
|
282
|
+
tools=self.tools,
|
|
258
283
|
)
|
|
259
284
|
# PROBLEM: new request is never put into results array, so we can't get the result.
|
|
260
285
|
self.retry_queue.put_nowait(new_request)
|
|
@@ -297,7 +322,7 @@ class APIRequestBase(ABC):
|
|
|
297
322
|
status_code=None,
|
|
298
323
|
is_error=True,
|
|
299
324
|
error_message="Request timed out (terminated by client).",
|
|
300
|
-
|
|
325
|
+
content=None,
|
|
301
326
|
input_tokens=None,
|
|
302
327
|
output_tokens=None,
|
|
303
328
|
)
|
|
@@ -315,7 +340,7 @@ class APIRequestBase(ABC):
|
|
|
315
340
|
status_code=None,
|
|
316
341
|
is_error=True,
|
|
317
342
|
error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
|
|
318
|
-
|
|
343
|
+
content=None,
|
|
319
344
|
input_tokens=None,
|
|
320
345
|
output_tokens=None,
|
|
321
346
|
)
|
|
@@ -344,6 +369,7 @@ def create_api_request(
|
|
|
344
369
|
callback: Callable | None = None,
|
|
345
370
|
all_model_names: list[str] | None = None,
|
|
346
371
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
372
|
+
tools: list | None = None,
|
|
347
373
|
) -> APIRequestBase:
|
|
348
374
|
from .common import CLASSES # circular import so made it lazy, does this work?
|
|
349
375
|
|
|
@@ -368,5 +394,6 @@ def create_api_request(
|
|
|
368
394
|
callback=callback,
|
|
369
395
|
all_model_names=all_model_names,
|
|
370
396
|
all_sampling_params=all_sampling_params,
|
|
397
|
+
tools=tools,
|
|
371
398
|
**kwargs,
|
|
372
399
|
)
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from aiohttp import ClientResponse
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
from typing import Callable
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from requests_aws4auth import AWS4Auth
|
|
10
|
+
except ImportError:
|
|
11
|
+
raise ImportError(
|
|
12
|
+
"aws4auth is required for bedrock support. Install with: pip install requests-aws4auth"
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from lm_deluge.prompt import Conversation, Message, Text, ToolCall, Thinking
|
|
16
|
+
from .base import APIRequestBase, APIResponse
|
|
17
|
+
|
|
18
|
+
from ..tracker import StatusTracker
|
|
19
|
+
from ..sampling_params import SamplingParams
|
|
20
|
+
from ..models import APIModel
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BedrockRequest(APIRequestBase):
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
task_id: int,
|
|
27
|
+
model_name: str,
|
|
28
|
+
prompt: Conversation,
|
|
29
|
+
attempts_left: int,
|
|
30
|
+
status_tracker: StatusTracker,
|
|
31
|
+
retry_queue: asyncio.Queue,
|
|
32
|
+
results_arr: list,
|
|
33
|
+
request_timeout: int = 30,
|
|
34
|
+
sampling_params: SamplingParams = SamplingParams(),
|
|
35
|
+
pbar: tqdm | None = None,
|
|
36
|
+
callback: Callable | None = None,
|
|
37
|
+
debug: bool = False,
|
|
38
|
+
all_model_names: list[str] | None = None,
|
|
39
|
+
all_sampling_params: list[SamplingParams] | None = None,
|
|
40
|
+
tools: list | None = None,
|
|
41
|
+
):
|
|
42
|
+
super().__init__(
|
|
43
|
+
task_id=task_id,
|
|
44
|
+
model_name=model_name,
|
|
45
|
+
prompt=prompt,
|
|
46
|
+
attempts_left=attempts_left,
|
|
47
|
+
status_tracker=status_tracker,
|
|
48
|
+
retry_queue=retry_queue,
|
|
49
|
+
results_arr=results_arr,
|
|
50
|
+
request_timeout=request_timeout,
|
|
51
|
+
sampling_params=sampling_params,
|
|
52
|
+
pbar=pbar,
|
|
53
|
+
callback=callback,
|
|
54
|
+
debug=debug,
|
|
55
|
+
all_model_names=all_model_names,
|
|
56
|
+
all_sampling_params=all_sampling_params,
|
|
57
|
+
tools=tools,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
self.model = APIModel.from_registry(model_name)
|
|
61
|
+
|
|
62
|
+
# Get AWS credentials from environment
|
|
63
|
+
self.access_key = os.getenv("AWS_ACCESS_KEY_ID")
|
|
64
|
+
self.secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
65
|
+
self.session_token = os.getenv("AWS_SESSION_TOKEN")
|
|
66
|
+
|
|
67
|
+
if not self.access_key or not self.secret_key:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
"AWS credentials not found. Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Determine region - use us-west-2 for cross-region inference models
|
|
73
|
+
if self.model.name.startswith("us.anthropic."):
|
|
74
|
+
# Cross-region inference profiles should use us-west-2
|
|
75
|
+
self.region = "us-west-2"
|
|
76
|
+
else:
|
|
77
|
+
# Direct model IDs can use default region
|
|
78
|
+
self.region = getattr(self.model, "region", "us-east-1")
|
|
79
|
+
if hasattr(self.model, "regions") and self.model.regions:
|
|
80
|
+
if isinstance(self.model.regions, list):
|
|
81
|
+
self.region = self.model.regions[0]
|
|
82
|
+
elif isinstance(self.model.regions, dict):
|
|
83
|
+
self.region = list(self.model.regions.keys())[0]
|
|
84
|
+
|
|
85
|
+
# Construct the endpoint URL
|
|
86
|
+
self.service = "bedrock" # Service name for signing is 'bedrock' even though endpoint is bedrock-runtime
|
|
87
|
+
self.url = f"https://bedrock-runtime.{self.region}.amazonaws.com/model/{self.model.name}/invoke"
|
|
88
|
+
|
|
89
|
+
# Convert prompt to Anthropic format for bedrock
|
|
90
|
+
self.system_message, messages = prompt.to_anthropic()
|
|
91
|
+
|
|
92
|
+
# Prepare request body in Anthropic's bedrock format
|
|
93
|
+
self.request_json = {
|
|
94
|
+
"anthropic_version": "bedrock-2023-05-31",
|
|
95
|
+
"max_tokens": sampling_params.max_new_tokens,
|
|
96
|
+
"temperature": sampling_params.temperature,
|
|
97
|
+
"top_p": sampling_params.top_p,
|
|
98
|
+
"messages": messages,
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
if self.system_message is not None:
|
|
102
|
+
self.request_json["system"] = self.system_message
|
|
103
|
+
|
|
104
|
+
if tools:
|
|
105
|
+
self.request_json["tools"] = [tool.dump_for("anthropic") for tool in tools]
|
|
106
|
+
|
|
107
|
+
# Setup AWS4Auth for signing
|
|
108
|
+
self.auth = AWS4Auth(
|
|
109
|
+
self.access_key,
|
|
110
|
+
self.secret_key,
|
|
111
|
+
self.region,
|
|
112
|
+
self.service,
|
|
113
|
+
session_token=self.session_token,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Setup basic headers (AWS4Auth will add the Authorization header)
|
|
117
|
+
self.request_header = {
|
|
118
|
+
"Content-Type": "application/json",
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
async def call_api(self):
|
|
122
|
+
"""Override call_api to handle AWS4Auth signing."""
|
|
123
|
+
try:
|
|
124
|
+
import aiohttp
|
|
125
|
+
|
|
126
|
+
self.status_tracker.total_requests += 1
|
|
127
|
+
timeout = aiohttp.ClientTimeout(total=self.request_timeout)
|
|
128
|
+
|
|
129
|
+
# Prepare the request data
|
|
130
|
+
payload = json.dumps(self.request_json, separators=(",", ":")).encode(
|
|
131
|
+
"utf-8"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Create a fake requests.PreparedRequest object for AWS4Auth to sign
|
|
135
|
+
import requests
|
|
136
|
+
|
|
137
|
+
fake_request = requests.Request(
|
|
138
|
+
method="POST",
|
|
139
|
+
url=self.url,
|
|
140
|
+
data=payload,
|
|
141
|
+
headers=self.request_header.copy(),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Prepare the request so AWS4Auth can sign it properly
|
|
145
|
+
prepared_request = fake_request.prepare()
|
|
146
|
+
|
|
147
|
+
# Let AWS4Auth sign the prepared request
|
|
148
|
+
signed_request = self.auth(prepared_request)
|
|
149
|
+
|
|
150
|
+
# Extract the signed headers
|
|
151
|
+
signed_headers = dict(signed_request.headers)
|
|
152
|
+
|
|
153
|
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
154
|
+
async with session.post(
|
|
155
|
+
url=self.url,
|
|
156
|
+
headers=signed_headers,
|
|
157
|
+
data=payload,
|
|
158
|
+
) as http_response:
|
|
159
|
+
response: APIResponse = await self.handle_response(http_response)
|
|
160
|
+
|
|
161
|
+
self.result.append(response)
|
|
162
|
+
if response.is_error:
|
|
163
|
+
self.handle_error(
|
|
164
|
+
create_new_request=response.retry_with_different_model or False,
|
|
165
|
+
give_up_if_no_other_models=response.give_up_if_no_other_models
|
|
166
|
+
or False,
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
self.handle_success(response)
|
|
170
|
+
|
|
171
|
+
except asyncio.TimeoutError:
|
|
172
|
+
self.result.append(
|
|
173
|
+
APIResponse(
|
|
174
|
+
id=self.task_id,
|
|
175
|
+
model_internal=self.model_name,
|
|
176
|
+
prompt=self.prompt,
|
|
177
|
+
sampling_params=self.sampling_params,
|
|
178
|
+
status_code=None,
|
|
179
|
+
is_error=True,
|
|
180
|
+
error_message="Request timed out (terminated by client).",
|
|
181
|
+
content=None,
|
|
182
|
+
input_tokens=None,
|
|
183
|
+
output_tokens=None,
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
self.handle_error(create_new_request=False)
|
|
187
|
+
|
|
188
|
+
except Exception as e:
|
|
189
|
+
from ..errors import raise_if_modal_exception
|
|
190
|
+
|
|
191
|
+
raise_if_modal_exception(e)
|
|
192
|
+
self.result.append(
|
|
193
|
+
APIResponse(
|
|
194
|
+
id=self.task_id,
|
|
195
|
+
model_internal=self.model_name,
|
|
196
|
+
prompt=self.prompt,
|
|
197
|
+
sampling_params=self.sampling_params,
|
|
198
|
+
status_code=None,
|
|
199
|
+
is_error=True,
|
|
200
|
+
error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
|
|
201
|
+
content=None,
|
|
202
|
+
input_tokens=None,
|
|
203
|
+
output_tokens=None,
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
self.handle_error(create_new_request=False)
|
|
207
|
+
|
|
208
|
+
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
209
|
+
is_error = False
|
|
210
|
+
error_message = None
|
|
211
|
+
thinking = None
|
|
212
|
+
content = None
|
|
213
|
+
input_tokens = None
|
|
214
|
+
output_tokens = None
|
|
215
|
+
status_code = http_response.status
|
|
216
|
+
mimetype = http_response.headers.get("Content-Type", None)
|
|
217
|
+
|
|
218
|
+
if status_code >= 200 and status_code < 300:
|
|
219
|
+
try:
|
|
220
|
+
data = await http_response.json()
|
|
221
|
+
response_content = data["content"]
|
|
222
|
+
|
|
223
|
+
# Parse response into Message with parts
|
|
224
|
+
parts = []
|
|
225
|
+
for item in response_content:
|
|
226
|
+
if item["type"] == "text":
|
|
227
|
+
parts.append(Text(item["text"]))
|
|
228
|
+
elif item["type"] == "thinking":
|
|
229
|
+
thinking = item["thinking"]
|
|
230
|
+
parts.append(Thinking(item["thinking"]))
|
|
231
|
+
elif item["type"] == "tool_use":
|
|
232
|
+
parts.append(
|
|
233
|
+
ToolCall(
|
|
234
|
+
id=item["id"],
|
|
235
|
+
name=item["name"],
|
|
236
|
+
arguments=item["input"],
|
|
237
|
+
)
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
content = Message("assistant", parts)
|
|
241
|
+
input_tokens = data["usage"]["input_tokens"]
|
|
242
|
+
output_tokens = data["usage"]["output_tokens"]
|
|
243
|
+
except Exception as e:
|
|
244
|
+
is_error = True
|
|
245
|
+
error_message = (
|
|
246
|
+
f"Error calling .json() on response w/ status {status_code}: {e}"
|
|
247
|
+
)
|
|
248
|
+
elif mimetype and "json" in mimetype.lower():
|
|
249
|
+
is_error = True
|
|
250
|
+
data = await http_response.json()
|
|
251
|
+
error_message = json.dumps(data)
|
|
252
|
+
else:
|
|
253
|
+
is_error = True
|
|
254
|
+
text = await http_response.text()
|
|
255
|
+
error_message = text
|
|
256
|
+
|
|
257
|
+
# Handle special kinds of errors
|
|
258
|
+
if is_error and error_message is not None:
|
|
259
|
+
if (
|
|
260
|
+
"rate limit" in error_message.lower()
|
|
261
|
+
or "throttling" in error_message.lower()
|
|
262
|
+
or status_code == 429
|
|
263
|
+
):
|
|
264
|
+
error_message += " (Rate limit error, triggering cooldown.)"
|
|
265
|
+
self.status_tracker.rate_limit_exceeded()
|
|
266
|
+
if "context length" in error_message or "too long" in error_message:
|
|
267
|
+
error_message += " (Context length exceeded, set retries to 0.)"
|
|
268
|
+
self.attempts_left = 0
|
|
269
|
+
|
|
270
|
+
return APIResponse(
|
|
271
|
+
id=self.task_id,
|
|
272
|
+
status_code=status_code,
|
|
273
|
+
is_error=is_error,
|
|
274
|
+
error_message=error_message,
|
|
275
|
+
prompt=self.prompt,
|
|
276
|
+
content=content,
|
|
277
|
+
thinking=thinking,
|
|
278
|
+
model_internal=self.model_name,
|
|
279
|
+
region=self.region,
|
|
280
|
+
sampling_params=self.sampling_params,
|
|
281
|
+
input_tokens=input_tokens,
|
|
282
|
+
output_tokens=output_tokens,
|
|
283
|
+
)
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from .openai import OpenAIRequest
|
|
2
2
|
from .anthropic import AnthropicRequest
|
|
3
3
|
from .mistral import MistralRequest
|
|
4
|
+
from .bedrock import BedrockRequest
|
|
4
5
|
|
|
5
6
|
CLASSES = {
|
|
6
7
|
"openai": OpenAIRequest,
|
|
7
8
|
"anthropic": AnthropicRequest,
|
|
8
9
|
"mistral": MistralRequest,
|
|
10
|
+
"bedrock": BedrockRequest,
|
|
9
11
|
}
|
|
@@ -7,7 +7,7 @@ from tqdm.auto import tqdm
|
|
|
7
7
|
from typing import Callable
|
|
8
8
|
|
|
9
9
|
from .base import APIRequestBase, APIResponse
|
|
10
|
-
from ..prompt import Conversation
|
|
10
|
+
from ..prompt import Conversation, Message
|
|
11
11
|
from ..tracker import StatusTracker
|
|
12
12
|
from ..sampling_params import SamplingParams
|
|
13
13
|
from ..models import APIModel
|
|
@@ -130,7 +130,7 @@ class MistralRequest(APIRequestBase):
|
|
|
130
130
|
error_message=error_message,
|
|
131
131
|
prompt=self.prompt,
|
|
132
132
|
logprobs=logprobs,
|
|
133
|
-
|
|
133
|
+
content=Message.ai(completion),
|
|
134
134
|
model_internal=self.model_name,
|
|
135
135
|
sampling_params=self.sampling_params,
|
|
136
136
|
input_tokens=input_tokens,
|