lm-deluge 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

lm_deluge/client.py CHANGED
@@ -1,42 +1,140 @@
1
- import os
2
- import requests
3
1
  import asyncio
2
+ from typing import Any, Literal, Self, Sequence, overload
3
+
4
4
  import numpy as np
5
- import time
6
5
  import yaml
7
- from dataclasses import dataclass
8
- from typing import Sequence, overload, Literal, Any
9
- from tqdm.auto import tqdm
10
-
11
- from lm_deluge.prompt import Conversation, CachePattern
6
+ from pydantic import BaseModel
7
+ from pydantic.functional_validators import model_validator
8
+
9
+ from lm_deluge.batches import (
10
+ submit_batches_anthropic,
11
+ submit_batches_oa,
12
+ wait_for_batch_completion_async,
13
+ )
14
+ from lm_deluge.prompt import CachePattern, Conversation, prompts_to_conversations
12
15
  from lm_deluge.tool import Tool
13
16
 
14
- from .tracker import StatusTracker
15
- from .sampling_params import SamplingParams
16
- from .models import registry
17
- from .api_requests.base import APIResponse, APIRequestBase
18
17
  from .api_requests import create_api_request
18
+ from .api_requests.base import APIRequestBase, APIResponse, deduplicate_responses
19
+ from .config import SamplingParams
20
+ from .models import registry
21
+ from .tracker import StatusTracker
22
+
19
23
  # from .cache import LevelDBCache, SqliteCache
20
24
 
25
+
21
26
  # TODO: get completions as they finish, not all at once at the end.
22
27
  # relatedly, would be nice to cache them as they finish too.
23
-
24
28
  # TODO: add optional max_input_tokens to client so we can reject long prompts to prevent abuse
29
+ class LLMClient(BaseModel):
30
+ """
31
+ LLMClient abstracts all the fixed arguments to process_prompts_async, so you can create it
32
+ once and use it for more stuff without having to configure all the arguments.
33
+ Handles models, sampling params for each model, model weights, rate limits, etc.
34
+ """
25
35
 
26
-
27
- @dataclass
28
- class ClientConfig:
29
- model_names: list[str]
30
- max_requests_per_minute: int
31
- max_tokens_per_minute: int
32
- max_concurrent_requests: int
33
- max_attempts: int
34
- request_timeout: int
35
- sampling_params: SamplingParams | list[SamplingParams]
36
- model_weights: list[float] | Literal["uniform", "rate_limit"]
36
+ model_names: list[str] = ["gpt-4.1-mini"]
37
+ max_requests_per_minute: int = 1_000
38
+ max_tokens_per_minute: int = 100_000
39
+ max_concurrent_requests: int = 225
40
+ sampling_params: list[SamplingParams] = []
41
+ model_weights: list[float] | Literal["uniform", "dynamic"] = "uniform"
42
+ max_attempts: int = 5
43
+ request_timeout: int = 30
44
+ cache: Any = None
45
+ # sampling params - if provided, and sampling_params is not,
46
+ # these override the defaults
47
+ temperature: float = 0.75
48
+ top_p: float = 1.0
49
+ json_mode: bool = False
50
+ max_new_tokens: int = 512
51
+ reasoning_effort: Literal["low", "medium", "high", None] = None
37
52
  logprobs: bool = False
38
53
  top_logprobs: int | None = None
39
- cache: Any = None
54
+
55
+ # NEW! Builder methods
56
+ def with_model(self, model: str):
57
+ self.model_names = [model]
58
+ return self
59
+
60
+ def with_models(self, models: list[str]):
61
+ self.model_names = models
62
+ return self
63
+
64
+ def with_limits(
65
+ self,
66
+ max_requests_per_minute: int | None = None,
67
+ max_tokens_per_minute: int | None = None,
68
+ max_concurrent_requests: int | None = None,
69
+ ):
70
+ if max_requests_per_minute:
71
+ self.max_requests_per_minute = max_requests_per_minute
72
+ if max_tokens_per_minute:
73
+ self.max_tokens_per_minute = max_tokens_per_minute
74
+ if max_concurrent_requests:
75
+ self.max_concurrent_requests = max_concurrent_requests
76
+
77
+ @property
78
+ def models(self):
79
+ return self.model_names # why? idk
80
+
81
+ @model_validator(mode="before")
82
+ @classmethod
83
+ def fix_lists(cls, data) -> "LLMClient":
84
+ if isinstance(data["model_names"], str):
85
+ data["model_names"] = [data["model_names"]]
86
+ if "sampling_params" not in data or len(data.get("sampling_params", [])) == 0:
87
+ data["sampling_params"] = [
88
+ SamplingParams(
89
+ temperature=data.get("temperature", 0.75),
90
+ top_p=data.get("top_p", 1.0),
91
+ json_mode=data.get("json_mode", False),
92
+ max_new_tokens=data.get("max_new_tokens", 512),
93
+ reasoning_effort=data.get("reasoning_effort", None),
94
+ logprobs=data.get("logprobs", False),
95
+ top_logprobs=data.get("top_logprobs", None),
96
+ )
97
+ ]
98
+ return data
99
+
100
+ @model_validator(mode="after")
101
+ def validate_client(self) -> Self:
102
+ if isinstance(self.model_names, str):
103
+ self.model_names = [self.model_names]
104
+ if any(m not in registry for m in self.model_names):
105
+ raise ValueError("all model_names must be in registry")
106
+ if isinstance(self.sampling_params, SamplingParams):
107
+ self.sampling_params = [self.sampling_params for _ in self.model_names]
108
+ elif len(self.sampling_params) != len(self.model_names):
109
+ raise ValueError("# models and # sampling params must match")
110
+ if self.model_weights == "uniform":
111
+ self.model_weights = [1 / len(self.model_names) for _ in self.model_names]
112
+ elif self.model_weights == "dynamic":
113
+ raise NotImplementedError("dynamic model weights not implemented yet")
114
+ # normalize weights
115
+ self.model_weights = [w / sum(self.model_weights) for w in self.model_weights]
116
+
117
+ # Validate logprobs settings across all sampling params
118
+ if self.logprobs or any(sp.logprobs for sp in self.sampling_params):
119
+ print("Logprobs enabled.")
120
+ for sp in self.sampling_params:
121
+ sp.logprobs = True
122
+ # set top_logprobs for each sp if provided and not set
123
+ if self.top_logprobs and not sp.top_logprobs:
124
+ sp.top_logprobs = self.top_logprobs
125
+ if sp.top_logprobs and not (0 <= sp.top_logprobs <= 20):
126
+ raise ValueError("top_logprobs must be 0-20")
127
+ if sp.top_logprobs and sp.max_new_tokens > 10:
128
+ print(
129
+ "WARNING: using top_logprobs can result in very large outputs. consider limiting max_new_tokens."
130
+ )
131
+ if not all(
132
+ registry[model].get("supports_logprobs") for model in self.models
133
+ ):
134
+ raise ValueError(
135
+ "logprobs can only be enabled if all models support it."
136
+ )
137
+ return self
40
138
 
