lm-deluge 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

lm_deluge/client.py CHANGED
@@ -1,42 +1,140 @@
1
- import os
2
- import requests
3
1
  import asyncio
2
+ from typing import Any, Literal, Self, Sequence, overload
3
+
4
4
  import numpy as np
5
- import time
6
5
  import yaml
7
- from dataclasses import dataclass
8
- from typing import Sequence, overload, Literal, Any
9
- from tqdm.auto import tqdm
10
-
11
- from lm_deluge.prompt import Conversation
6
+ from pydantic import BaseModel
7
+ from pydantic.functional_validators import model_validator
8
+
9
+ from lm_deluge.batches import (
10
+ submit_batches_anthropic,
11
+ submit_batches_oa,
12
+ wait_for_batch_completion_async,
13
+ )
14
+ from lm_deluge.prompt import CachePattern, Conversation, prompts_to_conversations
12
15
  from lm_deluge.tool import Tool
13
16
 
14
- from .tracker import StatusTracker
15
- from .sampling_params import SamplingParams
16
- from .models import registry
17
- from .api_requests.base import APIResponse, APIRequestBase
18
17
  from .api_requests import create_api_request
18
+ from .api_requests.base import APIRequestBase, APIResponse, deduplicate_responses
19
+ from .config import SamplingParams
20
+ from .models import registry
21
+ from .tracker import StatusTracker
22
+
19
23
  # from .cache import LevelDBCache, SqliteCache
20
24
 
25
+
21
26
  # TODO: get completions as they finish, not all at once at the end.
22
27
  # relatedly, would be nice to cache them as they finish too.
23
-
24
28
  # TODO: add optional max_input_tokens to client so we can reject long prompts to prevent abuse
29
+ class LLMClient(BaseModel):
30
+ """
31
+ LLMClient abstracts all the fixed arguments to process_prompts_async, so you can create it
32
+ once and use it for more stuff without having to configure all the arguments.
33
+ Handles models, sampling params for each model, model weights, rate limits, etc.
34
+ """
25
35
 
26
-
27
- @dataclass
28
- class ClientConfig:
29
- model_names: list[str]
30
- max_requests_per_minute: int
31
- max_tokens_per_minute: int
32
- max_concurrent_requests: int
33
- max_attempts: int
34
- request_timeout: int
35
- sampling_params: SamplingParams | list[SamplingParams]
36
- model_weights: list[float] | Literal["uniform", "rate_limit"]
36
+ model_names: list[str] = ["gpt-4.1-mini"]
37
+ max_requests_per_minute: int = 1_000
38
+ max_tokens_per_minute: int = 100_000
39
+ max_concurrent_requests: int = 225
40
+ sampling_params: list[SamplingParams] = []
41
+ model_weights: list[float] | Literal["uniform", "dynamic"] = "uniform"
42
+ max_attempts: int = 5
43
+ request_timeout: int = 30
44
+ cache: Any = None
45
+ # sampling params - if provided, and sampling_params is not,
46
+ # these override the defaults
47
+ temperature: float = 0.75
48
+ top_p: float = 1.0
49
+ json_mode: bool = False
50
+ max_new_tokens: int = 512
51
+ reasoning_effort: Literal["low", "medium", "high", None] = None
37
52
  logprobs: bool = False
38
53
  top_logprobs: int | None = None
39
- cache: Any = None
54
+
55
+ # NEW! Builder methods
56
+ def with_model(self, model: str):
57
+ self.model_names = [model]
58
+ return self
59
+
60
+ def with_models(self, models: list[str]):
61
+ self.model_names = models
62
+ return self
63
+
64
+ def with_limits(
65
+ self,
66
+ max_requests_per_minute: int | None = None,
67
+ max_tokens_per_minute: int | None = None,
68
+ max_concurrent_requests: int | None = None,
69
+ ):
70
+ if max_requests_per_minute:
71
+ self.max_requests_per_minute = max_requests_per_minute
72
+ if max_tokens_per_minute:
73
+ self.max_tokens_per_minute = max_tokens_per_minute
74
+ if max_concurrent_requests:
75
+ self.max_concurrent_requests = max_concurrent_requests
76
+
77
+ @property
78
+ def models(self):
79
+ return self.model_names # why? idk
80
+
81
+ @model_validator(mode="before")
82
+ @classmethod
83
+ def fix_lists(cls, data) -> "LLMClient":
84
+ if isinstance(data["model_names"], str):
85
+ data["model_names"] = [data["model_names"]]
86
+ if "sampling_params" not in data or len(data.get("sampling_params", [])) == 0:
87
+ data["sampling_params"] = [
88
+ SamplingParams(
89
+ temperature=data.get("temperature", 0.75),
90
+ top_p=data.get("top_p", 1.0),
91
+ json_mode=data.get("json_mode", False),
92
+ max_new_tokens=data.get("max_new_tokens", 512),
93
+ reasoning_effort=data.get("reasoning_effort", None),
94
+ logprobs=data.get("logprobs", False),
95
+ top_logprobs=data.get("top_logprobs", None),
96
+ )
97
+ ]
98
+ return data
99
+
100
+ @model_validator(mode="after")
101
+ def validate_client(self) -> Self:
102
+ if isinstance(self.model_names, str):
103
+ self.model_names = [self.model_names]
104
+ if any(m not in registry for m in self.model_names):
105
+ raise ValueError("all model_names must be in registry")
106
+ if isinstance(self.sampling_params, SamplingParams):
107
+ self.sampling_params = [self.sampling_params for _ in self.model_names]
108
+ elif len(self.sampling_params) != len(self.model_names):
109
+ raise ValueError("# models and # sampling params must match")
110
+ if self.model_weights == "uniform":
111
+ self.model_weights = [1 / len(self.model_names) for _ in self.model_names]
112
+ elif self.model_weights == "dynamic":
113
+ raise NotImplementedError("dynamic model weights not implemented yet")
114
+ # normalize weights
115
+ self.model_weights = [w / sum(self.model_weights) for w in self.model_weights]
116
+
117
+ # Validate logprobs settings across all sampling params
118
+ if self.logprobs or any(sp.logprobs for sp in self.sampling_params):
119
+ print("Logprobs enabled.")
120
+ for sp in self.sampling_params:
121
+ sp.logprobs = True
122
+ # set top_logprobs for each sp if provided and not set
123
+ if self.top_logprobs and not sp.top_logprobs:
124
+ sp.top_logprobs = self.top_logprobs
125
+ if sp.top_logprobs and not (0 <= sp.top_logprobs <= 20):
126
+ raise ValueError("top_logprobs must be 0-20")
127
+ if sp.top_logprobs and sp.max_new_tokens > 10:
128
+ print(
129
+ "WARNING: using top_logprobs can result in very large outputs. consider limiting max_new_tokens."
130
+ )
131
+ if not all(
132
+ registry[model].get("supports_logprobs") for model in self.models
133
+ ):
134
+ raise ValueError(
135
+ "logprobs can only be enabled if all models support it."
136
+ )
137
+ return self
40
138
 
