lm-deluge 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

lm_deluge/client.py CHANGED
@@ -1,42 +1,147 @@
1
- import os
2
- import requests
3
1
  import asyncio
2
+ from typing import Any, Literal, Self, Sequence, overload
3
+
4
4
  import numpy as np
5
- import time
6
5
  import yaml
7
- from dataclasses import dataclass
8
- from typing import Sequence, overload, Literal, Any
9
- from tqdm.auto import tqdm
10
-
11
- from lm_deluge.prompt import Conversation, CachePattern
6
+ from pydantic import BaseModel
7
+ from pydantic.functional_validators import model_validator
8
+
9
+ from lm_deluge.api_requests.openai import stream_chat
10
+ from lm_deluge.batches import (
11
+ submit_batches_anthropic,
12
+ submit_batches_oa,
13
+ wait_for_batch_completion_async,
14
+ )
15
+ from lm_deluge.prompt import CachePattern, Conversation, prompts_to_conversations
12
16
  from lm_deluge.tool import Tool
13
17
 
14
- from .tracker import StatusTracker
15
- from .sampling_params import SamplingParams
16
- from .models import registry
17
- from .api_requests.base import APIResponse, APIRequestBase
18
18
  from .api_requests import create_api_request
19
+ from .api_requests.base import APIRequestBase, APIResponse, deduplicate_responses
20
+ from .config import SamplingParams
21
+ from .models import registry
22
+ from .tracker import StatusTracker
23
+
19
24
  # from .cache import LevelDBCache, SqliteCache
20
25
 
26
+
21
27
  # TODO: get completions as they finish, not all at once at the end.
22
28
  # relatedly, would be nice to cache them as they finish too.
23
-
24
29
  # TODO: add optional max_input_tokens to client so we can reject long prompts to prevent abuse
30
+ class LLMClient(BaseModel):
31
+ """
32
+ LLMClient abstracts all the fixed arguments to process_prompts_async, so you can create it
33
+ once and use it for more stuff without having to configure all the arguments.
34
+ Handles models, sampling params for each model, model weights, rate limits, etc.
35
+ """
25
36
 
37
+ model_names: list[str] = ["gpt-4.1-mini"]
26
38
 
27
- @dataclass
28
- class ClientConfig:
29
- model_names: list[str]
30
- max_requests_per_minute: int
31
- max_tokens_per_minute: int
32
- max_concurrent_requests: int
33
- max_attempts: int
34
- request_timeout: int
35
- sampling_params: SamplingParams | list[SamplingParams]
36
- model_weights: list[float] | Literal["uniform", "rate_limit"]
39
+ def __init__(self, model_name: str | list[str] | None = None, **kwargs):
40
+ if model_name is not None:
41
+ kwargs["model_names"] = model_name
42
+ super().__init__(**kwargs)
43
+
44
+ max_requests_per_minute: int = 1_000
45
+ max_tokens_per_minute: int = 100_000
46
+ max_concurrent_requests: int = 225
47
+ sampling_params: list[SamplingParams] = []
48
+ model_weights: list[float] | Literal["uniform", "dynamic"] = "uniform"
49
+ max_attempts: int = 5
50
+ request_timeout: int = 30
51
+ cache: Any = None
52
+ # sampling params - if provided, and sampling_params is not,
53
+ # these override the defaults
54
+ temperature: float = 0.75
55
+ top_p: float = 1.0
56
+ json_mode: bool = False
57
+ max_new_tokens: int = 512
58
+ reasoning_effort: Literal["low", "medium", "high", None] = None
37
59
  logprobs: bool = False
38
60
  top_logprobs: int | None = None
39
- cache: Any = None
61
+
62
+ # NEW! Builder methods
63
+ def with_model(self, model: str):
64
+ self.model_names = [model]
65
+ return self
66
+
67
+ def with_models(self, models: list[str]):
68
+ self.model_names = models
69
+ return self
70
+
71
+ def with_limits(
72
+ self,
73
+ max_requests_per_minute: int | None = None,
74
+ max_tokens_per_minute: int | None = None,
75
+ max_concurrent_requests: int | None = None,
76
+ ):
77
+ if max_requests_per_minute:
78
+ self.max_requests_per_minute = max_requests_per_minute
79
+ if max_tokens_per_minute:
80
+ self.max_tokens_per_minute = max_tokens_per_minute
81
+ if max_concurrent_requests:
82
+ self.max_concurrent_requests = max_concurrent_requests
83
+
84
+ @property
85
+ def models(self):
86
+ return self.model_names # why? idk
87
+
88
+ @model_validator(mode="before")
89
+ @classmethod
90
+ def fix_lists(cls, data) -> "LLMClient":
91
+ if isinstance(data.get("model_names"), str):
92
+ data["model_names"] = [data["model_names"]]
93
+ if "sampling_params" not in data or len(data.get("sampling_params", [])) == 0:
94
+ data["sampling_params"] = [
95
+ SamplingParams(
96
+ temperature=data.get("temperature", 0.75),
97
+ top_p=data.get("top_p", 1.0),
98
+ json_mode=data.get("json_mode", False),
99
+ max_new_tokens=data.get("max_new_tokens", 512),
100
+ reasoning_effort=data.get("reasoning_effort", None),
101
+ logprobs=data.get("logprobs", False),
102
+ top_logprobs=data.get("top_logprobs", None),
103
+ )
104
+ ]
105
+ return data
106
+
107
+ @model_validator(mode="after")
108
+ def validate_client(self) -> Self:
109
+ if isinstance(self.model_names, str):
110
+ self.model_names = [self.model_names]
111
+ if any(m not in registry for m in self.model_names):
112
+ raise ValueError("all model_names must be in registry")
113
+ if isinstance(self.sampling_params, SamplingParams):
114
+ self.sampling_params = [self.sampling_params for _ in self.model_names]
115
+ elif len(self.sampling_params) != len(self.model_names):
116
+ raise ValueError("# models and # sampling params must match")
117
+ if self.model_weights == "uniform":
118
+ self.model_weights = [1 / len(self.model_names) for _ in self.model_names]
119
+ elif self.model_weights == "dynamic":
120
+ raise NotImplementedError("dynamic model weights not implemented yet")
121
+ # normalize weights
122
+ self.model_weights = [w / sum(self.model_weights) for w in self.model_weights]
123
+
124
+ # Validate logprobs settings across all sampling params
125
+ if self.logprobs or any(sp.logprobs for sp in self.sampling_params):
126
+ print("Logprobs enabled.")
127
+ for sp in self.sampling_params:
128
+ sp.logprobs = True
129
+ # set top_logprobs for each sp if provided and not set
130
+ if self.top_logprobs and not sp.top_logprobs:
131
+ sp.top_logprobs = self.top_logprobs
132
+ if sp.top_logprobs and not (0 <= sp.top_logprobs <= 20):
133
+ raise ValueError("top_logprobs must be 0-20")
134
+ if sp.top_logprobs and sp.max_new_tokens > 10:
135
+ print(
136
+ "WARNING: using top_logprobs can result in very large outputs. consider limiting max_new_tokens."
137
+ )
138
+ if not all(
139
+ registry[model].get("supports_logprobs") for model in self.models
140
+ ):
141
+ raise ValueError(
142
+ "logprobs can only be enabled if all models support it."
143
+ )
144
+ return self
40
145
 