41
139
  @classmethod
42
140
  def from_dict(cls, config_dict: dict):
@@ -46,7 +144,7 @@ class ClientConfig:
46
144
  ]
47
145
  else:
48
146
  config_dict["sampling_params"] = SamplingParams(
49
- config_dict["sampling_params"]
147
+ **config_dict["sampling_params"]
50
148
  )
51
149
 
52
150
  return cls(**config_dict)
@@ -56,184 +154,13 @@ class ClientConfig:
56
154
  config_dict = yaml.safe_load(open(file_path))
57
155
  return cls.from_dict(config_dict)
58
156
 
59
- def to_dict(self):
60
- if isinstance(self.sampling_params, list):
61
- sp = [x.__dict__ for x in self.sampling_params]
62
- else:
63
- sp = self.sampling_params.__dict__
64
-
65
- return {
66
- "model_names": self.model_names,
67
- "max_requests_per_minute": self.max_requests_per_minute,
68
- "max_tokens_per_minute": self.max_tokens_per_minute,
69
- "max_concurrent_requests": self.max_concurrent_requests,
70
- "max_attempts": self.max_attempts,
71
- "request_timeout": self.request_timeout,
72
- "sampling_params": sp,
73
- "model_weights": self.model_weights,
74
- "logprobs": self.logprobs,
75
- "top_logprobs": self.top_logprobs,
76
- }
77
-
78
-
79
- class LLMClient:
80
- """
81
- LLMClient abstracts all the fixed arguments to process_prompts_async, so you can create it
82
- once and use it for more stuff without having to configure all the arguments.
83
- Handles models, sampling params for each model, model weights, rate limits, etc.
84
- """
85
-
86
- def __init__(
87
- self,
88
- model_names: list[str],
89
- *,
90
- max_requests_per_minute: int,
91
- max_tokens_per_minute: int,
92
- max_concurrent_requests: int,
93
- sampling_params: SamplingParams | list[SamplingParams] = SamplingParams(),
94
- model_weights: list[float] | Literal["uniform", "rate_limit"] = "uniform",
95
- max_attempts: int = 5,
96
- request_timeout: int = 30,
97
- logprobs: bool = False,
98
- top_logprobs: int | None = None,
99
- use_qps: bool = False,
100
- debug: bool = False,
101
- cache: Any = None,
102
- ):
103
- self.models = model_names
104
- if isinstance(sampling_params, SamplingParams):
105
- self.sampling_params = [sampling_params for _ in model_names]
106
- else:
107
- if len(sampling_params) != len(model_names):
108
- raise ValueError(
109
- "If sampling_params is a list, it must have the same length as model_names."
110
- )
111
- self.sampling_params = sampling_params
112
- if model_weights == "uniform":
113
- self.model_weights = [1 / len(model_names) for _ in model_names]
114
- elif model_weights == "rate_limit":
115
- rpms = [registry[model]["requests_per_minute"] for model in model_names]
116
- self.model_weights = [rpm / sum(rpms) for rpm in rpms]
117
- elif sum(model_weights) != 1:
118
- self.model_weights = [w / sum(model_weights) for w in model_weights]
119
- else:
120
- self.model_weights = model_weights
121
-
122
- self.logprobs = logprobs
123
- self.top_logprobs = top_logprobs
124
-
125
- # logprobs and top_logprobs are only supported for OpenAI models
126
- if self.logprobs:
127
- for model in self.models:
128
- if registry[model].get("supports_logprobs", False) is False:
129
- raise ValueError(
130
- "logprobs can only be enabled if all models support it."
131
- )
132
- if self.top_logprobs is None:
133
- self.top_logprobs = 0 # will just return logprob of the chosen token
134
- elif self.top_logprobs > 20 or self.top_logprobs < 0:
135
- raise ValueError("top_logprobs must be between 0 and 20.")
136
- for sp in self.sampling_params:
137
- if sp.max_new_tokens > 10:
138
- print(
139
- "WARNING: using logprobs with large max_new_tokens can result in very large outputs. you may want to avoid saving these outputs to disk/db."
140
- )
141
- break
142
- else:
143
- self.top_logprobs = None
144
-
145
- self.max_requests_per_minute = max_requests_per_minute
146
- self.max_tokens_per_minute = max_tokens_per_minute
147
- self.max_concurrent_requests = max_concurrent_requests
148
- self.max_attempts = max_attempts
149
- self.request_timeout = request_timeout
150
- self.use_qps = use_qps
151
- self.debug = (
152
- debug # UNUSED/DEPRECATED i think? but dont want to break everything
153
- )
154
- self.cache = cache
155
-
156
- @classmethod
157
- def from_config(cls, config: ClientConfig, cache: Any = None):
158
- return cls(
159
- model_names=config.model_names,
160
- max_requests_per_minute=config.max_requests_per_minute,
161
- max_tokens_per_minute=config.max_tokens_per_minute,
162
- max_concurrent_requests=config.max_concurrent_requests,
163
- sampling_params=config.sampling_params,
164
- model_weights=config.model_weights,
165
- max_attempts=config.max_attempts,
166
- request_timeout=config.request_timeout,
167
- cache=cache,
168
- )
169
-
170
157
  @classmethod