41
139
  @classmethod
42
140
  def from_dict(cls, config_dict: dict):
@@ -46,7 +144,7 @@ class ClientConfig:
46
144
  ]
47
145
  else:
48
146
  config_dict["sampling_params"] = SamplingParams(
49
- config_dict["sampling_params"]
147
+ **config_dict["sampling_params"]
50
148
  )
51
149
 
52
150
  return cls(**config_dict)
@@ -56,183 +154,13 @@ class ClientConfig:
56
154
  config_dict = yaml.safe_load(open(file_path))
57
155
  return cls.from_dict(config_dict)
58
156
 
59
- def to_dict(self):
60
- if isinstance(self.sampling_params, list):
61
- sp = [x.__dict__ for x in self.sampling_params]
62
- else:
63
- sp = self.sampling_params.__dict__
64
-
65
- return {
66
- "model_names": self.model_names,
67
- "max_requests_per_minute": self.max_requests_per_minute,
68
- "max_tokens_per_minute": self.max_tokens_per_minute,
69
- "max_concurrent_requests": self.max_concurrent_requests,
70
- "max_attempts": self.max_attempts,
71
- "request_timeout": self.request_timeout,
72
- "sampling_params": sp,
73
- "model_weights": self.model_weights,
74
- "logprobs": self.logprobs,
75
- "top_logprobs": self.top_logprobs,
76
- }
77
-
78
-
79
- class LLMClient:
80
- """
81
- LLMClient abstracts all the fixed arguments to process_prompts_async, so you can create it
82
- once and use it for more stuff without having to configure all the arguments.
83
- Handles models, sampling params for each model, model weights, rate limits, etc.
84
- """
85
-
86
- def __init__(
87
- self,
88
- model_names: list[str],
89
- *,
90
- max_requests_per_minute: int,
91
- max_tokens_per_minute: int,
92
- max_concurrent_requests: int,
93
- sampling_params: SamplingParams | list[SamplingParams] = SamplingParams(),
94
- model_weights: list[float] | Literal["uniform", "rate_limit"] = "uniform",
95
- max_attempts: int = 5,
96
- request_timeout: int = 30,
97
- logprobs: bool = False,
98
- top_logprobs: int | None = None,
99
- use_qps: bool = False,
100
- debug: bool = False,
101
- cache: Any = None,
102
- ):
103
- self.models = model_names
104
- if isinstance(sampling_params, SamplingParams):
105
- self.sampling_params = [sampling_params for _ in model_names]
106
- else:
107
- if len(sampling_params) != len(model_names):
108
- raise ValueError(
109
- "If sampling_params is a list, it must have the same length as model_names."
110
- )
111
- self.sampling_params = sampling_params
112
- if model_weights == "uniform":
113
- self.model_weights = [1 / len(model_names) for _ in model_names]
114
- elif model_weights == "rate_limit":
115
- rpms = [registry[model]["requests_per_minute"] for model in model_names]
116
- self.model_weights = [rpm / sum(rpms) for rpm in rpms]
117
- elif sum(model_weights) != 1:
118
- self.model_weights = [w / sum(model_weights) for w in model_weights]
119
- else:
120
- self.model_weights = model_weights
121
-
122
- self.logprobs = logprobs
123
- self.top_logprobs = top_logprobs
124
-
125
- # logprobs and top_logprobs are only supported for OpenAI models
126
- if self.logprobs:
127
- for model in self.models:
128
- if registry[model].get("supports_logprobs", False) is False:
129
- raise ValueError(
130
- "logprobs can only be enabled if all models support it."
131
- )
132
- if self.top_logprobs is None:
133
- self.top_logprobs = 0 # will just return logprob of the chosen token
134
- elif self.top_logprobs > 20 or self.top_logprobs < 0:
135
- raise ValueError("top_logprobs must be between 0 and 20.")
136
- for sp in self.sampling_params:
137
- if sp.max_new_tokens > 10:
138
- print(
139
- "WARNING: using logprobs with large max_new_tokens can result in very large outputs. you may want to avoid saving these outputs to disk/db."
140
- )
141
- break
142
- else:
143
- self.top_logprobs = None
144
-
145
- self.max_requests_per_minute = max_requests_per_minute
146
- self.max_tokens_per_minute = max_tokens_per_minute
147
- self.max_concurrent_requests = max_concurrent_requests
148
- self.max_attempts = max_attempts
149
- self.request_timeout = request_timeout
150
- self.use_qps = use_qps
151
- self.debug = (
152
- debug # UNUSED/DEPRECATED i think? but dont want to break everything
153
- )
154
- self.cache = cache
155
-
156
- @classmethod
157
- def from_config(cls, config: ClientConfig, cache: Any = None):
158
- return cls(
159
- model_names=config.model_names,
160
- max_requests_per_minute=config.max_requests_per_minute,
161
- max_tokens_per_minute=config.max_tokens_per_minute,
162
- max_concurrent_requests=config.max_concurrent_requests,
163
- sampling_params=config.sampling_params,
164
- model_weights=config.model_weights,
165
- max_attempts=config.max_attempts,
166
- request_timeout=config.request_timeout,
167
- cache=cache,
168
- )
169
-
170
- @classmethod
171
- def from_yaml(cls, file_path: str, cache: Any = None):
172
- return cls.from_config(ClientConfig.from_yaml(file_path), cache=cache)
173
-
174
157
  @classmethod