41
146
  @classmethod
42
147
  def from_dict(cls, config_dict: dict):
@@ -46,7 +151,7 @@ class ClientConfig:
46
151
  ]
47
152
  else:
48
153
  config_dict["sampling_params"] = SamplingParams(
49
- config_dict["sampling_params"]
154
+ **config_dict["sampling_params"]
50
155
  )
51
156
 
52
157
  return cls(**config_dict)
@@ -56,184 +161,18 @@ class ClientConfig:
56
161
  config_dict = yaml.safe_load(open(file_path))
57
162
  return cls.from_dict(config_dict)
58
163
 
59
- def to_dict(self):
60
- if isinstance(self.sampling_params, list):
61
- sp = [x.__dict__ for x in self.sampling_params]
62
- else:
63
- sp = self.sampling_params.__dict__
64
-
65
- return {
66
- "model_names": self.model_names,
67
- "max_requests_per_minute": self.max_requests_per_minute,
68
- "max_tokens_per_minute": self.max_tokens_per_minute,
69
- "max_concurrent_requests": self.max_concurrent_requests,
70
- "max_attempts": self.max_attempts,
71
- "request_timeout": self.request_timeout,
72
- "sampling_params": sp,
73
- "model_weights": self.model_weights,
74
- "logprobs": self.logprobs,
75
- "top_logprobs": self.top_logprobs,
76
- }
77
-
78
-
79
- class LLMClient:
80
- """
81
- LLMClient abstracts all the fixed arguments to process_prompts_async, so you can create it
82
- once and use it for more stuff without having to configure all the arguments.
83
- Handles models, sampling params for each model, model weights, rate limits, etc.
84
- """
85
-
86
- def __init__(
87
- self,
88
- model_names: list[str],
89
- *,
90
- max_requests_per_minute: int,
91
- max_tokens_per_minute: int,
92
- max_concurrent_requests: int,
93
- sampling_params: SamplingParams | list[SamplingParams] = SamplingParams(),
94
- model_weights: list[float] | Literal["uniform", "rate_limit"] = "uniform",
95
- max_attempts: int = 5,
96
- request_timeout: int = 30,
97
- logprobs: bool = False,
98
- top_logprobs: int | None = None,
99
- use_qps: bool = False,
100
- debug: bool = False,
101
- cache: Any = None,
102
- ):
103
- self.models = model_names
104
- if isinstance(sampling_params, SamplingParams):
105
- self.sampling_params = [sampling_params for _ in model_names]
106
- else:
107
- if len(sampling_params) != len(model_names):
108
- raise ValueError(
109
- "If sampling_params is a list, it must have the same length as model_names."
110
- )
111
- self.sampling_params = sampling_params
112
- if model_weights == "uniform":
113
- self.model_weights = [1 / len(model_names) for _ in model_names]
114
- elif model_weights == "rate_limit":
115
- rpms = [registry[model]["requests_per_minute"] for model in model_names]
116
- self.model_weights = [rpm / sum(rpms) for rpm in rpms]
117
- elif sum(model_weights) != 1:
118
- self.model_weights = [w / sum(model_weights) for w in model_weights]
119
- else:
120
- self.model_weights = model_weights
121
-
122
- self.logprobs = logprobs
123
- self.top_logprobs = top_logprobs
124
-
125
- # logprobs and top_logprobs are only supported for OpenAI models
126
- if self.logprobs:
127
- for model in self.models:
128
- if registry[model].get("supports_logprobs", False) is False:
129
- raise ValueError(
130
- "logprobs can only be enabled if all models support it."
131
- )
132
- if self.top_logprobs is None:
133
- self.top_logprobs = 0 # will just return logprob of the chosen token
134
- elif self.top_logprobs > 20 or self.top_logprobs < 0:
135
- raise ValueError("top_logprobs must be between 0 and 20.")
136
- for sp in self.sampling_params:
137
- if sp.max_new_tokens > 10:
138
- print(
139
- "WARNING: using logprobs with large max_new_tokens can result in very large outputs. you may want to avoid saving these outputs to disk/db."
140
- )
141
- break
142
- else:
143
- self.top_logprobs = None
144
-
145
- self.max_requests_per_minute = max_requests_per_minute
146
- self.max_tokens_per_minute = max_tokens_per_minute
147
- self.max_concurrent_requests = max_concurrent_requests
148
- self.max_attempts = max_attempts
149
- self.request_timeout = request_timeout
150
- self.use_qps = use_qps
151
- self.debug = (
152
- debug # UNUSED/DEPRECATED i think? but dont want to break everything
153
- )
154
- self.cache = cache
155
-
156
164
  @classmethod
