lm-deluge 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +6 -0
- lm_deluge/api_requests/__init__.py +3 -0
- lm_deluge/api_requests/anthropic.py +177 -0
- lm_deluge/api_requests/base.py +375 -0
- lm_deluge/api_requests/cohere.py +138 -0
- lm_deluge/api_requests/common.py +18 -0
- lm_deluge/api_requests/deprecated/bedrock.py +288 -0
- lm_deluge/api_requests/deprecated/deepseek.py +118 -0
- lm_deluge/api_requests/deprecated/mistral.py +120 -0
- lm_deluge/api_requests/google.py +0 -0
- lm_deluge/api_requests/openai.py +145 -0
- lm_deluge/api_requests/vertex.py +365 -0
- lm_deluge/cache.py +144 -0
- lm_deluge/client.py +760 -0
- lm_deluge/embed.py +392 -0
- lm_deluge/errors.py +8 -0
- lm_deluge/gemini_limits.py +65 -0
- lm_deluge/image.py +200 -0
- lm_deluge/llm_tools/__init__.py +11 -0
- lm_deluge/llm_tools/extract.py +111 -0
- lm_deluge/llm_tools/score.py +71 -0
- lm_deluge/llm_tools/translate.py +44 -0
- lm_deluge/models.py +957 -0
- lm_deluge/prompt.py +355 -0
- lm_deluge/rerank.py +338 -0
- lm_deluge/sampling_params.py +25 -0
- lm_deluge/tool.py +106 -0
- lm_deluge/tracker.py +12 -0
- lm_deluge/util/json.py +167 -0
- lm_deluge/util/logprobs.py +446 -0
- lm_deluge/util/pdf.py +45 -0
- lm_deluge/util/validation.py +46 -0
- lm_deluge/util/xml.py +291 -0
- lm_deluge-0.0.3.dist-info/METADATA +127 -0
- lm_deluge-0.0.3.dist-info/RECORD +37 -0
- lm_deluge-0.0.3.dist-info/WHEEL +5 -0
- lm_deluge-0.0.3.dist-info/top_level.txt +1 -0
lm_deluge/client.py
ADDED
|
@@ -0,0 +1,760 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import requests
|
|
3
|
+
import asyncio
|
|
4
|
+
import numpy as np
|
|
5
|
+
import time
|
|
6
|
+
import yaml
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Sequence, overload, Literal, Optional, Union, Any
|
|
9
|
+
from tqdm.auto import tqdm
|
|
10
|
+
|
|
11
|
+
from lm_deluge.prompt import Conversation
|
|
12
|
+
|
|
13
|
+
from .tracker import StatusTracker
|
|
14
|
+
from .sampling_params import SamplingParams
|
|
15
|
+
from .models import registry
|
|
16
|
+
from .api_requests.base import APIResponse, APIRequestBase
|
|
17
|
+
from .api_requests import create_api_request
|
|
18
|
+
# from .cache import LevelDBCache, SqliteCache
|
|
19
|
+
|
|
20
|
+
# TODO: get completions as they finish, not all at once at the end.
|
|
21
|
+
# relatedly, would be nice to cache them as they finish too.
|
|
22
|
+
|
|
23
|
+
# TODO: add optional max_input_tokens to client so we can reject long prompts to prevent abuse
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class ClientConfig:
|
|
28
|
+
model_names: list[str]
|
|
29
|
+
max_requests_per_minute: int
|
|
30
|
+
max_tokens_per_minute: int
|
|
31
|
+
max_concurrent_requests: int
|
|
32
|
+
max_attempts: int
|
|
33
|
+
request_timeout: int
|
|
34
|
+
sampling_params: Union[SamplingParams, list[SamplingParams]]
|
|
35
|
+
model_weights: Union[list[float], Literal["uniform", "rate_limit"]]
|
|
36
|
+
logprobs: bool = False
|
|
37
|
+
top_logprobs: Optional[int] = None
|
|
38
|
+
cache: Optional[Any] = None
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def from_dict(cls, config_dict: dict):
|
|
42
|
+
if isinstance(config_dict["sampling_params"], list):
|
|
43
|
+
config_dict["sampling_params"] = [
|
|
44
|
+
SamplingParams(**x) for x in config_dict["sampling_params"]
|
|
45
|
+
]
|
|
46
|
+
else:
|
|
47
|
+
config_dict["sampling_params"] = SamplingParams(
|
|
48
|
+
config_dict["sampling_params"]
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
return cls(**config_dict)
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def from_yaml(cls, file_path: str):
|
|
55
|
+
config_dict = yaml.safe_load(open(file_path))
|
|
56
|
+
return cls.from_dict(config_dict)
|
|
57
|
+
|
|
58
|
+
def to_dict(self):
|
|
59
|
+
if isinstance(self.sampling_params, list):
|
|
60
|
+
sp = [x.__dict__ for x in self.sampling_params]
|
|
61
|
+
else:
|
|
62
|
+
sp = self.sampling_params.__dict__
|
|
63
|
+
|
|
64
|
+
return {
|
|
65
|
+
"model_names": self.model_names,
|
|
66
|
+
"max_requests_per_minute": self.max_requests_per_minute,
|
|
67
|
+
"max_tokens_per_minute": self.max_tokens_per_minute,
|
|
68
|
+
"max_concurrent_requests": self.max_concurrent_requests,
|
|
69
|
+
"max_attempts": self.max_attempts,
|
|
70
|
+
"request_timeout": self.request_timeout,
|
|
71
|
+
"sampling_params": sp,
|
|
72
|
+
"model_weights": self.model_weights,
|
|
73
|
+
"logprobs": self.logprobs,
|
|
74
|
+
"top_logprobs": self.top_logprobs,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class LLMClient:
|
|
79
|
+
"""
|
|
80
|
+
LLMClient abstracts all the fixed arguments to process_prompts_async, so you can create it
|
|
81
|
+
once and use it for more stuff without having to configure all the arguments.
|
|
82
|
+
Handles models, sampling params for each model, model weights, rate limits, etc.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
model_names: list[str],
|
|
90
|
+
max_requests_per_minute: int,
|
|
91
|
+
max_tokens_per_minute: int,
|
|
92
|
+
max_concurrent_requests: int,
|
|
93
|
+
sampling_params: Union[SamplingParams, list[SamplingParams]] = SamplingParams(),
|
|
94
|
+
model_weights: Union[list[float], Literal["uniform", "rate_limit"]] = "uniform",
|
|
95
|
+
max_attempts: int = 5,
|
|
96
|
+
request_timeout: int = 30,
|
|
97
|
+
logprobs: bool = False,
|
|
98
|
+
top_logprobs: Optional[int] = None,
|
|
99
|
+
use_qps: bool = False,
|
|
100
|
+
debug: bool = False,
|
|
101
|
+
cache: Optional[Any] = None,
|
|
102
|
+
):
|
|
103
|
+
self.models = model_names
|
|
104
|
+
if isinstance(sampling_params, SamplingParams):
|
|
105
|
+
self.sampling_params = [sampling_params for _ in model_names]
|
|
106
|
+
else:
|
|
107
|
+
if len(sampling_params) != len(model_names):
|
|
108
|
+
raise ValueError(
|
|
109
|
+
"If sampling_params is a list, it must have the same length as model_names."
|
|
110
|
+
)
|
|
111
|
+
self.sampling_params = sampling_params
|
|
112
|
+
if model_weights == "uniform":
|
|
113
|
+
self.model_weights = [1 / len(model_names) for _ in model_names]
|
|
114
|
+
elif model_weights == "rate_limit":
|
|
115
|
+
rpms = [registry[model]["requests_per_minute"] for model in model_names]
|
|
116
|
+
self.model_weights = [rpm / sum(rpms) for rpm in rpms]
|
|
117
|
+
elif sum(model_weights) != 1:
|
|
118
|
+
self.model_weights = [w / sum(model_weights) for w in model_weights]
|
|
119
|
+
else:
|
|
120
|
+
self.model_weights = model_weights
|
|
121
|
+
|
|
122
|
+
self.logprobs = logprobs
|
|
123
|
+
self.top_logprobs = top_logprobs
|
|
124
|
+
|
|
125
|
+
# logprobs and top_logprobs are only supported for OpenAI models
|
|
126
|
+
if self.logprobs:
|
|
127
|
+
for model in self.models:
|
|
128
|
+
if registry[model].get("supports_logprobs", False) is False:
|
|
129
|
+
raise ValueError(
|
|
130
|
+
"logprobs can only be enabled if all models support it."
|
|
131
|
+
)
|
|
132
|
+
if self.top_logprobs is None:
|
|
133
|
+
self.top_logprobs = 0 # will just return logprob of the chosen token
|
|
134
|
+
elif self.top_logprobs > 20 or self.top_logprobs < 0:
|
|
135
|
+
raise ValueError("top_logprobs must be between 0 and 20.")
|
|
136
|
+
for sp in self.sampling_params:
|
|
137
|
+
if sp.max_new_tokens > 10:
|
|
138
|
+
print(
|
|
139
|
+
"WARNING: using logprobs with large max_new_tokens can result in very large outputs. you may want to avoid saving these outputs to disk/db."
|
|
140
|
+
)
|
|
141
|
+
break
|
|
142
|
+
else:
|
|
143
|
+
self.top_logprobs = None
|
|
144
|
+
|
|
145
|
+
self.max_requests_per_minute = max_requests_per_minute
|
|
146
|
+
self.max_tokens_per_minute = max_tokens_per_minute
|
|
147
|
+
self.max_concurrent_requests = max_concurrent_requests
|
|
148
|
+
self.max_attempts = max_attempts
|
|
149
|
+
self.request_timeout = request_timeout
|
|
150
|
+
self.use_qps = use_qps
|
|
151
|
+
self.debug = (
|
|
152
|
+
debug # UNUSED/DEPRECATED i think? but dont want to break everything
|
|
153
|
+
)
|
|
154
|
+
self.cache = cache
|
|
155
|
+
|
|
156
|
+
@classmethod
|
|
157
|
+
def from_config(cls, config: ClientConfig, cache: Optional[Any] = None):
|
|
158
|
+
return cls(
|
|
159
|
+
model_names=config.model_names,
|
|
160
|
+
max_requests_per_minute=config.max_requests_per_minute,
|
|
161
|
+
max_tokens_per_minute=config.max_tokens_per_minute,
|
|
162
|
+
max_concurrent_requests=config.max_concurrent_requests,
|
|
163
|
+
sampling_params=config.sampling_params,
|
|
164
|
+
model_weights=config.model_weights,
|
|
165
|
+
max_attempts=config.max_attempts,
|
|
166
|
+
request_timeout=config.request_timeout,
|
|
167
|
+
cache=cache,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def from_yaml(cls, file_path: str, cache: Optional[Any] = None):
|
|
172
|
+
return cls.from_config(ClientConfig.from_yaml(file_path), cache=cache)
|
|
173
|
+
|
|
174
|
+
@classmethod
|
|
175
|
+
def basic(
|
|
176
|
+
cls,
|
|
177
|
+
model: Union[str, list[str]],
|
|
178
|
+
max_requests_per_minute: int = 5_000,
|
|
179
|
+
max_tokens_per_minute: int = 1_000_000,
|
|
180
|
+
max_concurrent_requests: int = 1_000,
|
|
181
|
+
temperature: float = 0.75,
|
|
182
|
+
max_new_tokens: int = 1000,
|
|
183
|
+
reasoning_effort: Literal[None, "low", "medium", "high"] = None,
|
|
184
|
+
model_weights: Union[list[float], Literal["uniform", "rate_limit"]] = "uniform",
|
|
185
|
+
logprobs: bool = False,
|
|
186
|
+
top_logprobs: Optional[int] = None,
|
|
187
|
+
max_attempts: int = 5,
|
|
188
|
+
request_timeout: int = 30,
|
|
189
|
+
cache: Optional[Any] = None,
|
|
190
|
+
):
|
|
191
|
+
model_names = model if isinstance(model, list) else [model]
|
|
192
|
+
return cls(
|
|
193
|
+
model_names=model_names,
|
|
194
|
+
max_requests_per_minute=max_requests_per_minute,
|
|
195
|
+
max_tokens_per_minute=max_tokens_per_minute,
|
|
196
|
+
max_concurrent_requests=max_concurrent_requests,
|
|
197
|
+
sampling_params=SamplingParams(
|
|
198
|
+
temperature=temperature,
|
|
199
|
+
max_new_tokens=max_new_tokens,
|
|
200
|
+
reasoning_effort=reasoning_effort,
|
|
201
|
+
),
|
|
202
|
+
logprobs=logprobs,
|
|
203
|
+
top_logprobs=top_logprobs,
|
|
204
|
+
model_weights=model_weights,
|
|
205
|
+
max_attempts=max_attempts,
|
|
206
|
+
request_timeout=request_timeout,
|
|
207
|
+
cache=cache,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def config(self):
|
|
212
|
+
return ClientConfig(
|
|
213
|
+
model_names=self.models,
|
|
214
|
+
model_weights=self.model_weights,
|
|
215
|
+
max_requests_per_minute=self.max_requests_per_minute,
|
|
216
|
+
max_tokens_per_minute=self.max_tokens_per_minute,
|
|
217
|
+
max_concurrent_requests=self.max_concurrent_requests,
|
|
218
|
+
max_attempts=self.max_attempts,
|
|
219
|
+
request_timeout=self.request_timeout,
|
|
220
|
+
sampling_params=self.sampling_params,
|
|
221
|
+
logprobs=self.logprobs,
|
|
222
|
+
top_logprobs=self.top_logprobs,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
from typing import Union, Literal
|
|
226
|
+
|
|
227
|
+
@overload
|
|
228
|
+
async def process_prompts_async(
|
|
229
|
+
self,
|
|
230
|
+
prompts: Sequence[str | list[dict] | Conversation],
|
|
231
|
+
return_completions_only: bool,
|
|
232
|
+
show_progress: bool = ...,
|
|
233
|
+
dry_run: Literal[True] = ...,
|
|
234
|
+
verbose: bool = ...,
|
|
235
|
+
) -> dict[str, int]: ...
|
|
236
|
+
|
|
237
|
+
@overload
|
|
238
|
+
async def process_prompts_async(
|
|
239
|
+
self,
|
|
240
|
+
prompts: Sequence[str | list[dict] | Conversation],
|
|
241
|
+
return_completions_only: Literal[True],
|
|
242
|
+
show_progress: bool = ...,
|
|
243
|
+
dry_run: Literal[False] = ...,
|
|
244
|
+
verbose: bool = ...,
|
|
245
|
+
) -> list[str | None]: ...
|
|
246
|
+
|
|
247
|
+
@overload
|
|
248
|
+
async def process_prompts_async(
|
|
249
|
+
self,
|
|
250
|
+
prompts: Sequence[str | list[dict] | Conversation],
|
|
251
|
+
return_completions_only: Literal[False] = ...,
|
|
252
|
+
show_progress: bool = ...,
|
|
253
|
+
dry_run: Literal[False] = ...,
|
|
254
|
+
verbose: bool = ...,
|
|
255
|
+
) -> list[APIResponse | None]: ...
|
|
256
|
+
|
|
257
|
+
async def process_prompts_async(
|
|
258
|
+
self,
|
|
259
|
+
prompts: Sequence[str | list[dict] | Conversation],
|
|
260
|
+
return_completions_only: bool = False,
|
|
261
|
+
show_progress: bool = True,
|
|
262
|
+
dry_run: bool = False,
|
|
263
|
+
verbose: bool = False,
|
|
264
|
+
) -> list[APIResponse | None] | list[str | None] | dict[str, int]:
|
|
265
|
+
# if prompts are not Conversations, convert them.
|
|
266
|
+
# can only handle strings for now
|
|
267
|
+
prompts = [ # type: ignore
|
|
268
|
+
p
|
|
269
|
+
if isinstance(p, Conversation)
|
|
270
|
+
else Conversation.user(p)
|
|
271
|
+
if isinstance(p, str)
|
|
272
|
+
else None
|
|
273
|
+
for p in prompts
|
|
274
|
+
]
|
|
275
|
+
if any(p is None for p in prompts):
|
|
276
|
+
raise ValueError("All prompts must be valid.")
|
|
277
|
+
ids = np.arange(len(prompts))
|
|
278
|
+
|
|
279
|
+
# if using cache, check for cached completions
|
|
280
|
+
if self.cache:
|
|
281
|
+
cached_results = [self.cache.get(prompt) for prompt in prompts]
|
|
282
|
+
cache_hit_ids = [
|
|
283
|
+
id for id, res in zip(ids, cached_results) if res is not None
|
|
284
|
+
]
|
|
285
|
+
cache_hit_results = [res for res in cached_results if res is not None]
|
|
286
|
+
assert len(cache_hit_ids) == len(
|
|
287
|
+
cache_hit_results
|
|
288
|
+
), "Cache hit ids and results must be the same length."
|
|
289
|
+
remaining_ids = np.array([i for i in ids if i not in cache_hit_ids])
|
|
290
|
+
remaining_prompts = [prompts[i] for i in remaining_ids]
|
|
291
|
+
if verbose:
|
|
292
|
+
print(f"{len(cache_hit_ids)} cache hits from previous completions.")
|
|
293
|
+
print(f"{len(remaining_ids)} prompts remaining after cache hits.")
|
|
294
|
+
print(f"Processing {len(remaining_prompts)} prompts.")
|
|
295
|
+
|
|
296
|
+
else:
|
|
297
|
+
cache_hit_ids = []
|
|
298
|
+
cache_hit_results = []
|
|
299
|
+
remaining_prompts = prompts
|
|
300
|
+
remaining_ids = ids
|
|
301
|
+
|
|
302
|
+
results: list[APIResponse | None] = [None for _ in range(len(prompts))]
|
|
303
|
+
if len(remaining_prompts) > 0:
|
|
304
|
+
# set up progress bar
|
|
305
|
+
pbar = tqdm(total=len(prompts), disable=(not show_progress))
|
|
306
|
+
|
|
307
|
+
# update progress bar with cache hits
|
|
308
|
+
pbar.update(len(cache_hit_ids))
|
|
309
|
+
api_task = None
|
|
310
|
+
if dry_run:
|
|
311
|
+
dry_run_results = api_prompts_dry_run(
|
|
312
|
+
ids,
|
|
313
|
+
prompts, # type: ignore -- fix later for dry running conversations
|
|
314
|
+
self.models,
|
|
315
|
+
self.model_weights,
|
|
316
|
+
self.sampling_params,
|
|
317
|
+
max_tokens_per_minute=self.max_tokens_per_minute,
|
|
318
|
+
max_requests_per_minute=self.max_requests_per_minute,
|
|
319
|
+
)
|
|
320
|
+
print("Dry run results:")
|
|
321
|
+
print(dry_run_results)
|
|
322
|
+
return dry_run_results
|
|
323
|
+
|
|
324
|
+
api_task = asyncio.create_task(
|
|
325
|
+
process_api_prompts_async(
|
|
326
|
+
ids,
|
|
327
|
+
prompts, # type: ignore -- fix later for dry running conversations
|
|
328
|
+
self.models,
|
|
329
|
+
self.model_weights,
|
|
330
|
+
self.sampling_params,
|
|
331
|
+
logprobs=self.logprobs,
|
|
332
|
+
top_logprobs=self.top_logprobs,
|
|
333
|
+
max_attempts=self.max_attempts,
|
|
334
|
+
max_tokens_per_minute=self.max_tokens_per_minute,
|
|
335
|
+
max_requests_per_minute=self.max_requests_per_minute,
|
|
336
|
+
max_concurrent_requests=self.max_concurrent_requests,
|
|
337
|
+
request_timeout=self.request_timeout,
|
|
338
|
+
progress_bar=pbar,
|
|
339
|
+
use_qps=self.use_qps,
|
|
340
|
+
verbose=verbose,
|
|
341
|
+
)
|
|
342
|
+
)
|
|
343
|
+
api_results: list[APIResponse] = await api_task
|
|
344
|
+
for res in api_results:
|
|
345
|
+
results[res.id] = res
|
|
346
|
+
# set to cache if result has a completion
|
|
347
|
+
if self.cache and res.completion:
|
|
348
|
+
self.cache.put(prompts[res.id], res)
|
|
349
|
+
|
|
350
|
+
# add cache hits back in
|
|
351
|
+
for id, res in zip(cache_hit_ids, cache_hit_results):
|
|
352
|
+
results[id] = res
|
|
353
|
+
|
|
354
|
+
if return_completions_only:
|
|
355
|
+
return [r.completion if r is not None else None for r in results]
|
|
356
|
+
|
|
357
|
+
return results
|
|
358
|
+
|
|
359
|
+
def process_prompts_sync(
|
|
360
|
+
self,
|
|
361
|
+
prompts: Sequence[str | list[dict] | Conversation],
|
|
362
|
+
return_completions_only: bool = False,
|
|
363
|
+
show_progress=True,
|
|
364
|
+
dry_run: bool = False,
|
|
365
|
+
verbose: bool = False,
|
|
366
|
+
):
|
|
367
|
+
return asyncio.run(
|
|
368
|
+
self.process_prompts_async(
|
|
369
|
+
prompts=prompts,
|
|
370
|
+
return_completions_only=return_completions_only,
|
|
371
|
+
show_progress=show_progress,
|
|
372
|
+
dry_run=dry_run,
|
|
373
|
+
verbose=verbose,
|
|
374
|
+
)
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
def _submit_one_batch(self, batch_requests: list):
|
|
378
|
+
# save the file
|
|
379
|
+
import pandas as pd
|
|
380
|
+
|
|
381
|
+
pd.DataFrame(batch_requests).to_json(
|
|
382
|
+
"openai_requests_temp.jsonl", orient="records", lines=True
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# upload the file
|
|
386
|
+
api_key = os.environ.get("OPENAI_API_KEY", None)
|
|
387
|
+
if api_key is None:
|
|
388
|
+
raise ValueError("OPENAI_API_KEY environment variable must be set.")
|
|
389
|
+
url = "https://api.openai.com/v1/files"
|
|
390
|
+
files = {
|
|
391
|
+
"file": (
|
|
392
|
+
"openai_requests_temp.jsonl",
|
|
393
|
+
open("openai_requests_temp.jsonl", "rb"),
|
|
394
|
+
),
|
|
395
|
+
}
|
|
396
|
+
data = {
|
|
397
|
+
"purpose": "batch",
|
|
398
|
+
}
|
|
399
|
+
headers = {
|
|
400
|
+
"Authorization": f"Bearer {api_key}",
|
|
401
|
+
}
|
|
402
|
+
response = requests.post(url, files=files, data=data, headers=headers)
|
|
403
|
+
|
|
404
|
+
file_id = None
|
|
405
|
+
if response.status_code == 200:
|
|
406
|
+
print("File uploaded successfully")
|
|
407
|
+
data = response.json()
|
|
408
|
+
file_id = data["id"]
|
|
409
|
+
|
|
410
|
+
else:
|
|
411
|
+
print("File upload failed")
|
|
412
|
+
raise ValueError(f"Error uploading file: {response.text}")
|
|
413
|
+
|
|
414
|
+
url = "https://api.openai.com/v1/batches"
|
|
415
|
+
data = {
|
|
416
|
+
"input_file_id": file_id,
|
|
417
|
+
"endpoint": "/v1/chat/completions",
|
|
418
|
+
"completion_window": "24h",
|
|
419
|
+
}
|
|
420
|
+
response = requests.post(url, json=data, headers=headers)
|
|
421
|
+
|
|
422
|
+
batch_id = None
|
|
423
|
+
if response.status_code == 200:
|
|
424
|
+
data = response.json()
|
|
425
|
+
batch_id = data["id"]
|
|
426
|
+
print("Batch job started successfully: id = ", batch_id)
|
|
427
|
+
return batch_id
|
|
428
|
+
else:
|
|
429
|
+
print("Batch job failed to start")
|
|
430
|
+
raise ValueError(f"Error starting batch job: {response.text}")
|
|
431
|
+
|
|
432
|
+
def submit_batch_job(self, prompts: Sequence[str | list[dict] | Conversation]):
|
|
433
|
+
# make sure 1) only 1 model is used, 2) it's an openai model, 3) it supports json mode
|
|
434
|
+
if len(self.models) != 1:
|
|
435
|
+
raise ValueError("Batch jobs can only be submitted with a single model.")
|
|
436
|
+
model = self.models[0]
|
|
437
|
+
if registry[model].get("api_spec", None) != "openai":
|
|
438
|
+
raise ValueError("Batch jobs can only be submitted with OpenAI models.")
|
|
439
|
+
|
|
440
|
+
# if prompts are strings, convert them to message lists
|
|
441
|
+
prompts = [ # type: ignore
|
|
442
|
+
p
|
|
443
|
+
if isinstance(p, Conversation)
|
|
444
|
+
else Conversation.user(p)
|
|
445
|
+
if isinstance(p, str)
|
|
446
|
+
else None
|
|
447
|
+
for p in prompts
|
|
448
|
+
]
|
|
449
|
+
if any(p is None for p in prompts):
|
|
450
|
+
raise ValueError("All prompts must be valid.")
|
|
451
|
+
ids = np.arange(len(prompts))
|
|
452
|
+
|
|
453
|
+
# create file with requests to send to batch api
|
|
454
|
+
batch_requests = []
|
|
455
|
+
for id, prompt in zip(ids, prompts):
|
|
456
|
+
assert isinstance(prompt, Conversation)
|
|
457
|
+
batch_requests.append(
|
|
458
|
+
{
|
|
459
|
+
"custom_id": str(id),
|
|
460
|
+
"method": "POST",
|
|
461
|
+
"url": "/v1/chat/completions",
|
|
462
|
+
"body": {
|
|
463
|
+
"model": self.models[0],
|
|
464
|
+
"messages": prompt.to_openai(),
|
|
465
|
+
"max_tokens": self.sampling_params[0].max_new_tokens,
|
|
466
|
+
"temperature": self.sampling_params[0].temperature,
|
|
467
|
+
"top_p": self.sampling_params[0].top_p,
|
|
468
|
+
},
|
|
469
|
+
}
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# since the api only accepts up to 50,000 requests per batch job, we chunk into 50k chunks
|
|
473
|
+
BATCH_SIZE = 50_000
|
|
474
|
+
batches = [
|
|
475
|
+
batch_requests[i : i + BATCH_SIZE]
|
|
476
|
+
for i in range(0, len(batch_requests), BATCH_SIZE)
|
|
477
|
+
]
|
|
478
|
+
batch_ids = []
|
|
479
|
+
for batch in tqdm(batches):
|
|
480
|
+
batch_id = self._submit_one_batch(batch)
|
|
481
|
+
batch_ids.append(batch_id)
|
|
482
|
+
|
|
483
|
+
print(f"Submitted {len(batches)} batch jobs.")
|
|
484
|
+
return batch_ids
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def api_prompts_dry_run(
|
|
488
|
+
ids: Union[np.ndarray, list[int]],
|
|
489
|
+
prompts: list[Conversation],
|
|
490
|
+
models: Union[str, list[str]],
|
|
491
|
+
model_weights: list[float],
|
|
492
|
+
sampling_params: list[SamplingParams],
|
|
493
|
+
max_tokens_per_minute: int = 500_000,
|
|
494
|
+
max_requests_per_minute: int = 1_000,
|
|
495
|
+
):
|
|
496
|
+
"""
|
|
497
|
+
Count tokens and estimate costs for a batch of prompts.
|
|
498
|
+
"""
|
|
499
|
+
results = []
|
|
500
|
+
for i, prompt in zip(ids, prompts):
|
|
501
|
+
# choose a model
|
|
502
|
+
model_idx = np.random.choice(range(len(models)), p=model_weights)
|
|
503
|
+
model = models[model_idx]
|
|
504
|
+
|
|
505
|
+
# dry run
|
|
506
|
+
input_tokens, output_tokens, min_cost, max_cost = prompt.dry_run(
|
|
507
|
+
model, sampling_params[model_idx].max_new_tokens
|
|
508
|
+
)
|
|
509
|
+
results.append(
|
|
510
|
+
{
|
|
511
|
+
"id": i,
|
|
512
|
+
"input_tokens": input_tokens,
|
|
513
|
+
"output_tokens": output_tokens,
|
|
514
|
+
"min_cost": min_cost,
|
|
515
|
+
"max_cost": max_cost,
|
|
516
|
+
}
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
combined_results: dict[str, Any] = {
|
|
520
|
+
"total_input_tokens": sum([r["input_tokens"] for r in results]),
|
|
521
|
+
"total_output_tokens": sum([r["output_tokens"] for r in results]),
|
|
522
|
+
"total_min_cost": sum([r["min_cost"] for r in results]),
|
|
523
|
+
"total_max_cost": sum([r["max_cost"] for r in results]),
|
|
524
|
+
}
|
|
525
|
+
minimum_time_tpm = combined_results["total_input_tokens"] / max_tokens_per_minute
|
|
526
|
+
maximum_time_tpm = (
|
|
527
|
+
combined_results["total_input_tokens"] + combined_results["total_output_tokens"]
|
|
528
|
+
) / max_tokens_per_minute
|
|
529
|
+
minimum_time_rpm = len(prompts) / max_requests_per_minute
|
|
530
|
+
|
|
531
|
+
combined_results["minimum_time"] = max(minimum_time_tpm, minimum_time_rpm)
|
|
532
|
+
combined_results["maximum_time"] = max(maximum_time_tpm, minimum_time_rpm)
|
|
533
|
+
limiting_factor = None
|
|
534
|
+
if minimum_time_rpm > maximum_time_tpm:
|
|
535
|
+
limiting_factor = "requests"
|
|
536
|
+
elif minimum_time_rpm < minimum_time_tpm:
|
|
537
|
+
limiting_factor = "tokens"
|
|
538
|
+
else:
|
|
539
|
+
limiting_factor = "depends"
|
|
540
|
+
combined_results["limiting_factor"] = limiting_factor
|
|
541
|
+
|
|
542
|
+
return combined_results
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
async def process_api_prompts_async(
|
|
546
|
+
ids: Union[np.ndarray, list[int]],
|
|
547
|
+
prompts: list[Conversation],
|
|
548
|
+
models: Union[str, list[str]],
|
|
549
|
+
model_weights: list[float],
|
|
550
|
+
sampling_params: list[SamplingParams],
|
|
551
|
+
logprobs: bool,
|
|
552
|
+
top_logprobs: Optional[int],
|
|
553
|
+
max_attempts: int = 5,
|
|
554
|
+
max_tokens_per_minute: int = 500_000,
|
|
555
|
+
max_requests_per_minute: int = 1_000,
|
|
556
|
+
max_concurrent_requests: int = 1_000,
|
|
557
|
+
request_timeout: int = 30,
|
|
558
|
+
progress_bar: Optional[tqdm] = None,
|
|
559
|
+
use_qps: bool = False,
|
|
560
|
+
verbose: bool = False,
|
|
561
|
+
):
|
|
562
|
+
"""Processes API requests in parallel, throttling to stay under rate limits."""
|
|
563
|
+
# change ids to integer list
|
|
564
|
+
if isinstance(ids, np.ndarray):
|
|
565
|
+
ids = ids.tolist() # pyright: ignore
|
|
566
|
+
|
|
567
|
+
# normalize weights
|
|
568
|
+
model_weights = [w / sum(model_weights) for w in model_weights]
|
|
569
|
+
|
|
570
|
+
# constants
|
|
571
|
+
seconds_to_pause_after_rate_limit_error = 5
|
|
572
|
+
# seconds_to_sleep_each_loop = 0.003 # so concurrent tasks can run
|
|
573
|
+
# calculate dynamically so we don't throttle RPM
|
|
574
|
+
seconds_to_sleep_each_loop = (60.0 * 0.9) / max_requests_per_minute
|
|
575
|
+
|
|
576
|
+
# initialize trackers
|
|
577
|
+
retry_queue = asyncio.Queue()
|
|
578
|
+
status_tracker = StatusTracker()
|
|
579
|
+
next_request = None # variable to hold the next request to call
|
|
580
|
+
|
|
581
|
+
# initialize available capacity counts
|
|
582
|
+
# throttle over a 1 second window rather than minute,
|
|
583
|
+
# since some models limit RPS rather than RPM
|
|
584
|
+
if use_qps:
|
|
585
|
+
available_request_capacity = max_requests_per_minute / 60.0
|
|
586
|
+
available_token_capacity = max_tokens_per_minute
|
|
587
|
+
else:
|
|
588
|
+
available_request_capacity = max_requests_per_minute
|
|
589
|
+
available_token_capacity = max_tokens_per_minute
|
|
590
|
+
last_update_time = time.time()
|
|
591
|
+
last_pbar_update_time = time.time()
|
|
592
|
+
|
|
593
|
+
# initialize flags
|
|
594
|
+
prompts_not_finished = True
|
|
595
|
+
|
|
596
|
+
# initials model weights
|
|
597
|
+
if isinstance(models, str):
|
|
598
|
+
models = [models]
|
|
599
|
+
if not isinstance(models, list):
|
|
600
|
+
raise ValueError("models must be a string or a list of model strings.")
|
|
601
|
+
for model in models:
|
|
602
|
+
if model not in registry:
|
|
603
|
+
raise ValueError(f"Model {model} not found in registry.")
|
|
604
|
+
|
|
605
|
+
if model_weights is None:
|
|
606
|
+
# if not given, spread requests evenly across models
|
|
607
|
+
model_weights = [1 / len(models) for _ in models]
|
|
608
|
+
elif len(model_weights) != len(models):
|
|
609
|
+
raise ValueError(
|
|
610
|
+
"model_weights must be None or a list of the same length as models."
|
|
611
|
+
)
|
|
612
|
+
elif sum(model_weights) != 1:
|
|
613
|
+
model_weights = [w / sum(model_weights) for w in model_weights]
|
|
614
|
+
|
|
615
|
+
prompts_iter = iter(zip(ids, prompts))
|
|
616
|
+
results: list[APIRequestBase] = []
|
|
617
|
+
while True:
|
|
618
|
+
# get next request (if one is not already waiting for capacity)
|
|
619
|
+
if next_request is None:
|
|
620
|
+
if not retry_queue.empty():
|
|
621
|
+
next_request = retry_queue.get_nowait()
|
|
622
|
+
print(f"Retrying request {next_request.task_id}.")
|
|
623
|
+
elif prompts_not_finished:
|
|
624
|
+
try:
|
|
625
|
+
# get new request
|
|
626
|
+
id, prompt = next(prompts_iter)
|
|
627
|
+
# select model
|
|
628
|
+
model_idx = np.random.choice(range(len(models)), p=model_weights)
|
|
629
|
+
next_request = create_api_request(
|
|
630
|
+
task_id=id,
|
|
631
|
+
model_name=models[model_idx],
|
|
632
|
+
prompt=prompt,
|
|
633
|
+
request_timeout=request_timeout,
|
|
634
|
+
attempts_left=max_attempts,
|
|
635
|
+
status_tracker=status_tracker,
|
|
636
|
+
retry_queue=retry_queue,
|
|
637
|
+
results_arr=results,
|
|
638
|
+
sampling_params=sampling_params[model_idx],
|
|
639
|
+
logprobs=logprobs,
|
|
640
|
+
top_logprobs=top_logprobs,
|
|
641
|
+
pbar=progress_bar,
|
|
642
|
+
all_model_names=models,
|
|
643
|
+
all_sampling_params=sampling_params,
|
|
644
|
+
)
|
|
645
|
+
status_tracker.num_tasks_started += 1
|
|
646
|
+
status_tracker.num_tasks_in_progress += 1
|
|
647
|
+
results.append(next_request)
|
|
648
|
+
|
|
649
|
+
except StopIteration:
|
|
650
|
+
prompts_not_finished = False
|
|
651
|
+
if verbose:
|
|
652
|
+
print("API requests finished, only retries remain.")
|
|
653
|
+
|
|
654
|
+
# update available capacity
|
|
655
|
+
current_time = time.time()
|
|
656
|
+
seconds_since_update = current_time - last_update_time
|
|
657
|
+
available_request_capacity = min(
|
|
658
|
+
available_request_capacity
|
|
659
|
+
+ max_requests_per_minute * seconds_since_update / 60.0,
|
|
660
|
+
max_requests_per_minute,
|
|
661
|
+
)
|
|
662
|
+
available_token_capacity = min(
|
|
663
|
+
available_token_capacity
|
|
664
|
+
+ max_tokens_per_minute * seconds_since_update / 60.0,
|
|
665
|
+
max_tokens_per_minute,
|
|
666
|
+
)
|
|
667
|
+
last_update_time = current_time
|
|
668
|
+
|
|
669
|
+
# if enough capacity available, call API
|
|
670
|
+
limiting_factor = None
|
|
671
|
+
if next_request:
|
|
672
|
+
next_request_tokens = next_request.num_tokens
|
|
673
|
+
request_available = available_request_capacity >= 1
|
|
674
|
+
tokens_available = available_token_capacity >= next_request_tokens
|
|
675
|
+
concurrent_request_available = (
|
|
676
|
+
status_tracker.num_tasks_in_progress < max_concurrent_requests
|
|
677
|
+
)
|
|
678
|
+
if request_available and tokens_available and concurrent_request_available:
|
|
679
|
+
# update counters
|
|
680
|
+
available_request_capacity -= 1
|
|
681
|
+
available_token_capacity -= next_request_tokens
|
|
682
|
+
next_request.attempts_left -= 1
|
|
683
|
+
|
|
684
|
+
# call API
|
|
685
|
+
asyncio.create_task(next_request.call_api())
|
|
686
|
+
next_request = None # reset next_request to empty
|
|
687
|
+
else:
|
|
688
|
+
if not request_available:
|
|
689
|
+
limiting_factor = "Requests"
|
|
690
|
+
elif not concurrent_request_available:
|
|
691
|
+
limiting_factor = "Concurrent Requests"
|
|
692
|
+
elif not tokens_available:
|
|
693
|
+
limiting_factor = "Tokens"
|
|
694
|
+
|
|
695
|
+
# update pbar status
|
|
696
|
+
if progress_bar and (current_time - last_pbar_update_time > 1):
|
|
697
|
+
last_pbar_update_time = current_time
|
|
698
|
+
progress_bar.set_postfix(
|
|
699
|
+
{
|
|
700
|
+
"Token Capacity": f"{available_token_capacity/1_000:.1f}k",
|
|
701
|
+
"Req. Capacity": f"{available_request_capacity:.1f}",
|
|
702
|
+
"Reqs. in Progress": status_tracker.num_tasks_in_progress,
|
|
703
|
+
"Limiting Factor": limiting_factor,
|
|
704
|
+
}
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
# if all tasks are finished, break
|
|
708
|
+
if status_tracker.num_tasks_in_progress == 0:
|
|
709
|
+
break
|
|
710
|
+
|
|
711
|
+
# main loop sleeps briefly so concurrent tasks can run
|
|
712
|
+
await asyncio.sleep(seconds_to_sleep_each_loop)
|
|
713
|
+
|
|
714
|
+
# if a rate limit error was hit recently, pause to cool down
|
|
715
|
+
seconds_since_rate_limit_error = (
|
|
716
|
+
time.time() - status_tracker.time_of_last_rate_limit_error
|
|
717
|
+
)
|
|
718
|
+
if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error:
|
|
719
|
+
remaining_seconds_to_pause = (
|
|
720
|
+
seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
|
|
721
|
+
)
|
|
722
|
+
await asyncio.sleep(remaining_seconds_to_pause)
|
|
723
|
+
# ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
|
|
724
|
+
print(
|
|
725
|
+
f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
# after finishing, log final status
|
|
729
|
+
if status_tracker.num_tasks_failed > 0:
|
|
730
|
+
print(
|
|
731
|
+
f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed."
|
|
732
|
+
)
|
|
733
|
+
if status_tracker.num_rate_limit_errors > 0:
|
|
734
|
+
print(
|
|
735
|
+
f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
|
|
736
|
+
)
|
|
737
|
+
if verbose:
|
|
738
|
+
print(
|
|
739
|
+
f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
# deduplicate results by id
|
|
743
|
+
deduplicated = {}
|
|
744
|
+
for request in results:
|
|
745
|
+
if request.task_id not in deduplicated:
|
|
746
|
+
deduplicated[request.task_id] = request.result[-1]
|
|
747
|
+
else:
|
|
748
|
+
current_response: APIResponse = deduplicated[request.task_id]
|
|
749
|
+
# only replace if the current request has no completion and the new one does
|
|
750
|
+
if (
|
|
751
|
+
request.result[-1].completion is not None
|
|
752
|
+
and current_response.completion is None
|
|
753
|
+
):
|
|
754
|
+
deduplicated[request.task_id] = request.result[-1]
|
|
755
|
+
|
|
756
|
+
output = list(deduplicated.values())
|
|
757
|
+
if verbose:
|
|
758
|
+
print(f"Returning {len(output)} unique results.")
|
|
759
|
+
|
|
760
|
+
return output
|