lm-deluge 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

lm_deluge/batches.py ADDED
@@ -0,0 +1,498 @@
1
+ import os
2
+ import json
3
+ import time
4
+ import asyncio
5
+ import aiohttp
6
+ import pandas as pd
7
+ from lm_deluge.prompt import CachePattern, Conversation, prompts_to_conversations
8
+ from lm_deluge.config import SamplingParams
9
+ from lm_deluge.models import APIModel
10
+ from typing import Sequence, Literal
11
+ from lm_deluge.api_requests.openai import _build_oa_chat_request
12
+ from lm_deluge.api_requests.anthropic import _build_anthropic_request
13
+ from rich.console import Console
14
+ from rich.live import Live
15
+ from rich.spinner import Spinner
16
+ from rich.table import Table
17
+ from rich.text import Text
18
+ from lm_deluge.models import registry
19
+
20
+
21
+ def _create_batch_status_display(
22
+ batch_id: str,
23
+ status: str,
24
+ elapsed: float,
25
+ counts: dict | None,
26
+ provider: str,
27
+ ):
28
+ """Create a unified status display for batch jobs."""
29
+ # Format elapsed time
30
+ hours = int(elapsed // 3600)
31
+ minutes = int((elapsed % 3600) // 60)
32
+ seconds = int(elapsed % 60)
33
+
34
+ if hours > 0:
35
+ elapsed_str = f"{hours}h {minutes}m {seconds}s"
36
+ elif minutes > 0:
37
+ elapsed_str = f"{minutes}m {seconds}s"
38
+ else:
39
+ elapsed_str = f"{seconds}s"
40
+
41
+ # Build progress text based on provider
42
+ progress_text = ""
43
+ if counts:
44
+ if provider == "openai":
45
+ total = counts.get("total", 0)
46
+ completed = counts.get("completed", 0)
47
+ failed = counts.get("failed", 0)
48
+ total_display = "?" if total == 0 else str(total)
49
+ progress_text = f" • {completed}/{total_display} done"
50
+ if failed > 0:
51
+ progress_text += f", {failed} failed"
52
+ elif provider == "anthropic":
53
+ total = (
54
+ counts.get("processing", 0)
55
+ + counts.get("succeeded", 0)
56
+ + counts.get("errored", 0)
57
+ )
58
+ succeeded = counts.get("succeeded", 0)
59
+ errored = counts.get("errored", 0)
60
+ total_display = "?" if total == 0 else str(total)
61
+ progress_text = f" • {succeeded}/{total_display} done"
62
+ if errored > 0:
63
+ progress_text += f", {errored} errors"
64
+
65
+ # Choose spinner color based on provider
66
+ spinner_style = "green" if provider == "openai" else "blue"
67
+ spinner = Spinner("dots", style=spinner_style, text="")
68
+
69
+ grid = Table.grid()
70
+ grid.add_column()
71
+ grid.add_column()
72
+ grid.add_row(
73
+ spinner,
74
+ Text(
75
+ f" Batch {batch_id} • {status} • {elapsed_str}{progress_text}",
76
+ style="white",
77
+ ),
78
+ )
79
+ return grid
80
+
81
+
82
+ async def submit_batch_oa(batch_requests: list[dict]):
83
+ """Submit one batch asynchronously."""
84
+ pd.DataFrame(batch_requests).to_json(
85
+ "requests_temp.jsonl", orient="records", lines=True
86
+ )
87
+
88
+ # upload the file
89
+ api_key = os.environ.get("OPENAI_API_KEY", None)
90
+ if api_key is None:
91
+ raise ValueError("OPENAI_API_KEY environment variable must be set.")
92
+
93
+ headers = {
94
+ "Authorization": f"Bearer {api_key}",
95
+ }
96
+
97
+ async with aiohttp.ClientSession() as session:
98
+ # Upload file
99
+ url = "https://api.openai.com/v1/files"
100
+ data = aiohttp.FormData()
101
+ data.add_field("purpose", "batch")
102
+ data.add_field(
103
+ "file",
104
+ open("requests_temp.jsonl", "rb"),
105
+ filename="requests_temp.jsonl",
106
+ content_type="application/json",
107
+ )
108
+
109
+ async with session.post(url, data=data, headers=headers) as response:
110
+ if response.status != 200:
111
+ text = await response.text()
112
+ raise ValueError(f"Error uploading file: {text}")
113
+
114
+ print("File uploaded successfully")
115
+ response_data = await response.json()
116
+ file_id = response_data["id"]
117
+
118
+ # Create batch
119
+ url = "https://api.openai.com/v1/batches"
120
+ batch_data = {
121
+ "input_file_id": file_id,
122
+ "endpoint": "/v1/chat/completions",
123
+ "completion_window": "24h",
124
+ }
125
+
126
+ async with session.post(url, json=batch_data, headers=headers) as response:
127
+ if response.status != 200:
128
+ text = await response.text()
129
+ raise ValueError(f"Error starting batch job: {text}")
130
+
131
+ response_data = await response.json()
132
+ batch_id = response_data["id"]
133
+ print("Batch job started successfully: id = ", batch_id)
134
+ return batch_id
135
+
136
+
137
+ async def submit_batches_oa(
138
+ model: str,
139
+ sampling_params: SamplingParams,
140
+ prompts: Sequence[str | list[dict] | Conversation],
141
+ ):
142
+ # if prompts are strings, convert them to message lists
143
+ prompts = prompts_to_conversations(prompts)
144
+ if any(p is None for p in prompts):
145
+ raise ValueError("All prompts must be valid.")
146
+ ids = [i for i, _ in enumerate(prompts)]
147
+
148
+ # create file with requests to send to batch api
149
+ batch_requests = []
150
+ model_obj = APIModel.from_registry(model)
151
+ for id, prompt in zip(ids, prompts):
152
+ assert isinstance(prompt, Conversation)
153
+ batch_requests.append(
154
+ {
155
+ "custom_id": str(id),
156
+ "method": "POST",
157
+ "url": "/v1/chat/completions",
158
+ "body": _build_oa_chat_request(model_obj, prompt, [], sampling_params),
159
+ }
160
+ )
161
+
162
+ # since the api only accepts up to 50,000 requests per batch job, we chunk into 50k chunks
163
+ BATCH_SIZE = 50_000
164
+ batches = [
165
+ batch_requests[i : i + BATCH_SIZE]
166
+ for i in range(0, len(batch_requests), BATCH_SIZE)
167
+ ]
168
+ tasks = []
169
+ for batch in batches:
170
+ tasks.append(asyncio.create_task(submit_batch_oa(batch)))
171
+ batch_ids = await asyncio.gather(*tasks)
172
+
173
+ print(f"Submitted {len(batches)} batch jobs.")
174
+
175
+ return batch_ids
176
+
177
+
178
+ async def submit_batches_anthropic(
179
+ model: str,
180
+ sampling_params: SamplingParams,
181
+ prompts: Sequence[str | list[dict] | Conversation],
182
+ *,
183
+ cache: CachePattern | None = None,
184
+ ):
185
+ """Submit a batch job to Anthropic's Message Batches API.
186
+
187
+ Args:
188
+ prompts: List of prompts to process
189
+ wait_for_completion: If True, poll until completion and return results
190
+ poll_interval: Seconds to wait between status checks when polling
191
+ tools: Optional tools to include in requests
192
+ cache: Optional cache pattern for requests
193
+
194
+ Returns: batch_ids (list[str])
195
+ """
196
+
197
+ # Convert prompts to Conversations
198
+ prompts = prompts_to_conversations(prompts)
199
+ # Create batch requests
200
+ request_headers = None
201
+ batch_requests = []
202
+ for i, prompt in enumerate(prompts):
203
+ assert isinstance(prompt, Conversation)
204
+ # Build request body
205
+ request_body, request_headers = _build_anthropic_request(
206
+ APIModel.from_registry(model), prompt, [], sampling_params, cache
207
+ )
208
+
209
+ batch_requests.append({"custom_id": str(i), "params": request_body})
210
+
211
+ # Chunk into batches of 100k requests (Anthropic's limit)
212
+ BATCH_SIZE = 100_000
213
+ batches = [
214
+ batch_requests[i : i + BATCH_SIZE]
215
+ for i in range(0, len(batch_requests), BATCH_SIZE)
216
+ ]
217
+ batch_ids = []
218
+ batch_tasks = []
219
+ async with aiohttp.ClientSession() as session:
220
+ for batch in batches:
221
+ url = f"{registry[model]['api_base']}/messages/batches"
222
+ data = {"requests": batch}
223
+
224
+ async def submit_batch(data, url, headers):
225
+ async with session.post(url, json=data, headers=headers) as response:
226
+ if response.status != 200:
227
+ text = await response.text()
228
+ raise ValueError(f"Error creating batch: {text}")
229
+
230
+ batch_data = await response.json()
231
+ batch_id = batch_data["id"]
232
+ print(f"Anthropic batch job started successfully: id = {batch_id}")
233
+ return batch_id
234
+
235
+ batch_tasks.append(submit_batch(data, url, request_headers))
236
+
237
+ batch_ids = await asyncio.gather(*batch_tasks)
238
+
239
+ print(f"Submitted {len(batches)} batch jobs.")
240
+ return batch_ids
241
+
242
+
243
+ async def wait_for_batch_completion_async(
244
+ batch_ids: list[str],
245
+ provider: Literal["openai", "anthropic"],
246
+ poll_interval: int = 30,
247
+ ):
248
+ """Wait for multiple batches to complete and return results asynchronously.
249
+
250
+ Args:
251
+ batch_ids: List of batch IDs to wait for
252
+ provider: Which provider the batches are from
253
+ poll_interval: Seconds to wait between status checks
254
+
255
+ Returns:
256
+ List of results for each batch
257
+ """
258
+ tasks = []
259
+ for batch_id in batch_ids:
260
+ if provider == "openai":
261
+ task = _wait_for_openai_batch_completion_async(batch_id, poll_interval)
262
+ elif provider == "anthropic":
263
+ task = _wait_for_anthropic_batch_completion_async(batch_id, poll_interval)
264
+ else:
265
+ raise ValueError(f"Unsupported provider: {provider}")
266
+ tasks.append(task)
267
+
268
+ # Wait for all batches concurrently
269
+ results = await asyncio.gather(*tasks)
270
+
271
+ results = [compl for batch in results for compl in batch]
272
+
273
+ return results
274
+
275
+
276
+ async def _wait_for_anthropic_batch_completion_async(
277
+ batch_id: str, poll_interval: int = 30
278
+ ):
279
+ """Poll Anthropic batch until completion and return results asynchronously."""
280
+ api_key = os.getenv("ANTHROPIC_API_KEY")
281
+ headers = {
282
+ "x-api-key": api_key,
283
+ "anthropic-version": "2023-06-01",
284
+ "content-type": "application/json",
285
+ }
286
+
287
+ url = f"https://api.anthropic.com/v1/messages/batches/{batch_id}"
288
+ console = Console()
289
+ start_time = time.time()
290
+
291
+ # Event to signal when to stop the display updater
292
+ stop_display_event = asyncio.Event()
293
+ current_status = {"status": "processing", "counts": None}
294
+
295
+ async def display_updater():
296
+ """Update display independently of polling."""
297
+ with Live(console=console, refresh_per_second=10) as live:
298
+ while not stop_display_event.is_set():
299
+ elapsed = time.time() - start_time
300
+ display = _create_batch_status_display(
301
+ batch_id,
302
+ current_status["status"],
303
+ elapsed,
304
+ current_status["counts"],
305
+ "anthropic",
306
+ )
307
+ live.update(display)
308
+ await asyncio.sleep(0.1) # Update every 100ms
309
+
310
+ # Start display updater
311
+ display_task = asyncio.create_task(display_updater())
312
+
313
+ try:
314
+ async with aiohttp.ClientSession() as session:
315
+ while True:
316
+ async with session.get(url, headers=headers) as response:
317
+ if response.status != 200:
318
+ text = await response.text()
319
+ raise ValueError(f"Error checking batch status: {text}")
320
+
321
+ batch_data = await response.json()
322
+ current_status["status"] = batch_data["processing_status"]
323
+ current_status["counts"] = batch_data.get("request_counts", {})
324
+
325
+ if current_status["status"] == "ended":
326
+ stop_display_event.set()
327
+ await display_task
328
+ console.print(
329
+ f"✅ Batch {batch_id} completed!", style="green bold"
330
+ )
331
+ return await _retrieve_anthropic_batch_results_async(batch_id)
332
+ elif current_status["status"] in ["canceled", "expired"]:
333
+ stop_display_event.set()
334
+ await display_task
335
+ raise ValueError(
336
+ f"Batch {batch_id} failed with status: {current_status['status']}"
337
+ )
338
+
339
+ await asyncio.sleep(poll_interval)
340
+ finally:
341
+ stop_display_event.set()
342
+ await display_task
343
+
344
+
345
+ async def _retrieve_anthropic_batch_results_async(batch_id: str):
346
+ """Retrieve results from completed Anthropic batch asynchronously."""
347
+ api_key = os.getenv("ANTHROPIC_API_KEY")
348
+ headers = {
349
+ "x-api-key": api_key,
350
+ "anthropic-version": "2023-06-01",
351
+ }
352
+
353
+ url = f"https://api.anthropic.com/v1/messages/batches/{batch_id}/results"
354
+
355
+ async with aiohttp.ClientSession() as session:
356
+ async with session.get(url, headers=headers) as response:
357
+ if response.status != 200:
358
+ text = await response.text()
359
+ raise ValueError(f"Error retrieving batch results: {text}")
360
+
361
+ # Parse JSONL results
362
+ results = []
363
+ text = await response.text()
364
+ for line in text.strip().split("\n"):
365
+ if line:
366
+ result = json.loads(line)
367
+ results.append(result)
368
+
369
+ # Sort by custom_id to maintain order
370
+ results.sort(key=lambda x: int(x["custom_id"]))
371
+
372
+ return results
373
+
374
+
375
+ async def _retrieve_openai_batch_results_async(batch_id: str):
376
+ """Retrieve results from OpenAI batch asynchronously."""
377
+ api_key = os.getenv("OPENAI_API_KEY")
378
+ if api_key is None:
379
+ raise ValueError("OPENAI_API_KEY environment variable must be set.")
380
+
381
+ headers = {
382
+ "Authorization": f"Bearer {api_key}",
383
+ "Content-Type": "application/json",
384
+ }
385
+
386
+ async with aiohttp.ClientSession() as session:
387
+ # Get batch info
388
+ url = f"https://api.openai.com/v1/batches/{batch_id}"
389
+ async with session.get(url, headers=headers) as response:
390
+ if response.status != 200:
391
+ text = await response.text()
392
+ raise ValueError(f"Error retrieving batch: {text}")
393
+
394
+ batch_data = await response.json()
395
+
396
+ if batch_data["status"] != "completed":
397
+ raise ValueError(
398
+ f"Batch {batch_id} is not completed. Status: {batch_data['status']}"
399
+ )
400
+
401
+ # Get output file
402
+ output_file_id = batch_data["output_file_id"]
403
+ if not output_file_id:
404
+ raise ValueError(f"No output file available for batch {batch_id}")
405
+
406
+ url = f"https://api.openai.com/v1/files/{output_file_id}/content"
407
+ async with session.get(url, headers=headers) as response:
408
+ if response.status != 200:
409
+ text = await response.text()
410
+ raise ValueError(f"Error retrieving batch results: {text}")
411
+
412
+ # Parse JSONL results
413
+ results = []
414
+ text = await response.text()
415
+ for line in text.strip().split("\n"):
416
+ if line:
417
+ result = json.loads(line)
418
+ results.append(result)
419
+
420
+ # Sort by custom_id to maintain order
421
+ results.sort(key=lambda x: int(x["custom_id"]))
422
+
423
+ return results
424
+
425
+
426
+ async def _wait_for_openai_batch_completion_async(
427
+ batch_id: str, poll_interval: int = 30
428
+ ):
429
+ """Poll OpenAI batch until completion and return results asynchronously."""
430
+ api_key = os.getenv("OPENAI_API_KEY")
431
+ if api_key is None:
432
+ raise ValueError("OPENAI_API_KEY environment variable must be set.")
433
+
434
+ headers = {
435
+ "Authorization": f"Bearer {api_key}",
436
+ "Content-Type": "application/json",
437
+ }
438
+
439
+ url = f"https://api.openai.com/v1/batches/{batch_id}"
440
+ console = Console()
441
+ start_time = time.time()
442
+
443
+ # Event to signal when to stop the display updater
444
+ stop_display_event = asyncio.Event()
445
+ current_status = {"status": "pending", "counts": None}
446
+
447
+ async def display_updater():
448
+ """Update display independently of polling."""
449
+ with Live(console=console, refresh_per_second=10) as live:
450
+ while not stop_display_event.is_set():
451
+ elapsed = time.time() - start_time
452
+ display = _create_batch_status_display(
453
+ batch_id,
454
+ current_status["status"],
455
+ elapsed,
456
+ current_status["counts"],
457
+ "openai",
458
+ )
459
+ live.update(display)
460
+ await asyncio.sleep(0.1) # Update every 100ms
461
+
462
+ # Start display updater
463
+ display_task = asyncio.create_task(display_updater())
464
+
465
+ try:
466
+ async with aiohttp.ClientSession() as session:
467
+ while True:
468
+ async with session.get(url, headers=headers) as response:
469
+ if response.status != 200:
470
+ text = await response.text()
471
+ raise ValueError(f"Error checking batch status: {text}")
472
+
473
+ batch_data = await response.json()
474
+ current_status["status"] = batch_data["status"]
475
+ current_status["counts"] = batch_data.get("request_counts", {})
476
+
477
+ if current_status["status"] == "completed":
478
+ stop_display_event.set()
479
+ await display_task
480
+ console.print(
481
+ f"✅ Batch {batch_id} completed!", style="green bold"
482
+ )
483
+ return await _retrieve_openai_batch_results_async(batch_id)
484
+ elif current_status["status"] in [
485
+ "failed",
486
+ "expired",
487
+ "cancelled",
488
+ ]:
489
+ stop_display_event.set()
490
+ await display_task
491
+ raise ValueError(
492
+ f"Batch {batch_id} failed with status: {current_status['status']}"
493
+ )
494
+
495
+ await asyncio.sleep(poll_interval)
496
+ finally:
497
+ stop_display_event.set()
498
+ await display_task