175
- def basic(
176
- cls,
177
- model: str | list[str],
178
- max_requests_per_minute: int = 5_000,
179
- max_tokens_per_minute: int = 1_000_000,
180
- max_concurrent_requests: int = 1_000,
181
- temperature: float = 0.75,
182
- max_new_tokens: int = 1000,
183
- reasoning_effort: Literal[None, "low", "medium", "high"] = None,
184
- model_weights: list[float] | Literal["uniform", "rate_limit"] = "uniform",
185
- logprobs: bool = False,
186
- top_logprobs: int | None = None,
187
- max_attempts: int = 5,
188
- request_timeout: int = 30,
189
- cache: Any = None,
190
- ):
191
- model_names = model if isinstance(model, list) else [model]
192
- return cls(
193
- model_names=model_names,
194
- max_requests_per_minute=max_requests_per_minute,
195
- max_tokens_per_minute=max_tokens_per_minute,
196
- max_concurrent_requests=max_concurrent_requests,
197
- sampling_params=SamplingParams(
198
- temperature=temperature,
199
- max_new_tokens=max_new_tokens,
200
- reasoning_effort=reasoning_effort,
201
- ),
202
- logprobs=logprobs,
203
- top_logprobs=top_logprobs,
204
- model_weights=model_weights,
205
- max_attempts=max_attempts,
206
- request_timeout=request_timeout,
207
- cache=cache,
208
- )
209
-
210
- @property
211
- def config(self):
212
- return ClientConfig(
213
- model_names=self.models,
214
- model_weights=self.model_weights,
215
- max_requests_per_minute=self.max_requests_per_minute,
216
- max_tokens_per_minute=self.max_tokens_per_minute,
217
- max_concurrent_requests=self.max_concurrent_requests,
218
- max_attempts=self.max_attempts,
219
- request_timeout=self.request_timeout,
220
- sampling_params=self.sampling_params,
221
- logprobs=self.logprobs,
222
- top_logprobs=self.top_logprobs,
223
- )
224
-
225
- @overload
226
- async def process_prompts_async(
227
- self,
228
- prompts: Sequence[str | list[dict] | Conversation],
229
- *,
230
- return_completions_only: bool,
231
- show_progress: bool = ...,
232
- dry_run: Literal[True],
233
- verbose: bool = ...,
234
- tools: list[Tool] | None = ...,
235
- ) -> dict[str, int]: ...
158
+ def basic(cls, model: str | list[str], **kwargs):
159
+ """
160
+ Doesn't do anything differently now, kept for backwards compat.
161
+ """
162
+ kwargs["model_names"] = model
163
+ return cls(**kwargs)
236
164
 
237
165
  @overload
238
166
  async def process_prompts_async(
@@ -241,9 +169,12 @@ class LLMClient:
241
169
  *,
242
170
  return_completions_only: Literal[True],
243
171
  show_progress: bool = ...,
244
- dry_run: bool = ...,
245
- verbose: bool = ...,
246
172
  tools: list[Tool] | None = ...,
173
+ cache: CachePattern | None = ...,
174
+ computer_use: bool = ...,
175
+ display_width: int = ...,
176
+ display_height: int = ...,
177
+ use_responses_api: bool = ...,
247
178
  ) -> list[str | None]: ...
248
179
 
249
180
  @overload
@@ -253,9 +184,12 @@ class LLMClient:
253
184
  *,
254
185
  return_completions_only: Literal[False] = ...,
255
186
  show_progress: bool = ...,
256
- dry_run: bool = ...,
257
- verbose: bool = ...,
258
187
  tools: list[Tool] | None = ...,
188
+ cache: CachePattern | None = ...,
189
+ computer_use: bool = ...,
190
+ display_width: int = ...,
191
+ display_height: int = ...,
192
+ use_responses_api: bool = ...,
259
193
  ) -> list[APIResponse | None]: ...
260
194
 
261
195
  async def process_prompts_async(
@@ -264,22 +198,15 @@ class LLMClient:
264
198
  *,
265
199
  return_completions_only: bool = False,
266
200
  show_progress: bool = True,
267
- dry_run: bool = False,
268
- verbose: bool = False,
269
201
  tools: list[Tool] | None = None,
202
+ cache: CachePattern | None = None,
203
+ computer_use: bool = False,
204
+ display_width: int = 1024,
205
+ display_height: int = 768,
206
+ use_responses_api: bool = False,
270
207
  ) -> list[APIResponse | None] | list[str | None] | dict[str, int]:
