lm-deluge 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +9 -1
- lm_deluge/agent.py +0 -0
- lm_deluge/api_requests/anthropic.py +90 -58
- lm_deluge/api_requests/base.py +68 -39
- lm_deluge/api_requests/bedrock.py +34 -10
- lm_deluge/api_requests/common.py +2 -1
- lm_deluge/api_requests/mistral.py +6 -15
- lm_deluge/api_requests/openai.py +270 -44
- lm_deluge/batches.py +498 -0
- lm_deluge/client.py +368 -638
- lm_deluge/computer_use/anthropic_tools.py +75 -0
- lm_deluge/{sampling_params.py → config.py} +10 -3
- lm_deluge/embed.py +17 -11
- lm_deluge/models.py +33 -0
- lm_deluge/prompt.py +86 -6
- lm_deluge/rerank.py +18 -12
- lm_deluge/tool.py +11 -1
- lm_deluge/tracker.py +212 -2
- lm_deluge/util/json.py +18 -1
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.13.dist-info}/METADATA +5 -5
- lm_deluge-0.0.13.dist-info/RECORD +42 -0
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.13.dist-info}/WHEEL +1 -1
- lm_deluge-0.0.12.dist-info/RECORD +0 -39
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.13.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.12.dist-info → lm_deluge-0.0.13.dist-info}/top_level.txt +0 -0
lm_deluge/batches.py
ADDED
|
@@ -0,0 +1,498 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import time
|
|
4
|
+
import asyncio
|
|
5
|
+
import aiohttp
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from lm_deluge.prompt import CachePattern, Conversation, prompts_to_conversations
|
|
8
|
+
from lm_deluge.config import SamplingParams
|
|
9
|
+
from lm_deluge.models import APIModel
|
|
10
|
+
from typing import Sequence, Literal
|
|
11
|
+
from lm_deluge.api_requests.openai import _build_oa_chat_request
|
|
12
|
+
from lm_deluge.api_requests.anthropic import _build_anthropic_request
|
|
13
|
+
from rich.console import Console
|
|
14
|
+
from rich.live import Live
|
|
15
|
+
from rich.spinner import Spinner
|
|
16
|
+
from rich.table import Table
|
|
17
|
+
from rich.text import Text
|
|
18
|
+
from lm_deluge.models import registry
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _create_batch_status_display(
|
|
22
|
+
batch_id: str,
|
|
23
|
+
status: str,
|
|
24
|
+
elapsed: float,
|
|
25
|
+
counts: dict | None,
|
|
26
|
+
provider: str,
|
|
27
|
+
):
|
|
28
|
+
"""Create a unified status display for batch jobs."""
|
|
29
|
+
# Format elapsed time
|
|
30
|
+
hours = int(elapsed // 3600)
|
|
31
|
+
minutes = int((elapsed % 3600) // 60)
|
|
32
|
+
seconds = int(elapsed % 60)
|
|
33
|
+
|
|
34
|
+
if hours > 0:
|
|
35
|
+
elapsed_str = f"{hours}h {minutes}m {seconds}s"
|
|
36
|
+
elif minutes > 0:
|
|
37
|
+
elapsed_str = f"{minutes}m {seconds}s"
|
|
38
|
+
else:
|
|
39
|
+
elapsed_str = f"{seconds}s"
|
|
40
|
+
|
|
41
|
+
# Build progress text based on provider
|
|
42
|
+
progress_text = ""
|
|
43
|
+
if counts:
|
|
44
|
+
if provider == "openai":
|
|
45
|
+
total = counts.get("total", 0)
|
|
46
|
+
completed = counts.get("completed", 0)
|
|
47
|
+
failed = counts.get("failed", 0)
|
|
48
|
+
total_display = "?" if total == 0 else str(total)
|
|
49
|
+
progress_text = f" • {completed}/{total_display} done"
|
|
50
|
+
if failed > 0:
|
|
51
|
+
progress_text += f", {failed} failed"
|
|
52
|
+
elif provider == "anthropic":
|
|
53
|
+
total = (
|
|
54
|
+
counts.get("processing", 0)
|
|
55
|
+
+ counts.get("succeeded", 0)
|
|
56
|
+
+ counts.get("errored", 0)
|
|
57
|
+
)
|
|
58
|
+
succeeded = counts.get("succeeded", 0)
|
|
59
|
+
errored = counts.get("errored", 0)
|
|
60
|
+
total_display = "?" if total == 0 else str(total)
|
|
61
|
+
progress_text = f" • {succeeded}/{total_display} done"
|
|
62
|
+
if errored > 0:
|
|
63
|
+
progress_text += f", {errored} errors"
|
|
64
|
+
|
|
65
|
+
# Choose spinner color based on provider
|
|
66
|
+
spinner_style = "green" if provider == "openai" else "blue"
|
|
67
|
+
spinner = Spinner("dots", style=spinner_style, text="")
|
|
68
|
+
|
|
69
|
+
grid = Table.grid()
|
|
70
|
+
grid.add_column()
|
|
71
|
+
grid.add_column()
|
|
72
|
+
grid.add_row(
|
|
73
|
+
spinner,
|
|
74
|
+
Text(
|
|
75
|
+
f" Batch {batch_id} • {status} • {elapsed_str}{progress_text}",
|
|
76
|
+
style="white",
|
|
77
|
+
),
|
|
78
|
+
)
|
|
79
|
+
return grid
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
async def submit_batch_oa(batch_requests: list[dict]):
|
|
83
|
+
"""Submit one batch asynchronously."""
|
|
84
|
+
pd.DataFrame(batch_requests).to_json(
|
|
85
|
+
"requests_temp.jsonl", orient="records", lines=True
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# upload the file
|
|
89
|
+
api_key = os.environ.get("OPENAI_API_KEY", None)
|
|
90
|
+
if api_key is None:
|
|
91
|
+
raise ValueError("OPENAI_API_KEY environment variable must be set.")
|
|
92
|
+
|
|
93
|
+
headers = {
|
|
94
|
+
"Authorization": f"Bearer {api_key}",
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
async with aiohttp.ClientSession() as session:
|
|
98
|
+
# Upload file
|
|
99
|
+
url = "https://api.openai.com/v1/files"
|
|
100
|
+
data = aiohttp.FormData()
|
|
101
|
+
data.add_field("purpose", "batch")
|
|
102
|
+
data.add_field(
|
|
103
|
+
"file",
|
|
104
|
+
open("requests_temp.jsonl", "rb"),
|
|
105
|
+
filename="requests_temp.jsonl",
|
|
106
|
+
content_type="application/json",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
async with session.post(url, data=data, headers=headers) as response:
|
|
110
|
+
if response.status != 200:
|
|
111
|
+
text = await response.text()
|
|
112
|
+
raise ValueError(f"Error uploading file: {text}")
|
|
113
|
+
|
|
114
|
+
print("File uploaded successfully")
|
|
115
|
+
response_data = await response.json()
|
|
116
|
+
file_id = response_data["id"]
|
|
117
|
+
|
|
118
|
+
# Create batch
|
|
119
|
+
url = "https://api.openai.com/v1/batches"
|
|
120
|
+
batch_data = {
|
|
121
|
+
"input_file_id": file_id,
|
|
122
|
+
"endpoint": "/v1/chat/completions",
|
|
123
|
+
"completion_window": "24h",
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
async with session.post(url, json=batch_data, headers=headers) as response:
|
|
127
|
+
if response.status != 200:
|
|
128
|
+
text = await response.text()
|
|
129
|
+
raise ValueError(f"Error starting batch job: {text}")
|
|
130
|
+
|
|
131
|
+
response_data = await response.json()
|
|
132
|
+
batch_id = response_data["id"]
|
|
133
|
+
print("Batch job started successfully: id = ", batch_id)
|
|
134
|
+
return batch_id
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
async def submit_batches_oa(
|
|
138
|
+
model: str,
|
|
139
|
+
sampling_params: SamplingParams,
|
|
140
|
+
prompts: Sequence[str | list[dict] | Conversation],
|
|
141
|
+
):
|
|
142
|
+
# if prompts are strings, convert them to message lists
|
|
143
|
+
prompts = prompts_to_conversations(prompts)
|
|
144
|
+
if any(p is None for p in prompts):
|
|
145
|
+
raise ValueError("All prompts must be valid.")
|
|
146
|
+
ids = [i for i, _ in enumerate(prompts)]
|
|
147
|
+
|
|
148
|
+
# create file with requests to send to batch api
|
|
149
|
+
batch_requests = []
|
|
150
|
+
model_obj = APIModel.from_registry(model)
|
|
151
|
+
for id, prompt in zip(ids, prompts):
|
|
152
|
+
assert isinstance(prompt, Conversation)
|
|
153
|
+
batch_requests.append(
|
|
154
|
+
{
|
|
155
|
+
"custom_id": str(id),
|
|
156
|
+
"method": "POST",
|
|
157
|
+
"url": "/v1/chat/completions",
|
|
158
|
+
"body": _build_oa_chat_request(model_obj, prompt, [], sampling_params),
|
|
159
|
+
}
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# since the api only accepts up to 50,000 requests per batch job, we chunk into 50k chunks
|
|
163
|
+
BATCH_SIZE = 50_000
|
|
164
|
+
batches = [
|
|
165
|
+
batch_requests[i : i + BATCH_SIZE]
|
|
166
|
+
for i in range(0, len(batch_requests), BATCH_SIZE)
|
|
167
|
+
]
|
|
168
|
+
tasks = []
|
|
169
|
+
for batch in batches:
|
|
170
|
+
tasks.append(asyncio.create_task(submit_batch_oa(batch)))
|
|
171
|
+
batch_ids = await asyncio.gather(*tasks)
|
|
172
|
+
|
|
173
|
+
print(f"Submitted {len(batches)} batch jobs.")
|
|
174
|
+
|
|
175
|
+
return batch_ids
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
async def submit_batches_anthropic(
|
|
179
|
+
model: str,
|
|
180
|
+
sampling_params: SamplingParams,
|
|
181
|
+
prompts: Sequence[str | list[dict] | Conversation],
|
|
182
|
+
*,
|
|
183
|
+
cache: CachePattern | None = None,
|
|
184
|
+
):
|
|
185
|
+
"""Submit a batch job to Anthropic's Message Batches API.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
prompts: List of prompts to process
|
|
189
|
+
wait_for_completion: If True, poll until completion and return results
|
|
190
|
+
poll_interval: Seconds to wait between status checks when polling
|
|
191
|
+
tools: Optional tools to include in requests
|
|
192
|
+
cache: Optional cache pattern for requests
|
|
193
|
+
|
|
194
|
+
Returns: batch_ids (list[str])
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
# Convert prompts to Conversations
|
|
198
|
+
prompts = prompts_to_conversations(prompts)
|
|
199
|
+
# Create batch requests
|
|
200
|
+
request_headers = None
|
|
201
|
+
batch_requests = []
|
|
202
|
+
for i, prompt in enumerate(prompts):
|
|
203
|
+
assert isinstance(prompt, Conversation)
|
|
204
|
+
# Build request body
|
|
205
|
+
request_body, request_headers = _build_anthropic_request(
|
|
206
|
+
APIModel.from_registry(model), prompt, [], sampling_params, cache
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
batch_requests.append({"custom_id": str(i), "params": request_body})
|
|
210
|
+
|
|
211
|
+
# Chunk into batches of 100k requests (Anthropic's limit)
|
|
212
|
+
BATCH_SIZE = 100_000
|
|
213
|
+
batches = [
|
|
214
|
+
batch_requests[i : i + BATCH_SIZE]
|
|
215
|
+
for i in range(0, len(batch_requests), BATCH_SIZE)
|
|
216
|
+
]
|
|
217
|
+
batch_ids = []
|
|
218
|
+
batch_tasks = []
|
|
219
|
+
async with aiohttp.ClientSession() as session:
|
|
220
|
+
for batch in batches:
|
|
221
|
+
url = f"{registry[model]['api_base']}/messages/batches"
|
|
222
|
+
data = {"requests": batch}
|
|
223
|
+
|
|
224
|
+
async def submit_batch(data, url, headers):
|
|
225
|
+
async with session.post(url, json=data, headers=headers) as response:
|
|
226
|
+
if response.status != 200:
|
|
227
|
+
text = await response.text()
|
|
228
|
+
raise ValueError(f"Error creating batch: {text}")
|
|
229
|
+
|
|
230
|
+
batch_data = await response.json()
|
|
231
|
+
batch_id = batch_data["id"]
|
|
232
|
+
print(f"Anthropic batch job started successfully: id = {batch_id}")
|
|
233
|
+
return batch_id
|
|
234
|
+
|
|
235
|
+
batch_tasks.append(submit_batch(data, url, request_headers))
|
|
236
|
+
|
|
237
|
+
batch_ids = await asyncio.gather(*batch_tasks)
|
|
238
|
+
|
|
239
|
+
print(f"Submitted {len(batches)} batch jobs.")
|
|
240
|
+
return batch_ids
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
async def wait_for_batch_completion_async(
|
|
244
|
+
batch_ids: list[str],
|
|
245
|
+
provider: Literal["openai", "anthropic"],
|
|
246
|
+
poll_interval: int = 30,
|
|
247
|
+
):
|
|
248
|
+
"""Wait for multiple batches to complete and return results asynchronously.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
batch_ids: List of batch IDs to wait for
|
|
252
|
+
provider: Which provider the batches are from
|
|
253
|
+
poll_interval: Seconds to wait between status checks
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
List of results for each batch
|
|
257
|
+
"""
|
|
258
|
+
tasks = []
|
|
259
|
+
for batch_id in batch_ids:
|
|
260
|
+
if provider == "openai":
|
|
261
|
+
task = _wait_for_openai_batch_completion_async(batch_id, poll_interval)
|
|
262
|
+
elif provider == "anthropic":
|
|
263
|
+
task = _wait_for_anthropic_batch_completion_async(batch_id, poll_interval)
|
|
264
|
+
else:
|
|
265
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
266
|
+
tasks.append(task)
|
|
267
|
+
|
|
268
|
+
# Wait for all batches concurrently
|
|
269
|
+
results = await asyncio.gather(*tasks)
|
|
270
|
+
|
|
271
|
+
results = [compl for batch in results for compl in batch]
|
|
272
|
+
|
|
273
|
+
return results
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
async def _wait_for_anthropic_batch_completion_async(
|
|
277
|
+
batch_id: str, poll_interval: int = 30
|
|
278
|
+
):
|
|
279
|
+
"""Poll Anthropic batch until completion and return results asynchronously."""
|
|
280
|
+
api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
281
|
+
headers = {
|
|
282
|
+
"x-api-key": api_key,
|
|
283
|
+
"anthropic-version": "2023-06-01",
|
|
284
|
+
"content-type": "application/json",
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
url = f"https://api.anthropic.com/v1/messages/batches/{batch_id}"
|
|
288
|
+
console = Console()
|
|
289
|
+
start_time = time.time()
|
|
290
|
+
|
|
291
|
+
# Event to signal when to stop the display updater
|
|
292
|
+
stop_display_event = asyncio.Event()
|
|
293
|
+
current_status = {"status": "processing", "counts": None}
|
|
294
|
+
|
|
295
|
+
async def display_updater():
|
|
296
|
+
"""Update display independently of polling."""
|
|
297
|
+
with Live(console=console, refresh_per_second=10) as live:
|
|
298
|
+
while not stop_display_event.is_set():
|
|
299
|
+
elapsed = time.time() - start_time
|
|
300
|
+
display = _create_batch_status_display(
|
|
301
|
+
batch_id,
|
|
302
|
+
current_status["status"],
|
|
303
|
+
elapsed,
|
|
304
|
+
current_status["counts"],
|
|
305
|
+
"anthropic",
|
|
306
|
+
)
|
|
307
|
+
live.update(display)
|
|
308
|
+
await asyncio.sleep(0.1) # Update every 100ms
|
|
309
|
+
|
|
310
|
+
# Start display updater
|
|
311
|
+
display_task = asyncio.create_task(display_updater())
|
|
312
|
+
|
|
313
|
+
try:
|
|
314
|
+
async with aiohttp.ClientSession() as session:
|
|
315
|
+
while True:
|
|
316
|
+
async with session.get(url, headers=headers) as response:
|
|
317
|
+
if response.status != 200:
|
|
318
|
+
text = await response.text()
|
|
319
|
+
raise ValueError(f"Error checking batch status: {text}")
|
|
320
|
+
|
|
321
|
+
batch_data = await response.json()
|
|
322
|
+
current_status["status"] = batch_data["processing_status"]
|
|
323
|
+
current_status["counts"] = batch_data.get("request_counts", {})
|
|
324
|
+
|
|
325
|
+
if current_status["status"] == "ended":
|
|
326
|
+
stop_display_event.set()
|
|
327
|
+
await display_task
|
|
328
|
+
console.print(
|
|
329
|
+
f"✅ Batch {batch_id} completed!", style="green bold"
|
|
330
|
+
)
|
|
331
|
+
return await _retrieve_anthropic_batch_results_async(batch_id)
|
|
332
|
+
elif current_status["status"] in ["canceled", "expired"]:
|
|
333
|
+
stop_display_event.set()
|
|
334
|
+
await display_task
|
|
335
|
+
raise ValueError(
|
|
336
|
+
f"Batch {batch_id} failed with status: {current_status['status']}"
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
await asyncio.sleep(poll_interval)
|
|
340
|
+
finally:
|
|
341
|
+
stop_display_event.set()
|
|
342
|
+
await display_task
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
async def _retrieve_anthropic_batch_results_async(batch_id: str):
|
|
346
|
+
"""Retrieve results from completed Anthropic batch asynchronously."""
|
|
347
|
+
api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
348
|
+
headers = {
|
|
349
|
+
"x-api-key": api_key,
|
|
350
|
+
"anthropic-version": "2023-06-01",
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
url = f"https://api.anthropic.com/v1/messages/batches/{batch_id}/results"
|
|
354
|
+
|
|
355
|
+
async with aiohttp.ClientSession() as session:
|
|
356
|
+
async with session.get(url, headers=headers) as response:
|
|
357
|
+
if response.status != 200:
|
|
358
|
+
text = await response.text()
|
|
359
|
+
raise ValueError(f"Error retrieving batch results: {text}")
|
|
360
|
+
|
|
361
|
+
# Parse JSONL results
|
|
362
|
+
results = []
|
|
363
|
+
text = await response.text()
|
|
364
|
+
for line in text.strip().split("\n"):
|
|
365
|
+
if line:
|
|
366
|
+
result = json.loads(line)
|
|
367
|
+
results.append(result)
|
|
368
|
+
|
|
369
|
+
# Sort by custom_id to maintain order
|
|
370
|
+
results.sort(key=lambda x: int(x["custom_id"]))
|
|
371
|
+
|
|
372
|
+
return results
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
async def _retrieve_openai_batch_results_async(batch_id: str):
|
|
376
|
+
"""Retrieve results from OpenAI batch asynchronously."""
|
|
377
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
378
|
+
if api_key is None:
|
|
379
|
+
raise ValueError("OPENAI_API_KEY environment variable must be set.")
|
|
380
|
+
|
|
381
|
+
headers = {
|
|
382
|
+
"Authorization": f"Bearer {api_key}",
|
|
383
|
+
"Content-Type": "application/json",
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
async with aiohttp.ClientSession() as session:
|
|
387
|
+
# Get batch info
|
|
388
|
+
url = f"https://api.openai.com/v1/batches/{batch_id}"
|
|
389
|
+
async with session.get(url, headers=headers) as response:
|
|
390
|
+
if response.status != 200:
|
|
391
|
+
text = await response.text()
|
|
392
|
+
raise ValueError(f"Error retrieving batch: {text}")
|
|
393
|
+
|
|
394
|
+
batch_data = await response.json()
|
|
395
|
+
|
|
396
|
+
if batch_data["status"] != "completed":
|
|
397
|
+
raise ValueError(
|
|
398
|
+
f"Batch {batch_id} is not completed. Status: {batch_data['status']}"
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# Get output file
|
|
402
|
+
output_file_id = batch_data["output_file_id"]
|
|
403
|
+
if not output_file_id:
|
|
404
|
+
raise ValueError(f"No output file available for batch {batch_id}")
|
|
405
|
+
|
|
406
|
+
url = f"https://api.openai.com/v1/files/{output_file_id}/content"
|
|
407
|
+
async with session.get(url, headers=headers) as response:
|
|
408
|
+
if response.status != 200:
|
|
409
|
+
text = await response.text()
|
|
410
|
+
raise ValueError(f"Error retrieving batch results: {text}")
|
|
411
|
+
|
|
412
|
+
# Parse JSONL results
|
|
413
|
+
results = []
|
|
414
|
+
text = await response.text()
|
|
415
|
+
for line in text.strip().split("\n"):
|
|
416
|
+
if line:
|
|
417
|
+
result = json.loads(line)
|
|
418
|
+
results.append(result)
|
|
419
|
+
|
|
420
|
+
# Sort by custom_id to maintain order
|
|
421
|
+
results.sort(key=lambda x: int(x["custom_id"]))
|
|
422
|
+
|
|
423
|
+
return results
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
async def _wait_for_openai_batch_completion_async(
|
|
427
|
+
batch_id: str, poll_interval: int = 30
|
|
428
|
+
):
|
|
429
|
+
"""Poll OpenAI batch until completion and return results asynchronously."""
|
|
430
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
431
|
+
if api_key is None:
|
|
432
|
+
raise ValueError("OPENAI_API_KEY environment variable must be set.")
|
|
433
|
+
|
|
434
|
+
headers = {
|
|
435
|
+
"Authorization": f"Bearer {api_key}",
|
|
436
|
+
"Content-Type": "application/json",
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
url = f"https://api.openai.com/v1/batches/{batch_id}"
|
|
440
|
+
console = Console()
|
|
441
|
+
start_time = time.time()
|
|
442
|
+
|
|
443
|
+
# Event to signal when to stop the display updater
|
|
444
|
+
stop_display_event = asyncio.Event()
|
|
445
|
+
current_status = {"status": "pending", "counts": None}
|
|
446
|
+
|
|
447
|
+
async def display_updater():
|
|
448
|
+
"""Update display independently of polling."""
|
|
449
|
+
with Live(console=console, refresh_per_second=10) as live:
|
|
450
|
+
while not stop_display_event.is_set():
|
|
451
|
+
elapsed = time.time() - start_time
|
|
452
|
+
display = _create_batch_status_display(
|
|
453
|
+
batch_id,
|
|
454
|
+
current_status["status"],
|
|
455
|
+
elapsed,
|
|
456
|
+
current_status["counts"],
|
|
457
|
+
"openai",
|
|
458
|
+
)
|
|
459
|
+
live.update(display)
|
|
460
|
+
await asyncio.sleep(0.1) # Update every 100ms
|
|
461
|
+
|
|
462
|
+
# Start display updater
|
|
463
|
+
display_task = asyncio.create_task(display_updater())
|
|
464
|
+
|
|
465
|
+
try:
|
|
466
|
+
async with aiohttp.ClientSession() as session:
|
|
467
|
+
while True:
|
|
468
|
+
async with session.get(url, headers=headers) as response:
|
|
469
|
+
if response.status != 200:
|
|
470
|
+
text = await response.text()
|
|
471
|
+
raise ValueError(f"Error checking batch status: {text}")
|
|
472
|
+
|
|
473
|
+
batch_data = await response.json()
|
|
474
|
+
current_status["status"] = batch_data["status"]
|
|
475
|
+
current_status["counts"] = batch_data.get("request_counts", {})
|
|
476
|
+
|
|
477
|
+
if current_status["status"] == "completed":
|
|
478
|
+
stop_display_event.set()
|
|
479
|
+
await display_task
|
|
480
|
+
console.print(
|
|
481
|
+
f"✅ Batch {batch_id} completed!", style="green bold"
|
|
482
|
+
)
|
|
483
|
+
return await _retrieve_openai_batch_results_async(batch_id)
|
|
484
|
+
elif current_status["status"] in [
|
|
485
|
+
"failed",
|
|
486
|
+
"expired",
|
|
487
|
+
"cancelled",
|
|
488
|
+
]:
|
|
489
|
+
stop_display_event.set()
|
|
490
|
+
await display_task
|
|
491
|
+
raise ValueError(
|
|
492
|
+
f"Batch {batch_id} failed with status: {current_status['status']}"
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
await asyncio.sleep(poll_interval)
|
|
496
|
+
finally:
|
|
497
|
+
stop_display_event.set()
|
|
498
|
+
await display_task
|