157
- def from_config(cls, config: ClientConfig, cache: Any = None):
158
- return cls(
159
- model_names=config.model_names,
160
- max_requests_per_minute=config.max_requests_per_minute,
161
- max_tokens_per_minute=config.max_tokens_per_minute,
162
- max_concurrent_requests=config.max_concurrent_requests,
163
- sampling_params=config.sampling_params,
164
- model_weights=config.model_weights,
165
- max_attempts=config.max_attempts,
166
- request_timeout=config.request_timeout,
167
- cache=cache,
168
- )
169
-
170
- @classmethod
171
- def from_yaml(cls, file_path: str, cache: Any = None):
172
- return cls.from_config(ClientConfig.from_yaml(file_path), cache=cache)
173
-
174
- @classmethod
175
- def basic(
176
- cls,
177
- model: str | list[str],
178
- max_requests_per_minute: int = 5_000,
179
- max_tokens_per_minute: int = 1_000_000,
180
- max_concurrent_requests: int = 1_000,
181
- temperature: float = 0.75,
182
- max_new_tokens: int = 1000,
183
- reasoning_effort: Literal[None, "low", "medium", "high"] = None,
184
- model_weights: list[float] | Literal["uniform", "rate_limit"] = "uniform",
185
- logprobs: bool = False,
186
- top_logprobs: int | None = None,
187
- max_attempts: int = 5,
188
- request_timeout: int = 30,
189
- cache: Any = None,
190
- ):
191
- model_names = model if isinstance(model, list) else [model]
192
- return cls(
193
- model_names=model_names,
194
- max_requests_per_minute=max_requests_per_minute,
195
- max_tokens_per_minute=max_tokens_per_minute,
196
- max_concurrent_requests=max_concurrent_requests,
197
- sampling_params=SamplingParams(
198
- temperature=temperature,
199
- max_new_tokens=max_new_tokens,
200
- reasoning_effort=reasoning_effort,
201
- ),
202
- logprobs=logprobs,
203
- top_logprobs=top_logprobs,
204
- model_weights=model_weights,
205
- max_attempts=max_attempts,
206
- request_timeout=request_timeout,
207
- cache=cache,
208
- )
209
-
210
- @property
211
- def config(self):
212
- return ClientConfig(
213
- model_names=self.models,
214
- model_weights=self.model_weights,
215
- max_requests_per_minute=self.max_requests_per_minute,
216
- max_tokens_per_minute=self.max_tokens_per_minute,
217
- max_concurrent_requests=self.max_concurrent_requests,
218
- max_attempts=self.max_attempts,
219
- request_timeout=self.request_timeout,
220
- sampling_params=self.sampling_params,
221
- logprobs=self.logprobs,
222
- top_logprobs=self.top_logprobs,
223
- )
224
-
225
- @overload
226
- async def process_prompts_async(
227
- self,
228
- prompts: Sequence[str | list[dict] | Conversation],
229
- *,
230
- return_completions_only: bool,
231
- show_progress: bool = ...,
232
- dry_run: Literal[True],
233
- verbose: bool = ...,
234
- tools: list[Tool] | None = ...,
235
- cache: CachePattern | None = ...,
236
- ) -> dict[str, int]: ...
165
+ def basic(cls, model: str | list[str], **kwargs):
166
+ """
167
+ Doesn't do anything differently now, kept for backwards compat.
168
+ """
169
+ kwargs["model_names"] = model
170
+ return cls(**kwargs)
171
+
172
+ def _select_model(self):
173
+ assert isinstance(self.model_weights, list)
174
+ model_idx = np.random.choice(range(len(self.models)), p=self.model_weights)
175
+ return self.models[model_idx], self.sampling_params[model_idx]
237
176
 
238
177
  @overload
239
178
  async def process_prompts_async(
@@ -242,10 +181,12 @@ class LLMClient:
242
181
  *,
243
182
  return_completions_only: Literal[True],
244
183
  show_progress: bool = ...,
245
- dry_run: bool = ...,
246
- verbose: bool = ...,
247
184
  tools: list[Tool] | None = ...,
248
185
  cache: CachePattern | None = ...,
186
+ computer_use: bool = ...,
187
+ display_width: int = ...,
188
+ display_height: int = ...,
189
+ use_responses_api: bool = ...,
249
190
  ) -> list[str | None]: ...
250
191
 
251
192
  @overload
@@ -255,10 +196,12 @@ class LLMClient:
255
196
  *,
256
197
  return_completions_only: Literal[False] = ...,
257
198
  show_progress: bool = ...,
258
- dry_run: bool = ...,
259
- verbose: bool = ...,
260
199
  tools: list[Tool] | None = ...,