271
208
  # if prompts are not Conversations, convert them.
272
- # can only handle strings for now
273
- prompts = [ # type: ignore
274
- p
275
- if isinstance(p, Conversation)
276
- else Conversation.user(p)
277
- if isinstance(p, str)
278
- else None
279
- for p in prompts
280
- ]
281
- if any(p is None for p in prompts):
282
- raise ValueError("All prompts must be valid.")
209
+ prompts = prompts_to_conversations(prompts)
283
210
  ids = np.arange(len(prompts))
284
211
 
285
212
  # if using cache, check for cached completions
@@ -294,10 +221,9 @@ class LLMClient:
294
221
  ), "Cache hit ids and results must be the same length."
295
222
  remaining_ids = np.array([i for i in ids if i not in cache_hit_ids])
296
223
  remaining_prompts = [prompts[i] for i in remaining_ids]
297
- if verbose:
298
- print(f"{len(cache_hit_ids)} cache hits from previous completions.")
299
- print(f"{len(remaining_ids)} prompts remaining after cache hits.")
300
- print(f"Processing {len(remaining_prompts)} prompts.")
224
+ print(
225
+ f"{len(cache_hit_ids)} cache hits; {len(remaining_ids)} prompts remaining."
226
+ )
301
227
 
302
228
  else:
303
229
  cache_hit_ids = []
@@ -307,47 +233,137 @@ class LLMClient:
307
233
 
308
234
  results: list[APIResponse | None] = [None for _ in range(len(prompts))]
309
235
  if len(remaining_prompts) > 0:
310
- # set up progress bar
311
- pbar = tqdm(total=len(prompts), disable=(not show_progress))
312
-
313
- # update progress bar with cache hits
314
- pbar.update(len(cache_hit_ids))
315
- api_task = None
316
- if dry_run:
317
- dry_run_results = api_prompts_dry_run(
318
- ids,
319
- prompts, # type: ignore -- fix later for dry running conversations
320
- self.models,
321
- self.model_weights,
322
- self.sampling_params,
323
- max_tokens_per_minute=self.max_tokens_per_minute,
324
- max_requests_per_minute=self.max_requests_per_minute,
325
- )
326
- print("Dry run results:")
327
- print(dry_run_results)
328
- return dry_run_results
329
-
330
- api_task = asyncio.create_task(
331
- process_api_prompts_async(
332
- ids,
333
- prompts, # type: ignore -- fix later for dry running conversations
334
- self.models,
335
- self.model_weights,
336
- self.sampling_params,
337
- logprobs=self.logprobs,
338
- top_logprobs=self.top_logprobs,
339
- max_attempts=self.max_attempts,
340
- max_tokens_per_minute=self.max_tokens_per_minute,
341
- max_requests_per_minute=self.max_requests_per_minute,
342
- max_concurrent_requests=self.max_concurrent_requests,
343
- request_timeout=self.request_timeout,
344
- progress_bar=pbar,
345
- use_qps=self.use_qps,
346
- verbose=verbose,
347
- tools=tools,
348
- )
236
+ # Create StatusTracker with integrated progress bar
237
+ tracker = StatusTracker(
238
+ max_requests_per_minute=self.max_requests_per_minute,
239
+ max_tokens_per_minute=self.max_tokens_per_minute,
240
+ max_concurrent_requests=self.max_concurrent_requests,
241
+ use_progress_bar=show_progress,
242
+ progress_bar_total=len(prompts),
243
+ progress_bar_disable=not show_progress,
244
+ use_rich=show_progress, # Disable Rich if progress is disabled
349
245
  )
350
- api_results: list[APIResponse] = await api_task
246
+
247
+ # Initialize progress bar and update with cache hits
248
+ tracker.init_progress_bar()
249
+ if len(cache_hit_ids) > 0:
250
+ tracker.update_pbar(len(cache_hit_ids))
251
+
252
+ # api_task = asyncio.create_task(
253
+ # process_api_prompts_async(
254
+ # ids,
255
+ # prompts, # type: ignore -- fix later for dry running conversations
256
+ # self.models,
257
+ # self.model_weights, # type: ignore
258
+ # self.sampling_params, # type: ignore
259
+ # max_attempts=self.max_attempts,
260
+ # max_concurrent_requests=self.max_concurrent_requests,
261
+ # request_timeout=self.request_timeout,
262
+ # status_tracker=tracker,
263
+ # tools=tools,
264
+ # cache=cache,
265
+ # computer_use=computer_use,
266
+ # display_width=display_width,
267
+ # display_height=display_height,
268
+ # use_responses_api=use_responses_api,
269
+ # )
270
+ # )
271
+ # async def process_api_prompts_async(
272
+
273
+ # models: str | list[str],
274
+ # model_weights: list[float],
275
+ # sampling_params: list[SamplingParams],
276
+ # max_attempts: int = 5,
277
+ # max_concurrent_requests: int = 1_000,
278
+ # request_timeout: int = 30,
279
+ # status_tracker: StatusTracker | None = None,
280
+ # tools: list[Tool] | None = None,
281
+ # cache: CachePattern | None = None,
282
+ # computer_use: bool = False,
283
+ # display_width: int = 1024,
284
+ # display_height: int = 768,
285
+ # use_responses_api: bool = False,
286
+ # ):
287
+ if isinstance(ids, np.ndarray):
288
+ ids = ids.tolist() # pyright: ignore
289
+
290
+ # calculate dynamically so we don't throttle RPM
291
+ seconds_to_sleep_each_loop = (60.0 * 0.9) / tracker.max_requests_per_minute
292
+ next_request = None # variable to hold the next request to call
293
+ prompts_not_finished = True
294
+ prompts_iter = iter(zip(ids, prompts))
295
+ requests: list[APIRequestBase] = []
296
+ assert tracker.retry_queue, "retry queue not initialized"
297
+ while True:
298
+ # get next request (if one is not already waiting for capacity)
299
+ if next_request is None:
300
+ if not tracker.retry_queue.empty():
301
+ next_request = tracker.retry_queue.get_nowait()
302
+ print(f"Retrying request {next_request.task_id}.")
303
+ elif prompts_not_finished:
304
+ try:
305
+ # get new request
306
+ id, prompt = next(prompts_iter)
307
+ # select model
308
+ assert isinstance(self.model_weights, list)
309
+ model_idx = np.random.choice(
310
+ range(len(self.models)), p=self.model_weights
311
+ )
312
+ next_request = create_api_request(
313
+ task_id=id,
314
+ model_name=self.models[model_idx],
315
+ prompt=prompt, # type: ignore
316
+ request_timeout=self.request_timeout,
317
+ attempts_left=self.max_attempts,
318
+ status_tracker=tracker,
319
+ results_arr=requests,
320
+ sampling_params=self.sampling_params[model_idx],
321
+ all_model_names=self.models,
322
+ all_sampling_params=self.sampling_params,
323
+ tools=tools,
324
+ cache=cache,
325
+ computer_use=computer_use,
326
+ display_width=display_width,
327
+ display_height=display_height,
328
+ use_responses_api=use_responses_api,
329
+ )
330
+ requests.append(next_request)
331
+
332
+ except StopIteration:
333
+ prompts_not_finished = False
334
+ # print("API requests finished, only retries remain.")
335
+
336
+ # update available capacity
337
+ tracker.update_capacity()
338
+
339
+ # if enough capacity available, call API
340
+ if next_request:
341
+ next_request_tokens = next_request.num_tokens
342
+ if tracker.check_capacity(next_request_tokens):
343
+ tracker.set_limiting_factor(None)
344
+ next_request.attempts_left -= 1
345
+ # call API
346
+ asyncio.create_task(next_request.call_api())
347
+ next_request = None # reset next_request to empty
348
+ # update pbar status
349
+ tracker.update_pbar()
350
+
351
+ # if all tasks are finished, break
352
+ if tracker.num_tasks_in_progress == 0:
353
+ break
354
+
355
+ # main loop sleeps briefly so concurrent tasks can run
356
+ await asyncio.sleep(seconds_to_sleep_each_loop)
357
+
358
+ # if a rate limit error was hit recently, pause to cool down
359
+ if tracker.seconds_to_pause > 0:
360
+ await asyncio.sleep(tracker.seconds_to_pause)
361
+ print(f"Pausing {tracker.seconds_to_pause}s to cool down.")
362
+
363
+ # after finishing, log final status
364
+ tracker.log_final_status()
365
+ # deduplicate results by id
366
+ api_results = deduplicate_responses(requests)
351
367
  for res in api_results:
352
368
  results[res.id] = res
353
369
  # set to cache if result has a completion
@@ -370,393 +386,116 @@ class LLMClient:
370
386
  *,
371
387
  return_completions_only: bool = False,
372
388
  show_progress=True,
373
- dry_run: bool = False,
374
- verbose: bool = False,
375
389
  tools: list[Tool] | None = None,
390
+ cache: CachePattern | None = None,
376
391
  ):