171
- def from_yaml(cls, file_path: str, cache: Any = None):
172
- return cls.from_config(ClientConfig.from_yaml(file_path), cache=cache)
173
-
174
- @classmethod
175
- def basic(
176
- cls,
177
- model: str | list[str],
178
- max_requests_per_minute: int = 5_000,
179
- max_tokens_per_minute: int = 1_000_000,
180
- max_concurrent_requests: int = 1_000,
181
- temperature: float = 0.75,
182
- max_new_tokens: int = 1000,
183
- reasoning_effort: Literal[None, "low", "medium", "high"] = None,
184
- model_weights: list[float] | Literal["uniform", "rate_limit"] = "uniform",
185
- logprobs: bool = False,
186
- top_logprobs: int | None = None,
187
- max_attempts: int = 5,
188
- request_timeout: int = 30,
189
- cache: Any = None,
190
- ):
191
- model_names = model if isinstance(model, list) else [model]
192
- return cls(
193
- model_names=model_names,
194
- max_requests_per_minute=max_requests_per_minute,
195
- max_tokens_per_minute=max_tokens_per_minute,
196
- max_concurrent_requests=max_concurrent_requests,
197
- sampling_params=SamplingParams(
198
- temperature=temperature,
199
- max_new_tokens=max_new_tokens,
200
- reasoning_effort=reasoning_effort,
201
- ),
202
- logprobs=logprobs,
203
- top_logprobs=top_logprobs,
204
- model_weights=model_weights,
205
- max_attempts=max_attempts,
206
- request_timeout=request_timeout,
207
- cache=cache,
208
- )
209
-
210
- @property
211
- def config(self):
212
- return ClientConfig(
213
- model_names=self.models,
214
- model_weights=self.model_weights,
215
- max_requests_per_minute=self.max_requests_per_minute,
216
- max_tokens_per_minute=self.max_tokens_per_minute,
217
- max_concurrent_requests=self.max_concurrent_requests,
218
- max_attempts=self.max_attempts,
219
- request_timeout=self.request_timeout,
220
- sampling_params=self.sampling_params,
221
- logprobs=self.logprobs,
222
- top_logprobs=self.top_logprobs,
223
- )
224
-
225
- @overload
226
- async def process_prompts_async(
227
- self,
228
- prompts: Sequence[str | list[dict] | Conversation],
229
- *,
230
- return_completions_only: bool,
231
- show_progress: bool = ...,
232
- dry_run: Literal[True],
233
- verbose: bool = ...,
234
- tools: list[Tool] | None = ...,
235
- cache: CachePattern | None = ...,
236
- ) -> dict[str, int]: ...
158
+ def basic(cls, model: str | list[str], **kwargs):
159
+ """
160
+ Doesn't do anything differently now, kept for backwards compat.
161
+ """
162
+ kwargs["model_names"] = model
163
+ return cls(**kwargs)
237
164
 
238
165
  @overload
239
166
  async def process_prompts_async(
@@ -242,10 +169,12 @@ class LLMClient:
242
169
  *,
243
170
  return_completions_only: Literal[True],
244
171
  show_progress: bool = ...,
245
- dry_run: bool = ...,
246
- verbose: bool = ...,
247
172
  tools: list[Tool] | None = ...,
248
173
  cache: CachePattern | None = ...,
174
+ computer_use: bool = ...,
175
+ display_width: int = ...,
176
+ display_height: int = ...,
177
+ use_responses_api: bool = ...,
249
178
  ) -> list[str | None]: ...
250
179
 
251
180
  @overload
@@ -255,10 +184,12 @@ class LLMClient:
255
184
  *,
256
185
  return_completions_only: Literal[False] = ...,
257
186
  show_progress: bool = ...,
258
- dry_run: bool = ...,
259
- verbose: bool = ...,
260
187
  tools: list[Tool] | None = ...,
261
188
  cache: CachePattern | None = ...,
189
+ computer_use: bool = ...,
190
+ display_width: int = ...,
191
+ display_height: int = ...,
192
+ use_responses_api: bool = ...,
262
193
  ) -> list[APIResponse | None]: ...
263
194
 
264
195
  async def process_prompts_async(
@@ -267,23 +198,15 @@ class LLMClient:
267
198
  *,
268
199
  return_completions_only: bool = False,
269
200
  show_progress: bool = True,
270
- dry_run: bool = False,
271
- verbose: bool = False,
272
201
  tools: list[Tool] | None = None,
273
202
  cache: CachePattern | None = None,
203
+ computer_use: bool = False,
204
+ display_width: int = 1024,
205
+ display_height: int = 768,
206
+ use_responses_api: bool = False,
274
207
  ) -> list[APIResponse | None] | list[str | None] | dict[str, int]:
275
208
  # if prompts are not Conversations, convert them.
276
- # can only handle strings for now
277
- prompts = [ # type: ignore
278
- p
279
- if isinstance(p, Conversation)
280
- else Conversation.user(p)
281
- if isinstance(p, str)
282
- else None
283
- for p in prompts
284
- ]
285
- if any(p is None for p in prompts):
286
- raise ValueError("All prompts must be valid.")
209
+ prompts = prompts_to_conversations(prompts)
287
210
  ids = np.arange(len(prompts))
288
211
 
289
212
  # if using cache, check for cached completions
@@ -298,10 +221,9 @@ class LLMClient:
298
221
  ), "Cache hit ids and results must be the same length."
299
222
  remaining_ids = np.array([i for i in ids if i not in cache_hit_ids])
300
223
  remaining_prompts = [prompts[i] for i in remaining_ids]
301
- if verbose:
302
- print(f"{len(cache_hit_ids)} cache hits from previous completions.")
303
- print(f"{len(remaining_ids)} prompts remaining after cache hits.")
304
- print(f"Processing {len(remaining_prompts)} prompts.")
224
+ print(
225
+ f"{len(cache_hit_ids)} cache hits; {len(remaining_ids)} prompts remaining."
226
+ )
305
227
 
306
228
  else:
307
229
  cache_hit_ids = []
@@ -311,48 +233,137 @@ class LLMClient:
311
233
 
312
234
  results: list[APIResponse | None] = [None for _ in range(len(prompts))]
313
235
  if len(remaining_prompts) > 0:
314
- # set up progress bar
315
- pbar = tqdm(total=len(prompts), disable=(not show_progress))
316
-
317
- # update progress bar with cache hits
318
- pbar.update(len(cache_hit_ids))
319
- api_task = None
320
- if dry_run:
321
- dry_run_results = api_prompts_dry_run(
322
- ids,
323
- prompts, # type: ignore -- fix later for dry running conversations
324
- self.models,
325
- self.model_weights,
326
- self.sampling_params,
327
- max_tokens_per_minute=self.max_tokens_per_minute,
328
- max_requests_per_minute=self.max_requests_per_minute,
329
- )
330
- print("Dry run results:")
331
- print(dry_run_results)
332
- return dry_run_results
333
-
334
- api_task = asyncio.create_task(
335
- process_api_prompts_async(
336
- ids,
337
- prompts, # type: ignore -- fix later for dry running conversations
338
- self.models,
339
- self.model_weights,
340
- self.sampling_params,
341
- logprobs=self.logprobs,
342
- top_logprobs=self.top_logprobs,
343
- max_attempts=self.max_attempts,
344
- max_tokens_per_minute=self.max_tokens_per_minute,
345
- max_requests_per_minute=self.max_requests_per_minute,
346
- max_concurrent_requests=self.max_concurrent_requests,
347
- request_timeout=self.request_timeout,
348
- progress_bar=pbar,
349
- use_qps=self.use_qps,
350
- verbose=verbose,
351
- tools=tools,
352
- cache=cache,
353
- )
236
+ # Create StatusTracker with integrated progress bar
237
+ tracker = StatusTracker(
238
+ max_requests_per_minute=self.max_requests_per_minute,
239
+ max_tokens_per_minute=self.max_tokens_per_minute,
240
+ max_concurrent_requests=self.max_concurrent_requests,
241
+ use_progress_bar=show_progress,
242
+ progress_bar_total=len(prompts),
243
+ progress_bar_disable=not show_progress,
244
+ use_rich=show_progress, # Disable Rich if progress is disabled
354
245
  )
355
- api_results: list[APIResponse] = await api_task
246
+
247
+ # Initialize progress bar and update with cache hits
248
+ tracker.init_progress_bar()
249
+ if len(cache_hit_ids) > 0:
250
+ tracker.update_pbar(len(cache_hit_ids))
251
+
252
+ # api_task = asyncio.create_task(
253
+ # process_api_prompts_async(
254
+ # ids,
255
+ # prompts, # type: ignore -- fix later for dry running conversations
256
+ # self.models,
257
+ # self.model_weights, # type: ignore
258
+ # self.sampling_params, # type: ignore
259
+ # max_attempts=self.max_attempts,
260
+ # max_concurrent_requests=self.max_concurrent_requests,
261
+ # request_timeout=self.request_timeout,
262
+ # status_tracker=tracker,
263
+ # tools=tools,
264
+ # cache=cache,
265
+ # computer_use=computer_use,
266
+ # display_width=display_width,
267
+ # display_height=display_height,
268
+ # use_responses_api=use_responses_api,
269
+ # )
270
+ # )
271
+ # async def process_api_prompts_async(
272
+
273
+ # models: str | list[str],
274
+ # model_weights: list[float],
275
+ # sampling_params: list[SamplingParams],
276
+ # max_attempts: int = 5,
277
+ # max_concurrent_requests: int = 1_000,
278
+ # request_timeout: int = 30,
279
+ # status_tracker: StatusTracker | None = None,
280
+ # tools: list[Tool] | None = None,
281
+ # cache: CachePattern | None = None,
282
+ # computer_use: bool = False,
283
+ # display_width: int = 1024,
284
+ # display_height: int = 768,
285
+ # use_responses_api: bool = False,
286
+ # ):
287
+ if isinstance(ids, np.ndarray):
288
+ ids = ids.tolist() # pyright: ignore
289
+
290
+ # calculate dynamically so we don't throttle RPM
291
+ seconds_to_sleep_each_loop = (60.0 * 0.9) / tracker.max_requests_per_minute
292
+ next_request = None # variable to hold the next request to call
293
+ prompts_not_finished = True
294
+ prompts_iter = iter(zip(ids, prompts))
295
+ requests: list[APIRequestBase] = []
296
+ assert tracker.retry_queue, "retry queue not initialized"
297
+ while True:
298
+ # get next request (if one is not already waiting for capacity)
299
+ if next_request is None:
300
+ if not tracker.retry_queue.empty():
301
+ next_request = tracker.retry_queue.get_nowait()
302
+ print(f"Retrying request {next_request.task_id}.")
303
+ elif prompts_not_finished:
304
+ try:
305
+ # get new request
306
+ id, prompt = next(prompts_iter)
307
+ # select model
308
+ assert isinstance(self.model_weights, list)
309
+ model_idx = np.random.choice(
310
+ range(len(self.models)), p=self.model_weights
311
+ )
312
+ next_request = create_api_request(
313
+ task_id=id,
314
+ model_name=self.models[model_idx],
315
+ prompt=prompt, # type: ignore
316
+ request_timeout=self.request_timeout,
317
+ attempts_left=self.max_attempts,
318
+ status_tracker=tracker,
319
+ results_arr=requests,
320
+ sampling_params=self.sampling_params[model_idx],
321
+ all_model_names=self.models,
322
+ all_sampling_params=self.sampling_params,
323
+ tools=tools,
324
+ cache=cache,
325
+ computer_use=computer_use,
326
+ display_width=display_width,
327
+ display_height=display_height,
328
+ use_responses_api=use_responses_api,
329
+ )
330
+ requests.append(next_request)
331
+
332
+ except StopIteration:
333
+ prompts_not_finished = False
334
+ # print("API requests finished, only retries remain.")
335
+
336
+ # update available capacity
337
+ tracker.update_capacity()
338
+
339
+ # if enough capacity available, call API
340
+ if next_request:
341
+ next_request_tokens = next_request.num_tokens
342
+ if tracker.check_capacity(next_request_tokens):
343
+ tracker.set_limiting_factor(None)
344
+ next_request.attempts_left -= 1
345
+ # call API
346
+ asyncio.create_task(next_request.call_api())
347
+ next_request = None # reset next_request to empty
348
+ # update pbar status
349
+ tracker.update_pbar()
350
+
351
+ # if all tasks are finished, break
352
+ if tracker.num_tasks_in_progress == 0:
353
+ break
354
+
355
+ # main loop sleeps briefly so concurrent tasks can run
356
+ await asyncio.sleep(seconds_to_sleep_each_loop)
357
+
358
+ # if a rate limit error was hit recently, pause to cool down
359
+ if tracker.seconds_to_pause > 0:
360
+ await asyncio.sleep(tracker.seconds_to_pause)
361
+ print(f"Pausing {tracker.seconds_to_pause}s to cool down.")
362
+
363
+ # after finishing, log final status
364
+ tracker.log_final_status()
365
+ # deduplicate results by id
366
+ api_results = deduplicate_responses(requests)
356
367
  for res in api_results:
357
368
  results[res.id] = res
358
369
  # set to cache if result has a completion
@@ -375,8 +386,6 @@ class LLMClient:
375
386
  *,
376
387
  return_completions_only: bool = False,
377
388
  show_progress=True,
378
- dry_run: bool = False,
379
- verbose: bool = False,
380
389
  tools: list[Tool] | None = None,
381
390
  cache: CachePattern | None = None,
382
391
  ):
