lm-deluge 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +9 -1
- lm_deluge/agent.py +0 -0
- lm_deluge/api_requests/anthropic.py +107 -60
- lm_deluge/api_requests/base.py +107 -54
- lm_deluge/api_requests/bedrock.py +59 -22
- lm_deluge/api_requests/common.py +2 -1
- lm_deluge/api_requests/mistral.py +20 -22
- lm_deluge/api_requests/openai.py +283 -51
- lm_deluge/batches.py +498 -0
- lm_deluge/client.py +373 -634
- lm_deluge/computer_use/anthropic_tools.py +75 -0
- lm_deluge/{sampling_params.py → config.py} +10 -3
- lm_deluge/embed.py +17 -11
- lm_deluge/models.py +33 -0
- lm_deluge/prompt.py +173 -7
- lm_deluge/rerank.py +18 -12
- lm_deluge/tool.py +11 -1
- lm_deluge/tracker.py +212 -2
- lm_deluge/usage.py +114 -0
- lm_deluge/util/json.py +18 -1
- {lm_deluge-0.0.11.dist-info → lm_deluge-0.0.13.dist-info}/METADATA +78 -20
- lm_deluge-0.0.13.dist-info/RECORD +42 -0
- {lm_deluge-0.0.11.dist-info → lm_deluge-0.0.13.dist-info}/WHEEL +1 -1
- lm_deluge-0.0.11.dist-info/RECORD +0 -38
- {lm_deluge-0.0.11.dist-info → lm_deluge-0.0.13.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.11.dist-info → lm_deluge-0.0.13.dist-info}/top_level.txt +0 -0
lm_deluge/client.py
CHANGED
|
@@ -1,42 +1,140 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import requests
|
|
3
1
|
import asyncio
|
|
2
|
+
from typing import Any, Literal, Self, Sequence, overload
|
|
3
|
+
|
|
4
4
|
import numpy as np
|
|
5
|
-
import time
|
|
6
5
|
import yaml
|
|
7
|
-
from
|
|
8
|
-
from
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
from pydantic.functional_validators import model_validator
|
|
8
|
+
|
|
9
|
+
from lm_deluge.batches import (
|
|
10
|
+
submit_batches_anthropic,
|
|
11
|
+
submit_batches_oa,
|
|
12
|
+
wait_for_batch_completion_async,
|
|
13
|
+
)
|
|
14
|
+
from lm_deluge.prompt import CachePattern, Conversation, prompts_to_conversations
|
|
12
15
|
from lm_deluge.tool import Tool
|
|
13
16
|
|
|
14
|
-
from .tracker import StatusTracker
|
|
15
|
-
from .sampling_params import SamplingParams
|
|
16
|
-
from .models import registry
|
|
17
|
-
from .api_requests.base import APIResponse, APIRequestBase
|
|
18
17
|
from .api_requests import create_api_request
|
|
18
|
+
from .api_requests.base import APIRequestBase, APIResponse, deduplicate_responses
|
|
19
|
+
from .config import SamplingParams
|
|
20
|
+
from .models import registry
|
|
21
|
+
from .tracker import StatusTracker
|
|
22
|
+
|
|
19
23
|
# from .cache import LevelDBCache, SqliteCache
|
|
20
24
|
|
|
25
|
+
|
|
21
26
|
# TODO: get completions as they finish, not all at once at the end.
|
|
22
27
|
# relatedly, would be nice to cache them as they finish too.
|
|
23
|
-
|
|
24
28
|
# TODO: add optional max_input_tokens to client so we can reject long prompts to prevent abuse
|
|
29
|
+
class LLMClient(BaseModel):
|
|
30
|
+
"""
|
|
31
|
+
LLMClient abstracts all the fixed arguments to process_prompts_async, so you can create it
|
|
32
|
+
once and use it for more stuff without having to configure all the arguments.
|
|
33
|
+
Handles models, sampling params for each model, model weights, rate limits, etc.
|
|
34
|
+
"""
|
|
25
35
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
sampling_params
|
|
36
|
-
|
|
36
|
+
model_names: list[str] = ["gpt-4.1-mini"]
|
|
37
|
+
max_requests_per_minute: int = 1_000
|
|
38
|
+
max_tokens_per_minute: int = 100_000
|
|
39
|
+
max_concurrent_requests: int = 225
|
|
40
|
+
sampling_params: list[SamplingParams] = []
|
|
41
|
+
model_weights: list[float] | Literal["uniform", "dynamic"] = "uniform"
|
|
42
|
+
max_attempts: int = 5
|
|
43
|
+
request_timeout: int = 30
|
|
44
|
+
cache: Any = None
|
|
45
|
+
# sampling params - if provided, and sampling_params is not,
|
|
46
|
+
# these override the defaults
|
|
47
|
+
temperature: float = 0.75
|
|
48
|
+
top_p: float = 1.0
|
|
49
|
+
json_mode: bool = False
|
|
50
|
+
max_new_tokens: int = 512
|
|
51
|
+
reasoning_effort: Literal["low", "medium", "high", None] = None
|
|
37
52
|
logprobs: bool = False
|
|
38
53
|
top_logprobs: int | None = None
|
|
39
|
-
|
|
54
|
+
|
|
55
|
+
# NEW! Builder methods
|
|
56
|
+
def with_model(self, model: str):
|
|
57
|
+
self.model_names = [model]
|
|
58
|
+
return self
|
|
59
|
+
|
|
60
|
+
def with_models(self, models: list[str]):
|
|
61
|
+
self.model_names = models
|
|
62
|
+
return self
|
|
63
|
+
|
|
64
|
+
def with_limits(
|
|
65
|
+
self,
|
|
66
|
+
max_requests_per_minute: int | None = None,
|
|
67
|
+
max_tokens_per_minute: int | None = None,
|
|
68
|
+
max_concurrent_requests: int | None = None,
|
|
69
|
+
):
|
|
70
|
+
if max_requests_per_minute:
|
|
71
|
+
self.max_requests_per_minute = max_requests_per_minute
|
|
72
|
+
if max_tokens_per_minute:
|
|
73
|
+
self.max_tokens_per_minute = max_tokens_per_minute
|
|
74
|
+
if max_concurrent_requests:
|
|
75
|
+
self.max_concurrent_requests = max_concurrent_requests
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def models(self):
|
|
79
|
+
return self.model_names # why? idk
|
|
80
|
+
|
|
81
|
+
@model_validator(mode="before")
|
|
82
|
+
@classmethod
|
|
83
|
+
def fix_lists(cls, data) -> "LLMClient":
|
|
84
|
+
if isinstance(data["model_names"], str):
|
|
85
|
+
data["model_names"] = [data["model_names"]]
|
|
86
|
+
if "sampling_params" not in data or len(data.get("sampling_params", [])) == 0:
|
|
87
|
+
data["sampling_params"] = [
|
|
88
|
+
SamplingParams(
|
|
89
|
+
temperature=data.get("temperature", 0.75),
|
|
90
|
+
top_p=data.get("top_p", 1.0),
|
|
91
|
+
json_mode=data.get("json_mode", False),
|
|
92
|
+
max_new_tokens=data.get("max_new_tokens", 512),
|
|
93
|
+
reasoning_effort=data.get("reasoning_effort", None),
|
|
94
|
+
logprobs=data.get("logprobs", False),
|
|
95
|
+
top_logprobs=data.get("top_logprobs", None),
|
|
96
|
+
)
|
|
97
|
+
]
|
|
98
|
+
return data
|
|
99
|
+
|
|
100
|
+
@model_validator(mode="after")
|
|
101
|
+
def validate_client(self) -> Self:
|
|
102
|
+
if isinstance(self.model_names, str):
|
|
103
|
+
self.model_names = [self.model_names]
|
|
104
|
+
if any(m not in registry for m in self.model_names):
|
|
105
|
+
raise ValueError("all model_names must be in registry")
|
|
106
|
+
if isinstance(self.sampling_params, SamplingParams):
|
|
107
|
+
self.sampling_params = [self.sampling_params for _ in self.model_names]
|
|
108
|
+
elif len(self.sampling_params) != len(self.model_names):
|
|
109
|
+
raise ValueError("# models and # sampling params must match")
|
|
110
|
+
if self.model_weights == "uniform":
|
|
111
|
+
self.model_weights = [1 / len(self.model_names) for _ in self.model_names]
|
|
112
|
+
elif self.model_weights == "dynamic":
|
|
113
|
+
raise NotImplementedError("dynamic model weights not implemented yet")
|
|
114
|
+
# normalize weights
|
|
115
|
+
self.model_weights = [w / sum(self.model_weights) for w in self.model_weights]
|
|
116
|
+
|
|
117
|
+
# Validate logprobs settings across all sampling params
|
|
118
|
+
if self.logprobs or any(sp.logprobs for sp in self.sampling_params):
|
|
119
|
+
print("Logprobs enabled.")
|
|
120
|
+
for sp in self.sampling_params:
|
|
121
|
+
sp.logprobs = True
|
|
122
|
+
# set top_logprobs for each sp if provided and not set
|
|
123
|
+
if self.top_logprobs and not sp.top_logprobs:
|
|
124
|
+
sp.top_logprobs = self.top_logprobs
|
|
125
|
+
if sp.top_logprobs and not (0 <= sp.top_logprobs <= 20):
|
|
126
|
+
raise ValueError("top_logprobs must be 0-20")
|
|
127
|
+
if sp.top_logprobs and sp.max_new_tokens > 10:
|
|
128
|
+
print(
|
|
129
|
+
"WARNING: using top_logprobs can result in very large outputs. consider limiting max_new_tokens."
|
|
130
|
+
)
|
|
131
|
+
if not all(
|
|
132
|
+
registry[model].get("supports_logprobs") for model in self.models
|
|
133
|
+
):
|
|
134
|
+
raise ValueError(
|
|
135
|
+
"logprobs can only be enabled if all models support it."
|
|
136
|
+
)
|
|
137
|
+
return self
|
|
40
138
|
|
|
41
139
|
@classmethod
|
|
42
140
|
def from_dict(cls, config_dict: dict):
|
|
@@ -46,7 +144,7 @@ class ClientConfig:
|
|
|
46
144
|
]
|
|
47
145
|
else:
|
|
48
146
|
config_dict["sampling_params"] = SamplingParams(
|
|
49
|
-
config_dict["sampling_params"]
|
|
147
|
+
**config_dict["sampling_params"]
|
|
50
148
|
)
|
|
51
149
|
|
|
52
150
|
return cls(**config_dict)
|
|
@@ -56,183 +154,13 @@ class ClientConfig:
|
|
|
56
154
|
config_dict = yaml.safe_load(open(file_path))
|
|
57
155
|
return cls.from_dict(config_dict)
|
|
58
156
|
|
|
59
|
-
def to_dict(self):
|
|
60
|
-
if isinstance(self.sampling_params, list):
|
|
61
|
-
sp = [x.__dict__ for x in self.sampling_params]
|
|
62
|
-
else:
|
|
63
|
-
sp = self.sampling_params.__dict__
|
|
64
|
-
|
|
65
|
-
return {
|
|
66
|
-
"model_names": self.model_names,
|
|
67
|
-
"max_requests_per_minute": self.max_requests_per_minute,
|
|
68
|
-
"max_tokens_per_minute": self.max_tokens_per_minute,
|
|
69
|
-
"max_concurrent_requests": self.max_concurrent_requests,
|
|
70
|
-
"max_attempts": self.max_attempts,
|
|
71
|
-
"request_timeout": self.request_timeout,
|
|
72
|
-
"sampling_params": sp,
|
|
73
|
-
"model_weights": self.model_weights,
|
|
74
|
-
"logprobs": self.logprobs,
|
|
75
|
-
"top_logprobs": self.top_logprobs,
|
|
76
|
-
}
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
class LLMClient:
|
|
80
|
-
"""
|
|
81
|
-
LLMClient abstracts all the fixed arguments to process_prompts_async, so you can create it
|
|
82
|
-
once and use it for more stuff without having to configure all the arguments.
|
|
83
|
-
Handles models, sampling params for each model, model weights, rate limits, etc.
|
|
84
|
-
"""
|
|
85
|
-
|
|
86
|
-
def __init__(
|
|
87
|
-
self,
|
|
88
|
-
model_names: list[str],
|
|
89
|
-
*,
|
|
90
|
-
max_requests_per_minute: int,
|
|
91
|
-
max_tokens_per_minute: int,
|
|
92
|
-
max_concurrent_requests: int,
|
|
93
|
-
sampling_params: SamplingParams | list[SamplingParams] = SamplingParams(),
|
|
94
|
-
model_weights: list[float] | Literal["uniform", "rate_limit"] = "uniform",
|
|
95
|
-
max_attempts: int = 5,
|
|
96
|
-
request_timeout: int = 30,
|
|
97
|
-
logprobs: bool = False,
|
|
98
|
-
top_logprobs: int | None = None,
|
|
99
|
-
use_qps: bool = False,
|
|
100
|
-
debug: bool = False,
|
|
101
|
-
cache: Any = None,
|
|
102
|
-
):
|
|
103
|
-
self.models = model_names
|
|
104
|
-
if isinstance(sampling_params, SamplingParams):
|
|
105
|
-
self.sampling_params = [sampling_params for _ in model_names]
|
|
106
|
-
else:
|
|
107
|
-
if len(sampling_params) != len(model_names):
|
|
108
|
-
raise ValueError(
|
|
109
|
-
"If sampling_params is a list, it must have the same length as model_names."
|
|
110
|
-
)
|
|
111
|
-
self.sampling_params = sampling_params
|
|
112
|
-
if model_weights == "uniform":
|
|
113
|
-
self.model_weights = [1 / len(model_names) for _ in model_names]
|
|
114
|
-
elif model_weights == "rate_limit":
|
|
115
|
-
rpms = [registry[model]["requests_per_minute"] for model in model_names]
|
|
116
|
-
self.model_weights = [rpm / sum(rpms) for rpm in rpms]
|
|
117
|
-
elif sum(model_weights) != 1:
|
|
118
|
-
self.model_weights = [w / sum(model_weights) for w in model_weights]
|
|
119
|
-
else:
|
|
120
|
-
self.model_weights = model_weights
|
|
121
|
-
|
|
122
|
-
self.logprobs = logprobs
|
|
123
|
-
self.top_logprobs = top_logprobs
|
|
124
|
-
|
|
125
|
-
# logprobs and top_logprobs are only supported for OpenAI models
|
|
126
|
-
if self.logprobs:
|
|
127
|
-
for model in self.models:
|
|
128
|
-
if registry[model].get("supports_logprobs", False) is False:
|
|
129
|
-
raise ValueError(
|
|
130
|
-
"logprobs can only be enabled if all models support it."
|
|
131
|
-
)
|
|
132
|
-
if self.top_logprobs is None:
|
|
133
|
-
self.top_logprobs = 0 # will just return logprob of the chosen token
|
|
134
|
-
elif self.top_logprobs > 20 or self.top_logprobs < 0:
|
|
135
|
-
raise ValueError("top_logprobs must be between 0 and 20.")
|
|
136
|
-
for sp in self.sampling_params:
|
|
137
|
-
if sp.max_new_tokens > 10:
|
|
138
|
-
print(
|
|
139
|
-
"WARNING: using logprobs with large max_new_tokens can result in very large outputs. you may want to avoid saving these outputs to disk/db."
|
|
140
|
-
)
|
|
141
|
-
break
|
|
142
|
-
else:
|
|
143
|
-
self.top_logprobs = None
|
|
144
|
-
|
|
145
|
-
self.max_requests_per_minute = max_requests_per_minute
|
|
146
|
-
self.max_tokens_per_minute = max_tokens_per_minute
|
|
147
|
-
self.max_concurrent_requests = max_concurrent_requests
|
|
148
|
-
self.max_attempts = max_attempts
|
|
149
|
-
self.request_timeout = request_timeout
|
|
150
|
-
self.use_qps = use_qps
|
|
151
|
-
self.debug = (
|
|
152
|
-
debug # UNUSED/DEPRECATED i think? but dont want to break everything
|
|
153
|
-
)
|
|
154
|
-
self.cache = cache
|
|
155
|
-
|
|
156
|
-
@classmethod
|
|
157
|
-
def from_config(cls, config: ClientConfig, cache: Any = None):
|
|
158
|
-
return cls(
|
|
159
|
-
model_names=config.model_names,
|
|
160
|
-
max_requests_per_minute=config.max_requests_per_minute,
|
|
161
|
-
max_tokens_per_minute=config.max_tokens_per_minute,
|
|
162
|
-
max_concurrent_requests=config.max_concurrent_requests,
|
|
163
|
-
sampling_params=config.sampling_params,
|
|
164
|
-
model_weights=config.model_weights,
|
|
165
|
-
max_attempts=config.max_attempts,
|
|
166
|
-
request_timeout=config.request_timeout,
|
|
167
|
-
cache=cache,
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
@classmethod
|
|
171
|
-
def from_yaml(cls, file_path: str, cache: Any = None):
|
|
172
|
-
return cls.from_config(ClientConfig.from_yaml(file_path), cache=cache)
|
|
173
|
-
|
|
174
157
|
@classmethod
|
|
175
|
-
def basic(
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
temperature: float = 0.75,
|
|
182
|
-
max_new_tokens: int = 1000,
|
|
183
|
-
reasoning_effort: Literal[None, "low", "medium", "high"] = None,
|
|
184
|
-
model_weights: list[float] | Literal["uniform", "rate_limit"] = "uniform",
|
|
185
|
-
logprobs: bool = False,
|
|
186
|
-
top_logprobs: int | None = None,
|
|
187
|
-
max_attempts: int = 5,
|
|
188
|
-
request_timeout: int = 30,
|
|
189
|
-
cache: Any = None,
|
|
190
|
-
):
|
|
191
|
-
model_names = model if isinstance(model, list) else [model]
|
|
192
|
-
return cls(
|
|
193
|
-
model_names=model_names,
|
|
194
|
-
max_requests_per_minute=max_requests_per_minute,
|
|
195
|
-
max_tokens_per_minute=max_tokens_per_minute,
|
|
196
|
-
max_concurrent_requests=max_concurrent_requests,
|
|
197
|
-
sampling_params=SamplingParams(
|
|
198
|
-
temperature=temperature,
|
|
199
|
-
max_new_tokens=max_new_tokens,
|
|
200
|
-
reasoning_effort=reasoning_effort,
|
|
201
|
-
),
|
|
202
|
-
logprobs=logprobs,
|
|
203
|
-
top_logprobs=top_logprobs,
|
|
204
|
-
model_weights=model_weights,
|
|
205
|
-
max_attempts=max_attempts,
|
|
206
|
-
request_timeout=request_timeout,
|
|
207
|
-
cache=cache,
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
@property
|
|
211
|
-
def config(self):
|
|
212
|
-
return ClientConfig(
|
|
213
|
-
model_names=self.models,
|
|
214
|
-
model_weights=self.model_weights,
|
|
215
|
-
max_requests_per_minute=self.max_requests_per_minute,
|
|
216
|
-
max_tokens_per_minute=self.max_tokens_per_minute,
|
|
217
|
-
max_concurrent_requests=self.max_concurrent_requests,
|
|
218
|
-
max_attempts=self.max_attempts,
|
|
219
|
-
request_timeout=self.request_timeout,
|
|
220
|
-
sampling_params=self.sampling_params,
|
|
221
|
-
logprobs=self.logprobs,
|
|
222
|
-
top_logprobs=self.top_logprobs,
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
@overload
|
|
226
|
-
async def process_prompts_async(
|
|
227
|
-
self,
|
|
228
|
-
prompts: Sequence[str | list[dict] | Conversation],
|
|
229
|
-
*,
|
|
230
|
-
return_completions_only: bool,
|
|
231
|
-
show_progress: bool = ...,
|
|
232
|
-
dry_run: Literal[True],
|
|
233
|
-
verbose: bool = ...,
|
|
234
|
-
tools: list[Tool] | None = ...,
|
|
235
|
-
) -> dict[str, int]: ...
|
|
158
|
+
def basic(cls, model: str | list[str], **kwargs):
|
|
159
|
+
"""
|
|
160
|
+
Doesn't do anything differently now, kept for backwards compat.
|
|
161
|
+
"""
|
|
162
|
+
kwargs["model_names"] = model
|
|
163
|
+
return cls(**kwargs)
|
|
236
164
|
|
|
237
165
|
@overload
|
|
238
166
|
async def process_prompts_async(
|
|
@@ -241,9 +169,12 @@ class LLMClient:
|
|
|
241
169
|
*,
|
|
242
170
|
return_completions_only: Literal[True],
|
|
243
171
|
show_progress: bool = ...,
|
|
244
|
-
dry_run: bool = ...,
|
|
245
|
-
verbose: bool = ...,
|
|
246
172
|
tools: list[Tool] | None = ...,
|
|
173
|
+
cache: CachePattern | None = ...,
|
|
174
|
+
computer_use: bool = ...,
|
|
175
|
+
display_width: int = ...,
|
|
176
|
+
display_height: int = ...,
|
|
177
|
+
use_responses_api: bool = ...,
|
|
247
178
|
) -> list[str | None]: ...
|
|
248
179
|
|
|
249
180
|
@overload
|
|
@@ -253,9 +184,12 @@ class LLMClient:
|
|
|
253
184
|
*,
|
|
254
185
|
return_completions_only: Literal[False] = ...,
|
|
255
186
|
show_progress: bool = ...,
|
|
256
|
-
dry_run: bool = ...,
|
|
257
|
-
verbose: bool = ...,
|
|
258
187
|
tools: list[Tool] | None = ...,
|
|
188
|
+
cache: CachePattern | None = ...,
|
|
189
|
+
computer_use: bool = ...,
|
|
190
|
+
display_width: int = ...,
|
|
191
|
+
display_height: int = ...,
|
|
192
|
+
use_responses_api: bool = ...,
|
|
259
193
|
) -> list[APIResponse | None]: ...
|
|
260
194
|
|
|
261
195
|
async def process_prompts_async(
|
|
@@ -264,22 +198,15 @@ class LLMClient:
|
|
|
264
198
|
*,
|
|
265
199
|
return_completions_only: bool = False,
|
|
266
200
|
show_progress: bool = True,
|
|
267
|
-
dry_run: bool = False,
|
|
268
|
-
verbose: bool = False,
|
|
269
201
|
tools: list[Tool] | None = None,
|
|
202
|
+
cache: CachePattern | None = None,
|
|
203
|
+
computer_use: bool = False,
|
|
204
|
+
display_width: int = 1024,
|
|
205
|
+
display_height: int = 768,
|
|
206
|
+
use_responses_api: bool = False,
|
|
270
207
|
) -> list[APIResponse | None] | list[str | None] | dict[str, int]:
|
|
271
208
|
# if prompts are not Conversations, convert them.
|
|
272
|
-
|
|
273
|
-
prompts = [ # type: ignore
|
|
274
|
-
p
|
|
275
|
-
if isinstance(p, Conversation)
|
|
276
|
-
else Conversation.user(p)
|
|
277
|
-
if isinstance(p, str)
|
|
278
|
-
else None
|
|
279
|
-
for p in prompts
|
|
280
|
-
]
|
|
281
|
-
if any(p is None for p in prompts):
|
|
282
|
-
raise ValueError("All prompts must be valid.")
|
|
209
|
+
prompts = prompts_to_conversations(prompts)
|
|
283
210
|
ids = np.arange(len(prompts))
|
|
284
211
|
|
|
285
212
|
# if using cache, check for cached completions
|
|
@@ -294,10 +221,9 @@ class LLMClient:
|
|
|
294
221
|
), "Cache hit ids and results must be the same length."
|
|
295
222
|
remaining_ids = np.array([i for i in ids if i not in cache_hit_ids])
|
|
296
223
|
remaining_prompts = [prompts[i] for i in remaining_ids]
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
print(f"Processing {len(remaining_prompts)} prompts.")
|
|
224
|
+
print(
|
|
225
|
+
f"{len(cache_hit_ids)} cache hits; {len(remaining_ids)} prompts remaining."
|
|
226
|
+
)
|
|
301
227
|
|
|
302
228
|
else:
|
|
303
229
|
cache_hit_ids = []
|
|
@@ -307,47 +233,137 @@ class LLMClient:
|
|
|
307
233
|
|
|
308
234
|
results: list[APIResponse | None] = [None for _ in range(len(prompts))]
|
|
309
235
|
if len(remaining_prompts) > 0:
|
|
310
|
-
#
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
prompts, # type: ignore -- fix later for dry running conversations
|
|
320
|
-
self.models,
|
|
321
|
-
self.model_weights,
|
|
322
|
-
self.sampling_params,
|
|
323
|
-
max_tokens_per_minute=self.max_tokens_per_minute,
|
|
324
|
-
max_requests_per_minute=self.max_requests_per_minute,
|
|
325
|
-
)
|
|
326
|
-
print("Dry run results:")
|
|
327
|
-
print(dry_run_results)
|
|
328
|
-
return dry_run_results
|
|
329
|
-
|
|
330
|
-
api_task = asyncio.create_task(
|
|
331
|
-
process_api_prompts_async(
|
|
332
|
-
ids,
|
|
333
|
-
prompts, # type: ignore -- fix later for dry running conversations
|
|
334
|
-
self.models,
|
|
335
|
-
self.model_weights,
|
|
336
|
-
self.sampling_params,
|
|
337
|
-
logprobs=self.logprobs,
|
|
338
|
-
top_logprobs=self.top_logprobs,
|
|
339
|
-
max_attempts=self.max_attempts,
|
|
340
|
-
max_tokens_per_minute=self.max_tokens_per_minute,
|
|
341
|
-
max_requests_per_minute=self.max_requests_per_minute,
|
|
342
|
-
max_concurrent_requests=self.max_concurrent_requests,
|
|
343
|
-
request_timeout=self.request_timeout,
|
|
344
|
-
progress_bar=pbar,
|
|
345
|
-
use_qps=self.use_qps,
|
|
346
|
-
verbose=verbose,
|
|
347
|
-
tools=tools,
|
|
348
|
-
)
|
|
236
|
+
# Create StatusTracker with integrated progress bar
|
|
237
|
+
tracker = StatusTracker(
|
|
238
|
+
max_requests_per_minute=self.max_requests_per_minute,
|
|
239
|
+
max_tokens_per_minute=self.max_tokens_per_minute,
|
|
240
|
+
max_concurrent_requests=self.max_concurrent_requests,
|
|
241
|
+
use_progress_bar=show_progress,
|
|
242
|
+
progress_bar_total=len(prompts),
|
|
243
|
+
progress_bar_disable=not show_progress,
|
|
244
|
+
use_rich=show_progress, # Disable Rich if progress is disabled
|
|
349
245
|
)
|
|
350
|
-
|
|
246
|
+
|
|
247
|
+
# Initialize progress bar and update with cache hits
|
|
248
|
+
tracker.init_progress_bar()
|
|
249
|
+
if len(cache_hit_ids) > 0:
|
|
250
|
+
tracker.update_pbar(len(cache_hit_ids))
|
|
251
|
+
|
|
252
|
+
# api_task = asyncio.create_task(
|
|
253
|
+
# process_api_prompts_async(
|
|
254
|
+
# ids,
|
|
255
|
+
# prompts, # type: ignore -- fix later for dry running conversations
|
|
256
|
+
# self.models,
|
|
257
|
+
# self.model_weights, # type: ignore
|
|
258
|
+
# self.sampling_params, # type: ignore
|
|
259
|
+
# max_attempts=self.max_attempts,
|
|
260
|
+
# max_concurrent_requests=self.max_concurrent_requests,
|
|
261
|
+
# request_timeout=self.request_timeout,
|
|
262
|
+
# status_tracker=tracker,
|
|
263
|
+
# tools=tools,
|
|
264
|
+
# cache=cache,
|
|
265
|
+
# computer_use=computer_use,
|
|
266
|
+
# display_width=display_width,
|
|
267
|
+
# display_height=display_height,
|
|
268
|
+
# use_responses_api=use_responses_api,
|
|
269
|
+
# )
|
|
270
|
+
# )
|
|
271
|
+
# async def process_api_prompts_async(
|
|
272
|
+
|
|
273
|
+
# models: str | list[str],
|
|
274
|
+
# model_weights: list[float],
|
|
275
|
+
# sampling_params: list[SamplingParams],
|
|
276
|
+
# max_attempts: int = 5,
|
|
277
|
+
# max_concurrent_requests: int = 1_000,
|
|
278
|
+
# request_timeout: int = 30,
|
|
279
|
+
# status_tracker: StatusTracker | None = None,
|
|
280
|
+
# tools: list[Tool] | None = None,
|
|
281
|
+
# cache: CachePattern | None = None,
|
|
282
|
+
# computer_use: bool = False,
|
|
283
|
+
# display_width: int = 1024,
|
|
284
|
+
# display_height: int = 768,
|
|
285
|
+
# use_responses_api: bool = False,
|
|
286
|
+
# ):
|
|
287
|
+
if isinstance(ids, np.ndarray):
|
|
288
|
+
ids = ids.tolist() # pyright: ignore
|
|
289
|
+
|
|
290
|
+
# calculate dynamically so we don't throttle RPM
|
|
291
|
+
seconds_to_sleep_each_loop = (60.0 * 0.9) / tracker.max_requests_per_minute
|
|
292
|
+
next_request = None # variable to hold the next request to call
|
|
293
|
+
prompts_not_finished = True
|
|
294
|
+
prompts_iter = iter(zip(ids, prompts))
|
|
295
|
+
requests: list[APIRequestBase] = []
|
|
296
|
+
assert tracker.retry_queue, "retry queue not initialized"
|
|
297
|
+
while True:
|
|
298
|
+
# get next request (if one is not already waiting for capacity)
|
|
299
|
+
if next_request is None:
|
|
300
|
+
if not tracker.retry_queue.empty():
|
|
301
|
+
next_request = tracker.retry_queue.get_nowait()
|
|
302
|
+
print(f"Retrying request {next_request.task_id}.")
|
|
303
|
+
elif prompts_not_finished:
|
|
304
|
+
try:
|
|
305
|
+
# get new request
|
|
306
|
+
id, prompt = next(prompts_iter)
|
|
307
|
+
# select model
|
|
308
|
+
assert isinstance(self.model_weights, list)
|
|
309
|
+
model_idx = np.random.choice(
|
|
310
|
+
range(len(self.models)), p=self.model_weights
|
|
311
|
+
)
|
|
312
|
+
next_request = create_api_request(
|
|
313
|
+
task_id=id,
|
|
314
|
+
model_name=self.models[model_idx],
|
|
315
|
+
prompt=prompt, # type: ignore
|
|
316
|
+
request_timeout=self.request_timeout,
|
|
317
|
+
attempts_left=self.max_attempts,
|
|
318
|
+
status_tracker=tracker,
|
|
319
|
+
results_arr=requests,
|
|
320
|
+
sampling_params=self.sampling_params[model_idx],
|
|
321
|
+
all_model_names=self.models,
|
|
322
|
+
all_sampling_params=self.sampling_params,
|
|
323
|
+
tools=tools,
|
|
324
|
+
cache=cache,
|
|
325
|
+
computer_use=computer_use,
|
|
326
|
+
display_width=display_width,
|
|
327
|
+
display_height=display_height,
|
|
328
|
+
use_responses_api=use_responses_api,
|
|
329
|
+
)
|
|
330
|
+
requests.append(next_request)
|
|
331
|
+
|
|
332
|
+
except StopIteration:
|
|
333
|
+
prompts_not_finished = False
|
|
334
|
+
# print("API requests finished, only retries remain.")
|
|
335
|
+
|
|
336
|
+
# update available capacity
|
|
337
|
+
tracker.update_capacity()
|
|
338
|
+
|
|
339
|
+
# if enough capacity available, call API
|
|
340
|
+
if next_request:
|
|
341
|
+
next_request_tokens = next_request.num_tokens
|
|
342
|
+
if tracker.check_capacity(next_request_tokens):
|
|
343
|
+
tracker.set_limiting_factor(None)
|
|
344
|
+
next_request.attempts_left -= 1
|
|
345
|
+
# call API
|
|
346
|
+
asyncio.create_task(next_request.call_api())
|
|
347
|
+
next_request = None # reset next_request to empty
|
|
348
|
+
# update pbar status
|
|
349
|
+
tracker.update_pbar()
|
|
350
|
+
|
|
351
|
+
# if all tasks are finished, break
|
|
352
|
+
if tracker.num_tasks_in_progress == 0:
|
|
353
|
+
break
|
|
354
|
+
|
|
355
|
+
# main loop sleeps briefly so concurrent tasks can run
|
|
356
|
+
await asyncio.sleep(seconds_to_sleep_each_loop)
|
|
357
|
+
|
|
358
|
+
# if a rate limit error was hit recently, pause to cool down
|
|
359
|
+
if tracker.seconds_to_pause > 0:
|
|
360
|
+
await asyncio.sleep(tracker.seconds_to_pause)
|
|
361
|
+
print(f"Pausing {tracker.seconds_to_pause}s to cool down.")
|
|
362
|
+
|
|
363
|
+
# after finishing, log final status
|
|
364
|
+
tracker.log_final_status()
|
|
365
|
+
# deduplicate results by id
|
|
366
|
+
api_results = deduplicate_responses(requests)
|
|
351
367
|
for res in api_results:
|
|
352
368
|
results[res.id] = res
|
|
353
369
|
# set to cache if result has a completion
|
|
@@ -370,393 +386,116 @@ class LLMClient:
|
|
|
370
386
|
*,
|
|
371
387
|
return_completions_only: bool = False,
|
|
372
388
|
show_progress=True,
|
|
373
|
-
dry_run: bool = False,
|
|
374
|
-
verbose: bool = False,
|
|
375
389
|
tools: list[Tool] | None = None,
|
|
390
|
+
cache: CachePattern | None = None,
|
|
376
391
|
):
|
|
377
392
|
return asyncio.run(
|
|
378
393
|
self.process_prompts_async(
|
|
379
394
|
prompts=prompts,
|
|
380
395
|
return_completions_only=return_completions_only,
|
|
381
396
|
show_progress=show_progress,
|
|
382
|
-
dry_run=dry_run,
|
|
383
|
-
verbose=verbose,
|
|
384
397
|
tools=tools,
|
|
398
|
+
cache=cache,
|
|
385
399
|
)
|
|
386
400
|
)
|
|
387
401
|
|
|
388
|
-
def
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
data = {
|
|
408
|
-
"purpose": "batch",
|
|
409
|
-
}
|
|
410
|
-
headers = {
|
|
411
|
-
"Authorization": f"Bearer {api_key}",
|
|
412
|
-
}
|
|
413
|
-
response = requests.post(url, files=files, data=data, headers=headers)
|
|
414
|
-
|
|
415
|
-
file_id = None
|
|
416
|
-
if response.status_code == 200:
|
|
417
|
-
print("File uploaded successfully")
|
|
418
|
-
data = response.json()
|
|
419
|
-
file_id = data["id"]
|
|
420
|
-
|
|
421
|
-
else:
|
|
422
|
-
print("File upload failed")
|
|
423
|
-
raise ValueError(f"Error uploading file: {response.text}")
|
|
424
|
-
|
|
425
|
-
url = "https://api.openai.com/v1/batches"
|
|
426
|
-
data = {
|
|
427
|
-
"input_file_id": file_id,
|
|
428
|
-
"endpoint": "/v1/chat/completions",
|
|
429
|
-
"completion_window": "24h",
|
|
430
|
-
}
|
|
431
|
-
response = requests.post(url, json=data, headers=headers)
|
|
432
|
-
|
|
433
|
-
batch_id = None
|
|
434
|
-
if response.status_code == 200:
|
|
435
|
-
data = response.json()
|
|
436
|
-
batch_id = data["id"]
|
|
437
|
-
print("Batch job started successfully: id = ", batch_id)
|
|
438
|
-
return batch_id
|
|
439
|
-
else:
|
|
440
|
-
print("Batch job failed to start")
|
|
441
|
-
raise ValueError(f"Error starting batch job: {response.text}")
|
|
442
|
-
|
|
443
|
-
def submit_batch_job(self, prompts: Sequence[str | list[dict] | Conversation]):
|
|
444
|
-
# make sure 1) only 1 model is used, 2) it's an openai model, 3) it supports json mode
|
|
402
|
+
async def submit_batch_job(
|
|
403
|
+
self,
|
|
404
|
+
prompts: Sequence[str | list[dict] | Conversation],
|
|
405
|
+
*,
|
|
406
|
+
tools: list[Tool] | None = None,
|
|
407
|
+
cache: CachePattern | None = None,
|
|
408
|
+
):
|
|
409
|
+
"""Submit a batch job asynchronously, automatically detecting the provider based on model.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
prompts: List of prompts to process
|
|
413
|
+
wait_for_completion: If True, poll until completion and return results
|
|
414
|
+
poll_interval: Seconds to wait between status checks when polling
|
|
415
|
+
tools: Optional tools to include in requests (Anthropic only)
|
|
416
|
+
cache: Optional cache pattern for requests (Anthropic only)
|
|
417
|
+
|
|
418
|
+
Returns: list of batch_ids
|
|
419
|
+
"""
|
|
420
|
+
assert isinstance(self.sampling_params, list)
|
|
445
421
|
if len(self.models) != 1:
|
|
446
422
|
raise ValueError("Batch jobs can only be submitted with a single model.")
|
|
447
423
|
model = self.models[0]
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
for p in prompts
|
|
459
|
-
]
|
|
460
|
-
if any(p is None for p in prompts):
|
|
461
|
-
raise ValueError("All prompts must be valid.")
|
|
462
|
-
ids = np.arange(len(prompts))
|
|
463
|
-
|
|
464
|
-
# create file with requests to send to batch api
|
|
465
|
-
batch_requests = []
|
|
466
|
-
for id, prompt in zip(ids, prompts):
|
|
467
|
-
assert isinstance(prompt, Conversation)
|
|
468
|
-
batch_requests.append(
|
|
469
|
-
{
|
|
470
|
-
"custom_id": str(id),
|
|
471
|
-
"method": "POST",
|
|
472
|
-
"url": "/v1/chat/completions",
|
|
473
|
-
"body": {
|
|
474
|
-
"model": self.models[0],
|
|
475
|
-
"messages": prompt.to_openai(),
|
|
476
|
-
"max_tokens": self.sampling_params[0].max_new_tokens,
|
|
477
|
-
"temperature": self.sampling_params[0].temperature,
|
|
478
|
-
"top_p": self.sampling_params[0].top_p,
|
|
479
|
-
},
|
|
480
|
-
}
|
|
424
|
+
api_spec = registry[model].get("api_spec", None)
|
|
425
|
+
|
|
426
|
+
if api_spec == "openai":
|
|
427
|
+
return await submit_batches_oa(model, self.sampling_params[0], prompts)
|
|
428
|
+
elif api_spec == "anthropic":
|
|
429
|
+
return await submit_batches_anthropic(
|
|
430
|
+
model,
|
|
431
|
+
self.sampling_params[0],
|
|
432
|
+
prompts,
|
|
433
|
+
cache=cache,
|
|
481
434
|
)
|
|
435
|
+
else:
|
|
436
|
+
raise ValueError(f"Batch processing not supported for API spec: {api_spec}")
|
|
482
437
|
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
]
|
|
489
|
-
batch_ids = []
|
|
490
|
-
for batch in tqdm(batches):
|
|
491
|
-
batch_id = self._submit_one_batch(batch)
|
|
492
|
-
batch_ids.append(batch_id)
|
|
493
|
-
|
|
494
|
-
print(f"Submitted {len(batches)} batch jobs.")
|
|
495
|
-
return batch_ids
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
def api_prompts_dry_run(
|
|
499
|
-
ids: np.ndarray | list[int],
|
|
500
|
-
prompts: list[Conversation],
|
|
501
|
-
models: str | list[str],
|
|
502
|
-
model_weights: list[float],
|
|
503
|
-
sampling_params: list[SamplingParams],
|
|
504
|
-
max_tokens_per_minute: int = 500_000,
|
|
505
|
-
max_requests_per_minute: int = 1_000,
|
|
506
|
-
):
|
|
507
|
-
"""
|
|
508
|
-
Count tokens and estimate costs for a batch of prompts.
|
|
509
|
-
"""
|
|
510
|
-
results = []
|
|
511
|
-
for i, prompt in zip(ids, prompts):
|
|
512
|
-
# choose a model
|
|
513
|
-
model_idx = np.random.choice(range(len(models)), p=model_weights)
|
|
514
|
-
model = models[model_idx]
|
|
515
|
-
|
|
516
|
-
# dry run
|
|
517
|
-
input_tokens, output_tokens, min_cost, max_cost = prompt.dry_run(
|
|
518
|
-
model, sampling_params[model_idx].max_new_tokens
|
|
519
|
-
)
|
|
520
|
-
results.append(
|
|
521
|
-
{
|
|
522
|
-
"id": i,
|
|
523
|
-
"input_tokens": input_tokens,
|
|
524
|
-
"output_tokens": output_tokens,
|
|
525
|
-
"min_cost": min_cost,
|
|
526
|
-
"max_cost": max_cost,
|
|
527
|
-
}
|
|
528
|
-
)
|
|
529
|
-
|
|
530
|
-
combined_results: dict[str, Any] = {
|
|
531
|
-
"total_input_tokens": sum([r["input_tokens"] for r in results]),
|
|
532
|
-
"total_output_tokens": sum([r["output_tokens"] for r in results]),
|
|
533
|
-
"total_min_cost": sum([r["min_cost"] for r in results]),
|
|
534
|
-
"total_max_cost": sum([r["max_cost"] for r in results]),
|
|
535
|
-
}
|
|
536
|
-
minimum_time_tpm = combined_results["total_input_tokens"] / max_tokens_per_minute
|
|
537
|
-
maximum_time_tpm = (
|
|
538
|
-
combined_results["total_input_tokens"] + combined_results["total_output_tokens"]
|
|
539
|
-
) / max_tokens_per_minute
|
|
540
|
-
minimum_time_rpm = len(prompts) / max_requests_per_minute
|
|
541
|
-
|
|
542
|
-
combined_results["minimum_time"] = max(minimum_time_tpm, minimum_time_rpm)
|
|
543
|
-
combined_results["maximum_time"] = max(maximum_time_tpm, minimum_time_rpm)
|
|
544
|
-
limiting_factor = None
|
|
545
|
-
if minimum_time_rpm > maximum_time_tpm:
|
|
546
|
-
limiting_factor = "requests"
|
|
547
|
-
elif minimum_time_rpm < minimum_time_tpm:
|
|
548
|
-
limiting_factor = "tokens"
|
|
549
|
-
else:
|
|
550
|
-
limiting_factor = "depends"
|
|
551
|
-
combined_results["limiting_factor"] = limiting_factor
|
|
552
|
-
|
|
553
|
-
return combined_results
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
async def process_api_prompts_async(
|
|
557
|
-
ids: np.ndarray | list[int],
|
|
558
|
-
prompts: list[Conversation],
|
|
559
|
-
models: str | list[str],
|
|
560
|
-
model_weights: list[float],
|
|
561
|
-
sampling_params: list[SamplingParams],
|
|
562
|
-
logprobs: bool,
|
|
563
|
-
top_logprobs: int | None,
|
|
564
|
-
max_attempts: int = 5,
|
|
565
|
-
max_tokens_per_minute: int = 500_000,
|
|
566
|
-
max_requests_per_minute: int = 1_000,
|
|
567
|
-
max_concurrent_requests: int = 1_000,
|
|
568
|
-
request_timeout: int = 30,
|
|
569
|
-
progress_bar: tqdm | None = None,
|
|
570
|
-
use_qps: bool = False,
|
|
571
|
-
verbose: bool = False,
|
|
572
|
-
tools: list[Tool] | None = None,
|
|
573
|
-
):
|
|
574
|
-
"""Processes API requests in parallel, throttling to stay under rate limits."""
|
|
575
|
-
# change ids to integer list
|
|
576
|
-
if isinstance(ids, np.ndarray):
|
|
577
|
-
ids = ids.tolist() # pyright: ignore
|
|
578
|
-
|
|
579
|
-
# normalize weights
|
|
580
|
-
model_weights = [w / sum(model_weights) for w in model_weights]
|
|
581
|
-
|
|
582
|
-
# constants
|
|
583
|
-
seconds_to_pause_after_rate_limit_error = 5
|
|
584
|
-
# seconds_to_sleep_each_loop = 0.003 # so concurrent tasks can run
|
|
585
|
-
# calculate dynamically so we don't throttle RPM
|
|
586
|
-
seconds_to_sleep_each_loop = (60.0 * 0.9) / max_requests_per_minute
|
|
587
|
-
|
|
588
|
-
# initialize trackers
|
|
589
|
-
retry_queue = asyncio.Queue()
|
|
590
|
-
status_tracker = StatusTracker()
|
|
591
|
-
next_request = None # variable to hold the next request to call
|
|
592
|
-
|
|
593
|
-
# initialize available capacity counts
|
|
594
|
-
# throttle over a 1 second window rather than minute,
|
|
595
|
-
# since some models limit RPS rather than RPM
|
|
596
|
-
if use_qps:
|
|
597
|
-
available_request_capacity = max_requests_per_minute / 60.0
|
|
598
|
-
available_token_capacity = max_tokens_per_minute
|
|
599
|
-
else:
|
|
600
|
-
available_request_capacity = max_requests_per_minute
|
|
601
|
-
available_token_capacity = max_tokens_per_minute
|
|
602
|
-
last_update_time = time.time()
|
|
603
|
-
last_pbar_update_time = time.time()
|
|
604
|
-
|
|
605
|
-
# initialize flags
|
|
606
|
-
prompts_not_finished = True
|
|
607
|
-
|
|
608
|
-
# initials model weights
|
|
609
|
-
if isinstance(models, str):
|
|
610
|
-
models = [models]
|
|
611
|
-
if not isinstance(models, list):
|
|
612
|
-
raise ValueError("models must be a string or a list of model strings.")
|
|
613
|
-
for model in models:
|
|
614
|
-
if model not in registry:
|
|
615
|
-
raise ValueError(f"Model {model} not found in registry.")
|
|
616
|
-
|
|
617
|
-
if model_weights is None:
|
|
618
|
-
# if not given, spread requests evenly across models
|
|
619
|
-
model_weights = [1 / len(models) for _ in models]
|
|
620
|
-
elif len(model_weights) != len(models):
|
|
621
|
-
raise ValueError(
|
|
622
|
-
"model_weights must be None or a list of the same length as models."
|
|
623
|
-
)
|
|
624
|
-
elif sum(model_weights) != 1:
|
|
625
|
-
model_weights = [w / sum(model_weights) for w in model_weights]
|
|
626
|
-
|
|
627
|
-
prompts_iter = iter(zip(ids, prompts))
|
|
628
|
-
results: list[APIRequestBase] = []
|
|
629
|
-
while True:
|
|
630
|
-
# get next request (if one is not already waiting for capacity)
|
|
631
|
-
if next_request is None:
|
|
632
|
-
if not retry_queue.empty():
|
|
633
|
-
next_request = retry_queue.get_nowait()
|
|
634
|
-
print(f"Retrying request {next_request.task_id}.")
|
|
635
|
-
elif prompts_not_finished:
|
|
636
|
-
try:
|
|
637
|
-
# get new request
|
|
638
|
-
id, prompt = next(prompts_iter)
|
|
639
|
-
# select model
|
|
640
|
-
model_idx = np.random.choice(range(len(models)), p=model_weights)
|
|
641
|
-
next_request = create_api_request(
|
|
642
|
-
task_id=id,
|
|
643
|
-
model_name=models[model_idx],
|
|
644
|
-
prompt=prompt,
|
|
645
|
-
request_timeout=request_timeout,
|
|
646
|
-
attempts_left=max_attempts,
|
|
647
|
-
status_tracker=status_tracker,
|
|
648
|
-
retry_queue=retry_queue,
|
|
649
|
-
results_arr=results,
|
|
650
|
-
sampling_params=sampling_params[model_idx],
|
|
651
|
-
logprobs=logprobs,
|
|
652
|
-
top_logprobs=top_logprobs,
|
|
653
|
-
pbar=progress_bar,
|
|
654
|
-
all_model_names=models,
|
|
655
|
-
all_sampling_params=sampling_params,
|
|
656
|
-
tools=tools,
|
|
657
|
-
)
|
|
658
|
-
status_tracker.num_tasks_started += 1
|
|
659
|
-
status_tracker.num_tasks_in_progress += 1
|
|
660
|
-
results.append(next_request)
|
|
661
|
-
|
|
662
|
-
except StopIteration:
|
|
663
|
-
prompts_not_finished = False
|
|
664
|
-
if verbose:
|
|
665
|
-
print("API requests finished, only retries remain.")
|
|
666
|
-
|
|
667
|
-
# update available capacity
|
|
668
|
-
current_time = time.time()
|
|
669
|
-
seconds_since_update = current_time - last_update_time
|
|
670
|
-
available_request_capacity = min(
|
|
671
|
-
available_request_capacity
|
|
672
|
-
+ max_requests_per_minute * seconds_since_update / 60.0,
|
|
673
|
-
max_requests_per_minute,
|
|
674
|
-
)
|
|
675
|
-
available_token_capacity = min(
|
|
676
|
-
available_token_capacity
|
|
677
|
-
+ max_tokens_per_minute * seconds_since_update / 60.0,
|
|
678
|
-
max_tokens_per_minute,
|
|
679
|
-
)
|
|
680
|
-
last_update_time = current_time
|
|
681
|
-
|
|
682
|
-
# if enough capacity available, call API
|
|
683
|
-
limiting_factor = None
|
|
684
|
-
if next_request:
|
|
685
|
-
next_request_tokens = next_request.num_tokens
|
|
686
|
-
request_available = available_request_capacity >= 1
|
|
687
|
-
tokens_available = available_token_capacity >= next_request_tokens
|
|
688
|
-
concurrent_request_available = (
|
|
689
|
-
status_tracker.num_tasks_in_progress < max_concurrent_requests
|
|
690
|
-
)
|
|
691
|
-
if request_available and tokens_available and concurrent_request_available:
|
|
692
|
-
# update counters
|
|
693
|
-
available_request_capacity -= 1
|
|
694
|
-
available_token_capacity -= next_request_tokens
|
|
695
|
-
next_request.attempts_left -= 1
|
|
696
|
-
|
|
697
|
-
# call API
|
|
698
|
-
asyncio.create_task(next_request.call_api())
|
|
699
|
-
next_request = None # reset next_request to empty
|
|
700
|
-
else:
|
|
701
|
-
if not request_available:
|
|
702
|
-
limiting_factor = "Requests"
|
|
703
|
-
elif not concurrent_request_available:
|
|
704
|
-
limiting_factor = "Concurrent Requests"
|
|
705
|
-
elif not tokens_available:
|
|
706
|
-
limiting_factor = "Tokens"
|
|
707
|
-
|
|
708
|
-
# update pbar status
|
|
709
|
-
if progress_bar and (current_time - last_pbar_update_time > 1):
|
|
710
|
-
last_pbar_update_time = current_time
|
|
711
|
-
progress_bar.set_postfix(
|
|
712
|
-
{
|
|
713
|
-
"Token Capacity": f"{available_token_capacity/1_000:.1f}k",
|
|
714
|
-
"Req. Capacity": f"{available_request_capacity:.1f}",
|
|
715
|
-
"Reqs. in Progress": status_tracker.num_tasks_in_progress,
|
|
716
|
-
"Limiting Factor": limiting_factor,
|
|
717
|
-
}
|
|
718
|
-
)
|
|
719
|
-
|
|
720
|
-
# if all tasks are finished, break
|
|
721
|
-
if status_tracker.num_tasks_in_progress == 0:
|
|
722
|
-
break
|
|
723
|
-
|
|
724
|
-
# main loop sleeps briefly so concurrent tasks can run
|
|
725
|
-
await asyncio.sleep(seconds_to_sleep_each_loop)
|
|
726
|
-
|
|
727
|
-
# if a rate limit error was hit recently, pause to cool down
|
|
728
|
-
remaining_seconds_to_pause = max(
|
|
729
|
-
0,
|
|
730
|
-
seconds_to_pause_after_rate_limit_error
|
|
731
|
-
- status_tracker.time_since_rate_limit_error,
|
|
732
|
-
)
|
|
733
|
-
if remaining_seconds_to_pause > 0:
|
|
734
|
-
await asyncio.sleep(remaining_seconds_to_pause)
|
|
735
|
-
print(f"Pausing {remaining_seconds_to_pause}s to cool down.")
|
|
736
|
-
|
|
737
|
-
# after finishing, log final status
|
|
738
|
-
status_tracker.log_final_status()
|
|
739
|
-
if verbose:
|
|
740
|
-
print(
|
|
741
|
-
f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
|
|
438
|
+
async def wait_for_batch_job(
|
|
439
|
+
self, batch_ids: list[str], provider: Literal["anthropic", "openai"]
|
|
440
|
+
):
|
|
441
|
+
return await wait_for_batch_completion_async(
|
|
442
|
+
batch_ids, provider, poll_interval=30
|
|
742
443
|
)
|
|
743
444
|
|
|
744
|
-
# deduplicate results by id
|
|
745
|
-
deduplicated = {}
|
|
746
|
-
for request in results:
|
|
747
|
-
if request.task_id not in deduplicated:
|
|
748
|
-
deduplicated[request.task_id] = request.result[-1]
|
|
749
|
-
else:
|
|
750
|
-
current_response: APIResponse = deduplicated[request.task_id]
|
|
751
|
-
# only replace if the current request has no completion and the new one does
|
|
752
|
-
if (
|
|
753
|
-
request.result[-1].completion is not None
|
|
754
|
-
and current_response.completion is None
|
|
755
|
-
):
|
|
756
|
-
deduplicated[request.task_id] = request.result[-1]
|
|
757
|
-
|
|
758
|
-
output = list(deduplicated.values())
|
|
759
|
-
if verbose:
|
|
760
|
-
print(f"Returning {len(output)} unique results.")
|
|
761
445
|
|
|
762
|
-
|
|
446
|
+
# def api_prompts_dry_run(
|
|
447
|
+
# ids: np.ndarray | list[int],
|
|
448
|
+
# prompts: list[Conversation],
|
|
449
|
+
# models: str | list[str],
|
|
450
|
+
# model_weights: list[float],
|
|
451
|
+
# sampling_params: list[SamplingParams],
|
|
452
|
+
# max_tokens_per_minute: int = 500_000,
|
|
453
|
+
# max_requests_per_minute: int = 1_000,
|
|
454
|
+
# ):
|
|
455
|
+
# """
|
|
456
|
+
# Count tokens and estimate costs for a batch of prompts.
|
|
457
|
+
# """
|
|
458
|
+
# results = []
|
|
459
|
+
# for i, prompt in zip(ids, prompts):
|
|
460
|
+
# # choose a model
|
|
461
|
+
# model_idx = np.random.choice(range(len(models)), p=model_weights)
|
|
462
|
+
# model = models[model_idx]
|
|
463
|
+
|
|
464
|
+
# # dry run
|
|
465
|
+
# input_tokens, output_tokens, min_cost, max_cost = prompt.dry_run(
|
|
466
|
+
# model, sampling_params[model_idx].max_new_tokens
|
|
467
|
+
# )
|
|
468
|
+
# results.append(
|
|
469
|
+
# {
|
|
470
|
+
# "id": i,
|
|
471
|
+
# "input_tokens": input_tokens,
|
|
472
|
+
# "output_tokens": output_tokens,
|
|
473
|
+
# "min_cost": min_cost,
|
|
474
|
+
# "max_cost": max_cost,
|
|
475
|
+
# }
|
|
476
|
+
# )
|
|
477
|
+
|
|
478
|
+
# combined_results: dict[str, Any] = {
|
|
479
|
+
# "total_input_tokens": sum([r["input_tokens"] for r in results]),
|
|
480
|
+
# "total_output_tokens": sum([r["output_tokens"] for r in results]),
|
|
481
|
+
# "total_min_cost": sum([r["min_cost"] for r in results]),
|
|
482
|
+
# "total_max_cost": sum([r["max_cost"] for r in results]),
|
|
483
|
+
# }
|
|
484
|
+
# minimum_time_tpm = combined_results["total_input_tokens"] / max_tokens_per_minute
|
|
485
|
+
# maximum_time_tpm = (
|
|
486
|
+
# combined_results["total_input_tokens"] + combined_results["total_output_tokens"]
|
|
487
|
+
# ) / max_tokens_per_minute
|
|
488
|
+
# minimum_time_rpm = len(prompts) / max_requests_per_minute
|
|
489
|
+
|
|
490
|
+
# combined_results["minimum_time"] = max(minimum_time_tpm, minimum_time_rpm)
|
|
491
|
+
# combined_results["maximum_time"] = max(maximum_time_tpm, minimum_time_rpm)
|
|
492
|
+
# limiting_factor = None
|
|
493
|
+
# if minimum_time_rpm > maximum_time_tpm:
|
|
494
|
+
# limiting_factor = "requests"
|
|
495
|
+
# elif minimum_time_rpm < minimum_time_tpm:
|
|
496
|
+
# limiting_factor = "tokens"
|
|
497
|
+
# else:
|
|
498
|
+
# limiting_factor = "depends"
|
|
499
|
+
# combined_results["limiting_factor"] = limiting_factor
|
|
500
|
+
|
|
501
|
+
# return combined_results
|