377
392
  return asyncio.run(
378
393
  self.process_prompts_async(
379
394
  prompts=prompts,
380
395
  return_completions_only=return_completions_only,
381
396
  show_progress=show_progress,
382
- dry_run=dry_run,
383
- verbose=verbose,
384
397
  tools=tools,
398
+ cache=cache,
385
399
  )
386
400
  )
387
401
 
388
- def _submit_one_batch(self, batch_requests: list):
389
- # save the file
390
- import pandas as pd
391
-
392
- pd.DataFrame(batch_requests).to_json(
393
- "openai_requests_temp.jsonl", orient="records", lines=True
394
- )
395
-
396
- # upload the file
397
- api_key = os.environ.get("OPENAI_API_KEY", None)
398
- if api_key is None:
399
- raise ValueError("OPENAI_API_KEY environment variable must be set.")
400
- url = "https://api.openai.com/v1/files"
401
- files = {
402
- "file": (
403
- "openai_requests_temp.jsonl",
404
- open("openai_requests_temp.jsonl", "rb"),
405
- ),
406
- }
407
- data = {
408
- "purpose": "batch",
409
- }
410
- headers = {
411
- "Authorization": f"Bearer {api_key}",
412
- }
413
- response = requests.post(url, files=files, data=data, headers=headers)
414
-
415
- file_id = None
416
- if response.status_code == 200:
417
- print("File uploaded successfully")
418
- data = response.json()
419
- file_id = data["id"]
420
-
421
- else:
422
- print("File upload failed")
423
- raise ValueError(f"Error uploading file: {response.text}")
424
-
425
- url = "https://api.openai.com/v1/batches"
426
- data = {
427
- "input_file_id": file_id,
428
- "endpoint": "/v1/chat/completions",
429
- "completion_window": "24h",
430
- }
431
- response = requests.post(url, json=data, headers=headers)
432
-
433
- batch_id = None
434
- if response.status_code == 200:
435
- data = response.json()
436
- batch_id = data["id"]
437
- print("Batch job started successfully: id = ", batch_id)
438
- return batch_id
439
- else:
440
- print("Batch job failed to start")
441
- raise ValueError(f"Error starting batch job: {response.text}")
442
-
443
- def submit_batch_job(self, prompts: Sequence[str | list[dict] | Conversation]):
444
- # make sure 1) only 1 model is used, 2) it's an openai model, 3) it supports json mode
402
+ async def submit_batch_job(
403
+ self,
404
+ prompts: Sequence[str | list[dict] | Conversation],
405
+ *,
406
+ tools: list[Tool] | None = None,
407
+ cache: CachePattern | None = None,
408
+ ):
409
+ """Submit a batch job asynchronously, automatically detecting the provider based on model.
410
+
411
+ Args:
412
+ prompts: List of prompts to process
413
+ wait_for_completion: If True, poll until completion and return results
414
+ poll_interval: Seconds to wait between status checks when polling
415
+ tools: Optional tools to include in requests (Anthropic only)
416
+ cache: Optional cache pattern for requests (Anthropic only)
417
+
418
+ Returns: list of batch_ids
419
+ """
420
+ assert isinstance(self.sampling_params, list)
445
421
  if len(self.models) != 1:
