lm-deluge 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +11 -1
- lm_deluge/agent.py +0 -0
- lm_deluge/api_requests/anthropic.py +90 -58
- lm_deluge/api_requests/base.py +63 -180
- lm_deluge/api_requests/bedrock.py +34 -10
- lm_deluge/api_requests/common.py +2 -1
- lm_deluge/api_requests/mistral.py +6 -15
- lm_deluge/api_requests/openai.py +342 -50
- lm_deluge/api_requests/response.py +153 -0
- lm_deluge/batches.py +498 -0
- lm_deluge/client.py +354 -636
- lm_deluge/computer_use/anthropic_tools.py +75 -0
- lm_deluge/{sampling_params.py → config.py} +12 -4
- lm_deluge/embed.py +17 -11
- lm_deluge/file.py +149 -0
- lm_deluge/models.py +33 -0
- lm_deluge/prompt.py +156 -15
- lm_deluge/rerank.py +18 -12
- lm_deluge/tool.py +11 -1
- lm_deluge/tracker.py +214 -2
- lm_deluge/util/json.py +18 -1
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.14.dist-info}/METADATA +8 -5
- lm_deluge-0.0.14.dist-info/RECORD +44 -0
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.14.dist-info}/WHEEL +1 -1
- lm_deluge-0.0.12.dist-info/RECORD +0 -39
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.14.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.14.dist-info}/top_level.txt +0 -0
lm_deluge/client.py
CHANGED
|
@@ -1,42 +1,147 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import requests
|
|
3
1
|
import asyncio
|
|
2
|
+
from typing import Any, Literal, Self, Sequence, overload
|
|
3
|
+
|
|
4
4
|
import numpy as np
|
|
5
|
-
import time
|
|
6
5
|
import yaml
|
|
7
|
-
from
|
|
8
|
-
from
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
from lm_deluge.
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
from pydantic.functional_validators import model_validator
|
|
8
|
+
|
|
9
|
+
from lm_deluge.api_requests.openai import stream_chat
|
|
10
|
+
from lm_deluge.batches import (
|
|
11
|
+
submit_batches_anthropic,
|
|
12
|
+
submit_batches_oa,
|
|
13
|
+
wait_for_batch_completion_async,
|
|
14
|
+
)
|
|
15
|
+
from lm_deluge.prompt import CachePattern, Conversation, prompts_to_conversations
|
|
12
16
|
from lm_deluge.tool import Tool
|
|
13
17
|
|
|
14
|
-
from .tracker import StatusTracker
|
|
15
|
-
from .sampling_params import SamplingParams
|
|
16
|
-
from .models import registry
|
|
17
|
-
from .api_requests.base import APIResponse, APIRequestBase
|
|
18
18
|
from .api_requests import create_api_request
|
|
19
|
+
from .api_requests.base import APIRequestBase, APIResponse, deduplicate_responses
|
|
20
|
+
from .config import SamplingParams
|
|
21
|
+
from .models import registry
|
|
22
|
+
from .tracker import StatusTracker
|
|
23
|
+
|
|
19
24
|
# from .cache import LevelDBCache, SqliteCache
|
|
20
25
|
|
|
26
|
+
|
|
21
27
|
# TODO: get completions as they finish, not all at once at the end.
|
|
22
28
|
# relatedly, would be nice to cache them as they finish too.
|
|
23
|
-
|
|
24
29
|
# TODO: add optional max_input_tokens to client so we can reject long prompts to prevent abuse
|
|
30
|
+
class LLMClient(BaseModel):
|
|
31
|
+
"""
|
|
32
|
+
LLMClient abstracts all the fixed arguments to process_prompts_async, so you can create it
|
|
33
|
+
once and use it for more stuff without having to configure all the arguments.
|
|
34
|
+
Handles models, sampling params for each model, model weights, rate limits, etc.
|
|
35
|
+
"""
|
|
25
36
|
|
|
37
|
+
model_names: list[str] = ["gpt-4.1-mini"]
|
|
26
38
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
sampling_params: SamplingParams
|
|
36
|
-
model_weights: list[float] | Literal["uniform", "
|
|
39
|
+
def __init__(self, model_name: str | list[str] | None = None, **kwargs):
|
|
40
|
+
if model_name is not None:
|
|
41
|
+
kwargs["model_names"] = model_name
|
|
42
|
+
super().__init__(**kwargs)
|
|
43
|
+
|
|
44
|
+
max_requests_per_minute: int = 1_000
|
|
45
|
+
max_tokens_per_minute: int = 100_000
|
|
46
|
+
max_concurrent_requests: int = 225
|
|
47
|
+
sampling_params: list[SamplingParams] = []
|
|
48
|
+
model_weights: list[float] | Literal["uniform", "dynamic"] = "uniform"
|
|
49
|
+
max_attempts: int = 5
|
|
50
|
+
request_timeout: int = 30
|
|
51
|
+
cache: Any = None
|
|
52
|
+
# sampling params - if provided, and sampling_params is not,
|
|
53
|
+
# these override the defaults
|
|
54
|
+
temperature: float = 0.75
|
|
55
|
+
top_p: float = 1.0
|
|
56
|
+
json_mode: bool = False
|
|
57
|
+
max_new_tokens: int = 512
|
|
58
|
+
reasoning_effort: Literal["low", "medium", "high", None] = None
|
|
37
59
|
logprobs: bool = False
|
|
38
60
|
top_logprobs: int | None = None
|
|
39
|
-
|
|
61
|
+
|
|
62
|
+
# NEW! Builder methods
|
|
63
|
+
def with_model(self, model: str):
|
|
64
|
+
self.model_names = [model]
|
|
65
|
+
return self
|
|
66
|
+
|
|
67
|
+
def with_models(self, models: list[str]):
|
|
68
|
+
self.model_names = models
|
|
69
|
+
return self
|
|
70
|
+
|
|
71
|
+
def with_limits(
|
|
72
|
+
self,
|
|
73
|
+
max_requests_per_minute: int | None = None,
|
|
74
|
+
max_tokens_per_minute: int | None = None,
|
|
75
|
+
max_concurrent_requests: int | None = None,
|
|
76
|
+
):
|
|
77
|
+
if max_requests_per_minute:
|
|
78
|
+
self.max_requests_per_minute = max_requests_per_minute
|
|
79
|
+
if max_tokens_per_minute:
|
|
80
|
+
self.max_tokens_per_minute = max_tokens_per_minute
|
|
81
|
+
if max_concurrent_requests:
|
|
82
|
+
self.max_concurrent_requests = max_concurrent_requests
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def models(self):
|
|
86
|
+
return self.model_names # why? idk
|
|
87
|
+
|
|
88
|
+
@model_validator(mode="before")
|
|
89
|
+
@classmethod
|
|
90
|
+
def fix_lists(cls, data) -> "LLMClient":
|
|
91
|
+
if isinstance(data.get("model_names"), str):
|
|
92
|
+
data["model_names"] = [data["model_names"]]
|
|
93
|
+
if "sampling_params" not in data or len(data.get("sampling_params", [])) == 0:
|
|
94
|
+
data["sampling_params"] = [
|
|
95
|
+
SamplingParams(
|
|
96
|
+
temperature=data.get("temperature", 0.75),
|
|
97
|
+
top_p=data.get("top_p", 1.0),
|
|
98
|
+
json_mode=data.get("json_mode", False),
|
|
99
|
+
max_new_tokens=data.get("max_new_tokens", 512),
|
|
100
|
+
reasoning_effort=data.get("reasoning_effort", None),
|
|
101
|
+
logprobs=data.get("logprobs", False),
|
|
102
|
+
top_logprobs=data.get("top_logprobs", None),
|
|
103
|
+
)
|
|
104
|
+
]
|
|
105
|
+
return data
|
|
106
|
+
|
|
107
|
+
@model_validator(mode="after")
|
|
108
|
+
def validate_client(self) -> Self:
|
|
109
|
+
if isinstance(self.model_names, str):
|
|
110
|
+
self.model_names = [self.model_names]
|
|
111
|
+
if any(m not in registry for m in self.model_names):
|
|
112
|
+
raise ValueError("all model_names must be in registry")
|
|
113
|
+
if isinstance(self.sampling_params, SamplingParams):
|
|
114
|
+
self.sampling_params = [self.sampling_params for _ in self.model_names]
|
|
115
|
+
elif len(self.sampling_params) != len(self.model_names):
|
|
116
|
+
raise ValueError("# models and # sampling params must match")
|
|
117
|
+
if self.model_weights == "uniform":
|
|
118
|
+
self.model_weights = [1 / len(self.model_names) for _ in self.model_names]
|
|
119
|
+
elif self.model_weights == "dynamic":
|
|
120
|
+
raise NotImplementedError("dynamic model weights not implemented yet")
|
|
121
|
+
# normalize weights
|
|
122
|
+
self.model_weights = [w / sum(self.model_weights) for w in self.model_weights]
|
|
123
|
+
|
|
124
|
+
# Validate logprobs settings across all sampling params
|
|
125
|
+
if self.logprobs or any(sp.logprobs for sp in self.sampling_params):
|
|
126
|
+
print("Logprobs enabled.")
|
|
127
|
+
for sp in self.sampling_params:
|
|
128
|
+
sp.logprobs = True
|
|
129
|
+
# set top_logprobs for each sp if provided and not set
|
|
130
|
+
if self.top_logprobs and not sp.top_logprobs:
|
|
131
|
+
sp.top_logprobs = self.top_logprobs
|
|
132
|
+
if sp.top_logprobs and not (0 <= sp.top_logprobs <= 20):
|
|
133
|
+
raise ValueError("top_logprobs must be 0-20")
|
|
134
|
+
if sp.top_logprobs and sp.max_new_tokens > 10:
|
|
135
|
+
print(
|
|
136
|
+
"WARNING: using top_logprobs can result in very large outputs. consider limiting max_new_tokens."
|
|
137
|
+
)
|
|
138
|
+
if not all(
|
|
139
|
+
registry[model].get("supports_logprobs") for model in self.models
|
|
140
|
+
):
|
|
141
|
+
raise ValueError(
|
|
142
|
+
"logprobs can only be enabled if all models support it."
|
|
143
|
+
)
|
|
144
|
+
return self
|
|
40
145
|
|
|
41
146
|
@classmethod
|
|
42
147
|
def from_dict(cls, config_dict: dict):
|
|
@@ -46,7 +151,7 @@ class ClientConfig:
|
|
|
46
151
|
]
|
|
47
152
|
else:
|
|
48
153
|
config_dict["sampling_params"] = SamplingParams(
|
|
49
|
-
config_dict["sampling_params"]
|
|
154
|
+
**config_dict["sampling_params"]
|
|
50
155
|
)
|
|
51
156
|
|
|
52
157
|
return cls(**config_dict)
|
|
@@ -56,184 +161,18 @@ class ClientConfig:
|
|
|
56
161
|
config_dict = yaml.safe_load(open(file_path))
|
|
57
162
|
return cls.from_dict(config_dict)
|
|
58
163
|
|
|
59
|
-
def to_dict(self):
|
|
60
|
-
if isinstance(self.sampling_params, list):
|
|
61
|
-
sp = [x.__dict__ for x in self.sampling_params]
|
|
62
|
-
else:
|
|
63
|
-
sp = self.sampling_params.__dict__
|
|
64
|
-
|
|
65
|
-
return {
|
|
66
|
-
"model_names": self.model_names,
|
|
67
|
-
"max_requests_per_minute": self.max_requests_per_minute,
|
|
68
|
-
"max_tokens_per_minute": self.max_tokens_per_minute,
|
|
69
|
-
"max_concurrent_requests": self.max_concurrent_requests,
|
|
70
|
-
"max_attempts": self.max_attempts,
|
|
71
|
-
"request_timeout": self.request_timeout,
|
|
72
|
-
"sampling_params": sp,
|
|
73
|
-
"model_weights": self.model_weights,
|
|
74
|
-
"logprobs": self.logprobs,
|
|
75
|
-
"top_logprobs": self.top_logprobs,
|
|
76
|
-
}
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
class LLMClient:
|
|
80
|
-
"""
|
|
81
|
-
LLMClient abstracts all the fixed arguments to process_prompts_async, so you can create it
|
|
82
|
-
once and use it for more stuff without having to configure all the arguments.
|
|
83
|
-
Handles models, sampling params for each model, model weights, rate limits, etc.
|
|
84
|
-
"""
|
|
85
|
-
|
|
86
|
-
def __init__(
|
|
87
|
-
self,
|
|
88
|
-
model_names: list[str],
|
|
89
|
-
*,
|
|
90
|
-
max_requests_per_minute: int,
|
|
91
|
-
max_tokens_per_minute: int,
|
|
92
|
-
max_concurrent_requests: int,
|
|
93
|
-
sampling_params: SamplingParams | list[SamplingParams] = SamplingParams(),
|
|
94
|
-
model_weights: list[float] | Literal["uniform", "rate_limit"] = "uniform",
|
|
95
|
-
max_attempts: int = 5,
|
|
96
|
-
request_timeout: int = 30,
|
|
97
|
-
logprobs: bool = False,
|
|
98
|
-
top_logprobs: int | None = None,
|
|
99
|
-
use_qps: bool = False,
|
|
100
|
-
debug: bool = False,
|
|
101
|
-
cache: Any = None,
|
|
102
|
-
):
|
|
103
|
-
self.models = model_names
|
|
104
|
-
if isinstance(sampling_params, SamplingParams):
|
|
105
|
-
self.sampling_params = [sampling_params for _ in model_names]
|
|
106
|
-
else:
|
|
107
|
-
if len(sampling_params) != len(model_names):
|
|
108
|
-
raise ValueError(
|
|
109
|
-
"If sampling_params is a list, it must have the same length as model_names."
|
|
110
|
-
)
|
|
111
|
-
self.sampling_params = sampling_params
|
|
112
|
-
if model_weights == "uniform":
|
|
113
|
-
self.model_weights = [1 / len(model_names) for _ in model_names]
|
|
114
|
-
elif model_weights == "rate_limit":
|
|
115
|
-
rpms = [registry[model]["requests_per_minute"] for model in model_names]
|
|
116
|
-
self.model_weights = [rpm / sum(rpms) for rpm in rpms]
|
|
117
|
-
elif sum(model_weights) != 1:
|
|
118
|
-
self.model_weights = [w / sum(model_weights) for w in model_weights]
|
|
119
|
-
else:
|
|
120
|
-
self.model_weights = model_weights
|
|
121
|
-
|
|
122
|
-
self.logprobs = logprobs
|
|
123
|
-
self.top_logprobs = top_logprobs
|
|
124
|
-
|
|
125
|
-
# logprobs and top_logprobs are only supported for OpenAI models
|
|
126
|
-
if self.logprobs:
|
|
127
|
-
for model in self.models:
|
|
128
|
-
if registry[model].get("supports_logprobs", False) is False:
|
|
129
|
-
raise ValueError(
|
|
130
|
-
"logprobs can only be enabled if all models support it."
|
|
131
|
-
)
|
|
132
|
-
if self.top_logprobs is None:
|
|
133
|
-
self.top_logprobs = 0 # will just return logprob of the chosen token
|
|
134
|
-
elif self.top_logprobs > 20 or self.top_logprobs < 0:
|
|
135
|
-
raise ValueError("top_logprobs must be between 0 and 20.")
|
|
136
|
-
for sp in self.sampling_params:
|
|
137
|
-
if sp.max_new_tokens > 10:
|
|
138
|
-
print(
|
|
139
|
-
"WARNING: using logprobs with large max_new_tokens can result in very large outputs. you may want to avoid saving these outputs to disk/db."
|
|
140
|
-
)
|
|
141
|
-
break
|
|
142
|
-
else:
|
|
143
|
-
self.top_logprobs = None
|
|
144
|
-
|
|
145
|
-
self.max_requests_per_minute = max_requests_per_minute
|
|
146
|
-
self.max_tokens_per_minute = max_tokens_per_minute
|
|
147
|
-
self.max_concurrent_requests = max_concurrent_requests
|
|
148
|
-
self.max_attempts = max_attempts
|
|
149
|
-
self.request_timeout = request_timeout
|
|
150
|
-
self.use_qps = use_qps
|
|
151
|
-
self.debug = (
|
|
152
|
-
debug # UNUSED/DEPRECATED i think? but dont want to break everything
|
|
153
|
-
)
|
|
154
|
-
self.cache = cache
|
|
155
|
-
|
|
156
164
|
@classmethod
|
|
157
|
-
def
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
@classmethod
|
|
171
|
-
def from_yaml(cls, file_path: str, cache: Any = None):
|
|
172
|
-
return cls.from_config(ClientConfig.from_yaml(file_path), cache=cache)
|
|
173
|
-
|
|
174
|
-
@classmethod
|
|
175
|
-
def basic(
|
|
176
|
-
cls,
|
|
177
|
-
model: str | list[str],
|
|
178
|
-
max_requests_per_minute: int = 5_000,
|
|
179
|
-
max_tokens_per_minute: int = 1_000_000,
|
|
180
|
-
max_concurrent_requests: int = 1_000,
|
|
181
|
-
temperature: float = 0.75,
|
|
182
|
-
max_new_tokens: int = 1000,
|
|
183
|
-
reasoning_effort: Literal[None, "low", "medium", "high"] = None,
|
|
184
|
-
model_weights: list[float] | Literal["uniform", "rate_limit"] = "uniform",
|
|
185
|
-
logprobs: bool = False,
|
|
186
|
-
top_logprobs: int | None = None,
|
|
187
|
-
max_attempts: int = 5,
|
|
188
|
-
request_timeout: int = 30,
|
|
189
|
-
cache: Any = None,
|
|
190
|
-
):
|
|
191
|
-
model_names = model if isinstance(model, list) else [model]
|
|
192
|
-
return cls(
|
|
193
|
-
model_names=model_names,
|
|
194
|
-
max_requests_per_minute=max_requests_per_minute,
|
|
195
|
-
max_tokens_per_minute=max_tokens_per_minute,
|
|
196
|
-
max_concurrent_requests=max_concurrent_requests,
|
|
197
|
-
sampling_params=SamplingParams(
|
|
198
|
-
temperature=temperature,
|
|
199
|
-
max_new_tokens=max_new_tokens,
|
|
200
|
-
reasoning_effort=reasoning_effort,
|
|
201
|
-
),
|
|
202
|
-
logprobs=logprobs,
|
|
203
|
-
top_logprobs=top_logprobs,
|
|
204
|
-
model_weights=model_weights,
|
|
205
|
-
max_attempts=max_attempts,
|
|
206
|
-
request_timeout=request_timeout,
|
|
207
|
-
cache=cache,
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
@property
|
|
211
|
-
def config(self):
|
|
212
|
-
return ClientConfig(
|
|
213
|
-
model_names=self.models,
|
|
214
|
-
model_weights=self.model_weights,
|
|
215
|
-
max_requests_per_minute=self.max_requests_per_minute,
|
|
216
|
-
max_tokens_per_minute=self.max_tokens_per_minute,
|
|
217
|
-
max_concurrent_requests=self.max_concurrent_requests,
|
|
218
|
-
max_attempts=self.max_attempts,
|
|
219
|
-
request_timeout=self.request_timeout,
|
|
220
|
-
sampling_params=self.sampling_params,
|
|
221
|
-
logprobs=self.logprobs,
|
|
222
|
-
top_logprobs=self.top_logprobs,
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
@overload
|
|
226
|
-
async def process_prompts_async(
|
|
227
|
-
self,
|
|
228
|
-
prompts: Sequence[str | list[dict] | Conversation],
|
|
229
|
-
*,
|
|
230
|
-
return_completions_only: bool,
|
|
231
|
-
show_progress: bool = ...,
|
|
232
|
-
dry_run: Literal[True],
|
|
233
|
-
verbose: bool = ...,
|
|
234
|
-
tools: list[Tool] | None = ...,
|
|
235
|
-
cache: CachePattern | None = ...,
|
|
236
|
-
) -> dict[str, int]: ...
|
|
165
|
+
def basic(cls, model: str | list[str], **kwargs):
|
|
166
|
+
"""
|
|
167
|
+
Doesn't do anything differently now, kept for backwards compat.
|
|
168
|
+
"""
|
|
169
|
+
kwargs["model_names"] = model
|
|
170
|
+
return cls(**kwargs)
|
|
171
|
+
|
|
172
|
+
def _select_model(self):
|
|
173
|
+
assert isinstance(self.model_weights, list)
|
|
174
|
+
model_idx = np.random.choice(range(len(self.models)), p=self.model_weights)
|
|
175
|
+
return self.models[model_idx], self.sampling_params[model_idx]
|
|
237
176
|
|
|
238
177
|
@overload
|
|
239
178
|
async def process_prompts_async(
|
|
@@ -242,10 +181,12 @@ class LLMClient:
|
|
|
242
181
|
*,
|
|
243
182
|
return_completions_only: Literal[True],
|
|
244
183
|
show_progress: bool = ...,
|
|
245
|
-
dry_run: bool = ...,
|
|
246
|
-
verbose: bool = ...,
|
|
247
184
|
tools: list[Tool] | None = ...,
|
|
248
185
|
cache: CachePattern | None = ...,
|
|
186
|
+
computer_use: bool = ...,
|
|
187
|
+
display_width: int = ...,
|
|
188
|
+
display_height: int = ...,
|
|
189
|
+
use_responses_api: bool = ...,
|
|
249
190
|
) -> list[str | None]: ...
|
|
250
191
|
|
|
251
192
|
@overload
|
|
@@ -255,10 +196,12 @@ class LLMClient:
|
|
|
255
196
|
*,
|
|
256
197
|
return_completions_only: Literal[False] = ...,
|
|
257
198
|
show_progress: bool = ...,
|
|
258
|
-
dry_run: bool = ...,
|
|
259
|
-
verbose: bool = ...,
|
|
260
199
|
tools: list[Tool] | None = ...,
|
|
261
200
|
cache: CachePattern | None = ...,
|
|
201
|
+
computer_use: bool = ...,
|
|
202
|
+
display_width: int = ...,
|
|
203
|
+
display_height: int = ...,
|
|
204
|
+
use_responses_api: bool = ...,
|
|
262
205
|
) -> list[APIResponse | None]: ...
|
|
263
206
|
|
|
264
207
|
async def process_prompts_async(
|
|
@@ -267,23 +210,15 @@ class LLMClient:
|
|
|
267
210
|
*,
|
|
268
211
|
return_completions_only: bool = False,
|
|
269
212
|
show_progress: bool = True,
|
|
270
|
-
dry_run: bool = False,
|
|
271
|
-
verbose: bool = False,
|
|
272
213
|
tools: list[Tool] | None = None,
|
|
273
214
|
cache: CachePattern | None = None,
|
|
215
|
+
computer_use: bool = False,
|
|
216
|
+
display_width: int = 1024,
|
|
217
|
+
display_height: int = 768,
|
|
218
|
+
use_responses_api: bool = False,
|
|
274
219
|
) -> list[APIResponse | None] | list[str | None] | dict[str, int]:
|
|
275
220
|
# if prompts are not Conversations, convert them.
|
|
276
|
-
|
|
277
|
-
prompts = [ # type: ignore
|
|
278
|
-
p
|
|
279
|
-
if isinstance(p, Conversation)
|
|
280
|
-
else Conversation.user(p)
|
|
281
|
-
if isinstance(p, str)
|
|
282
|
-
else None
|
|
283
|
-
for p in prompts
|
|
284
|
-
]
|
|
285
|
-
if any(p is None for p in prompts):
|
|
286
|
-
raise ValueError("All prompts must be valid.")
|
|
221
|
+
prompts = prompts_to_conversations(prompts)
|
|
287
222
|
ids = np.arange(len(prompts))
|
|
288
223
|
|
|
289
224
|
# if using cache, check for cached completions
|
|
@@ -298,10 +233,9 @@ class LLMClient:
|
|
|
298
233
|
), "Cache hit ids and results must be the same length."
|
|
299
234
|
remaining_ids = np.array([i for i in ids if i not in cache_hit_ids])
|
|
300
235
|
remaining_prompts = [prompts[i] for i in remaining_ids]
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
print(f"Processing {len(remaining_prompts)} prompts.")
|
|
236
|
+
print(
|
|
237
|
+
f"{len(cache_hit_ids)} cache hits; {len(remaining_ids)} prompts remaining."
|
|
238
|
+
)
|
|
305
239
|
|
|
306
240
|
else:
|
|
307
241
|
cache_hit_ids = []
|
|
@@ -311,48 +245,102 @@ class LLMClient:
|
|
|
311
245
|
|
|
312
246
|
results: list[APIResponse | None] = [None for _ in range(len(prompts))]
|
|
313
247
|
if len(remaining_prompts) > 0:
|
|
314
|
-
#
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
prompts, # type: ignore -- fix later for dry running conversations
|
|
324
|
-
self.models,
|
|
325
|
-
self.model_weights,
|
|
326
|
-
self.sampling_params,
|
|
327
|
-
max_tokens_per_minute=self.max_tokens_per_minute,
|
|
328
|
-
max_requests_per_minute=self.max_requests_per_minute,
|
|
329
|
-
)
|
|
330
|
-
print("Dry run results:")
|
|
331
|
-
print(dry_run_results)
|
|
332
|
-
return dry_run_results
|
|
333
|
-
|
|
334
|
-
api_task = asyncio.create_task(
|
|
335
|
-
process_api_prompts_async(
|
|
336
|
-
ids,
|
|
337
|
-
prompts, # type: ignore -- fix later for dry running conversations
|
|
338
|
-
self.models,
|
|
339
|
-
self.model_weights,
|
|
340
|
-
self.sampling_params,
|
|
341
|
-
logprobs=self.logprobs,
|
|
342
|
-
top_logprobs=self.top_logprobs,
|
|
343
|
-
max_attempts=self.max_attempts,
|
|
344
|
-
max_tokens_per_minute=self.max_tokens_per_minute,
|
|
345
|
-
max_requests_per_minute=self.max_requests_per_minute,
|
|
346
|
-
max_concurrent_requests=self.max_concurrent_requests,
|
|
347
|
-
request_timeout=self.request_timeout,
|
|
348
|
-
progress_bar=pbar,
|
|
349
|
-
use_qps=self.use_qps,
|
|
350
|
-
verbose=verbose,
|
|
351
|
-
tools=tools,
|
|
352
|
-
cache=cache,
|
|
353
|
-
)
|
|
248
|
+
# Create StatusTracker with integrated progress bar
|
|
249
|
+
tracker = StatusTracker(
|
|
250
|
+
max_requests_per_minute=self.max_requests_per_minute,
|
|
251
|
+
max_tokens_per_minute=self.max_tokens_per_minute,
|
|
252
|
+
max_concurrent_requests=self.max_concurrent_requests,
|
|
253
|
+
use_progress_bar=show_progress,
|
|
254
|
+
progress_bar_total=len(prompts),
|
|
255
|
+
progress_bar_disable=not show_progress,
|
|
256
|
+
use_rich=show_progress, # Disable Rich if progress is disabled
|
|
354
257
|
)
|
|
355
|
-
|
|
258
|
+
|
|
259
|
+
# Initialize progress bar and update with cache hits
|
|
260
|
+
tracker.init_progress_bar()
|
|
261
|
+
if len(cache_hit_ids) > 0:
|
|
262
|
+
tracker.update_pbar(len(cache_hit_ids))
|
|
263
|
+
|
|
264
|
+
if isinstance(ids, np.ndarray):
|
|
265
|
+
ids = ids.tolist() # pyright: ignore
|
|
266
|
+
|
|
267
|
+
# calculate dynamically so we don't throttle RPM
|
|
268
|
+
seconds_to_sleep_each_loop = (60.0 * 0.9) / tracker.max_requests_per_minute
|
|
269
|
+
next_request = None # variable to hold the next request to call
|
|
270
|
+
prompts_not_finished = True
|
|
271
|
+
prompts_iter = iter(zip(ids, prompts))
|
|
272
|
+
requests: list[APIRequestBase] = []
|
|
273
|
+
assert tracker.retry_queue, "retry queue not initialized"
|
|
274
|
+
while True:
|
|
275
|
+
# get next request (if one is not already waiting for capacity)
|
|
276
|
+
retry_request = False
|
|
277
|
+
if next_request is None:
|
|
278
|
+
if not tracker.retry_queue.empty():
|
|
279
|
+
next_request = tracker.retry_queue.get_nowait()
|
|
280
|
+
retry_request = True
|
|
281
|
+
print(f"Retrying request {next_request.task_id}.")
|
|
282
|
+
elif prompts_not_finished:
|
|
283
|
+
try:
|
|
284
|
+
# get new request
|
|
285
|
+
id, prompt = next(prompts_iter)
|
|
286
|
+
# select model
|
|
287
|
+
model, sampling_params = self._select_model()
|
|
288
|
+
|
|
289
|
+
next_request = create_api_request(
|
|
290
|
+
task_id=id,
|
|
291
|
+
model_name=model,
|
|
292
|
+
prompt=prompt, # type: ignore
|
|
293
|
+
request_timeout=self.request_timeout,
|
|
294
|
+
attempts_left=self.max_attempts,
|
|
295
|
+
status_tracker=tracker,
|
|
296
|
+
results_arr=requests,
|
|
297
|
+
sampling_params=sampling_params,
|
|
298
|
+
all_model_names=self.models,
|
|
299
|
+
all_sampling_params=self.sampling_params,
|
|
300
|
+
tools=tools,
|
|
301
|
+
cache=cache,
|
|
302
|
+
computer_use=computer_use,
|
|
303
|
+
display_width=display_width,
|
|
304
|
+
display_height=display_height,
|
|
305
|
+
use_responses_api=use_responses_api,
|
|
306
|
+
)
|
|
307
|
+
requests.append(next_request)
|
|
308
|
+
|
|
309
|
+
except StopIteration:
|
|
310
|
+
prompts_not_finished = False
|
|
311
|
+
# print("API requests finished, only retries remain.")
|
|
312
|
+
|
|
313
|
+
# update available capacity
|
|
314
|
+
tracker.update_capacity()
|
|
315
|
+
|
|
316
|
+
# if enough capacity available, call API
|
|
317
|
+
if next_request:
|
|
318
|
+
next_request_tokens = next_request.num_tokens
|
|
319
|
+
if tracker.check_capacity(next_request_tokens, retry=retry_request):
|
|
320
|
+
tracker.set_limiting_factor(None)
|
|
321
|
+
# call API (attempts_left will be decremented in handle_error if it fails)
|
|
322
|
+
asyncio.create_task(next_request.call_api())
|
|
323
|
+
next_request = None # reset next_request to empty
|
|
324
|
+
# update pbar status
|
|
325
|
+
tracker.update_pbar()
|
|
326
|
+
|
|
327
|
+
# if all tasks are finished, break
|
|
328
|
+
if tracker.num_tasks_in_progress == 0:
|
|
329
|
+
break
|
|
330
|
+
|
|
331
|
+
# main loop sleeps briefly so concurrent tasks can run
|
|
332
|
+
await asyncio.sleep(seconds_to_sleep_each_loop)
|
|
333
|
+
|
|
334
|
+
# if a rate limit error was hit recently, pause to cool down
|
|
335
|
+
if tracker.seconds_to_pause > 0:
|
|
336
|
+
await asyncio.sleep(tracker.seconds_to_pause)
|
|
337
|
+
print(f"Pausing {tracker.seconds_to_pause}s to cool down.")
|
|
338
|
+
|
|
339
|
+
# after finishing, log final status
|
|
340
|
+
tracker.log_final_status()
|
|
341
|
+
|
|
342
|
+
# deduplicate results by id
|
|
343
|
+
api_results = deduplicate_responses(requests)
|
|
356
344
|
for res in api_results:
|
|
357
345
|
results[res.id] = res
|
|
358
346
|
# set to cache if result has a completion
|
|
@@ -375,8 +363,6 @@ class LLMClient:
|
|
|
375
363
|
*,
|
|
376
364
|
return_completions_only: bool = False,
|
|
377
365
|
show_progress=True,
|
|
378
|
-
dry_run: bool = False,
|
|
379
|
-
verbose: bool = False,
|
|
380
366
|
tools: list[Tool] | None = None,
|
|
381
367
|
cache: CachePattern | None = None,
|
|
382
368
|
):
|
|
@@ -385,387 +371,119 @@ class LLMClient:
|
|
|
385
371
|
prompts=prompts,
|
|
386
372
|
return_completions_only=return_completions_only,
|
|
387
373
|
show_progress=show_progress,
|
|
388
|
-
dry_run=dry_run,
|
|
389
|
-
verbose=verbose,
|
|
390
374
|
tools=tools,
|
|
391
375
|
cache=cache,
|
|
392
376
|
)
|
|
393
377
|
)
|
|
394
378
|
|
|
395
|
-
def
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
if api_key is None:
|
|
406
|
-
raise ValueError("OPENAI_API_KEY environment variable must be set.")
|
|
407
|
-
url = "https://api.openai.com/v1/files"
|
|
408
|
-
files = {
|
|
409
|
-
"file": (
|
|
410
|
-
"openai_requests_temp.jsonl",
|
|
411
|
-
open("openai_requests_temp.jsonl", "rb"),
|
|
412
|
-
),
|
|
413
|
-
}
|
|
414
|
-
data = {
|
|
415
|
-
"purpose": "batch",
|
|
416
|
-
}
|
|
417
|
-
headers = {
|
|
418
|
-
"Authorization": f"Bearer {api_key}",
|
|
419
|
-
}
|
|
420
|
-
response = requests.post(url, files=files, data=data, headers=headers)
|
|
421
|
-
|
|
422
|
-
file_id = None
|
|
423
|
-
if response.status_code == 200:
|
|
424
|
-
print("File uploaded successfully")
|
|
425
|
-
data = response.json()
|
|
426
|
-
file_id = data["id"]
|
|
427
|
-
|
|
428
|
-
else:
|
|
429
|
-
print("File upload failed")
|
|
430
|
-
raise ValueError(f"Error uploading file: {response.text}")
|
|
431
|
-
|
|
432
|
-
url = "https://api.openai.com/v1/batches"
|
|
433
|
-
data = {
|
|
434
|
-
"input_file_id": file_id,
|
|
435
|
-
"endpoint": "/v1/chat/completions",
|
|
436
|
-
"completion_window": "24h",
|
|
437
|
-
}
|
|
438
|
-
response = requests.post(url, json=data, headers=headers)
|
|
439
|
-
|
|
440
|
-
batch_id = None
|
|
441
|
-
if response.status_code == 200:
|
|
442
|
-
data = response.json()
|
|
443
|
-
batch_id = data["id"]
|
|
444
|
-
print("Batch job started successfully: id = ", batch_id)
|
|
445
|
-
return batch_id
|
|
446
|
-
else:
|
|
447
|
-
print("Batch job failed to start")
|
|
448
|
-
raise ValueError(f"Error starting batch job: {response.text}")
|
|
379
|
+
async def stream(self, prompt: str | Conversation, tools: list[Tool] | None = None):
|
|
380
|
+
model, sampling_params = self._select_model()
|
|
381
|
+
if isinstance(prompt, str):
|
|
382
|
+
prompt = Conversation.user(prompt)
|
|
383
|
+
async for item in stream_chat(model, prompt, sampling_params, tools, None):
|
|
384
|
+
if isinstance(item, str):
|
|
385
|
+
print(item, end="", flush=True)
|
|
386
|
+
else:
|
|
387
|
+
# final item
|
|
388
|
+
return item
|
|
449
389
|
|
|
450
|
-
def submit_batch_job(
|
|
451
|
-
|
|
390
|
+
async def submit_batch_job(
|
|
391
|
+
self,
|
|
392
|
+
prompts: Sequence[str | list[dict] | Conversation],
|
|
393
|
+
*,
|
|
394
|
+
tools: list[Tool] | None = None,
|
|
395
|
+
cache: CachePattern | None = None,
|
|
396
|
+
):
|
|
397
|
+
"""Submit a batch job asynchronously, automatically detecting the provider based on model.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
prompts: List of prompts to process
|
|
401
|
+
wait_for_completion: If True, poll until completion and return results
|
|
402
|
+
poll_interval: Seconds to wait between status checks when polling
|
|
403
|
+
tools: Optional tools to include in requests (Anthropic only)
|
|
404
|
+
cache: Optional cache pattern for requests (Anthropic only)
|
|
405
|
+
|
|
406
|
+
Returns: list of batch_ids
|
|
407
|
+
"""
|
|
408
|
+
assert isinstance(self.sampling_params, list)
|
|
452
409
|
if len(self.models) != 1:
|
|
453
410
|
raise ValueError("Batch jobs can only be submitted with a single model.")
|
|
454
411
|
model = self.models[0]
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
for p in prompts
|
|
466
|
-
]
|
|
467
|
-
if any(p is None for p in prompts):
|
|
468
|
-
raise ValueError("All prompts must be valid.")
|
|
469
|
-
ids = np.arange(len(prompts))
|
|
470
|
-
|
|
471
|
-
# create file with requests to send to batch api
|
|
472
|
-
batch_requests = []
|
|
473
|
-
for id, prompt in zip(ids, prompts):
|
|
474
|
-
assert isinstance(prompt, Conversation)
|
|
475
|
-
batch_requests.append(
|
|
476
|
-
{
|
|
477
|
-
"custom_id": str(id),
|
|
478
|
-
"method": "POST",
|
|
479
|
-
"url": "/v1/chat/completions",
|
|
480
|
-
"body": {
|
|
481
|
-
"model": self.models[0],
|
|
482
|
-
"messages": prompt.to_openai(),
|
|
483
|
-
"max_tokens": self.sampling_params[0].max_new_tokens,
|
|
484
|
-
"temperature": self.sampling_params[0].temperature,
|
|
485
|
-
"top_p": self.sampling_params[0].top_p,
|
|
486
|
-
},
|
|
487
|
-
}
|
|
488
|
-
)
|
|
489
|
-
|
|
490
|
-
# since the api only accepts up to 50,000 requests per batch job, we chunk into 50k chunks
|
|
491
|
-
BATCH_SIZE = 50_000
|
|
492
|
-
batches = [
|
|
493
|
-
batch_requests[i : i + BATCH_SIZE]
|
|
494
|
-
for i in range(0, len(batch_requests), BATCH_SIZE)
|
|
495
|
-
]
|
|
496
|
-
batch_ids = []
|
|
497
|
-
for batch in tqdm(batches):
|
|
498
|
-
batch_id = self._submit_one_batch(batch)
|
|
499
|
-
batch_ids.append(batch_id)
|
|
500
|
-
|
|
501
|
-
print(f"Submitted {len(batches)} batch jobs.")
|
|
502
|
-
return batch_ids
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
def api_prompts_dry_run(
|
|
506
|
-
ids: np.ndarray | list[int],
|
|
507
|
-
prompts: list[Conversation],
|
|
508
|
-
models: str | list[str],
|
|
509
|
-
model_weights: list[float],
|
|
510
|
-
sampling_params: list[SamplingParams],
|
|
511
|
-
max_tokens_per_minute: int = 500_000,
|
|
512
|
-
max_requests_per_minute: int = 1_000,
|
|
513
|
-
):
|
|
514
|
-
"""
|
|
515
|
-
Count tokens and estimate costs for a batch of prompts.
|
|
516
|
-
"""
|
|
517
|
-
results = []
|
|
518
|
-
for i, prompt in zip(ids, prompts):
|
|
519
|
-
# choose a model
|
|
520
|
-
model_idx = np.random.choice(range(len(models)), p=model_weights)
|
|
521
|
-
model = models[model_idx]
|
|
522
|
-
|
|
523
|
-
# dry run
|
|
524
|
-
input_tokens, output_tokens, min_cost, max_cost = prompt.dry_run(
|
|
525
|
-
model, sampling_params[model_idx].max_new_tokens
|
|
526
|
-
)
|
|
527
|
-
results.append(
|
|
528
|
-
{
|
|
529
|
-
"id": i,
|
|
530
|
-
"input_tokens": input_tokens,
|
|
531
|
-
"output_tokens": output_tokens,
|
|
532
|
-
"min_cost": min_cost,
|
|
533
|
-
"max_cost": max_cost,
|
|
534
|
-
}
|
|
535
|
-
)
|
|
536
|
-
|
|
537
|
-
combined_results: dict[str, Any] = {
|
|
538
|
-
"total_input_tokens": sum([r["input_tokens"] for r in results]),
|
|
539
|
-
"total_output_tokens": sum([r["output_tokens"] for r in results]),
|
|
540
|
-
"total_min_cost": sum([r["min_cost"] for r in results]),
|
|
541
|
-
"total_max_cost": sum([r["max_cost"] for r in results]),
|
|
542
|
-
}
|
|
543
|
-
minimum_time_tpm = combined_results["total_input_tokens"] / max_tokens_per_minute
|
|
544
|
-
maximum_time_tpm = (
|
|
545
|
-
combined_results["total_input_tokens"] + combined_results["total_output_tokens"]
|
|
546
|
-
) / max_tokens_per_minute
|
|
547
|
-
minimum_time_rpm = len(prompts) / max_requests_per_minute
|
|
548
|
-
|
|
549
|
-
combined_results["minimum_time"] = max(minimum_time_tpm, minimum_time_rpm)
|
|
550
|
-
combined_results["maximum_time"] = max(maximum_time_tpm, minimum_time_rpm)
|
|
551
|
-
limiting_factor = None
|
|
552
|
-
if minimum_time_rpm > maximum_time_tpm:
|
|
553
|
-
limiting_factor = "requests"
|
|
554
|
-
elif minimum_time_rpm < minimum_time_tpm:
|
|
555
|
-
limiting_factor = "tokens"
|
|
556
|
-
else:
|
|
557
|
-
limiting_factor = "depends"
|
|
558
|
-
combined_results["limiting_factor"] = limiting_factor
|
|
559
|
-
|
|
560
|
-
return combined_results
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
async def process_api_prompts_async(
|
|
564
|
-
ids: np.ndarray | list[int],
|
|
565
|
-
prompts: list[Conversation],
|
|
566
|
-
models: str | list[str],
|
|
567
|
-
model_weights: list[float],
|
|
568
|
-
sampling_params: list[SamplingParams],
|
|
569
|
-
logprobs: bool,
|
|
570
|
-
top_logprobs: int | None,
|
|
571
|
-
max_attempts: int = 5,
|
|
572
|
-
max_tokens_per_minute: int = 500_000,
|
|
573
|
-
max_requests_per_minute: int = 1_000,
|
|
574
|
-
max_concurrent_requests: int = 1_000,
|
|
575
|
-
request_timeout: int = 30,
|
|
576
|
-
progress_bar: tqdm | None = None,
|
|
577
|
-
use_qps: bool = False,
|
|
578
|
-
verbose: bool = False,
|
|
579
|
-
tools: list[Tool] | None = None,
|
|
580
|
-
cache: CachePattern | None = None,
|
|
581
|
-
):
|
|
582
|
-
"""Processes API requests in parallel, throttling to stay under rate limits."""
|
|
583
|
-
# change ids to integer list
|
|
584
|
-
if isinstance(ids, np.ndarray):
|
|
585
|
-
ids = ids.tolist() # pyright: ignore
|
|
586
|
-
|
|
587
|
-
# normalize weights
|
|
588
|
-
model_weights = [w / sum(model_weights) for w in model_weights]
|
|
589
|
-
|
|
590
|
-
# constants
|
|
591
|
-
seconds_to_pause_after_rate_limit_error = 5
|
|
592
|
-
# seconds_to_sleep_each_loop = 0.003 # so concurrent tasks can run
|
|
593
|
-
# calculate dynamically so we don't throttle RPM
|
|
594
|
-
seconds_to_sleep_each_loop = (60.0 * 0.9) / max_requests_per_minute
|
|
595
|
-
|
|
596
|
-
# initialize trackers
|
|
597
|
-
retry_queue = asyncio.Queue()
|
|
598
|
-
status_tracker = StatusTracker()
|
|
599
|
-
next_request = None # variable to hold the next request to call
|
|
600
|
-
|
|
601
|
-
# initialize available capacity counts
|
|
602
|
-
# throttle over a 1 second window rather than minute,
|
|
603
|
-
# since some models limit RPS rather than RPM
|
|
604
|
-
if use_qps:
|
|
605
|
-
available_request_capacity = max_requests_per_minute / 60.0
|
|
606
|
-
available_token_capacity = max_tokens_per_minute
|
|
607
|
-
else:
|
|
608
|
-
available_request_capacity = max_requests_per_minute
|
|
609
|
-
available_token_capacity = max_tokens_per_minute
|
|
610
|
-
last_update_time = time.time()
|
|
611
|
-
last_pbar_update_time = time.time()
|
|
612
|
-
|
|
613
|
-
# initialize flags
|
|
614
|
-
prompts_not_finished = True
|
|
615
|
-
|
|
616
|
-
# initials model weights
|
|
617
|
-
if isinstance(models, str):
|
|
618
|
-
models = [models]
|
|
619
|
-
if not isinstance(models, list):
|
|
620
|
-
raise ValueError("models must be a string or a list of model strings.")
|
|
621
|
-
for model in models:
|
|
622
|
-
if model not in registry:
|
|
623
|
-
raise ValueError(f"Model {model} not found in registry.")
|
|
624
|
-
|
|
625
|
-
if model_weights is None:
|
|
626
|
-
# if not given, spread requests evenly across models
|
|
627
|
-
model_weights = [1 / len(models) for _ in models]
|
|
628
|
-
elif len(model_weights) != len(models):
|
|
629
|
-
raise ValueError(
|
|
630
|
-
"model_weights must be None or a list of the same length as models."
|
|
631
|
-
)
|
|
632
|
-
elif sum(model_weights) != 1:
|
|
633
|
-
model_weights = [w / sum(model_weights) for w in model_weights]
|
|
634
|
-
|
|
635
|
-
prompts_iter = iter(zip(ids, prompts))
|
|
636
|
-
results: list[APIRequestBase] = []
|
|
637
|
-
while True:
|
|
638
|
-
# get next request (if one is not already waiting for capacity)
|
|
639
|
-
if next_request is None:
|
|
640
|
-
if not retry_queue.empty():
|
|
641
|
-
next_request = retry_queue.get_nowait()
|
|
642
|
-
print(f"Retrying request {next_request.task_id}.")
|
|
643
|
-
elif prompts_not_finished:
|
|
644
|
-
try:
|
|
645
|
-
# get new request
|
|
646
|
-
id, prompt = next(prompts_iter)
|
|
647
|
-
# select model
|
|
648
|
-
model_idx = np.random.choice(range(len(models)), p=model_weights)
|
|
649
|
-
next_request = create_api_request(
|
|
650
|
-
task_id=id,
|
|
651
|
-
model_name=models[model_idx],
|
|
652
|
-
prompt=prompt,
|
|
653
|
-
request_timeout=request_timeout,
|
|
654
|
-
attempts_left=max_attempts,
|
|
655
|
-
status_tracker=status_tracker,
|
|
656
|
-
retry_queue=retry_queue,
|
|
657
|
-
results_arr=results,
|
|
658
|
-
sampling_params=sampling_params[model_idx],
|
|
659
|
-
logprobs=logprobs,
|
|
660
|
-
top_logprobs=top_logprobs,
|
|
661
|
-
pbar=progress_bar,
|
|
662
|
-
all_model_names=models,
|
|
663
|
-
all_sampling_params=sampling_params,
|
|
664
|
-
tools=tools,
|
|
665
|
-
cache=cache,
|
|
666
|
-
)
|
|
667
|
-
status_tracker.num_tasks_started += 1
|
|
668
|
-
status_tracker.num_tasks_in_progress += 1
|
|
669
|
-
results.append(next_request)
|
|
670
|
-
|
|
671
|
-
except StopIteration:
|
|
672
|
-
prompts_not_finished = False
|
|
673
|
-
if verbose:
|
|
674
|
-
print("API requests finished, only retries remain.")
|
|
675
|
-
|
|
676
|
-
# update available capacity
|
|
677
|
-
current_time = time.time()
|
|
678
|
-
seconds_since_update = current_time - last_update_time
|
|
679
|
-
available_request_capacity = min(
|
|
680
|
-
available_request_capacity
|
|
681
|
-
+ max_requests_per_minute * seconds_since_update / 60.0,
|
|
682
|
-
max_requests_per_minute,
|
|
683
|
-
)
|
|
684
|
-
available_token_capacity = min(
|
|
685
|
-
available_token_capacity
|
|
686
|
-
+ max_tokens_per_minute * seconds_since_update / 60.0,
|
|
687
|
-
max_tokens_per_minute,
|
|
688
|
-
)
|
|
689
|
-
last_update_time = current_time
|
|
690
|
-
|
|
691
|
-
# if enough capacity available, call API
|
|
692
|
-
limiting_factor = None
|
|
693
|
-
if next_request:
|
|
694
|
-
next_request_tokens = next_request.num_tokens
|
|
695
|
-
request_available = available_request_capacity >= 1
|
|
696
|
-
tokens_available = available_token_capacity >= next_request_tokens
|
|
697
|
-
concurrent_request_available = (
|
|
698
|
-
status_tracker.num_tasks_in_progress < max_concurrent_requests
|
|
699
|
-
)
|
|
700
|
-
if request_available and tokens_available and concurrent_request_available:
|
|
701
|
-
# update counters
|
|
702
|
-
available_request_capacity -= 1
|
|
703
|
-
available_token_capacity -= next_request_tokens
|
|
704
|
-
next_request.attempts_left -= 1
|
|
705
|
-
|
|
706
|
-
# call API
|
|
707
|
-
asyncio.create_task(next_request.call_api())
|
|
708
|
-
next_request = None # reset next_request to empty
|
|
709
|
-
else:
|
|
710
|
-
if not request_available:
|
|
711
|
-
limiting_factor = "Requests"
|
|
712
|
-
elif not concurrent_request_available:
|
|
713
|
-
limiting_factor = "Concurrent Requests"
|
|
714
|
-
elif not tokens_available:
|
|
715
|
-
limiting_factor = "Tokens"
|
|
716
|
-
|
|
717
|
-
# update pbar status
|
|
718
|
-
if progress_bar and (current_time - last_pbar_update_time > 1):
|
|
719
|
-
last_pbar_update_time = current_time
|
|
720
|
-
progress_bar.set_postfix(
|
|
721
|
-
{
|
|
722
|
-
"Token Capacity": f"{available_token_capacity/1_000:.1f}k",
|
|
723
|
-
"Req. Capacity": f"{available_request_capacity:.1f}",
|
|
724
|
-
"Reqs. in Progress": status_tracker.num_tasks_in_progress,
|
|
725
|
-
"Limiting Factor": limiting_factor,
|
|
726
|
-
}
|
|
412
|
+
api_spec = registry[model].get("api_spec", None)
|
|
413
|
+
|
|
414
|
+
if api_spec == "openai":
|
|
415
|
+
return await submit_batches_oa(model, self.sampling_params[0], prompts)
|
|
416
|
+
elif api_spec == "anthropic":
|
|
417
|
+
return await submit_batches_anthropic(
|
|
418
|
+
model,
|
|
419
|
+
self.sampling_params[0],
|
|
420
|
+
prompts,
|
|
421
|
+
cache=cache,
|
|
727
422
|
)
|
|
423
|
+
else:
|
|
424
|
+
raise ValueError(f"Batch processing not supported for API spec: {api_spec}")
|
|
728
425
|
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
await asyncio.sleep(seconds_to_sleep_each_loop)
|
|
735
|
-
|
|
736
|
-
# if a rate limit error was hit recently, pause to cool down
|
|
737
|
-
remaining_seconds_to_pause = max(
|
|
738
|
-
0,
|
|
739
|
-
seconds_to_pause_after_rate_limit_error
|
|
740
|
-
- status_tracker.time_since_rate_limit_error,
|
|
741
|
-
)
|
|
742
|
-
if remaining_seconds_to_pause > 0:
|
|
743
|
-
await asyncio.sleep(remaining_seconds_to_pause)
|
|
744
|
-
print(f"Pausing {remaining_seconds_to_pause}s to cool down.")
|
|
745
|
-
|
|
746
|
-
# after finishing, log final status
|
|
747
|
-
status_tracker.log_final_status()
|
|
748
|
-
if verbose:
|
|
749
|
-
print(
|
|
750
|
-
f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
|
|
426
|
+
async def wait_for_batch_job(
|
|
427
|
+
self, batch_ids: list[str], provider: Literal["anthropic", "openai"]
|
|
428
|
+
):
|
|
429
|
+
return await wait_for_batch_completion_async(
|
|
430
|
+
batch_ids, provider, poll_interval=30
|
|
751
431
|
)
|
|
752
432
|
|
|
753
|
-
# deduplicate results by id
|
|
754
|
-
deduplicated = {}
|
|
755
|
-
for request in results:
|
|
756
|
-
if request.task_id not in deduplicated:
|
|
757
|
-
deduplicated[request.task_id] = request.result[-1]
|
|
758
|
-
else:
|
|
759
|
-
current_response: APIResponse = deduplicated[request.task_id]
|
|
760
|
-
# only replace if the current request has no completion and the new one does
|
|
761
|
-
if (
|
|
762
|
-
request.result[-1].completion is not None
|
|
763
|
-
and current_response.completion is None
|
|
764
|
-
):
|
|
765
|
-
deduplicated[request.task_id] = request.result[-1]
|
|
766
|
-
|
|
767
|
-
output = list(deduplicated.values())
|
|
768
|
-
if verbose:
|
|
769
|
-
print(f"Returning {len(output)} unique results.")
|
|
770
433
|
|
|
771
|
-
|
|
434
|
+
# def api_prompts_dry_run(
|
|
435
|
+
# ids: np.ndarray | list[int],
|
|
436
|
+
# prompts: list[Conversation],
|
|
437
|
+
# models: str | list[str],
|
|
438
|
+
# model_weights: list[float],
|
|
439
|
+
# sampling_params: list[SamplingParams],
|
|
440
|
+
# max_tokens_per_minute: int = 500_000,
|
|
441
|
+
# max_requests_per_minute: int = 1_000,
|
|
442
|
+
# ):
|
|
443
|
+
# """
|
|
444
|
+
# Count tokens and estimate costs for a batch of prompts.
|
|
445
|
+
# """
|
|
446
|
+
# results = []
|
|
447
|
+
# for i, prompt in zip(ids, prompts):
|
|
448
|
+
# # choose a model
|
|
449
|
+
# model_idx = np.random.choice(range(len(models)), p=model_weights)
|
|
450
|
+
# model = models[model_idx]
|
|
451
|
+
|
|
452
|
+
# # dry run
|
|
453
|
+
# input_tokens, output_tokens, min_cost, max_cost = prompt.dry_run(
|
|
454
|
+
# model, sampling_params[model_idx].max_new_tokens
|
|
455
|
+
# )
|
|
456
|
+
# results.append(
|
|
457
|
+
# {
|
|
458
|
+
# "id": i,
|
|
459
|
+
# "input_tokens": input_tokens,
|
|
460
|
+
# "output_tokens": output_tokens,
|
|
461
|
+
# "min_cost": min_cost,
|
|
462
|
+
# "max_cost": max_cost,
|
|
463
|
+
# }
|
|
464
|
+
# )
|
|
465
|
+
|
|
466
|
+
# combined_results: dict[str, Any] = {
|
|
467
|
+
# "total_input_tokens": sum([r["input_tokens"] for r in results]),
|
|
468
|
+
# "total_output_tokens": sum([r["output_tokens"] for r in results]),
|
|
469
|
+
# "total_min_cost": sum([r["min_cost"] for r in results]),
|
|
470
|
+
# "total_max_cost": sum([r["max_cost"] for r in results]),
|
|
471
|
+
# }
|
|
472
|
+
# minimum_time_tpm = combined_results["total_input_tokens"] / max_tokens_per_minute
|
|
473
|
+
# maximum_time_tpm = (
|
|
474
|
+
# combined_results["total_input_tokens"] + combined_results["total_output_tokens"]
|
|
475
|
+
# ) / max_tokens_per_minute
|
|
476
|
+
# minimum_time_rpm = len(prompts) / max_requests_per_minute
|
|
477
|
+
|
|
478
|
+
# combined_results["minimum_time"] = max(minimum_time_tpm, minimum_time_rpm)
|
|
479
|
+
# combined_results["maximum_time"] = max(maximum_time_tpm, minimum_time_rpm)
|
|
480
|
+
# limiting_factor = None
|
|
481
|
+
# if minimum_time_rpm > maximum_time_tpm:
|
|
482
|
+
# limiting_factor = "requests"
|
|
483
|
+
# elif minimum_time_rpm < minimum_time_tpm:
|
|
484
|
+
# limiting_factor = "tokens"
|
|
485
|
+
# else:
|
|
486
|
+
# limiting_factor = "depends"
|
|
487
|
+
# combined_results["limiting_factor"] = limiting_factor
|
|
488
|
+
|
|
489
|
+
# return combined_results
|