lm-deluge 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +6 -0
- lm_deluge/api_requests/__init__.py +3 -0
- lm_deluge/api_requests/anthropic.py +177 -0
- lm_deluge/api_requests/base.py +375 -0
- lm_deluge/api_requests/cohere.py +138 -0
- lm_deluge/api_requests/common.py +18 -0
- lm_deluge/api_requests/deprecated/bedrock.py +288 -0
- lm_deluge/api_requests/deprecated/deepseek.py +118 -0
- lm_deluge/api_requests/deprecated/mistral.py +120 -0
- lm_deluge/api_requests/google.py +0 -0
- lm_deluge/api_requests/openai.py +145 -0
- lm_deluge/api_requests/vertex.py +365 -0
- lm_deluge/cache.py +144 -0
- lm_deluge/client.py +760 -0
- lm_deluge/embed.py +392 -0
- lm_deluge/errors.py +8 -0
- lm_deluge/gemini_limits.py +65 -0
- lm_deluge/image.py +200 -0
- lm_deluge/llm_tools/__init__.py +11 -0
- lm_deluge/llm_tools/extract.py +111 -0
- lm_deluge/llm_tools/score.py +71 -0
- lm_deluge/llm_tools/translate.py +44 -0
- lm_deluge/models.py +957 -0
- lm_deluge/prompt.py +355 -0
- lm_deluge/rerank.py +338 -0
- lm_deluge/sampling_params.py +25 -0
- lm_deluge/tool.py +106 -0
- lm_deluge/tracker.py +12 -0
- lm_deluge/util/json.py +167 -0
- lm_deluge/util/logprobs.py +446 -0
- lm_deluge/util/pdf.py +45 -0
- lm_deluge/util/validation.py +46 -0
- lm_deluge/util/xml.py +291 -0
- lm_deluge-0.0.3.dist-info/METADATA +127 -0
- lm_deluge-0.0.3.dist-info/RECORD +37 -0
- lm_deluge-0.0.3.dist-info/WHEEL +5 -0
- lm_deluge-0.0.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import warnings
|
|
3
|
+
from aiohttp import ClientResponse
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
from tqdm.auto import tqdm
|
|
8
|
+
from typing import Optional, Callable
|
|
9
|
+
|
|
10
|
+
from .base import APIRequestBase, APIResponse
|
|
11
|
+
from ..prompt import Conversation
|
|
12
|
+
from ..tracker import StatusTracker
|
|
13
|
+
from ..sampling_params import SamplingParams
|
|
14
|
+
from ..models import APIModel
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class OpenAIRequest(APIRequestBase):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
task_id: int,
|
|
21
|
+
# should always be 'role', 'content' keys.
|
|
22
|
+
# internal logic should handle translating to specific API format
|
|
23
|
+
model_name: str, # must correspond to registry
|
|
24
|
+
prompt: Conversation,
|
|
25
|
+
attempts_left: int,
|
|
26
|
+
status_tracker: StatusTracker,
|
|
27
|
+
retry_queue: asyncio.Queue,
|
|
28
|
+
results_arr: list,
|
|
29
|
+
request_timeout: int = 30,
|
|
30
|
+
sampling_params: SamplingParams = SamplingParams(),
|
|
31
|
+
logprobs: bool = False,
|
|
32
|
+
top_logprobs: Optional[int] = None,
|
|
33
|
+
pbar: Optional[tqdm] = None,
|
|
34
|
+
callback: Optional[Callable] = None,
|
|
35
|
+
debug: bool = False,
|
|
36
|
+
all_model_names: list[str] | None = None,
|
|
37
|
+
all_sampling_params: list[SamplingParams] | None = None,
|
|
38
|
+
):
|
|
39
|
+
super().__init__(
|
|
40
|
+
task_id=task_id,
|
|
41
|
+
model_name=model_name,
|
|
42
|
+
prompt=prompt,
|
|
43
|
+
attempts_left=attempts_left,
|
|
44
|
+
status_tracker=status_tracker,
|
|
45
|
+
retry_queue=retry_queue,
|
|
46
|
+
results_arr=results_arr,
|
|
47
|
+
request_timeout=request_timeout,
|
|
48
|
+
sampling_params=sampling_params,
|
|
49
|
+
logprobs=logprobs,
|
|
50
|
+
top_logprobs=top_logprobs,
|
|
51
|
+
pbar=pbar,
|
|
52
|
+
callback=callback,
|
|
53
|
+
debug=debug,
|
|
54
|
+
all_model_names=all_model_names,
|
|
55
|
+
all_sampling_params=all_sampling_params,
|
|
56
|
+
)
|
|
57
|
+
self.model = APIModel.from_registry(model_name)
|
|
58
|
+
self.url = f"{self.model.api_base}/chat/completions"
|
|
59
|
+
self.request_header = {
|
|
60
|
+
"Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
|
|
61
|
+
}
|
|
62
|
+
self.request_json = {
|
|
63
|
+
"model": self.model.name,
|
|
64
|
+
"messages": prompt.to_openai(),
|
|
65
|
+
"temperature": sampling_params.temperature,
|
|
66
|
+
"top_p": sampling_params.top_p,
|
|
67
|
+
"max_completion_tokens": sampling_params.max_new_tokens,
|
|
68
|
+
}
|
|
69
|
+
if self.model.reasoning_model:
|
|
70
|
+
self.request_json["temperature"] = 1.0
|
|
71
|
+
self.request_json["top_p"] = 1.0
|
|
72
|
+
self.request_json["reasoning_effort"] = sampling_params.reasoning_effort
|
|
73
|
+
else:
|
|
74
|
+
if sampling_params.reasoning_effort:
|
|
75
|
+
warnings.warn(
|
|
76
|
+
f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
|
|
77
|
+
)
|
|
78
|
+
if logprobs:
|
|
79
|
+
self.request_json["logprobs"] = True
|
|
80
|
+
if top_logprobs is not None:
|
|
81
|
+
self.request_json["top_logprobs"] = top_logprobs
|
|
82
|
+
if sampling_params.json_mode and self.model.supports_json:
|
|
83
|
+
self.request_json["response_format"] = {"type": "json_object"}
|
|
84
|
+
|
|
85
|
+
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
86
|
+
is_error = False
|
|
87
|
+
error_message = None
|
|
88
|
+
completion = None
|
|
89
|
+
input_tokens = None
|
|
90
|
+
output_tokens = None
|
|
91
|
+
logprobs = None
|
|
92
|
+
status_code = http_response.status
|
|
93
|
+
mimetype = http_response.headers.get("Content-Type", None)
|
|
94
|
+
data = None
|
|
95
|
+
if status_code >= 200 and status_code < 300:
|
|
96
|
+
try:
|
|
97
|
+
data = await http_response.json()
|
|
98
|
+
except Exception:
|
|
99
|
+
is_error = True
|
|
100
|
+
error_message = (
|
|
101
|
+
f"Error calling .json() on response w/ status {status_code}"
|
|
102
|
+
)
|
|
103
|
+
if not is_error:
|
|
104
|
+
assert data is not None, "data is None"
|
|
105
|
+
try:
|
|
106
|
+
completion = data["choices"][0]["message"]["content"]
|
|
107
|
+
input_tokens = data["usage"]["prompt_tokens"]
|
|
108
|
+
output_tokens = data["usage"]["completion_tokens"]
|
|
109
|
+
if self.logprobs and "logprobs" in data["choices"][0]:
|
|
110
|
+
logprobs = data["choices"][0]["logprobs"]["content"]
|
|
111
|
+
except Exception:
|
|
112
|
+
is_error = True
|
|
113
|
+
error_message = f"Error getting 'choices' and 'usage' from {self.model.name} response."
|
|
114
|
+
elif mimetype and "json" in mimetype.lower():
|
|
115
|
+
is_error = True # expected status is 200, otherwise it's an error
|
|
116
|
+
data = await http_response.json()
|
|
117
|
+
error_message = json.dumps(data)
|
|
118
|
+
else:
|
|
119
|
+
is_error = True
|
|
120
|
+
text = await http_response.text()
|
|
121
|
+
error_message = text
|
|
122
|
+
|
|
123
|
+
# handle special kinds of errors
|
|
124
|
+
if is_error and error_message is not None:
|
|
125
|
+
if "rate limit" in error_message.lower() or status_code == 429:
|
|
126
|
+
error_message += " (Rate limit error, triggering cooldown.)"
|
|
127
|
+
self.status_tracker.time_of_last_rate_limit_error = time.time()
|
|
128
|
+
self.status_tracker.num_rate_limit_errors += 1
|
|
129
|
+
if "context length" in error_message:
|
|
130
|
+
error_message += " (Context length exceeded, set retries to 0.)"
|
|
131
|
+
self.attempts_left = 0
|
|
132
|
+
|
|
133
|
+
return APIResponse(
|
|
134
|
+
id=self.task_id,
|
|
135
|
+
status_code=status_code,
|
|
136
|
+
is_error=is_error,
|
|
137
|
+
error_message=error_message,
|
|
138
|
+
prompt=self.prompt,
|
|
139
|
+
logprobs=logprobs,
|
|
140
|
+
completion=completion,
|
|
141
|
+
model_internal=self.model_name,
|
|
142
|
+
sampling_params=self.sampling_params,
|
|
143
|
+
input_tokens=input_tokens,
|
|
144
|
+
output_tokens=output_tokens,
|
|
145
|
+
)
|
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
# consider: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-gemini-using-openai-library#call-chat-completions-api
|
|
2
|
+
import asyncio
|
|
3
|
+
from aiohttp import ClientResponse
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from typing import Optional, Callable
|
|
9
|
+
|
|
10
|
+
from lm_deluge.prompt import Conversation
|
|
11
|
+
from .base import APIRequestBase, APIResponse
|
|
12
|
+
from ..tracker import StatusTracker
|
|
13
|
+
from ..sampling_params import SamplingParams
|
|
14
|
+
from ..models import APIModel
|
|
15
|
+
|
|
16
|
+
from google.oauth2 import service_account
|
|
17
|
+
from google.auth.transport.requests import Request
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_access_token(service_account_file: str):
|
|
21
|
+
"""
|
|
22
|
+
Get access token from environment variables if another process/coroutine
|
|
23
|
+
has already got them, otherwise get from service account file.
|
|
24
|
+
"""
|
|
25
|
+
LAST_REFRESHED = os.getenv("VERTEX_TOKEN_LAST_REFRESHED", None)
|
|
26
|
+
LAST_REFRESHED = int(LAST_REFRESHED) if LAST_REFRESHED is not None else 0
|
|
27
|
+
VERTEX_API_TOKEN = os.getenv("VERTEX_API_TOKEN", None)
|
|
28
|
+
|
|
29
|
+
if VERTEX_API_TOKEN is not None and time.time() - LAST_REFRESHED < 60 * 50:
|
|
30
|
+
return VERTEX_API_TOKEN
|
|
31
|
+
else:
|
|
32
|
+
credentials = service_account.Credentials.from_service_account_file(
|
|
33
|
+
service_account_file,
|
|
34
|
+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
35
|
+
)
|
|
36
|
+
credentials.refresh(Request())
|
|
37
|
+
token = credentials.token
|
|
38
|
+
os.environ["VERTEX_API_TOKEN"] = token
|
|
39
|
+
os.environ["VERTEX_TOKEN_LAST_REFRESHED"] = str(int(time.time()))
|
|
40
|
+
|
|
41
|
+
return token
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class VertexAnthropicRequest(APIRequestBase):
|
|
45
|
+
"""
|
|
46
|
+
For Claude on Vertex, you'll also have to set the PROJECT_ID environment variable.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
task_id: int,
|
|
52
|
+
model_name: str, # must correspond to registry
|
|
53
|
+
prompt: Conversation,
|
|
54
|
+
attempts_left: int,
|
|
55
|
+
status_tracker: StatusTracker,
|
|
56
|
+
retry_queue: asyncio.Queue,
|
|
57
|
+
results_arr: list,
|
|
58
|
+
request_timeout: int = 30,
|
|
59
|
+
sampling_params: SamplingParams = SamplingParams(),
|
|
60
|
+
pbar: Optional[tqdm] = None,
|
|
61
|
+
callback: Optional[Callable] = None,
|
|
62
|
+
debug: bool = False,
|
|
63
|
+
):
|
|
64
|
+
super().__init__(
|
|
65
|
+
task_id=task_id,
|
|
66
|
+
model_name=model_name,
|
|
67
|
+
prompt=prompt,
|
|
68
|
+
attempts_left=attempts_left,
|
|
69
|
+
status_tracker=status_tracker,
|
|
70
|
+
retry_queue=retry_queue,
|
|
71
|
+
results_arr=results_arr,
|
|
72
|
+
request_timeout=request_timeout,
|
|
73
|
+
sampling_params=sampling_params,
|
|
74
|
+
pbar=pbar,
|
|
75
|
+
callback=callback,
|
|
76
|
+
debug=debug,
|
|
77
|
+
)
|
|
78
|
+
creds = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
|
|
79
|
+
if not creds:
|
|
80
|
+
raise RuntimeError(
|
|
81
|
+
"GOOGLE_APPLICATION_CREDENTIALS not provided in environment"
|
|
82
|
+
)
|
|
83
|
+
token = get_access_token(creds)
|
|
84
|
+
|
|
85
|
+
self.model = APIModel.from_registry(model_name)
|
|
86
|
+
project_id = os.getenv("PROJECT_ID")
|
|
87
|
+
region = self.model.sample_region()
|
|
88
|
+
|
|
89
|
+
endpoint = f"https://{region}-aiplatform.googleapis.com"
|
|
90
|
+
self.url = f"{endpoint}/v1/projects/{project_id}/locations/{region}/publishers/anthropic/models/{self.model.name}:generateContent"
|
|
91
|
+
self.request_header = {
|
|
92
|
+
"Authorization": f"Bearer {token}",
|
|
93
|
+
"Content-Type": "application/json",
|
|
94
|
+
}
|
|
95
|
+
self.system_message, messages = prompt.to_anthropic()
|
|
96
|
+
|
|
97
|
+
self.request_json = {
|
|
98
|
+
"anthropic_version": "vertex-2023-10-16",
|
|
99
|
+
"messages": messages,
|
|
100
|
+
"temperature": self.sampling_params.temperature,
|
|
101
|
+
"top_p": self.sampling_params.top_p,
|
|
102
|
+
"max_tokens": self.sampling_params.max_new_tokens,
|
|
103
|
+
}
|
|
104
|
+
if self.system_message is not None:
|
|
105
|
+
self.request_json["system"] = self.system_message
|
|
106
|
+
|
|
107
|
+
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
108
|
+
is_error = False
|
|
109
|
+
error_message = None
|
|
110
|
+
completion = None
|
|
111
|
+
input_tokens = None
|
|
112
|
+
output_tokens = None
|
|
113
|
+
status_code = http_response.status
|
|
114
|
+
mimetype = http_response.headers.get("Content-Type", None)
|
|
115
|
+
if status_code >= 200 and status_code < 300:
|
|
116
|
+
try:
|
|
117
|
+
data = await http_response.json()
|
|
118
|
+
completion = data["content"][0]["text"]
|
|
119
|
+
input_tokens = data["usage"]["input_tokens"]
|
|
120
|
+
output_tokens = data["usage"]["output_tokens"]
|
|
121
|
+
except Exception as e:
|
|
122
|
+
is_error = True
|
|
123
|
+
error_message = (
|
|
124
|
+
f"Error calling .json() on response w/ status {status_code}: {e}"
|
|
125
|
+
)
|
|
126
|
+
elif "json" in (mimetype or "").lower():
|
|
127
|
+
is_error = True # expected status is 200, otherwise it's an error
|
|
128
|
+
data = await http_response.json()
|
|
129
|
+
error_message = json.dumps(data)
|
|
130
|
+
|
|
131
|
+
else:
|
|
132
|
+
is_error = True
|
|
133
|
+
text = await http_response.text()
|
|
134
|
+
error_message = text
|
|
135
|
+
|
|
136
|
+
# handle special kinds of errors. TODO: make sure these are correct for anthropic
|
|
137
|
+
if is_error and error_message is not None:
|
|
138
|
+
if (
|
|
139
|
+
"rate limit" in error_message.lower()
|
|
140
|
+
or "overloaded" in error_message.lower()
|
|
141
|
+
or status_code == 429
|
|
142
|
+
):
|
|
143
|
+
error_message += " (Rate limit error, triggering cooldown.)"
|
|
144
|
+
self.status_tracker.time_of_last_rate_limit_error = time.time()
|
|
145
|
+
self.status_tracker.num_rate_limit_errors += 1
|
|
146
|
+
if "context length" in error_message:
|
|
147
|
+
error_message += " (Context length exceeded, set retries to 0.)"
|
|
148
|
+
self.attempts_left = 0
|
|
149
|
+
|
|
150
|
+
return APIResponse(
|
|
151
|
+
id=self.task_id,
|
|
152
|
+
status_code=status_code,
|
|
153
|
+
is_error=is_error,
|
|
154
|
+
error_message=error_message,
|
|
155
|
+
prompt=self.prompt,
|
|
156
|
+
completion=completion,
|
|
157
|
+
model_internal=self.model_name,
|
|
158
|
+
sampling_params=self.sampling_params,
|
|
159
|
+
input_tokens=input_tokens,
|
|
160
|
+
output_tokens=output_tokens,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
SAFETY_SETTING_CATEGORIES = [
|
|
165
|
+
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
|
166
|
+
"HARM_CATEGORY_HARASSMENT",
|
|
167
|
+
"HARM_CATEGORY_HATE_SPEECH",
|
|
168
|
+
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
|
169
|
+
]
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class GeminiRequest(APIRequestBase):
|
|
173
|
+
"""
|
|
174
|
+
For Gemini, you'll also have to set the PROJECT_ID environment variable.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
def __init__(
|
|
178
|
+
self,
|
|
179
|
+
task_id: int,
|
|
180
|
+
model_name: str, # must correspond to registry
|
|
181
|
+
prompt: Conversation,
|
|
182
|
+
attempts_left: int,
|
|
183
|
+
status_tracker: StatusTracker,
|
|
184
|
+
retry_queue: asyncio.Queue,
|
|
185
|
+
results_arr: list,
|
|
186
|
+
request_timeout: int = 30,
|
|
187
|
+
sampling_params: SamplingParams = SamplingParams(),
|
|
188
|
+
pbar: Optional[tqdm] = None,
|
|
189
|
+
callback: Optional[Callable] = None,
|
|
190
|
+
debug: bool = False,
|
|
191
|
+
all_model_names: list[str] | None = None,
|
|
192
|
+
all_sampling_params: list[SamplingParams] | None = None,
|
|
193
|
+
):
|
|
194
|
+
super().__init__(
|
|
195
|
+
task_id=task_id,
|
|
196
|
+
model_name=model_name,
|
|
197
|
+
prompt=prompt,
|
|
198
|
+
attempts_left=attempts_left,
|
|
199
|
+
status_tracker=status_tracker,
|
|
200
|
+
retry_queue=retry_queue,
|
|
201
|
+
results_arr=results_arr,
|
|
202
|
+
request_timeout=request_timeout,
|
|
203
|
+
sampling_params=sampling_params,
|
|
204
|
+
pbar=pbar,
|
|
205
|
+
callback=callback,
|
|
206
|
+
debug=debug,
|
|
207
|
+
all_model_names=all_model_names,
|
|
208
|
+
all_sampling_params=all_sampling_params,
|
|
209
|
+
)
|
|
210
|
+
self.model = APIModel.from_registry(model_name)
|
|
211
|
+
credentials_file = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
|
|
212
|
+
if not credentials_file:
|
|
213
|
+
raise RuntimeError(
|
|
214
|
+
"no credentials file found. ensure you provide a google credentials file and point to it with GOOGLE_APPLICATION_CREDENTIALS environment variable."
|
|
215
|
+
)
|
|
216
|
+
token = get_access_token(credentials_file)
|
|
217
|
+
self.project_id = os.getenv("PROJECT_ID")
|
|
218
|
+
# sample weighted by region counts
|
|
219
|
+
self.region = self.model.sample_region()
|
|
220
|
+
assert self.region is not None, "unable to sample region"
|
|
221
|
+
self.url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.region}/publishers/google/models/{self.model.name}:generateContent"
|
|
222
|
+
|
|
223
|
+
self.request_header = {
|
|
224
|
+
"Authorization": f"Bearer {token}",
|
|
225
|
+
"Content-Type": "application/json",
|
|
226
|
+
}
|
|
227
|
+
self.system_message, contents = prompt.to_gemini()
|
|
228
|
+
self.request_json = {
|
|
229
|
+
"contents": contents,
|
|
230
|
+
"generationConfig": {
|
|
231
|
+
"stopSequences": [],
|
|
232
|
+
"temperature": sampling_params.temperature,
|
|
233
|
+
"maxOutputTokens": sampling_params.max_new_tokens,
|
|
234
|
+
"topP": sampling_params.top_p,
|
|
235
|
+
"topK": None,
|
|
236
|
+
},
|
|
237
|
+
"safetySettings": [
|
|
238
|
+
{"category": category, "threshold": "BLOCK_NONE"}
|
|
239
|
+
for category in SAFETY_SETTING_CATEGORIES
|
|
240
|
+
],
|
|
241
|
+
}
|
|
242
|
+
if sampling_params.json_mode and self.model.supports_json:
|
|
243
|
+
self.request_json["generationConfig"]["responseMimeType"] = (
|
|
244
|
+
"application/json"
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
if self.system_message is not None:
|
|
248
|
+
self.request_json["systemInstruction"] = (
|
|
249
|
+
{"role": "SYSTEM", "parts": [{"text": self.system_message}]},
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
253
|
+
is_error = False
|
|
254
|
+
error_message = None
|
|
255
|
+
completion = None
|
|
256
|
+
input_tokens = None
|
|
257
|
+
output_tokens = None
|
|
258
|
+
finish_reason = None
|
|
259
|
+
data = None
|
|
260
|
+
retry_with_different_model = False
|
|
261
|
+
give_up_if_no_other_models = False
|
|
262
|
+
status_code = http_response.status
|
|
263
|
+
mimetype = http_response.headers.get("Content-Type", None)
|
|
264
|
+
if status_code >= 200 and status_code < 300:
|
|
265
|
+
try:
|
|
266
|
+
data = await http_response.json()
|
|
267
|
+
if "candidates" not in data:
|
|
268
|
+
is_error = True
|
|
269
|
+
if "promptFeedback" in data:
|
|
270
|
+
error_message = "Prompt rejected. Feedback: " + str(
|
|
271
|
+
data["promptFeedback"]
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
error_message = "No candidates in response."
|
|
275
|
+
retry_with_different_model = True
|
|
276
|
+
give_up_if_no_other_models = True
|
|
277
|
+
else:
|
|
278
|
+
candidate = data["candidates"][0]
|
|
279
|
+
finish_reason = candidate["finishReason"]
|
|
280
|
+
if "content" in candidate:
|
|
281
|
+
parts = candidate["content"]["parts"]
|
|
282
|
+
completion = " ".join([part["text"] for part in parts])
|
|
283
|
+
usage = data["usageMetadata"]
|
|
284
|
+
input_tokens = usage["promptTokenCount"]
|
|
285
|
+
output_tokens = usage["candidatesTokenCount"]
|
|
286
|
+
elif finish_reason == "RECITATION":
|
|
287
|
+
is_error = True
|
|
288
|
+
citations = candidate.get("citationMetadata", {}).get(
|
|
289
|
+
"citations", []
|
|
290
|
+
)
|
|
291
|
+
urls = ",".join(
|
|
292
|
+
[citation.get("uri", "") for citation in citations]
|
|
293
|
+
)
|
|
294
|
+
error_message = "Finish reason RECITATION. URLS: " + urls
|
|
295
|
+
retry_with_different_model = True
|
|
296
|
+
elif finish_reason == "OTHER":
|
|
297
|
+
is_error = True
|
|
298
|
+
error_message = "Finish reason OTHER."
|
|
299
|
+
retry_with_different_model = True
|
|
300
|
+
elif finish_reason == "SAFETY":
|
|
301
|
+
is_error = True
|
|
302
|
+
error_message = "Finish reason SAFETY."
|
|
303
|
+
retry_with_different_model = True
|
|
304
|
+
else:
|
|
305
|
+
print("Actual structure of response:")
|
|
306
|
+
print(data)
|
|
307
|
+
is_error = True
|
|
308
|
+
error_message = "No content in response."
|
|
309
|
+
except Exception as e:
|
|
310
|
+
is_error = True
|
|
311
|
+
error_message = f"Error calling .json() on response w/ status {status_code}: {e.__class__} {e}"
|
|
312
|
+
if isinstance(e, KeyError):
|
|
313
|
+
print("Actual structure of response:")
|
|
314
|
+
print(data)
|
|
315
|
+
elif "json" in (mimetype or "").lower():
|
|
316
|
+
is_error = True
|
|
317
|
+
data = await http_response.json()
|
|
318
|
+
error_message = json.dumps(data)
|
|
319
|
+
else:
|
|
320
|
+
is_error = True
|
|
321
|
+
text = await http_response.text()
|
|
322
|
+
error_message = text
|
|
323
|
+
|
|
324
|
+
old_region = self.region
|
|
325
|
+
if is_error and error_message is not None:
|
|
326
|
+
if (
|
|
327
|
+
"rate limit" in error_message.lower()
|
|
328
|
+
or "temporarily out of capacity" in error_message.lower()
|
|
329
|
+
or "exceeded" in error_message.lower()
|
|
330
|
+
or
|
|
331
|
+
# 429 code
|
|
332
|
+
status_code == 429
|
|
333
|
+
):
|
|
334
|
+
error_message += " (Rate limit error, triggering cooldown & retrying with different model.)"
|
|
335
|
+
self.status_tracker.time_of_last_rate_limit_error = time.time()
|
|
336
|
+
self.status_tracker.num_rate_limit_errors += 1
|
|
337
|
+
retry_with_different_model = (
|
|
338
|
+
True # if possible, retry with a different model
|
|
339
|
+
)
|
|
340
|
+
if is_error:
|
|
341
|
+
# change the region in case error is due to region unavailability
|
|
342
|
+
self.region = self.model.sample_region()
|
|
343
|
+
assert self.region is not None, "Unable to sample region"
|
|
344
|
+
self.url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.region}/publishers/google/models/{self.model.name}:generateContent"
|
|
345
|
+
|
|
346
|
+
return APIResponse(
|
|
347
|
+
id=self.task_id,
|
|
348
|
+
status_code=status_code,
|
|
349
|
+
is_error=is_error,
|
|
350
|
+
error_message=error_message,
|
|
351
|
+
prompt=self.prompt,
|
|
352
|
+
completion=completion,
|
|
353
|
+
model_internal=self.model_name,
|
|
354
|
+
sampling_params=self.sampling_params,
|
|
355
|
+
input_tokens=input_tokens,
|
|
356
|
+
output_tokens=output_tokens,
|
|
357
|
+
region=old_region,
|
|
358
|
+
finish_reason=finish_reason,
|
|
359
|
+
retry_with_different_model=retry_with_different_model,
|
|
360
|
+
give_up_if_no_other_models=give_up_if_no_other_models,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
# class LlamaEndpointRequest(APIRequestBase):
|
|
365
|
+
# raise NotImplementedError("Llama endpoints are not implemented and never will be because Vertex AI sucks ass.")
|
lm_deluge/cache.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import tempfile
|
|
2
|
+
import json
|
|
3
|
+
import sqlite3
|
|
4
|
+
from typing import Any
|
|
5
|
+
from .prompt import Conversation
|
|
6
|
+
from .api_requests.base import APIResponse
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import plyvel # type: ignore
|
|
10
|
+
except ImportError:
|
|
11
|
+
plyvel = None
|
|
12
|
+
print("Warning: plyvel not installed, cannot use LevelDB.")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def encode_api_response(response: APIResponse) -> bytes:
|
|
16
|
+
"""
|
|
17
|
+
Encode an API response as a string.
|
|
18
|
+
"""
|
|
19
|
+
return json.dumps(response.to_dict()).encode()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def decode_api_response(data: bytes) -> APIResponse:
|
|
23
|
+
"""
|
|
24
|
+
Decode an API response from a string.
|
|
25
|
+
"""
|
|
26
|
+
return APIResponse.from_dict(json.loads(data.decode()))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DistributedDictCache:
|
|
30
|
+
"""
|
|
31
|
+
Use distributed dictionary (e.g. Modal Dict) as a cache.
|
|
32
|
+
Pass in the dictionary object to use. Cache must implement
|
|
33
|
+
'get' and 'put' methods.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, cache: Any, cache_key: str = "default"):
|
|
37
|
+
self.cache = cache
|
|
38
|
+
self.cache_key = cache_key # for namespacing
|
|
39
|
+
|
|
40
|
+
def get(self, prompt: Conversation) -> APIResponse | None:
|
|
41
|
+
"""
|
|
42
|
+
Get an API response from the cache.
|
|
43
|
+
"""
|
|
44
|
+
data = self.cache.get(f"{self.cache_key}:{prompt.fingerprint}")
|
|
45
|
+
if data is not None:
|
|
46
|
+
return decode_api_response(data)
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
def put(self, prompt: Conversation, response: APIResponse) -> None:
|
|
50
|
+
"""
|
|
51
|
+
Put an API response into the cache.
|
|
52
|
+
"""
|
|
53
|
+
key = f"{self.cache_key}:{prompt.fingerprint}"
|
|
54
|
+
self.cache.put(key, encode_api_response(response))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class LevelDBCache:
|
|
58
|
+
"""
|
|
59
|
+
Store API responses based on their input messages.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(self, path: str | None = None, cache_key: str = "default"):
|
|
63
|
+
if path is None:
|
|
64
|
+
self.temp_file = tempfile.TemporaryFile(suffix=".db")
|
|
65
|
+
path = self.temp_file.name
|
|
66
|
+
print(f"Using temporary cache at {path}")
|
|
67
|
+
else:
|
|
68
|
+
self.temp_file = None
|
|
69
|
+
self.path = path
|
|
70
|
+
if plyvel is not None:
|
|
71
|
+
self.db = plyvel.DB(path, create_if_missing=True)
|
|
72
|
+
else:
|
|
73
|
+
raise ImportError("plyvel not installed, cannot use LevelDBCache.")
|
|
74
|
+
self.cache_key = cache_key # for namespacing
|
|
75
|
+
|
|
76
|
+
def get(self, prompt: Conversation) -> APIResponse | None:
|
|
77
|
+
"""
|
|
78
|
+
Get an API response from the cache.
|
|
79
|
+
"""
|
|
80
|
+
key = f"{self.cache_key}:{prompt.fingerprint}"
|
|
81
|
+
data = self.db.get(key.encode())
|
|
82
|
+
if data is not None:
|
|
83
|
+
return decode_api_response(data)
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
def put(self, prompt: Conversation, response: APIResponse):
|
|
87
|
+
"""
|
|
88
|
+
Put an API response into the cache.
|
|
89
|
+
"""
|
|
90
|
+
key = f"{self.cache_key}:{prompt.fingerprint}"
|
|
91
|
+
self.db.put(key.encode(), encode_api_response(response))
|
|
92
|
+
|
|
93
|
+
def close(self):
|
|
94
|
+
"""
|
|
95
|
+
Close the cache.
|
|
96
|
+
"""
|
|
97
|
+
self.db.close()
|
|
98
|
+
if self.temp_file is not None:
|
|
99
|
+
self.temp_file.close()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class SqliteCache:
|
|
103
|
+
"""
|
|
104
|
+
Same interface as LevelDBCache, but uses SQLite as KV store instead.
|
|
105
|
+
Good to use on systems where LevelDB installation is problematic.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(self, path: str, cache_key: str = "default"):
|
|
109
|
+
self.path = path
|
|
110
|
+
self.cache_key = cache_key # for namespacing
|
|
111
|
+
self.conn = sqlite3.connect(path)
|
|
112
|
+
self.cursor = self.conn.cursor()
|
|
113
|
+
self.cursor.execute(
|
|
114
|
+
"CREATE TABLE IF NOT EXISTS cache (key TEXT PRIMARY KEY, value BLOB)"
|
|
115
|
+
)
|
|
116
|
+
self.conn.commit()
|
|
117
|
+
|
|
118
|
+
def get(self, prompt: Conversation) -> APIResponse | None:
|
|
119
|
+
"""
|
|
120
|
+
Get an API response from the cache.
|
|
121
|
+
"""
|
|
122
|
+
key = f"{self.cache_key}:{prompt.fingerprint}"
|
|
123
|
+
self.cursor.execute("SELECT value FROM cache WHERE key=?", (key,))
|
|
124
|
+
data = self.cursor.fetchone()
|
|
125
|
+
if data is not None and len(data) > 0:
|
|
126
|
+
return decode_api_response(data[0])
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
def put(self, prompt: Conversation, response: APIResponse):
|
|
130
|
+
"""
|
|
131
|
+
Put an API response into the cache.
|
|
132
|
+
"""
|
|
133
|
+
key = f"{self.cache_key}:{prompt.fingerprint}"
|
|
134
|
+
self.cursor.execute(
|
|
135
|
+
"INSERT OR REPLACE INTO cache (key, value) VALUES (?, ?)",
|
|
136
|
+
(key, encode_api_response(response)),
|
|
137
|
+
)
|
|
138
|
+
self.conn.commit()
|
|
139
|
+
|
|
140
|
+
def close(self):
|
|
141
|
+
"""
|
|
142
|
+
Close the cache.
|
|
143
|
+
"""
|
|
144
|
+
self.conn.close()
|