446
422
  raise ValueError("Batch jobs can only be submitted with a single model.")
447
423
  model = self.models[0]
448
- if registry[model].get("api_spec", None) != "openai":
449
- raise ValueError("Batch jobs can only be submitted with OpenAI models.")
450
-
451
- # if prompts are strings, convert them to message lists
452
- prompts = [ # type: ignore
453
- p
454
- if isinstance(p, Conversation)
455
- else Conversation.user(p)
456
- if isinstance(p, str)
457
- else None
458
- for p in prompts
459
- ]
460
- if any(p is None for p in prompts):
461
- raise ValueError("All prompts must be valid.")
462
- ids = np.arange(len(prompts))
463
-
464
- # create file with requests to send to batch api
465
- batch_requests = []
466
- for id, prompt in zip(ids, prompts):
467
- assert isinstance(prompt, Conversation)
468
- batch_requests.append(
469
- {
470
- "custom_id": str(id),
471
- "method": "POST",
472
- "url": "/v1/chat/completions",
473
- "body": {
474
- "model": self.models[0],
475
- "messages": prompt.to_openai(),
476
- "max_tokens": self.sampling_params[0].max_new_tokens,
477
- "temperature": self.sampling_params[0].temperature,
478
- "top_p": self.sampling_params[0].top_p,
479
- },
480
- }
424
+ api_spec = registry[model].get("api_spec", None)
425
+
426
+ if api_spec == "openai":
427
+ return await submit_batches_oa(model, self.sampling_params[0], prompts)
428
+ elif api_spec == "anthropic":
429
+ return await submit_batches_anthropic(
430
+ model,
431
+ self.sampling_params[0],
432
+ prompts,
433
+ cache=cache,
481
434
  )
435
+ else:
436
+ raise ValueError(f"Batch processing not supported for API spec: {api_spec}")
482
437
 
483
- # since the api only accepts up to 50,000 requests per batch job, we chunk into 50k chunks
484
- BATCH_SIZE = 50_000
485
- batches = [
486
- batch_requests[i : i + BATCH_SIZE]
487
- for i in range(0, len(batch_requests), BATCH_SIZE)
488
- ]
489
- batch_ids = []
490
- for batch in tqdm(batches):
491
- batch_id = self._submit_one_batch(batch)
492
- batch_ids.append(batch_id)
493
-
494
- print(f"Submitted {len(batches)} batch jobs.")
495
- return batch_ids
496
-
497
-
498
- def api_prompts_dry_run(
499
- ids: np.ndarray | list[int],
500
- prompts: list[Conversation],
501
- models: str | list[str],
502
- model_weights: list[float],
503
- sampling_params: list[SamplingParams],
504
- max_tokens_per_minute: int = 500_000,
505
- max_requests_per_minute: int = 1_000,
506
- ):
507
- """
508
- Count tokens and estimate costs for a batch of prompts.
509
- """
510
- results = []
511
- for i, prompt in zip(ids, prompts):
512
- # choose a model
513
- model_idx = np.random.choice(range(len(models)), p=model_weights)
514
- model = models[model_idx]
515
-
516
- # dry run
517
- input_tokens, output_tokens, min_cost, max_cost = prompt.dry_run(
518
- model, sampling_params[model_idx].max_new_tokens
519
- )
520
- results.append(
521
- {
522
- "id": i,
523
- "input_tokens": input_tokens,
524
- "output_tokens": output_tokens,
525
- "min_cost": min_cost,
526
- "max_cost": max_cost,
527
- }
528
- )
529
-
530
- combined_results: dict[str, Any] = {
531
- "total_input_tokens": sum([r["input_tokens"] for r in results]),
532
- "total_output_tokens": sum([r["output_tokens"] for r in results]),
533
- "total_min_cost": sum([r["min_cost"] for r in results]),
534
- "total_max_cost": sum([r["max_cost"] for r in results]),
535
- }
536
- minimum_time_tpm = combined_results["total_input_tokens"] / max_tokens_per_minute
537
- maximum_time_tpm = (
538
- combined_results["total_input_tokens"] + combined_results["total_output_tokens"]
539
- ) / max_tokens_per_minute
540
- minimum_time_rpm = len(prompts) / max_requests_per_minute
541
-
542
- combined_results["minimum_time"] = max(minimum_time_tpm, minimum_time_rpm)
543
- combined_results["maximum_time"] = max(maximum_time_tpm, minimum_time_rpm)
544
- limiting_factor = None
545
- if minimum_time_rpm > maximum_time_tpm:
546
- limiting_factor = "requests"
547
- elif minimum_time_rpm < minimum_time_tpm:
548
- limiting_factor = "tokens"
549
- else:
550
- limiting_factor = "depends"
551
- combined_results["limiting_factor"] = limiting_factor
552
-
553
- return combined_results
554
-
555
-
556
- async def process_api_prompts_async(
557
- ids: np.ndarray | list[int],
558
- prompts: list[Conversation],
559
- models: str | list[str],
560
- model_weights: list[float],
561
- sampling_params: list[SamplingParams],
562
- logprobs: bool,
563
- top_logprobs: int | None,
564
- max_attempts: int = 5,
565
- max_tokens_per_minute: int = 500_000,
566
- max_requests_per_minute: int = 1_000,
567
- max_concurrent_requests: int = 1_000,
568
- request_timeout: int = 30,
569
- progress_bar: tqdm | None = None,
570
- use_qps: bool = False,
571
- verbose: bool = False,
572
- tools: list[Tool] | None = None,
573
- ):
574
- """Processes API requests in parallel, throttling to stay under rate limits."""
575
- # change ids to integer list
576
- if isinstance(ids, np.ndarray):
577
- ids = ids.tolist() # pyright: ignore
578
-
579
- # normalize weights
580
- model_weights = [w / sum(model_weights) for w in model_weights]
581
-
582
- # constants
583
- seconds_to_pause_after_rate_limit_error = 5
584
- # seconds_to_sleep_each_loop = 0.003 # so concurrent tasks can run
585
- # calculate dynamically so we don't throttle RPM
586
- seconds_to_sleep_each_loop = (60.0 * 0.9) / max_requests_per_minute
587
-
588
- # initialize trackers
589
- retry_queue = asyncio.Queue()
590
- status_tracker = StatusTracker()
591
- next_request = None # variable to hold the next request to call
592
-
593
- # initialize available capacity counts
594
- # throttle over a 1 second window rather than minute,
595
- # since some models limit RPS rather than RPM
596
- if use_qps:
597
- available_request_capacity = max_requests_per_minute / 60.0
598
- available_token_capacity = max_tokens_per_minute
599
- else:
600
- available_request_capacity = max_requests_per_minute
601
- available_token_capacity = max_tokens_per_minute
602
- last_update_time = time.time()
603
- last_pbar_update_time = time.time()
604
-
605
- # initialize flags
606
- prompts_not_finished = True
607
-
608
- # initials model weights
609
- if isinstance(models, str):
610
- models = [models]
611
- if not isinstance(models, list):
612
- raise ValueError("models must be a string or a list of model strings.")
613
- for model in models:
614
- if model not in registry:
615
- raise ValueError(f"Model {model} not found in registry.")
616
-
617
- if model_weights is None:
618
- # if not given, spread requests evenly across models
619
- model_weights = [1 / len(models) for _ in models]
620
- elif len(model_weights) != len(models):
621
- raise ValueError(
622
- "model_weights must be None or a list of the same length as models."
623
- )
624
- elif sum(model_weights) != 1:
625
- model_weights = [w / sum(model_weights) for w in model_weights]
626
-
627
- prompts_iter = iter(zip(ids, prompts))
628
- results: list[APIRequestBase] = []
629
- while True:
630
- # get next request (if one is not already waiting for capacity)
631
- if next_request is None:
632
- if not retry_queue.empty():
633
- next_request = retry_queue.get_nowait()
634
- print(f"Retrying request {next_request.task_id}.")
635
- elif prompts_not_finished:
636
- try:
637
- # get new request
638
- id, prompt = next(prompts_iter)
639
- # select model
640
- model_idx = np.random.choice(range(len(models)), p=model_weights)
641
- next_request = create_api_request(
642
- task_id=id,
643
- model_name=models[model_idx],
644
- prompt=prompt,
645
- request_timeout=request_timeout,
646
- attempts_left=max_attempts,
647
- status_tracker=status_tracker,
648
- retry_queue=retry_queue,
649
- results_arr=results,
650
- sampling_params=sampling_params[model_idx],
651
- logprobs=logprobs,
652
- top_logprobs=top_logprobs,
653
- pbar=progress_bar,
654
- all_model_names=models,
655
- all_sampling_params=sampling_params,
656
- tools=tools,
657
- )
658
- status_tracker.num_tasks_started += 1
659
- status_tracker.num_tasks_in_progress += 1
660
- results.append(next_request)
661
-
662
- except StopIteration:
663
- prompts_not_finished = False
664
- if verbose:
665
- print("API requests finished, only retries remain.")
666
-
667
- # update available capacity
668
- current_time = time.time()
669
- seconds_since_update = current_time - last_update_time
670
- available_request_capacity = min(
671
- available_request_capacity
672
- + max_requests_per_minute * seconds_since_update / 60.0,
673
- max_requests_per_minute,
674
- )
675
- available_token_capacity = min(
676
- available_token_capacity
677
- + max_tokens_per_minute * seconds_since_update / 60.0,
678
- max_tokens_per_minute,
679
- )
680
- last_update_time = current_time
681
-
682
- # if enough capacity available, call API
683
- limiting_factor = None
684
- if next_request:
685
- next_request_tokens = next_request.num_tokens
686
- request_available = available_request_capacity >= 1
687
- tokens_available = available_token_capacity >= next_request_tokens
688
- concurrent_request_available = (
689
- status_tracker.num_tasks_in_progress < max_concurrent_requests
690
- )
691
- if request_available and tokens_available and concurrent_request_available:
692
- # update counters
693
- available_request_capacity -= 1
694
- available_token_capacity -= next_request_tokens
695
- next_request.attempts_left -= 1
696
-
697
- # call API
698
- asyncio.create_task(next_request.call_api())
699
- next_request = None # reset next_request to empty
700
- else:
701
- if not request_available:
702
- limiting_factor = "Requests"
703
- elif not concurrent_request_available:
704
- limiting_factor = "Concurrent Requests"
705
- elif not tokens_available:
706
- limiting_factor = "Tokens"
707
-
708
- # update pbar status
709
- if progress_bar and (current_time - last_pbar_update_time > 1):
710
- last_pbar_update_time = current_time
711
- progress_bar.set_postfix(
712
- {
713
- "Token Capacity": f"{available_token_capacity/1_000:.1f}k",
714
- "Req. Capacity": f"{available_request_capacity:.1f}",
715
- "Reqs. in Progress": status_tracker.num_tasks_in_progress,
716
- "Limiting Factor": limiting_factor,
717
- }
718
- )
719
-
720
- # if all tasks are finished, break
721
- if status_tracker.num_tasks_in_progress == 0:
722
- break
723
-
724
- # main loop sleeps briefly so concurrent tasks can run
725
- await asyncio.sleep(seconds_to_sleep_each_loop)
726
-
727
- # if a rate limit error was hit recently, pause to cool down
728
- remaining_seconds_to_pause = max(
729
- 0,
730
- seconds_to_pause_after_rate_limit_error
731
- - status_tracker.time_since_rate_limit_error,
732
- )
733
- if remaining_seconds_to_pause > 0:
734
- await asyncio.sleep(remaining_seconds_to_pause)
735
- print(f"Pausing {remaining_seconds_to_pause}s to cool down.")
736
-
737
- # after finishing, log final status
738
- status_tracker.log_final_status()
739
- if verbose:
740
- print(
741
- f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
438
+ async def wait_for_batch_job(
439
+ self, batch_ids: list[str], provider: Literal["anthropic", "openai"]
440
+ ):
441
+ return await wait_for_batch_completion_async(
442
+ batch_ids, provider, poll_interval=30
742
443
  )
743
444
 
744
- # deduplicate results by id
745
- deduplicated = {}
746
- for request in results:
747
- if request.task_id not in deduplicated:
748
- deduplicated[request.task_id] = request.result[-1]
749
- else:
750
- current_response: APIResponse = deduplicated[request.task_id]
751
- # only replace if the current request has no completion and the new one does
752
- if (
753
- request.result[-1].completion is not None
754
- and current_response.completion is None
755
- ):
756
- deduplicated[request.task_id] = request.result[-1]
757
-
758
- output = list(deduplicated.values())
759
- if verbose:
760
- print(f"Returning {len(output)} unique results.")
761
445
 
762
- return output
446
+ # def api_prompts_dry_run(
447
+ # ids: np.ndarray | list[int],
448
+ # prompts: list[Conversation],
449
+ # models: str | list[str],
450
+ # model_weights: list[float],
451
+ # sampling_params: list[SamplingParams],
452
+ # max_tokens_per_minute: int = 500_000,
453
+ # max_requests_per_minute: int = 1_000,
454
+ # ):
455
+ # """
456
+ # Count tokens and estimate costs for a batch of prompts.
457
+ # """
458
+ # results = []
459
+ # for i, prompt in zip(ids, prompts):
460
+ # # choose a model
461
+ # model_idx = np.random.choice(range(len(models)), p=model_weights)
462
+ # model = models[model_idx]
463
+
464
+ # # dry run
465
+ # input_tokens, output_tokens, min_cost, max_cost = prompt.dry_run(
466
+ # model, sampling_params[model_idx].max_new_tokens
467
+ # )
468
+ # results.append(
469
+ # {
470
+ # "id": i,
471
+ # "input_tokens": input_tokens,
472
+ # "output_tokens": output_tokens,
473
+ # "min_cost": min_cost,
474
+ # "max_cost": max_cost,
475
+ # }
476
+ # )
477
+
478
+ # combined_results: dict[str, Any] = {
479
+ # "total_input_tokens": sum([r["input_tokens"] for r in results]),
480
+ # "total_output_tokens": sum([r["output_tokens"] for r in results]),
481
+ # "total_min_cost": sum([r["min_cost"] for r in results]),
482
+ # "total_max_cost": sum([r["max_cost"] for r in results]),
483
+ # }
484
+ # minimum_time_tpm = combined_results["total_input_tokens"] / max_tokens_per_minute
485
+ # maximum_time_tpm = (
486
+ # combined_results["total_input_tokens"] + combined_results["total_output_tokens"]
487
+ # ) / max_tokens_per_minute
488
+ # minimum_time_rpm = len(prompts) / max_requests_per_minute
489
+
490
+ # combined_results["minimum_time"] = max(minimum_time_tpm, minimum_time_rpm)
491
+ # combined_results["maximum_time"] = max(maximum_time_tpm, minimum_time_rpm)
492
+ # limiting_factor = None
493
+ # if minimum_time_rpm > maximum_time_tpm:
494
+ # limiting_factor = "requests"
495
+ # elif minimum_time_rpm < minimum_time_tpm:
496
+ # limiting_factor = "tokens"
497
+ # else:
498
+ # limiting_factor = "depends"
499
+ # combined_results["limiting_factor"] = limiting_factor
500
+
501
+ # return combined_results