lm-deluge 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +6 -0
- lm_deluge/api_requests/__init__.py +3 -0
- lm_deluge/api_requests/anthropic.py +177 -0
- lm_deluge/api_requests/base.py +375 -0
- lm_deluge/api_requests/cohere.py +138 -0
- lm_deluge/api_requests/common.py +18 -0
- lm_deluge/api_requests/deprecated/bedrock.py +288 -0
- lm_deluge/api_requests/deprecated/deepseek.py +118 -0
- lm_deluge/api_requests/deprecated/mistral.py +120 -0
- lm_deluge/api_requests/google.py +0 -0
- lm_deluge/api_requests/openai.py +145 -0
- lm_deluge/api_requests/vertex.py +365 -0
- lm_deluge/cache.py +144 -0
- lm_deluge/client.py +760 -0
- lm_deluge/embed.py +392 -0
- lm_deluge/errors.py +8 -0
- lm_deluge/gemini_limits.py +65 -0
- lm_deluge/image.py +200 -0
- lm_deluge/llm_tools/__init__.py +11 -0
- lm_deluge/llm_tools/extract.py +111 -0
- lm_deluge/llm_tools/score.py +71 -0
- lm_deluge/llm_tools/translate.py +44 -0
- lm_deluge/models.py +957 -0
- lm_deluge/prompt.py +355 -0
- lm_deluge/rerank.py +338 -0
- lm_deluge/sampling_params.py +25 -0
- lm_deluge/tool.py +106 -0
- lm_deluge/tracker.py +12 -0
- lm_deluge/util/json.py +167 -0
- lm_deluge/util/logprobs.py +446 -0
- lm_deluge/util/pdf.py +45 -0
- lm_deluge/util/validation.py +46 -0
- lm_deluge/util/xml.py +291 -0
- lm_deluge-0.0.3.dist-info/METADATA +127 -0
- lm_deluge-0.0.3.dist-info/RECORD +37 -0
- lm_deluge-0.0.3.dist-info/WHEEL +5 -0
- lm_deluge-0.0.3.dist-info/top_level.txt +1 -0
lm_deluge/__init__.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from aiohttp import ClientResponse
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import warnings
|
|
6
|
+
import time
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from typing import Optional, Callable
|
|
9
|
+
|
|
10
|
+
from lm_deluge.prompt import Conversation
|
|
11
|
+
from .base import APIRequestBase, APIResponse
|
|
12
|
+
|
|
13
|
+
from ..tracker import StatusTracker
|
|
14
|
+
from ..sampling_params import SamplingParams
|
|
15
|
+
from ..models import APIModel
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AnthropicRequest(APIRequestBase):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
task_id: int,
|
|
22
|
+
# should always be 'role', 'content' keys.
|
|
23
|
+
# internal logic should handle translating to specific API format
|
|
24
|
+
model_name: str, # must correspond to registry
|
|
25
|
+
prompt: Conversation,
|
|
26
|
+
attempts_left: int,
|
|
27
|
+
status_tracker: StatusTracker,
|
|
28
|
+
retry_queue: asyncio.Queue,
|
|
29
|
+
results_arr: list,
|
|
30
|
+
request_timeout: int = 30,
|
|
31
|
+
sampling_params: SamplingParams = SamplingParams(),
|
|
32
|
+
pbar: Optional[tqdm] = None,
|
|
33
|
+
callback: Optional[Callable] = None,
|
|
34
|
+
debug: bool = False,
|
|
35
|
+
# for retries
|
|
36
|
+
all_model_names: list[str] | None = None,
|
|
37
|
+
all_sampling_params: list[SamplingParams] | None = None,
|
|
38
|
+
):
|
|
39
|
+
super().__init__(
|
|
40
|
+
task_id=task_id,
|
|
41
|
+
model_name=model_name,
|
|
42
|
+
prompt=prompt,
|
|
43
|
+
attempts_left=attempts_left,
|
|
44
|
+
status_tracker=status_tracker,
|
|
45
|
+
retry_queue=retry_queue,
|
|
46
|
+
results_arr=results_arr,
|
|
47
|
+
request_timeout=request_timeout,
|
|
48
|
+
sampling_params=sampling_params,
|
|
49
|
+
pbar=pbar,
|
|
50
|
+
callback=callback,
|
|
51
|
+
debug=debug,
|
|
52
|
+
all_model_names=all_model_names,
|
|
53
|
+
all_sampling_params=all_sampling_params,
|
|
54
|
+
)
|
|
55
|
+
self.model = APIModel.from_registry(model_name)
|
|
56
|
+
self.url = f"{self.model.api_base}/messages"
|
|
57
|
+
|
|
58
|
+
self.system_message, messages = prompt.to_anthropic()
|
|
59
|
+
self.request_header = {
|
|
60
|
+
"x-api-key": os.getenv(self.model.api_key_env_var),
|
|
61
|
+
"anthropic-version": "2023-06-01",
|
|
62
|
+
"content-type": "application/json",
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
self.request_json = {
|
|
66
|
+
"model": self.model.name,
|
|
67
|
+
"messages": messages,
|
|
68
|
+
"temperature": self.sampling_params.temperature,
|
|
69
|
+
"top_p": self.sampling_params.top_p,
|
|
70
|
+
"max_tokens": self.sampling_params.max_new_tokens,
|
|
71
|
+
}
|
|
72
|
+
# handle thinking
|
|
73
|
+
if self.model.reasoning_model:
|
|
74
|
+
if sampling_params.reasoning_effort:
|
|
75
|
+
# translate reasoning effort of low, medium, high to budget tokens
|
|
76
|
+
budget = {"low": 1024, "medium": 4096, "high": 16384}.get(
|
|
77
|
+
sampling_params.reasoning_effort
|
|
78
|
+
)
|
|
79
|
+
self.request_json["thinking"] = {
|
|
80
|
+
"type": "enabled",
|
|
81
|
+
"budget_tokens": budget,
|
|
82
|
+
}
|
|
83
|
+
self.request_json.pop("top_p")
|
|
84
|
+
self.request_json["temperature"] = 1.0
|
|
85
|
+
self.request_json["max_tokens"] += (
|
|
86
|
+
budget # assume max tokens is max completion tokens
|
|
87
|
+
)
|
|
88
|
+
else:
|
|
89
|
+
# no thinking
|
|
90
|
+
self.request_json["thinking"] = {"type": "disabled"}
|
|
91
|
+
else:
|
|
92
|
+
if sampling_params.reasoning_effort:
|
|
93
|
+
warnings.warn(
|
|
94
|
+
f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
|
|
95
|
+
)
|
|
96
|
+
if self.system_message is not None:
|
|
97
|
+
self.request_json["system"] = self.system_message
|
|
98
|
+
|
|
99
|
+
# print("request data:", self.request_json)
|
|
100
|
+
|
|
101
|
+
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
102
|
+
is_error = False
|
|
103
|
+
error_message = None
|
|
104
|
+
thinking = None
|
|
105
|
+
completion = None
|
|
106
|
+
input_tokens = None
|
|
107
|
+
output_tokens = None
|
|
108
|
+
status_code = http_response.status
|
|
109
|
+
mimetype = http_response.headers.get("Content-Type", None)
|
|
110
|
+
rate_limits = {}
|
|
111
|
+
for header in [
|
|
112
|
+
"anthropic-ratelimit-requests-limit",
|
|
113
|
+
"anthropic-ratelimit-requests-remaining",
|
|
114
|
+
"anthropic-ratelimit-requests-reset",
|
|
115
|
+
"anthropic-ratelimit-tokens-limit",
|
|
116
|
+
"anthropic-ratelimit-tokens-remaining",
|
|
117
|
+
"anthropic-ratelimit-tokens-reset",
|
|
118
|
+
]:
|
|
119
|
+
rate_limits[header] = http_response.headers.get(header, None)
|
|
120
|
+
if self.debug:
|
|
121
|
+
print(f"Rate limits: {rate_limits}")
|
|
122
|
+
if status_code >= 200 and status_code < 300:
|
|
123
|
+
try:
|
|
124
|
+
data = await http_response.json()
|
|
125
|
+
print("response data:", data)
|
|
126
|
+
content = data["content"] # [0]["text"]
|
|
127
|
+
print("content is length", len(content))
|
|
128
|
+
for item in content:
|
|
129
|
+
if item["type"] == "text":
|
|
130
|
+
completion = item["text"]
|
|
131
|
+
elif item["type"] == "thinking":
|
|
132
|
+
thinking = item["thinking"]
|
|
133
|
+
elif item["type"] == "tool_use":
|
|
134
|
+
continue # TODO: implement and report tool use
|
|
135
|
+
input_tokens = data["usage"]["input_tokens"]
|
|
136
|
+
output_tokens = data["usage"]["output_tokens"]
|
|
137
|
+
except Exception as e:
|
|
138
|
+
is_error = True
|
|
139
|
+
error_message = (
|
|
140
|
+
f"Error calling .json() on response w/ status {status_code}: {e}"
|
|
141
|
+
)
|
|
142
|
+
elif mimetype and "json" in mimetype.lower():
|
|
143
|
+
is_error = True # expected status is 200, otherwise it's an error
|
|
144
|
+
data = await http_response.json()
|
|
145
|
+
error_message = json.dumps(data)
|
|
146
|
+
|
|
147
|
+
else:
|
|
148
|
+
is_error = True
|
|
149
|
+
text = await http_response.text()
|
|
150
|
+
error_message = text
|
|
151
|
+
|
|
152
|
+
# handle special kinds of errors. TODO: make sure these are correct for anthropic
|
|
153
|
+
if is_error and error_message is not None:
|
|
154
|
+
if (
|
|
155
|
+
"rate limit" in error_message.lower()
|
|
156
|
+
or "overloaded" in error_message.lower()
|
|
157
|
+
):
|
|
158
|
+
error_message += " (Rate limit error, triggering cooldown.)"
|
|
159
|
+
self.status_tracker.time_of_last_rate_limit_error = time.time()
|
|
160
|
+
self.status_tracker.num_rate_limit_errors += 1
|
|
161
|
+
if "context length" in error_message:
|
|
162
|
+
error_message += " (Context length exceeded, set retries to 0.)"
|
|
163
|
+
self.attempts_left = 0
|
|
164
|
+
|
|
165
|
+
return APIResponse(
|
|
166
|
+
id=self.task_id,
|
|
167
|
+
status_code=status_code,
|
|
168
|
+
is_error=is_error,
|
|
169
|
+
error_message=error_message,
|
|
170
|
+
prompt=self.prompt,
|
|
171
|
+
completion=completion,
|
|
172
|
+
thinking=thinking,
|
|
173
|
+
model_internal=self.model_name,
|
|
174
|
+
sampling_params=self.sampling_params,
|
|
175
|
+
input_tokens=input_tokens,
|
|
176
|
+
output_tokens=output_tokens,
|
|
177
|
+
)
|
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
import aiohttp
|
|
2
|
+
import asyncio
|
|
3
|
+
import json
|
|
4
|
+
import random
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from typing import Optional, Callable
|
|
9
|
+
|
|
10
|
+
from lm_deluge.prompt import Conversation
|
|
11
|
+
|
|
12
|
+
from ..tracker import StatusTracker
|
|
13
|
+
from ..sampling_params import SamplingParams
|
|
14
|
+
from ..models import APIModel
|
|
15
|
+
from ..errors import raise_if_modal_exception
|
|
16
|
+
from aiohttp import ClientResponse
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class APIResponse:
|
|
21
|
+
# request information
|
|
22
|
+
id: int # should be unique to the request within a given prompt-processing call
|
|
23
|
+
model_internal: str # our internal model tag
|
|
24
|
+
prompt: Conversation
|
|
25
|
+
sampling_params: SamplingParams
|
|
26
|
+
|
|
27
|
+
# http response information
|
|
28
|
+
status_code: int | None
|
|
29
|
+
is_error: Optional[bool]
|
|
30
|
+
error_message: Optional[str]
|
|
31
|
+
|
|
32
|
+
# completion information
|
|
33
|
+
completion: Optional[str]
|
|
34
|
+
input_tokens: Optional[int]
|
|
35
|
+
output_tokens: Optional[int]
|
|
36
|
+
|
|
37
|
+
# optional or calculated automatically
|
|
38
|
+
thinking: Optional[str] = None # if model shows thinking tokens
|
|
39
|
+
model_external: Optional[str] = None # the model tag used by the API
|
|
40
|
+
region: Optional[str] = None
|
|
41
|
+
logprobs: Optional[list] = None
|
|
42
|
+
finish_reason: Optional[str] = None # make required later
|
|
43
|
+
cost: Optional[float] = None # calculated automatically
|
|
44
|
+
# set to true if is_error and should be retried with a different model
|
|
45
|
+
retry_with_different_model: Optional[bool] = False
|
|
46
|
+
# set to true if should NOT retry with the same model (unrecoverable error)
|
|
47
|
+
give_up_if_no_other_models: Optional[bool] = False
|
|
48
|
+
|
|
49
|
+
def __post_init__(self):
|
|
50
|
+
# calculate cost & get external model name
|
|
51
|
+
self.id = int(self.id)
|
|
52
|
+
api_model = APIModel.from_registry(self.model_internal)
|
|
53
|
+
self.model_external = api_model.name
|
|
54
|
+
self.cost = None
|
|
55
|
+
if (
|
|
56
|
+
self.input_tokens is not None
|
|
57
|
+
and self.output_tokens is not None
|
|
58
|
+
and api_model.input_cost is not None
|
|
59
|
+
and api_model.output_cost is not None
|
|
60
|
+
):
|
|
61
|
+
self.cost = (
|
|
62
|
+
self.input_tokens * api_model.input_cost / 1e6
|
|
63
|
+
+ self.output_tokens * api_model.output_cost / 1e6
|
|
64
|
+
)
|
|
65
|
+
elif self.completion is not None:
|
|
66
|
+
print(
|
|
67
|
+
f"Warning: Completion provided without token counts for model {self.model_internal}."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def to_dict(self):
|
|
71
|
+
return {
|
|
72
|
+
"id": self.id,
|
|
73
|
+
"model_internal": self.model_internal,
|
|
74
|
+
"model_external": self.model_external,
|
|
75
|
+
"region": self.region,
|
|
76
|
+
"prompt": self.prompt.to_log(), # destroys image if present
|
|
77
|
+
"sampling_params": self.sampling_params.__dict__,
|
|
78
|
+
"status_code": self.status_code,
|
|
79
|
+
"is_error": self.is_error,
|
|
80
|
+
"error_message": self.error_message,
|
|
81
|
+
"completion": self.completion,
|
|
82
|
+
"input_tokens": self.input_tokens,
|
|
83
|
+
"output_tokens": self.output_tokens,
|
|
84
|
+
"finish_reason": self.finish_reason,
|
|
85
|
+
"cost": self.cost,
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def from_dict(cls, data: dict):
|
|
90
|
+
return cls(
|
|
91
|
+
id=data.get("id", random.randint(0, 1_000_000_000)),
|
|
92
|
+
model_internal=data["model_internal"],
|
|
93
|
+
model_external=data["model_external"],
|
|
94
|
+
region=data["region"],
|
|
95
|
+
prompt=Conversation.from_log(data["prompt"]),
|
|
96
|
+
sampling_params=SamplingParams(**data["sampling_params"]),
|
|
97
|
+
status_code=data["status_code"],
|
|
98
|
+
is_error=data["is_error"],
|
|
99
|
+
error_message=data["error_message"],
|
|
100
|
+
input_tokens=data["input_tokens"],
|
|
101
|
+
output_tokens=data["output_tokens"],
|
|
102
|
+
completion=data["completion"],
|
|
103
|
+
finish_reason=data["finish_reason"],
|
|
104
|
+
cost=data["cost"],
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def write_to_file(self, filename):
|
|
108
|
+
"""
|
|
109
|
+
Writes the APIResponse as a line to a file.
|
|
110
|
+
If file exists, appends to it.
|
|
111
|
+
"""
|
|
112
|
+
with open(filename, "a") as f:
|
|
113
|
+
f.write(json.dumps(self.to_dict()) + "\n")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class APIRequestBase(ABC):
|
|
117
|
+
"""
|
|
118
|
+
Class for handling API requests. All model/endpoint-specific logic should be
|
|
119
|
+
handled by overriding __init__ and implementing the handle_response method.
|
|
120
|
+
For call_api to work, the __init__ must handle setting:
|
|
121
|
+
- url
|
|
122
|
+
- request_header
|
|
123
|
+
- request_json
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
task_id: int,
|
|
129
|
+
# should always be 'role', 'content' keys.
|
|
130
|
+
# internal logic should handle translating to specific API format
|
|
131
|
+
model_name: str, # must correspond to registry
|
|
132
|
+
prompt: Conversation,
|
|
133
|
+
attempts_left: int,
|
|
134
|
+
status_tracker: StatusTracker,
|
|
135
|
+
retry_queue: asyncio.Queue,
|
|
136
|
+
# needed in order to retry with a different model and not throw the output away
|
|
137
|
+
results_arr: list["APIRequestBase"],
|
|
138
|
+
request_timeout: int = 30,
|
|
139
|
+
sampling_params: SamplingParams = SamplingParams(),
|
|
140
|
+
logprobs: bool = False,
|
|
141
|
+
top_logprobs: Optional[int] = None,
|
|
142
|
+
pbar: Optional[tqdm] = None,
|
|
143
|
+
callback: Optional[Callable] = None,
|
|
144
|
+
debug: bool = False,
|
|
145
|
+
all_model_names: list[str] | None = None,
|
|
146
|
+
all_sampling_params: list[SamplingParams] | None = None,
|
|
147
|
+
):
|
|
148
|
+
if all_model_names is None:
|
|
149
|
+
raise ValueError("all_model_names must be provided.")
|
|
150
|
+
self.task_id = task_id
|
|
151
|
+
self.model_name = model_name
|
|
152
|
+
self.system_prompt = None
|
|
153
|
+
self.prompt = prompt
|
|
154
|
+
self.attempts_left = attempts_left
|
|
155
|
+
self.status_tracker = status_tracker
|
|
156
|
+
self.retry_queue = retry_queue
|
|
157
|
+
self.request_timeout = request_timeout
|
|
158
|
+
self.sampling_params = sampling_params
|
|
159
|
+
self.logprobs = logprobs # len(completion) logprobs
|
|
160
|
+
self.top_logprobs = top_logprobs
|
|
161
|
+
self.pbar = pbar
|
|
162
|
+
self.callback = callback
|
|
163
|
+
self.num_tokens = prompt.count_tokens(sampling_params.max_new_tokens)
|
|
164
|
+
self.results_arr = results_arr
|
|
165
|
+
self.debug = debug
|
|
166
|
+
self.all_model_names = all_model_names
|
|
167
|
+
self.all_sampling_params = all_sampling_params
|
|
168
|
+
self.result = [] # list of APIResponse objects from each attempt
|
|
169
|
+
|
|
170
|
+
# these should be set in the __init__ of the subclass
|
|
171
|
+
self.url = None
|
|
172
|
+
self.request_header = None
|
|
173
|
+
self.request_json = None
|
|
174
|
+
self.region = None
|
|
175
|
+
|
|
176
|
+
def increment_pbar(self):
|
|
177
|
+
if self.pbar is not None:
|
|
178
|
+
self.pbar.update(1)
|
|
179
|
+
|
|
180
|
+
def call_callback(self):
|
|
181
|
+
if self.callback is not None:
|
|
182
|
+
# the APIResponse in self.result includes all the information
|
|
183
|
+
self.callback(self.result[-1], self.status_tracker)
|
|
184
|
+
|
|
185
|
+
def handle_success(self, data):
|
|
186
|
+
self.call_callback()
|
|
187
|
+
self.increment_pbar()
|
|
188
|
+
self.status_tracker.num_tasks_in_progress -= 1
|
|
189
|
+
self.status_tracker.num_tasks_succeeded += 1
|
|
190
|
+
|
|
191
|
+
def handle_error(self, create_new_request=False, give_up_if_no_other_models=False):
|
|
192
|
+
"""
|
|
193
|
+
If create_new_request is True, will create a new API request (so that it
|
|
194
|
+
has a chance of being sent to a different model). If false, will retry
|
|
195
|
+
the same request.
|
|
196
|
+
"""
|
|
197
|
+
last_result: APIResponse = self.result[-1]
|
|
198
|
+
error_to_print = f"Error task {self.task_id}. "
|
|
199
|
+
error_to_print += (
|
|
200
|
+
f"Model: {last_result.model_internal} Code: {last_result.status_code}, "
|
|
201
|
+
)
|
|
202
|
+
if self.region is not None:
|
|
203
|
+
error_to_print += f"Region: {self.region}, "
|
|
204
|
+
error_to_print += f"Message: {last_result.error_message}."
|
|
205
|
+
print(error_to_print)
|
|
206
|
+
if self.attempts_left > 0:
|
|
207
|
+
self.attempts_left -= 1
|
|
208
|
+
if not create_new_request:
|
|
209
|
+
self.retry_queue.put_nowait(self)
|
|
210
|
+
return
|
|
211
|
+
else:
|
|
212
|
+
# make sure we have another model to send it to besides the current one
|
|
213
|
+
if self.all_model_names is None or len(self.all_model_names) < 2:
|
|
214
|
+
if give_up_if_no_other_models:
|
|
215
|
+
print(
|
|
216
|
+
f"No other models to try for task {self.task_id}. Giving up."
|
|
217
|
+
)
|
|
218
|
+
self.status_tracker.num_tasks_in_progress -= 1
|
|
219
|
+
self.status_tracker.num_tasks_failed += 1
|
|
220
|
+
else:
|
|
221
|
+
print(
|
|
222
|
+
f"No other models to try for task {self.task_id}. Retrying with same model."
|
|
223
|
+
)
|
|
224
|
+
self.retry_queue.put_nowait(self)
|
|
225
|
+
else:
|
|
226
|
+
# two things to change: model_name and sampling_params
|
|
227
|
+
new_model_name = self.model_name
|
|
228
|
+
new_model_idx = 0
|
|
229
|
+
while new_model_name == self.model_name:
|
|
230
|
+
new_model_idx = random.randint(0, len(self.all_model_names) - 1)
|
|
231
|
+
new_model_name = self.all_model_names[new_model_idx]
|
|
232
|
+
|
|
233
|
+
if isinstance(self.all_sampling_params, list):
|
|
234
|
+
new_sampling_params = self.all_sampling_params[new_model_idx]
|
|
235
|
+
elif isinstance(self.all_sampling_params, SamplingParams):
|
|
236
|
+
new_sampling_params = self.all_sampling_params
|
|
237
|
+
elif self.all_sampling_params is None:
|
|
238
|
+
new_sampling_params = self.sampling_params
|
|
239
|
+
else:
|
|
240
|
+
new_sampling_params = self.sampling_params
|
|
241
|
+
|
|
242
|
+
print("Creating new request with model", new_model_name)
|
|
243
|
+
new_request = create_api_request(
|
|
244
|
+
task_id=self.task_id,
|
|
245
|
+
model_name=new_model_name,
|
|
246
|
+
prompt=self.prompt,
|
|
247
|
+
attempts_left=self.attempts_left,
|
|
248
|
+
status_tracker=self.status_tracker,
|
|
249
|
+
retry_queue=self.retry_queue,
|
|
250
|
+
results_arr=self.results_arr,
|
|
251
|
+
request_timeout=self.request_timeout,
|
|
252
|
+
sampling_params=new_sampling_params,
|
|
253
|
+
logprobs=self.logprobs,
|
|
254
|
+
top_logprobs=self.top_logprobs,
|
|
255
|
+
pbar=self.pbar,
|
|
256
|
+
callback=self.callback,
|
|
257
|
+
all_model_names=self.all_model_names,
|
|
258
|
+
all_sampling_params=self.all_sampling_params,
|
|
259
|
+
)
|
|
260
|
+
# PROBLEM: new request is never put into results array, so we can't get the result.
|
|
261
|
+
self.retry_queue.put_nowait(new_request)
|
|
262
|
+
# SOLUTION: just need to make sure it's deduplicated by task_id later.
|
|
263
|
+
self.results_arr.append(new_request)
|
|
264
|
+
else:
|
|
265
|
+
print(f"Task {self.task_id} out of tries.")
|
|
266
|
+
self.status_tracker.num_tasks_in_progress -= 1
|
|
267
|
+
self.status_tracker.num_tasks_failed += 1
|
|
268
|
+
|
|
269
|
+
async def call_api(self):
|
|
270
|
+
try:
|
|
271
|
+
self.status_tracker.total_requests += 1
|
|
272
|
+
timeout = aiohttp.ClientTimeout(total=self.request_timeout)
|
|
273
|
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
274
|
+
assert self.url is not None, "URL is not set"
|
|
275
|
+
async with session.post(
|
|
276
|
+
url=self.url,
|
|
277
|
+
headers=self.request_header,
|
|
278
|
+
json=self.request_json,
|
|
279
|
+
) as http_response:
|
|
280
|
+
response: APIResponse = await self.handle_response(http_response)
|
|
281
|
+
|
|
282
|
+
self.result.append(response)
|
|
283
|
+
if response.is_error:
|
|
284
|
+
self.handle_error(
|
|
285
|
+
create_new_request=response.retry_with_different_model or False,
|
|
286
|
+
give_up_if_no_other_models=response.give_up_if_no_other_models
|
|
287
|
+
or False,
|
|
288
|
+
)
|
|
289
|
+
else:
|
|
290
|
+
self.handle_success(response)
|
|
291
|
+
|
|
292
|
+
except asyncio.TimeoutError:
|
|
293
|
+
self.result.append(
|
|
294
|
+
APIResponse(
|
|
295
|
+
id=self.task_id,
|
|
296
|
+
model_internal=self.model_name,
|
|
297
|
+
prompt=self.prompt,
|
|
298
|
+
sampling_params=self.sampling_params,
|
|
299
|
+
status_code=None,
|
|
300
|
+
is_error=True,
|
|
301
|
+
error_message="Request timed out (terminated by client).",
|
|
302
|
+
completion=None,
|
|
303
|
+
input_tokens=None,
|
|
304
|
+
output_tokens=None,
|
|
305
|
+
)
|
|
306
|
+
)
|
|
307
|
+
self.handle_error(create_new_request=False)
|
|
308
|
+
|
|
309
|
+
except Exception as e:
|
|
310
|
+
raise_if_modal_exception(e)
|
|
311
|
+
# print(f"Unexpected error {type(e).__name__}: {str(e) or 'No message.'}")
|
|
312
|
+
self.result.append(
|
|
313
|
+
APIResponse(
|
|
314
|
+
id=self.task_id,
|
|
315
|
+
model_internal=self.model_name,
|
|
316
|
+
prompt=self.prompt,
|
|
317
|
+
sampling_params=self.sampling_params,
|
|
318
|
+
status_code=None,
|
|
319
|
+
is_error=True,
|
|
320
|
+
error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
|
|
321
|
+
completion=None,
|
|
322
|
+
input_tokens=None,
|
|
323
|
+
output_tokens=None,
|
|
324
|
+
)
|
|
325
|
+
)
|
|
326
|
+
# maybe consider making True?
|
|
327
|
+
self.handle_error(create_new_request=False)
|
|
328
|
+
|
|
329
|
+
@abstractmethod
|
|
330
|
+
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
331
|
+
raise NotImplementedError
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def create_api_request(
|
|
335
|
+
task_id: int,
|
|
336
|
+
model_name: str,
|
|
337
|
+
prompt: Conversation,
|
|
338
|
+
attempts_left: int,
|
|
339
|
+
status_tracker: StatusTracker,
|
|
340
|
+
retry_queue: asyncio.Queue,
|
|
341
|
+
results_arr: list["APIRequestBase"],
|
|
342
|
+
request_timeout: int = 30,
|
|
343
|
+
sampling_params: SamplingParams = SamplingParams(),
|
|
344
|
+
logprobs: bool = False,
|
|
345
|
+
top_logprobs: Optional[int] = None,
|
|
346
|
+
pbar: Optional[tqdm] = None,
|
|
347
|
+
callback: Optional[Callable] = None,
|
|
348
|
+
all_model_names: list[str] | None = None,
|
|
349
|
+
all_sampling_params: list[SamplingParams] | None = None,
|
|
350
|
+
) -> APIRequestBase:
|
|
351
|
+
from .common import CLASSES # circular import so made it lazy, does this work?
|
|
352
|
+
|
|
353
|
+
model_obj = APIModel.from_registry(model_name)
|
|
354
|
+
request_class = CLASSES.get(model_obj.api_spec, None)
|
|
355
|
+
if request_class is None:
|
|
356
|
+
raise ValueError(f"Unsupported API spec: {model_obj.api_spec}")
|
|
357
|
+
kwargs = (
|
|
358
|
+
{} if not logprobs else {"logprobs": logprobs, "top_logprobs": top_logprobs}
|
|
359
|
+
)
|
|
360
|
+
return request_class(
|
|
361
|
+
task_id=task_id,
|
|
362
|
+
model_name=model_name,
|
|
363
|
+
prompt=prompt,
|
|
364
|
+
attempts_left=attempts_left,
|
|
365
|
+
status_tracker=status_tracker,
|
|
366
|
+
retry_queue=retry_queue,
|
|
367
|
+
results_arr=results_arr,
|
|
368
|
+
request_timeout=request_timeout,
|
|
369
|
+
sampling_params=sampling_params,
|
|
370
|
+
pbar=pbar,
|
|
371
|
+
callback=callback,
|
|
372
|
+
all_model_names=all_model_names,
|
|
373
|
+
all_sampling_params=all_sampling_params,
|
|
374
|
+
**kwargs,
|
|
375
|
+
)
|