lm-deluge 0.0.3__tar.gz → 0.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/PKG-INFO +2 -2
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/pyproject.toml +2 -2
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/api_requests/anthropic.py +4 -10
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/api_requests/base.py +23 -27
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/api_requests/cohere.py +6 -12
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/api_requests/deprecated/bedrock.py +4 -4
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/api_requests/deprecated/deepseek.py +2 -2
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/api_requests/deprecated/mistral.py +2 -2
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/api_requests/openai.py +5 -7
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/api_requests/vertex.py +9 -13
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/client.py +28 -43
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/embed.py +13 -28
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/llm_tools/extract.py +5 -5
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/models.py +4 -5
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/rerank.py +15 -29
- lm_deluge-0.0.5/src/lm_deluge/tracker.py +43 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/util/logprobs.py +2 -2
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge.egg-info/PKG-INFO +2 -2
- lm_deluge-0.0.3/src/lm_deluge/tracker.py +0 -12
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/README.md +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/setup.cfg +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/__init__.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/api_requests/__init__.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/api_requests/common.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/api_requests/google.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/cache.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/errors.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/gemini_limits.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/image.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/llm_tools/__init__.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/llm_tools/score.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/llm_tools/translate.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/prompt.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/sampling_params.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/tool.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/util/json.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/util/pdf.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/util/validation.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge/util/xml.py +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge.egg-info/SOURCES.txt +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge.egg-info/dependency_links.txt +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge.egg-info/requires.txt +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/src/lm_deluge.egg-info/top_level.txt +0 -0
- {lm_deluge-0.0.3 → lm_deluge-0.0.5}/tests/test_heal_json.py +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: lm_deluge
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.5
|
|
4
4
|
Summary: Python utility for using LLM API models.
|
|
5
5
|
Author-email: Benjamin Anderson <ben@trytaylor.ai>
|
|
6
|
-
Requires-Python: >=3.
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
8
8
|
Requires-Dist: python-dotenv
|
|
9
9
|
Requires-Dist: json5
|
|
@@ -3,11 +3,11 @@ requires = ["setuptools", "wheel"]
|
|
|
3
3
|
|
|
4
4
|
[project]
|
|
5
5
|
name = "lm_deluge"
|
|
6
|
-
version = "0.0.
|
|
6
|
+
version = "0.0.5"
|
|
7
7
|
authors = [{ name = "Benjamin Anderson", email = "ben@trytaylor.ai" }]
|
|
8
8
|
description = "Python utility for using LLM API models."
|
|
9
9
|
readme = "README.md"
|
|
10
|
-
requires-python = ">=3.
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
11
|
keywords = []
|
|
12
12
|
license = { text = "" }
|
|
13
13
|
classifiers = []
|
|
@@ -3,9 +3,8 @@ from aiohttp import ClientResponse
|
|
|
3
3
|
import json
|
|
4
4
|
import os
|
|
5
5
|
import warnings
|
|
6
|
-
import time
|
|
7
6
|
from tqdm import tqdm
|
|
8
|
-
from typing import
|
|
7
|
+
from typing import Callable
|
|
9
8
|
|
|
10
9
|
from lm_deluge.prompt import Conversation
|
|
11
10
|
from .base import APIRequestBase, APIResponse
|
|
@@ -29,8 +28,8 @@ class AnthropicRequest(APIRequestBase):
|
|
|
29
28
|
results_arr: list,
|
|
30
29
|
request_timeout: int = 30,
|
|
31
30
|
sampling_params: SamplingParams = SamplingParams(),
|
|
32
|
-
pbar:
|
|
33
|
-
callback:
|
|
31
|
+
pbar: tqdm | None = None,
|
|
32
|
+
callback: Callable | None = None,
|
|
34
33
|
debug: bool = False,
|
|
35
34
|
# for retries
|
|
36
35
|
all_model_names: list[str] | None = None,
|
|
@@ -96,8 +95,6 @@ class AnthropicRequest(APIRequestBase):
|
|
|
96
95
|
if self.system_message is not None:
|
|
97
96
|
self.request_json["system"] = self.system_message
|
|
98
97
|
|
|
99
|
-
# print("request data:", self.request_json)
|
|
100
|
-
|
|
101
98
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
102
99
|
is_error = False
|
|
103
100
|
error_message = None
|
|
@@ -122,9 +119,7 @@ class AnthropicRequest(APIRequestBase):
|
|
|
122
119
|
if status_code >= 200 and status_code < 300:
|
|
123
120
|
try:
|
|
124
121
|
data = await http_response.json()
|
|
125
|
-
print("response data:", data)
|
|
126
122
|
content = data["content"] # [0]["text"]
|
|
127
|
-
print("content is length", len(content))
|
|
128
123
|
for item in content:
|
|
129
124
|
if item["type"] == "text":
|
|
130
125
|
completion = item["text"]
|
|
@@ -156,8 +151,7 @@ class AnthropicRequest(APIRequestBase):
|
|
|
156
151
|
or "overloaded" in error_message.lower()
|
|
157
152
|
):
|
|
158
153
|
error_message += " (Rate limit error, triggering cooldown.)"
|
|
159
|
-
self.status_tracker.
|
|
160
|
-
self.status_tracker.num_rate_limit_errors += 1
|
|
154
|
+
self.status_tracker.rate_limit_exceeded()
|
|
161
155
|
if "context length" in error_message:
|
|
162
156
|
error_message += " (Context length exceeded, set retries to 0.)"
|
|
163
157
|
self.attempts_left = 0
|
|
@@ -5,7 +5,7 @@ import random
|
|
|
5
5
|
from tqdm import tqdm
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
from abc import ABC, abstractmethod
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import Callable
|
|
9
9
|
|
|
10
10
|
from lm_deluge.prompt import Conversation
|
|
11
11
|
|
|
@@ -26,25 +26,25 @@ class APIResponse:
|
|
|
26
26
|
|
|
27
27
|
# http response information
|
|
28
28
|
status_code: int | None
|
|
29
|
-
is_error:
|
|
30
|
-
error_message:
|
|
29
|
+
is_error: bool | None
|
|
30
|
+
error_message: str | None
|
|
31
31
|
|
|
32
32
|
# completion information
|
|
33
|
-
completion:
|
|
34
|
-
input_tokens:
|
|
35
|
-
output_tokens:
|
|
33
|
+
completion: str | None
|
|
34
|
+
input_tokens: int | None
|
|
35
|
+
output_tokens: int | None
|
|
36
36
|
|
|
37
37
|
# optional or calculated automatically
|
|
38
|
-
thinking:
|
|
39
|
-
model_external:
|
|
40
|
-
region:
|
|
41
|
-
logprobs:
|
|
42
|
-
finish_reason:
|
|
43
|
-
cost:
|
|
38
|
+
thinking: str | None = None # if model shows thinking tokens
|
|
39
|
+
model_external: str | None = None # the model tag used by the API
|
|
40
|
+
region: str | None = None
|
|
41
|
+
logprobs: list | None = None
|
|
42
|
+
finish_reason: str | None = None # make required later
|
|
43
|
+
cost: float | None = None # calculated automatically
|
|
44
44
|
# set to true if is_error and should be retried with a different model
|
|
45
|
-
retry_with_different_model:
|
|
45
|
+
retry_with_different_model: bool | None = False
|
|
46
46
|
# set to true if should NOT retry with the same model (unrecoverable error)
|
|
47
|
-
give_up_if_no_other_models:
|
|
47
|
+
give_up_if_no_other_models: bool | None = False
|
|
48
48
|
|
|
49
49
|
def __post_init__(self):
|
|
50
50
|
# calculate cost & get external model name
|
|
@@ -138,9 +138,9 @@ class APIRequestBase(ABC):
|
|
|
138
138
|
request_timeout: int = 30,
|
|
139
139
|
sampling_params: SamplingParams = SamplingParams(),
|
|
140
140
|
logprobs: bool = False,
|
|
141
|
-
top_logprobs:
|
|
142
|
-
pbar:
|
|
143
|
-
callback:
|
|
141
|
+
top_logprobs: int | None = None,
|
|
142
|
+
pbar: tqdm | None = None,
|
|
143
|
+
callback: Callable | None = None,
|
|
144
144
|
debug: bool = False,
|
|
145
145
|
all_model_names: list[str] | None = None,
|
|
146
146
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
@@ -185,8 +185,7 @@ class APIRequestBase(ABC):
|
|
|
185
185
|
def handle_success(self, data):
|
|
186
186
|
self.call_callback()
|
|
187
187
|
self.increment_pbar()
|
|
188
|
-
self.status_tracker.
|
|
189
|
-
self.status_tracker.num_tasks_succeeded += 1
|
|
188
|
+
self.status_tracker.task_succeeded(self.task_id)
|
|
190
189
|
|
|
191
190
|
def handle_error(self, create_new_request=False, give_up_if_no_other_models=False):
|
|
192
191
|
"""
|
|
@@ -215,8 +214,7 @@ class APIRequestBase(ABC):
|
|
|
215
214
|
print(
|
|
216
215
|
f"No other models to try for task {self.task_id}. Giving up."
|
|
217
216
|
)
|
|
218
|
-
self.status_tracker.
|
|
219
|
-
self.status_tracker.num_tasks_failed += 1
|
|
217
|
+
self.status_tracker.task_failed(self.task_id)
|
|
220
218
|
else:
|
|
221
219
|
print(
|
|
222
220
|
f"No other models to try for task {self.task_id}. Retrying with same model."
|
|
@@ -263,8 +261,7 @@ class APIRequestBase(ABC):
|
|
|
263
261
|
self.results_arr.append(new_request)
|
|
264
262
|
else:
|
|
265
263
|
print(f"Task {self.task_id} out of tries.")
|
|
266
|
-
self.status_tracker.
|
|
267
|
-
self.status_tracker.num_tasks_failed += 1
|
|
264
|
+
self.status_tracker.task_failed(self.task_id)
|
|
268
265
|
|
|
269
266
|
async def call_api(self):
|
|
270
267
|
try:
|
|
@@ -308,7 +305,6 @@ class APIRequestBase(ABC):
|
|
|
308
305
|
|
|
309
306
|
except Exception as e:
|
|
310
307
|
raise_if_modal_exception(e)
|
|
311
|
-
# print(f"Unexpected error {type(e).__name__}: {str(e) or 'No message.'}")
|
|
312
308
|
self.result.append(
|
|
313
309
|
APIResponse(
|
|
314
310
|
id=self.task_id,
|
|
@@ -342,9 +338,9 @@ def create_api_request(
|
|
|
342
338
|
request_timeout: int = 30,
|
|
343
339
|
sampling_params: SamplingParams = SamplingParams(),
|
|
344
340
|
logprobs: bool = False,
|
|
345
|
-
top_logprobs:
|
|
346
|
-
pbar:
|
|
347
|
-
callback:
|
|
341
|
+
top_logprobs: int | None = None,
|
|
342
|
+
pbar: tqdm | None = None,
|
|
343
|
+
callback: Callable | None = None,
|
|
348
344
|
all_model_names: list[str] | None = None,
|
|
349
345
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
350
346
|
) -> APIRequestBase:
|
|
@@ -4,9 +4,8 @@ import asyncio
|
|
|
4
4
|
from aiohttp import ClientResponse
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
|
-
import time
|
|
8
7
|
from tqdm import tqdm
|
|
9
|
-
from typing import
|
|
8
|
+
from typing import Callable
|
|
10
9
|
from lm_deluge.prompt import Conversation
|
|
11
10
|
from .base import APIRequestBase, APIResponse
|
|
12
11
|
|
|
@@ -29,8 +28,8 @@ class CohereRequest(APIRequestBase):
|
|
|
29
28
|
retry_queue: asyncio.Queue,
|
|
30
29
|
request_timeout: int = 30,
|
|
31
30
|
sampling_params: SamplingParams = SamplingParams(),
|
|
32
|
-
pbar:
|
|
33
|
-
callback:
|
|
31
|
+
pbar: tqdm | None = None,
|
|
32
|
+
callback: Callable | None = None,
|
|
34
33
|
debug: bool = False,
|
|
35
34
|
all_model_names: list[str] | None = None,
|
|
36
35
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
@@ -56,7 +55,7 @@ class CohereRequest(APIRequestBase):
|
|
|
56
55
|
|
|
57
56
|
self.model = APIModel.from_registry(model_name)
|
|
58
57
|
self.url = f"{self.model.api_base}/chat"
|
|
59
|
-
|
|
58
|
+
messages = prompt.to_cohere()
|
|
60
59
|
|
|
61
60
|
self.request_header = {
|
|
62
61
|
"Authorization": f"bearer {os.getenv(self.model.api_key_env_var)}",
|
|
@@ -66,16 +65,12 @@ class CohereRequest(APIRequestBase):
|
|
|
66
65
|
|
|
67
66
|
self.request_json = {
|
|
68
67
|
"model": self.model.name,
|
|
69
|
-
"
|
|
70
|
-
"message": last_user_message,
|
|
68
|
+
"messages": messages,
|
|
71
69
|
"temperature": sampling_params.temperature,
|
|
72
70
|
"top_p": sampling_params.top_p,
|
|
73
71
|
"max_tokens": sampling_params.max_new_tokens,
|
|
74
72
|
}
|
|
75
73
|
|
|
76
|
-
if self.system_message:
|
|
77
|
-
self.request_json["preamble"] = self.system_message
|
|
78
|
-
|
|
79
74
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
80
75
|
is_error = False
|
|
81
76
|
error_message = None
|
|
@@ -118,8 +113,7 @@ class CohereRequest(APIRequestBase):
|
|
|
118
113
|
or "overloaded" in error_message.lower()
|
|
119
114
|
):
|
|
120
115
|
error_message += " (Rate limit error, triggering cooldown.)"
|
|
121
|
-
self.status_tracker.
|
|
122
|
-
self.status_tracker.num_rate_limit_errors += 1
|
|
116
|
+
self.status_tracker.rate_limit_exceeded()
|
|
123
117
|
if "context length" in error_message:
|
|
124
118
|
error_message += " (Context length exceeded, set retries to 0.)"
|
|
125
119
|
self.attempts_left = 0
|
|
@@ -55,8 +55,8 @@
|
|
|
55
55
|
# retry_queue: asyncio.Queue,
|
|
56
56
|
# request_timeout: int = 30,
|
|
57
57
|
# sampling_params: SamplingParams = SamplingParams(),
|
|
58
|
-
# pbar:
|
|
59
|
-
# callback:
|
|
58
|
+
# pbar: tqdm | None = None,
|
|
59
|
+
# callback: Callable | None = None,
|
|
60
60
|
# debug: bool = False,
|
|
61
61
|
# all_model_names: list[str] | None = None,
|
|
62
62
|
# all_sampling_params: list[SamplingParams] | None = None,
|
|
@@ -175,8 +175,8 @@
|
|
|
175
175
|
# results_arr: list,
|
|
176
176
|
# request_timeout: int = 30,
|
|
177
177
|
# sampling_params: SamplingParams = SamplingParams(),
|
|
178
|
-
# pbar:
|
|
179
|
-
# callback:
|
|
178
|
+
# pbar: tqdm | None = None,
|
|
179
|
+
# callback: Callable | None = None,
|
|
180
180
|
# debug: bool = False,
|
|
181
181
|
# all_model_names: list[str] | None = None,
|
|
182
182
|
# all_sampling_params: list[SamplingParams] | None = None,
|
|
@@ -25,8 +25,8 @@
|
|
|
25
25
|
# results_arr: list,
|
|
26
26
|
# request_timeout: int = 30,
|
|
27
27
|
# sampling_params: SamplingParams = SamplingParams(),
|
|
28
|
-
# pbar:
|
|
29
|
-
# callback:
|
|
28
|
+
# pbar: tqdm | None = None,
|
|
29
|
+
# callback: Callable | None = None,
|
|
30
30
|
# debug: bool = False,
|
|
31
31
|
# all_model_names: list[str] = None,
|
|
32
32
|
# all_sampling_params: list[SamplingParams] = None,
|
|
@@ -27,8 +27,8 @@
|
|
|
27
27
|
# results_arr: list,
|
|
28
28
|
# request_timeout: int = 30,
|
|
29
29
|
# sampling_params: SamplingParams = SamplingParams(),
|
|
30
|
-
# pbar:
|
|
31
|
-
# callback:
|
|
30
|
+
# pbar: tqdm | None = None,
|
|
31
|
+
# callback: Callable | None = None,
|
|
32
32
|
# debug: bool = False,
|
|
33
33
|
# all_model_names: list[str] = None,
|
|
34
34
|
# all_sampling_params: list[SamplingParams] = None,
|
|
@@ -3,9 +3,8 @@ import warnings
|
|
|
3
3
|
from aiohttp import ClientResponse
|
|
4
4
|
import json
|
|
5
5
|
import os
|
|
6
|
-
import time
|
|
7
6
|
from tqdm.auto import tqdm
|
|
8
|
-
from typing import
|
|
7
|
+
from typing import Callable
|
|
9
8
|
|
|
10
9
|
from .base import APIRequestBase, APIResponse
|
|
11
10
|
from ..prompt import Conversation
|
|
@@ -29,9 +28,9 @@ class OpenAIRequest(APIRequestBase):
|
|
|
29
28
|
request_timeout: int = 30,
|
|
30
29
|
sampling_params: SamplingParams = SamplingParams(),
|
|
31
30
|
logprobs: bool = False,
|
|
32
|
-
top_logprobs:
|
|
33
|
-
pbar:
|
|
34
|
-
callback:
|
|
31
|
+
top_logprobs: int | None = None,
|
|
32
|
+
pbar: tqdm | None = None,
|
|
33
|
+
callback: Callable | None = None,
|
|
35
34
|
debug: bool = False,
|
|
36
35
|
all_model_names: list[str] | None = None,
|
|
37
36
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
@@ -124,8 +123,7 @@ class OpenAIRequest(APIRequestBase):
|
|
|
124
123
|
if is_error and error_message is not None:
|
|
125
124
|
if "rate limit" in error_message.lower() or status_code == 429:
|
|
126
125
|
error_message += " (Rate limit error, triggering cooldown.)"
|
|
127
|
-
self.status_tracker.
|
|
128
|
-
self.status_tracker.num_rate_limit_errors += 1
|
|
126
|
+
self.status_tracker.rate_limit_exceeded()
|
|
129
127
|
if "context length" in error_message:
|
|
130
128
|
error_message += " (Context length exceeded, set retries to 0.)"
|
|
131
129
|
self.attempts_left = 0
|
|
@@ -5,7 +5,7 @@ import json
|
|
|
5
5
|
import os
|
|
6
6
|
import time
|
|
7
7
|
from tqdm import tqdm
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import Callable
|
|
9
9
|
|
|
10
10
|
from lm_deluge.prompt import Conversation
|
|
11
11
|
from .base import APIRequestBase, APIResponse
|
|
@@ -57,8 +57,8 @@ class VertexAnthropicRequest(APIRequestBase):
|
|
|
57
57
|
results_arr: list,
|
|
58
58
|
request_timeout: int = 30,
|
|
59
59
|
sampling_params: SamplingParams = SamplingParams(),
|
|
60
|
-
pbar:
|
|
61
|
-
callback:
|
|
60
|
+
pbar: tqdm | None = None,
|
|
61
|
+
callback: Callable | None = None,
|
|
62
62
|
debug: bool = False,
|
|
63
63
|
):
|
|
64
64
|
super().__init__(
|
|
@@ -141,8 +141,7 @@ class VertexAnthropicRequest(APIRequestBase):
|
|
|
141
141
|
or status_code == 429
|
|
142
142
|
):
|
|
143
143
|
error_message += " (Rate limit error, triggering cooldown.)"
|
|
144
|
-
self.status_tracker.
|
|
145
|
-
self.status_tracker.num_rate_limit_errors += 1
|
|
144
|
+
self.status_tracker.rate_limit_exceeded()
|
|
146
145
|
if "context length" in error_message:
|
|
147
146
|
error_message += " (Context length exceeded, set retries to 0.)"
|
|
148
147
|
self.attempts_left = 0
|
|
@@ -185,8 +184,8 @@ class GeminiRequest(APIRequestBase):
|
|
|
185
184
|
results_arr: list,
|
|
186
185
|
request_timeout: int = 30,
|
|
187
186
|
sampling_params: SamplingParams = SamplingParams(),
|
|
188
|
-
pbar:
|
|
189
|
-
callback:
|
|
187
|
+
pbar: tqdm | None = None,
|
|
188
|
+
callback: Callable | None = None,
|
|
190
189
|
debug: bool = False,
|
|
191
190
|
all_model_names: list[str] | None = None,
|
|
192
191
|
all_sampling_params: list[SamplingParams] | None = None,
|
|
@@ -302,16 +301,14 @@ class GeminiRequest(APIRequestBase):
|
|
|
302
301
|
error_message = "Finish reason SAFETY."
|
|
303
302
|
retry_with_different_model = True
|
|
304
303
|
else:
|
|
305
|
-
print("Actual structure of response:")
|
|
306
|
-
print(data)
|
|
304
|
+
print("Actual structure of response:", data)
|
|
307
305
|
is_error = True
|
|
308
306
|
error_message = "No content in response."
|
|
309
307
|
except Exception as e:
|
|
310
308
|
is_error = True
|
|
311
309
|
error_message = f"Error calling .json() on response w/ status {status_code}: {e.__class__} {e}"
|
|
312
310
|
if isinstance(e, KeyError):
|
|
313
|
-
print("Actual structure of response:")
|
|
314
|
-
print(data)
|
|
311
|
+
print("Actual structure of response:", data)
|
|
315
312
|
elif "json" in (mimetype or "").lower():
|
|
316
313
|
is_error = True
|
|
317
314
|
data = await http_response.json()
|
|
@@ -332,8 +329,7 @@ class GeminiRequest(APIRequestBase):
|
|
|
332
329
|
status_code == 429
|
|
333
330
|
):
|
|
334
331
|
error_message += " (Rate limit error, triggering cooldown & retrying with different model.)"
|
|
335
|
-
self.status_tracker.
|
|
336
|
-
self.status_tracker.num_rate_limit_errors += 1
|
|
332
|
+
self.status_tracker.rate_limit_exceeded()
|
|
337
333
|
retry_with_different_model = (
|
|
338
334
|
True # if possible, retry with a different model
|
|
339
335
|
)
|
|
@@ -5,7 +5,7 @@ import numpy as np
|
|
|
5
5
|
import time
|
|
6
6
|
import yaml
|
|
7
7
|
from dataclasses import dataclass
|
|
8
|
-
from typing import Sequence, overload, Literal,
|
|
8
|
+
from typing import Sequence, overload, Literal, Any
|
|
9
9
|
from tqdm.auto import tqdm
|
|
10
10
|
|
|
11
11
|
from lm_deluge.prompt import Conversation
|
|
@@ -31,11 +31,11 @@ class ClientConfig:
|
|
|
31
31
|
max_concurrent_requests: int
|
|
32
32
|
max_attempts: int
|
|
33
33
|
request_timeout: int
|
|
34
|
-
sampling_params:
|
|
35
|
-
model_weights:
|
|
34
|
+
sampling_params: SamplingParams | list[SamplingParams]
|
|
35
|
+
model_weights: list[float] | Literal["uniform", "rate_limit"]
|
|
36
36
|
logprobs: bool = False
|
|
37
|
-
top_logprobs:
|
|
38
|
-
cache:
|
|
37
|
+
top_logprobs: int | None = None
|
|
38
|
+
cache: Any = None
|
|
39
39
|
|
|
40
40
|
@classmethod
|
|
41
41
|
def from_dict(cls, config_dict: dict):
|
|
@@ -82,23 +82,21 @@ class LLMClient:
|
|
|
82
82
|
Handles models, sampling params for each model, model weights, rate limits, etc.
|
|
83
83
|
"""
|
|
84
84
|
|
|
85
|
-
pass
|
|
86
|
-
|
|
87
85
|
def __init__(
|
|
88
86
|
self,
|
|
89
87
|
model_names: list[str],
|
|
90
88
|
max_requests_per_minute: int,
|
|
91
89
|
max_tokens_per_minute: int,
|
|
92
90
|
max_concurrent_requests: int,
|
|
93
|
-
sampling_params:
|
|
94
|
-
model_weights:
|
|
91
|
+
sampling_params: SamplingParams | list[SamplingParams] = SamplingParams(),
|
|
92
|
+
model_weights: list[float] | Literal["uniform", "rate_limit"] = "uniform",
|
|
95
93
|
max_attempts: int = 5,
|
|
96
94
|
request_timeout: int = 30,
|
|
97
95
|
logprobs: bool = False,
|
|
98
|
-
top_logprobs:
|
|
96
|
+
top_logprobs: int | None = None,
|
|
99
97
|
use_qps: bool = False,
|
|
100
98
|
debug: bool = False,
|
|
101
|
-
cache:
|
|
99
|
+
cache: Any = None,
|
|
102
100
|
):
|
|
103
101
|
self.models = model_names
|
|
104
102
|
if isinstance(sampling_params, SamplingParams):
|
|
@@ -154,7 +152,7 @@ class LLMClient:
|
|
|
154
152
|
self.cache = cache
|
|
155
153
|
|
|
156
154
|
@classmethod
|
|
157
|
-
def from_config(cls, config: ClientConfig, cache:
|
|
155
|
+
def from_config(cls, config: ClientConfig, cache: Any = None):
|
|
158
156
|
return cls(
|
|
159
157
|
model_names=config.model_names,
|
|
160
158
|
max_requests_per_minute=config.max_requests_per_minute,
|
|
@@ -168,25 +166,25 @@ class LLMClient:
|
|
|
168
166
|
)
|
|
169
167
|
|
|
170
168
|
@classmethod
|
|
171
|
-
def from_yaml(cls, file_path: str, cache:
|
|
169
|
+
def from_yaml(cls, file_path: str, cache: Any = None):
|
|
172
170
|
return cls.from_config(ClientConfig.from_yaml(file_path), cache=cache)
|
|
173
171
|
|
|
174
172
|
@classmethod
|
|
175
173
|
def basic(
|
|
176
174
|
cls,
|
|
177
|
-
model:
|
|
175
|
+
model: str | list[str],
|
|
178
176
|
max_requests_per_minute: int = 5_000,
|
|
179
177
|
max_tokens_per_minute: int = 1_000_000,
|
|
180
178
|
max_concurrent_requests: int = 1_000,
|
|
181
179
|
temperature: float = 0.75,
|
|
182
180
|
max_new_tokens: int = 1000,
|
|
183
181
|
reasoning_effort: Literal[None, "low", "medium", "high"] = None,
|
|
184
|
-
model_weights:
|
|
182
|
+
model_weights: list[float] | Literal["uniform", "rate_limit"] = "uniform",
|
|
185
183
|
logprobs: bool = False,
|
|
186
|
-
top_logprobs:
|
|
184
|
+
top_logprobs: int | None = None,
|
|
187
185
|
max_attempts: int = 5,
|
|
188
186
|
request_timeout: int = 30,
|
|
189
|
-
cache:
|
|
187
|
+
cache: Any = None,
|
|
190
188
|
):
|
|
191
189
|
model_names = model if isinstance(model, list) else [model]
|
|
192
190
|
return cls(
|
|
@@ -222,8 +220,6 @@ class LLMClient:
|
|
|
222
220
|
top_logprobs=self.top_logprobs,
|
|
223
221
|
)
|
|
224
222
|
|
|
225
|
-
from typing import Union, Literal
|
|
226
|
-
|
|
227
223
|
@overload
|
|
228
224
|
async def process_prompts_async(
|
|
229
225
|
self,
|
|
@@ -485,9 +481,9 @@ class LLMClient:
|
|
|
485
481
|
|
|
486
482
|
|
|
487
483
|
def api_prompts_dry_run(
|
|
488
|
-
ids:
|
|
484
|
+
ids: np.ndarray | list[int],
|
|
489
485
|
prompts: list[Conversation],
|
|
490
|
-
models:
|
|
486
|
+
models: str | list[str],
|
|
491
487
|
model_weights: list[float],
|
|
492
488
|
sampling_params: list[SamplingParams],
|
|
493
489
|
max_tokens_per_minute: int = 500_000,
|
|
@@ -543,19 +539,19 @@ def api_prompts_dry_run(
|
|
|
543
539
|
|
|
544
540
|
|
|
545
541
|
async def process_api_prompts_async(
|
|
546
|
-
ids:
|
|
542
|
+
ids: np.ndarray | list[int],
|
|
547
543
|
prompts: list[Conversation],
|
|
548
|
-
models:
|
|
544
|
+
models: str | list[str],
|
|
549
545
|
model_weights: list[float],
|
|
550
546
|
sampling_params: list[SamplingParams],
|
|
551
547
|
logprobs: bool,
|
|
552
|
-
top_logprobs:
|
|
548
|
+
top_logprobs: int | None,
|
|
553
549
|
max_attempts: int = 5,
|
|
554
550
|
max_tokens_per_minute: int = 500_000,
|
|
555
551
|
max_requests_per_minute: int = 1_000,
|
|
556
552
|
max_concurrent_requests: int = 1_000,
|
|
557
553
|
request_timeout: int = 30,
|
|
558
|
-
progress_bar:
|
|
554
|
+
progress_bar: tqdm | None = None,
|
|
559
555
|
use_qps: bool = False,
|
|
560
556
|
verbose: bool = False,
|
|
561
557
|
):
|
|
@@ -712,28 +708,17 @@ async def process_api_prompts_async(
|
|
|
712
708
|
await asyncio.sleep(seconds_to_sleep_each_loop)
|
|
713
709
|
|
|
714
710
|
# if a rate limit error was hit recently, pause to cool down
|
|
715
|
-
|
|
716
|
-
|
|
711
|
+
remaining_seconds_to_pause = max(
|
|
712
|
+
0,
|
|
713
|
+
seconds_to_pause_after_rate_limit_error
|
|
714
|
+
- status_tracker.time_since_rate_limit_error,
|
|
717
715
|
)
|
|
718
|
-
if
|
|
719
|
-
remaining_seconds_to_pause = (
|
|
720
|
-
seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
|
|
721
|
-
)
|
|
716
|
+
if remaining_seconds_to_pause > 0:
|
|
722
717
|
await asyncio.sleep(remaining_seconds_to_pause)
|
|
723
|
-
|
|
724
|
-
print(
|
|
725
|
-
f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
|
|
726
|
-
)
|
|
718
|
+
print(f"Pausing {remaining_seconds_to_pause}s to cool down.")
|
|
727
719
|
|
|
728
720
|
# after finishing, log final status
|
|
729
|
-
|
|
730
|
-
print(
|
|
731
|
-
f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed."
|
|
732
|
-
)
|
|
733
|
-
if status_tracker.num_rate_limit_errors > 0:
|
|
734
|
-
print(
|
|
735
|
-
f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
|
|
736
|
-
)
|
|
721
|
+
status_tracker.log_final_status()
|
|
737
722
|
if verbose:
|
|
738
723
|
print(
|
|
739
724
|
f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
|
|
@@ -5,7 +5,7 @@ import aiohttp
|
|
|
5
5
|
from tqdm.auto import tqdm
|
|
6
6
|
import asyncio
|
|
7
7
|
import time
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
from dataclasses import dataclass
|
|
10
10
|
from .tracker import StatusTracker
|
|
11
11
|
|
|
@@ -58,7 +58,7 @@ class EmbeddingRequest:
|
|
|
58
58
|
status_tracker: StatusTracker,
|
|
59
59
|
retry_queue: asyncio.Queue,
|
|
60
60
|
request_timeout: int,
|
|
61
|
-
pbar:
|
|
61
|
+
pbar: tqdm | None = None,
|
|
62
62
|
**kwargs, # openai or cohere specific params
|
|
63
63
|
):
|
|
64
64
|
self.task_id = task_id
|
|
@@ -78,8 +78,7 @@ class EmbeddingRequest:
|
|
|
78
78
|
|
|
79
79
|
def handle_success(self):
|
|
80
80
|
self.increment_pbar()
|
|
81
|
-
self.status_tracker.
|
|
82
|
-
self.status_tracker.num_tasks_succeeded += 1
|
|
81
|
+
self.status_tracker.task_succeeded(self.task_id)
|
|
83
82
|
|
|
84
83
|
def handle_error(self):
|
|
85
84
|
last_result: EmbeddingResponse = self.result[-1]
|
|
@@ -94,8 +93,7 @@ class EmbeddingRequest:
|
|
|
94
93
|
return
|
|
95
94
|
else:
|
|
96
95
|
print(f"Task {self.task_id} out of tries.")
|
|
97
|
-
self.status_tracker.
|
|
98
|
-
self.status_tracker.num_tasks_failed += 1
|
|
96
|
+
self.status_tracker.task_failed(self.task_id)
|
|
99
97
|
|
|
100
98
|
async def handle_response(self, response: aiohttp.ClientResponse):
|
|
101
99
|
try:
|
|
@@ -217,7 +215,7 @@ class EmbeddingResponse:
|
|
|
217
215
|
id: int
|
|
218
216
|
status_code: int | None
|
|
219
217
|
is_error: bool
|
|
220
|
-
error_message:
|
|
218
|
+
error_message: str | None
|
|
221
219
|
texts: list[str]
|
|
222
220
|
embeddings: list[list[float]]
|
|
223
221
|
|
|
@@ -282,8 +280,7 @@ async def embed_parallel_async(
|
|
|
282
280
|
pbar=pbar,
|
|
283
281
|
**kwargs,
|
|
284
282
|
)
|
|
285
|
-
status_tracker.
|
|
286
|
-
status_tracker.num_tasks_in_progress += 1
|
|
283
|
+
status_tracker.start_task(batch_id)
|
|
287
284
|
results.append(next_request)
|
|
288
285
|
|
|
289
286
|
except StopIteration:
|
|
@@ -333,29 +330,17 @@ async def embed_parallel_async(
|
|
|
333
330
|
await asyncio.sleep(seconds_to_sleep_each_loop)
|
|
334
331
|
|
|
335
332
|
# if a rate limit error was hit recently, pause to cool down
|
|
336
|
-
|
|
337
|
-
|
|
333
|
+
remaining_seconds_to_pause = max(
|
|
334
|
+
0,
|
|
335
|
+
seconds_to_pause_after_rate_limit_error
|
|
336
|
+
- status_tracker.time_since_rate_limit_error,
|
|
338
337
|
)
|
|
339
|
-
if
|
|
340
|
-
remaining_seconds_to_pause = (
|
|
341
|
-
seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
|
|
342
|
-
)
|
|
338
|
+
if remaining_seconds_to_pause > 0:
|
|
343
339
|
await asyncio.sleep(remaining_seconds_to_pause)
|
|
344
|
-
|
|
345
|
-
print(
|
|
346
|
-
f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
|
|
347
|
-
)
|
|
340
|
+
print(f"Pausing {remaining_seconds_to_pause}s to cool down.")
|
|
348
341
|
|
|
349
342
|
# after finishing, log final status
|
|
350
|
-
|
|
351
|
-
print(
|
|
352
|
-
f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed."
|
|
353
|
-
)
|
|
354
|
-
if status_tracker.num_rate_limit_errors > 0:
|
|
355
|
-
print(
|
|
356
|
-
f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
|
|
357
|
-
)
|
|
358
|
-
|
|
343
|
+
status_tracker.log_final_status()
|
|
359
344
|
print(
|
|
360
345
|
f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
|
|
361
346
|
)
|
|
@@ -3,7 +3,7 @@ import json
|
|
|
3
3
|
from ..prompt import Conversation
|
|
4
4
|
import asyncio
|
|
5
5
|
from ..client import LLMClient
|
|
6
|
-
from typing import
|
|
6
|
+
from typing import Any
|
|
7
7
|
from ..util.json import load_json
|
|
8
8
|
|
|
9
9
|
try:
|
|
@@ -16,8 +16,8 @@ async def extract_async(
|
|
|
16
16
|
inputs: list[str | Any],
|
|
17
17
|
schema: Any,
|
|
18
18
|
client: LLMClient,
|
|
19
|
-
document_name:
|
|
20
|
-
object_name:
|
|
19
|
+
document_name: str | None = None,
|
|
20
|
+
object_name: str | None = None,
|
|
21
21
|
show_progress: bool = True,
|
|
22
22
|
return_prompts: bool = False,
|
|
23
23
|
):
|
|
@@ -93,8 +93,8 @@ def extract(
|
|
|
93
93
|
inputs: list[str | Any],
|
|
94
94
|
schema: Any,
|
|
95
95
|
client: LLMClient,
|
|
96
|
-
document_name:
|
|
97
|
-
object_name:
|
|
96
|
+
document_name: str | None = None,
|
|
97
|
+
object_name: str | None = None,
|
|
98
98
|
show_progress: bool = True,
|
|
99
99
|
return_prompts: bool = False,
|
|
100
100
|
):
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import random
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
|
-
from typing import Optional
|
|
4
3
|
from .gemini_limits import gemini_1_5_pro_limits, gemini_flash_limits
|
|
5
4
|
|
|
6
5
|
registry = {
|
|
@@ -928,15 +927,15 @@ class APIModel:
|
|
|
928
927
|
api_base: str
|
|
929
928
|
api_key_env_var: str
|
|
930
929
|
api_spec: str
|
|
931
|
-
input_cost:
|
|
932
|
-
output_cost:
|
|
930
|
+
input_cost: float | None = 0 # $ per million input tokens
|
|
931
|
+
output_cost: float | None = 0 # $ per million output tokens
|
|
933
932
|
supports_json: bool = False
|
|
934
933
|
supports_logprobs: bool = False
|
|
935
934
|
reasoning_model: bool = False
|
|
936
935
|
regions: list[str] | dict[str, int] = field(default_factory=list)
|
|
937
936
|
tokens_per_minute: int | None = None
|
|
938
937
|
requests_per_minute: int | None = None
|
|
939
|
-
gpus:
|
|
938
|
+
gpus: list[str] | None = None
|
|
940
939
|
|
|
941
940
|
@classmethod
|
|
942
941
|
def from_registry(cls, name: str):
|
|
@@ -950,7 +949,7 @@ class APIModel:
|
|
|
950
949
|
regions = self.regions
|
|
951
950
|
weights = [1] * len(regions)
|
|
952
951
|
elif isinstance(self.regions, dict):
|
|
953
|
-
regions = self.regions.keys()
|
|
952
|
+
regions = list(self.regions.keys())
|
|
954
953
|
weights = self.regions.values()
|
|
955
954
|
else:
|
|
956
955
|
raise ValueError("no regions to sample")
|
|
@@ -4,7 +4,6 @@ import aiohttp
|
|
|
4
4
|
from tqdm.auto import tqdm
|
|
5
5
|
import asyncio
|
|
6
6
|
import time
|
|
7
|
-
from typing import Optional
|
|
8
7
|
from dataclasses import dataclass
|
|
9
8
|
from .tracker import StatusTracker
|
|
10
9
|
|
|
@@ -28,7 +27,7 @@ class RerankingRequest:
|
|
|
28
27
|
status_tracker: StatusTracker,
|
|
29
28
|
retry_queue: asyncio.Queue,
|
|
30
29
|
request_timeout: int,
|
|
31
|
-
pbar:
|
|
30
|
+
pbar: tqdm | None = None,
|
|
32
31
|
):
|
|
33
32
|
self.task_id = task_id
|
|
34
33
|
self.model_name = model_name
|
|
@@ -48,8 +47,7 @@ class RerankingRequest:
|
|
|
48
47
|
|
|
49
48
|
def handle_success(self):
|
|
50
49
|
self.increment_pbar()
|
|
51
|
-
self.status_tracker.
|
|
52
|
-
self.status_tracker.num_tasks_succeeded += 1
|
|
50
|
+
self.status_tracker.task_succeeded(self.task_id)
|
|
53
51
|
|
|
54
52
|
def handle_error(self):
|
|
55
53
|
"""
|
|
@@ -69,8 +67,7 @@ class RerankingRequest:
|
|
|
69
67
|
return
|
|
70
68
|
else:
|
|
71
69
|
print(f"Task {self.task_id} out of tries.")
|
|
72
|
-
self.status_tracker.
|
|
73
|
-
self.status_tracker.num_tasks_failed += 1
|
|
70
|
+
self.status_tracker.task_failed(self.task_id)
|
|
74
71
|
|
|
75
72
|
async def handle_response(self, response: aiohttp.ClientResponse):
|
|
76
73
|
try:
|
|
@@ -127,8 +124,9 @@ class RerankingRequest:
|
|
|
127
124
|
try:
|
|
128
125
|
self.status_tracker.total_requests += 1
|
|
129
126
|
async with aiohttp.ClientSession() as session:
|
|
127
|
+
timeout = aiohttp.ClientTimeout(total=self.request_timeout)
|
|
130
128
|
async with session.post(
|
|
131
|
-
url, headers=headers, json=data, timeout=
|
|
129
|
+
url, headers=headers, json=data, timeout=timeout
|
|
132
130
|
) as response:
|
|
133
131
|
# print("got response!!")
|
|
134
132
|
response_obj: RerankingResponse = await self.handle_response(
|
|
@@ -176,7 +174,7 @@ class RerankingResponse:
|
|
|
176
174
|
id: int
|
|
177
175
|
status_code: int | None
|
|
178
176
|
is_error: bool
|
|
179
|
-
error_message:
|
|
177
|
+
error_message: str | None
|
|
180
178
|
query: str
|
|
181
179
|
documents: list[str]
|
|
182
180
|
top_k_indices: list[int]
|
|
@@ -196,7 +194,7 @@ async def rerank_parallel_async(
|
|
|
196
194
|
max_requests_per_minute: int = 4_000,
|
|
197
195
|
max_concurrent_requests: int = 500,
|
|
198
196
|
request_timeout: int = 10,
|
|
199
|
-
progress_bar:
|
|
197
|
+
progress_bar: tqdm | None = None,
|
|
200
198
|
):
|
|
201
199
|
"""Processes rerank requests in parallel, throttling to stay under rate limits."""
|
|
202
200
|
ids = range(len(queries))
|
|
@@ -243,8 +241,7 @@ async def rerank_parallel_async(
|
|
|
243
241
|
request_timeout=request_timeout,
|
|
244
242
|
pbar=progress_bar,
|
|
245
243
|
)
|
|
246
|
-
status_tracker.
|
|
247
|
-
status_tracker.num_tasks_in_progress += 1
|
|
244
|
+
status_tracker.start_task(req_id)
|
|
248
245
|
results.append(next_request)
|
|
249
246
|
|
|
250
247
|
except StopIteration:
|
|
@@ -294,28 +291,17 @@ async def rerank_parallel_async(
|
|
|
294
291
|
await asyncio.sleep(seconds_to_sleep_each_loop)
|
|
295
292
|
|
|
296
293
|
# if a rate limit error was hit recently, pause to cool down
|
|
297
|
-
|
|
298
|
-
|
|
294
|
+
remaining_seconds_to_pause = max(
|
|
295
|
+
0,
|
|
296
|
+
seconds_to_pause_after_rate_limit_error
|
|
297
|
+
- status_tracker.time_since_rate_limit_error,
|
|
299
298
|
)
|
|
300
|
-
if
|
|
301
|
-
remaining_seconds_to_pause = (
|
|
302
|
-
seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
|
|
303
|
-
)
|
|
299
|
+
if remaining_seconds_to_pause > 0:
|
|
304
300
|
await asyncio.sleep(remaining_seconds_to_pause)
|
|
305
|
-
|
|
306
|
-
print(
|
|
307
|
-
f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
|
|
308
|
-
)
|
|
301
|
+
print(f"Pausing {remaining_seconds_to_pause}s to cool down.")
|
|
309
302
|
|
|
310
303
|
# after finishing, log final status
|
|
311
|
-
|
|
312
|
-
print(
|
|
313
|
-
f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed."
|
|
314
|
-
)
|
|
315
|
-
if status_tracker.num_rate_limit_errors > 0:
|
|
316
|
-
print(
|
|
317
|
-
f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
|
|
318
|
-
)
|
|
304
|
+
status_tracker.log_final_status()
|
|
319
305
|
|
|
320
306
|
print(
|
|
321
307
|
f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass
|
|
6
|
+
class StatusTracker:
|
|
7
|
+
num_tasks_started: int = 0
|
|
8
|
+
num_tasks_in_progress: int = 0
|
|
9
|
+
num_tasks_succeeded: int = 0
|
|
10
|
+
num_tasks_failed: int = 0
|
|
11
|
+
num_rate_limit_errors: int = 0
|
|
12
|
+
time_of_last_rate_limit_error: int | float = 0
|
|
13
|
+
total_requests = 0
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def time_since_rate_limit_error(self):
|
|
17
|
+
return time.time() - self.time_of_last_rate_limit_error
|
|
18
|
+
|
|
19
|
+
def start_task(self, task_id):
|
|
20
|
+
self.num_tasks_started += 1
|
|
21
|
+
self.num_tasks_in_progress += 1
|
|
22
|
+
|
|
23
|
+
def rate_limit_exceeded(self):
|
|
24
|
+
self.time_of_last_rate_limit_error = time.time()
|
|
25
|
+
self.num_rate_limit_errors += 1
|
|
26
|
+
|
|
27
|
+
def task_succeeded(self, task_id):
|
|
28
|
+
self.num_tasks_in_progress -= 1
|
|
29
|
+
self.num_tasks_succeeded += 1
|
|
30
|
+
|
|
31
|
+
def task_failed(self, task_id):
|
|
32
|
+
self.num_tasks_in_progress -= 1
|
|
33
|
+
self.num_tasks_failed += 1
|
|
34
|
+
|
|
35
|
+
def log_final_status(self):
|
|
36
|
+
if self.num_tasks_failed > 0:
|
|
37
|
+
print(
|
|
38
|
+
f"{self.num_tasks_failed} / {self.num_tasks_started} requests failed."
|
|
39
|
+
)
|
|
40
|
+
if self.num_rate_limit_errors > 0:
|
|
41
|
+
print(
|
|
42
|
+
f"{self.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
|
|
43
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import numpy as np
|
|
3
|
-
from typing import TypedDict,
|
|
3
|
+
from typing import TypedDict, Callable
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class TopLogprob(TypedDict):
|
|
@@ -403,7 +403,7 @@ def extract_prob(
|
|
|
403
403
|
normalize_top_logprobs: bool = True, # if using top_logprobs, normalize by all the present tokens so they add up to 1
|
|
404
404
|
use_complement: bool = False, # if True, assume there's 2 choices, and return 1 - p if the top token doesn't match
|
|
405
405
|
token_index: int = 0, # get from the first token of the completion by default
|
|
406
|
-
token_match_fn:
|
|
406
|
+
token_match_fn: Callable[[str, str], bool] | None = is_match,
|
|
407
407
|
):
|
|
408
408
|
"""
|
|
409
409
|
Extract the probability of the token from the logprobs object of a single
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: lm_deluge
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.5
|
|
4
4
|
Summary: Python utility for using LLM API models.
|
|
5
5
|
Author-email: Benjamin Anderson <ben@trytaylor.ai>
|
|
6
|
-
Requires-Python: >=3.
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
8
8
|
Requires-Dist: python-dotenv
|
|
9
9
|
Requires-Dist: json5
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
@dataclass
|
|
5
|
-
class StatusTracker:
|
|
6
|
-
num_tasks_started: int = 0
|
|
7
|
-
num_tasks_in_progress: int = 0
|
|
8
|
-
num_tasks_succeeded: int = 0
|
|
9
|
-
num_tasks_failed: int = 0
|
|
10
|
-
num_rate_limit_errors: int = 0
|
|
11
|
-
time_of_last_rate_limit_error: int | float = 0
|
|
12
|
-
total_requests = 0
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|