lm-deluge 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

lm_deluge/client.py ADDED
@@ -0,0 +1,760 @@
1
+ import os
2
+ import requests
3
+ import asyncio
4
+ import numpy as np
5
+ import time
6
+ import yaml
7
+ from dataclasses import dataclass
8
+ from typing import Sequence, overload, Literal, Optional, Union, Any
9
+ from tqdm.auto import tqdm
10
+
11
+ from lm_deluge.prompt import Conversation
12
+
13
+ from .tracker import StatusTracker
14
+ from .sampling_params import SamplingParams
15
+ from .models import registry
16
+ from .api_requests.base import APIResponse, APIRequestBase
17
+ from .api_requests import create_api_request
18
+ # from .cache import LevelDBCache, SqliteCache
19
+
20
+ # TODO: get completions as they finish, not all at once at the end.
21
+ # relatedly, would be nice to cache them as they finish too.
22
+
23
+ # TODO: add optional max_input_tokens to client so we can reject long prompts to prevent abuse
24
+
25
+
26
+ @dataclass
27
+ class ClientConfig:
28
+ model_names: list[str]
29
+ max_requests_per_minute: int
30
+ max_tokens_per_minute: int
31
+ max_concurrent_requests: int
32
+ max_attempts: int
33
+ request_timeout: int
34
+ sampling_params: Union[SamplingParams, list[SamplingParams]]
35
+ model_weights: Union[list[float], Literal["uniform", "rate_limit"]]
36
+ logprobs: bool = False
37
+ top_logprobs: Optional[int] = None
38
+ cache: Optional[Any] = None
39
+
40
+ @classmethod
41
+ def from_dict(cls, config_dict: dict):
42
+ if isinstance(config_dict["sampling_params"], list):
43
+ config_dict["sampling_params"] = [
44
+ SamplingParams(**x) for x in config_dict["sampling_params"]
45
+ ]
46
+ else:
47
+ config_dict["sampling_params"] = SamplingParams(
48
+ config_dict["sampling_params"]
49
+ )
50
+
51
+ return cls(**config_dict)
52
+
53
+ @classmethod
54
+ def from_yaml(cls, file_path: str):
55
+ config_dict = yaml.safe_load(open(file_path))
56
+ return cls.from_dict(config_dict)
57
+
58
+ def to_dict(self):
59
+ if isinstance(self.sampling_params, list):
60
+ sp = [x.__dict__ for x in self.sampling_params]
61
+ else:
62
+ sp = self.sampling_params.__dict__
63
+
64
+ return {
65
+ "model_names": self.model_names,
66
+ "max_requests_per_minute": self.max_requests_per_minute,
67
+ "max_tokens_per_minute": self.max_tokens_per_minute,
68
+ "max_concurrent_requests": self.max_concurrent_requests,
69
+ "max_attempts": self.max_attempts,
70
+ "request_timeout": self.request_timeout,
71
+ "sampling_params": sp,
72
+ "model_weights": self.model_weights,
73
+ "logprobs": self.logprobs,
74
+ "top_logprobs": self.top_logprobs,
75
+ }
76
+
77
+
78
+ class LLMClient:
79
+ """
80
+ LLMClient abstracts all the fixed arguments to process_prompts_async, so you can create it
81
+ once and use it for more stuff without having to configure all the arguments.
82
+ Handles models, sampling params for each model, model weights, rate limits, etc.
83
+ """
84
+
85
+ pass
86
+
87
+ def __init__(
88
+ self,
89
+ model_names: list[str],
90
+ max_requests_per_minute: int,
91
+ max_tokens_per_minute: int,
92
+ max_concurrent_requests: int,
93
+ sampling_params: Union[SamplingParams, list[SamplingParams]] = SamplingParams(),
94
+ model_weights: Union[list[float], Literal["uniform", "rate_limit"]] = "uniform",
95
+ max_attempts: int = 5,
96
+ request_timeout: int = 30,
97
+ logprobs: bool = False,
98
+ top_logprobs: Optional[int] = None,
99
+ use_qps: bool = False,
100
+ debug: bool = False,
101
+ cache: Optional[Any] = None,
102
+ ):
103
+ self.models = model_names
104
+ if isinstance(sampling_params, SamplingParams):
105
+ self.sampling_params = [sampling_params for _ in model_names]
106
+ else:
107
+ if len(sampling_params) != len(model_names):
108
+ raise ValueError(
109
+ "If sampling_params is a list, it must have the same length as model_names."
110
+ )
111
+ self.sampling_params = sampling_params
112
+ if model_weights == "uniform":
113
+ self.model_weights = [1 / len(model_names) for _ in model_names]
114
+ elif model_weights == "rate_limit":
115
+ rpms = [registry[model]["requests_per_minute"] for model in model_names]
116
+ self.model_weights = [rpm / sum(rpms) for rpm in rpms]
117
+ elif sum(model_weights) != 1:
118
+ self.model_weights = [w / sum(model_weights) for w in model_weights]
119
+ else:
120
+ self.model_weights = model_weights
121
+
122
+ self.logprobs = logprobs
123
+ self.top_logprobs = top_logprobs
124
+
125
+ # logprobs and top_logprobs are only supported for OpenAI models
126
+ if self.logprobs:
127
+ for model in self.models:
128
+ if registry[model].get("supports_logprobs", False) is False:
129
+ raise ValueError(
130
+ "logprobs can only be enabled if all models support it."
131
+ )
132
+ if self.top_logprobs is None:
133
+ self.top_logprobs = 0 # will just return logprob of the chosen token
134
+ elif self.top_logprobs > 20 or self.top_logprobs < 0:
135
+ raise ValueError("top_logprobs must be between 0 and 20.")
136
+ for sp in self.sampling_params:
137
+ if sp.max_new_tokens > 10:
138
+ print(
139
+ "WARNING: using logprobs with large max_new_tokens can result in very large outputs. you may want to avoid saving these outputs to disk/db."
140
+ )
141
+ break
142
+ else:
143
+ self.top_logprobs = None
144
+
145
+ self.max_requests_per_minute = max_requests_per_minute
146
+ self.max_tokens_per_minute = max_tokens_per_minute
147
+ self.max_concurrent_requests = max_concurrent_requests
148
+ self.max_attempts = max_attempts
149
+ self.request_timeout = request_timeout
150
+ self.use_qps = use_qps
151
+ self.debug = (
152
+ debug # UNUSED/DEPRECATED i think? but dont want to break everything
153
+ )
154
+ self.cache = cache
155
+
156
+ @classmethod
157
+ def from_config(cls, config: ClientConfig, cache: Optional[Any] = None):
158
+ return cls(
159
+ model_names=config.model_names,
160
+ max_requests_per_minute=config.max_requests_per_minute,
161
+ max_tokens_per_minute=config.max_tokens_per_minute,
162
+ max_concurrent_requests=config.max_concurrent_requests,
163
+ sampling_params=config.sampling_params,
164
+ model_weights=config.model_weights,
165
+ max_attempts=config.max_attempts,
166
+ request_timeout=config.request_timeout,
167
+ cache=cache,
168
+ )
169
+
170
+ @classmethod
171
+ def from_yaml(cls, file_path: str, cache: Optional[Any] = None):
172
+ return cls.from_config(ClientConfig.from_yaml(file_path), cache=cache)
173
+
174
+ @classmethod
175
+ def basic(
176
+ cls,
177
+ model: Union[str, list[str]],
178
+ max_requests_per_minute: int = 5_000,
179
+ max_tokens_per_minute: int = 1_000_000,
180
+ max_concurrent_requests: int = 1_000,
181
+ temperature: float = 0.75,
182
+ max_new_tokens: int = 1000,
183
+ reasoning_effort: Literal[None, "low", "medium", "high"] = None,
184
+ model_weights: Union[list[float], Literal["uniform", "rate_limit"]] = "uniform",
185
+ logprobs: bool = False,
186
+ top_logprobs: Optional[int] = None,
187
+ max_attempts: int = 5,
188
+ request_timeout: int = 30,
189
+ cache: Optional[Any] = None,
190
+ ):
191
+ model_names = model if isinstance(model, list) else [model]
192
+ return cls(
193
+ model_names=model_names,
194
+ max_requests_per_minute=max_requests_per_minute,
195
+ max_tokens_per_minute=max_tokens_per_minute,
196
+ max_concurrent_requests=max_concurrent_requests,
197
+ sampling_params=SamplingParams(
198
+ temperature=temperature,
199
+ max_new_tokens=max_new_tokens,
200
+ reasoning_effort=reasoning_effort,
201
+ ),
202
+ logprobs=logprobs,
203
+ top_logprobs=top_logprobs,
204
+ model_weights=model_weights,
205
+ max_attempts=max_attempts,
206
+ request_timeout=request_timeout,
207
+ cache=cache,
208
+ )
209
+
210
+ @property
211
+ def config(self):
212
+ return ClientConfig(
213
+ model_names=self.models,
214
+ model_weights=self.model_weights,
215
+ max_requests_per_minute=self.max_requests_per_minute,
216
+ max_tokens_per_minute=self.max_tokens_per_minute,
217
+ max_concurrent_requests=self.max_concurrent_requests,
218
+ max_attempts=self.max_attempts,
219
+ request_timeout=self.request_timeout,
220
+ sampling_params=self.sampling_params,
221
+ logprobs=self.logprobs,
222
+ top_logprobs=self.top_logprobs,
223
+ )
224
+
225
+ from typing import Union, Literal
226
+
227
+ @overload
228
+ async def process_prompts_async(
229
+ self,
230
+ prompts: Sequence[str | list[dict] | Conversation],
231
+ return_completions_only: bool,
232
+ show_progress: bool = ...,
233
+ dry_run: Literal[True] = ...,
234
+ verbose: bool = ...,
235
+ ) -> dict[str, int]: ...
236
+
237
+ @overload
238
+ async def process_prompts_async(
239
+ self,
240
+ prompts: Sequence[str | list[dict] | Conversation],
241
+ return_completions_only: Literal[True],
242
+ show_progress: bool = ...,
243
+ dry_run: Literal[False] = ...,
244
+ verbose: bool = ...,
245
+ ) -> list[str | None]: ...
246
+
247
+ @overload
248
+ async def process_prompts_async(
249
+ self,
250
+ prompts: Sequence[str | list[dict] | Conversation],
251
+ return_completions_only: Literal[False] = ...,
252
+ show_progress: bool = ...,
253
+ dry_run: Literal[False] = ...,
254
+ verbose: bool = ...,
255
+ ) -> list[APIResponse | None]: ...
256
+
257
+ async def process_prompts_async(
258
+ self,
259
+ prompts: Sequence[str | list[dict] | Conversation],
260
+ return_completions_only: bool = False,
261
+ show_progress: bool = True,
262
+ dry_run: bool = False,
263
+ verbose: bool = False,
264
+ ) -> list[APIResponse | None] | list[str | None] | dict[str, int]:
265
+ # if prompts are not Conversations, convert them.
266
+ # can only handle strings for now
267
+ prompts = [ # type: ignore
268
+ p
269
+ if isinstance(p, Conversation)
270
+ else Conversation.user(p)
271
+ if isinstance(p, str)
272
+ else None
273
+ for p in prompts
274
+ ]
275
+ if any(p is None for p in prompts):
276
+ raise ValueError("All prompts must be valid.")
277
+ ids = np.arange(len(prompts))
278
+
279
+ # if using cache, check for cached completions
280
+ if self.cache:
281
+ cached_results = [self.cache.get(prompt) for prompt in prompts]
282
+ cache_hit_ids = [
283
+ id for id, res in zip(ids, cached_results) if res is not None
284
+ ]
285
+ cache_hit_results = [res for res in cached_results if res is not None]
286
+ assert len(cache_hit_ids) == len(
287
+ cache_hit_results
288
+ ), "Cache hit ids and results must be the same length."
289
+ remaining_ids = np.array([i for i in ids if i not in cache_hit_ids])
290
+ remaining_prompts = [prompts[i] for i in remaining_ids]
291
+ if verbose:
292
+ print(f"{len(cache_hit_ids)} cache hits from previous completions.")
293
+ print(f"{len(remaining_ids)} prompts remaining after cache hits.")
294
+ print(f"Processing {len(remaining_prompts)} prompts.")
295
+
296
+ else:
297
+ cache_hit_ids = []
298
+ cache_hit_results = []
299
+ remaining_prompts = prompts
300
+ remaining_ids = ids
301
+
302
+ results: list[APIResponse | None] = [None for _ in range(len(prompts))]
303
+ if len(remaining_prompts) > 0:
304
+ # set up progress bar
305
+ pbar = tqdm(total=len(prompts), disable=(not show_progress))
306
+
307
+ # update progress bar with cache hits
308
+ pbar.update(len(cache_hit_ids))
309
+ api_task = None
310
+ if dry_run:
311
+ dry_run_results = api_prompts_dry_run(
312
+ ids,
313
+ prompts, # type: ignore -- fix later for dry running conversations
314
+ self.models,
315
+ self.model_weights,
316
+ self.sampling_params,
317
+ max_tokens_per_minute=self.max_tokens_per_minute,
318
+ max_requests_per_minute=self.max_requests_per_minute,
319
+ )
320
+ print("Dry run results:")
321
+ print(dry_run_results)
322
+ return dry_run_results
323
+
324
+ api_task = asyncio.create_task(
325
+ process_api_prompts_async(
326
+ ids,
327
+ prompts, # type: ignore -- fix later for dry running conversations
328
+ self.models,
329
+ self.model_weights,
330
+ self.sampling_params,
331
+ logprobs=self.logprobs,
332
+ top_logprobs=self.top_logprobs,
333
+ max_attempts=self.max_attempts,
334
+ max_tokens_per_minute=self.max_tokens_per_minute,
335
+ max_requests_per_minute=self.max_requests_per_minute,
336
+ max_concurrent_requests=self.max_concurrent_requests,
337
+ request_timeout=self.request_timeout,
338
+ progress_bar=pbar,
339
+ use_qps=self.use_qps,
340
+ verbose=verbose,
341
+ )
342
+ )
343
+ api_results: list[APIResponse] = await api_task
344
+ for res in api_results:
345
+ results[res.id] = res
346
+ # set to cache if result has a completion
347
+ if self.cache and res.completion:
348
+ self.cache.put(prompts[res.id], res)
349
+
350
+ # add cache hits back in
351
+ for id, res in zip(cache_hit_ids, cache_hit_results):
352
+ results[id] = res
353
+
354
+ if return_completions_only:
355
+ return [r.completion if r is not None else None for r in results]
356
+
357
+ return results
358
+
359
+ def process_prompts_sync(
360
+ self,
361
+ prompts: Sequence[str | list[dict] | Conversation],
362
+ return_completions_only: bool = False,
363
+ show_progress=True,
364
+ dry_run: bool = False,
365
+ verbose: bool = False,
366
+ ):
367
+ return asyncio.run(
368
+ self.process_prompts_async(
369
+ prompts=prompts,
370
+ return_completions_only=return_completions_only,
371
+ show_progress=show_progress,
372
+ dry_run=dry_run,
373
+ verbose=verbose,
374
+ )
375
+ )
376
+
377
+ def _submit_one_batch(self, batch_requests: list):
378
+ # save the file
379
+ import pandas as pd
380
+
381
+ pd.DataFrame(batch_requests).to_json(
382
+ "openai_requests_temp.jsonl", orient="records", lines=True
383
+ )
384
+
385
+ # upload the file
386
+ api_key = os.environ.get("OPENAI_API_KEY", None)
387
+ if api_key is None:
388
+ raise ValueError("OPENAI_API_KEY environment variable must be set.")
389
+ url = "https://api.openai.com/v1/files"
390
+ files = {
391
+ "file": (
392
+ "openai_requests_temp.jsonl",
393
+ open("openai_requests_temp.jsonl", "rb"),
394
+ ),
395
+ }
396
+ data = {
397
+ "purpose": "batch",
398
+ }
399
+ headers = {
400
+ "Authorization": f"Bearer {api_key}",
401
+ }
402
+ response = requests.post(url, files=files, data=data, headers=headers)
403
+
404
+ file_id = None
405
+ if response.status_code == 200:
406
+ print("File uploaded successfully")
407
+ data = response.json()
408
+ file_id = data["id"]
409
+
410
+ else:
411
+ print("File upload failed")
412
+ raise ValueError(f"Error uploading file: {response.text}")
413
+
414
+ url = "https://api.openai.com/v1/batches"
415
+ data = {
416
+ "input_file_id": file_id,
417
+ "endpoint": "/v1/chat/completions",
418
+ "completion_window": "24h",
419
+ }
420
+ response = requests.post(url, json=data, headers=headers)
421
+
422
+ batch_id = None
423
+ if response.status_code == 200:
424
+ data = response.json()
425
+ batch_id = data["id"]
426
+ print("Batch job started successfully: id = ", batch_id)
427
+ return batch_id
428
+ else:
429
+ print("Batch job failed to start")
430
+ raise ValueError(f"Error starting batch job: {response.text}")
431
+
432
+ def submit_batch_job(self, prompts: Sequence[str | list[dict] | Conversation]):
433
+ # make sure 1) only 1 model is used, 2) it's an openai model, 3) it supports json mode
434
+ if len(self.models) != 1:
435
+ raise ValueError("Batch jobs can only be submitted with a single model.")
436
+ model = self.models[0]
437
+ if registry[model].get("api_spec", None) != "openai":
438
+ raise ValueError("Batch jobs can only be submitted with OpenAI models.")
439
+
440
+ # if prompts are strings, convert them to message lists
441
+ prompts = [ # type: ignore
442
+ p
443
+ if isinstance(p, Conversation)
444
+ else Conversation.user(p)
445
+ if isinstance(p, str)
446
+ else None
447
+ for p in prompts
448
+ ]
449
+ if any(p is None for p in prompts):
450
+ raise ValueError("All prompts must be valid.")
451
+ ids = np.arange(len(prompts))
452
+
453
+ # create file with requests to send to batch api
454
+ batch_requests = []
455
+ for id, prompt in zip(ids, prompts):
456
+ assert isinstance(prompt, Conversation)
457
+ batch_requests.append(
458
+ {
459
+ "custom_id": str(id),
460
+ "method": "POST",
461
+ "url": "/v1/chat/completions",
462
+ "body": {
463
+ "model": self.models[0],
464
+ "messages": prompt.to_openai(),
465
+ "max_tokens": self.sampling_params[0].max_new_tokens,
466
+ "temperature": self.sampling_params[0].temperature,
467
+ "top_p": self.sampling_params[0].top_p,
468
+ },
469
+ }
470
+ )
471
+
472
+ # since the api only accepts up to 50,000 requests per batch job, we chunk into 50k chunks
473
+ BATCH_SIZE = 50_000
474
+ batches = [
475
+ batch_requests[i : i + BATCH_SIZE]
476
+ for i in range(0, len(batch_requests), BATCH_SIZE)
477
+ ]
478
+ batch_ids = []
479
+ for batch in tqdm(batches):
480
+ batch_id = self._submit_one_batch(batch)
481
+ batch_ids.append(batch_id)
482
+
483
+ print(f"Submitted {len(batches)} batch jobs.")
484
+ return batch_ids
485
+
486
+
487
+ def api_prompts_dry_run(
488
+ ids: Union[np.ndarray, list[int]],
489
+ prompts: list[Conversation],
490
+ models: Union[str, list[str]],
491
+ model_weights: list[float],
492
+ sampling_params: list[SamplingParams],
493
+ max_tokens_per_minute: int = 500_000,
494
+ max_requests_per_minute: int = 1_000,
495
+ ):
496
+ """
497
+ Count tokens and estimate costs for a batch of prompts.
498
+ """
499
+ results = []
500
+ for i, prompt in zip(ids, prompts):
501
+ # choose a model
502
+ model_idx = np.random.choice(range(len(models)), p=model_weights)
503
+ model = models[model_idx]
504
+
505
+ # dry run
506
+ input_tokens, output_tokens, min_cost, max_cost = prompt.dry_run(
507
+ model, sampling_params[model_idx].max_new_tokens
508
+ )
509
+ results.append(
510
+ {
511
+ "id": i,
512
+ "input_tokens": input_tokens,
513
+ "output_tokens": output_tokens,
514
+ "min_cost": min_cost,
515
+ "max_cost": max_cost,
516
+ }
517
+ )
518
+
519
+ combined_results: dict[str, Any] = {
520
+ "total_input_tokens": sum([r["input_tokens"] for r in results]),
521
+ "total_output_tokens": sum([r["output_tokens"] for r in results]),
522
+ "total_min_cost": sum([r["min_cost"] for r in results]),
523
+ "total_max_cost": sum([r["max_cost"] for r in results]),
524
+ }
525
+ minimum_time_tpm = combined_results["total_input_tokens"] / max_tokens_per_minute
526
+ maximum_time_tpm = (
527
+ combined_results["total_input_tokens"] + combined_results["total_output_tokens"]
528
+ ) / max_tokens_per_minute
529
+ minimum_time_rpm = len(prompts) / max_requests_per_minute
530
+
531
+ combined_results["minimum_time"] = max(minimum_time_tpm, minimum_time_rpm)
532
+ combined_results["maximum_time"] = max(maximum_time_tpm, minimum_time_rpm)
533
+ limiting_factor = None
534
+ if minimum_time_rpm > maximum_time_tpm:
535
+ limiting_factor = "requests"
536
+ elif minimum_time_rpm < minimum_time_tpm:
537
+ limiting_factor = "tokens"
538
+ else:
539
+ limiting_factor = "depends"
540
+ combined_results["limiting_factor"] = limiting_factor
541
+
542
+ return combined_results
543
+
544
+
545
+ async def process_api_prompts_async(
546
+ ids: Union[np.ndarray, list[int]],
547
+ prompts: list[Conversation],
548
+ models: Union[str, list[str]],
549
+ model_weights: list[float],
550
+ sampling_params: list[SamplingParams],
551
+ logprobs: bool,
552
+ top_logprobs: Optional[int],
553
+ max_attempts: int = 5,
554
+ max_tokens_per_minute: int = 500_000,
555
+ max_requests_per_minute: int = 1_000,
556
+ max_concurrent_requests: int = 1_000,
557
+ request_timeout: int = 30,
558
+ progress_bar: Optional[tqdm] = None,
559
+ use_qps: bool = False,
560
+ verbose: bool = False,
561
+ ):
562
+ """Processes API requests in parallel, throttling to stay under rate limits."""
563
+ # change ids to integer list
564
+ if isinstance(ids, np.ndarray):
565
+ ids = ids.tolist() # pyright: ignore
566
+
567
+ # normalize weights
568
+ model_weights = [w / sum(model_weights) for w in model_weights]
569
+
570
+ # constants
571
+ seconds_to_pause_after_rate_limit_error = 5
572
+ # seconds_to_sleep_each_loop = 0.003 # so concurrent tasks can run
573
+ # calculate dynamically so we don't throttle RPM
574
+ seconds_to_sleep_each_loop = (60.0 * 0.9) / max_requests_per_minute
575
+
576
+ # initialize trackers
577
+ retry_queue = asyncio.Queue()
578
+ status_tracker = StatusTracker()
579
+ next_request = None # variable to hold the next request to call
580
+
581
+ # initialize available capacity counts
582
+ # throttle over a 1 second window rather than minute,
583
+ # since some models limit RPS rather than RPM
584
+ if use_qps:
585
+ available_request_capacity = max_requests_per_minute / 60.0
586
+ available_token_capacity = max_tokens_per_minute
587
+ else:
588
+ available_request_capacity = max_requests_per_minute
589
+ available_token_capacity = max_tokens_per_minute
590
+ last_update_time = time.time()
591
+ last_pbar_update_time = time.time()
592
+
593
+ # initialize flags
594
+ prompts_not_finished = True
595
+
596
+ # initials model weights
597
+ if isinstance(models, str):
598
+ models = [models]
599
+ if not isinstance(models, list):
600
+ raise ValueError("models must be a string or a list of model strings.")
601
+ for model in models:
602
+ if model not in registry:
603
+ raise ValueError(f"Model {model} not found in registry.")
604
+
605
+ if model_weights is None:
606
+ # if not given, spread requests evenly across models
607
+ model_weights = [1 / len(models) for _ in models]
608
+ elif len(model_weights) != len(models):
609
+ raise ValueError(
610
+ "model_weights must be None or a list of the same length as models."
611
+ )
612
+ elif sum(model_weights) != 1:
613
+ model_weights = [w / sum(model_weights) for w in model_weights]
614
+
615
+ prompts_iter = iter(zip(ids, prompts))
616
+ results: list[APIRequestBase] = []
617
+ while True:
618
+ # get next request (if one is not already waiting for capacity)
619
+ if next_request is None:
620
+ if not retry_queue.empty():
621
+ next_request = retry_queue.get_nowait()
622
+ print(f"Retrying request {next_request.task_id}.")
623
+ elif prompts_not_finished:
624
+ try:
625
+ # get new request
626
+ id, prompt = next(prompts_iter)
627
+ # select model
628
+ model_idx = np.random.choice(range(len(models)), p=model_weights)
629
+ next_request = create_api_request(
630
+ task_id=id,
631
+ model_name=models[model_idx],
632
+ prompt=prompt,
633
+ request_timeout=request_timeout,
634
+ attempts_left=max_attempts,
635
+ status_tracker=status_tracker,
636
+ retry_queue=retry_queue,
637
+ results_arr=results,
638
+ sampling_params=sampling_params[model_idx],
639
+ logprobs=logprobs,
640
+ top_logprobs=top_logprobs,
641
+ pbar=progress_bar,
642
+ all_model_names=models,
643
+ all_sampling_params=sampling_params,
644
+ )
645
+ status_tracker.num_tasks_started += 1
646
+ status_tracker.num_tasks_in_progress += 1
647
+ results.append(next_request)
648
+
649
+ except StopIteration:
650
+ prompts_not_finished = False
651
+ if verbose:
652
+ print("API requests finished, only retries remain.")
653
+
654
+ # update available capacity
655
+ current_time = time.time()
656
+ seconds_since_update = current_time - last_update_time
657
+ available_request_capacity = min(
658
+ available_request_capacity
659
+ + max_requests_per_minute * seconds_since_update / 60.0,
660
+ max_requests_per_minute,
661
+ )
662
+ available_token_capacity = min(
663
+ available_token_capacity
664
+ + max_tokens_per_minute * seconds_since_update / 60.0,
665
+ max_tokens_per_minute,
666
+ )
667
+ last_update_time = current_time
668
+
669
+ # if enough capacity available, call API
670
+ limiting_factor = None
671
+ if next_request:
672
+ next_request_tokens = next_request.num_tokens
673
+ request_available = available_request_capacity >= 1
674
+ tokens_available = available_token_capacity >= next_request_tokens
675
+ concurrent_request_available = (
676
+ status_tracker.num_tasks_in_progress < max_concurrent_requests
677
+ )
678
+ if request_available and tokens_available and concurrent_request_available:
679
+ # update counters
680
+ available_request_capacity -= 1
681
+ available_token_capacity -= next_request_tokens
682
+ next_request.attempts_left -= 1
683
+
684
+ # call API
685
+ asyncio.create_task(next_request.call_api())
686
+ next_request = None # reset next_request to empty
687
+ else:
688
+ if not request_available:
689
+ limiting_factor = "Requests"
690
+ elif not concurrent_request_available:
691
+ limiting_factor = "Concurrent Requests"
692
+ elif not tokens_available:
693
+ limiting_factor = "Tokens"
694
+
695
+ # update pbar status
696
+ if progress_bar and (current_time - last_pbar_update_time > 1):
697
+ last_pbar_update_time = current_time
698
+ progress_bar.set_postfix(
699
+ {
700
+ "Token Capacity": f"{available_token_capacity/1_000:.1f}k",
701
+ "Req. Capacity": f"{available_request_capacity:.1f}",
702
+ "Reqs. in Progress": status_tracker.num_tasks_in_progress,
703
+ "Limiting Factor": limiting_factor,
704
+ }
705
+ )
706
+
707
+ # if all tasks are finished, break
708
+ if status_tracker.num_tasks_in_progress == 0:
709
+ break
710
+
711
+ # main loop sleeps briefly so concurrent tasks can run
712
+ await asyncio.sleep(seconds_to_sleep_each_loop)
713
+
714
+ # if a rate limit error was hit recently, pause to cool down
715
+ seconds_since_rate_limit_error = (
716
+ time.time() - status_tracker.time_of_last_rate_limit_error
717
+ )
718
+ if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error:
719
+ remaining_seconds_to_pause = (
720
+ seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
721
+ )
722
+ await asyncio.sleep(remaining_seconds_to_pause)
723
+ # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
724
+ print(
725
+ f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
726
+ )
727
+
728
+ # after finishing, log final status
729
+ if status_tracker.num_tasks_failed > 0:
730
+ print(
731
+ f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed."
732
+ )
733
+ if status_tracker.num_rate_limit_errors > 0:
734
+ print(
735
+ f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
736
+ )
737
+ if verbose:
738
+ print(
739
+ f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
740
+ )
741
+
742
+ # deduplicate results by id
743
+ deduplicated = {}
744
+ for request in results:
745
+ if request.task_id not in deduplicated:
746
+ deduplicated[request.task_id] = request.result[-1]
747
+ else:
748
+ current_response: APIResponse = deduplicated[request.task_id]
749
+ # only replace if the current request has no completion and the new one does
750
+ if (
751
+ request.result[-1].completion is not None
752
+ and current_response.completion is None
753
+ ):
754
+ deduplicated[request.task_id] = request.result[-1]
755
+
756
+ output = list(deduplicated.values())
757
+ if verbose:
758
+ print(f"Returning {len(output)} unique results.")
759
+
760
+ return output