lm-deluge 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +6 -0
- lm_deluge/api_requests/__init__.py +3 -0
- lm_deluge/api_requests/anthropic.py +177 -0
- lm_deluge/api_requests/base.py +375 -0
- lm_deluge/api_requests/cohere.py +138 -0
- lm_deluge/api_requests/common.py +18 -0
- lm_deluge/api_requests/deprecated/bedrock.py +288 -0
- lm_deluge/api_requests/deprecated/deepseek.py +118 -0
- lm_deluge/api_requests/deprecated/mistral.py +120 -0
- lm_deluge/api_requests/google.py +0 -0
- lm_deluge/api_requests/openai.py +145 -0
- lm_deluge/api_requests/vertex.py +365 -0
- lm_deluge/cache.py +144 -0
- lm_deluge/client.py +760 -0
- lm_deluge/embed.py +392 -0
- lm_deluge/errors.py +8 -0
- lm_deluge/gemini_limits.py +65 -0
- lm_deluge/image.py +200 -0
- lm_deluge/llm_tools/__init__.py +11 -0
- lm_deluge/llm_tools/extract.py +111 -0
- lm_deluge/llm_tools/score.py +71 -0
- lm_deluge/llm_tools/translate.py +44 -0
- lm_deluge/models.py +957 -0
- lm_deluge/prompt.py +355 -0
- lm_deluge/rerank.py +338 -0
- lm_deluge/sampling_params.py +25 -0
- lm_deluge/tool.py +106 -0
- lm_deluge/tracker.py +12 -0
- lm_deluge/util/json.py +167 -0
- lm_deluge/util/logprobs.py +446 -0
- lm_deluge/util/pdf.py +45 -0
- lm_deluge/util/validation.py +46 -0
- lm_deluge/util/xml.py +291 -0
- lm_deluge-0.0.3.dist-info/METADATA +127 -0
- lm_deluge-0.0.3.dist-info/RECORD +37 -0
- lm_deluge-0.0.3.dist-info/WHEEL +5 -0
- lm_deluge-0.0.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
# https://docs.cohere.com/reference/chat
|
|
2
|
+
# https://cohere.com/pricing
|
|
3
|
+
import asyncio
|
|
4
|
+
from aiohttp import ClientResponse
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import time
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
from typing import Optional, Callable
|
|
10
|
+
from lm_deluge.prompt import Conversation
|
|
11
|
+
from .base import APIRequestBase, APIResponse
|
|
12
|
+
|
|
13
|
+
from ..tracker import StatusTracker
|
|
14
|
+
from ..sampling_params import SamplingParams
|
|
15
|
+
from ..models import APIModel
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CohereRequest(APIRequestBase):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
task_id: int,
|
|
22
|
+
# should always be 'role', 'content' keys.
|
|
23
|
+
# internal logic should handle translating to specific API format
|
|
24
|
+
model_name: str, # must correspond to registry
|
|
25
|
+
prompt: Conversation,
|
|
26
|
+
attempts_left: int,
|
|
27
|
+
status_tracker: StatusTracker,
|
|
28
|
+
results_arr: list,
|
|
29
|
+
retry_queue: asyncio.Queue,
|
|
30
|
+
request_timeout: int = 30,
|
|
31
|
+
sampling_params: SamplingParams = SamplingParams(),
|
|
32
|
+
pbar: Optional[tqdm] = None,
|
|
33
|
+
callback: Optional[Callable] = None,
|
|
34
|
+
debug: bool = False,
|
|
35
|
+
all_model_names: list[str] | None = None,
|
|
36
|
+
all_sampling_params: list[SamplingParams] | None = None,
|
|
37
|
+
):
|
|
38
|
+
super().__init__(
|
|
39
|
+
task_id=task_id,
|
|
40
|
+
model_name=model_name,
|
|
41
|
+
prompt=prompt,
|
|
42
|
+
attempts_left=attempts_left,
|
|
43
|
+
status_tracker=status_tracker,
|
|
44
|
+
retry_queue=retry_queue,
|
|
45
|
+
results_arr=results_arr,
|
|
46
|
+
request_timeout=request_timeout,
|
|
47
|
+
sampling_params=sampling_params,
|
|
48
|
+
pbar=pbar,
|
|
49
|
+
callback=callback,
|
|
50
|
+
debug=debug,
|
|
51
|
+
all_model_names=all_model_names,
|
|
52
|
+
all_sampling_params=all_sampling_params,
|
|
53
|
+
)
|
|
54
|
+
self.system_message = None
|
|
55
|
+
self.last_user_message = None
|
|
56
|
+
|
|
57
|
+
self.model = APIModel.from_registry(model_name)
|
|
58
|
+
self.url = f"{self.model.api_base}/chat"
|
|
59
|
+
self.system_message, chat_history, last_user_message = prompt.to_cohere()
|
|
60
|
+
|
|
61
|
+
self.request_header = {
|
|
62
|
+
"Authorization": f"bearer {os.getenv(self.model.api_key_env_var)}",
|
|
63
|
+
"content-type": "application/json",
|
|
64
|
+
"accept": "application/json",
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
self.request_json = {
|
|
68
|
+
"model": self.model.name,
|
|
69
|
+
"chat_history": chat_history,
|
|
70
|
+
"message": last_user_message,
|
|
71
|
+
"temperature": sampling_params.temperature,
|
|
72
|
+
"top_p": sampling_params.top_p,
|
|
73
|
+
"max_tokens": sampling_params.max_new_tokens,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
if self.system_message:
|
|
77
|
+
self.request_json["preamble"] = self.system_message
|
|
78
|
+
|
|
79
|
+
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
80
|
+
is_error = False
|
|
81
|
+
error_message = None
|
|
82
|
+
completion = None
|
|
83
|
+
input_tokens = None
|
|
84
|
+
output_tokens = None
|
|
85
|
+
status_code = http_response.status
|
|
86
|
+
mimetype = http_response.headers.get("Content-Type", None)
|
|
87
|
+
if status_code >= 200 and status_code < 300:
|
|
88
|
+
try:
|
|
89
|
+
data = await http_response.json()
|
|
90
|
+
except Exception:
|
|
91
|
+
data = None
|
|
92
|
+
is_error = True
|
|
93
|
+
error_message = (
|
|
94
|
+
f"Error calling .json() on response w/ status {status_code}"
|
|
95
|
+
)
|
|
96
|
+
if not is_error and isinstance(data, dict):
|
|
97
|
+
try:
|
|
98
|
+
completion = data["text"]
|
|
99
|
+
input_tokens = data["meta"]["billed_units"]["input_tokens"]
|
|
100
|
+
output_tokens = data["meta"]["billed_units"]["input_tokens"]
|
|
101
|
+
except Exception:
|
|
102
|
+
is_error = True
|
|
103
|
+
error_message = f"Error getting 'text' or 'meta' from {self.model.name} response."
|
|
104
|
+
elif mimetype is not None and "json" in mimetype.lower():
|
|
105
|
+
is_error = True # expected status is 200, otherwise it's an error
|
|
106
|
+
data = await http_response.json()
|
|
107
|
+
error_message = json.dumps(data)
|
|
108
|
+
|
|
109
|
+
else:
|
|
110
|
+
is_error = True
|
|
111
|
+
text = await http_response.text()
|
|
112
|
+
error_message = text
|
|
113
|
+
|
|
114
|
+
# handle special kinds of errors. TODO: make sure these are correct for anthropic
|
|
115
|
+
if is_error and error_message is not None:
|
|
116
|
+
if (
|
|
117
|
+
"rate limit" in error_message.lower()
|
|
118
|
+
or "overloaded" in error_message.lower()
|
|
119
|
+
):
|
|
120
|
+
error_message += " (Rate limit error, triggering cooldown.)"
|
|
121
|
+
self.status_tracker.time_of_last_rate_limit_error = time.time()
|
|
122
|
+
self.status_tracker.num_rate_limit_errors += 1
|
|
123
|
+
if "context length" in error_message:
|
|
124
|
+
error_message += " (Context length exceeded, set retries to 0.)"
|
|
125
|
+
self.attempts_left = 0
|
|
126
|
+
|
|
127
|
+
return APIResponse(
|
|
128
|
+
id=self.task_id,
|
|
129
|
+
status_code=status_code,
|
|
130
|
+
is_error=is_error,
|
|
131
|
+
error_message=error_message,
|
|
132
|
+
prompt=self.prompt,
|
|
133
|
+
completion=completion,
|
|
134
|
+
model_internal=self.model_name,
|
|
135
|
+
sampling_params=self.sampling_params,
|
|
136
|
+
input_tokens=input_tokens,
|
|
137
|
+
output_tokens=output_tokens,
|
|
138
|
+
)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# from .vertex import VertexAnthropicRequest, GeminiRequest
|
|
2
|
+
# from .bedrock import BedrockAnthropicRequest, MistralBedrockRequest
|
|
3
|
+
# from .deepseek import DeepseekRequest
|
|
4
|
+
from .openai import OpenAIRequest
|
|
5
|
+
from .cohere import CohereRequest
|
|
6
|
+
from .anthropic import AnthropicRequest
|
|
7
|
+
|
|
8
|
+
CLASSES = {
|
|
9
|
+
"openai": OpenAIRequest,
|
|
10
|
+
# "deepseek": DeepseekRequest,
|
|
11
|
+
"anthropic": AnthropicRequest,
|
|
12
|
+
# "vertex_anthropic": VertexAnthropicRequest,
|
|
13
|
+
# "vertex_gemini": GeminiRequest,
|
|
14
|
+
"cohere": CohereRequest,
|
|
15
|
+
# "bedrock_anthropic": BedrockAnthropicRequest,
|
|
16
|
+
# "bedrock_mistral": MistralBedrockRequest,
|
|
17
|
+
# "mistral": MistralRequest,
|
|
18
|
+
}
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
# import asyncio
|
|
2
|
+
# import requests
|
|
3
|
+
# from requests.structures import CaseInsensitiveDict
|
|
4
|
+
# from requests_aws4auth import AWS4Auth
|
|
5
|
+
# from aiohttp import ClientResponse
|
|
6
|
+
# import json
|
|
7
|
+
# import os
|
|
8
|
+
# import time
|
|
9
|
+
# from tqdm import tqdm
|
|
10
|
+
# from typing import Optional, Callable
|
|
11
|
+
# from lm_deluge.prompt import Conversation
|
|
12
|
+
# from .base import APIRequestBase, APIResponse
|
|
13
|
+
# from ..tracker import StatusTracker
|
|
14
|
+
# from ..sampling_params import SamplingParams
|
|
15
|
+
# from ..models import APIModel
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# def get_aws_headers(
|
|
19
|
+
# access_key_id: str,
|
|
20
|
+
# secret_access_key: str,
|
|
21
|
+
# region: str,
|
|
22
|
+
# url: str,
|
|
23
|
+
# request_json: dict,
|
|
24
|
+
# service: str = "bedrock",
|
|
25
|
+
# ):
|
|
26
|
+
# auth = AWS4Auth(
|
|
27
|
+
# access_key_id,
|
|
28
|
+
# secret_access_key,
|
|
29
|
+
# region,
|
|
30
|
+
# service,
|
|
31
|
+
# )
|
|
32
|
+
|
|
33
|
+
# headers = CaseInsensitiveDict()
|
|
34
|
+
# mock_request = requests.Request(
|
|
35
|
+
# method="POST", url=url, headers=headers, json=request_json
|
|
36
|
+
# ).prepare()
|
|
37
|
+
# auth(mock_request)
|
|
38
|
+
# # print("headers:", mock_request.headers)
|
|
39
|
+
# return mock_request.headers
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# class BedrockAnthropicRequest(APIRequestBase):
|
|
43
|
+
# """
|
|
44
|
+
# For Claude on Bedrock, you'll also have to set the PROJECT_ID environment variable.
|
|
45
|
+
# """
|
|
46
|
+
|
|
47
|
+
# def __init__(
|
|
48
|
+
# self,
|
|
49
|
+
# task_id: int,
|
|
50
|
+
# model_name: str, # must correspond to registry
|
|
51
|
+
# prompt: Conversation,
|
|
52
|
+
# attempts_left: int,
|
|
53
|
+
# results_arr: list,
|
|
54
|
+
# status_tracker: StatusTracker,
|
|
55
|
+
# retry_queue: asyncio.Queue,
|
|
56
|
+
# request_timeout: int = 30,
|
|
57
|
+
# sampling_params: SamplingParams = SamplingParams(),
|
|
58
|
+
# pbar: Optional[tqdm] = None,
|
|
59
|
+
# callback: Optional[Callable] = None,
|
|
60
|
+
# debug: bool = False,
|
|
61
|
+
# all_model_names: list[str] | None = None,
|
|
62
|
+
# all_sampling_params: list[SamplingParams] | None = None,
|
|
63
|
+
# ):
|
|
64
|
+
# super().__init__(
|
|
65
|
+
# task_id=task_id,
|
|
66
|
+
# model_name=model_name,
|
|
67
|
+
# prompt=prompt,
|
|
68
|
+
# attempts_left=attempts_left,
|
|
69
|
+
# status_tracker=status_tracker,
|
|
70
|
+
# retry_queue=retry_queue,
|
|
71
|
+
# results_arr=results_arr,
|
|
72
|
+
# request_timeout=request_timeout,
|
|
73
|
+
# sampling_params=sampling_params,
|
|
74
|
+
# pbar=pbar,
|
|
75
|
+
# callback=callback,
|
|
76
|
+
# debug=debug,
|
|
77
|
+
# all_model_names=all_model_names,
|
|
78
|
+
# all_sampling_params=all_sampling_params,
|
|
79
|
+
# )
|
|
80
|
+
# self.model = APIModel.from_registry(model_name)
|
|
81
|
+
# region = self.model.sample_region()
|
|
82
|
+
# assert region is not None, "unable to sample a region"
|
|
83
|
+
# self.url = f"https://bedrock-runtime.{region}.amazonaws.com/model/{self.model.name}/invoke"
|
|
84
|
+
# self.system_message, messages = prompt.to_anthropic()
|
|
85
|
+
|
|
86
|
+
# self.request_json = {
|
|
87
|
+
# "anthropic_version": "bedrock-2023-05-31",
|
|
88
|
+
# "messages": messages,
|
|
89
|
+
# "temperature": self.sampling_params.temperature,
|
|
90
|
+
# "top_p": self.sampling_params.top_p,
|
|
91
|
+
# "max_tokens": self.sampling_params.max_new_tokens,
|
|
92
|
+
# }
|
|
93
|
+
# if self.system_message is not None:
|
|
94
|
+
# self.request_json["system"] = self.system_message
|
|
95
|
+
|
|
96
|
+
# self.request_header = dict(
|
|
97
|
+
# get_aws_headers(
|
|
98
|
+
# access_key_id=os.getenv("AWS_ACCESS_KEY_ID", ""),
|
|
99
|
+
# secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY", ""),
|
|
100
|
+
# region=region,
|
|
101
|
+
# url=self.url,
|
|
102
|
+
# request_json=self.request_json,
|
|
103
|
+
# )
|
|
104
|
+
# )
|
|
105
|
+
|
|
106
|
+
# async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
107
|
+
# is_error = False
|
|
108
|
+
# error_message = None
|
|
109
|
+
# completion = None
|
|
110
|
+
# input_tokens = None
|
|
111
|
+
# output_tokens = None
|
|
112
|
+
# status_code = http_response.status
|
|
113
|
+
# mimetype = http_response.headers.get("Content-Type", None)
|
|
114
|
+
# if status_code >= 200 and status_code < 300:
|
|
115
|
+
# try:
|
|
116
|
+
# data = await http_response.json()
|
|
117
|
+
# completion = data["content"][0]["text"]
|
|
118
|
+
# input_tokens = data["usage"]["input_tokens"]
|
|
119
|
+
# output_tokens = data["usage"]["output_tokens"]
|
|
120
|
+
# except Exception as e:
|
|
121
|
+
# is_error = True
|
|
122
|
+
# error_message = (
|
|
123
|
+
# f"Error calling .json() on response w/ status {status_code}: {e}"
|
|
124
|
+
# )
|
|
125
|
+
# elif "json" in mimetype.lower() if mimetype else "":
|
|
126
|
+
# is_error = True # expected status is 200, otherwise it's an error
|
|
127
|
+
# data = await http_response.json()
|
|
128
|
+
# error_message = json.dumps(data)
|
|
129
|
+
|
|
130
|
+
# else:
|
|
131
|
+
# is_error = True
|
|
132
|
+
# text = await http_response.text()
|
|
133
|
+
# error_message = text
|
|
134
|
+
|
|
135
|
+
# # handle special kinds of errors. TODO: make sure these are correct for anthropic
|
|
136
|
+
# if is_error and error_message is not None:
|
|
137
|
+
# if (
|
|
138
|
+
# "rate limit" in error_message.lower()
|
|
139
|
+
# or "overloaded" in error_message.lower()
|
|
140
|
+
# ):
|
|
141
|
+
# error_message += " (Rate limit error, triggering cooldown.)"
|
|
142
|
+
# self.status_tracker.time_of_last_rate_limit_error = time.time()
|
|
143
|
+
# self.status_tracker.num_rate_limit_errors += 1
|
|
144
|
+
# if "context length" in error_message:
|
|
145
|
+
# error_message += " (Context length exceeded, set retries to 0.)"
|
|
146
|
+
# self.attempts_left = 0
|
|
147
|
+
|
|
148
|
+
# return APIResponse(
|
|
149
|
+
# id=self.task_id,
|
|
150
|
+
# status_code=status_code,
|
|
151
|
+
# is_error=is_error,
|
|
152
|
+
# error_message=error_message,
|
|
153
|
+
# prompt=self.prompt,
|
|
154
|
+
# completion=completion,
|
|
155
|
+
# model_internal=self.model_name,
|
|
156
|
+
# sampling_params=self.sampling_params,
|
|
157
|
+
# input_tokens=input_tokens,
|
|
158
|
+
# output_tokens=output_tokens,
|
|
159
|
+
# )
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
# class MistralBedrockRequest(APIRequestBase):
|
|
163
|
+
# """
|
|
164
|
+
# Documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html#model-parameters-mistral-request-response
|
|
165
|
+
# """
|
|
166
|
+
|
|
167
|
+
# def __init__(
|
|
168
|
+
# self,
|
|
169
|
+
# task_id: int,
|
|
170
|
+
# model_name: str, # must correspond to registry
|
|
171
|
+
# prompt: Conversation,
|
|
172
|
+
# attempts_left: int,
|
|
173
|
+
# status_tracker: StatusTracker,
|
|
174
|
+
# retry_queue: asyncio.Queue,
|
|
175
|
+
# results_arr: list,
|
|
176
|
+
# request_timeout: int = 30,
|
|
177
|
+
# sampling_params: SamplingParams = SamplingParams(),
|
|
178
|
+
# pbar: Optional[tqdm] = None,
|
|
179
|
+
# callback: Optional[Callable] = None,
|
|
180
|
+
# debug: bool = False,
|
|
181
|
+
# all_model_names: list[str] | None = None,
|
|
182
|
+
# all_sampling_params: list[SamplingParams] | None = None,
|
|
183
|
+
# ):
|
|
184
|
+
# super().__init__(
|
|
185
|
+
# task_id=task_id,
|
|
186
|
+
# model_name=model_name,
|
|
187
|
+
# prompt=prompt,
|
|
188
|
+
# attempts_left=attempts_left,
|
|
189
|
+
# status_tracker=status_tracker,
|
|
190
|
+
# retry_queue=retry_queue,
|
|
191
|
+
# results_arr=results_arr,
|
|
192
|
+
# request_timeout=request_timeout,
|
|
193
|
+
# sampling_params=sampling_params,
|
|
194
|
+
# pbar=pbar,
|
|
195
|
+
# callback=callback,
|
|
196
|
+
# debug=debug,
|
|
197
|
+
# all_model_names=all_model_names,
|
|
198
|
+
# all_sampling_params=all_sampling_params,
|
|
199
|
+
# )
|
|
200
|
+
# self.model = APIModel.from_registry(model_name)
|
|
201
|
+
# self.region = self.model.sample_region()
|
|
202
|
+
# assert self.region is not None, "unable to select a region"
|
|
203
|
+
# self.url = f"https://bedrock-runtime.{self.region}.amazonaws.com/model/{self.model.name}/invoke"
|
|
204
|
+
# self.system_message = None
|
|
205
|
+
# self.request_json = {
|
|
206
|
+
# "prompt": prompt.to_mistral_bedrock(),
|
|
207
|
+
# "max_tokens": self.sampling_params.max_new_tokens,
|
|
208
|
+
# "temperature": self.sampling_params.temperature,
|
|
209
|
+
# "top_p": self.sampling_params.top_p,
|
|
210
|
+
# }
|
|
211
|
+
# self.request_header = dict(
|
|
212
|
+
# get_aws_headers(
|
|
213
|
+
# access_key_id=os.getenv("AWS_ACCESS_KEY_ID", ""),
|
|
214
|
+
# secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY", ""),
|
|
215
|
+
# region=self.region,
|
|
216
|
+
# url=self.url,
|
|
217
|
+
# request_json=self.request_json,
|
|
218
|
+
# )
|
|
219
|
+
# )
|
|
220
|
+
|
|
221
|
+
# async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
222
|
+
# is_error = False
|
|
223
|
+
# error_message: str | None = None
|
|
224
|
+
# completion = None
|
|
225
|
+
# input_tokens = None
|
|
226
|
+
# output_tokens = None
|
|
227
|
+
# status_code = http_response.status
|
|
228
|
+
# mimetype = http_response.headers.get("Content-Type", None)
|
|
229
|
+
# if status_code >= 200 and status_code < 300:
|
|
230
|
+
# try:
|
|
231
|
+
# data = await http_response.json()
|
|
232
|
+
# completion = data["outputs"][0]["text"]
|
|
233
|
+
# input_tokens = len(self.request_json["prompt"]) // 4 # approximate
|
|
234
|
+
# output_tokens = len(completion) // 4 # approximate
|
|
235
|
+
# except Exception as e:
|
|
236
|
+
# is_error = True
|
|
237
|
+
# error_message = (
|
|
238
|
+
# f"Error calling .json() on response w/ status {status_code}: {e}"
|
|
239
|
+
# )
|
|
240
|
+
# elif "json" in (mimetype.lower() if mimetype else ""):
|
|
241
|
+
# is_error = True # expected status is 200, otherwise it's an error
|
|
242
|
+
# data = await http_response.json()
|
|
243
|
+
# error_message = json.dumps(data)
|
|
244
|
+
|
|
245
|
+
# else:
|
|
246
|
+
# is_error = True
|
|
247
|
+
# text = await http_response.text()
|
|
248
|
+
# error_message = (
|
|
249
|
+
# text if isinstance(text, str) else (str(text) if text else "")
|
|
250
|
+
# )
|
|
251
|
+
|
|
252
|
+
# # TODO: Handle rate-limit errors
|
|
253
|
+
# # TODO: in the future, instead of slowing down, switch models?
|
|
254
|
+
# if status_code == 429:
|
|
255
|
+
# assert isinstance(error_message, str)
|
|
256
|
+
# error_message += " (Rate limit error, triggering cooldown.)"
|
|
257
|
+
# self.status_tracker.time_of_last_rate_limit_error = time.time()
|
|
258
|
+
# self.status_tracker.num_rate_limit_errors += 1
|
|
259
|
+
|
|
260
|
+
# # if error, change the region
|
|
261
|
+
# old_region = self.region
|
|
262
|
+
# if is_error:
|
|
263
|
+
# self.region = self.model.sample_region()
|
|
264
|
+
# assert self.region is not None, "could not select a region"
|
|
265
|
+
# self.url = f"https://bedrock-runtime.{self.region}.amazonaws.com/model/{self.model.name}/invoke"
|
|
266
|
+
# self.request_header = dict(
|
|
267
|
+
# get_aws_headers(
|
|
268
|
+
# access_key_id=os.getenv("AWS_ACCESS_KEY_ID", ""),
|
|
269
|
+
# secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY", ""),
|
|
270
|
+
# region=self.region,
|
|
271
|
+
# url=self.url,
|
|
272
|
+
# request_json=self.request_json,
|
|
273
|
+
# )
|
|
274
|
+
# )
|
|
275
|
+
|
|
276
|
+
# return APIResponse(
|
|
277
|
+
# id=self.task_id,
|
|
278
|
+
# status_code=status_code,
|
|
279
|
+
# is_error=is_error,
|
|
280
|
+
# error_message=error_message,
|
|
281
|
+
# prompt=self.prompt,
|
|
282
|
+
# completion=completion,
|
|
283
|
+
# model_internal=self.model_name,
|
|
284
|
+
# sampling_params=self.sampling_params,
|
|
285
|
+
# input_tokens=input_tokens,
|
|
286
|
+
# output_tokens=output_tokens,
|
|
287
|
+
# region=old_region,
|
|
288
|
+
# )
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# import asyncio
|
|
2
|
+
# from aiohttp import ClientResponse
|
|
3
|
+
# import json
|
|
4
|
+
# import os
|
|
5
|
+
# import time
|
|
6
|
+
# from tqdm import tqdm
|
|
7
|
+
# from typing import Optional, Callable
|
|
8
|
+
|
|
9
|
+
# from .base import APIRequestBase, APIResponse
|
|
10
|
+
# from ..prompt import Prompt
|
|
11
|
+
# from ..tracker import StatusTracker
|
|
12
|
+
# from ..sampling_params import SamplingParams
|
|
13
|
+
# from ..models import APIModel
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# class DeepseekRequest(APIRequestBase):
|
|
17
|
+
# def __init__(
|
|
18
|
+
# self,
|
|
19
|
+
# task_id: int,
|
|
20
|
+
# model_name: str, # must correspond to registry
|
|
21
|
+
# prompt: Prompt,
|
|
22
|
+
# attempts_left: int,
|
|
23
|
+
# status_tracker: StatusTracker,
|
|
24
|
+
# retry_queue: asyncio.Queue,
|
|
25
|
+
# results_arr: list,
|
|
26
|
+
# request_timeout: int = 30,
|
|
27
|
+
# sampling_params: SamplingParams = SamplingParams(),
|
|
28
|
+
# pbar: Optional[tqdm] = None,
|
|
29
|
+
# callback: Optional[Callable] = None,
|
|
30
|
+
# debug: bool = False,
|
|
31
|
+
# all_model_names: list[str] = None,
|
|
32
|
+
# all_sampling_params: list[SamplingParams] = None,
|
|
33
|
+
# ):
|
|
34
|
+
# super().__init__(
|
|
35
|
+
# task_id=task_id,
|
|
36
|
+
# model_name=model_name,
|
|
37
|
+
# prompt=prompt,
|
|
38
|
+
# attempts_left=attempts_left,
|
|
39
|
+
# status_tracker=status_tracker,
|
|
40
|
+
# retry_queue=retry_queue,
|
|
41
|
+
# results_arr=results_arr,
|
|
42
|
+
# request_timeout=request_timeout,
|
|
43
|
+
# sampling_params=sampling_params,
|
|
44
|
+
# pbar=pbar,
|
|
45
|
+
# callback=callback,
|
|
46
|
+
# debug=debug,
|
|
47
|
+
# all_model_names=all_model_names,
|
|
48
|
+
# all_sampling_params=all_sampling_params,
|
|
49
|
+
# )
|
|
50
|
+
# self.model = APIModel.from_registry(model_name)
|
|
51
|
+
# self.url = f"{self.model.api_base}/chat/completions"
|
|
52
|
+
# self.request_header = {
|
|
53
|
+
# "Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
|
|
54
|
+
# }
|
|
55
|
+
# if prompt.image is not None:
|
|
56
|
+
# raise ValueError("Deepseek does not support images.")
|
|
57
|
+
|
|
58
|
+
# self.request_json = {
|
|
59
|
+
# "model": self.model.name,
|
|
60
|
+
# "messages": prompt.to_openai(),
|
|
61
|
+
# "temperature": sampling_params.temperature,
|
|
62
|
+
# "top_p": sampling_params.top_p,
|
|
63
|
+
# "max_tokens": sampling_params.max_new_tokens,
|
|
64
|
+
# }
|
|
65
|
+
# if sampling_params.json_mode and self.model.supports_json:
|
|
66
|
+
# self.request_json["response_format"] = {"type": "json_object"}
|
|
67
|
+
|
|
68
|
+
# async def handle_response(self, response: ClientResponse) -> APIResponse:
|
|
69
|
+
# is_error = False
|
|
70
|
+
# error_message = None
|
|
71
|
+
# completion = None
|
|
72
|
+
# input_tokens = None
|
|
73
|
+
# output_tokens = None
|
|
74
|
+
# status_code = response.status
|
|
75
|
+
# mimetype = response.headers.get("Content-Type", None)
|
|
76
|
+
# if status_code >= 200 and status_code < 300:
|
|
77
|
+
# try:
|
|
78
|
+
# data = await response.json()
|
|
79
|
+
# completion = data["choices"][0]["message"]["content"]
|
|
80
|
+
# input_tokens = data["usage"]["prompt_tokens"]
|
|
81
|
+
# output_tokens = data["usage"]["completion_tokens"]
|
|
82
|
+
|
|
83
|
+
# except Exception:
|
|
84
|
+
# is_error = True
|
|
85
|
+
# error_message = (
|
|
86
|
+
# f"Error calling .json() on response w/ status {status_code}"
|
|
87
|
+
# )
|
|
88
|
+
# elif "json" in mimetype.lower():
|
|
89
|
+
# is_error = True # expected status is 200, otherwise it's an error
|
|
90
|
+
# data = await response.json()
|
|
91
|
+
# error_message = json.dumps(data)
|
|
92
|
+
# else:
|
|
93
|
+
# is_error = True
|
|
94
|
+
# text = await response.text()
|
|
95
|
+
# error_message = text
|
|
96
|
+
|
|
97
|
+
# # handle special kinds of errors
|
|
98
|
+
# if is_error and error_message is not None:
|
|
99
|
+
# if "rate limit" in error_message.lower():
|
|
100
|
+
# error_message += " (Rate limit error, triggering cooldown.)"
|
|
101
|
+
# self.status_tracker.time_of_last_rate_limit_error = time.time()
|
|
102
|
+
# self.status_tracker.num_rate_limit_errors += 1
|
|
103
|
+
# if "context length" in error_message:
|
|
104
|
+
# error_message += " (Context length exceeded, set retries to 0.)"
|
|
105
|
+
# self.attempts_left = 0
|
|
106
|
+
|
|
107
|
+
# return APIResponse(
|
|
108
|
+
# id=self.task_id,
|
|
109
|
+
# status_code=status_code,
|
|
110
|
+
# is_error=is_error,
|
|
111
|
+
# error_message=error_message,
|
|
112
|
+
# prompt=self.prompt,
|
|
113
|
+
# completion=completion,
|
|
114
|
+
# model_internal=self.model_name,
|
|
115
|
+
# sampling_params=self.sampling_params,
|
|
116
|
+
# input_tokens=input_tokens,
|
|
117
|
+
# output_tokens=output_tokens,
|
|
118
|
+
# )
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# import asyncio
|
|
2
|
+
# from aiohttp import ClientResponse
|
|
3
|
+
# import json
|
|
4
|
+
# import os
|
|
5
|
+
# import time
|
|
6
|
+
# from tqdm import tqdm
|
|
7
|
+
# from typing import Optional, Callable
|
|
8
|
+
|
|
9
|
+
# from .base import APIRequestBase, APIResponse
|
|
10
|
+
# from ..prompt import Prompt
|
|
11
|
+
# from ..tracker import StatusTracker
|
|
12
|
+
# from ..sampling_params import SamplingParams
|
|
13
|
+
# from ..models import APIModel
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# class MistralRequest(APIRequestBase):
|
|
17
|
+
# def __init__(
|
|
18
|
+
# self,
|
|
19
|
+
# task_id: int,
|
|
20
|
+
# # should always be 'role', 'content' keys.
|
|
21
|
+
# # internal logic should handle translating to specific API format
|
|
22
|
+
# model_name: str, # must correspond to registry
|
|
23
|
+
# prompt: Prompt,
|
|
24
|
+
# attempts_left: int,
|
|
25
|
+
# status_tracker: StatusTracker,
|
|
26
|
+
# retry_queue: asyncio.Queue,
|
|
27
|
+
# results_arr: list,
|
|
28
|
+
# request_timeout: int = 30,
|
|
29
|
+
# sampling_params: SamplingParams = SamplingParams(),
|
|
30
|
+
# pbar: Optional[tqdm] = None,
|
|
31
|
+
# callback: Optional[Callable] = None,
|
|
32
|
+
# debug: bool = False,
|
|
33
|
+
# all_model_names: list[str] = None,
|
|
34
|
+
# all_sampling_params: list[SamplingParams] = None,
|
|
35
|
+
# ):
|
|
36
|
+
# super().__init__(
|
|
37
|
+
# task_id=task_id,
|
|
38
|
+
# model_name=model_name,
|
|
39
|
+
# prompt=prompt,
|
|
40
|
+
# attempts_left=attempts_left,
|
|
41
|
+
# status_tracker=status_tracker,
|
|
42
|
+
# retry_queue=retry_queue,
|
|
43
|
+
# results_arr=results_arr,
|
|
44
|
+
# request_timeout=request_timeout,
|
|
45
|
+
# sampling_params=sampling_params,
|
|
46
|
+
# pbar=pbar,
|
|
47
|
+
# callback=callback,
|
|
48
|
+
# debug=debug,
|
|
49
|
+
# all_model_names=all_model_names,
|
|
50
|
+
# all_sampling_params=all_sampling_params,
|
|
51
|
+
# )
|
|
52
|
+
# self.model = APIModel.from_registry(model_name)
|
|
53
|
+
# self.url = f"{self.model.api_base}/chat/completions"
|
|
54
|
+
# self.request_header = {
|
|
55
|
+
# "Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
|
|
56
|
+
# }
|
|
57
|
+
# if prompt.image is not None:
|
|
58
|
+
# raise ValueError("Mistral does not support images.")
|
|
59
|
+
|
|
60
|
+
# self.request_json = {
|
|
61
|
+
# "model": self.model.name,
|
|
62
|
+
# "messages": prompt.to_openai(),
|
|
63
|
+
# "temperature": sampling_params.temperature,
|
|
64
|
+
# "top_p": sampling_params.top_p,
|
|
65
|
+
# "max_tokens": sampling_params.max_new_tokens,
|
|
66
|
+
# }
|
|
67
|
+
# if sampling_params.json_mode and self.model.supports_json:
|
|
68
|
+
# self.request_json["response_format"] = {"type": "json_object"}
|
|
69
|
+
|
|
70
|
+
# async def handle_response(self, response: ClientResponse) -> APIResponse:
|
|
71
|
+
# is_error = False
|
|
72
|
+
# error_message = None
|
|
73
|
+
# completion = None
|
|
74
|
+
# input_tokens = None
|
|
75
|
+
# output_tokens = None
|
|
76
|
+
# status_code = response.status
|
|
77
|
+
# mimetype = response.headers.get("Content-Type", None)
|
|
78
|
+
# if status_code >= 200 and status_code < 300:
|
|
79
|
+
# try:
|
|
80
|
+
# data = await response.json()
|
|
81
|
+
# completion = data["choices"][0]["message"]["content"]
|
|
82
|
+
# input_tokens = data["usage"]["prompt_tokens"]
|
|
83
|
+
# output_tokens = data["usage"]["completion_tokens"]
|
|
84
|
+
|
|
85
|
+
# except Exception:
|
|
86
|
+
# is_error = True
|
|
87
|
+
# error_message = (
|
|
88
|
+
# f"Error calling .json() on response w/ status {status_code}"
|
|
89
|
+
# )
|
|
90
|
+
# elif "json" in mimetype.lower():
|
|
91
|
+
# is_error = True # expected status is 200, otherwise it's an error
|
|
92
|
+
# data = await response.json()
|
|
93
|
+
# error_message = json.dumps(data)
|
|
94
|
+
# else:
|
|
95
|
+
# is_error = True
|
|
96
|
+
# text = await response.text()
|
|
97
|
+
# error_message = text
|
|
98
|
+
|
|
99
|
+
# # handle special kinds of errors
|
|
100
|
+
# if is_error and error_message is not None:
|
|
101
|
+
# if "rate limit" in error_message.lower():
|
|
102
|
+
# error_message += " (Rate limit error, triggering cooldown.)"
|
|
103
|
+
# self.status_tracker.time_of_last_rate_limit_error = time.time()
|
|
104
|
+
# self.status_tracker.num_rate_limit_errors += 1
|
|
105
|
+
# if "context length" in error_message:
|
|
106
|
+
# error_message += " (Context length exceeded, set retries to 0.)"
|
|
107
|
+
# self.attempts_left = 0
|
|
108
|
+
|
|
109
|
+
# return APIResponse(
|
|
110
|
+
# id=self.task_id,
|
|
111
|
+
# status_code=status_code,
|
|
112
|
+
# is_error=is_error,
|
|
113
|
+
# error_message=error_message,
|
|
114
|
+
# prompt=self.prompt,
|
|
115
|
+
# completion=completion,
|
|
116
|
+
# model_internal=self.model_name,
|
|
117
|
+
# sampling_params=self.sampling_params,
|
|
118
|
+
# input_tokens=input_tokens,
|
|
119
|
+
# output_tokens=output_tokens,
|
|
120
|
+
# )
|
|
File without changes
|