261
200
  cache: CachePattern | None = ...,
201
+ computer_use: bool = ...,
202
+ display_width: int = ...,
203
+ display_height: int = ...,
204
+ use_responses_api: bool = ...,
262
205
  ) -> list[APIResponse | None]: ...
263
206
 
264
207
  async def process_prompts_async(
@@ -267,23 +210,15 @@ class LLMClient:
267
210
  *,
268
211
  return_completions_only: bool = False,
269
212
  show_progress: bool = True,
270
- dry_run: bool = False,
271
- verbose: bool = False,
272
213
  tools: list[Tool] | None = None,
273
214
  cache: CachePattern | None = None,
215
+ computer_use: bool = False,
216
+ display_width: int = 1024,
217
+ display_height: int = 768,
218
+ use_responses_api: bool = False,
274
219
  ) -> list[APIResponse | None] | list[str | None] | dict[str, int]:
275
220
  # if prompts are not Conversations, convert them.
276
- # can only handle strings for now
277
- prompts = [ # type: ignore
278
- p
279
- if isinstance(p, Conversation)
280
- else Conversation.user(p)
281
- if isinstance(p, str)
282
- else None
283
- for p in prompts
284
- ]
285
- if any(p is None for p in prompts):
286
- raise ValueError("All prompts must be valid.")
221
+ prompts = prompts_to_conversations(prompts)
287
222
  ids = np.arange(len(prompts))
288
223
 
289
224
  # if using cache, check for cached completions
@@ -298,10 +233,9 @@ class LLMClient:
298
233
  ), "Cache hit ids and results must be the same length."
299
234
  remaining_ids = np.array([i for i in ids if i not in cache_hit_ids])
300
235
  remaining_prompts = [prompts[i] for i in remaining_ids]
301
- if verbose:
302
- print(f"{len(cache_hit_ids)} cache hits from previous completions.")
303
- print(f"{len(remaining_ids)} prompts remaining after cache hits.")
304
- print(f"Processing {len(remaining_prompts)} prompts.")
236
+ print(
237
+ f"{len(cache_hit_ids)} cache hits; {len(remaining_ids)} prompts remaining."
238
+ )
305
239
 
306
240
  else:
307
241
  cache_hit_ids = []
@@ -311,48 +245,102 @@ class LLMClient:
311
245
 
312
246
  results: list[APIResponse | None] = [None for _ in range(len(prompts))]
313
247
  if len(remaining_prompts) > 0:
314
- # set up progress bar
315
- pbar = tqdm(total=len(prompts), disable=(not show_progress))
316
-
317
- # update progress bar with cache hits
318
- pbar.update(len(cache_hit_ids))
319
- api_task = None
320
- if dry_run:
321
- dry_run_results = api_prompts_dry_run(
322
- ids,
323
- prompts, # type: ignore -- fix later for dry running conversations
324
- self.models,
325
- self.model_weights,
326
- self.sampling_params,
327
- max_tokens_per_minute=self.max_tokens_per_minute,
328
- max_requests_per_minute=self.max_requests_per_minute,
329
- )
330
- print("Dry run results:")
331
- print(dry_run_results)
332
- return dry_run_results
333
-
334
- api_task = asyncio.create_task(
335
- process_api_prompts_async(
336
- ids,
337
- prompts, # type: ignore -- fix later for dry running conversations
338
- self.models,
339
- self.model_weights,
340
- self.sampling_params,
341
- logprobs=self.logprobs,
342
- top_logprobs=self.top_logprobs,
343
- max_attempts=self.max_attempts,
344
- max_tokens_per_minute=self.max_tokens_per_minute,
345
- max_requests_per_minute=self.max_requests_per_minute,
346
- max_concurrent_requests=self.max_concurrent_requests,
347
- request_timeout=self.request_timeout,
348
- progress_bar=pbar,
349
- use_qps=self.use_qps,
350
- verbose=verbose,
351
- tools=tools,
352
- cache=cache,
353
- )
248
+ # Create StatusTracker with integrated progress bar
249
+ tracker = StatusTracker(
250
+ max_requests_per_minute=self.max_requests_per_minute,
251
+ max_tokens_per_minute=self.max_tokens_per_minute,
252
+ max_concurrent_requests=self.max_concurrent_requests,
253
+ use_progress_bar=show_progress,
254
+ progress_bar_total=len(prompts),
255
+ progress_bar_disable=not show_progress,
256
+ use_rich=show_progress, # Disable Rich if progress is disabled
354
257
  )
355
- api_results: list[APIResponse] = await api_task
258
+
259
+ # Initialize progress bar and update with cache hits
260
+ tracker.init_progress_bar()
261
+ if len(cache_hit_ids) > 0:
262
+ tracker.update_pbar(len(cache_hit_ids))
263
+
264
+ if isinstance(ids, np.ndarray):
265
+ ids = ids.tolist() # pyright: ignore
266
+
267
+ # calculate dynamically so we don't throttle RPM
268
+ seconds_to_sleep_each_loop = (60.0 * 0.9) / tracker.max_requests_per_minute
269
+ next_request = None # variable to hold the next request to call
270
+ prompts_not_finished = True
271
+ prompts_iter = iter(zip(ids, prompts))
272
+ requests: list[APIRequestBase] = []
273
+ assert tracker.retry_queue, "retry queue not initialized"
274
+ while True:
275
+ # get next request (if one is not already waiting for capacity)
276
+ retry_request = False
277
+ if next_request is None:
278
+ if not tracker.retry_queue.empty():
279
+ next_request = tracker.retry_queue.get_nowait()
280
+ retry_request = True
281
+ print(f"Retrying request {next_request.task_id}.")
282
+ elif prompts_not_finished:
283
+ try:
284
+ # get new request
285
+ id, prompt = next(prompts_iter)
286
+ # select model
287
+ model, sampling_params = self._select_model()
288
+
289
+ next_request = create_api_request(
290
+ task_id=id,
291
+ model_name=model,
292
+ prompt=prompt, # type: ignore
293
+ request_timeout=self.request_timeout,
294
+ attempts_left=self.max_attempts,
295
+ status_tracker=tracker,
296
+ results_arr=requests,
297
+ sampling_params=sampling_params,
298
+ all_model_names=self.models,
299
+ all_sampling_params=self.sampling_params,
300
+ tools=tools,
301
+ cache=cache,
302
+ computer_use=computer_use,
303
+ display_width=display_width,
304
+ display_height=display_height,
305
+ use_responses_api=use_responses_api,
306
+ )
307
+ requests.append(next_request)
308
+
309
+ except StopIteration:
310
+ prompts_not_finished = False
311
+ # print("API requests finished, only retries remain.")
312
+
313
+ # update available capacity
314
+ tracker.update_capacity()
315
+
316
+ # if enough capacity available, call API
317
+ if next_request:
318
+ next_request_tokens = next_request.num_tokens
319
+ if tracker.check_capacity(next_request_tokens, retry=retry_request):
320
+ tracker.set_limiting_factor(None)
321
+ # call API (attempts_left will be decremented in handle_error if it fails)
322
+ asyncio.create_task(next_request.call_api())
323
+ next_request = None # reset next_request to empty
324
+ # update pbar status
325
+ tracker.update_pbar()
326
+
327
+ # if all tasks are finished, break
328
+ if tracker.num_tasks_in_progress == 0:
329
+ break
330
+
331
+ # main loop sleeps briefly so concurrent tasks can run
332
+ await asyncio.sleep(seconds_to_sleep_each_loop)
333
+
334
+ # if a rate limit error was hit recently, pause to cool down
335
+ if tracker.seconds_to_pause > 0:
336
+ await asyncio.sleep(tracker.seconds_to_pause)
337
+ print(f"Pausing {tracker.seconds_to_pause}s to cool down.")
338
+
339
+ # after finishing, log final status
340
+ tracker.log_final_status()
341
+
342
+ # deduplicate results by id
343
+ api_results = deduplicate_responses(requests)
356
344
  for res in api_results:
357
345
  results[res.id] = res
358
346
  # set to cache if result has a completion
@@ -375,8 +363,6 @@ class LLMClient:
375
363
  *,
376
364
  return_completions_only: bool = False,
377
365
  show_progress=True,
378
- dry_run: bool = False,
379
- verbose: bool = False,
380
366
  tools: list[Tool] | None = None,
381
367
  cache: CachePattern | None = None,
382
368
  ):
