cat-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cat_stack/__about__.py +10 -0
- cat_stack/__init__.py +128 -0
- cat_stack/_batch.py +1388 -0
- cat_stack/_category_analysis.py +348 -0
- cat_stack/_chunked.py +424 -0
- cat_stack/_embeddings.py +189 -0
- cat_stack/_formatter.py +169 -0
- cat_stack/_providers.py +1048 -0
- cat_stack/_tiebreaker.py +277 -0
- cat_stack/_utils.py +512 -0
- cat_stack/_web_fetch.py +194 -0
- cat_stack/calls/CoVe.py +287 -0
- cat_stack/calls/__init__.py +25 -0
- cat_stack/calls/all_calls.py +622 -0
- cat_stack/calls/image_CoVe.py +386 -0
- cat_stack/calls/image_stepback.py +210 -0
- cat_stack/calls/pdf_CoVe.py +386 -0
- cat_stack/calls/pdf_stepback.py +210 -0
- cat_stack/calls/stepback.py +180 -0
- cat_stack/calls/top_n.py +217 -0
- cat_stack/classify.py +682 -0
- cat_stack/explore.py +111 -0
- cat_stack/extract.py +218 -0
- cat_stack/image_functions.py +2078 -0
- cat_stack/images/circle.png +0 -0
- cat_stack/images/cube.png +0 -0
- cat_stack/images/diamond.png +0 -0
- cat_stack/images/overlapping_pentagons.png +0 -0
- cat_stack/images/rectangles.png +0 -0
- cat_stack/model_reference_list.py +94 -0
- cat_stack/pdf_functions.py +2087 -0
- cat_stack/summarize.py +290 -0
- cat_stack/text_functions.py +1358 -0
- cat_stack/text_functions_ensemble.py +3644 -0
- cat_stack-0.1.0.dist-info/METADATA +150 -0
- cat_stack-0.1.0.dist-info/RECORD +38 -0
- cat_stack-0.1.0.dist-info/WHEEL +4 -0
- cat_stack-0.1.0.dist-info/licenses/LICENSE +672 -0
cat_stack/_batch.py
ADDED
|
@@ -0,0 +1,1388 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Async batch inference for CatLLM.
|
|
3
|
+
|
|
4
|
+
Supports OpenAI, Anthropic, Google, Mistral, and xAI — all offer 50% cost
|
|
5
|
+
savings and higher rate limits compared to synchronous API calls.
|
|
6
|
+
|
|
7
|
+
All five providers follow the same conceptual pattern:
|
|
8
|
+
1. Package all requests as JSONL
|
|
9
|
+
2. Submit a batch job (file upload → job creation, or inline for Anthropic)
|
|
10
|
+
3. Poll until the job reaches a terminal state
|
|
11
|
+
4. Download and parse results
|
|
12
|
+
5. Return a DataFrame identical in format to the synchronous single-model path
|
|
13
|
+
|
|
14
|
+
Not supported: HuggingFace, Perplexity, Ollama (no batch API).
|
|
15
|
+
Ensemble mode: supported. Each model submits its own batch job concurrently.
|
|
16
|
+
Providers without batch API (HuggingFace, Perplexity, Ollama) fall back to
|
|
17
|
+
synchronous calls and are merged in with the batch results.
|
|
18
|
+
Not compatible: PDF/image input (text only).
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import io
|
|
22
|
+
import json
|
|
23
|
+
import os
|
|
24
|
+
import time
|
|
25
|
+
|
|
26
|
+
import requests
|
|
27
|
+
|
|
28
|
+
from ._providers import UnifiedLLMClient
|
|
29
|
+
from ._utils import extract_json
|
|
30
|
+
|
|
31
|
+
# =============================================================================
|
|
32
|
+
# Constants
|
|
33
|
+
# =============================================================================
|
|
34
|
+
|
|
35
|
+
BATCH_ENDPOINTS = {
|
|
36
|
+
"openai": {
|
|
37
|
+
"upload": "https://api.openai.com/v1/files",
|
|
38
|
+
"create": "https://api.openai.com/v1/batches",
|
|
39
|
+
"status": "https://api.openai.com/v1/batches/{job_id}",
|
|
40
|
+
"results": "https://api.openai.com/v1/files/{file_id}/content",
|
|
41
|
+
},
|
|
42
|
+
"anthropic": {
|
|
43
|
+
# No file upload — requests are sent inline at job creation
|
|
44
|
+
"create": "https://api.anthropic.com/v1/messages/batches",
|
|
45
|
+
"status": "https://api.anthropic.com/v1/messages/batches/{job_id}",
|
|
46
|
+
"results": "https://api.anthropic.com/v1/messages/batches/{job_id}/results",
|
|
47
|
+
},
|
|
48
|
+
"google": {
|
|
49
|
+
"upload": "https://generativelanguage.googleapis.com/upload/v1beta/files",
|
|
50
|
+
"create": "https://generativelanguage.googleapis.com/v1beta/models/{model}:batchGenerateContent",
|
|
51
|
+
"status": "https://generativelanguage.googleapis.com/v1beta/{job_name}",
|
|
52
|
+
"download": "https://generativelanguage.googleapis.com/download/v1beta/{file_name}:download",
|
|
53
|
+
},
|
|
54
|
+
"mistral": {
|
|
55
|
+
"upload": "https://api.mistral.ai/v1/files",
|
|
56
|
+
"create": "https://api.mistral.ai/v1/batch/jobs",
|
|
57
|
+
"status": "https://api.mistral.ai/v1/batch/jobs/{job_id}",
|
|
58
|
+
"results": "https://api.mistral.ai/v1/files/{file_id}/content",
|
|
59
|
+
},
|
|
60
|
+
"xai": {
|
|
61
|
+
"create": "https://api.x.ai/v1/batches",
|
|
62
|
+
"add": "https://api.x.ai/v1/batches/{job_id}/requests",
|
|
63
|
+
"status": "https://api.x.ai/v1/batches/{job_id}",
|
|
64
|
+
"results": "https://api.x.ai/v1/batches/{job_id}/results",
|
|
65
|
+
},
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
UNSUPPORTED_BATCH_PROVIDERS = {"huggingface", "huggingface-together", "perplexity", "ollama"}
|
|
69
|
+
|
|
70
|
+
# Terminal states per provider
|
|
71
|
+
_TERMINAL_STATES = {
|
|
72
|
+
"openai": {"completed", "failed", "expired", "cancelled"},
|
|
73
|
+
"anthropic": {"ended"},
|
|
74
|
+
"google": {
|
|
75
|
+
"BATCH_STATE_SUCCEEDED", "BATCH_STATE_FAILED", "BATCH_STATE_CANCELLED", "BATCH_STATE_EXPIRED",
|
|
76
|
+
"JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED",
|
|
77
|
+
},
|
|
78
|
+
"mistral": {"SUCCESS", "FAILED", "TIMEOUT_EXCEEDED", "CANCELLATION_REQUESTED"},
|
|
79
|
+
"xai": {"completed", "failed", "expired", "cancelled"},
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
_SUCCESS_STATES = {
|
|
83
|
+
"openai": {"completed"},
|
|
84
|
+
"anthropic": {"ended"}, # must check request_counts.failed separately
|
|
85
|
+
"google": {"BATCH_STATE_SUCCEEDED", "JOB_STATE_SUCCEEDED"},
|
|
86
|
+
"mistral": {"SUCCESS"},
|
|
87
|
+
"xai": {"completed"},
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# =============================================================================
|
|
92
|
+
# Exceptions
|
|
93
|
+
# =============================================================================
|
|
94
|
+
|
|
95
|
+
class BatchJobExpiredError(RuntimeError):
|
|
96
|
+
"""Raised when a batch job expires before completing."""
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class BatchJobFailedError(RuntimeError):
|
|
101
|
+
"""Raised when a batch job terminates in a failed state."""
|
|
102
|
+
pass
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# =============================================================================
|
|
106
|
+
# Auth headers
|
|
107
|
+
# =============================================================================
|
|
108
|
+
|
|
109
|
+
def _get_batch_headers(provider: str, api_key: str) -> dict:
|
|
110
|
+
"""Return HTTP headers for the given provider's batch API."""
|
|
111
|
+
headers = {"Content-Type": "application/json"}
|
|
112
|
+
if provider == "openai":
|
|
113
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
114
|
+
elif provider == "anthropic":
|
|
115
|
+
headers["x-api-key"] = api_key
|
|
116
|
+
headers["anthropic-version"] = "2023-06-01"
|
|
117
|
+
headers["anthropic-beta"] = "message-batches-2024-09-24"
|
|
118
|
+
elif provider == "google":
|
|
119
|
+
headers["x-goog-api-key"] = api_key
|
|
120
|
+
elif provider == "mistral":
|
|
121
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
122
|
+
elif provider == "xai":
|
|
123
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
124
|
+
return headers
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# =============================================================================
|
|
128
|
+
# JSONL line builders
|
|
129
|
+
# =============================================================================
|
|
130
|
+
|
|
131
|
+
def _build_jsonl_line(provider: str, custom_id: str, payload: dict, model: str) -> dict:
|
|
132
|
+
"""
|
|
133
|
+
Wrap a provider payload in the provider's batch JSONL envelope.
|
|
134
|
+
|
|
135
|
+
Returns a dict that will be serialized as one JSONL line.
|
|
136
|
+
"""
|
|
137
|
+
if provider == "openai":
|
|
138
|
+
return {
|
|
139
|
+
"custom_id": custom_id,
|
|
140
|
+
"method": "POST",
|
|
141
|
+
"url": "/v1/chat/completions",
|
|
142
|
+
"body": payload,
|
|
143
|
+
}
|
|
144
|
+
elif provider == "anthropic":
|
|
145
|
+
# Anthropic uses "params" key; no file upload needed
|
|
146
|
+
return {
|
|
147
|
+
"custom_id": custom_id,
|
|
148
|
+
"params": payload,
|
|
149
|
+
}
|
|
150
|
+
elif provider == "google":
|
|
151
|
+
# Google batch JSONL format: request payload + metadata with key
|
|
152
|
+
return {
|
|
153
|
+
"request": payload,
|
|
154
|
+
"metadata": {"key": custom_id},
|
|
155
|
+
}
|
|
156
|
+
elif provider == "mistral":
|
|
157
|
+
return {
|
|
158
|
+
"custom_id": custom_id,
|
|
159
|
+
"method": "POST",
|
|
160
|
+
"url": "/v1/chat/completions",
|
|
161
|
+
"body": payload,
|
|
162
|
+
}
|
|
163
|
+
elif provider == "xai":
|
|
164
|
+
# xAI requests are added one-by-one after batch creation; same OpenAI-compat format
|
|
165
|
+
return {
|
|
166
|
+
"custom_id": custom_id,
|
|
167
|
+
"method": "POST",
|
|
168
|
+
"url": "/v1/chat/completions",
|
|
169
|
+
"body": payload,
|
|
170
|
+
}
|
|
171
|
+
raise ValueError(f"Unsupported batch provider: {provider}")
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
# =============================================================================
|
|
175
|
+
# File upload (OpenAI, Google, Mistral)
|
|
176
|
+
# =============================================================================
|
|
177
|
+
|
|
178
|
+
def _upload_jsonl(provider: str, api_key: str, jsonl_bytes: bytes, filename: str = "batch_requests.jsonl") -> str:
|
|
179
|
+
"""
|
|
180
|
+
Upload a JSONL file to the provider's files API.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
file_id string used when creating the batch job.
|
|
184
|
+
"""
|
|
185
|
+
headers = _get_batch_headers(provider, api_key)
|
|
186
|
+
# Content-Type for multipart upload — remove JSON header
|
|
187
|
+
headers.pop("Content-Type", None)
|
|
188
|
+
|
|
189
|
+
if provider == "openai":
|
|
190
|
+
url = BATCH_ENDPOINTS["openai"]["upload"]
|
|
191
|
+
files = {"file": (filename, io.BytesIO(jsonl_bytes), "application/jsonl")}
|
|
192
|
+
data = {"purpose": "batch"}
|
|
193
|
+
resp = requests.post(url, headers=headers, files=files, data=data, timeout=120)
|
|
194
|
+
resp.raise_for_status()
|
|
195
|
+
return resp.json()["id"]
|
|
196
|
+
|
|
197
|
+
elif provider == "mistral":
|
|
198
|
+
url = BATCH_ENDPOINTS["mistral"]["upload"]
|
|
199
|
+
files = {"file": (filename, io.BytesIO(jsonl_bytes), "application/octet-stream")}
|
|
200
|
+
data = {"purpose": "batch"}
|
|
201
|
+
resp = requests.post(url, headers=headers, files=files, data=data, timeout=120)
|
|
202
|
+
resp.raise_for_status()
|
|
203
|
+
return resp.json()["id"]
|
|
204
|
+
|
|
205
|
+
elif provider == "google":
|
|
206
|
+
# Google Files API: resumable upload
|
|
207
|
+
upload_url = BATCH_ENDPOINTS["google"]["upload"]
|
|
208
|
+
# Step 1: Initiate upload
|
|
209
|
+
init_headers = {
|
|
210
|
+
"x-goog-api-key": api_key,
|
|
211
|
+
"X-Goog-Upload-Protocol": "resumable",
|
|
212
|
+
"X-Goog-Upload-Command": "start",
|
|
213
|
+
"X-Goog-Upload-Header-Content-Type": "application/jsonl",
|
|
214
|
+
"X-Goog-Upload-Header-Content-Length": str(len(jsonl_bytes)),
|
|
215
|
+
"Content-Type": "application/json",
|
|
216
|
+
}
|
|
217
|
+
init_body = json.dumps({"file": {"display_name": filename}})
|
|
218
|
+
init_resp = requests.post(upload_url, headers=init_headers, data=init_body, timeout=60)
|
|
219
|
+
init_resp.raise_for_status()
|
|
220
|
+
session_url = init_resp.headers.get("X-Goog-Upload-URL")
|
|
221
|
+
if not session_url:
|
|
222
|
+
raise RuntimeError("Google file upload: no session URL returned")
|
|
223
|
+
# Step 2: Upload bytes
|
|
224
|
+
upload_headers = {
|
|
225
|
+
"X-Goog-Upload-Command": "upload, finalize",
|
|
226
|
+
"X-Goog-Upload-Offset": "0",
|
|
227
|
+
"Content-Type": "application/jsonl",
|
|
228
|
+
}
|
|
229
|
+
upload_resp = requests.post(session_url, headers=upload_headers, data=jsonl_bytes, timeout=120)
|
|
230
|
+
upload_resp.raise_for_status()
|
|
231
|
+
file_info = upload_resp.json()
|
|
232
|
+
# Response wraps file metadata under a "file" key: {"file": {"name": "files/abc", ...}}
|
|
233
|
+
file_obj = file_info.get("file", file_info)
|
|
234
|
+
file_name = file_obj.get("name") or file_obj.get("uri")
|
|
235
|
+
if not file_name:
|
|
236
|
+
raise RuntimeError(f"Google file upload: could not extract file name. Response: {file_info}")
|
|
237
|
+
return file_name
|
|
238
|
+
|
|
239
|
+
raise ValueError(f"Provider '{provider}' does not use file upload for batch")
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# =============================================================================
|
|
243
|
+
# Batch job creation
|
|
244
|
+
# =============================================================================
|
|
245
|
+
|
|
246
|
+
def _create_batch_job(
|
|
247
|
+
provider: str,
|
|
248
|
+
api_key: str,
|
|
249
|
+
model: str,
|
|
250
|
+
file_id: str = None,
|
|
251
|
+
requests_list: list = None,
|
|
252
|
+
) -> str:
|
|
253
|
+
"""
|
|
254
|
+
Create a batch job and return the job ID.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
provider: Provider name
|
|
258
|
+
api_key: API key
|
|
259
|
+
model: Model name (used for Mistral and xAI)
|
|
260
|
+
file_id: Uploaded file ID (OpenAI, Google, Mistral)
|
|
261
|
+
requests_list: Inline request list (Anthropic only)
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
job_id string for polling
|
|
265
|
+
"""
|
|
266
|
+
headers = _get_batch_headers(provider, api_key)
|
|
267
|
+
|
|
268
|
+
if provider == "openai":
|
|
269
|
+
url = BATCH_ENDPOINTS["openai"]["create"]
|
|
270
|
+
body = {
|
|
271
|
+
"input_file_id": file_id,
|
|
272
|
+
"endpoint": "/v1/chat/completions",
|
|
273
|
+
"completion_window": "24h",
|
|
274
|
+
}
|
|
275
|
+
resp = requests.post(url, headers=headers, json=body, timeout=60)
|
|
276
|
+
resp.raise_for_status()
|
|
277
|
+
return resp.json()["id"]
|
|
278
|
+
|
|
279
|
+
elif provider == "anthropic":
|
|
280
|
+
url = BATCH_ENDPOINTS["anthropic"]["create"]
|
|
281
|
+
body = {"requests": requests_list}
|
|
282
|
+
resp = requests.post(url, headers=headers, json=body, timeout=60)
|
|
283
|
+
resp.raise_for_status()
|
|
284
|
+
return resp.json()["id"]
|
|
285
|
+
|
|
286
|
+
elif provider == "google":
|
|
287
|
+
url = BATCH_ENDPOINTS["google"]["create"].format(model=model)
|
|
288
|
+
# Google inline batch: requests are sent in the body (no file upload needed)
|
|
289
|
+
body = {
|
|
290
|
+
"batch": {
|
|
291
|
+
"display_name": f"cat_stack_batch_{int(time.time())}",
|
|
292
|
+
"input_config": {
|
|
293
|
+
"requests": {
|
|
294
|
+
"requests": [
|
|
295
|
+
{"request": line["request"], "metadata": line["metadata"]}
|
|
296
|
+
for line in (requests_list or [])
|
|
297
|
+
]
|
|
298
|
+
}
|
|
299
|
+
},
|
|
300
|
+
}
|
|
301
|
+
}
|
|
302
|
+
resp = requests.post(url, headers=headers, json=body, timeout=60)
|
|
303
|
+
resp.raise_for_status()
|
|
304
|
+
result = resp.json()
|
|
305
|
+
# Google returns the job name (e.g. "batches/abc123") — use as job_id
|
|
306
|
+
return result.get("name", result.get("id"))
|
|
307
|
+
|
|
308
|
+
elif provider == "mistral":
|
|
309
|
+
url = BATCH_ENDPOINTS["mistral"]["create"]
|
|
310
|
+
body = {
|
|
311
|
+
"input_files": [file_id],
|
|
312
|
+
"model": model,
|
|
313
|
+
"endpoint": "/v1/chat/completions",
|
|
314
|
+
}
|
|
315
|
+
resp = requests.post(url, headers=headers, json=body, timeout=60)
|
|
316
|
+
resp.raise_for_status()
|
|
317
|
+
return resp.json()["id"]
|
|
318
|
+
|
|
319
|
+
elif provider == "xai":
|
|
320
|
+
# Step 1: Create empty batch
|
|
321
|
+
url = BATCH_ENDPOINTS["xai"]["create"]
|
|
322
|
+
body = {"completion_window": "24h"}
|
|
323
|
+
resp = requests.post(url, headers=headers, json=body, timeout=60)
|
|
324
|
+
resp.raise_for_status()
|
|
325
|
+
job_id = resp.json()["id"]
|
|
326
|
+
|
|
327
|
+
# Step 2: Add all requests to the batch
|
|
328
|
+
add_url = BATCH_ENDPOINTS["xai"]["add"].format(job_id=job_id)
|
|
329
|
+
add_resp = requests.post(add_url, headers=headers, json=requests_list, timeout=120)
|
|
330
|
+
add_resp.raise_for_status()
|
|
331
|
+
return job_id
|
|
332
|
+
|
|
333
|
+
raise ValueError(f"Unsupported batch provider: {provider}")
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
# =============================================================================
|
|
337
|
+
# Polling
|
|
338
|
+
# =============================================================================
|
|
339
|
+
|
|
340
|
+
def _poll_batch_job(
|
|
341
|
+
provider: str,
|
|
342
|
+
api_key: str,
|
|
343
|
+
job_id: str,
|
|
344
|
+
interval: float = 30.0,
|
|
345
|
+
timeout: float = 86400.0,
|
|
346
|
+
) -> dict:
|
|
347
|
+
"""
|
|
348
|
+
Poll the batch job until it reaches a terminal state.
|
|
349
|
+
|
|
350
|
+
Prints one-line status updates each poll cycle.
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
The final status response dict (contains output file ID or job name for result retrieval).
|
|
354
|
+
|
|
355
|
+
Raises:
|
|
356
|
+
BatchJobExpiredError: If the job expired or was cancelled.
|
|
357
|
+
BatchJobFailedError: If the job terminated in a failed state.
|
|
358
|
+
TimeoutError: If timeout is reached before the job completes.
|
|
359
|
+
"""
|
|
360
|
+
headers = _get_batch_headers(provider, api_key)
|
|
361
|
+
terminal = _TERMINAL_STATES[provider]
|
|
362
|
+
success = _SUCCESS_STATES[provider]
|
|
363
|
+
|
|
364
|
+
start = time.time()
|
|
365
|
+
attempt = 0
|
|
366
|
+
|
|
367
|
+
if provider == "google":
|
|
368
|
+
status_url = BATCH_ENDPOINTS["google"]["status"].format(job_name=job_id)
|
|
369
|
+
else:
|
|
370
|
+
status_url = BATCH_ENDPOINTS[provider]["status"].format(job_id=job_id)
|
|
371
|
+
|
|
372
|
+
while True:
|
|
373
|
+
elapsed = time.time() - start
|
|
374
|
+
if elapsed >= timeout:
|
|
375
|
+
raise TimeoutError(
|
|
376
|
+
f"Batch job '{job_id}' did not complete within {timeout/3600:.1f}h. "
|
|
377
|
+
f"Increase batch_timeout or switch to synchronous mode."
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
try:
|
|
381
|
+
resp = requests.get(status_url, headers=headers, timeout=30)
|
|
382
|
+
if resp.status_code >= 500:
|
|
383
|
+
# Server error — back off and retry
|
|
384
|
+
wait = min(60 * (2 ** min(attempt, 4)), 300)
|
|
385
|
+
print(f" [batch] Server error {resp.status_code}; retrying in {wait}s...")
|
|
386
|
+
time.sleep(wait)
|
|
387
|
+
attempt += 1
|
|
388
|
+
continue
|
|
389
|
+
resp.raise_for_status()
|
|
390
|
+
except requests.exceptions.RequestException as e:
|
|
391
|
+
wait = min(60 * (2 ** min(attempt, 4)), 300)
|
|
392
|
+
print(f" [batch] Network error ({e}); retrying in {wait}s...")
|
|
393
|
+
time.sleep(wait)
|
|
394
|
+
attempt += 1
|
|
395
|
+
continue
|
|
396
|
+
|
|
397
|
+
attempt = 0
|
|
398
|
+
status_data = resp.json()
|
|
399
|
+
|
|
400
|
+
# Extract state string per provider
|
|
401
|
+
if provider == "openai":
|
|
402
|
+
state = status_data.get("status", "")
|
|
403
|
+
counts = status_data.get("request_counts", {})
|
|
404
|
+
progress_str = (
|
|
405
|
+
f"completed={counts.get('completed', '?')} "
|
|
406
|
+
f"failed={counts.get('failed', '?')} "
|
|
407
|
+
f"total={counts.get('total', '?')}"
|
|
408
|
+
)
|
|
409
|
+
elif provider == "anthropic":
|
|
410
|
+
state = status_data.get("processing_status", "")
|
|
411
|
+
counts = status_data.get("request_counts", {})
|
|
412
|
+
progress_str = (
|
|
413
|
+
f"processing={counts.get('processing', '?')} "
|
|
414
|
+
f"succeeded={counts.get('succeeded', '?')} "
|
|
415
|
+
f"errored={counts.get('errored', '?')}"
|
|
416
|
+
)
|
|
417
|
+
elif provider == "google":
|
|
418
|
+
# State lives at metadata.state in the batchGenerateContent response
|
|
419
|
+
state = (status_data.get("metadata", {}).get("state", "")
|
|
420
|
+
or status_data.get("state", ""))
|
|
421
|
+
progress_str = f"state={state}"
|
|
422
|
+
elif provider == "mistral":
|
|
423
|
+
state = status_data.get("status", "")
|
|
424
|
+
progress_str = (
|
|
425
|
+
f"succeeded={status_data.get('succeeded_requests', '?')} "
|
|
426
|
+
f"failed={status_data.get('failed_requests', '?')} "
|
|
427
|
+
f"total={status_data.get('total_requests', '?')}"
|
|
428
|
+
)
|
|
429
|
+
elif provider == "xai":
|
|
430
|
+
state = status_data.get("status", "")
|
|
431
|
+
counts = status_data.get("request_counts", {})
|
|
432
|
+
progress_str = (
|
|
433
|
+
f"completed={counts.get('completed', '?')} "
|
|
434
|
+
f"failed={counts.get('failed', '?')}"
|
|
435
|
+
)
|
|
436
|
+
else:
|
|
437
|
+
state = ""
|
|
438
|
+
progress_str = ""
|
|
439
|
+
|
|
440
|
+
print(f" [batch] {provider} | elapsed={elapsed:.0f}s | {progress_str} | state={state}")
|
|
441
|
+
|
|
442
|
+
if state in terminal:
|
|
443
|
+
if state not in success:
|
|
444
|
+
expired = {"expired", "TIMEOUT_EXCEEDED", "JOB_STATE_CANCELLED", "CANCELLATION_REQUESTED"}
|
|
445
|
+
if state in expired or "cancel" in state.lower() or "timeout" in state.lower():
|
|
446
|
+
raise BatchJobExpiredError(
|
|
447
|
+
f"Batch job '{job_id}' expired/was cancelled (state: {state}). "
|
|
448
|
+
f"Job ID saved above — check provider dashboard for details."
|
|
449
|
+
)
|
|
450
|
+
raise BatchJobFailedError(
|
|
451
|
+
f"Batch job '{job_id}' failed (state: {state}). "
|
|
452
|
+
f"Check the provider dashboard for details."
|
|
453
|
+
)
|
|
454
|
+
return status_data
|
|
455
|
+
|
|
456
|
+
time.sleep(interval)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
# =============================================================================
|
|
460
|
+
# Result download
|
|
461
|
+
# =============================================================================
|
|
462
|
+
|
|
463
|
+
def _download_batch_results(
|
|
464
|
+
provider: str,
|
|
465
|
+
api_key: str,
|
|
466
|
+
job_id: str,
|
|
467
|
+
status_data: dict,
|
|
468
|
+
) -> str:
|
|
469
|
+
"""
|
|
470
|
+
Download completed batch results as raw JSONL text.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
provider: Provider name
|
|
474
|
+
api_key: API key
|
|
475
|
+
job_id: Batch job ID
|
|
476
|
+
status_data: Final status dict from polling (contains output file references)
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
Raw JSONL string (one JSON object per line)
|
|
480
|
+
"""
|
|
481
|
+
headers = _get_batch_headers(provider, api_key)
|
|
482
|
+
|
|
483
|
+
if provider == "openai":
|
|
484
|
+
output_file_id = status_data.get("output_file_id")
|
|
485
|
+
if not output_file_id:
|
|
486
|
+
raise RuntimeError("OpenAI batch: no output_file_id in completed status")
|
|
487
|
+
url = BATCH_ENDPOINTS["openai"]["results"].format(file_id=output_file_id)
|
|
488
|
+
headers_dl = dict(headers)
|
|
489
|
+
headers_dl.pop("Content-Type", None)
|
|
490
|
+
resp = requests.get(url, headers=headers_dl, timeout=120)
|
|
491
|
+
resp.raise_for_status()
|
|
492
|
+
return resp.text
|
|
493
|
+
|
|
494
|
+
elif provider == "anthropic":
|
|
495
|
+
url = BATCH_ENDPOINTS["anthropic"]["results"].format(job_id=job_id)
|
|
496
|
+
headers_dl = dict(headers)
|
|
497
|
+
headers_dl.pop("Content-Type", None)
|
|
498
|
+
resp = requests.get(url, headers=headers_dl, timeout=120, stream=True)
|
|
499
|
+
resp.raise_for_status()
|
|
500
|
+
return resp.text
|
|
501
|
+
|
|
502
|
+
elif provider == "google":
|
|
503
|
+
# Inline batch results live in the operation response:
|
|
504
|
+
# status_data["response"]["inlinedResponses"]["inlinedResponses"] → list of items
|
|
505
|
+
resp_outer = status_data.get("response", {})
|
|
506
|
+
inlined_wrapper = resp_outer.get("inlinedResponses", {})
|
|
507
|
+
if isinstance(inlined_wrapper, dict):
|
|
508
|
+
inlined = inlined_wrapper.get("inlinedResponses", [])
|
|
509
|
+
elif isinstance(inlined_wrapper, list):
|
|
510
|
+
inlined = inlined_wrapper
|
|
511
|
+
else:
|
|
512
|
+
inlined = []
|
|
513
|
+
if not inlined:
|
|
514
|
+
raise RuntimeError(
|
|
515
|
+
f"Google batch: no inlinedResponses in completed status. "
|
|
516
|
+
f"Status keys: {list(status_data.keys())}, "
|
|
517
|
+
f"response keys: {list(resp_outer.keys()) if isinstance(resp_outer, dict) else resp_outer}"
|
|
518
|
+
)
|
|
519
|
+
# Responses are NOT necessarily in order — use the metadata.key from each item,
|
|
520
|
+
# which preserves the original request key (e.g. "item-42") for correct mapping.
|
|
521
|
+
lines = [
|
|
522
|
+
json.dumps({"key": item.get("metadata", {}).get("key", f"item-{i}"), **item})
|
|
523
|
+
for i, item in enumerate(inlined)
|
|
524
|
+
]
|
|
525
|
+
return "\n".join(lines)
|
|
526
|
+
|
|
527
|
+
elif provider == "mistral":
|
|
528
|
+
output_file_id = status_data.get("output_file")
|
|
529
|
+
if not output_file_id:
|
|
530
|
+
raise RuntimeError("Mistral batch: no output_file in completed status")
|
|
531
|
+
url = BATCH_ENDPOINTS["mistral"]["results"].format(file_id=output_file_id)
|
|
532
|
+
headers_dl = dict(headers)
|
|
533
|
+
headers_dl.pop("Content-Type", None)
|
|
534
|
+
resp = requests.get(url, headers=headers_dl, timeout=120)
|
|
535
|
+
resp.raise_for_status()
|
|
536
|
+
return resp.text
|
|
537
|
+
|
|
538
|
+
elif provider == "xai":
|
|
539
|
+
url = BATCH_ENDPOINTS["xai"]["results"].format(job_id=job_id)
|
|
540
|
+
headers_dl = dict(headers)
|
|
541
|
+
headers_dl.pop("Content-Type", None)
|
|
542
|
+
resp = requests.get(url, headers=headers_dl, timeout=120)
|
|
543
|
+
resp.raise_for_status()
|
|
544
|
+
return resp.text
|
|
545
|
+
|
|
546
|
+
raise ValueError(f"Unsupported batch provider: {provider}")
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
# =============================================================================
|
|
550
|
+
# Result parsing
|
|
551
|
+
# =============================================================================
|
|
552
|
+
|
|
553
|
+
def _parse_batch_results(
|
|
554
|
+
provider: str,
|
|
555
|
+
raw_results: str,
|
|
556
|
+
custom_id_map: dict,
|
|
557
|
+
client: "UnifiedLLMClient",
|
|
558
|
+
parse_mode: str = "json",
|
|
559
|
+
) -> dict:
|
|
560
|
+
"""
|
|
561
|
+
Parse the downloaded JSONL results back into per-item strings.
|
|
562
|
+
|
|
563
|
+
Args:
|
|
564
|
+
provider: Provider name
|
|
565
|
+
raw_results: Raw JSONL text from the batch results download
|
|
566
|
+
custom_id_map: Dict mapping custom_id string → original item index
|
|
567
|
+
client: UnifiedLLMClient instance (used to call _parse_response())
|
|
568
|
+
parse_mode: "json" (default) runs extract_json() on responses,
|
|
569
|
+
"text" returns raw text as-is (for summarization)
|
|
570
|
+
|
|
571
|
+
Returns:
|
|
572
|
+
Dict mapping item_index → (result_str, error_or_None)
|
|
573
|
+
Missing items (job dropped them) get (None, "Missing from batch results").
|
|
574
|
+
"""
|
|
575
|
+
parsed_results = {}
|
|
576
|
+
|
|
577
|
+
for line in raw_results.splitlines():
|
|
578
|
+
line = line.strip()
|
|
579
|
+
if not line:
|
|
580
|
+
continue
|
|
581
|
+
try:
|
|
582
|
+
data = json.loads(line)
|
|
583
|
+
except json.JSONDecodeError:
|
|
584
|
+
continue
|
|
585
|
+
|
|
586
|
+
# Extract custom_id and the embedded response object
|
|
587
|
+
if provider == "openai":
|
|
588
|
+
custom_id = data.get("custom_id")
|
|
589
|
+
response_body = data.get("response", {}).get("body")
|
|
590
|
+
error_val = data.get("response", {}).get("error")
|
|
591
|
+
if error_val or response_body is None:
|
|
592
|
+
error_msg = str(error_val) if error_val else "No response body"
|
|
593
|
+
idx = custom_id_map.get(custom_id)
|
|
594
|
+
if idx is not None:
|
|
595
|
+
parsed_results[idx] = (None, error_msg)
|
|
596
|
+
continue
|
|
597
|
+
raw_text = client._parse_response(response_body)
|
|
598
|
+
|
|
599
|
+
elif provider == "anthropic":
|
|
600
|
+
custom_id = data.get("custom_id")
|
|
601
|
+
result = data.get("result", {})
|
|
602
|
+
if result.get("type") != "succeeded":
|
|
603
|
+
error_msg = str(result.get("error", "Request did not succeed"))
|
|
604
|
+
idx = custom_id_map.get(custom_id)
|
|
605
|
+
if idx is not None:
|
|
606
|
+
parsed_results[idx] = (None, error_msg)
|
|
607
|
+
continue
|
|
608
|
+
raw_text = client._parse_response(result.get("message", {}))
|
|
609
|
+
|
|
610
|
+
elif provider == "google":
|
|
611
|
+
# Output JSONL: {"key": "item-0", "response": {generateContent response}}
|
|
612
|
+
# or {"key": "item-0", "error": {"code": ..., "message": ...}}
|
|
613
|
+
custom_id = data.get("key")
|
|
614
|
+
error_val = data.get("error")
|
|
615
|
+
response_data = data.get("response")
|
|
616
|
+
if error_val or response_data is None:
|
|
617
|
+
error_msg = str(error_val) if error_val else "No response in batch output"
|
|
618
|
+
idx = custom_id_map.get(custom_id)
|
|
619
|
+
if idx is not None:
|
|
620
|
+
parsed_results[idx] = (None, error_msg)
|
|
621
|
+
continue
|
|
622
|
+
raw_text = client._parse_response(response_data)
|
|
623
|
+
|
|
624
|
+
elif provider == "mistral":
|
|
625
|
+
# Mistral batch output mirrors OpenAI: response.body holds the completion
|
|
626
|
+
custom_id = data.get("custom_id")
|
|
627
|
+
response_obj = data.get("response", {})
|
|
628
|
+
error_val = data.get("error") or response_obj.get("error")
|
|
629
|
+
if error_val:
|
|
630
|
+
idx = custom_id_map.get(custom_id)
|
|
631
|
+
if idx is not None:
|
|
632
|
+
parsed_results[idx] = (None, str(error_val))
|
|
633
|
+
continue
|
|
634
|
+
response_body = response_obj.get("body", response_obj)
|
|
635
|
+
raw_text = client._parse_response(response_body)
|
|
636
|
+
|
|
637
|
+
elif provider == "xai":
|
|
638
|
+
custom_id = data.get("custom_id")
|
|
639
|
+
response_body = data.get("response", {}).get("body")
|
|
640
|
+
error_val = data.get("response", {}).get("error")
|
|
641
|
+
if error_val or response_body is None:
|
|
642
|
+
error_msg = str(error_val) if error_val else "No response body"
|
|
643
|
+
idx = custom_id_map.get(custom_id)
|
|
644
|
+
if idx is not None:
|
|
645
|
+
parsed_results[idx] = (None, error_msg)
|
|
646
|
+
continue
|
|
647
|
+
raw_text = client._parse_response(response_body)
|
|
648
|
+
|
|
649
|
+
else:
|
|
650
|
+
continue
|
|
651
|
+
|
|
652
|
+
idx = custom_id_map.get(custom_id)
|
|
653
|
+
if idx is None:
|
|
654
|
+
continue
|
|
655
|
+
|
|
656
|
+
if parse_mode == "text":
|
|
657
|
+
parsed_results[idx] = (raw_text, None)
|
|
658
|
+
else:
|
|
659
|
+
json_str = extract_json(raw_text)
|
|
660
|
+
parsed_results[idx] = (json_str, None)
|
|
661
|
+
|
|
662
|
+
return parsed_results
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
# =============================================================================
|
|
666
|
+
# Per-model helpers (used by both single-model and ensemble batch paths)
|
|
667
|
+
# =============================================================================
|
|
668
|
+
|
|
669
|
+
def _run_one_batch_job(
|
|
670
|
+
cfg: dict,
|
|
671
|
+
items: list,
|
|
672
|
+
prompt_params: dict,
|
|
673
|
+
batch_poll_interval: float = 30.0,
|
|
674
|
+
batch_timeout: float = 86400.0,
|
|
675
|
+
) -> dict:
|
|
676
|
+
"""
|
|
677
|
+
Submit, poll, download, and parse a batch job for one model.
|
|
678
|
+
Returns {item_index: (json_str_or_None, error_or_None)}.
|
|
679
|
+
"""
|
|
680
|
+
from .text_functions_ensemble import build_text_classification_prompt
|
|
681
|
+
|
|
682
|
+
provider = cfg["provider"]
|
|
683
|
+
api_key = cfg["api_key"]
|
|
684
|
+
model = cfg["model"]
|
|
685
|
+
|
|
686
|
+
categories_str = prompt_params["categories_str"]
|
|
687
|
+
survey_question_context = prompt_params.get("survey_question_context", "")
|
|
688
|
+
examples_text = prompt_params.get("examples_text", "")
|
|
689
|
+
chain_of_thought = prompt_params.get("chain_of_thought", False)
|
|
690
|
+
context_prompt = prompt_params.get("context_prompt", False)
|
|
691
|
+
step_back_prompt = prompt_params.get("step_back_prompt", False)
|
|
692
|
+
stepback_insights = prompt_params.get("stepback_insights", {})
|
|
693
|
+
json_schema = prompt_params.get("json_schema")
|
|
694
|
+
creativity = prompt_params.get("creativity")
|
|
695
|
+
thinking_budget = prompt_params.get("thinking_budget", 0)
|
|
696
|
+
|
|
697
|
+
client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
|
|
698
|
+
|
|
699
|
+
print(f"\n[batch] Building {len(items)} request(s) for {model} ({provider})...")
|
|
700
|
+
|
|
701
|
+
# Step 1: Build per-item payloads and JSONL
|
|
702
|
+
custom_id_map = {}
|
|
703
|
+
jsonl_lines = []
|
|
704
|
+
requests_list = []
|
|
705
|
+
|
|
706
|
+
for idx, item in enumerate(items):
|
|
707
|
+
custom_id = f"item-{idx}"
|
|
708
|
+
custom_id_map[custom_id] = idx
|
|
709
|
+
|
|
710
|
+
messages = build_text_classification_prompt(
|
|
711
|
+
response_text=str(item) if item is not None else "",
|
|
712
|
+
categories_str=categories_str,
|
|
713
|
+
survey_question_context=survey_question_context,
|
|
714
|
+
examples_text=examples_text,
|
|
715
|
+
chain_of_thought=chain_of_thought,
|
|
716
|
+
context_prompt=context_prompt,
|
|
717
|
+
step_back_prompt=step_back_prompt,
|
|
718
|
+
stepback_insights=stepback_insights,
|
|
719
|
+
model_name=model,
|
|
720
|
+
multi_label=prompt_params.get("multi_label", True),
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
payload = client._build_payload(
|
|
724
|
+
messages=messages,
|
|
725
|
+
json_schema=json_schema,
|
|
726
|
+
creativity=creativity,
|
|
727
|
+
thinking_budget=thinking_budget if thinking_budget and thinking_budget > 0 else None,
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
line = _build_jsonl_line(provider, custom_id, payload, model)
|
|
731
|
+
jsonl_lines.append(line)
|
|
732
|
+
if provider in ("anthropic", "xai", "google"):
|
|
733
|
+
requests_list.append(line)
|
|
734
|
+
|
|
735
|
+
jsonl_bytes = b"\n".join(json.dumps(line).encode("utf-8") for line in jsonl_lines)
|
|
736
|
+
|
|
737
|
+
# Step 2: Upload file (OpenAI, Mistral only — Google uses inline requests)
|
|
738
|
+
file_id = None
|
|
739
|
+
if provider in ("openai", "mistral"):
|
|
740
|
+
print(f"[batch] Uploading JSONL ({len(jsonl_bytes)/1024:.1f} KB) to {provider}...")
|
|
741
|
+
file_id = _upload_jsonl(provider, api_key, jsonl_bytes)
|
|
742
|
+
print(f"[batch] File uploaded: {file_id}")
|
|
743
|
+
|
|
744
|
+
# Step 3: Create batch job
|
|
745
|
+
print(f"[batch] Creating batch job for {model}...")
|
|
746
|
+
job_id = _create_batch_job(
|
|
747
|
+
provider=provider,
|
|
748
|
+
api_key=api_key,
|
|
749
|
+
model=model,
|
|
750
|
+
file_id=file_id,
|
|
751
|
+
requests_list=requests_list if provider in ("anthropic", "xai", "google") else None,
|
|
752
|
+
)
|
|
753
|
+
print(f"[batch] Job created: {job_id}")
|
|
754
|
+
print(f"[batch] Polling every {batch_poll_interval}s (timeout={batch_timeout/3600:.1f}h)...")
|
|
755
|
+
|
|
756
|
+
# Step 4: Poll until complete
|
|
757
|
+
status_data = _poll_batch_job(
|
|
758
|
+
provider=provider,
|
|
759
|
+
api_key=api_key,
|
|
760
|
+
job_id=job_id,
|
|
761
|
+
interval=batch_poll_interval,
|
|
762
|
+
timeout=batch_timeout,
|
|
763
|
+
)
|
|
764
|
+
print(f"[batch] Job complete for {model}.")
|
|
765
|
+
|
|
766
|
+
# Step 5: Download results
|
|
767
|
+
print(f"[batch] Downloading results for {model}...")
|
|
768
|
+
raw_results = _download_batch_results(
|
|
769
|
+
provider=provider,
|
|
770
|
+
api_key=api_key,
|
|
771
|
+
job_id=job_id,
|
|
772
|
+
status_data=status_data,
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
# Step 6: Parse results
|
|
776
|
+
return _parse_batch_results(
|
|
777
|
+
provider=provider,
|
|
778
|
+
raw_results=raw_results,
|
|
779
|
+
custom_id_map=custom_id_map,
|
|
780
|
+
client=client,
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
def _run_one_sync_model(
|
|
785
|
+
cfg: dict,
|
|
786
|
+
items: list,
|
|
787
|
+
prompt_params: dict,
|
|
788
|
+
) -> dict:
|
|
789
|
+
"""
|
|
790
|
+
Classify all items synchronously for one model (fallback for unsupported batch providers).
|
|
791
|
+
Returns {item_index: (json_str_or_None, error_or_None)}.
|
|
792
|
+
"""
|
|
793
|
+
from .text_functions_ensemble import build_text_classification_prompt
|
|
794
|
+
|
|
795
|
+
provider = cfg["provider"]
|
|
796
|
+
api_key = cfg["api_key"]
|
|
797
|
+
model = cfg["model"]
|
|
798
|
+
json_schema = prompt_params.get("json_schema")
|
|
799
|
+
creativity = prompt_params.get("creativity")
|
|
800
|
+
thinking_budget = prompt_params.get("thinking_budget", 0)
|
|
801
|
+
|
|
802
|
+
client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
|
|
803
|
+
item_results = {}
|
|
804
|
+
|
|
805
|
+
print(f"\n[batch] Synchronous fallback for {model} ({provider}): {len(items)} item(s)...")
|
|
806
|
+
|
|
807
|
+
for idx, item in enumerate(items):
|
|
808
|
+
messages = build_text_classification_prompt(
|
|
809
|
+
response_text=str(item) if item is not None else "",
|
|
810
|
+
categories_str=prompt_params["categories_str"],
|
|
811
|
+
survey_question_context=prompt_params.get("survey_question_context", ""),
|
|
812
|
+
examples_text=prompt_params.get("examples_text", ""),
|
|
813
|
+
chain_of_thought=prompt_params.get("chain_of_thought", False),
|
|
814
|
+
context_prompt=prompt_params.get("context_prompt", False),
|
|
815
|
+
step_back_prompt=prompt_params.get("step_back_prompt", False),
|
|
816
|
+
stepback_insights=prompt_params.get("stepback_insights", {}),
|
|
817
|
+
model_name=model,
|
|
818
|
+
multi_label=prompt_params.get("multi_label", True),
|
|
819
|
+
)
|
|
820
|
+
try:
|
|
821
|
+
raw = client.complete(
|
|
822
|
+
messages=messages,
|
|
823
|
+
json_schema=json_schema,
|
|
824
|
+
creativity=creativity,
|
|
825
|
+
thinking_budget=thinking_budget if thinking_budget and thinking_budget > 0 else None,
|
|
826
|
+
)
|
|
827
|
+
item_results[idx] = (extract_json(raw), None)
|
|
828
|
+
except Exception as e:
|
|
829
|
+
item_results[idx] = (None, str(e))
|
|
830
|
+
|
|
831
|
+
return item_results
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
# =============================================================================
|
|
835
|
+
# Main entry point (single-model)
|
|
836
|
+
# =============================================================================
|
|
837
|
+
|
|
838
|
+
def run_batch_classify(
|
|
839
|
+
items: list,
|
|
840
|
+
cfg: dict,
|
|
841
|
+
categories: list,
|
|
842
|
+
prompt_params: dict,
|
|
843
|
+
filename: str = None,
|
|
844
|
+
save_directory: str = None,
|
|
845
|
+
batch_poll_interval: float = 30.0,
|
|
846
|
+
batch_timeout: float = 86400.0,
|
|
847
|
+
fail_strategy: str = "partial",
|
|
848
|
+
) -> "pd.DataFrame":
|
|
849
|
+
"""
|
|
850
|
+
Run batch classification for a single model against a list of text items.
|
|
851
|
+
|
|
852
|
+
This is the main entry point called from classify() when batch_mode=True.
|
|
853
|
+
Returns a DataFrame in the same format as the synchronous single-model path.
|
|
854
|
+
|
|
855
|
+
Args:
|
|
856
|
+
items: List of text strings to classify
|
|
857
|
+
cfg: Model config dict from prepare_model_configs() (single entry)
|
|
858
|
+
categories: List of category names
|
|
859
|
+
prompt_params: Dict containing prompt-building parameters:
|
|
860
|
+
- categories_str (str)
|
|
861
|
+
- survey_question_context (str)
|
|
862
|
+
- examples_text (str)
|
|
863
|
+
- chain_of_thought (bool)
|
|
864
|
+
- context_prompt (bool)
|
|
865
|
+
- step_back_prompt (bool)
|
|
866
|
+
- stepback_insights (dict)
|
|
867
|
+
- json_schema (dict or None)
|
|
868
|
+
- creativity (float or None)
|
|
869
|
+
- thinking_budget (int)
|
|
870
|
+
filename: Optional CSV filename to save results
|
|
871
|
+
save_directory: Optional directory to save results
|
|
872
|
+
batch_poll_interval: Seconds between poll checks (default 30)
|
|
873
|
+
batch_timeout: Max seconds to wait for job (default 86400 = 24h)
|
|
874
|
+
fail_strategy: "partial" or "strict"
|
|
875
|
+
|
|
876
|
+
Returns:
|
|
877
|
+
pd.DataFrame with category_1, category_2, ... columns (same as sync path)
|
|
878
|
+
"""
|
|
879
|
+
from .text_functions_ensemble import aggregate_results, build_output_dataframes
|
|
880
|
+
|
|
881
|
+
# =========================================================================
|
|
882
|
+
# Steps 1-6: Submit, poll, download, and parse the batch job
|
|
883
|
+
# =========================================================================
|
|
884
|
+
item_results = _run_one_batch_job(
|
|
885
|
+
cfg=cfg,
|
|
886
|
+
items=items,
|
|
887
|
+
prompt_params=prompt_params,
|
|
888
|
+
batch_poll_interval=batch_poll_interval,
|
|
889
|
+
batch_timeout=batch_timeout,
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
# =========================================================================
|
|
893
|
+
# Step 7: Build all_results list in the format aggregate_results expects
|
|
894
|
+
# =========================================================================
|
|
895
|
+
model_name = cfg["sanitized_name"]
|
|
896
|
+
all_results = []
|
|
897
|
+
|
|
898
|
+
for idx, item in enumerate(items):
|
|
899
|
+
json_str, error = item_results.get(idx, (None, "Missing from batch results"))
|
|
900
|
+
|
|
901
|
+
model_results = {model_name: (json_str, error)}
|
|
902
|
+
aggregated = aggregate_results(
|
|
903
|
+
model_results=model_results,
|
|
904
|
+
categories=categories,
|
|
905
|
+
consensus_threshold="unanimous",
|
|
906
|
+
fail_strategy=fail_strategy,
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
all_results.append({
|
|
910
|
+
"response": str(item) if item is not None else "",
|
|
911
|
+
"model_results": model_results,
|
|
912
|
+
"aggregated": aggregated,
|
|
913
|
+
"skipped": (item is None or (isinstance(item, float) and __import__("math").isnan(item))),
|
|
914
|
+
})
|
|
915
|
+
|
|
916
|
+
# =========================================================================
|
|
917
|
+
# Step 8: Build output DataFrame (reuses existing pipeline)
|
|
918
|
+
# =========================================================================
|
|
919
|
+
return build_output_dataframes(
|
|
920
|
+
all_results=all_results,
|
|
921
|
+
model_configs=[cfg],
|
|
922
|
+
categories=categories,
|
|
923
|
+
filename=filename,
|
|
924
|
+
save_directory=save_directory,
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
|
|
928
|
+
# =============================================================================
|
|
929
|
+
# Ensemble batch entry point
|
|
930
|
+
# =============================================================================
|
|
931
|
+
|
|
932
|
+
def run_batch_ensemble_classify(
|
|
933
|
+
items: list,
|
|
934
|
+
model_configs: list,
|
|
935
|
+
categories: list,
|
|
936
|
+
prompt_params_per_model: dict,
|
|
937
|
+
consensus_threshold,
|
|
938
|
+
fail_strategy: str = "partial",
|
|
939
|
+
filename: str = None,
|
|
940
|
+
save_directory: str = None,
|
|
941
|
+
batch_poll_interval: float = 30.0,
|
|
942
|
+
batch_timeout: float = 86400.0,
|
|
943
|
+
) -> "pd.DataFrame":
|
|
944
|
+
"""
|
|
945
|
+
Run batch classification for multiple models concurrently, then merge results.
|
|
946
|
+
|
|
947
|
+
Batch-capable providers (openai, anthropic, google, mistral, xai) submit jobs
|
|
948
|
+
concurrently. Unsupported providers (huggingface, perplexity, ollama) fall back
|
|
949
|
+
to synchronous per-item calls and are merged with the batch results.
|
|
950
|
+
|
|
951
|
+
Args:
|
|
952
|
+
items: List of text strings to classify
|
|
953
|
+
model_configs: List of model config dicts from prepare_model_configs()
|
|
954
|
+
categories: List of category names
|
|
955
|
+
prompt_params_per_model: Dict mapping model name → prompt_params dict
|
|
956
|
+
consensus_threshold: Agreement threshold (str or float)
|
|
957
|
+
fail_strategy: "partial" or "strict"
|
|
958
|
+
filename: Optional CSV filename to save results
|
|
959
|
+
save_directory: Optional directory to save results
|
|
960
|
+
batch_poll_interval: Seconds between poll checks (default 30)
|
|
961
|
+
batch_timeout: Max seconds to wait for job (default 86400 = 24h)
|
|
962
|
+
|
|
963
|
+
Returns:
|
|
964
|
+
pd.DataFrame with consensus columns and per-model columns (same as sync ensemble)
|
|
965
|
+
"""
|
|
966
|
+
import math
|
|
967
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
968
|
+
|
|
969
|
+
from .text_functions_ensemble import aggregate_results, build_output_dataframes
|
|
970
|
+
|
|
971
|
+
batch_cfgs = [c for c in model_configs if c["provider"] not in UNSUPPORTED_BATCH_PROVIDERS]
|
|
972
|
+
sync_cfgs = [c for c in model_configs if c["provider"] in UNSUPPORTED_BATCH_PROVIDERS]
|
|
973
|
+
|
|
974
|
+
if batch_cfgs:
|
|
975
|
+
print(
|
|
976
|
+
f"\n[batch ensemble] {len(batch_cfgs)} model(s) will use batch API: "
|
|
977
|
+
f"{', '.join(c['model'] for c in batch_cfgs)}"
|
|
978
|
+
)
|
|
979
|
+
if sync_cfgs:
|
|
980
|
+
print(
|
|
981
|
+
f"[batch ensemble] {len(sync_cfgs)} model(s) will use synchronous fallback: "
|
|
982
|
+
f"{', '.join(c['model'] for c in sync_cfgs)}"
|
|
983
|
+
)
|
|
984
|
+
|
|
985
|
+
all_model_results = {}
|
|
986
|
+
|
|
987
|
+
def _run_cfg(cfg):
|
|
988
|
+
model_key = cfg["sanitized_name"]
|
|
989
|
+
pp = prompt_params_per_model[cfg["model"]]
|
|
990
|
+
if cfg["provider"] in UNSUPPORTED_BATCH_PROVIDERS:
|
|
991
|
+
return model_key, _run_one_sync_model(cfg, items, pp)
|
|
992
|
+
else:
|
|
993
|
+
return model_key, _run_one_batch_job(cfg, items, pp, batch_poll_interval, batch_timeout)
|
|
994
|
+
|
|
995
|
+
with ThreadPoolExecutor(max_workers=len(model_configs)) as executor:
|
|
996
|
+
futures = {executor.submit(_run_cfg, cfg): cfg for cfg in model_configs}
|
|
997
|
+
for future in as_completed(futures):
|
|
998
|
+
model_key, result = future.result()
|
|
999
|
+
all_model_results[model_key] = result
|
|
1000
|
+
|
|
1001
|
+
all_results = []
|
|
1002
|
+
for idx, item in enumerate(items):
|
|
1003
|
+
model_results = {
|
|
1004
|
+
cfg["sanitized_name"]: all_model_results[cfg["sanitized_name"]].get(
|
|
1005
|
+
idx, (None, "Missing from batch results")
|
|
1006
|
+
)
|
|
1007
|
+
for cfg in model_configs
|
|
1008
|
+
}
|
|
1009
|
+
aggregated = aggregate_results(
|
|
1010
|
+
model_results=model_results,
|
|
1011
|
+
categories=categories,
|
|
1012
|
+
consensus_threshold=consensus_threshold,
|
|
1013
|
+
fail_strategy=fail_strategy,
|
|
1014
|
+
)
|
|
1015
|
+
skipped = item is None or (isinstance(item, float) and math.isnan(item))
|
|
1016
|
+
all_results.append({
|
|
1017
|
+
"response": str(item) if not skipped else "",
|
|
1018
|
+
"model_results": model_results,
|
|
1019
|
+
"aggregated": aggregated,
|
|
1020
|
+
"skipped": skipped,
|
|
1021
|
+
})
|
|
1022
|
+
|
|
1023
|
+
return build_output_dataframes(all_results, model_configs, categories, filename, save_directory)
|
|
1024
|
+
|
|
1025
|
+
|
|
1026
|
+
# =============================================================================
|
|
1027
|
+
# Batch summarization
|
|
1028
|
+
# =============================================================================
|
|
1029
|
+
|
|
1030
|
+
def _run_one_batch_summarize_job(
|
|
1031
|
+
cfg: dict,
|
|
1032
|
+
items: list,
|
|
1033
|
+
prompt_params: dict,
|
|
1034
|
+
batch_poll_interval: float = 30.0,
|
|
1035
|
+
batch_timeout: float = 86400.0,
|
|
1036
|
+
) -> dict:
|
|
1037
|
+
"""
|
|
1038
|
+
Submit, poll, download, and parse a batch summarization job for one model.
|
|
1039
|
+
Returns {item_index: (summary_text_or_None, error_or_None)}.
|
|
1040
|
+
"""
|
|
1041
|
+
from .text_functions_ensemble import build_text_summarization_prompt, build_summary_json_schema
|
|
1042
|
+
|
|
1043
|
+
provider = cfg["provider"]
|
|
1044
|
+
api_key = cfg["api_key"]
|
|
1045
|
+
model = cfg["model"]
|
|
1046
|
+
|
|
1047
|
+
input_description = prompt_params.get("input_description", "")
|
|
1048
|
+
summary_instructions = prompt_params.get("summary_instructions", "")
|
|
1049
|
+
max_length = prompt_params.get("max_length")
|
|
1050
|
+
focus = prompt_params.get("focus")
|
|
1051
|
+
chain_of_thought = prompt_params.get("chain_of_thought", False)
|
|
1052
|
+
context_prompt = prompt_params.get("context_prompt", False)
|
|
1053
|
+
step_back_prompt = prompt_params.get("step_back_prompt", False)
|
|
1054
|
+
stepback_insights = prompt_params.get("stepback_insights", {})
|
|
1055
|
+
creativity = prompt_params.get("creativity")
|
|
1056
|
+
|
|
1057
|
+
include_additional = provider != "google"
|
|
1058
|
+
json_schema = build_summary_json_schema(include_additional)
|
|
1059
|
+
|
|
1060
|
+
client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
|
|
1061
|
+
|
|
1062
|
+
print(f"\n[batch] Building {len(items)} summarization request(s) for {model} ({provider})...")
|
|
1063
|
+
|
|
1064
|
+
custom_id_map = {}
|
|
1065
|
+
jsonl_lines = []
|
|
1066
|
+
requests_list = []
|
|
1067
|
+
|
|
1068
|
+
for idx, item in enumerate(items):
|
|
1069
|
+
custom_id = f"item-{idx}"
|
|
1070
|
+
custom_id_map[custom_id] = idx
|
|
1071
|
+
|
|
1072
|
+
text = str(item) if item is not None else ""
|
|
1073
|
+
|
|
1074
|
+
messages = build_text_summarization_prompt(
|
|
1075
|
+
response_text=text,
|
|
1076
|
+
input_description=input_description,
|
|
1077
|
+
summary_instructions=summary_instructions,
|
|
1078
|
+
max_length=max_length,
|
|
1079
|
+
focus=focus,
|
|
1080
|
+
chain_of_thought=chain_of_thought,
|
|
1081
|
+
context_prompt=context_prompt,
|
|
1082
|
+
step_back_prompt=step_back_prompt,
|
|
1083
|
+
stepback_insights=stepback_insights,
|
|
1084
|
+
model_name=model,
|
|
1085
|
+
)
|
|
1086
|
+
|
|
1087
|
+
payload = client._build_payload(
|
|
1088
|
+
messages=messages,
|
|
1089
|
+
json_schema=json_schema,
|
|
1090
|
+
creativity=creativity,
|
|
1091
|
+
)
|
|
1092
|
+
|
|
1093
|
+
line = _build_jsonl_line(provider, custom_id, payload, model)
|
|
1094
|
+
jsonl_lines.append(line)
|
|
1095
|
+
if provider in ("anthropic", "xai", "google"):
|
|
1096
|
+
requests_list.append(line)
|
|
1097
|
+
|
|
1098
|
+
jsonl_bytes = b"\n".join(json.dumps(line).encode("utf-8") for line in jsonl_lines)
|
|
1099
|
+
|
|
1100
|
+
# Upload file (OpenAI, Mistral only)
|
|
1101
|
+
file_id = None
|
|
1102
|
+
if provider in ("openai", "mistral"):
|
|
1103
|
+
print(f"[batch] Uploading JSONL ({len(jsonl_bytes)/1024:.1f} KB) to {provider}...")
|
|
1104
|
+
file_id = _upload_jsonl(provider, api_key, jsonl_bytes)
|
|
1105
|
+
print(f"[batch] File uploaded: {file_id}")
|
|
1106
|
+
|
|
1107
|
+
# Create batch job
|
|
1108
|
+
print(f"[batch] Creating batch job for {model}...")
|
|
1109
|
+
job_id = _create_batch_job(
|
|
1110
|
+
provider=provider,
|
|
1111
|
+
api_key=api_key,
|
|
1112
|
+
model=model,
|
|
1113
|
+
file_id=file_id,
|
|
1114
|
+
requests_list=requests_list if provider in ("anthropic", "xai", "google") else None,
|
|
1115
|
+
)
|
|
1116
|
+
print(f"[batch] Job created: {job_id}")
|
|
1117
|
+
print(f"[batch] Polling every {batch_poll_interval}s (timeout={batch_timeout/3600:.1f}h)...")
|
|
1118
|
+
|
|
1119
|
+
# Poll until complete
|
|
1120
|
+
status_data = _poll_batch_job(
|
|
1121
|
+
provider=provider,
|
|
1122
|
+
api_key=api_key,
|
|
1123
|
+
job_id=job_id,
|
|
1124
|
+
interval=batch_poll_interval,
|
|
1125
|
+
timeout=batch_timeout,
|
|
1126
|
+
)
|
|
1127
|
+
print(f"[batch] Job complete for {model}.")
|
|
1128
|
+
|
|
1129
|
+
# Download results
|
|
1130
|
+
print(f"[batch] Downloading results for {model}...")
|
|
1131
|
+
raw_results = _download_batch_results(
|
|
1132
|
+
provider=provider,
|
|
1133
|
+
api_key=api_key,
|
|
1134
|
+
job_id=job_id,
|
|
1135
|
+
status_data=status_data,
|
|
1136
|
+
)
|
|
1137
|
+
|
|
1138
|
+
# Parse results — use text mode since summaries are JSON-wrapped
|
|
1139
|
+
# but we still want extract_json since output is {"summary": "..."}
|
|
1140
|
+
return _parse_batch_results(
|
|
1141
|
+
provider=provider,
|
|
1142
|
+
raw_results=raw_results,
|
|
1143
|
+
custom_id_map=custom_id_map,
|
|
1144
|
+
client=client,
|
|
1145
|
+
parse_mode="json",
|
|
1146
|
+
)
|
|
1147
|
+
|
|
1148
|
+
|
|
1149
|
+
def _run_one_sync_summarize_model(
|
|
1150
|
+
cfg: dict,
|
|
1151
|
+
items: list,
|
|
1152
|
+
prompt_params: dict,
|
|
1153
|
+
) -> dict:
|
|
1154
|
+
"""
|
|
1155
|
+
Summarize all items synchronously for one model (fallback for unsupported batch providers).
|
|
1156
|
+
Returns {item_index: (json_str_or_None, error_or_None)}.
|
|
1157
|
+
"""
|
|
1158
|
+
from .text_functions_ensemble import build_text_summarization_prompt, build_summary_json_schema
|
|
1159
|
+
|
|
1160
|
+
provider = cfg["provider"]
|
|
1161
|
+
api_key = cfg["api_key"]
|
|
1162
|
+
model = cfg["model"]
|
|
1163
|
+
creativity = prompt_params.get("creativity")
|
|
1164
|
+
|
|
1165
|
+
include_additional = provider != "google"
|
|
1166
|
+
json_schema = build_summary_json_schema(include_additional)
|
|
1167
|
+
|
|
1168
|
+
client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
|
|
1169
|
+
item_results = {}
|
|
1170
|
+
|
|
1171
|
+
print(f"\n[batch] Synchronous fallback for {model} ({provider}): {len(items)} item(s)...")
|
|
1172
|
+
|
|
1173
|
+
for idx, item in enumerate(items):
|
|
1174
|
+
messages = build_text_summarization_prompt(
|
|
1175
|
+
response_text=str(item) if item is not None else "",
|
|
1176
|
+
input_description=prompt_params.get("input_description", ""),
|
|
1177
|
+
summary_instructions=prompt_params.get("summary_instructions", ""),
|
|
1178
|
+
max_length=prompt_params.get("max_length"),
|
|
1179
|
+
focus=prompt_params.get("focus"),
|
|
1180
|
+
chain_of_thought=prompt_params.get("chain_of_thought", False),
|
|
1181
|
+
context_prompt=prompt_params.get("context_prompt", False),
|
|
1182
|
+
step_back_prompt=prompt_params.get("step_back_prompt", False),
|
|
1183
|
+
stepback_insights=prompt_params.get("stepback_insights", {}),
|
|
1184
|
+
model_name=model,
|
|
1185
|
+
)
|
|
1186
|
+
try:
|
|
1187
|
+
raw, _err = client.complete(
|
|
1188
|
+
messages=messages,
|
|
1189
|
+
json_schema=json_schema,
|
|
1190
|
+
creativity=creativity,
|
|
1191
|
+
)
|
|
1192
|
+
item_results[idx] = (extract_json(raw), None)
|
|
1193
|
+
except Exception as e:
|
|
1194
|
+
item_results[idx] = (None, str(e))
|
|
1195
|
+
|
|
1196
|
+
return item_results
|
|
1197
|
+
|
|
1198
|
+
|
|
1199
|
+
def run_batch_summarize(
|
|
1200
|
+
items: list,
|
|
1201
|
+
cfg: dict,
|
|
1202
|
+
prompt_params: dict,
|
|
1203
|
+
filename: str = None,
|
|
1204
|
+
save_directory: str = None,
|
|
1205
|
+
batch_poll_interval: float = 30.0,
|
|
1206
|
+
batch_timeout: float = 86400.0,
|
|
1207
|
+
fail_strategy: str = "partial",
|
|
1208
|
+
) -> "pd.DataFrame":
|
|
1209
|
+
"""
|
|
1210
|
+
Run batch summarization for a single model.
|
|
1211
|
+
|
|
1212
|
+
Returns a DataFrame with input_data, summary, processing_status columns.
|
|
1213
|
+
"""
|
|
1214
|
+
import math
|
|
1215
|
+
import pandas as pd
|
|
1216
|
+
from .text_functions_ensemble import extract_summary_from_json
|
|
1217
|
+
|
|
1218
|
+
item_results = _run_one_batch_summarize_job(
|
|
1219
|
+
cfg=cfg,
|
|
1220
|
+
items=items,
|
|
1221
|
+
prompt_params=prompt_params,
|
|
1222
|
+
batch_poll_interval=batch_poll_interval,
|
|
1223
|
+
batch_timeout=batch_timeout,
|
|
1224
|
+
)
|
|
1225
|
+
|
|
1226
|
+
rows = []
|
|
1227
|
+
for idx, item in enumerate(items):
|
|
1228
|
+
json_str, error = item_results.get(idx, (None, "Missing from batch results"))
|
|
1229
|
+
|
|
1230
|
+
text = str(item) if item is not None else ""
|
|
1231
|
+
is_skipped = item is None or (isinstance(item, float) and math.isnan(item))
|
|
1232
|
+
|
|
1233
|
+
if is_skipped:
|
|
1234
|
+
rows.append({"input_data": text, "summary": "", "processing_status": "skipped"})
|
|
1235
|
+
continue
|
|
1236
|
+
|
|
1237
|
+
if error:
|
|
1238
|
+
rows.append({"input_data": text, "summary": "", "processing_status": "error"})
|
|
1239
|
+
continue
|
|
1240
|
+
|
|
1241
|
+
if fail_strategy == "strict" and error:
|
|
1242
|
+
rows.append({"input_data": text, "summary": "", "processing_status": "error"})
|
|
1243
|
+
continue
|
|
1244
|
+
|
|
1245
|
+
is_valid, summary_text = extract_summary_from_json(json_str)
|
|
1246
|
+
if is_valid and summary_text:
|
|
1247
|
+
rows.append({"input_data": text, "summary": summary_text, "processing_status": "success"})
|
|
1248
|
+
else:
|
|
1249
|
+
rows.append({"input_data": text, "summary": "", "processing_status": "error"})
|
|
1250
|
+
|
|
1251
|
+
df = pd.DataFrame(rows)
|
|
1252
|
+
|
|
1253
|
+
if filename:
|
|
1254
|
+
import os
|
|
1255
|
+
save_path = os.path.join(save_directory, filename) if save_directory else filename
|
|
1256
|
+
if save_directory:
|
|
1257
|
+
os.makedirs(save_directory, exist_ok=True)
|
|
1258
|
+
df.to_csv(save_path, index=False)
|
|
1259
|
+
print(f"\nResults saved to: {save_path}")
|
|
1260
|
+
|
|
1261
|
+
return df
|
|
1262
|
+
|
|
1263
|
+
|
|
1264
|
+
def run_batch_ensemble_summarize(
|
|
1265
|
+
items: list,
|
|
1266
|
+
model_configs: list,
|
|
1267
|
+
prompt_params_per_model: dict,
|
|
1268
|
+
fail_strategy: str = "partial",
|
|
1269
|
+
filename: str = None,
|
|
1270
|
+
save_directory: str = None,
|
|
1271
|
+
batch_poll_interval: float = 30.0,
|
|
1272
|
+
batch_timeout: float = 86400.0,
|
|
1273
|
+
max_retries: int = 3,
|
|
1274
|
+
) -> "pd.DataFrame":
|
|
1275
|
+
"""
|
|
1276
|
+
Run batch summarization for multiple models concurrently, then synthesize.
|
|
1277
|
+
|
|
1278
|
+
Each model submits its own batch job. After all complete, summaries are
|
|
1279
|
+
synthesized into a consensus using _synthesize_summaries().
|
|
1280
|
+
"""
|
|
1281
|
+
import math
|
|
1282
|
+
import pandas as pd
|
|
1283
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
1284
|
+
from .text_functions_ensemble import extract_summary_from_json, _synthesize_summaries
|
|
1285
|
+
|
|
1286
|
+
batch_cfgs = [c for c in model_configs if c["provider"] not in UNSUPPORTED_BATCH_PROVIDERS]
|
|
1287
|
+
sync_cfgs = [c for c in model_configs if c["provider"] in UNSUPPORTED_BATCH_PROVIDERS]
|
|
1288
|
+
|
|
1289
|
+
if batch_cfgs:
|
|
1290
|
+
print(
|
|
1291
|
+
f"\n[batch ensemble] {len(batch_cfgs)} model(s) will use batch API: "
|
|
1292
|
+
f"{', '.join(c['model'] for c in batch_cfgs)}"
|
|
1293
|
+
)
|
|
1294
|
+
if sync_cfgs:
|
|
1295
|
+
print(
|
|
1296
|
+
f"[batch ensemble] {len(sync_cfgs)} model(s) will use synchronous fallback: "
|
|
1297
|
+
f"{', '.join(c['model'] for c in sync_cfgs)}"
|
|
1298
|
+
)
|
|
1299
|
+
|
|
1300
|
+
all_model_results = {}
|
|
1301
|
+
|
|
1302
|
+
def _run_cfg(cfg):
|
|
1303
|
+
model_key = cfg["sanitized_name"]
|
|
1304
|
+
pp = prompt_params_per_model[cfg["model"]]
|
|
1305
|
+
if cfg["provider"] in UNSUPPORTED_BATCH_PROVIDERS:
|
|
1306
|
+
return model_key, _run_one_sync_summarize_model(cfg, items, pp)
|
|
1307
|
+
else:
|
|
1308
|
+
return model_key, _run_one_batch_summarize_job(cfg, items, pp, batch_poll_interval, batch_timeout)
|
|
1309
|
+
|
|
1310
|
+
with ThreadPoolExecutor(max_workers=len(model_configs)) as executor:
|
|
1311
|
+
futures = {executor.submit(_run_cfg, cfg): cfg for cfg in model_configs}
|
|
1312
|
+
for future in as_completed(futures):
|
|
1313
|
+
model_key, result = future.result()
|
|
1314
|
+
all_model_results[model_key] = result
|
|
1315
|
+
|
|
1316
|
+
model_names = [cfg["sanitized_name"] for cfg in model_configs]
|
|
1317
|
+
|
|
1318
|
+
rows = []
|
|
1319
|
+
for idx, item in enumerate(items):
|
|
1320
|
+
text = str(item) if item is not None else ""
|
|
1321
|
+
is_skipped = item is None or (isinstance(item, float) and math.isnan(item))
|
|
1322
|
+
|
|
1323
|
+
if is_skipped:
|
|
1324
|
+
row = {"input_data": text, "summary": "", "processing_status": "skipped"}
|
|
1325
|
+
for mn in model_names:
|
|
1326
|
+
row[f"summary_{mn}"] = ""
|
|
1327
|
+
row["failed_models"] = ""
|
|
1328
|
+
rows.append(row)
|
|
1329
|
+
continue
|
|
1330
|
+
|
|
1331
|
+
summaries = {}
|
|
1332
|
+
errors = []
|
|
1333
|
+
for cfg in model_configs:
|
|
1334
|
+
mn = cfg["sanitized_name"]
|
|
1335
|
+
json_str, error = all_model_results[mn].get(idx, (None, "Missing from batch results"))
|
|
1336
|
+
if error:
|
|
1337
|
+
summaries[mn] = ""
|
|
1338
|
+
errors.append(mn)
|
|
1339
|
+
else:
|
|
1340
|
+
is_valid, summary_text = extract_summary_from_json(json_str)
|
|
1341
|
+
summaries[mn] = summary_text if is_valid else ""
|
|
1342
|
+
if not is_valid or not summary_text:
|
|
1343
|
+
errors.append(mn)
|
|
1344
|
+
|
|
1345
|
+
# fail_strategy="strict": blank everything if any model failed
|
|
1346
|
+
if fail_strategy == "strict" and errors:
|
|
1347
|
+
summaries = {k: "" for k in summaries}
|
|
1348
|
+
|
|
1349
|
+
row = {"input_data": text}
|
|
1350
|
+
for mn in model_names:
|
|
1351
|
+
row[f"summary_{mn}"] = summaries.get(mn, "")
|
|
1352
|
+
|
|
1353
|
+
# Synthesize consensus
|
|
1354
|
+
valid_summaries = {k: v for k, v in summaries.items() if v}
|
|
1355
|
+
if valid_summaries:
|
|
1356
|
+
synthesis_cfg = model_configs[0]
|
|
1357
|
+
consensus = _synthesize_summaries(
|
|
1358
|
+
summaries=valid_summaries,
|
|
1359
|
+
original_text=text,
|
|
1360
|
+
synthesis_config=synthesis_cfg,
|
|
1361
|
+
max_retries=max_retries,
|
|
1362
|
+
)
|
|
1363
|
+
row["summary"] = consensus
|
|
1364
|
+
else:
|
|
1365
|
+
row["summary"] = ""
|
|
1366
|
+
|
|
1367
|
+
row["failed_models"] = ",".join(errors) if errors else ""
|
|
1368
|
+
|
|
1369
|
+
if all(not s for s in summaries.values()):
|
|
1370
|
+
row["processing_status"] = "error"
|
|
1371
|
+
elif any(not s for s in summaries.values()):
|
|
1372
|
+
row["processing_status"] = "partial"
|
|
1373
|
+
else:
|
|
1374
|
+
row["processing_status"] = "success"
|
|
1375
|
+
|
|
1376
|
+
rows.append(row)
|
|
1377
|
+
|
|
1378
|
+
df = pd.DataFrame(rows)
|
|
1379
|
+
|
|
1380
|
+
if filename:
|
|
1381
|
+
import os
|
|
1382
|
+
save_path = os.path.join(save_directory, filename) if save_directory else filename
|
|
1383
|
+
if save_directory:
|
|
1384
|
+
os.makedirs(save_directory, exist_ok=True)
|
|
1385
|
+
df.to_csv(save_path, index=False)
|
|
1386
|
+
print(f"\nResults saved to: {save_path}")
|
|
1387
|
+
|
|
1388
|
+
return df
|