cat-stack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cat_stack/_batch.py ADDED
@@ -0,0 +1,1388 @@
1
+ """
2
+ Async batch inference for CatLLM.
3
+
4
+ Supports OpenAI, Anthropic, Google, Mistral, and xAI — all offer 50% cost
5
+ savings and higher rate limits compared to synchronous API calls.
6
+
7
+ All five providers follow the same conceptual pattern:
8
+ 1. Package all requests as JSONL
9
+ 2. Submit a batch job (file upload → job creation, or inline for Anthropic)
10
+ 3. Poll until the job reaches a terminal state
11
+ 4. Download and parse results
12
+ 5. Return a DataFrame identical in format to the synchronous single-model path
13
+
14
+ Not supported: HuggingFace, Perplexity, Ollama (no batch API).
15
+ Ensemble mode: supported. Each model submits its own batch job concurrently.
16
+ Providers without batch API (HuggingFace, Perplexity, Ollama) fall back to
17
+ synchronous calls and are merged in with the batch results.
18
+ Not compatible: PDF/image input (text only).
19
+ """
20
+
21
+ import io
22
+ import json
23
+ import os
24
+ import time
25
+
26
+ import requests
27
+
28
+ from ._providers import UnifiedLLMClient
29
+ from ._utils import extract_json
30
+
31
+ # =============================================================================
32
+ # Constants
33
+ # =============================================================================
34
+
35
+ BATCH_ENDPOINTS = {
36
+ "openai": {
37
+ "upload": "https://api.openai.com/v1/files",
38
+ "create": "https://api.openai.com/v1/batches",
39
+ "status": "https://api.openai.com/v1/batches/{job_id}",
40
+ "results": "https://api.openai.com/v1/files/{file_id}/content",
41
+ },
42
+ "anthropic": {
43
+ # No file upload — requests are sent inline at job creation
44
+ "create": "https://api.anthropic.com/v1/messages/batches",
45
+ "status": "https://api.anthropic.com/v1/messages/batches/{job_id}",
46
+ "results": "https://api.anthropic.com/v1/messages/batches/{job_id}/results",
47
+ },
48
+ "google": {
49
+ "upload": "https://generativelanguage.googleapis.com/upload/v1beta/files",
50
+ "create": "https://generativelanguage.googleapis.com/v1beta/models/{model}:batchGenerateContent",
51
+ "status": "https://generativelanguage.googleapis.com/v1beta/{job_name}",
52
+ "download": "https://generativelanguage.googleapis.com/download/v1beta/{file_name}:download",
53
+ },
54
+ "mistral": {
55
+ "upload": "https://api.mistral.ai/v1/files",
56
+ "create": "https://api.mistral.ai/v1/batch/jobs",
57
+ "status": "https://api.mistral.ai/v1/batch/jobs/{job_id}",
58
+ "results": "https://api.mistral.ai/v1/files/{file_id}/content",
59
+ },
60
+ "xai": {
61
+ "create": "https://api.x.ai/v1/batches",
62
+ "add": "https://api.x.ai/v1/batches/{job_id}/requests",
63
+ "status": "https://api.x.ai/v1/batches/{job_id}",
64
+ "results": "https://api.x.ai/v1/batches/{job_id}/results",
65
+ },
66
+ }
67
+
68
+ UNSUPPORTED_BATCH_PROVIDERS = {"huggingface", "huggingface-together", "perplexity", "ollama"}
69
+
70
+ # Terminal states per provider
71
+ _TERMINAL_STATES = {
72
+ "openai": {"completed", "failed", "expired", "cancelled"},
73
+ "anthropic": {"ended"},
74
+ "google": {
75
+ "BATCH_STATE_SUCCEEDED", "BATCH_STATE_FAILED", "BATCH_STATE_CANCELLED", "BATCH_STATE_EXPIRED",
76
+ "JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED",
77
+ },
78
+ "mistral": {"SUCCESS", "FAILED", "TIMEOUT_EXCEEDED", "CANCELLATION_REQUESTED"},
79
+ "xai": {"completed", "failed", "expired", "cancelled"},
80
+ }
81
+
82
+ _SUCCESS_STATES = {
83
+ "openai": {"completed"},
84
+ "anthropic": {"ended"}, # must check request_counts.failed separately
85
+ "google": {"BATCH_STATE_SUCCEEDED", "JOB_STATE_SUCCEEDED"},
86
+ "mistral": {"SUCCESS"},
87
+ "xai": {"completed"},
88
+ }
89
+
90
+
91
+ # =============================================================================
92
+ # Exceptions
93
+ # =============================================================================
94
+
95
+ class BatchJobExpiredError(RuntimeError):
96
+ """Raised when a batch job expires before completing."""
97
+ pass
98
+
99
+
100
+ class BatchJobFailedError(RuntimeError):
101
+ """Raised when a batch job terminates in a failed state."""
102
+ pass
103
+
104
+
105
+ # =============================================================================
106
+ # Auth headers
107
+ # =============================================================================
108
+
109
+ def _get_batch_headers(provider: str, api_key: str) -> dict:
110
+ """Return HTTP headers for the given provider's batch API."""
111
+ headers = {"Content-Type": "application/json"}
112
+ if provider == "openai":
113
+ headers["Authorization"] = f"Bearer {api_key}"
114
+ elif provider == "anthropic":
115
+ headers["x-api-key"] = api_key
116
+ headers["anthropic-version"] = "2023-06-01"
117
+ headers["anthropic-beta"] = "message-batches-2024-09-24"
118
+ elif provider == "google":
119
+ headers["x-goog-api-key"] = api_key
120
+ elif provider == "mistral":
121
+ headers["Authorization"] = f"Bearer {api_key}"
122
+ elif provider == "xai":
123
+ headers["Authorization"] = f"Bearer {api_key}"
124
+ return headers
125
+
126
+
127
+ # =============================================================================
128
+ # JSONL line builders
129
+ # =============================================================================
130
+
131
+ def _build_jsonl_line(provider: str, custom_id: str, payload: dict, model: str) -> dict:
132
+ """
133
+ Wrap a provider payload in the provider's batch JSONL envelope.
134
+
135
+ Returns a dict that will be serialized as one JSONL line.
136
+ """
137
+ if provider == "openai":
138
+ return {
139
+ "custom_id": custom_id,
140
+ "method": "POST",
141
+ "url": "/v1/chat/completions",
142
+ "body": payload,
143
+ }
144
+ elif provider == "anthropic":
145
+ # Anthropic uses "params" key; no file upload needed
146
+ return {
147
+ "custom_id": custom_id,
148
+ "params": payload,
149
+ }
150
+ elif provider == "google":
151
+ # Google batch JSONL format: request payload + metadata with key
152
+ return {
153
+ "request": payload,
154
+ "metadata": {"key": custom_id},
155
+ }
156
+ elif provider == "mistral":
157
+ return {
158
+ "custom_id": custom_id,
159
+ "method": "POST",
160
+ "url": "/v1/chat/completions",
161
+ "body": payload,
162
+ }
163
+ elif provider == "xai":
164
+ # xAI requests are added one-by-one after batch creation; same OpenAI-compat format
165
+ return {
166
+ "custom_id": custom_id,
167
+ "method": "POST",
168
+ "url": "/v1/chat/completions",
169
+ "body": payload,
170
+ }
171
+ raise ValueError(f"Unsupported batch provider: {provider}")
172
+
173
+
174
+ # =============================================================================
175
+ # File upload (OpenAI, Google, Mistral)
176
+ # =============================================================================
177
+
178
+ def _upload_jsonl(provider: str, api_key: str, jsonl_bytes: bytes, filename: str = "batch_requests.jsonl") -> str:
179
+ """
180
+ Upload a JSONL file to the provider's files API.
181
+
182
+ Returns:
183
+ file_id string used when creating the batch job.
184
+ """
185
+ headers = _get_batch_headers(provider, api_key)
186
+ # Content-Type for multipart upload — remove JSON header
187
+ headers.pop("Content-Type", None)
188
+
189
+ if provider == "openai":
190
+ url = BATCH_ENDPOINTS["openai"]["upload"]
191
+ files = {"file": (filename, io.BytesIO(jsonl_bytes), "application/jsonl")}
192
+ data = {"purpose": "batch"}
193
+ resp = requests.post(url, headers=headers, files=files, data=data, timeout=120)
194
+ resp.raise_for_status()
195
+ return resp.json()["id"]
196
+
197
+ elif provider == "mistral":
198
+ url = BATCH_ENDPOINTS["mistral"]["upload"]
199
+ files = {"file": (filename, io.BytesIO(jsonl_bytes), "application/octet-stream")}
200
+ data = {"purpose": "batch"}
201
+ resp = requests.post(url, headers=headers, files=files, data=data, timeout=120)
202
+ resp.raise_for_status()
203
+ return resp.json()["id"]
204
+
205
+ elif provider == "google":
206
+ # Google Files API: resumable upload
207
+ upload_url = BATCH_ENDPOINTS["google"]["upload"]
208
+ # Step 1: Initiate upload
209
+ init_headers = {
210
+ "x-goog-api-key": api_key,
211
+ "X-Goog-Upload-Protocol": "resumable",
212
+ "X-Goog-Upload-Command": "start",
213
+ "X-Goog-Upload-Header-Content-Type": "application/jsonl",
214
+ "X-Goog-Upload-Header-Content-Length": str(len(jsonl_bytes)),
215
+ "Content-Type": "application/json",
216
+ }
217
+ init_body = json.dumps({"file": {"display_name": filename}})
218
+ init_resp = requests.post(upload_url, headers=init_headers, data=init_body, timeout=60)
219
+ init_resp.raise_for_status()
220
+ session_url = init_resp.headers.get("X-Goog-Upload-URL")
221
+ if not session_url:
222
+ raise RuntimeError("Google file upload: no session URL returned")
223
+ # Step 2: Upload bytes
224
+ upload_headers = {
225
+ "X-Goog-Upload-Command": "upload, finalize",
226
+ "X-Goog-Upload-Offset": "0",
227
+ "Content-Type": "application/jsonl",
228
+ }
229
+ upload_resp = requests.post(session_url, headers=upload_headers, data=jsonl_bytes, timeout=120)
230
+ upload_resp.raise_for_status()
231
+ file_info = upload_resp.json()
232
+ # Response wraps file metadata under a "file" key: {"file": {"name": "files/abc", ...}}
233
+ file_obj = file_info.get("file", file_info)
234
+ file_name = file_obj.get("name") or file_obj.get("uri")
235
+ if not file_name:
236
+ raise RuntimeError(f"Google file upload: could not extract file name. Response: {file_info}")
237
+ return file_name
238
+
239
+ raise ValueError(f"Provider '{provider}' does not use file upload for batch")
240
+
241
+
242
+ # =============================================================================
243
+ # Batch job creation
244
+ # =============================================================================
245
+
246
+ def _create_batch_job(
247
+ provider: str,
248
+ api_key: str,
249
+ model: str,
250
+ file_id: str = None,
251
+ requests_list: list = None,
252
+ ) -> str:
253
+ """
254
+ Create a batch job and return the job ID.
255
+
256
+ Args:
257
+ provider: Provider name
258
+ api_key: API key
259
+ model: Model name (used for Mistral and xAI)
260
+ file_id: Uploaded file ID (OpenAI, Google, Mistral)
261
+ requests_list: Inline request list (Anthropic only)
262
+
263
+ Returns:
264
+ job_id string for polling
265
+ """
266
+ headers = _get_batch_headers(provider, api_key)
267
+
268
+ if provider == "openai":
269
+ url = BATCH_ENDPOINTS["openai"]["create"]
270
+ body = {
271
+ "input_file_id": file_id,
272
+ "endpoint": "/v1/chat/completions",
273
+ "completion_window": "24h",
274
+ }
275
+ resp = requests.post(url, headers=headers, json=body, timeout=60)
276
+ resp.raise_for_status()
277
+ return resp.json()["id"]
278
+
279
+ elif provider == "anthropic":
280
+ url = BATCH_ENDPOINTS["anthropic"]["create"]
281
+ body = {"requests": requests_list}
282
+ resp = requests.post(url, headers=headers, json=body, timeout=60)
283
+ resp.raise_for_status()
284
+ return resp.json()["id"]
285
+
286
+ elif provider == "google":
287
+ url = BATCH_ENDPOINTS["google"]["create"].format(model=model)
288
+ # Google inline batch: requests are sent in the body (no file upload needed)
289
+ body = {
290
+ "batch": {
291
+ "display_name": f"cat_stack_batch_{int(time.time())}",
292
+ "input_config": {
293
+ "requests": {
294
+ "requests": [
295
+ {"request": line["request"], "metadata": line["metadata"]}
296
+ for line in (requests_list or [])
297
+ ]
298
+ }
299
+ },
300
+ }
301
+ }
302
+ resp = requests.post(url, headers=headers, json=body, timeout=60)
303
+ resp.raise_for_status()
304
+ result = resp.json()
305
+ # Google returns the job name (e.g. "batches/abc123") — use as job_id
306
+ return result.get("name", result.get("id"))
307
+
308
+ elif provider == "mistral":
309
+ url = BATCH_ENDPOINTS["mistral"]["create"]
310
+ body = {
311
+ "input_files": [file_id],
312
+ "model": model,
313
+ "endpoint": "/v1/chat/completions",
314
+ }
315
+ resp = requests.post(url, headers=headers, json=body, timeout=60)
316
+ resp.raise_for_status()
317
+ return resp.json()["id"]
318
+
319
+ elif provider == "xai":
320
+ # Step 1: Create empty batch
321
+ url = BATCH_ENDPOINTS["xai"]["create"]
322
+ body = {"completion_window": "24h"}
323
+ resp = requests.post(url, headers=headers, json=body, timeout=60)
324
+ resp.raise_for_status()
325
+ job_id = resp.json()["id"]
326
+
327
+ # Step 2: Add all requests to the batch
328
+ add_url = BATCH_ENDPOINTS["xai"]["add"].format(job_id=job_id)
329
+ add_resp = requests.post(add_url, headers=headers, json=requests_list, timeout=120)
330
+ add_resp.raise_for_status()
331
+ return job_id
332
+
333
+ raise ValueError(f"Unsupported batch provider: {provider}")
334
+
335
+
336
+ # =============================================================================
337
+ # Polling
338
+ # =============================================================================
339
+
340
+ def _poll_batch_job(
341
+ provider: str,
342
+ api_key: str,
343
+ job_id: str,
344
+ interval: float = 30.0,
345
+ timeout: float = 86400.0,
346
+ ) -> dict:
347
+ """
348
+ Poll the batch job until it reaches a terminal state.
349
+
350
+ Prints one-line status updates each poll cycle.
351
+
352
+ Returns:
353
+ The final status response dict (contains output file ID or job name for result retrieval).
354
+
355
+ Raises:
356
+ BatchJobExpiredError: If the job expired or was cancelled.
357
+ BatchJobFailedError: If the job terminated in a failed state.
358
+ TimeoutError: If timeout is reached before the job completes.
359
+ """
360
+ headers = _get_batch_headers(provider, api_key)
361
+ terminal = _TERMINAL_STATES[provider]
362
+ success = _SUCCESS_STATES[provider]
363
+
364
+ start = time.time()
365
+ attempt = 0
366
+
367
+ if provider == "google":
368
+ status_url = BATCH_ENDPOINTS["google"]["status"].format(job_name=job_id)
369
+ else:
370
+ status_url = BATCH_ENDPOINTS[provider]["status"].format(job_id=job_id)
371
+
372
+ while True:
373
+ elapsed = time.time() - start
374
+ if elapsed >= timeout:
375
+ raise TimeoutError(
376
+ f"Batch job '{job_id}' did not complete within {timeout/3600:.1f}h. "
377
+ f"Increase batch_timeout or switch to synchronous mode."
378
+ )
379
+
380
+ try:
381
+ resp = requests.get(status_url, headers=headers, timeout=30)
382
+ if resp.status_code >= 500:
383
+ # Server error — back off and retry
384
+ wait = min(60 * (2 ** min(attempt, 4)), 300)
385
+ print(f" [batch] Server error {resp.status_code}; retrying in {wait}s...")
386
+ time.sleep(wait)
387
+ attempt += 1
388
+ continue
389
+ resp.raise_for_status()
390
+ except requests.exceptions.RequestException as e:
391
+ wait = min(60 * (2 ** min(attempt, 4)), 300)
392
+ print(f" [batch] Network error ({e}); retrying in {wait}s...")
393
+ time.sleep(wait)
394
+ attempt += 1
395
+ continue
396
+
397
+ attempt = 0
398
+ status_data = resp.json()
399
+
400
+ # Extract state string per provider
401
+ if provider == "openai":
402
+ state = status_data.get("status", "")
403
+ counts = status_data.get("request_counts", {})
404
+ progress_str = (
405
+ f"completed={counts.get('completed', '?')} "
406
+ f"failed={counts.get('failed', '?')} "
407
+ f"total={counts.get('total', '?')}"
408
+ )
409
+ elif provider == "anthropic":
410
+ state = status_data.get("processing_status", "")
411
+ counts = status_data.get("request_counts", {})
412
+ progress_str = (
413
+ f"processing={counts.get('processing', '?')} "
414
+ f"succeeded={counts.get('succeeded', '?')} "
415
+ f"errored={counts.get('errored', '?')}"
416
+ )
417
+ elif provider == "google":
418
+ # State lives at metadata.state in the batchGenerateContent response
419
+ state = (status_data.get("metadata", {}).get("state", "")
420
+ or status_data.get("state", ""))
421
+ progress_str = f"state={state}"
422
+ elif provider == "mistral":
423
+ state = status_data.get("status", "")
424
+ progress_str = (
425
+ f"succeeded={status_data.get('succeeded_requests', '?')} "
426
+ f"failed={status_data.get('failed_requests', '?')} "
427
+ f"total={status_data.get('total_requests', '?')}"
428
+ )
429
+ elif provider == "xai":
430
+ state = status_data.get("status", "")
431
+ counts = status_data.get("request_counts", {})
432
+ progress_str = (
433
+ f"completed={counts.get('completed', '?')} "
434
+ f"failed={counts.get('failed', '?')}"
435
+ )
436
+ else:
437
+ state = ""
438
+ progress_str = ""
439
+
440
+ print(f" [batch] {provider} | elapsed={elapsed:.0f}s | {progress_str} | state={state}")
441
+
442
+ if state in terminal:
443
+ if state not in success:
444
+ expired = {"expired", "TIMEOUT_EXCEEDED", "JOB_STATE_CANCELLED", "CANCELLATION_REQUESTED"}
445
+ if state in expired or "cancel" in state.lower() or "timeout" in state.lower():
446
+ raise BatchJobExpiredError(
447
+ f"Batch job '{job_id}' expired/was cancelled (state: {state}). "
448
+ f"Job ID saved above — check provider dashboard for details."
449
+ )
450
+ raise BatchJobFailedError(
451
+ f"Batch job '{job_id}' failed (state: {state}). "
452
+ f"Check the provider dashboard for details."
453
+ )
454
+ return status_data
455
+
456
+ time.sleep(interval)
457
+
458
+
459
+ # =============================================================================
460
+ # Result download
461
+ # =============================================================================
462
+
463
+ def _download_batch_results(
464
+ provider: str,
465
+ api_key: str,
466
+ job_id: str,
467
+ status_data: dict,
468
+ ) -> str:
469
+ """
470
+ Download completed batch results as raw JSONL text.
471
+
472
+ Args:
473
+ provider: Provider name
474
+ api_key: API key
475
+ job_id: Batch job ID
476
+ status_data: Final status dict from polling (contains output file references)
477
+
478
+ Returns:
479
+ Raw JSONL string (one JSON object per line)
480
+ """
481
+ headers = _get_batch_headers(provider, api_key)
482
+
483
+ if provider == "openai":
484
+ output_file_id = status_data.get("output_file_id")
485
+ if not output_file_id:
486
+ raise RuntimeError("OpenAI batch: no output_file_id in completed status")
487
+ url = BATCH_ENDPOINTS["openai"]["results"].format(file_id=output_file_id)
488
+ headers_dl = dict(headers)
489
+ headers_dl.pop("Content-Type", None)
490
+ resp = requests.get(url, headers=headers_dl, timeout=120)
491
+ resp.raise_for_status()
492
+ return resp.text
493
+
494
+ elif provider == "anthropic":
495
+ url = BATCH_ENDPOINTS["anthropic"]["results"].format(job_id=job_id)
496
+ headers_dl = dict(headers)
497
+ headers_dl.pop("Content-Type", None)
498
+ resp = requests.get(url, headers=headers_dl, timeout=120, stream=True)
499
+ resp.raise_for_status()
500
+ return resp.text
501
+
502
+ elif provider == "google":
503
+ # Inline batch results live in the operation response:
504
+ # status_data["response"]["inlinedResponses"]["inlinedResponses"] → list of items
505
+ resp_outer = status_data.get("response", {})
506
+ inlined_wrapper = resp_outer.get("inlinedResponses", {})
507
+ if isinstance(inlined_wrapper, dict):
508
+ inlined = inlined_wrapper.get("inlinedResponses", [])
509
+ elif isinstance(inlined_wrapper, list):
510
+ inlined = inlined_wrapper
511
+ else:
512
+ inlined = []
513
+ if not inlined:
514
+ raise RuntimeError(
515
+ f"Google batch: no inlinedResponses in completed status. "
516
+ f"Status keys: {list(status_data.keys())}, "
517
+ f"response keys: {list(resp_outer.keys()) if isinstance(resp_outer, dict) else resp_outer}"
518
+ )
519
+ # Responses are NOT necessarily in order — use the metadata.key from each item,
520
+ # which preserves the original request key (e.g. "item-42") for correct mapping.
521
+ lines = [
522
+ json.dumps({"key": item.get("metadata", {}).get("key", f"item-{i}"), **item})
523
+ for i, item in enumerate(inlined)
524
+ ]
525
+ return "\n".join(lines)
526
+
527
+ elif provider == "mistral":
528
+ output_file_id = status_data.get("output_file")
529
+ if not output_file_id:
530
+ raise RuntimeError("Mistral batch: no output_file in completed status")
531
+ url = BATCH_ENDPOINTS["mistral"]["results"].format(file_id=output_file_id)
532
+ headers_dl = dict(headers)
533
+ headers_dl.pop("Content-Type", None)
534
+ resp = requests.get(url, headers=headers_dl, timeout=120)
535
+ resp.raise_for_status()
536
+ return resp.text
537
+
538
+ elif provider == "xai":
539
+ url = BATCH_ENDPOINTS["xai"]["results"].format(job_id=job_id)
540
+ headers_dl = dict(headers)
541
+ headers_dl.pop("Content-Type", None)
542
+ resp = requests.get(url, headers=headers_dl, timeout=120)
543
+ resp.raise_for_status()
544
+ return resp.text
545
+
546
+ raise ValueError(f"Unsupported batch provider: {provider}")
547
+
548
+
549
+ # =============================================================================
550
+ # Result parsing
551
+ # =============================================================================
552
+
553
+ def _parse_batch_results(
554
+ provider: str,
555
+ raw_results: str,
556
+ custom_id_map: dict,
557
+ client: "UnifiedLLMClient",
558
+ parse_mode: str = "json",
559
+ ) -> dict:
560
+ """
561
+ Parse the downloaded JSONL results back into per-item strings.
562
+
563
+ Args:
564
+ provider: Provider name
565
+ raw_results: Raw JSONL text from the batch results download
566
+ custom_id_map: Dict mapping custom_id string → original item index
567
+ client: UnifiedLLMClient instance (used to call _parse_response())
568
+ parse_mode: "json" (default) runs extract_json() on responses,
569
+ "text" returns raw text as-is (for summarization)
570
+
571
+ Returns:
572
+ Dict mapping item_index → (result_str, error_or_None)
573
+ Missing items (job dropped them) get (None, "Missing from batch results").
574
+ """
575
+ parsed_results = {}
576
+
577
+ for line in raw_results.splitlines():
578
+ line = line.strip()
579
+ if not line:
580
+ continue
581
+ try:
582
+ data = json.loads(line)
583
+ except json.JSONDecodeError:
584
+ continue
585
+
586
+ # Extract custom_id and the embedded response object
587
+ if provider == "openai":
588
+ custom_id = data.get("custom_id")
589
+ response_body = data.get("response", {}).get("body")
590
+ error_val = data.get("response", {}).get("error")
591
+ if error_val or response_body is None:
592
+ error_msg = str(error_val) if error_val else "No response body"
593
+ idx = custom_id_map.get(custom_id)
594
+ if idx is not None:
595
+ parsed_results[idx] = (None, error_msg)
596
+ continue
597
+ raw_text = client._parse_response(response_body)
598
+
599
+ elif provider == "anthropic":
600
+ custom_id = data.get("custom_id")
601
+ result = data.get("result", {})
602
+ if result.get("type") != "succeeded":
603
+ error_msg = str(result.get("error", "Request did not succeed"))
604
+ idx = custom_id_map.get(custom_id)
605
+ if idx is not None:
606
+ parsed_results[idx] = (None, error_msg)
607
+ continue
608
+ raw_text = client._parse_response(result.get("message", {}))
609
+
610
+ elif provider == "google":
611
+ # Output JSONL: {"key": "item-0", "response": {generateContent response}}
612
+ # or {"key": "item-0", "error": {"code": ..., "message": ...}}
613
+ custom_id = data.get("key")
614
+ error_val = data.get("error")
615
+ response_data = data.get("response")
616
+ if error_val or response_data is None:
617
+ error_msg = str(error_val) if error_val else "No response in batch output"
618
+ idx = custom_id_map.get(custom_id)
619
+ if idx is not None:
620
+ parsed_results[idx] = (None, error_msg)
621
+ continue
622
+ raw_text = client._parse_response(response_data)
623
+
624
+ elif provider == "mistral":
625
+ # Mistral batch output mirrors OpenAI: response.body holds the completion
626
+ custom_id = data.get("custom_id")
627
+ response_obj = data.get("response", {})
628
+ error_val = data.get("error") or response_obj.get("error")
629
+ if error_val:
630
+ idx = custom_id_map.get(custom_id)
631
+ if idx is not None:
632
+ parsed_results[idx] = (None, str(error_val))
633
+ continue
634
+ response_body = response_obj.get("body", response_obj)
635
+ raw_text = client._parse_response(response_body)
636
+
637
+ elif provider == "xai":
638
+ custom_id = data.get("custom_id")
639
+ response_body = data.get("response", {}).get("body")
640
+ error_val = data.get("response", {}).get("error")
641
+ if error_val or response_body is None:
642
+ error_msg = str(error_val) if error_val else "No response body"
643
+ idx = custom_id_map.get(custom_id)
644
+ if idx is not None:
645
+ parsed_results[idx] = (None, error_msg)
646
+ continue
647
+ raw_text = client._parse_response(response_body)
648
+
649
+ else:
650
+ continue
651
+
652
+ idx = custom_id_map.get(custom_id)
653
+ if idx is None:
654
+ continue
655
+
656
+ if parse_mode == "text":
657
+ parsed_results[idx] = (raw_text, None)
658
+ else:
659
+ json_str = extract_json(raw_text)
660
+ parsed_results[idx] = (json_str, None)
661
+
662
+ return parsed_results
663
+
664
+
665
+ # =============================================================================
666
+ # Per-model helpers (used by both single-model and ensemble batch paths)
667
+ # =============================================================================
668
+
669
+ def _run_one_batch_job(
670
+ cfg: dict,
671
+ items: list,
672
+ prompt_params: dict,
673
+ batch_poll_interval: float = 30.0,
674
+ batch_timeout: float = 86400.0,
675
+ ) -> dict:
676
+ """
677
+ Submit, poll, download, and parse a batch job for one model.
678
+ Returns {item_index: (json_str_or_None, error_or_None)}.
679
+ """
680
+ from .text_functions_ensemble import build_text_classification_prompt
681
+
682
+ provider = cfg["provider"]
683
+ api_key = cfg["api_key"]
684
+ model = cfg["model"]
685
+
686
+ categories_str = prompt_params["categories_str"]
687
+ survey_question_context = prompt_params.get("survey_question_context", "")
688
+ examples_text = prompt_params.get("examples_text", "")
689
+ chain_of_thought = prompt_params.get("chain_of_thought", False)
690
+ context_prompt = prompt_params.get("context_prompt", False)
691
+ step_back_prompt = prompt_params.get("step_back_prompt", False)
692
+ stepback_insights = prompt_params.get("stepback_insights", {})
693
+ json_schema = prompt_params.get("json_schema")
694
+ creativity = prompt_params.get("creativity")
695
+ thinking_budget = prompt_params.get("thinking_budget", 0)
696
+
697
+ client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
698
+
699
+ print(f"\n[batch] Building {len(items)} request(s) for {model} ({provider})...")
700
+
701
+ # Step 1: Build per-item payloads and JSONL
702
+ custom_id_map = {}
703
+ jsonl_lines = []
704
+ requests_list = []
705
+
706
+ for idx, item in enumerate(items):
707
+ custom_id = f"item-{idx}"
708
+ custom_id_map[custom_id] = idx
709
+
710
+ messages = build_text_classification_prompt(
711
+ response_text=str(item) if item is not None else "",
712
+ categories_str=categories_str,
713
+ survey_question_context=survey_question_context,
714
+ examples_text=examples_text,
715
+ chain_of_thought=chain_of_thought,
716
+ context_prompt=context_prompt,
717
+ step_back_prompt=step_back_prompt,
718
+ stepback_insights=stepback_insights,
719
+ model_name=model,
720
+ multi_label=prompt_params.get("multi_label", True),
721
+ )
722
+
723
+ payload = client._build_payload(
724
+ messages=messages,
725
+ json_schema=json_schema,
726
+ creativity=creativity,
727
+ thinking_budget=thinking_budget if thinking_budget and thinking_budget > 0 else None,
728
+ )
729
+
730
+ line = _build_jsonl_line(provider, custom_id, payload, model)
731
+ jsonl_lines.append(line)
732
+ if provider in ("anthropic", "xai", "google"):
733
+ requests_list.append(line)
734
+
735
+ jsonl_bytes = b"\n".join(json.dumps(line).encode("utf-8") for line in jsonl_lines)
736
+
737
+ # Step 2: Upload file (OpenAI, Mistral only — Google uses inline requests)
738
+ file_id = None
739
+ if provider in ("openai", "mistral"):
740
+ print(f"[batch] Uploading JSONL ({len(jsonl_bytes)/1024:.1f} KB) to {provider}...")
741
+ file_id = _upload_jsonl(provider, api_key, jsonl_bytes)
742
+ print(f"[batch] File uploaded: {file_id}")
743
+
744
+ # Step 3: Create batch job
745
+ print(f"[batch] Creating batch job for {model}...")
746
+ job_id = _create_batch_job(
747
+ provider=provider,
748
+ api_key=api_key,
749
+ model=model,
750
+ file_id=file_id,
751
+ requests_list=requests_list if provider in ("anthropic", "xai", "google") else None,
752
+ )
753
+ print(f"[batch] Job created: {job_id}")
754
+ print(f"[batch] Polling every {batch_poll_interval}s (timeout={batch_timeout/3600:.1f}h)...")
755
+
756
+ # Step 4: Poll until complete
757
+ status_data = _poll_batch_job(
758
+ provider=provider,
759
+ api_key=api_key,
760
+ job_id=job_id,
761
+ interval=batch_poll_interval,
762
+ timeout=batch_timeout,
763
+ )
764
+ print(f"[batch] Job complete for {model}.")
765
+
766
+ # Step 5: Download results
767
+ print(f"[batch] Downloading results for {model}...")
768
+ raw_results = _download_batch_results(
769
+ provider=provider,
770
+ api_key=api_key,
771
+ job_id=job_id,
772
+ status_data=status_data,
773
+ )
774
+
775
+ # Step 6: Parse results
776
+ return _parse_batch_results(
777
+ provider=provider,
778
+ raw_results=raw_results,
779
+ custom_id_map=custom_id_map,
780
+ client=client,
781
+ )
782
+
783
+
784
+ def _run_one_sync_model(
785
+ cfg: dict,
786
+ items: list,
787
+ prompt_params: dict,
788
+ ) -> dict:
789
+ """
790
+ Classify all items synchronously for one model (fallback for unsupported batch providers).
791
+ Returns {item_index: (json_str_or_None, error_or_None)}.
792
+ """
793
+ from .text_functions_ensemble import build_text_classification_prompt
794
+
795
+ provider = cfg["provider"]
796
+ api_key = cfg["api_key"]
797
+ model = cfg["model"]
798
+ json_schema = prompt_params.get("json_schema")
799
+ creativity = prompt_params.get("creativity")
800
+ thinking_budget = prompt_params.get("thinking_budget", 0)
801
+
802
+ client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
803
+ item_results = {}
804
+
805
+ print(f"\n[batch] Synchronous fallback for {model} ({provider}): {len(items)} item(s)...")
806
+
807
+ for idx, item in enumerate(items):
808
+ messages = build_text_classification_prompt(
809
+ response_text=str(item) if item is not None else "",
810
+ categories_str=prompt_params["categories_str"],
811
+ survey_question_context=prompt_params.get("survey_question_context", ""),
812
+ examples_text=prompt_params.get("examples_text", ""),
813
+ chain_of_thought=prompt_params.get("chain_of_thought", False),
814
+ context_prompt=prompt_params.get("context_prompt", False),
815
+ step_back_prompt=prompt_params.get("step_back_prompt", False),
816
+ stepback_insights=prompt_params.get("stepback_insights", {}),
817
+ model_name=model,
818
+ multi_label=prompt_params.get("multi_label", True),
819
+ )
820
+ try:
821
+ raw = client.complete(
822
+ messages=messages,
823
+ json_schema=json_schema,
824
+ creativity=creativity,
825
+ thinking_budget=thinking_budget if thinking_budget and thinking_budget > 0 else None,
826
+ )
827
+ item_results[idx] = (extract_json(raw), None)
828
+ except Exception as e:
829
+ item_results[idx] = (None, str(e))
830
+
831
+ return item_results
832
+
833
+
834
+ # =============================================================================
835
+ # Main entry point (single-model)
836
+ # =============================================================================
837
+
838
+ def run_batch_classify(
839
+ items: list,
840
+ cfg: dict,
841
+ categories: list,
842
+ prompt_params: dict,
843
+ filename: str = None,
844
+ save_directory: str = None,
845
+ batch_poll_interval: float = 30.0,
846
+ batch_timeout: float = 86400.0,
847
+ fail_strategy: str = "partial",
848
+ ) -> "pd.DataFrame":
849
+ """
850
+ Run batch classification for a single model against a list of text items.
851
+
852
+ This is the main entry point called from classify() when batch_mode=True.
853
+ Returns a DataFrame in the same format as the synchronous single-model path.
854
+
855
+ Args:
856
+ items: List of text strings to classify
857
+ cfg: Model config dict from prepare_model_configs() (single entry)
858
+ categories: List of category names
859
+ prompt_params: Dict containing prompt-building parameters:
860
+ - categories_str (str)
861
+ - survey_question_context (str)
862
+ - examples_text (str)
863
+ - chain_of_thought (bool)
864
+ - context_prompt (bool)
865
+ - step_back_prompt (bool)
866
+ - stepback_insights (dict)
867
+ - json_schema (dict or None)
868
+ - creativity (float or None)
869
+ - thinking_budget (int)
870
+ filename: Optional CSV filename to save results
871
+ save_directory: Optional directory to save results
872
+ batch_poll_interval: Seconds between poll checks (default 30)
873
+ batch_timeout: Max seconds to wait for job (default 86400 = 24h)
874
+ fail_strategy: "partial" or "strict"
875
+
876
+ Returns:
877
+ pd.DataFrame with category_1, category_2, ... columns (same as sync path)
878
+ """
879
+ from .text_functions_ensemble import aggregate_results, build_output_dataframes
880
+
881
+ # =========================================================================
882
+ # Steps 1-6: Submit, poll, download, and parse the batch job
883
+ # =========================================================================
884
+ item_results = _run_one_batch_job(
885
+ cfg=cfg,
886
+ items=items,
887
+ prompt_params=prompt_params,
888
+ batch_poll_interval=batch_poll_interval,
889
+ batch_timeout=batch_timeout,
890
+ )
891
+
892
+ # =========================================================================
893
+ # Step 7: Build all_results list in the format aggregate_results expects
894
+ # =========================================================================
895
+ model_name = cfg["sanitized_name"]
896
+ all_results = []
897
+
898
+ for idx, item in enumerate(items):
899
+ json_str, error = item_results.get(idx, (None, "Missing from batch results"))
900
+
901
+ model_results = {model_name: (json_str, error)}
902
+ aggregated = aggregate_results(
903
+ model_results=model_results,
904
+ categories=categories,
905
+ consensus_threshold="unanimous",
906
+ fail_strategy=fail_strategy,
907
+ )
908
+
909
+ all_results.append({
910
+ "response": str(item) if item is not None else "",
911
+ "model_results": model_results,
912
+ "aggregated": aggregated,
913
+ "skipped": (item is None or (isinstance(item, float) and __import__("math").isnan(item))),
914
+ })
915
+
916
+ # =========================================================================
917
+ # Step 8: Build output DataFrame (reuses existing pipeline)
918
+ # =========================================================================
919
+ return build_output_dataframes(
920
+ all_results=all_results,
921
+ model_configs=[cfg],
922
+ categories=categories,
923
+ filename=filename,
924
+ save_directory=save_directory,
925
+ )
926
+
927
+
928
+ # =============================================================================
929
+ # Ensemble batch entry point
930
+ # =============================================================================
931
+
932
+ def run_batch_ensemble_classify(
933
+ items: list,
934
+ model_configs: list,
935
+ categories: list,
936
+ prompt_params_per_model: dict,
937
+ consensus_threshold,
938
+ fail_strategy: str = "partial",
939
+ filename: str = None,
940
+ save_directory: str = None,
941
+ batch_poll_interval: float = 30.0,
942
+ batch_timeout: float = 86400.0,
943
+ ) -> "pd.DataFrame":
944
+ """
945
+ Run batch classification for multiple models concurrently, then merge results.
946
+
947
+ Batch-capable providers (openai, anthropic, google, mistral, xai) submit jobs
948
+ concurrently. Unsupported providers (huggingface, perplexity, ollama) fall back
949
+ to synchronous per-item calls and are merged with the batch results.
950
+
951
+ Args:
952
+ items: List of text strings to classify
953
+ model_configs: List of model config dicts from prepare_model_configs()
954
+ categories: List of category names
955
+ prompt_params_per_model: Dict mapping model name → prompt_params dict
956
+ consensus_threshold: Agreement threshold (str or float)
957
+ fail_strategy: "partial" or "strict"
958
+ filename: Optional CSV filename to save results
959
+ save_directory: Optional directory to save results
960
+ batch_poll_interval: Seconds between poll checks (default 30)
961
+ batch_timeout: Max seconds to wait for job (default 86400 = 24h)
962
+
963
+ Returns:
964
+ pd.DataFrame with consensus columns and per-model columns (same as sync ensemble)
965
+ """
966
+ import math
967
+ from concurrent.futures import ThreadPoolExecutor, as_completed
968
+
969
+ from .text_functions_ensemble import aggregate_results, build_output_dataframes
970
+
971
+ batch_cfgs = [c for c in model_configs if c["provider"] not in UNSUPPORTED_BATCH_PROVIDERS]
972
+ sync_cfgs = [c for c in model_configs if c["provider"] in UNSUPPORTED_BATCH_PROVIDERS]
973
+
974
+ if batch_cfgs:
975
+ print(
976
+ f"\n[batch ensemble] {len(batch_cfgs)} model(s) will use batch API: "
977
+ f"{', '.join(c['model'] for c in batch_cfgs)}"
978
+ )
979
+ if sync_cfgs:
980
+ print(
981
+ f"[batch ensemble] {len(sync_cfgs)} model(s) will use synchronous fallback: "
982
+ f"{', '.join(c['model'] for c in sync_cfgs)}"
983
+ )
984
+
985
+ all_model_results = {}
986
+
987
+ def _run_cfg(cfg):
988
+ model_key = cfg["sanitized_name"]
989
+ pp = prompt_params_per_model[cfg["model"]]
990
+ if cfg["provider"] in UNSUPPORTED_BATCH_PROVIDERS:
991
+ return model_key, _run_one_sync_model(cfg, items, pp)
992
+ else:
993
+ return model_key, _run_one_batch_job(cfg, items, pp, batch_poll_interval, batch_timeout)
994
+
995
+ with ThreadPoolExecutor(max_workers=len(model_configs)) as executor:
996
+ futures = {executor.submit(_run_cfg, cfg): cfg for cfg in model_configs}
997
+ for future in as_completed(futures):
998
+ model_key, result = future.result()
999
+ all_model_results[model_key] = result
1000
+
1001
+ all_results = []
1002
+ for idx, item in enumerate(items):
1003
+ model_results = {
1004
+ cfg["sanitized_name"]: all_model_results[cfg["sanitized_name"]].get(
1005
+ idx, (None, "Missing from batch results")
1006
+ )
1007
+ for cfg in model_configs
1008
+ }
1009
+ aggregated = aggregate_results(
1010
+ model_results=model_results,
1011
+ categories=categories,
1012
+ consensus_threshold=consensus_threshold,
1013
+ fail_strategy=fail_strategy,
1014
+ )
1015
+ skipped = item is None or (isinstance(item, float) and math.isnan(item))
1016
+ all_results.append({
1017
+ "response": str(item) if not skipped else "",
1018
+ "model_results": model_results,
1019
+ "aggregated": aggregated,
1020
+ "skipped": skipped,
1021
+ })
1022
+
1023
+ return build_output_dataframes(all_results, model_configs, categories, filename, save_directory)
1024
+
1025
+
1026
+ # =============================================================================
1027
+ # Batch summarization
1028
+ # =============================================================================
1029
+
1030
+ def _run_one_batch_summarize_job(
1031
+ cfg: dict,
1032
+ items: list,
1033
+ prompt_params: dict,
1034
+ batch_poll_interval: float = 30.0,
1035
+ batch_timeout: float = 86400.0,
1036
+ ) -> dict:
1037
+ """
1038
+ Submit, poll, download, and parse a batch summarization job for one model.
1039
+ Returns {item_index: (summary_text_or_None, error_or_None)}.
1040
+ """
1041
+ from .text_functions_ensemble import build_text_summarization_prompt, build_summary_json_schema
1042
+
1043
+ provider = cfg["provider"]
1044
+ api_key = cfg["api_key"]
1045
+ model = cfg["model"]
1046
+
1047
+ input_description = prompt_params.get("input_description", "")
1048
+ summary_instructions = prompt_params.get("summary_instructions", "")
1049
+ max_length = prompt_params.get("max_length")
1050
+ focus = prompt_params.get("focus")
1051
+ chain_of_thought = prompt_params.get("chain_of_thought", False)
1052
+ context_prompt = prompt_params.get("context_prompt", False)
1053
+ step_back_prompt = prompt_params.get("step_back_prompt", False)
1054
+ stepback_insights = prompt_params.get("stepback_insights", {})
1055
+ creativity = prompt_params.get("creativity")
1056
+
1057
+ include_additional = provider != "google"
1058
+ json_schema = build_summary_json_schema(include_additional)
1059
+
1060
+ client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
1061
+
1062
+ print(f"\n[batch] Building {len(items)} summarization request(s) for {model} ({provider})...")
1063
+
1064
+ custom_id_map = {}
1065
+ jsonl_lines = []
1066
+ requests_list = []
1067
+
1068
+ for idx, item in enumerate(items):
1069
+ custom_id = f"item-{idx}"
1070
+ custom_id_map[custom_id] = idx
1071
+
1072
+ text = str(item) if item is not None else ""
1073
+
1074
+ messages = build_text_summarization_prompt(
1075
+ response_text=text,
1076
+ input_description=input_description,
1077
+ summary_instructions=summary_instructions,
1078
+ max_length=max_length,
1079
+ focus=focus,
1080
+ chain_of_thought=chain_of_thought,
1081
+ context_prompt=context_prompt,
1082
+ step_back_prompt=step_back_prompt,
1083
+ stepback_insights=stepback_insights,
1084
+ model_name=model,
1085
+ )
1086
+
1087
+ payload = client._build_payload(
1088
+ messages=messages,
1089
+ json_schema=json_schema,
1090
+ creativity=creativity,
1091
+ )
1092
+
1093
+ line = _build_jsonl_line(provider, custom_id, payload, model)
1094
+ jsonl_lines.append(line)
1095
+ if provider in ("anthropic", "xai", "google"):
1096
+ requests_list.append(line)
1097
+
1098
+ jsonl_bytes = b"\n".join(json.dumps(line).encode("utf-8") for line in jsonl_lines)
1099
+
1100
+ # Upload file (OpenAI, Mistral only)
1101
+ file_id = None
1102
+ if provider in ("openai", "mistral"):
1103
+ print(f"[batch] Uploading JSONL ({len(jsonl_bytes)/1024:.1f} KB) to {provider}...")
1104
+ file_id = _upload_jsonl(provider, api_key, jsonl_bytes)
1105
+ print(f"[batch] File uploaded: {file_id}")
1106
+
1107
+ # Create batch job
1108
+ print(f"[batch] Creating batch job for {model}...")
1109
+ job_id = _create_batch_job(
1110
+ provider=provider,
1111
+ api_key=api_key,
1112
+ model=model,
1113
+ file_id=file_id,
1114
+ requests_list=requests_list if provider in ("anthropic", "xai", "google") else None,
1115
+ )
1116
+ print(f"[batch] Job created: {job_id}")
1117
+ print(f"[batch] Polling every {batch_poll_interval}s (timeout={batch_timeout/3600:.1f}h)...")
1118
+
1119
+ # Poll until complete
1120
+ status_data = _poll_batch_job(
1121
+ provider=provider,
1122
+ api_key=api_key,
1123
+ job_id=job_id,
1124
+ interval=batch_poll_interval,
1125
+ timeout=batch_timeout,
1126
+ )
1127
+ print(f"[batch] Job complete for {model}.")
1128
+
1129
+ # Download results
1130
+ print(f"[batch] Downloading results for {model}...")
1131
+ raw_results = _download_batch_results(
1132
+ provider=provider,
1133
+ api_key=api_key,
1134
+ job_id=job_id,
1135
+ status_data=status_data,
1136
+ )
1137
+
1138
+ # Parse results — use text mode since summaries are JSON-wrapped
1139
+ # but we still want extract_json since output is {"summary": "..."}
1140
+ return _parse_batch_results(
1141
+ provider=provider,
1142
+ raw_results=raw_results,
1143
+ custom_id_map=custom_id_map,
1144
+ client=client,
1145
+ parse_mode="json",
1146
+ )
1147
+
1148
+
1149
+ def _run_one_sync_summarize_model(
1150
+ cfg: dict,
1151
+ items: list,
1152
+ prompt_params: dict,
1153
+ ) -> dict:
1154
+ """
1155
+ Summarize all items synchronously for one model (fallback for unsupported batch providers).
1156
+ Returns {item_index: (json_str_or_None, error_or_None)}.
1157
+ """
1158
+ from .text_functions_ensemble import build_text_summarization_prompt, build_summary_json_schema
1159
+
1160
+ provider = cfg["provider"]
1161
+ api_key = cfg["api_key"]
1162
+ model = cfg["model"]
1163
+ creativity = prompt_params.get("creativity")
1164
+
1165
+ include_additional = provider != "google"
1166
+ json_schema = build_summary_json_schema(include_additional)
1167
+
1168
+ client = UnifiedLLMClient(provider=provider, api_key=api_key, model=model)
1169
+ item_results = {}
1170
+
1171
+ print(f"\n[batch] Synchronous fallback for {model} ({provider}): {len(items)} item(s)...")
1172
+
1173
+ for idx, item in enumerate(items):
1174
+ messages = build_text_summarization_prompt(
1175
+ response_text=str(item) if item is not None else "",
1176
+ input_description=prompt_params.get("input_description", ""),
1177
+ summary_instructions=prompt_params.get("summary_instructions", ""),
1178
+ max_length=prompt_params.get("max_length"),
1179
+ focus=prompt_params.get("focus"),
1180
+ chain_of_thought=prompt_params.get("chain_of_thought", False),
1181
+ context_prompt=prompt_params.get("context_prompt", False),
1182
+ step_back_prompt=prompt_params.get("step_back_prompt", False),
1183
+ stepback_insights=prompt_params.get("stepback_insights", {}),
1184
+ model_name=model,
1185
+ )
1186
+ try:
1187
+ raw, _err = client.complete(
1188
+ messages=messages,
1189
+ json_schema=json_schema,
1190
+ creativity=creativity,
1191
+ )
1192
+ item_results[idx] = (extract_json(raw), None)
1193
+ except Exception as e:
1194
+ item_results[idx] = (None, str(e))
1195
+
1196
+ return item_results
1197
+
1198
+
1199
+ def run_batch_summarize(
1200
+ items: list,
1201
+ cfg: dict,
1202
+ prompt_params: dict,
1203
+ filename: str = None,
1204
+ save_directory: str = None,
1205
+ batch_poll_interval: float = 30.0,
1206
+ batch_timeout: float = 86400.0,
1207
+ fail_strategy: str = "partial",
1208
+ ) -> "pd.DataFrame":
1209
+ """
1210
+ Run batch summarization for a single model.
1211
+
1212
+ Returns a DataFrame with input_data, summary, processing_status columns.
1213
+ """
1214
+ import math
1215
+ import pandas as pd
1216
+ from .text_functions_ensemble import extract_summary_from_json
1217
+
1218
+ item_results = _run_one_batch_summarize_job(
1219
+ cfg=cfg,
1220
+ items=items,
1221
+ prompt_params=prompt_params,
1222
+ batch_poll_interval=batch_poll_interval,
1223
+ batch_timeout=batch_timeout,
1224
+ )
1225
+
1226
+ rows = []
1227
+ for idx, item in enumerate(items):
1228
+ json_str, error = item_results.get(idx, (None, "Missing from batch results"))
1229
+
1230
+ text = str(item) if item is not None else ""
1231
+ is_skipped = item is None or (isinstance(item, float) and math.isnan(item))
1232
+
1233
+ if is_skipped:
1234
+ rows.append({"input_data": text, "summary": "", "processing_status": "skipped"})
1235
+ continue
1236
+
1237
+ if error:
1238
+ rows.append({"input_data": text, "summary": "", "processing_status": "error"})
1239
+ continue
1240
+
1241
+ if fail_strategy == "strict" and error:
1242
+ rows.append({"input_data": text, "summary": "", "processing_status": "error"})
1243
+ continue
1244
+
1245
+ is_valid, summary_text = extract_summary_from_json(json_str)
1246
+ if is_valid and summary_text:
1247
+ rows.append({"input_data": text, "summary": summary_text, "processing_status": "success"})
1248
+ else:
1249
+ rows.append({"input_data": text, "summary": "", "processing_status": "error"})
1250
+
1251
+ df = pd.DataFrame(rows)
1252
+
1253
+ if filename:
1254
+ import os
1255
+ save_path = os.path.join(save_directory, filename) if save_directory else filename
1256
+ if save_directory:
1257
+ os.makedirs(save_directory, exist_ok=True)
1258
+ df.to_csv(save_path, index=False)
1259
+ print(f"\nResults saved to: {save_path}")
1260
+
1261
+ return df
1262
+
1263
+
1264
+ def run_batch_ensemble_summarize(
1265
+ items: list,
1266
+ model_configs: list,
1267
+ prompt_params_per_model: dict,
1268
+ fail_strategy: str = "partial",
1269
+ filename: str = None,
1270
+ save_directory: str = None,
1271
+ batch_poll_interval: float = 30.0,
1272
+ batch_timeout: float = 86400.0,
1273
+ max_retries: int = 3,
1274
+ ) -> "pd.DataFrame":
1275
+ """
1276
+ Run batch summarization for multiple models concurrently, then synthesize.
1277
+
1278
+ Each model submits its own batch job. After all complete, summaries are
1279
+ synthesized into a consensus using _synthesize_summaries().
1280
+ """
1281
+ import math
1282
+ import pandas as pd
1283
+ from concurrent.futures import ThreadPoolExecutor, as_completed
1284
+ from .text_functions_ensemble import extract_summary_from_json, _synthesize_summaries
1285
+
1286
+ batch_cfgs = [c for c in model_configs if c["provider"] not in UNSUPPORTED_BATCH_PROVIDERS]
1287
+ sync_cfgs = [c for c in model_configs if c["provider"] in UNSUPPORTED_BATCH_PROVIDERS]
1288
+
1289
+ if batch_cfgs:
1290
+ print(
1291
+ f"\n[batch ensemble] {len(batch_cfgs)} model(s) will use batch API: "
1292
+ f"{', '.join(c['model'] for c in batch_cfgs)}"
1293
+ )
1294
+ if sync_cfgs:
1295
+ print(
1296
+ f"[batch ensemble] {len(sync_cfgs)} model(s) will use synchronous fallback: "
1297
+ f"{', '.join(c['model'] for c in sync_cfgs)}"
1298
+ )
1299
+
1300
+ all_model_results = {}
1301
+
1302
+ def _run_cfg(cfg):
1303
+ model_key = cfg["sanitized_name"]
1304
+ pp = prompt_params_per_model[cfg["model"]]
1305
+ if cfg["provider"] in UNSUPPORTED_BATCH_PROVIDERS:
1306
+ return model_key, _run_one_sync_summarize_model(cfg, items, pp)
1307
+ else:
1308
+ return model_key, _run_one_batch_summarize_job(cfg, items, pp, batch_poll_interval, batch_timeout)
1309
+
1310
+ with ThreadPoolExecutor(max_workers=len(model_configs)) as executor:
1311
+ futures = {executor.submit(_run_cfg, cfg): cfg for cfg in model_configs}
1312
+ for future in as_completed(futures):
1313
+ model_key, result = future.result()
1314
+ all_model_results[model_key] = result
1315
+
1316
+ model_names = [cfg["sanitized_name"] for cfg in model_configs]
1317
+
1318
+ rows = []
1319
+ for idx, item in enumerate(items):
1320
+ text = str(item) if item is not None else ""
1321
+ is_skipped = item is None or (isinstance(item, float) and math.isnan(item))
1322
+
1323
+ if is_skipped:
1324
+ row = {"input_data": text, "summary": "", "processing_status": "skipped"}
1325
+ for mn in model_names:
1326
+ row[f"summary_{mn}"] = ""
1327
+ row["failed_models"] = ""
1328
+ rows.append(row)
1329
+ continue
1330
+
1331
+ summaries = {}
1332
+ errors = []
1333
+ for cfg in model_configs:
1334
+ mn = cfg["sanitized_name"]
1335
+ json_str, error = all_model_results[mn].get(idx, (None, "Missing from batch results"))
1336
+ if error:
1337
+ summaries[mn] = ""
1338
+ errors.append(mn)
1339
+ else:
1340
+ is_valid, summary_text = extract_summary_from_json(json_str)
1341
+ summaries[mn] = summary_text if is_valid else ""
1342
+ if not is_valid or not summary_text:
1343
+ errors.append(mn)
1344
+
1345
+ # fail_strategy="strict": blank everything if any model failed
1346
+ if fail_strategy == "strict" and errors:
1347
+ summaries = {k: "" for k in summaries}
1348
+
1349
+ row = {"input_data": text}
1350
+ for mn in model_names:
1351
+ row[f"summary_{mn}"] = summaries.get(mn, "")
1352
+
1353
+ # Synthesize consensus
1354
+ valid_summaries = {k: v for k, v in summaries.items() if v}
1355
+ if valid_summaries:
1356
+ synthesis_cfg = model_configs[0]
1357
+ consensus = _synthesize_summaries(
1358
+ summaries=valid_summaries,
1359
+ original_text=text,
1360
+ synthesis_config=synthesis_cfg,
1361
+ max_retries=max_retries,
1362
+ )
1363
+ row["summary"] = consensus
1364
+ else:
1365
+ row["summary"] = ""
1366
+
1367
+ row["failed_models"] = ",".join(errors) if errors else ""
1368
+
1369
+ if all(not s for s in summaries.values()):
1370
+ row["processing_status"] = "error"
1371
+ elif any(not s for s in summaries.values()):
1372
+ row["processing_status"] = "partial"
1373
+ else:
1374
+ row["processing_status"] = "success"
1375
+
1376
+ rows.append(row)
1377
+
1378
+ df = pd.DataFrame(rows)
1379
+
1380
+ if filename:
1381
+ import os
1382
+ save_path = os.path.join(save_directory, filename) if save_directory else filename
1383
+ if save_directory:
1384
+ os.makedirs(save_directory, exist_ok=True)
1385
+ df.to_csv(save_path, index=False)
1386
+ print(f"\nResults saved to: {save_path}")
1387
+
1388
+ return df