@@ -385,387 +371,119 @@ class LLMClient:
385
371
  prompts=prompts,
386
372
  return_completions_only=return_completions_only,
387
373
  show_progress=show_progress,
388
- dry_run=dry_run,
389
- verbose=verbose,
390
374
  tools=tools,
391
375
  cache=cache,
392
376
  )
393
377
  )
394
378
 
395
- def _submit_one_batch(self, batch_requests: list):
396
- # save the file
397
- import pandas as pd
398
-
399
- pd.DataFrame(batch_requests).to_json(
400
- "openai_requests_temp.jsonl", orient="records", lines=True
401
- )
402
-
403
- # upload the file
404
- api_key = os.environ.get("OPENAI_API_KEY", None)
405
- if api_key is None:
406
- raise ValueError("OPENAI_API_KEY environment variable must be set.")
407
- url = "https://api.openai.com/v1/files"
408
- files = {
409
- "file": (
410
- "openai_requests_temp.jsonl",
411
- open("openai_requests_temp.jsonl", "rb"),
412
- ),
413
- }
414
- data = {
415
- "purpose": "batch",
416
- }
417
- headers = {
418
- "Authorization": f"Bearer {api_key}",
419
- }
420
- response = requests.post(url, files=files, data=data, headers=headers)
421
-
422
- file_id = None
423
- if response.status_code == 200:
424
- print("File uploaded successfully")
425
- data = response.json()
426
- file_id = data["id"]
427
-
428
- else:
429
- print("File upload failed")
430
- raise ValueError(f"Error uploading file: {response.text}")
431
-
432
- url = "https://api.openai.com/v1/batches"
433
- data = {
434
- "input_file_id": file_id,
435
- "endpoint": "/v1/chat/completions",
436
- "completion_window": "24h",
437
- }
438
- response = requests.post(url, json=data, headers=headers)
439
-
440
- batch_id = None
441
- if response.status_code == 200:
442
- data = response.json()
443
- batch_id = data["id"]
444
- print("Batch job started successfully: id = ", batch_id)
445
- return batch_id
446
- else:
447
- print("Batch job failed to start")
448
- raise ValueError(f"Error starting batch job: {response.text}")
379
+ async def stream(self, prompt: str | Conversation, tools: list[Tool] | None = None):
380
+ model, sampling_params = self._select_model()
381
+ if isinstance(prompt, str):
382
+ prompt = Conversation.user(prompt)
383
+ async for item in stream_chat(model, prompt, sampling_params, tools, None):
384
+ if isinstance(item, str):
385
+ print(item, end="", flush=True)
386
+ else:
387
+ # final item
388
+ return item
449
389
 