@@ -385,387 +394,108 @@ class LLMClient:
385
394
  prompts=prompts,
386
395
  return_completions_only=return_completions_only,
387
396
  show_progress=show_progress,
388
- dry_run=dry_run,
389
- verbose=verbose,
390
397
  tools=tools,
391
398
  cache=cache,
392
399
  )
393
400
  )
394
401
 
395
- def _submit_one_batch(self, batch_requests: list):
396
- # save the file
397
- import pandas as pd
398
-
399
- pd.DataFrame(batch_requests).to_json(
400
- "openai_requests_temp.jsonl", orient="records", lines=True
401
- )
402
-
403
- # upload the file
404
- api_key = os.environ.get("OPENAI_API_KEY", None)
405
- if api_key is None:
406
- raise ValueError("OPENAI_API_KEY environment variable must be set.")
407
- url = "https://api.openai.com/v1/files"
408
- files = {
409
- "file": (
410
- "openai_requests_temp.jsonl",
411
- open("openai_requests_temp.jsonl", "rb"),
412
- ),
413
- }
414
- data = {
415
- "purpose": "batch",
416
- }
417
- headers = {
418
- "Authorization": f"Bearer {api_key}",
419
- }
420
- response = requests.post(url, files=files, data=data, headers=headers)
421
-
422
- file_id = None
423
- if response.status_code == 200:
424
- print("File uploaded successfully")
425
- data = response.json()
426
- file_id = data["id"]
427
-
428
- else:
429
- print("File upload failed")
430
- raise ValueError(f"Error uploading file: {response.text}")
431
-
432
- url = "https://api.openai.com/v1/batches"
433
- data = {
434
- "input_file_id": file_id,
435
- "endpoint": "/v1/chat/completions",
436
- "completion_window": "24h",
437
- }
438
- response = requests.post(url, json=data, headers=headers)
439
-
440
- batch_id = None
441
- if response.status_code == 200:
442
- data = response.json()
443
- batch_id = data["id"]
444
- print("Batch job started successfully: id = ", batch_id)
445
- return batch_id
446
- else:
447
- print("Batch job failed to start")
448
- raise ValueError(f"Error starting batch job: {response.text}")
449
-
450
- def submit_batch_job(self, prompts: Sequence[str | list[dict] | Conversation]):
451
- # make sure 1) only 1 model is used, 2) it's an openai model, 3) it supports json mode
402
+ async def submit_batch_job(
403
+ self,
404
+ prompts: Sequence[str | list[dict] | Conversation],
405
+ *,
406
+ tools: list[Tool] | None = None,
407
+ cache: CachePattern | None = None,
408
+ ):
409
+ """Submit a batch job asynchronously, automatically detecting the provider based on model.
410
+
411
+ Args:
412
+ prompts: List of prompts to process
413
+ wait_for_completion: If True, poll until completion and return results
414
+ poll_interval: Seconds to wait between status checks when polling
415
+ tools: Optional tools to include in requests (Anthropic only)
416
+ cache: Optional cache pattern for requests (Anthropic only)
417
+
418
+ Returns: list of batch_ids
419
+ """
420
+ assert isinstance(self.sampling_params, list)
452
421
  if len(self.models) != 1:
453
422
  raise ValueError("Batch jobs can only be submitted with a single model.")
454
423
  model = self.models[0]
455
- if registry[model].get("api_spec", None) != "openai":
456
- raise ValueError("Batch jobs can only be submitted with OpenAI models.")
457
-
458
- # if prompts are strings, convert them to message lists
459
- prompts = [ # type: ignore
460
- p
461
- if isinstance(p, Conversation)
462
- else Conversation.user(p)
463
- if isinstance(p, str)
464
- else None
465
- for p in prompts
466
- ]
467
- if any(p is None for p in prompts):
468
- raise ValueError("All prompts must be valid.")
469
- ids = np.arange(len(prompts))
470
-
471
- # create file with requests to send to batch api
472
- batch_requests = []
473
- for id, prompt in zip(ids, prompts):
474
- assert isinstance(prompt, Conversation)
475
- batch_requests.append(
476
- {
477
- "custom_id": str(id),
478
- "method": "POST",
479
- "url": "/v1/chat/completions",
480
- "body": {
481
- "model": self.models[0],
482
- "messages": prompt.to_openai(),
483
- "max_tokens": self.sampling_params[0].max_new_tokens,
484
- "temperature": self.sampling_params[0].temperature,
485
- "top_p": self.sampling_params[0].top_p,
486
- },
487
- }
488
- )
489
-
490
- # since the api only accepts up to 50,000 requests per batch job, we chunk into 50k chunks
491
- BATCH_SIZE = 50_000
492
- batches = [
493
- batch_requests[i : i + BATCH_SIZE]
494
- for i in range(0, len(batch_requests), BATCH_SIZE)
495
- ]
496
- batch_ids = []
497
- for batch in tqdm(batches):
498
- batch_id = self._submit_one_batch(batch)
499
- batch_ids.append(batch_id)
500
-
501
- print(f"Submitted {len(batches)} batch jobs.")
502
- return batch_ids
503
-
504
-
505
- def api_prompts_dry_run(
506
- ids: np.ndarray | list[int],
507
- prompts: list[Conversation],
508
- models: str | list[str],
509
- model_weights: list[float],
510
- sampling_params: list[SamplingParams],
511
- max_tokens_per_minute: int = 500_000,
512
- max_requests_per_minute: int = 1_000,
513
- ):
514
- """
515
- Count tokens and estimate costs for a batch of prompts.
516
- """
517
- results = []
518
- for i, prompt in zip(ids, prompts):
519
- # choose a model
520
- model_idx = np.random.choice(range(len(models)), p=model_weights)
521
- model = models[model_idx]
522
-
523
- # dry run
524
- input_tokens, output_tokens, min_cost, max_cost = prompt.dry_run(
525
- model, sampling_params[model_idx].max_new_tokens
526
- )
527
- results.append(
528
- {
529
- "id": i,
530
- "input_tokens": input_tokens,
531
- "output_tokens": output_tokens,
532
- "min_cost": min_cost,
533
- "max_cost": max_cost,
534
- }
535
- )
536
-
537
- combined_results: dict[str, Any] = {
538
- "total_input_tokens": sum([r["input_tokens"] for r in results]),
539
- "total_output_tokens": sum([r["output_tokens"] for r in results]),
540
- "total_min_cost": sum([r["min_cost"] for r in results]),
541
- "total_max_cost": sum([r["max_cost"] for r in results]),
542
- }
543
- minimum_time_tpm = combined_results["total_input_tokens"] / max_tokens_per_minute
544
- maximum_time_tpm = (
545
- combined_results["total_input_tokens"] + combined_results["total_output_tokens"]
546
- ) / max_tokens_per_minute
547
- minimum_time_rpm = len(prompts) / max_requests_per_minute
548
-
549
- combined_results["minimum_time"] = max(minimum_time_tpm, minimum_time_rpm)
550
- combined_results["maximum_time"] = max(maximum_time_tpm, minimum_time_rpm)
551
- limiting_factor = None
552
- if minimum_time_rpm > maximum_time_tpm:
553
- limiting_factor = "requests"
554
- elif minimum_time_rpm < minimum_time_tpm:
555
- limiting_factor = "tokens"
556
- else:
557
- limiting_factor = "depends"
558
- combined_results["limiting_factor"] = limiting_factor
559
-
560
- return combined_results
561
-
562
-
563
- async def process_api_prompts_async(
564
- ids: np.ndarray | list[int],
565
- prompts: list[Conversation],
566
- models: str | list[str],
567
- model_weights: list[float],
568
- sampling_params: list[SamplingParams],
569
- logprobs: bool,
570
- top_logprobs: int | None,
571
- max_attempts: int = 5,
572
- max_tokens_per_minute: int = 500_000,
573
- max_requests_per_minute: int = 1_000,
574
- max_concurrent_requests: int = 1_000,
575
- request_timeout: int = 30,
576
- progress_bar: tqdm | None = None,
577
- use_qps: bool = False,
578
- verbose: bool = False,
579
- tools: list[Tool] | None = None,
580
- cache: CachePattern | None = None,
581
- ):
582
- """Processes API requests in parallel, throttling to stay under rate limits."""
583
- # change ids to integer list
584
- if isinstance(ids, np.ndarray):
585
- ids = ids.tolist() # pyright: ignore
586
-
587
- # normalize weights
588
- model_weights = [w / sum(model_weights) for w in model_weights]
589
-
590
- # constants
591
- seconds_to_pause_after_rate_limit_error = 5
592
- # seconds_to_sleep_each_loop = 0.003 # so concurrent tasks can run
593
- # calculate dynamically so we don't throttle RPM
594
- seconds_to_sleep_each_loop = (60.0 * 0.9) / max_requests_per_minute
595
-
596
- # initialize trackers
597
- retry_queue = asyncio.Queue()
598
- status_tracker = StatusTracker()
599
- next_request = None # variable to hold the next request to call
600
-
601
- # initialize available capacity counts
602
- # throttle over a 1 second window rather than minute,
603
- # since some models limit RPS rather than RPM
604
- if use_qps:
605
- available_request_capacity = max_requests_per_minute / 60.0
606
- available_token_capacity = max_tokens_per_minute
607
- else:
608
- available_request_capacity = max_requests_per_minute
609
- available_token_capacity = max_tokens_per_minute
610
- last_update_time = time.time()
611
- last_pbar_update_time = time.time()
612
-
613
- # initialize flags
614
- prompts_not_finished = True
615
-
616
- # initials model weights
617
- if isinstance(models, str):
618
- models = [models]
619
- if not isinstance(models, list):
620
- raise ValueError("models must be a string or a list of model strings.")
621
- for model in models:
622
- if model not in registry:
623
- raise ValueError(f"Model {model} not found in registry.")
624
-
625
- if model_weights is None:
626
- # if not given, spread requests evenly across models
627
- model_weights = [1 / len(models) for _ in models]
628
- elif len(model_weights) != len(models):
629
- raise ValueError(
630
- "model_weights must be None or a list of the same length as models."
631
- )
632
- elif sum(model_weights) != 1:
633
- model_weights = [w / sum(model_weights) for w in model_weights]
634
-
635
- prompts_iter = iter(zip(ids, prompts))
636
- results: list[APIRequestBase] = []
637
- while True:
638
- # get next request (if one is not already waiting for capacity)
639
- if next_request is None:
640
- if not retry_queue.empty():
641
- next_request = retry_queue.get_nowait()
642
- print(f"Retrying request {next_request.task_id}.")
643
- elif prompts_not_finished:
644
- try:
645
- # get new request
646
- id, prompt = next(prompts_iter)
647
- # select model
648
- model_idx = np.random.choice(range(len(models)), p=model_weights)
649
- next_request = create_api_request(
650
- task_id=id,
651
- model_name=models[model_idx],
652
- prompt=prompt,
653
- request_timeout=request_timeout,
654
- attempts_left=max_attempts,
655
- status_tracker=status_tracker,
656
- retry_queue=retry_queue,
657
- results_arr=results,
658
- sampling_params=sampling_params[model_idx],
659
- logprobs=logprobs,
660
- top_logprobs=top_logprobs,
661
- pbar=progress_bar,
662
- all_model_names=models,
663
- all_sampling_params=sampling_params,
664
- tools=tools,
665
- cache=cache,
666
- )
667
- status_tracker.num_tasks_started += 1
668
- status_tracker.num_tasks_in_progress += 1
669
- results.append(next_request)
670
-
671
- except StopIteration:
672
- prompts_not_finished = False
673
- if verbose:
674
- print("API requests finished, only retries remain.")
675
-
676
- # update available capacity
677
- current_time = time.time()
678
- seconds_since_update = current_time - last_update_time
679
- available_request_capacity = min(
680
- available_request_capacity
681
- + max_requests_per_minute * seconds_since_update / 60.0,
682
- max_requests_per_minute,
683
- )
684
- available_token_capacity = min(
685
- available_token_capacity
686
- + max_tokens_per_minute * seconds_since_update / 60.0,
687
- max_tokens_per_minute,
688
- )
689
- last_update_time = current_time
690
-
691
- # if enough capacity available, call API
692
- limiting_factor = None
693
- if next_request:
694
- next_request_tokens = next_request.num_tokens
695
- request_available = available_request_capacity >= 1
696
- tokens_available = available_token_capacity >= next_request_tokens
697
- concurrent_request_available = (
698
- status_tracker.num_tasks_in_progress < max_concurrent_requests
699
- )
700
- if request_available and tokens_available and concurrent_request_available:
701
- # update counters
702
- available_request_capacity -= 1
703
- available_token_capacity -= next_request_tokens
704
- next_request.attempts_left -= 1
705
-
706
- # call API
707
- asyncio.create_task(next_request.call_api())
708
- next_request = None # reset next_request to empty
709
- else:
710
- if not request_available:
711
- limiting_factor = "Requests"
712
- elif not concurrent_request_available:
713
- limiting_factor = "Concurrent Requests"
714
- elif not tokens_available:
715
- limiting_factor = "Tokens"
716
-
717
- # update pbar status
718
- if progress_bar and (current_time - last_pbar_update_time > 1):
719
- last_pbar_update_time = current_time
720
- progress_bar.set_postfix(
721
- {
722
- "Token Capacity": f"{available_token_capacity/1_000:.1f}k",
723
- "Req. Capacity": f"{available_request_capacity:.1f}",
724
- "Reqs. in Progress": status_tracker.num_tasks_in_progress,
725
- "Limiting Factor": limiting_factor,
726
- }
424
+ api_spec = registry[model].get("api_spec", None)
425
+
426
+ if api_spec == "openai":
427
+ return await submit_batches_oa(model, self.sampling_params[0], prompts)
428
+ elif api_spec == "anthropic":
429
+ return await submit_batches_anthropic(
430
+ model,
431
+ self.sampling_params[0],
432
+ prompts,
433
+ cache=cache,
727
434
  )
435
+ else:
436
+ raise ValueError(f"Batch processing not supported for API spec: {api_spec}")
728
437
 
729
- # if all tasks are finished, break
730
- if status_tracker.num_tasks_in_progress == 0:
731
- break
732
-
733
- # main loop sleeps briefly so concurrent tasks can run
734
- await asyncio.sleep(seconds_to_sleep_each_loop)
735
-
736
- # if a rate limit error was hit recently, pause to cool down
737
- remaining_seconds_to_pause = max(
738
- 0,
739
- seconds_to_pause_after_rate_limit_error
740
- - status_tracker.time_since_rate_limit_error,
741
- )
742
- if remaining_seconds_to_pause > 0:
743
- await asyncio.sleep(remaining_seconds_to_pause)
744
- print(f"Pausing {remaining_seconds_to_pause}s to cool down.")
745
-
746
- # after finishing, log final status
747
- status_tracker.log_final_status()
748
- if verbose:
749
- print(
750
- f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
438
+ async def wait_for_batch_job(
439
+ self, batch_ids: list[str], provider: Literal["anthropic", "openai"]
440
+ ):
441
+ return await wait_for_batch_completion_async(
442
+ batch_ids, provider, poll_interval=30
751
443
  )
752
444
 
753
- # deduplicate results by id
754
- deduplicated = {}
755
- for request in results:
756
- if request.task_id not in deduplicated:
757
- deduplicated[request.task_id] = request.result[-1]
758
- else:
759
- current_response: APIResponse = deduplicated[request.task_id]
760
- # only replace if the current request has no completion and the new one does
761
- if (
762
- request.result[-1].completion is not None
763
- and current_response.completion is None
764
- ):
765
- deduplicated[request.task_id] = request.result[-1]
766
-
767
- output = list(deduplicated.values())
768
- if verbose:
769
- print(f"Returning {len(output)} unique results.")
770
445
 
771
- return output
446
+ # def api_prompts_dry_run(
447
+ # ids: np.ndarray | list[int],
448
+ # prompts: list[Conversation],
449
+ # models: str | list[str],
450
+ # model_weights: list[float],
451
+ # sampling_params: list[SamplingParams],
452
+ # max_tokens_per_minute: int = 500_000,
453
+ # max_requests_per_minute: int = 1_000,
454
+ # ):
455
+ # """
456
+ # Count tokens and estimate costs for a batch of prompts.
457
+ # """
458
+ # results = []
459
+ # for i, prompt in zip(ids, prompts):
460
+ # # choose a model
461
+ # model_idx = np.random.choice(range(len(models)), p=model_weights)
462
+ # model = models[model_idx]
463
+
464
+ # # dry run
465
+ # input_tokens, output_tokens, min_cost, max_cost = prompt.dry_run(
466
+ # model, sampling_params[model_idx].max_new_tokens
467
+ # )
468
+ # results.append(
469
+ # {
470
+ # "id": i,
471
+ # "input_tokens": input_tokens,
472
+ # "output_tokens": output_tokens,
473
+ # "min_cost": min_cost,
474
+ # "max_cost": max_cost,
475
+ # }
476
+ # )
477
+
478
+ # combined_results: dict[str, Any] = {
479
+ # "total_input_tokens": sum([r["input_tokens"] for r in results]),
480
+ # "total_output_tokens": sum([r["output_tokens"] for r in results]),
481
+ # "total_min_cost": sum([r["min_cost"] for r in results]),
482
+ # "total_max_cost": sum([r["max_cost"] for r in results]),
483
+ # }
484
+ # minimum_time_tpm = combined_results["total_input_tokens"] / max_tokens_per_minute
485
+ # maximum_time_tpm = (
486
+ # combined_results["total_input_tokens"] + combined_results["total_output_tokens"]
487
+ # ) / max_tokens_per_minute
488
+ # minimum_time_rpm = len(prompts) / max_requests_per_minute
489
+
490
+ # combined_results["minimum_time"] = max(minimum_time_tpm, minimum_time_rpm)
491
+ # combined_results["maximum_time"] = max(maximum_time_tpm, minimum_time_rpm)
492
+ # limiting_factor = None
493
+ # if minimum_time_rpm > maximum_time_tpm:
494
+ # limiting_factor = "requests"
495
+ # elif minimum_time_rpm < minimum_time_tpm:
496
+ # limiting_factor = "tokens"
497
+ # else:
498
+ # limiting_factor = "depends"
499
+ # combined_results["limiting_factor"] = limiting_factor
500
+
501
+ # return combined_results