450
- def submit_batch_job(self, prompts: Sequence[str | list[dict] | Conversation]):
451
- # make sure 1) only 1 model is used, 2) it's an openai model, 3) it supports json mode
390
+ async def submit_batch_job(
391
+ self,
392
+ prompts: Sequence[str | list[dict] | Conversation],
393
+ *,
394
+ tools: list[Tool] | None = None,
395
+ cache: CachePattern | None = None,
396
+ ):
397
+ """Submit a batch job asynchronously, automatically detecting the provider based on model.
398
+
399
+ Args:
400
+ prompts: List of prompts to process
401
+ wait_for_completion: If True, poll until completion and return results
402
+ poll_interval: Seconds to wait between status checks when polling
403
+ tools: Optional tools to include in requests (Anthropic only)
404
+ cache: Optional cache pattern for requests (Anthropic only)
405
+
406
+ Returns: list of batch_ids
407
+ """
408
+ assert isinstance(self.sampling_params, list)
452
409
  if len(self.models) != 1:
453
410
  raise ValueError("Batch jobs can only be submitted with a single model.")
454
411
  model = self.models[0]
455
- if registry[model].get("api_spec", None) != "openai":
456
- raise ValueError("Batch jobs can only be submitted with OpenAI models.")
457
-
458
- # if prompts are strings, convert them to message lists
459
- prompts = [ # type: ignore
460
- p
461
- if isinstance(p, Conversation)
462
- else Conversation.user(p)
463
- if isinstance(p, str)
464
- else None
465
- for p in prompts
466
- ]
467
- if any(p is None for p in prompts):
468
- raise ValueError("All prompts must be valid.")
469
- ids = np.arange(len(prompts))
470
-
471
- # create file with requests to send to batch api
472
- batch_requests = []
473
- for id, prompt in zip(ids, prompts):
474
- assert isinstance(prompt, Conversation)
475
- batch_requests.append(
476
- {
477
- "custom_id": str(id),
478
- "method": "POST",
479
- "url": "/v1/chat/completions",
480
- "body": {
481
- "model": self.models[0],
482
- "messages": prompt.to_openai(),
483
- "max_tokens": self.sampling_params[0].max_new_tokens,
484
- "temperature": self.sampling_params[0].temperature,
485
- "top_p": self.sampling_params[0].top_p,
486
- },
487
- }
488
- )
489
-
490
- # since the api only accepts up to 50,000 requests per batch job, we chunk into 50k chunks
491
- BATCH_SIZE = 50_000
492
- batches = [
493
- batch_requests[i : i + BATCH_SIZE]
494
- for i in range(0, len(batch_requests), BATCH_SIZE)
495
- ]
496
- batch_ids = []
497
- for batch in tqdm(batches):
498
- batch_id = self._submit_one_batch(batch)
499
- batch_ids.append(batch_id)
500
-
501
- print(f"Submitted {len(batches)} batch jobs.")
502
- return batch_ids
503
-
504
-
505
- def api_prompts_dry_run(
506
- ids: np.ndarray | list[int],
507
- prompts: list[Conversation],
508
- models: str | list[str],
509
- model_weights: list[float],
510
- sampling_params: list[SamplingParams],
511
- max_tokens_per_minute: int = 500_000,
512
- max_requests_per_minute: int = 1_000,
513
- ):
514
- """
515
- Count tokens and estimate costs for a batch of prompts.
516
- """
517
- results = []
518
- for i, prompt in zip(ids, prompts):
519
- # choose a model
520
- model_idx = np.random.choice(range(len(models)), p=model_weights)
521
- model = models[model_idx]
522
-
523
- # dry run
524
- input_tokens, output_tokens, min_cost, max_cost = prompt.dry_run(
525
- model, sampling_params[model_idx].max_new_tokens
526
- )
527
- results.append(
528
- {
529
- "id": i,
530
- "input_tokens": input_tokens,
531
- "output_tokens": output_tokens,
532
- "min_cost": min_cost,
533
- "max_cost": max_cost,
534
- }
535
- )
536
-
537
- combined_results: dict[str, Any] = {
538
- "total_input_tokens": sum([r["input_tokens"] for r in results]),
539
- "total_output_tokens": sum([r["output_tokens"] for r in results]),
540
- "total_min_cost": sum([r["min_cost"] for r in results]),
541
- "total_max_cost": sum([r["max_cost"] for r in results]),
542
- }
543
- minimum_time_tpm = combined_results["total_input_tokens"] / max_tokens_per_minute
544
- maximum_time_tpm = (
545
- combined_results["total_input_tokens"] + combined_results["total_output_tokens"]
546
- ) / max_tokens_per_minute
547
- minimum_time_rpm = len(prompts) / max_requests_per_minute
548
-
549
- combined_results["minimum_time"] = max(minimum_time_tpm, minimum_time_rpm)
550
- combined_results["maximum_time"] = max(maximum_time_tpm, minimum_time_rpm)
551
- limiting_factor = None
552
- if minimum_time_rpm > maximum_time_tpm:
553
- limiting_factor = "requests"
554
- elif minimum_time_rpm < minimum_time_tpm:
555
- limiting_factor = "tokens"
556
- else:
557
- limiting_factor = "depends"
558
- combined_results["limiting_factor"] = limiting_factor
559
-
560
- return combined_results
561
-
562
-
563
- async def process_api_prompts_async(
564
- ids: np.ndarray | list[int],
565
- prompts: list[Conversation],
566
- models: str | list[str],
567
- model_weights: list[float],
568
- sampling_params: list[SamplingParams],
569
- logprobs: bool,
570
- top_logprobs: int | None,
571
- max_attempts: int = 5,
572
- max_tokens_per_minute: int = 500_000,
573
- max_requests_per_minute: int = 1_000,
574
- max_concurrent_requests: int = 1_000,
575
- request_timeout: int = 30,
576
- progress_bar: tqdm | None = None,
577
- use_qps: bool = False,
578
- verbose: bool = False,
579
- tools: list[Tool] | None = None,
580
- cache: CachePattern | None = None,
581
- ):
582
- """Processes API requests in parallel, throttling to stay under rate limits."""
583
- # change ids to integer list
584
- if isinstance(ids, np.ndarray):
585
- ids = ids.tolist() # pyright: ignore
586
-
587
- # normalize weights
588
- model_weights = [w / sum(model_weights) for w in model_weights]
589
-
590
- # constants
591
- seconds_to_pause_after_rate_limit_error = 5
592
- # seconds_to_sleep_each_loop = 0.003 # so concurrent tasks can run
593
- # calculate dynamically so we don't throttle RPM
594
- seconds_to_sleep_each_loop = (60.0 * 0.9) / max_requests_per_minute
595
-
596
- # initialize trackers
597
- retry_queue = asyncio.Queue()
598
- status_tracker = StatusTracker()
599
- next_request = None # variable to hold the next request to call
600
-
601
- # initialize available capacity counts
602
- # throttle over a 1 second window rather than minute,
603
- # since some models limit RPS rather than RPM
604
- if use_qps:
605
- available_request_capacity = max_requests_per_minute / 60.0
606
- available_token_capacity = max_tokens_per_minute
607
- else:
608
- available_request_capacity = max_requests_per_minute
609
- available_token_capacity = max_tokens_per_minute
610
- last_update_time = time.time()
611
- last_pbar_update_time = time.time()
612
-
613
- # initialize flags
614
- prompts_not_finished = True
615
-
616
- # initials model weights
617
- if isinstance(models, str):
618
- models = [models]
619
- if not isinstance(models, list):
620
- raise ValueError("models must be a string or a list of model strings.")
621
- for model in models:
622
- if model not in registry:
623
- raise ValueError(f"Model {model} not found in registry.")
624
-
625
- if model_weights is None:
626
- # if not given, spread requests evenly across models
627
- model_weights = [1 / len(models) for _ in models]
628
- elif len(model_weights) != len(models):
629
- raise ValueError(
630
- "model_weights must be None or a list of the same length as models."
631
- )
632
- elif sum(model_weights) != 1:
633
- model_weights = [w / sum(model_weights) for w in model_weights]
634
-
635
- prompts_iter = iter(zip(ids, prompts))
636
- results: list[APIRequestBase] = []
637
- while True:
638
- # get next request (if one is not already waiting for capacity)
639
- if next_request is None:
640
- if not retry_queue.empty():
641
- next_request = retry_queue.get_nowait()
642
- print(f"Retrying request {next_request.task_id}.")
643
- elif prompts_not_finished:
644
- try:
645
- # get new request
646
- id, prompt = next(prompts_iter)
647
- # select model
648
- model_idx = np.random.choice(range(len(models)), p=model_weights)
649
- next_request = create_api_request(
650
- task_id=id,
651
- model_name=models[model_idx],
652
- prompt=prompt,
653
- request_timeout=request_timeout,
654
- attempts_left=max_attempts,
655
- status_tracker=status_tracker,
656
- retry_queue=retry_queue,
657
- results_arr=results,
658
- sampling_params=sampling_params[model_idx],
659
- logprobs=logprobs,
660
- top_logprobs=top_logprobs,
661
- pbar=progress_bar,
662
- all_model_names=models,
663
- all_sampling_params=sampling_params,
664
- tools=tools,
665
- cache=cache,
666
- )
667
- status_tracker.num_tasks_started += 1
668
- status_tracker.num_tasks_in_progress += 1
669
- results.append(next_request)
670
-
671
- except StopIteration:
672
- prompts_not_finished = False
673
- if verbose:
674
- print("API requests finished, only retries remain.")
675
-
676
- # update available capacity
677
- current_time = time.time()
678
- seconds_since_update = current_time - last_update_time
679
- available_request_capacity = min(
680
- available_request_capacity
681
- + max_requests_per_minute * seconds_since_update / 60.0,
682
- max_requests_per_minute,
683
- )
684
- available_token_capacity = min(
685
- available_token_capacity
686
- + max_tokens_per_minute * seconds_since_update / 60.0,
687
- max_tokens_per_minute,
688
- )
689
- last_update_time = current_time
690
-
691
- # if enough capacity available, call API
692
- limiting_factor = None
693
- if next_request:
694
- next_request_tokens = next_request.num_tokens
695
- request_available = available_request_capacity >= 1
696
- tokens_available = available_token_capacity >= next_request_tokens
697
- concurrent_request_available = (
698
- status_tracker.num_tasks_in_progress < max_concurrent_requests
699
- )
700
- if request_available and tokens_available and concurrent_request_available:
701
- # update counters
702
- available_request_capacity -= 1
703
- available_token_capacity -= next_request_tokens
704
- next_request.attempts_left -= 1
705
-
706
- # call API
707
- asyncio.create_task(next_request.call_api())
708
- next_request = None # reset next_request to empty
709
- else:
710
- if not request_available:
711
- limiting_factor = "Requests"
712
- elif not concurrent_request_available:
713
- limiting_factor = "Concurrent Requests"
714
- elif not tokens_available:
715
- limiting_factor = "Tokens"
716
-
717
- # update pbar status
718
- if progress_bar and (current_time - last_pbar_update_time > 1):
719
- last_pbar_update_time = current_time
720
- progress_bar.set_postfix(
721
- {
722
- "Token Capacity": f"{available_token_capacity/1_000:.1f}k",
723
- "Req. Capacity": f"{available_request_capacity:.1f}",
724
- "Reqs. in Progress": status_tracker.num_tasks_in_progress,
725
- "Limiting Factor": limiting_factor,
726
- }
412
+ api_spec = registry[model].get("api_spec", None)
413
+
414
+ if api_spec == "openai":
415
+ return await submit_batches_oa(model, self.sampling_params[0], prompts)
416
+ elif api_spec == "anthropic":
417
+ return await submit_batches_anthropic(
418
+ model,
419
+ self.sampling_params[0],
420
+ prompts,
421
+ cache=cache,
727
422
  )
423
+ else:
424
+ raise ValueError(f"Batch processing not supported for API spec: {api_spec}")
728
425
 
729
- # if all tasks are finished, break
730
- if status_tracker.num_tasks_in_progress == 0:
731
- break
732
-
733
- # main loop sleeps briefly so concurrent tasks can run
734
- await asyncio.sleep(seconds_to_sleep_each_loop)
735
-
736
- # if a rate limit error was hit recently, pause to cool down
737
- remaining_seconds_to_pause = max(
738
- 0,
739
- seconds_to_pause_after_rate_limit_error
740
- - status_tracker.time_since_rate_limit_error,
741
- )
742
- if remaining_seconds_to_pause > 0:
743
- await asyncio.sleep(remaining_seconds_to_pause)
744
- print(f"Pausing {remaining_seconds_to_pause}s to cool down.")
745
-
746
- # after finishing, log final status
747
- status_tracker.log_final_status()
748
- if verbose:
749
- print(
750
- f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
426
+ async def wait_for_batch_job(
427
+ self, batch_ids: list[str], provider: Literal["anthropic", "openai"]
428
+ ):
429
+ return await wait_for_batch_completion_async(
430
+ batch_ids, provider, poll_interval=30
751
431
  )
752
432
 
753
- # deduplicate results by id
754
- deduplicated = {}
755
- for request in results:
756
- if request.task_id not in deduplicated:
757
- deduplicated[request.task_id] = request.result[-1]
758
- else:
759
- current_response: APIResponse = deduplicated[request.task_id]
760
- # only replace if the current request has no completion and the new one does
761
- if (
762
- request.result[-1].completion is not None
763
- and current_response.completion is None
764
- ):
765
- deduplicated[request.task_id] = request.result[-1]
766
-
767
- output = list(deduplicated.values())
768
- if verbose:
769
- print(f"Returning {len(output)} unique results.")
770
433
 
771
- return output
434
+ # def api_prompts_dry_run(
435
+ # ids: np.ndarray | list[int],
436
+ # prompts: list[Conversation],
437
+ # models: str | list[str],
438
+ # model_weights: list[float],
439
+ # sampling_params: list[SamplingParams],
440
+ # max_tokens_per_minute: int = 500_000,
441
+ # max_requests_per_minute: int = 1_000,
442
+ # ):
443
+ # """
444
+ # Count tokens and estimate costs for a batch of prompts.
445
+ # """
446
+ # results = []
447
+ # for i, prompt in zip(ids, prompts):
448
+ # # choose a model
449
+ # model_idx = np.random.choice(range(len(models)), p=model_weights)
450
+ # model = models[model_idx]
451
+
452
+ # # dry run
453
+ # input_tokens, output_tokens, min_cost, max_cost = prompt.dry_run(
454
+ # model, sampling_params[model_idx].max_new_tokens
455
+ # )
456
+ # results.append(
457
+ # {
458
+ # "id": i,
459
+ # "input_tokens": input_tokens,
460
+ # "output_tokens": output_tokens,
461
+ # "min_cost": min_cost,
462
+ # "max_cost": max_cost,
463
+ # }
464
+ # )
465
+
466
+ # combined_results: dict[str, Any] = {
467
+ # "total_input_tokens": sum([r["input_tokens"] for r in results]),
468
+ # "total_output_tokens": sum([r["output_tokens"] for r in results]),
469
+ # "total_min_cost": sum([r["min_cost"] for r in results]),
470
+ # "total_max_cost": sum([r["max_cost"] for r in results]),
471
+ # }
472
+ # minimum_time_tpm = combined_results["total_input_tokens"] / max_tokens_per_minute
473
+ # maximum_time_tpm = (
474
+ # combined_results["total_input_tokens"] + combined_results["total_output_tokens"]
475
+ # ) / max_tokens_per_minute
476
+ # minimum_time_rpm = len(prompts) / max_requests_per_minute
477
+
478
+ # combined_results["minimum_time"] = max(minimum_time_tpm, minimum_time_rpm)
479
+ # combined_results["maximum_time"] = max(maximum_time_tpm, minimum_time_rpm)
480
+ # limiting_factor = None
481
+ # if minimum_time_rpm > maximum_time_tpm:
482
+ # limiting_factor = "requests"
483
+ # elif minimum_time_rpm < minimum_time_tpm:
484
+ # limiting_factor = "tokens"
485
+ # else:
486
+ # limiting_factor = "depends"
487
+ # combined_results["limiting_factor"] = limiting_factor
488
+
489
+ # return combined_results