cat-stack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3644 @@
1
+ """
2
+ Ensemble text classification functions for CatLLM.
3
+
4
+ This module provides multi-model ensemble classification using parallel execution.
5
+ Multiple LLM models are called simultaneously and results are combined using
6
+ majority voting for more robust classification.
7
+
8
+ MODULE STRUCTURE:
9
+ =================
10
+ Helper Functions:
11
+ - sanitize_model_name(): Convert model names to column-safe suffixes
12
+ - prepare_model_configs(): Validate model configurations, check Ollama
13
+
14
+ Shared Utilities (reusable for image/pdf classification):
15
+ - normalize_model_input(): Convert various input formats to list of tuples
16
+ - gather_stepback_insights(): Get step-back insights from each model
17
+ - prepare_json_schemas(): Build JSON schemas per provider
18
+ - aggregate_results(): Majority voting consensus
19
+ - build_output_dataframes(): Build output DataFrames
20
+
21
+ Text-Specific (swap for image classification):
22
+ - build_text_classification_prompt(): Build prompt for text classification
23
+
24
+ Main Function:
25
+ - multi_class_ensemble(): Main entry point
26
+ - Supports single model (returns DataFrame) or multiple models (returns dict)
27
+ - Supports categories="auto" for auto-detection
28
+
29
+ Chain of Verification (CoVe):
30
+ ============================
31
+ CoVe is a 4-step prompting strategy to improve classification accuracy:
32
+ Step 1: Initial classification (existing classify_single)
33
+ Step 2: Generate verification questions about the classification
34
+ Step 3: Answer each verification question (up to 5 questions)
35
+ Step 4: Final corrected classification based on Q&A
36
+
37
+ Usage: Set chain_of_verification=True in multi_class_ensemble()
38
+ Note: CoVe requires ~4x API calls per response. Not recommended for ensemble mode
39
+ due to cost, but supported for single-model usage.
40
+ """
41
+
42
+ __all__ = ["classify_ensemble", "multi_class_ensemble", "summarize_ensemble"]
43
+
44
+ import json
45
+ import os
46
+ import re
47
+ import time
48
+ import pandas as pd
49
+ from pathlib import Path
50
+ from tqdm import tqdm
51
+ from concurrent.futures import ThreadPoolExecutor, as_completed
52
+ from typing import Optional, Callable, Union
53
+
54
+ from .text_functions import (
55
+ UnifiedLLMClient,
56
+ detect_provider,
57
+ build_json_schema,
58
+ extract_json,
59
+ validate_classification_json,
60
+ ollama_two_step_classify,
61
+ check_ollama_running,
62
+ check_ollama_model,
63
+ pull_ollama_model,
64
+ check_claude_cli_available,
65
+ _get_stepback_insight,
66
+ )
67
+
68
+ # PDF utility imports
69
+ from .pdf_functions import (
70
+ _load_pdf_files,
71
+ _get_pdf_pages,
72
+ _extract_page_as_image_bytes,
73
+ _extract_page_as_pdf_bytes,
74
+ _extract_page_text,
75
+ _encode_bytes_to_base64,
76
+ )
77
+
78
+ # Image utility imports
79
+ from .image_functions import (
80
+ _load_image_files,
81
+ _encode_image,
82
+ )
83
+
84
+
85
+ # =============================================================================
86
+ # Consensus Threshold Helper
87
+ # =============================================================================
88
+
89
+ def _resolve_consensus_threshold(threshold: Union[str, float, int]) -> float:
90
+ """
91
+ Convert consensus threshold to a numeric value.
92
+
93
+ Accepts both string aliases and numeric values for flexibility:
94
+ - String values: "majority" (0.5), "two-thirds" (0.67), "unanimous" (1.0)
95
+ - Numeric values: Any float between 0 and 1
96
+
97
+ Args:
98
+ threshold: Either a string alias or numeric value (0-1)
99
+
100
+ Returns:
101
+ float: The resolved threshold value
102
+
103
+ Examples:
104
+ >>> _resolve_consensus_threshold("majority")
105
+ 0.5
106
+ >>> _resolve_consensus_threshold("two-thirds")
107
+ 0.67
108
+ >>> _resolve_consensus_threshold(0.75)
109
+ 0.75
110
+ """
111
+ if isinstance(threshold, str):
112
+ mapping = {
113
+ "majority": 0.5,
114
+ "two-thirds": 0.67,
115
+ "two_thirds": 0.67,
116
+ "twothirds": 0.67,
117
+ "unanimous": 1.0,
118
+ }
119
+ resolved = mapping.get(threshold.lower().strip())
120
+ if resolved is None:
121
+ valid_options = ", ".join(f'"{k}"' for k in ["majority", "two-thirds", "unanimous"])
122
+ raise ValueError(
123
+ f"Invalid consensus_threshold string: '{threshold}'. "
124
+ f"Valid options: {valid_options}, or a numeric value between 0 and 1."
125
+ )
126
+ return resolved
127
+ else:
128
+ value = float(threshold)
129
+ if not 0 <= value <= 1:
130
+ raise ValueError(f"consensus_threshold must be between 0 and 1, got {value}")
131
+ return value
132
+
133
+
134
+ # =============================================================================
135
+ # Test Utilities (for debugging batch retry logic)
136
+ # =============================================================================
137
+
138
+ # Global flag and state for retry testing - set _TEST_FORCE_FAILURE = True to test
139
+ _TEST_FORCE_FAILURE = False
140
+ _test_attempted_pairs = set()
141
+
142
+
143
+ def _test_should_force_failure(response_text: str, model_name: str) -> bool:
144
+ """
145
+ Test helper: Returns True if this (response, model) pair should be forced to fail.
146
+
147
+ Only forces failure on FIRST attempt. Subsequent attempts (retries) will proceed normally.
148
+
149
+ Usage: Set _TEST_FORCE_FAILURE = True at module level to enable testing.
150
+ """
151
+ if not _TEST_FORCE_FAILURE:
152
+ return False
153
+
154
+ pair_key = (response_text[:50], model_name)
155
+ if pair_key not in _test_attempted_pairs:
156
+ _test_attempted_pairs.add(pair_key)
157
+ print(f" [TEST] Forcing error for: {model_name} on '{response_text[:30]}...'")
158
+ return True
159
+ return False
160
+
161
+
162
+ def _test_reset():
163
+ """Reset test state between test runs."""
164
+ global _test_attempted_pairs
165
+ _test_attempted_pairs = set()
166
+
167
+
168
+ # =============================================================================
169
+ # Input Type Detection
170
+ # =============================================================================
171
+
172
+ # Supported image extensions (from image_functions.py)
173
+ _IMAGE_EXTENSIONS = {
174
+ '.png', '.jpg', '.jpeg', '.gif', '.webp', '.svg', '.svgz',
175
+ '.avif', '.apng', '.tif', '.tiff', '.bmp', '.heif', '.heic',
176
+ '.ico', '.psd', '.jfif', '.pjpeg', '.pjp', '.jpe'
177
+ }
178
+
179
+
180
+ def _detect_input_type(input_data) -> str:
181
+ """
182
+ Detect if input is text strings, PDF files, or image files.
183
+
184
+ Auto-detection logic:
185
+ - If input ends in .pdf → PDF mode
186
+ - If input ends in image extension (.png, .jpg, etc.) → Image mode
187
+ - If input is a directory → Check first file to determine PDF or Image mode
188
+ - Otherwise → Text mode
189
+
190
+ Args:
191
+ input_data: Text strings, PDF paths, image paths, or directory path
192
+
193
+ Returns:
194
+ 'text', 'pdf', or 'image'
195
+ """
196
+ # Handle single string input
197
+ if isinstance(input_data, (str, Path)):
198
+ survey_str = str(input_data)
199
+ ext = os.path.splitext(survey_str)[1].lower()
200
+
201
+ # Check for PDF
202
+ if ext == '.pdf':
203
+ return 'pdf'
204
+
205
+ # Check for image
206
+ if ext in _IMAGE_EXTENSIONS:
207
+ return 'image'
208
+
209
+ # Check if it's a directory (could contain PDFs or images)
210
+ if os.path.isdir(survey_str):
211
+ # Check first file to determine type
212
+ try:
213
+ for f in sorted(os.listdir(survey_str)):
214
+ f_ext = os.path.splitext(f)[1].lower()
215
+ if f_ext == '.pdf':
216
+ return 'pdf'
217
+ if f_ext in _IMAGE_EXTENSIONS:
218
+ return 'image'
219
+ except OSError:
220
+ pass
221
+ # Default to PDF for directories (backward compatibility)
222
+ return 'pdf'
223
+
224
+ return 'text'
225
+
226
+ # Handle list/series input
227
+ if hasattr(input_data, '__iter__'):
228
+ for item in input_data:
229
+ if item is not None and not pd.isna(item):
230
+ item_str = str(item)
231
+ ext = os.path.splitext(item_str)[1].lower()
232
+ if ext == '.pdf':
233
+ return 'pdf'
234
+ if ext in _IMAGE_EXTENSIONS:
235
+ return 'image'
236
+ # First non-null item is text
237
+ return 'text'
238
+
239
+ return 'text'
240
+
241
+
242
+ # =============================================================================
243
+ # Helper Functions
244
+ # =============================================================================
245
+
246
+ def sanitize_model_name(model: str) -> str:
247
+ """
248
+ Convert model name to a valid column suffix.
249
+
250
+ Examples:
251
+ gpt-4o -> gpt_4o
252
+ claude-sonnet-4-5-20250929 -> claude_sonnet_4_5_20250929
253
+ llama3.2:latest -> llama3_2_latest
254
+
255
+ Args:
256
+ model: The model name string
257
+
258
+ Returns:
259
+ Sanitized string suitable for use in column names
260
+ """
261
+ sanitized = re.sub(r'[^a-zA-Z0-9]', '_', model)
262
+ sanitized = re.sub(r'_+', '_', sanitized)
263
+ sanitized = sanitized.strip('_').lower()
264
+ return sanitized[:40] # Truncate to reasonable length
265
+
266
+
267
+ def _format_creativity_suffix(creativity) -> str:
268
+ """Format creativity value as a column suffix, e.g. 0.25 -> '_t25', 1.0 -> '_t100'."""
269
+ if creativity is None:
270
+ return "_tauto"
271
+ # Multiply by 100 and format as integer: 0 -> 0, 0.25 -> 25, 1.0 -> 100
272
+ return f"_t{int(round(creativity * 100))}"
273
+
274
+
275
+ def prepare_model_configs(models: list, auto_download: bool = False) -> list:
276
+ """
277
+ Validate and prepare model configurations.
278
+
279
+ Args:
280
+ models: List of tuples. Each tuple can be:
281
+ - (model, provider, api_key) — 3 elements
282
+ - (model, provider, api_key, options) — 4 elements, where options is a
283
+ dict with per-model overrides (e.g. {"creativity": 0.5})
284
+ auto_download: If True, automatically download missing Ollama models
285
+
286
+ Returns:
287
+ List of config dicts with validated settings
288
+
289
+ Raises:
290
+ ValueError: If API key missing for non-Ollama provider
291
+ ConnectionError: If Ollama not running when needed
292
+ RuntimeError: If Ollama model not available and auto_download is False
293
+ """
294
+ configs = []
295
+ has_ollama = False
296
+ ollama_checked = False
297
+ is_ensemble = len(models) > 1
298
+
299
+ for entry in models:
300
+ if len(entry) == 4:
301
+ model, provider, api_key, options = entry
302
+ elif len(entry) == 3:
303
+ model, provider, api_key = entry
304
+ options = {}
305
+ else:
306
+ raise ValueError(
307
+ f"Each model entry must be a 3-tuple (model, provider, api_key) "
308
+ f"or 4-tuple (model, provider, api_key, options), got {len(entry)} elements"
309
+ )
310
+
311
+ detected_provider = detect_provider(model, provider)
312
+
313
+ if detected_provider == "ollama":
314
+ has_ollama = True
315
+ # Check Ollama running (once)
316
+ if not ollama_checked:
317
+ if not check_ollama_running():
318
+ raise ConnectionError(
319
+ "\n" + "="*60 + "\n"
320
+ " OLLAMA NOT RUNNING\n"
321
+ "="*60 + "\n\n"
322
+ "Ollama must be running to use local models.\n\n"
323
+ "To start Ollama:\n"
324
+ " macOS: Open the Ollama app, or run 'ollama serve'\n"
325
+ " Linux: Run 'ollama serve' in terminal\n"
326
+ " Windows: Open the Ollama app\n\n"
327
+ + "="*60
328
+ )
329
+ ollama_checked = True
330
+
331
+ # Check model availability
332
+ if not check_ollama_model(model):
333
+ if not pull_ollama_model(model, auto_confirm=auto_download):
334
+ raise RuntimeError(
335
+ f"Ollama model '{model}' not available. "
336
+ f"Run: ollama pull {model}"
337
+ )
338
+ elif detected_provider == "claude-code":
339
+ if not check_claude_cli_available():
340
+ raise ConnectionError(
341
+ "\n" + "="*60 + "\n"
342
+ " CLAUDE CLI NOT FOUND\n"
343
+ "="*60 + "\n\n"
344
+ "The claude CLI must be installed to use claude-code as a provider.\n"
345
+ "Install: https://docs.anthropic.com/en/docs/claude-code\n"
346
+ + "="*60
347
+ )
348
+ else:
349
+ # Validate API key exists for cloud providers
350
+ if not api_key:
351
+ raise ValueError(
352
+ f"API key required for provider '{detected_provider}' (model: {model})"
353
+ )
354
+
355
+ # Per-model creativity override (None means use global)
356
+ per_model_creativity = options.get("creativity", None) if options else None
357
+
358
+ # Build sanitized column name
359
+ base_name = sanitize_model_name(model)
360
+ if is_ensemble:
361
+ base_name += _format_creativity_suffix(per_model_creativity)
362
+
363
+ configs.append({
364
+ "model": model,
365
+ "provider": detected_provider,
366
+ "api_key": api_key,
367
+ "use_two_step": (detected_provider == "ollama"),
368
+ "sanitized_name": base_name,
369
+ "creativity": per_model_creativity,
370
+ })
371
+
372
+ # Check for duplicate sanitized names
373
+ sanitized_names = [c["sanitized_name"] for c in configs]
374
+ if len(sanitized_names) != len(set(sanitized_names)):
375
+ # Find duplicates and make unique by appending index
376
+ seen = {}
377
+ for cfg in configs:
378
+ name = cfg["sanitized_name"]
379
+ if name in seen:
380
+ seen[name] += 1
381
+ cfg["sanitized_name"] = f"{name}_{seen[name]}"
382
+ else:
383
+ seen[name] = 0
384
+
385
+ return configs
386
+
387
+
388
+ def aggregate_results(
389
+ model_results: dict,
390
+ categories: list,
391
+ consensus_threshold: Union[str, float],
392
+ fail_strategy: str,
393
+ formatter_state: dict = None,
394
+ ) -> dict:
395
+ """
396
+ Aggregate results from multiple models using majority voting.
397
+
398
+ Args:
399
+ model_results: Dict mapping model name to (json_str, error)
400
+ categories: List of category names
401
+ consensus_threshold: Threshold for majority vote. Can be:
402
+ - "majority": 50% agreement (default)
403
+ - "two-thirds": 67% agreement
404
+ - "unanimous": 100% agreement
405
+ - float: Custom threshold between 0 and 1
406
+ fail_strategy: How to handle failures ("partial" or "strict")
407
+
408
+ Returns:
409
+ Dict with per_model results, consensus, agreement scores, and metadata
410
+ """
411
+ # Resolve string thresholds to numeric values
412
+ threshold = _resolve_consensus_threshold(consensus_threshold)
413
+ successful = {}
414
+ failed_models = []
415
+
416
+ num_cats = len(categories)
417
+ expected_keys = {str(i) for i in range(1, num_cats + 1)}
418
+
419
+ for model_name, (json_str, error) in model_results.items():
420
+ if error:
421
+ failed_models.append(model_name)
422
+ else:
423
+ try:
424
+ parsed = json.loads(json_str)
425
+ # Accept if at least one key is a valid numbered category
426
+ # with a 0/1 value. Models may only return present categories
427
+ # (e.g. {"3": "1"}) — missing keys default to 0 downstream.
428
+ # Strip out any keys with invalid values so they also
429
+ # default to 0 cleanly instead of hitting error paths.
430
+ valid_count = sum(
431
+ 1 for k, v in parsed.items()
432
+ if k in expected_keys and str(v).strip() in ("0", "1")
433
+ )
434
+ if valid_count > 0:
435
+ cleaned = {
436
+ k: str(v).strip() for k, v in parsed.items()
437
+ if k in expected_keys and str(v).strip() in ("0", "1")
438
+ }
439
+ successful[model_name] = cleaned
440
+ else:
441
+ failed_models.append(model_name)
442
+ except json.JSONDecodeError:
443
+ failed_models.append(model_name)
444
+
445
+ # Handle failure strategies
446
+ if fail_strategy == "strict" and failed_models:
447
+ return {
448
+ "per_model": {},
449
+ "consensus": {},
450
+ "agreement": {},
451
+ "failed_models": failed_models,
452
+ "missing_keys": {},
453
+ "error": f"Models failed (strict mode): {failed_models}",
454
+ }
455
+
456
+ if not successful:
457
+ return {
458
+ "per_model": {},
459
+ "consensus": {},
460
+ "agreement": {},
461
+ "missing_keys": {},
462
+ "failed_models": failed_models,
463
+ "error": "All models failed",
464
+ }
465
+
466
+ # Calculate consensus via majority vote
467
+ consensus = {}
468
+ agreement_scores = {}
469
+ num_successful = len(successful)
470
+ # Track missing keys per model (keys in range but absent from response)
471
+ missing_keys_count = {}
472
+
473
+ for i in range(1, len(categories) + 1):
474
+ key = str(i)
475
+ votes = []
476
+ for model_name, parsed in successful.items():
477
+ if key not in parsed:
478
+ missing_keys_count[model_name] = missing_keys_count.get(model_name, 0) + 1
479
+ vote = parsed.get(key, "0")
480
+ # Handle both string and int values
481
+ try:
482
+ votes.append(int(vote))
483
+ except (ValueError, TypeError):
484
+ votes.append(0)
485
+
486
+ positive_rate = sum(votes) / num_successful if num_successful > 0 else 0
487
+ consensus_val = "1" if positive_rate >= threshold else "0"
488
+ consensus[key] = consensus_val
489
+ # Agreement = fraction of models that match the consensus decision
490
+ consensus_int = int(consensus_val)
491
+ matching = sum(1 for v in votes if v == consensus_int)
492
+ agreement_scores[key] = round(matching / num_successful, 3) if num_successful > 0 else 0
493
+
494
+ return {
495
+ "per_model": successful,
496
+ "consensus": consensus,
497
+ "agreement": agreement_scores,
498
+ "failed_models": failed_models,
499
+ "missing_keys": missing_keys_count,
500
+ "error": None,
501
+ }
502
+
503
+
504
+ # =============================================================================
505
+ # Shared Utility Functions
506
+ # =============================================================================
507
+
508
+ def normalize_model_input(
509
+ model: str = None,
510
+ api_key: str = None,
511
+ provider: str = "auto",
512
+ models: list = None,
513
+ ) -> list:
514
+ """
515
+ Normalize model input to a list of tuples.
516
+
517
+ Supports three input formats:
518
+ - Single model: model="gpt-4o", api_key="sk-...", provider="auto"
519
+ - Single tuple: models=("gpt-4o", "openai", "sk-...")
520
+ - List of tuples: models=[("gpt-4o", "openai", "sk-..."), ...]
521
+
522
+ Args:
523
+ model: Single model name
524
+ api_key: API key for single model
525
+ provider: Provider for single model (default "auto")
526
+ models: List of tuples or single tuple
527
+
528
+ Returns:
529
+ List of tuples: [(model, provider, api_key), ...]
530
+
531
+ Raises:
532
+ ValueError: If no model specified
533
+ """
534
+ if models is None and model is not None:
535
+ # Single model mode: model="gpt-4o", api_key="sk-...", provider="auto"
536
+ return [(model, provider, api_key)]
537
+ elif models is not None:
538
+ # Check if it's a single tuple (not a list of tuples)
539
+ if isinstance(models, tuple) and len(models) in (3, 4) and isinstance(models[0], str):
540
+ return [models]
541
+ return models
542
+
543
+ raise ValueError(
544
+ "No model specified. Use either:\n"
545
+ " - Single model: model='gpt-4o', api_key='sk-...'\n"
546
+ " - Multiple models: models=[('gpt-4o', 'openai', 'sk-...'), ...]"
547
+ )
548
+
549
+
550
+ def gather_stepback_insights(
551
+ model_configs: list,
552
+ survey_question: str,
553
+ creativity: float = None,
554
+ ) -> dict:
555
+ """
556
+ Gather step-back insights from each model.
557
+
558
+ Step-back prompting first asks about underlying factors before classification.
559
+
560
+ Args:
561
+ model_configs: List of model configuration dicts
562
+ survey_question: The survey question being analyzed
563
+ creativity: Temperature setting
564
+
565
+ Returns:
566
+ Dict mapping model name to (stepback_question, insight) tuples
567
+ """
568
+ if not survey_question:
569
+ raise TypeError(
570
+ "survey_question is required when using step_back_prompt. "
571
+ "Please provide the survey question you are analyzing."
572
+ )
573
+
574
+ stepback_question = f'What are the underlying factors or dimensions that explain how people typically answer "{survey_question}"?'
575
+
576
+ print("Getting step-back insights for each model...")
577
+ stepback_insights = {}
578
+
579
+ for cfg in model_configs:
580
+ if cfg["provider"] != "ollama":
581
+ effective_creativity = cfg.get("creativity") if cfg.get("creativity") is not None else creativity
582
+ insight, added = _get_stepback_insight(
583
+ cfg["provider"],
584
+ stepback_question,
585
+ cfg["api_key"],
586
+ cfg["model"],
587
+ effective_creativity
588
+ )
589
+ if added:
590
+ stepback_insights[cfg["model"]] = (stepback_question, insight)
591
+
592
+ return stepback_insights
593
+
594
+
595
+ def prepare_json_schemas(
596
+ model_configs: list,
597
+ categories: list,
598
+ use_json_schema: bool = True,
599
+ ) -> dict:
600
+ """
601
+ Prepare JSON schemas for each model based on provider requirements.
602
+
603
+ Args:
604
+ model_configs: List of model configuration dicts
605
+ categories: List of category names
606
+ use_json_schema: Whether to use strict JSON schema
607
+
608
+ Returns:
609
+ Dict mapping model name to JSON schema (or None)
610
+ """
611
+ json_schemas = {}
612
+
613
+ for cfg in model_configs:
614
+ if use_json_schema:
615
+ # Google doesn't support additionalProperties
616
+ include_additional = (cfg["provider"] != "google")
617
+ json_schemas[cfg["model"]] = build_json_schema(categories, include_additional)
618
+ else:
619
+ json_schemas[cfg["model"]] = None
620
+
621
+ return json_schemas
622
+
623
+
624
+ # =============================================================================
625
+ # Chain of Verification (CoVe) Functions
626
+ # =============================================================================
627
+
628
+ def build_cove_prompts(original_task: str, response_text: str) -> tuple:
629
+ """
630
+ Build Chain of Verification prompts for the 4-step verification process.
631
+
632
+ Args:
633
+ original_task: The original classification prompt/task
634
+ response_text: The text response being classified
635
+
636
+ Returns:
637
+ Tuple of (step2_prompt, step3_prompt, step4_prompt)
638
+ """
639
+ step2_prompt = """You provided this initial categorization:
640
+ <<INITIAL_REPLY>>
641
+
642
+ Original task: {original_task}
643
+
644
+ Generate a focused list of 3-5 verification questions to fact-check your categorization. Each question should:
645
+ - Be concise and specific (one sentence)
646
+ - Address a distinct aspect of the categorization
647
+ - Be answerable independently
648
+
649
+ Focus on verifying:
650
+ - Whether each category assignment is accurate
651
+ - Whether the categories match the criteria in the original task
652
+ - Whether there are any logical inconsistencies
653
+
654
+ Provide only the verification questions as a numbered list.""".format(original_task=original_task)
655
+
656
+ step3_prompt = """Answer the following verification question based on the text response provided.
657
+
658
+ Text response: {response_text}
659
+
660
+ Verification question: <<QUESTION>>
661
+
662
+ Provide a brief, direct answer (1-2 sentences maximum).
663
+
664
+ Answer:""".format(response_text=response_text)
665
+
666
+ step4_prompt = """Original task: {original_task}
667
+ Initial categorization:
668
+ <<INITIAL_REPLY>>
669
+ Verification questions and answers:
670
+ <<VERIFICATION_QA>>
671
+ If no categories are present, assign "0" to all categories.
672
+ Provide the final corrected categorization in the same JSON format:""".format(original_task=original_task)
673
+
674
+ return step2_prompt, step3_prompt, step4_prompt
675
+
676
+
677
+ def _remove_numbering(line: str) -> str:
678
+ """
679
+ Remove numbering/bullets from a line for CoVe question parsing.
680
+
681
+ Handles formats like:
682
+ - "1. Question"
683
+ - "1) Question"
684
+ - "- Question"
685
+ - "• Question"
686
+ """
687
+ line = line.strip()
688
+ if line.startswith('- '):
689
+ return line[2:].strip()
690
+ if line.startswith('• '):
691
+ return line[2:].strip()
692
+ if line and line[0].isdigit():
693
+ i = 0
694
+ while i < len(line) and line[i].isdigit():
695
+ i += 1
696
+ if i < len(line) and line[i] in '.)':
697
+ return line[i+1:].strip()
698
+ return line
699
+
700
+
701
+ def run_chain_of_verification(
702
+ client,
703
+ initial_reply: str,
704
+ step2_prompt: str,
705
+ step3_prompt: str,
706
+ step4_prompt: str,
707
+ json_schema: dict,
708
+ creativity: float = None,
709
+ max_retries: int = 5,
710
+ ) -> str:
711
+ """
712
+ Run the Chain of Verification process.
713
+
714
+ This is a 4-step process:
715
+ 1. Initial classification (already done, passed as initial_reply)
716
+ 2. Generate verification questions
717
+ 3. Answer each verification question
718
+ 4. Final corrected classification
719
+
720
+ Args:
721
+ client: UnifiedLLMClient instance
722
+ initial_reply: The initial JSON classification result
723
+ step2_prompt: Prompt template for generating questions
724
+ step3_prompt: Prompt template for answering questions
725
+ step4_prompt: Prompt template for final classification
726
+ json_schema: JSON schema for the final classification
727
+ creativity: Temperature setting
728
+ max_retries: Maximum retry attempts for each API call
729
+
730
+ Returns:
731
+ Final corrected JSON classification string
732
+ """
733
+ # Step 2: Generate verification questions (text response, not JSON)
734
+ step2_filled = step2_prompt.replace("<<INITIAL_REPLY>>", initial_reply)
735
+ questions_reply, err = client.complete(
736
+ messages=[{"role": "user", "content": step2_filled}],
737
+ creativity=creativity,
738
+ force_json=False, # Text response
739
+ max_retries=max_retries,
740
+ )
741
+ if err:
742
+ return initial_reply # Fall back to initial reply on error
743
+
744
+ # Parse questions
745
+ questions = [
746
+ _remove_numbering(line)
747
+ for line in questions_reply.strip().split('\n')
748
+ if line.strip()
749
+ ]
750
+
751
+ # Step 3: Answer each verification question (text responses)
752
+ qa_pairs = []
753
+ for question in questions[:5]: # Limit to 5 questions
754
+ step3_filled = step3_prompt.replace("<<QUESTION>>", question)
755
+ answer_reply, err = client.complete(
756
+ messages=[{"role": "user", "content": step3_filled}],
757
+ creativity=creativity,
758
+ force_json=False, # Text response
759
+ max_retries=max_retries,
760
+ )
761
+ if not err:
762
+ qa_pairs.append(f"Q: {question}\nA: {answer_reply.strip()}")
763
+
764
+ verification_qa = "\n\n".join(qa_pairs)
765
+
766
+ # Step 4: Final corrected categorization (JSON response)
767
+ step4_filled = step4_prompt.replace(
768
+ "<<INITIAL_REPLY>>", initial_reply
769
+ ).replace(
770
+ "<<VERIFICATION_QA>>", verification_qa
771
+ )
772
+ final_reply, err = client.complete(
773
+ messages=[{"role": "user", "content": step4_filled}],
774
+ json_schema=json_schema,
775
+ creativity=creativity,
776
+ max_retries=max_retries,
777
+ )
778
+
779
+ if err:
780
+ return initial_reply
781
+
782
+ # Extract and validate JSON from the response (critical for providers
783
+ # like HuggingFace that use json_object mode instead of strict json_schema)
784
+ extracted = extract_json(final_reply)
785
+ try:
786
+ parsed = json.loads(extracted)
787
+ # Verify it has at least one valid key
788
+ if parsed and any(v in ("0", "1") for v in parsed.values()):
789
+ return extracted
790
+ except (json.JSONDecodeError, AttributeError):
791
+ pass
792
+
793
+ # Fall back to initial reply if extraction/validation fails
794
+ return initial_reply
795
+
796
+
797
+ # =============================================================================
798
+ # Text-Specific Functions (swap these for image classification)
799
+ # =============================================================================
800
+
801
+ def build_text_classification_prompt(
802
+ response_text: str,
803
+ categories_str: str,
804
+ survey_question_context: str = "",
805
+ examples_text: str = "",
806
+ chain_of_thought: bool = False,
807
+ context_prompt: bool = False,
808
+ step_back_prompt: bool = False,
809
+ stepback_insights: dict = None,
810
+ model_name: str = None,
811
+ multi_label: bool = True,
812
+ ) -> list:
813
+ """
814
+ Build the classification prompt for a text response.
815
+
816
+ This is the text-specific prompt builder. For image classification,
817
+ a different function would be used.
818
+
819
+ Args:
820
+ response_text: The text to classify
821
+ categories_str: Formatted string of categories (numbered list)
822
+ survey_question_context: Context about the survey question
823
+ examples_text: Few-shot examples text
824
+ chain_of_thought: Whether to use step-by-step reasoning
825
+ context_prompt: Whether to add expert context prefix
826
+ step_back_prompt: Whether step-back prompting is enabled
827
+ stepback_insights: Dict of step-back insights per model
828
+ model_name: Current model name (for step-back lookup)
829
+
830
+ Returns:
831
+ List of message dicts for the LLM
832
+ """
833
+ if multi_label:
834
+ categorize_instruction = 'into the following categories that apply'
835
+ json_instruction = 'Provide your answer in JSON format where the category number is the key and "1" if present, "0" if not.'
836
+ cot_step3 = 'assign 1 to matching categories and 0 to non-matching categories'
837
+ else:
838
+ categorize_instruction = 'into the single most appropriate category'
839
+ json_instruction = 'Provide your answer in JSON format where the category number is the key. Assign "1" to the single best matching category and "0" to all others.'
840
+ cot_step3 = 'assign 1 to the single best matching category and 0 to all others'
841
+
842
+ if chain_of_thought:
843
+ user_prompt = f"""{survey_question_context}
844
+
845
+ Categorize this text response "{response_text}" {categorize_instruction}:
846
+ {categories_str}
847
+
848
+ Let's think step by step:
849
+ 1. First, identify the main themes mentioned in the response
850
+ 2. Then, match each theme to the relevant categories
851
+ 3. Finally, {cot_step3}
852
+
853
+ {examples_text}
854
+
855
+ {json_instruction}"""
856
+ else:
857
+ user_prompt = f"""{survey_question_context}
858
+ Categorize this text response "{response_text}" {categorize_instruction}:
859
+ {categories_str}
860
+ {examples_text}
861
+ {json_instruction}"""
862
+
863
+ # Add context prompt prefix if enabled
864
+ if context_prompt:
865
+ label_type = "multi-label" if multi_label else "single-label"
866
+ context = f"""You are an expert researcher in text data categorization.
867
+ Apply {label_type} classification and base decisions on explicit and implicit meanings.
868
+ When uncertain, prioritize precision over recall.
869
+
870
+ """
871
+ user_prompt = context + user_prompt
872
+
873
+ # Build messages list
874
+ messages = []
875
+
876
+ # Add step-back insight if available for this model
877
+ if step_back_prompt and stepback_insights and model_name in stepback_insights:
878
+ sb_question, sb_insight = stepback_insights[model_name]
879
+ messages.append({"role": "user", "content": sb_question})
880
+ messages.append({"role": "assistant", "content": sb_insight})
881
+
882
+ messages.append({"role": "user", "content": user_prompt})
883
+
884
+ return messages
885
+
886
+
887
+ # =============================================================================
888
+ # Summarization Functions
889
+ # =============================================================================
890
+
891
+ def build_summary_json_schema(include_additional_properties: bool = True) -> dict:
892
+ """
893
+ Build JSON schema for summary output.
894
+
895
+ Args:
896
+ include_additional_properties: Whether to include additionalProperties: false.
897
+ Should be False for Google (not supported).
898
+
899
+ Returns:
900
+ JSON schema dict for structured summary output
901
+ """
902
+ schema = {
903
+ "type": "object",
904
+ "properties": {
905
+ "summary": {
906
+ "type": "string",
907
+ "description": "A concise summary of the input text"
908
+ }
909
+ },
910
+ "required": ["summary"],
911
+ }
912
+ if include_additional_properties:
913
+ schema["additionalProperties"] = False
914
+ return schema
915
+
916
+
917
+ def build_text_summarization_prompt(
918
+ response_text: str,
919
+ input_description: str = "",
920
+ summary_instructions: str = "",
921
+ max_length: int = None,
922
+ focus: str = None,
923
+ chain_of_thought: bool = False,
924
+ context_prompt: bool = False,
925
+ step_back_prompt: bool = False,
926
+ stepback_insights: dict = None,
927
+ model_name: str = None,
928
+ ) -> list:
929
+ """
930
+ Build the summarization prompt for a text input.
931
+
932
+ Args:
933
+ response_text: The text to summarize
934
+ input_description: Description of what the text contains
935
+ summary_instructions: Specific instructions (e.g., "bullet points", "one sentence")
936
+ max_length: Maximum summary length in words
937
+ focus: What to focus on (e.g., "main arguments", "emotional content")
938
+ chain_of_thought: Whether to use step-by-step reasoning
939
+ context_prompt: Whether to add expert context prefix
940
+ step_back_prompt: Whether step-back prompting is enabled
941
+ stepback_insights: Dict of step-back insights per model
942
+ model_name: Current model name (for step-back lookup)
943
+
944
+ Returns:
945
+ List of message dicts for the LLM
946
+ """
947
+ # Build focus instruction if provided
948
+ focus_instruction = ""
949
+ if focus:
950
+ focus_instruction = f", focusing on {focus}"
951
+
952
+ # Build description context if provided
953
+ description_context = ""
954
+ if input_description:
955
+ description_context = f"The following text is: {input_description}\n\n"
956
+
957
+ # Build length instruction if provided
958
+ length_instruction = ""
959
+ if max_length:
960
+ length_instruction = f"\n\nKeep the summary under {max_length} words."
961
+
962
+ # Build custom instructions if provided
963
+ custom_instructions = ""
964
+ if summary_instructions:
965
+ custom_instructions = f"\n\nAdditional instructions: {summary_instructions}"
966
+
967
+ if chain_of_thought:
968
+ user_prompt = f"""{description_context}Summarize the following text{focus_instruction}:
969
+
970
+ "{response_text}"
971
+
972
+ Let's think step by step:
973
+ 1. First, identify the main topic or theme
974
+ 2. Then, extract the key points
975
+ 3. Finally, synthesize into a concise summary{length_instruction}{custom_instructions}
976
+
977
+ Provide your answer in JSON format: {{"summary": "your summary here"}}"""
978
+ else:
979
+ user_prompt = f"""{description_context}Summarize the following text{focus_instruction}:
980
+
981
+ "{response_text}"{length_instruction}{custom_instructions}
982
+
983
+ Provide your answer in JSON format: {{"summary": "your summary here"}}"""
984
+
985
+ # Add context prompt prefix if enabled
986
+ if context_prompt:
987
+ context = """You are an expert at synthesizing key insights from text.
988
+ Focus on accuracy, clarity, and identifying the most important themes.
989
+ Provide concise summaries that capture essential information.
990
+
991
+ """
992
+ user_prompt = context + user_prompt
993
+
994
+ # Build messages list
995
+ messages = []
996
+
997
+ # Add step-back insight if available for this model
998
+ if step_back_prompt and stepback_insights and model_name in stepback_insights:
999
+ sb_question, sb_insight = stepback_insights[model_name]
1000
+ messages.append({"role": "user", "content": sb_question})
1001
+ messages.append({"role": "assistant", "content": sb_insight})
1002
+
1003
+ messages.append({"role": "user", "content": user_prompt})
1004
+
1005
+ return messages
1006
+
1007
+
1008
+ def extract_summary_from_json(json_str: str) -> tuple:
1009
+ """
1010
+ Extract summary from JSON response.
1011
+
1012
+ Args:
1013
+ json_str: JSON string containing {"summary": "..."}
1014
+
1015
+ Returns:
1016
+ Tuple of (is_valid, summary_text or None)
1017
+ """
1018
+ try:
1019
+ data = json.loads(json_str)
1020
+ if isinstance(data, dict) and "summary" in data:
1021
+ summary = data["summary"]
1022
+ if isinstance(summary, str) and summary.strip():
1023
+ return True, summary.strip()
1024
+ return False, None
1025
+ except (json.JSONDecodeError, TypeError):
1026
+ return False, None
1027
+
1028
+
1029
+ def build_pdf_summarization_prompt(
1030
+ page_data: dict,
1031
+ input_description: str = "",
1032
+ summary_instructions: str = "",
1033
+ max_length: int = None,
1034
+ focus: str = None,
1035
+ provider: str = "openai",
1036
+ pdf_mode: str = "image",
1037
+ chain_of_thought: bool = False,
1038
+ context_prompt: bool = False,
1039
+ step_back_prompt: bool = False,
1040
+ stepback_insights: dict = None,
1041
+ model_name: str = None,
1042
+ ) -> list:
1043
+ """
1044
+ Build the summarization prompt for a PDF page.
1045
+
1046
+ This is the PDF-specific prompt builder, parallel to build_pdf_classification_prompt()
1047
+ but for summarization instead of classification.
1048
+
1049
+ Args:
1050
+ page_data: Dict containing:
1051
+ - pdf_path: Path to source PDF
1052
+ - page_index: Page number (0-indexed)
1053
+ - page_label: Label like "document_p1"
1054
+ - image_bytes: PNG bytes (for image mode)
1055
+ - pdf_bytes: PDF bytes (for native PDF providers)
1056
+ - text: Extracted text (for text mode)
1057
+ input_description: Description of what the PDF documents contain
1058
+ summary_instructions: Specific instructions (e.g., "bullet points")
1059
+ max_length: Maximum summary length in words
1060
+ focus: What to focus on in the summary
1061
+ provider: Provider name for format-specific handling
1062
+ pdf_mode: "image", "text", or "both"
1063
+ chain_of_thought: Whether to use step-by-step reasoning
1064
+ context_prompt: Whether to add expert context prefix
1065
+ step_back_prompt: Whether step-back prompting is enabled
1066
+ stepback_insights: Dict of step-back insights per model
1067
+ model_name: Current model name (for step-back lookup)
1068
+
1069
+ Returns:
1070
+ List of message content parts for the LLM (format varies by provider)
1071
+ """
1072
+ # Build focus instruction if provided
1073
+ focus_instruction = ""
1074
+ if focus:
1075
+ focus_instruction = f", focusing on {focus}"
1076
+
1077
+ # Build examine instruction based on mode
1078
+ if pdf_mode == "text":
1079
+ examine_instruction = "Examine the following text extracted from a PDF page"
1080
+ elif pdf_mode == "both":
1081
+ examine_instruction = "Examine the attached PDF page AND the extracted text below"
1082
+ else: # image mode
1083
+ examine_instruction = "Examine the attached PDF page"
1084
+
1085
+ # Build length instruction if provided
1086
+ length_instruction = ""
1087
+ if max_length:
1088
+ length_instruction = f"\n\nKeep the summary under {max_length} words."
1089
+
1090
+ # Build custom instructions if provided
1091
+ custom_instructions = ""
1092
+ if summary_instructions:
1093
+ custom_instructions = f"\n\nAdditional instructions: {summary_instructions}"
1094
+
1095
+ if chain_of_thought:
1096
+ base_text = f"""You are a document summarization assistant.
1097
+ Task: {examine_instruction} and provide a concise summary{focus_instruction}.
1098
+
1099
+ {f'Document context: {input_description}' if input_description else ''}
1100
+
1101
+ Let's analyze step by step:
1102
+ 1. First, identify the main topic or theme of this page
1103
+ 2. Then, extract the key points and important information
1104
+ 3. Finally, synthesize into a concise summary{length_instruction}{custom_instructions}
1105
+
1106
+ Provide your answer in JSON format: {{"summary": "your summary here"}}"""
1107
+ else:
1108
+ base_text = f"""You are a document summarization assistant.
1109
+ Task: {examine_instruction} and provide a concise summary{focus_instruction}.
1110
+
1111
+ {f'Document context: {input_description}' if input_description else ''}{length_instruction}{custom_instructions}
1112
+
1113
+ Provide your answer in JSON format: {{"summary": "your summary here"}}"""
1114
+
1115
+ # Add extracted text for text and both modes
1116
+ if page_data.get("text") and pdf_mode in ("text", "both"):
1117
+ base_text += f"\n\n--- EXTRACTED TEXT FROM PAGE ---\n{page_data['text']}\n--- END OF EXTRACTED TEXT ---"
1118
+
1119
+ # Add context prompt prefix if enabled
1120
+ if context_prompt:
1121
+ context = """You are an expert at synthesizing key insights from documents.
1122
+ Focus on accuracy, clarity, and identifying the most important themes.
1123
+ Provide concise summaries that capture essential information.
1124
+
1125
+ """
1126
+ base_text = context + base_text
1127
+
1128
+ # Build messages based on provider and mode
1129
+ messages = []
1130
+
1131
+ # Add step-back insight if available
1132
+ if step_back_prompt and stepback_insights and model_name in stepback_insights:
1133
+ sb_question, sb_insight = stepback_insights[model_name]
1134
+ messages.append({"role": "user", "content": sb_question})
1135
+ messages.append({"role": "assistant", "content": sb_insight})
1136
+
1137
+ # TEXT-ONLY MODE: No image/PDF attachment
1138
+ if pdf_mode == "text":
1139
+ messages.append({"role": "user", "content": base_text})
1140
+ return messages
1141
+
1142
+ # IMAGE/BOTH MODE: Include visual content
1143
+ # Format depends on provider
1144
+
1145
+ if provider in _NATIVE_PDF_PROVIDERS and page_data.get("pdf_bytes"):
1146
+ # Anthropic or Google with native PDF
1147
+ encoded_pdf = _encode_bytes_to_base64(page_data["pdf_bytes"])
1148
+
1149
+ if provider == "anthropic":
1150
+ content = [
1151
+ {"type": "text", "text": base_text},
1152
+ {
1153
+ "type": "document",
1154
+ "source": {
1155
+ "type": "base64",
1156
+ "media_type": "application/pdf",
1157
+ "data": encoded_pdf
1158
+ }
1159
+ }
1160
+ ]
1161
+ messages.append({"role": "user", "content": content})
1162
+
1163
+ elif provider == "google":
1164
+ # Google uses a special format
1165
+ content = [
1166
+ {"type": "text", "text": base_text},
1167
+ {
1168
+ "type": "inline_data",
1169
+ "mime_type": "application/pdf",
1170
+ "data": encoded_pdf
1171
+ }
1172
+ ]
1173
+ messages.append({"role": "user", "content": content})
1174
+
1175
+ elif page_data.get("image_bytes"):
1176
+ # Providers requiring image conversion (OpenAI, Mistral, xAI, etc.)
1177
+ encoded_image = _encode_bytes_to_base64(page_data["image_bytes"])
1178
+ encoded_image_url = f"data:image/png;base64,{encoded_image}"
1179
+
1180
+ content = [
1181
+ {"type": "text", "text": base_text},
1182
+ {"type": "image_url", "image_url": {"url": encoded_image_url, "detail": "high"}}
1183
+ ]
1184
+ messages.append({"role": "user", "content": content})
1185
+
1186
+ else:
1187
+ # Fallback to text-only if no visual data available
1188
+ messages.append({"role": "user", "content": base_text})
1189
+
1190
+ return messages
1191
+
1192
+
1193
+ # =============================================================================
1194
+ # PDF-Specific Functions
1195
+ # =============================================================================
1196
+
1197
+ # Provider-specific PDF support info
1198
+ _NATIVE_PDF_PROVIDERS = {"anthropic", "google"} # Send native PDF bytes
1199
+ _IMAGE_PROVIDERS = {"openai", "mistral", "xai", "perplexity", "huggingface"} # Convert to image
1200
+ _TEXT_ONLY_PROVIDERS = {"ollama"} # Text extraction only
1201
+
1202
+
1203
+ def build_pdf_classification_prompt(
1204
+ page_data: dict,
1205
+ categories_str: str,
1206
+ input_description: str = "",
1207
+ provider: str = "openai",
1208
+ pdf_mode: str = "image",
1209
+ chain_of_thought: bool = False,
1210
+ context_prompt: bool = False,
1211
+ step_back_prompt: bool = False,
1212
+ stepback_insights: dict = None,
1213
+ model_name: str = None,
1214
+ example_json: str = None,
1215
+ multi_label: bool = True,
1216
+ ) -> list:
1217
+ """
1218
+ Build the classification prompt for a PDF page.
1219
+
1220
+ This is the PDF-specific prompt builder, parallel to build_text_classification_prompt().
1221
+
1222
+ Args:
1223
+ page_data: Dict containing:
1224
+ - pdf_path: Path to source PDF
1225
+ - page_index: Page number (0-indexed)
1226
+ - page_label: Label like "document_p1"
1227
+ - image_bytes: PNG bytes (for image mode)
1228
+ - pdf_bytes: PDF bytes (for native PDF providers)
1229
+ - text: Extracted text (for text mode)
1230
+ categories_str: Formatted string of categories (numbered list)
1231
+ input_description: Description of what the PDF documents contain
1232
+ provider: Provider name for format-specific handling
1233
+ pdf_mode: "image", "text", or "both"
1234
+ chain_of_thought: Whether to use step-by-step reasoning
1235
+ context_prompt: Whether to add expert context prefix
1236
+ step_back_prompt: Whether step-back prompting is enabled
1237
+ stepback_insights: Dict of step-back insights per model
1238
+ model_name: Current model name (for step-back lookup)
1239
+ example_json: Example JSON output format
1240
+
1241
+ Returns:
1242
+ List of message content parts for the LLM (format varies by provider)
1243
+ """
1244
+ # Build the base text prompt
1245
+ if pdf_mode == "text":
1246
+ examine_instruction = "Examine the following text extracted from a PDF page"
1247
+ elif pdf_mode == "both":
1248
+ examine_instruction = "Examine the attached PDF page AND the extracted text below"
1249
+ else: # image mode
1250
+ examine_instruction = "Examine the attached PDF page"
1251
+
1252
+ if multi_label:
1253
+ task_instruction = f'{examine_instruction} and decide, **for each category below**, whether it is PRESENT (1) or NOT PRESENT (0).'
1254
+ cot_step3 = 'assign 1 to matching categories and 0 to non-matching categories'
1255
+ else:
1256
+ task_instruction = f'{examine_instruction} and decide which **single category** below best describes this document page. Assign PRESENT (1) to only the best match and NOT PRESENT (0) to all others.'
1257
+ cot_step3 = 'assign 1 to the single best matching category and 0 to all others'
1258
+
1259
+ if chain_of_thought:
1260
+ base_text = f"""You are a document-tagging assistant.
1261
+ Task: {task_instruction}
1262
+
1263
+ {f'Document page is expected to contain: {input_description}' if input_description else ''}
1264
+
1265
+ Categories:
1266
+ {categories_str}
1267
+
1268
+ Let's analyze step by step:
1269
+ 1. First, identify the key content elements in the document page
1270
+ 2. Then, match each element to the relevant categories
1271
+ 3. Finally, {cot_step3}
1272
+
1273
+ Output format: Respond with **only** a JSON object whose keys are the quoted category numbers ('1', '2', ...) and whose values are 1 or 0. No additional keys, comments, or text.
1274
+
1275
+ {f'Example JSON format: {example_json}' if example_json else ''}"""
1276
+ else:
1277
+ base_text = f"""You are a document-tagging assistant.
1278
+ Task: {task_instruction}
1279
+
1280
+ {f'Document page is expected to contain: {input_description}' if input_description else ''}
1281
+
1282
+ Categories:
1283
+ {categories_str}
1284
+
1285
+ Output format: Respond with **only** a JSON object whose keys are the quoted category numbers ('1', '2', ...) and whose values are 1 or 0. No additional keys, comments, or text.
1286
+
1287
+ {f'Example JSON format: {example_json}' if example_json else ''}"""
1288
+
1289
+ # Add extracted text for text and both modes
1290
+ if page_data.get("text") and pdf_mode in ("text", "both"):
1291
+ base_text += f"\n\n--- EXTRACTED TEXT FROM PAGE ---\n{page_data['text']}\n--- END OF EXTRACTED TEXT ---"
1292
+
1293
+ # Add context prompt prefix if enabled
1294
+ if context_prompt:
1295
+ label_type = "multi-label" if multi_label else "single-label"
1296
+ context = f"""You are an expert document analyst specializing in page categorization.
1297
+ Apply {label_type} classification based on explicit and implicit content cues.
1298
+ When uncertain, prioritize precision over recall.
1299
+
1300
+ """
1301
+ base_text = context + base_text
1302
+
1303
+ # Build messages based on provider and mode
1304
+ messages = []
1305
+
1306
+ # Add step-back insight if available
1307
+ if step_back_prompt and stepback_insights and model_name in stepback_insights:
1308
+ sb_question, sb_insight = stepback_insights[model_name]
1309
+ messages.append({"role": "user", "content": sb_question})
1310
+ messages.append({"role": "assistant", "content": sb_insight})
1311
+
1312
+ # TEXT-ONLY MODE: No image/PDF attachment
1313
+ if pdf_mode == "text":
1314
+ messages.append({"role": "user", "content": base_text})
1315
+ return messages
1316
+
1317
+ # IMAGE/BOTH MODE: Include visual content
1318
+ # Format depends on provider
1319
+
1320
+ if provider in _NATIVE_PDF_PROVIDERS and page_data.get("pdf_bytes"):
1321
+ # Anthropic or Google with native PDF
1322
+ encoded_pdf = _encode_bytes_to_base64(page_data["pdf_bytes"])
1323
+
1324
+ if provider == "anthropic":
1325
+ content = [
1326
+ {"type": "text", "text": base_text},
1327
+ {
1328
+ "type": "document",
1329
+ "source": {
1330
+ "type": "base64",
1331
+ "media_type": "application/pdf",
1332
+ "data": encoded_pdf
1333
+ }
1334
+ }
1335
+ ]
1336
+ messages.append({"role": "user", "content": content})
1337
+
1338
+ elif provider == "google":
1339
+ # Google uses a special format - return dict for google-specific handling
1340
+ content = [
1341
+ {"type": "text", "text": base_text},
1342
+ {
1343
+ "type": "inline_data",
1344
+ "mime_type": "application/pdf",
1345
+ "data": encoded_pdf
1346
+ }
1347
+ ]
1348
+ messages.append({"role": "user", "content": content})
1349
+
1350
+ elif page_data.get("image_bytes"):
1351
+ # Providers requiring image conversion (OpenAI, Mistral, xAI, etc.)
1352
+ encoded_image = _encode_bytes_to_base64(page_data["image_bytes"])
1353
+ encoded_image_url = f"data:image/png;base64,{encoded_image}"
1354
+
1355
+ content = [
1356
+ {"type": "text", "text": base_text},
1357
+ {"type": "image_url", "image_url": {"url": encoded_image_url, "detail": "high"}}
1358
+ ]
1359
+ messages.append({"role": "user", "content": content})
1360
+
1361
+ else:
1362
+ # Fallback to text-only if no visual data available
1363
+ messages.append({"role": "user", "content": base_text})
1364
+
1365
+ return messages
1366
+
1367
+
1368
+ def _prepare_page_data(
1369
+ pdf_path: str,
1370
+ page_index: int,
1371
+ page_label: str,
1372
+ pdf_mode: str,
1373
+ provider: str,
1374
+ pdf_dpi: int = 150,
1375
+ ) -> dict:
1376
+ """
1377
+ Prepare page data for classification based on mode and provider.
1378
+
1379
+ Args:
1380
+ pdf_path: Path to the PDF file
1381
+ page_index: Page number (0-indexed)
1382
+ page_label: Label for the page (e.g., "document_p1")
1383
+ pdf_mode: "image", "text", or "both"
1384
+ provider: Provider name for determining format
1385
+ pdf_dpi: DPI for image extraction
1386
+
1387
+ Returns:
1388
+ Dict with page data ready for classification
1389
+ """
1390
+ page_data = {
1391
+ "pdf_path": pdf_path,
1392
+ "page_index": page_index,
1393
+ "page_label": page_label,
1394
+ "text": None,
1395
+ "image_bytes": None,
1396
+ "pdf_bytes": None,
1397
+ "error": None,
1398
+ }
1399
+
1400
+ # Extract text if needed
1401
+ if pdf_mode in ("text", "both"):
1402
+ text, is_valid, error = _extract_page_text(pdf_path, page_index)
1403
+ if is_valid:
1404
+ page_data["text"] = text
1405
+ elif pdf_mode == "text":
1406
+ # Text mode requires text
1407
+ page_data["error"] = f"Failed to extract text: {error}"
1408
+ return page_data
1409
+
1410
+ # Extract visual content if needed
1411
+ if pdf_mode in ("image", "both"):
1412
+ if provider in _NATIVE_PDF_PROVIDERS:
1413
+ # Extract as PDF bytes for native PDF providers
1414
+ pdf_bytes, is_valid = _extract_page_as_pdf_bytes(pdf_path, page_index)
1415
+ if is_valid:
1416
+ page_data["pdf_bytes"] = pdf_bytes
1417
+ else:
1418
+ # Fallback to image
1419
+ image_bytes, is_valid = _extract_page_as_image_bytes(pdf_path, page_index, dpi=pdf_dpi)
1420
+ if is_valid:
1421
+ page_data["image_bytes"] = image_bytes
1422
+ else:
1423
+ page_data["error"] = "Failed to extract page as PDF or image"
1424
+ else:
1425
+ # Extract as image for other providers
1426
+ image_bytes, is_valid = _extract_page_as_image_bytes(pdf_path, page_index, dpi=pdf_dpi)
1427
+ if is_valid:
1428
+ page_data["image_bytes"] = image_bytes
1429
+ else:
1430
+ page_data["error"] = "Failed to extract page as image"
1431
+
1432
+ return page_data
1433
+
1434
+
1435
+ # =============================================================================
1436
+ # Image-Specific Functions
1437
+ # =============================================================================
1438
+
1439
+ def build_image_classification_prompt(
1440
+ image_data: dict,
1441
+ categories_str: str,
1442
+ input_description: str = "",
1443
+ provider: str = "openai",
1444
+ chain_of_thought: bool = False,
1445
+ context_prompt: bool = False,
1446
+ step_back_prompt: bool = False,
1447
+ stepback_insights: dict = None,
1448
+ model_name: str = None,
1449
+ example_json: str = None,
1450
+ multi_label: bool = True,
1451
+ ) -> list:
1452
+ """
1453
+ Build the classification prompt for an image.
1454
+
1455
+ This is the image-specific prompt builder, parallel to build_pdf_classification_prompt().
1456
+
1457
+ Args:
1458
+ image_data: Dict containing:
1459
+ - image_path: Path to source image
1460
+ - image_label: Label for the image (filename without extension)
1461
+ - encoded_image: Base64 encoded image
1462
+ - extension: Image file extension (without dot)
1463
+ categories_str: Formatted string of categories (numbered list)
1464
+ input_description: Description of what the images contain
1465
+ provider: Provider name for format-specific handling
1466
+ chain_of_thought: Whether to use step-by-step reasoning
1467
+ context_prompt: Whether to add expert context prefix
1468
+ step_back_prompt: Whether step-back prompting is enabled
1469
+ stepback_insights: Dict of step-back insights per model
1470
+ model_name: Current model name (for step-back lookup)
1471
+ example_json: Example JSON output format
1472
+
1473
+ Returns:
1474
+ List of message content parts for the LLM (format varies by provider)
1475
+ """
1476
+ # Build the base text prompt
1477
+ if multi_label:
1478
+ task_instruction = 'Examine the attached image and decide, **for each category below**, whether it is PRESENT (1) or NOT PRESENT (0).'
1479
+ cot_step3 = 'assign 1 to matching categories and 0 to non-matching categories'
1480
+ else:
1481
+ task_instruction = 'Examine the attached image and decide which **single category** below best describes this image. Assign PRESENT (1) to only the best match and NOT PRESENT (0) to all others.'
1482
+ cot_step3 = 'assign 1 to the single best matching category and 0 to all others'
1483
+
1484
+ if chain_of_thought:
1485
+ base_text = f"""You are an image-tagging assistant.
1486
+ Task: {task_instruction}
1487
+
1488
+ {f'The image is expected to contain: {input_description}' if input_description else ''}
1489
+
1490
+ Categories:
1491
+ {categories_str}
1492
+
1493
+ Let's analyze step by step:
1494
+ 1. First, identify the key visual elements in the image
1495
+ 2. Then, match each element to the relevant categories
1496
+ 3. Finally, {cot_step3}
1497
+
1498
+ Output format: Respond with **only** a JSON object whose keys are the quoted category numbers ('1', '2', ...) and whose values are 1 or 0. No additional keys, comments, or text.
1499
+
1500
+ {f'Example JSON format: {example_json}' if example_json else ''}"""
1501
+ else:
1502
+ base_text = f"""You are an image-tagging assistant.
1503
+ Task: {task_instruction}
1504
+
1505
+ {f'The image is expected to contain: {input_description}' if input_description else ''}
1506
+
1507
+ Categories:
1508
+ {categories_str}
1509
+
1510
+ Output format: Respond with **only** a JSON object whose keys are the quoted category numbers ('1', '2', ...) and whose values are 1 or 0. No additional keys, comments, or text.
1511
+
1512
+ {f'Example JSON format: {example_json}' if example_json else ''}"""
1513
+
1514
+ # Add context prompt prefix if enabled
1515
+ if context_prompt:
1516
+ label_type = "multi-label" if multi_label else "single-label"
1517
+ context = f"""You are an expert visual analyst specializing in image categorization.
1518
+ Apply {label_type} classification based on explicit and implicit visual cues.
1519
+ When uncertain, prioritize precision over recall.
1520
+
1521
+ """
1522
+ base_text = context + base_text
1523
+
1524
+ # Build messages based on provider
1525
+ messages = []
1526
+
1527
+ # Add step-back insight if available
1528
+ if step_back_prompt and stepback_insights and model_name in stepback_insights:
1529
+ sb_question, sb_insight = stepback_insights[model_name]
1530
+ messages.append({"role": "user", "content": sb_question})
1531
+ messages.append({"role": "assistant", "content": sb_insight})
1532
+
1533
+ # Get encoded image and extension
1534
+ encoded = image_data.get("encoded_image", "")
1535
+ ext = image_data.get("extension", "png")
1536
+
1537
+ # Format depends on provider
1538
+ if provider == "anthropic":
1539
+ # Anthropic uses explicit media_type
1540
+ content = [
1541
+ {"type": "text", "text": base_text},
1542
+ {
1543
+ "type": "image",
1544
+ "source": {
1545
+ "type": "base64",
1546
+ "media_type": f"image/{ext}",
1547
+ "data": encoded
1548
+ }
1549
+ }
1550
+ ]
1551
+ messages.append({"role": "user", "content": content})
1552
+
1553
+ elif provider == "google":
1554
+ # Google uses inline_data format
1555
+ content = [
1556
+ {"type": "text", "text": base_text},
1557
+ {
1558
+ "type": "inline_data",
1559
+ "mime_type": f"image/{ext}",
1560
+ "data": encoded
1561
+ }
1562
+ ]
1563
+ messages.append({"role": "user", "content": content})
1564
+
1565
+ else:
1566
+ # OpenAI, Mistral, xAI, etc. use image_url format
1567
+ encoded_url = f"data:image/{ext};base64,{encoded}"
1568
+ content = [
1569
+ {"type": "text", "text": base_text},
1570
+ {"type": "image_url", "image_url": {"url": encoded_url, "detail": "high"}}
1571
+ ]
1572
+ messages.append({"role": "user", "content": content})
1573
+
1574
+ return messages
1575
+
1576
+
1577
+ def _prepare_image_data(
1578
+ image_path: str,
1579
+ image_label: str,
1580
+ ) -> dict:
1581
+ """
1582
+ Prepare image data for classification.
1583
+
1584
+ Args:
1585
+ image_path: Path to the image file
1586
+ image_label: Label for the image (typically filename without extension)
1587
+
1588
+ Returns:
1589
+ Dict with image_path, image_label, encoded_image, extension, error
1590
+ """
1591
+ image_data = {
1592
+ "image_path": image_path,
1593
+ "image_label": image_label,
1594
+ "encoded_image": None,
1595
+ "extension": None,
1596
+ "error": None,
1597
+ }
1598
+
1599
+ try:
1600
+ # _encode_image returns (encoded_data, extension, is_valid)
1601
+ encoded, ext_from_encode, is_valid = _encode_image(image_path)
1602
+ if not is_valid:
1603
+ image_data["error"] = "Failed to encode image"
1604
+ return image_data
1605
+
1606
+ # Normalize extension
1607
+ ext = ext_from_encode.lower() if ext_from_encode else os.path.splitext(image_path)[1].lower().replace('.', '')
1608
+ if ext in ('jpg', 'jpe', 'jfif', 'pjpeg', 'pjp'):
1609
+ ext = 'jpeg'
1610
+
1611
+ image_data["encoded_image"] = encoded
1612
+ image_data["extension"] = ext
1613
+ except Exception as e:
1614
+ image_data["error"] = str(e)
1615
+
1616
+ return image_data
1617
+
1618
+
1619
+ def _save_partial_results(
1620
+ all_results: list,
1621
+ model_configs: list,
1622
+ categories: list,
1623
+ filename: str,
1624
+ save_directory: str,
1625
+ ) -> None:
1626
+ """
1627
+ Save partial results to file for safety/incremental saves.
1628
+
1629
+ This is a simplified version of build_output_dataframes that only
1630
+ saves the combined DataFrame without building per-model DataFrames.
1631
+ """
1632
+ model_names = [cfg["sanitized_name"] for cfg in model_configs]
1633
+ num_categories = len(categories)
1634
+
1635
+ # Check if any results have PDF or image metadata
1636
+ has_pdf_metadata = any(result.get("pdf_path") is not None for result in all_results)
1637
+ has_image_metadata = any(result.get("image_path") is not None for result in all_results)
1638
+
1639
+ # Build partial data
1640
+ rows = []
1641
+ for result in all_results:
1642
+ # Determine processing status
1643
+ if result.get("skipped"):
1644
+ status = "skipped"
1645
+ elif result["aggregated"]["error"]:
1646
+ status = "error"
1647
+ elif result["aggregated"]["failed_models"]:
1648
+ status = "partial"
1649
+ else:
1650
+ status = "success"
1651
+
1652
+ row = {
1653
+ "input_data": result["response"],
1654
+ "processing_status": status,
1655
+ "failed_models": ",".join(result["aggregated"]["failed_models"]) if result["aggregated"]["failed_models"] else "",
1656
+ }
1657
+
1658
+ # Add PDF metadata columns if present
1659
+ if has_pdf_metadata:
1660
+ row["pdf_path"] = result.get("pdf_path", "")
1661
+ row["page_index"] = result.get("page_index", "")
1662
+
1663
+ # Add image metadata columns if present
1664
+ if has_image_metadata:
1665
+ row["image_path"] = result.get("image_path", "")
1666
+
1667
+ # Per-model results
1668
+ failed_set = set(result["aggregated"].get("failed_models", []))
1669
+ is_skipped = result.get("skipped", False)
1670
+ for model_name in model_names:
1671
+ if is_skipped or model_name in failed_set:
1672
+ # Skipped (NaN input) or model failed — mark as NA
1673
+ for i in range(1, num_categories + 1):
1674
+ row[f"category_{i}_{model_name}"] = None
1675
+ else:
1676
+ parsed = result["aggregated"]["per_model"].get(model_name, {})
1677
+ for i in range(1, num_categories + 1):
1678
+ key = str(i)
1679
+ value = parsed.get(key, "0")
1680
+ try:
1681
+ row[f"category_{i}_{model_name}"] = int(value)
1682
+ except (ValueError, TypeError):
1683
+ row[f"category_{i}_{model_name}"] = 0
1684
+
1685
+ # Consensus results
1686
+ for i in range(1, num_categories + 1):
1687
+ key = str(i)
1688
+ consensus_val = result["aggregated"]["consensus"].get(key, None)
1689
+ agreement_val = result["aggregated"]["agreement"].get(key, None)
1690
+
1691
+ if consensus_val is not None:
1692
+ try:
1693
+ row[f"category_{i}_consensus"] = int(consensus_val)
1694
+ except (ValueError, TypeError):
1695
+ row[f"category_{i}_consensus"] = None
1696
+ else:
1697
+ row[f"category_{i}_consensus"] = None
1698
+
1699
+ row[f"category_{i}_agreement"] = agreement_val
1700
+
1701
+ rows.append(row)
1702
+
1703
+ # Create DataFrame and save
1704
+ partial_df = pd.DataFrame(rows)
1705
+
1706
+ save_path = filename
1707
+ if save_directory:
1708
+ os.makedirs(save_directory, exist_ok=True)
1709
+ save_path = os.path.join(save_directory, filename)
1710
+
1711
+ partial_df.to_csv(save_path, index=False)
1712
+
1713
+
1714
+ def classify_ensemble(
1715
+ input_data,
1716
+ categories,
1717
+ # Single model mode (like original multi_class)
1718
+ model: str = None,
1719
+ api_key: str = None,
1720
+ provider: str = "auto",
1721
+ # Multi-model mode
1722
+ models: list = None,
1723
+ # Common parameters
1724
+ survey_question: str = "",
1725
+ example1: str = None,
1726
+ example2: str = None,
1727
+ example3: str = None,
1728
+ example4: str = None,
1729
+ example5: str = None,
1730
+ example6: str = None,
1731
+ creativity: float = None,
1732
+ chain_of_thought: bool = False,
1733
+ chain_of_verification: bool = False,
1734
+ step_back_prompt: bool = False,
1735
+ context_prompt: bool = False,
1736
+ thinking_budget: int = 0,
1737
+ use_json_schema: bool = True,
1738
+ max_workers: int = None,
1739
+ parallel: bool = None,
1740
+ consensus_threshold: Union[str, float] = "unanimous",
1741
+ fail_strategy: str = "partial",
1742
+ safety: bool = False,
1743
+ max_retries: int = 5,
1744
+ batch_retries: int = 2,
1745
+ retry_delay: float = 1.0,
1746
+ row_delay: float = 0.0,
1747
+ filename: str = None,
1748
+ save_directory: str = None,
1749
+ progress_callback: Callable = None,
1750
+ # Auto-category detection parameters
1751
+ max_categories: int = 12,
1752
+ categories_per_chunk: int = 10,
1753
+ divisions: int = 10,
1754
+ research_question: str = None,
1755
+ # PDF-specific parameters (only used when input_data contains PDFs)
1756
+ pdf_mode: str = "image",
1757
+ pdf_dpi: int = 150,
1758
+ input_description: str = "",
1759
+ # Ollama parameters
1760
+ auto_download: bool = False,
1761
+ # JSON formatter fallback
1762
+ formatter_state: dict = None,
1763
+ # Label mode
1764
+ multi_label: bool = True,
1765
+ # Chunked classification
1766
+ categories_per_call: int = None,
1767
+ # Embedding tiebreaker
1768
+ embedding_tiebreaker_state: dict = None,
1769
+ ):
1770
+ """
1771
+ Multi-class classification with support for text AND PDF inputs, single or multiple LLM models.
1772
+
1773
+ This unified function auto-detects whether the input is text or PDF and processes accordingly.
1774
+
1775
+ Input type detection:
1776
+ - If input_data is a directory path -> PDF mode
1777
+ - If input_data contains .pdf file paths -> PDF mode
1778
+ - Otherwise -> Text mode
1779
+
1780
+ This function can work in multiple modes:
1781
+ 1. Single model mode: Like the original multi_class function
1782
+ 2. Ensemble mode: Call multiple models in parallel with majority voting
1783
+ 3. PDF mode: Classify PDF pages instead of text responses
1784
+
1785
+ Args:
1786
+ input_data: Text responses OR PDF paths (auto-detected)
1787
+ - Text mode: List or Series of text strings to classify
1788
+ - PDF mode: Directory path, single PDF path, or list of PDF paths
1789
+
1790
+ categories: List of category names, or "auto" to auto-detect categories
1791
+
1792
+ # Single model mode (use these for simple single-model classification):
1793
+ model: Model name (e.g., "gpt-4o", "claude-sonnet-4-5-20250929")
1794
+ api_key: API key for the provider
1795
+ provider: Provider name or "auto" to detect from model name
1796
+
1797
+ # Multi-model mode (use this for ensemble classification):
1798
+ models: List of tuples (model_name, provider, api_key), or a single tuple
1799
+ Example: [("gpt-4o", "openai", "sk-..."), ("claude-sonnet-4-5-20250929", "anthropic", "sk-ant-...")]
1800
+
1801
+ # Classification parameters:
1802
+ survey_question: Context about what question was asked (required for categories="auto")
1803
+ example1-6: Optional few-shot examples for classification
1804
+ creativity: Temperature setting (None for provider default)
1805
+ chain_of_thought: If True, uses step-by-step reasoning in prompt
1806
+ chain_of_verification: If True, uses 4-step verification to improve accuracy
1807
+ (Note: ~4x API calls per response - expensive for ensemble mode)
1808
+ step_back_prompt: If True, first asks about underlying factors before classifying
1809
+ context_prompt: If True, adds expert context prefix to prompts
1810
+ thinking_budget: Token budget for Google's extended thinking (0 to disable)
1811
+ use_json_schema: Whether to use strict JSON schema (vs just json_object mode)
1812
+
1813
+ # Ensemble parameters:
1814
+ max_workers: Maximum parallel workers (default: min(len(models), 8))
1815
+ consensus_threshold: Threshold for consensus vote. Can be:
1816
+ - "unanimous": 100% agreement (default — best accuracy in empirical testing)
1817
+ - "majority": 50% agreement
1818
+ - "two-thirds": 67% agreement
1819
+ - float: Custom threshold between 0 and 1 (e.g., 0.75 for 75%)
1820
+ fail_strategy: How to handle model failures:
1821
+ - "partial": Continue with successful models
1822
+ - "strict": Fail row if any model fails
1823
+ safety: If True, saves results incrementally during processing to prevent
1824
+ data loss. Requires filename to be set.
1825
+ max_retries: Maximum retry attempts for each API call (handles rate limits,
1826
+ server errors, timeouts). Default 5.
1827
+ batch_retries: Maximum retry passes for failed (row, model) pairs after
1828
+ the batch completes. Default 2 means up to 3 total attempts. Set to 0
1829
+ to disable batch-level retries.
1830
+ retry_delay: Seconds to wait between batch retry passes.
1831
+
1832
+ # Output parameters:
1833
+ filename: Optional CSV filename to save combined results (required if safety=True)
1834
+ save_directory: Optional directory for saved files
1835
+ progress_callback: Optional callback(response_idx, model_name, success, total, completed)
1836
+
1837
+ # Auto-category detection parameters (used when categories="auto"):
1838
+ max_categories: Maximum number of categories to discover (default 12)
1839
+ categories_per_chunk: Categories to extract per data chunk (default 10)
1840
+ divisions: Number of chunks to divide data into (default 10)
1841
+ research_question: Optional research context for category discovery
1842
+
1843
+ # PDF-specific parameters (only used when input_data contains PDFs):
1844
+ pdf_mode: How to process PDF pages. Options:
1845
+ - "image": Render pages as images (best for visual elements)
1846
+ - "text": Extract text only (faster/cheaper for text-heavy docs)
1847
+ - "both": Send both text and image (most comprehensive)
1848
+ pdf_dpi: Resolution for PDF to image conversion (default 150)
1849
+ input_description: Description of what the PDF documents contain
1850
+
1851
+ # Ollama parameters:
1852
+ auto_download: If True, automatically download missing Ollama models
1853
+
1854
+ Returns:
1855
+ - Single model: Returns DataFrame directly (backward compatible with multi_class)
1856
+ - Multiple models: Returns dict containing:
1857
+ - "combined": DataFrame with all per-model and consensus columns
1858
+ - "consensus": DataFrame with only consensus results
1859
+ - "<model_name>": Individual DataFrame for each model
1860
+
1861
+ DataFrame columns:
1862
+ - input_data: Text string OR page label (e.g., "document_p1")
1863
+ - category_N_<model>: Per-model results (0/1)
1864
+ - category_N_consensus: Majority vote result
1865
+ - category_N_agreement: Model agreement score
1866
+ - processing_status: "success", "partial", "error", "skipped"
1867
+ - failed_models: List of failed models
1868
+
1869
+ Additional columns (PDF mode only):
1870
+ - pdf_path: Source PDF file path
1871
+ - page_index: Page number (0-indexed)
1872
+
1873
+ Examples:
1874
+ # TEXT MODE - Single model (returns DataFrame directly):
1875
+ df = multi_class_ensemble(
1876
+ input_data=["I moved for a new job"],
1877
+ categories=["Employment", "Family", "Housing"],
1878
+ model="gpt-4o",
1879
+ api_key="sk-...",
1880
+ survey_question="Why did you move?",
1881
+ )
1882
+
1883
+ # TEXT MODE - Ensemble with multiple models (returns dict):
1884
+ results = multi_class_ensemble(
1885
+ input_data=["I moved for a new job"],
1886
+ categories=["Employment", "Family", "Housing"],
1887
+ models=[
1888
+ ("gpt-4o", "openai", "sk-..."),
1889
+ ("claude-sonnet-4-5-20250929", "anthropic", "sk-ant-..."),
1890
+ ],
1891
+ survey_question="Why did you move?",
1892
+ )
1893
+ combined_df = results["combined"]
1894
+
1895
+ # PDF MODE - Single model (auto-detected from .pdf paths):
1896
+ df = multi_class_ensemble(
1897
+ input_data="reports/", # Directory of PDFs
1898
+ categories=["Has Chart", "Has Table", "Financial Summary"],
1899
+ model="gpt-4o",
1900
+ api_key="sk-...",
1901
+ pdf_mode="image",
1902
+ input_description="Financial reports with charts and tables",
1903
+ )
1904
+
1905
+ # PDF MODE - Ensemble with native PDF support:
1906
+ results = multi_class_ensemble(
1907
+ input_data=["doc1.pdf", "doc2.pdf"],
1908
+ categories=["Diagnosis", "Treatment Plan"],
1909
+ models=[
1910
+ ("gpt-4o", "openai", "sk-..."),
1911
+ ("claude-sonnet-4-5-20250929", "anthropic", "sk-ant-..."), # Native PDF
1912
+ ],
1913
+ pdf_mode="both",
1914
+ consensus_threshold=0.5,
1915
+ )
1916
+ """
1917
+ # Normalize model input to list of tuples
1918
+ models = normalize_model_input(model, api_key, provider, models)
1919
+
1920
+ # Validate safety parameter
1921
+ if safety and filename is None:
1922
+ raise TypeError(
1923
+ "filename is required when using safety=True. "
1924
+ "Please provide a filename to save incremental results to."
1925
+ )
1926
+
1927
+ # Handle categories="auto" - auto-detect categories from the data
1928
+ if categories == "auto":
1929
+ from .main import extract
1930
+
1931
+ # Detect input type to choose the right extraction path
1932
+ detected_type = _detect_input_type(input_data)
1933
+
1934
+ if detected_type == "text" and survey_question == "":
1935
+ raise TypeError(
1936
+ "survey_question is required when using categories='auto' with text input. "
1937
+ "Please provide the survey question you are analyzing."
1938
+ )
1939
+
1940
+ # Use first model for category discovery
1941
+ first_entry = models[0]
1942
+ first_model, first_provider, first_api_key = first_entry[0], first_entry[1], first_entry[2]
1943
+ detected_provider = detect_provider(first_model, first_provider)
1944
+
1945
+ print(f"Auto-detecting categories using {first_model} (input type: {detected_type})...")
1946
+ auto_result = extract(
1947
+ input_data=input_data,
1948
+ api_key=first_api_key,
1949
+ input_type=detected_type,
1950
+ description=survey_question or input_description,
1951
+ max_categories=max_categories,
1952
+ categories_per_chunk=categories_per_chunk,
1953
+ divisions=divisions,
1954
+ user_model=first_model,
1955
+ model_source=detected_provider,
1956
+ research_question=research_question,
1957
+ mode=pdf_mode,
1958
+ )
1959
+ categories = auto_result["top_categories"]
1960
+ print(f"Discovered {len(categories)} categories: {categories}")
1961
+
1962
+ if not isinstance(categories, list) or len(categories) == 0:
1963
+ raise ValueError("categories must be a non-empty list of category names, or 'auto'")
1964
+
1965
+ # Prepare model configurations
1966
+ print(f"Validating {len(models)} model configuration(s)...")
1967
+ model_configs = prepare_model_configs(models, auto_download=auto_download)
1968
+
1969
+ # Print model info
1970
+ print(f"\nModels to use:")
1971
+ for cfg in model_configs:
1972
+ print(f" - {cfg['model']} ({cfg['provider']}) -> column suffix: {cfg['sanitized_name']}")
1973
+
1974
+ # =============================================================================
1975
+ # DETECT INPUT TYPE: Text vs PDF vs Image
1976
+ # =============================================================================
1977
+ input_type = _detect_input_type(input_data)
1978
+ print(f"\nInput type detected: {input_type.upper()}")
1979
+
1980
+ # Initialize processing variables
1981
+ items_to_process = []
1982
+ is_pdf_mode = (input_type == 'pdf')
1983
+ is_image_mode = (input_type == 'image')
1984
+
1985
+ # Build example JSON for visual modes (PDF/image)
1986
+ category_dict = {str(i+1): "0" for i in range(len(categories))}
1987
+ example_json = json.dumps(category_dict, indent=2)
1988
+
1989
+ if is_image_mode:
1990
+ # =================================================================
1991
+ # IMAGE MODE: Load images
1992
+ # =================================================================
1993
+ print(f"Loading images...")
1994
+
1995
+ # Load image files
1996
+ image_files = _load_image_files(input_data)
1997
+
1998
+ if not image_files:
1999
+ raise ValueError("No images found in the provided input.")
2000
+
2001
+ print(f"Total images to process: {len(image_files)}")
2002
+
2003
+ # items_to_process is list of (image_path, image_label) tuples
2004
+ items_to_process = [
2005
+ (img_path, os.path.splitext(os.path.basename(img_path))[0])
2006
+ for img_path in image_files
2007
+ ]
2008
+
2009
+ elif is_pdf_mode:
2010
+ # =================================================================
2011
+ # PDF MODE: Load PDFs and extract all pages
2012
+ # =================================================================
2013
+ # Validate pdf_mode parameter
2014
+ pdf_mode = pdf_mode.lower()
2015
+ if pdf_mode not in {"image", "text", "both"}:
2016
+ raise ValueError(f"pdf_mode must be 'image', 'text', or 'both', got: {pdf_mode}")
2017
+
2018
+ print(f"PDF processing mode: {pdf_mode}")
2019
+
2020
+ # Load PDF files
2021
+ pdf_files = _load_pdf_files(input_data)
2022
+
2023
+ # Extract all pages from all PDFs
2024
+ all_pages = []
2025
+ for pdf_path in pdf_files:
2026
+ pages = _get_pdf_pages(pdf_path)
2027
+ all_pages.extend(pages)
2028
+
2029
+ if not all_pages:
2030
+ raise ValueError("No pages found in the provided PDF files.")
2031
+
2032
+ print(f"Total pages to process: {len(all_pages)}")
2033
+
2034
+ # items_to_process is list of (pdf_path, page_index, page_label) tuples
2035
+ items_to_process = all_pages
2036
+
2037
+ else:
2038
+ # =================================================================
2039
+ # TEXT MODE: input_data is the items to process
2040
+ # =================================================================
2041
+ items_to_process = input_data
2042
+
2043
+ # Auto-resolve parallel mode: sequential for all-local (Ollama), parallel otherwise
2044
+ if parallel is None:
2045
+ all_local = all(cfg["provider"] == "ollama" for cfg in model_configs)
2046
+ parallel = not all_local
2047
+
2048
+ # Set max workers
2049
+ effective_workers = max_workers or min(len(models), 8)
2050
+ if parallel:
2051
+ print(f"\nParallel workers: {effective_workers}")
2052
+ else:
2053
+ print(f"\nSequential mode (models run one at a time per row)")
2054
+
2055
+ # Warn about CoVe cost with ensemble
2056
+ if chain_of_verification:
2057
+ print("\n[Chain of Verification enabled]")
2058
+ print(" - ~4x API calls per response per model")
2059
+ if len(models) > 1:
2060
+ print(" - WARNING: CoVe with ensemble is expensive. Consider single-model mode.")
2061
+
2062
+ # Build shared prompt components
2063
+ categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
2064
+
2065
+ examples = [example1, example2, example3, example4, example5, example6]
2066
+ examples_text = "\n".join(
2067
+ f"Example {i}: {ex}" for i, ex in enumerate(examples, 1) if ex is not None
2068
+ )
2069
+
2070
+ survey_question_context = f"Context: {survey_question}." if survey_question else ""
2071
+
2072
+ # Print categories
2073
+ print(f"\nCategories to classify ({len(categories)} total):")
2074
+ for i, cat in enumerate(categories, 1):
2075
+ print(f" {i}. {cat}")
2076
+ print()
2077
+
2078
+ # Get step-back insights per model (if enabled)
2079
+ stepback_insights = {}
2080
+ if step_back_prompt:
2081
+ stepback_insights = gather_stepback_insights(model_configs, survey_question, creativity)
2082
+
2083
+ # Build JSON schemas per provider
2084
+ json_schemas = prepare_json_schemas(model_configs, categories, use_json_schema)
2085
+
2086
+ # Build original task prompt for CoVe (if enabled)
2087
+ cove_original_task = ""
2088
+ if chain_of_verification:
2089
+ if multi_label:
2090
+ cove_categorize = "into the following categories"
2091
+ cove_json = 'Provide your answer in JSON format where the category number is the key and "1" if present, "0" if not.'
2092
+ else:
2093
+ cove_categorize = "into the single most appropriate category"
2094
+ cove_json = 'Provide your answer in JSON format where the category number is the key. Assign "1" to the single best matching category and "0" to all others.'
2095
+ cove_original_task = f"""{survey_question_context}
2096
+ Categorize text responses {cove_categorize}:
2097
+ {categories_str}
2098
+ {cove_json}"""
2099
+
2100
+ # Formatter fallback helper (only active when formatter_state is provided)
2101
+ def _try_formatter_fallback(json_result, raw_reply, chunk_categories=None):
2102
+ """Try the JSON formatter if extract_json produced invalid output.
2103
+
2104
+ Args:
2105
+ chunk_categories: When called from chunked classification, the
2106
+ actual chunk category list (not the full list). Needed so the
2107
+ formatter sees the correct category names and count.
2108
+ """
2109
+ if not formatter_state:
2110
+ return json_result
2111
+ cats = chunk_categories if chunk_categories is not None else categories
2112
+ n = len(cats)
2113
+ is_valid, _ = validate_classification_json(json_result, n)
2114
+ if is_valid:
2115
+ return json_result
2116
+ from ._formatter import run_formatter
2117
+ fixed_output = run_formatter(
2118
+ raw_reply, cats,
2119
+ formatter_state["model"],
2120
+ formatter_state["tokenizer"],
2121
+ formatter_state["device"],
2122
+ )
2123
+ fixed_json = extract_json(fixed_output)
2124
+ fixed_valid, _ = validate_classification_json(fixed_json, n)
2125
+ if fixed_valid:
2126
+ return fixed_json
2127
+ return json_result
2128
+
2129
+ # When chunking is active, extend categories with unified "Other" so
2130
+ # aggregate_results and build_output_dataframes create the column.
2131
+ # Save original list for the actual chunked LLM calls (which add their
2132
+ # own per-chunk "Other" internally).
2133
+ _original_categories = categories
2134
+ _add_unified_other = False
2135
+ if categories_per_call is not None:
2136
+ from ._category_analysis import has_other_category as _has_other
2137
+ if not _has_other(categories):
2138
+ categories = list(categories) + ["Other"]
2139
+ _add_unified_other = True
2140
+
2141
+ # Classification function for single model + single item (text, PDF page, or image)
2142
+ def classify_single(cfg: dict, item) -> tuple:
2143
+ """
2144
+ Classify one item (text response, PDF page, or image) with one model.
2145
+
2146
+ Args:
2147
+ cfg: Model configuration dict
2148
+ item: Either:
2149
+ - Text string (text mode)
2150
+ - Tuple (pdf_path, page_index, page_label) (PDF mode)
2151
+ - Tuple (image_path, image_label) (image mode)
2152
+
2153
+ Returns:
2154
+ tuple: (model_name, json_result, error)
2155
+ """
2156
+ # Resolve per-model creativity override (falls back to global)
2157
+ effective_creativity = cfg["creativity"] if cfg["creativity"] is not None else creativity
2158
+
2159
+ # Determine item type and identifier
2160
+ if is_image_mode and isinstance(item, tuple) and len(item) == 2:
2161
+ # Image mode: item is (image_path, image_label)
2162
+ image_path, image_label = item
2163
+ item_identifier = image_label
2164
+ elif is_pdf_mode and isinstance(item, tuple) and len(item) == 3:
2165
+ # PDF mode: item is (pdf_path, page_index, page_label)
2166
+ pdf_path, page_index, page_label = item
2167
+ item_identifier = page_label
2168
+ else:
2169
+ # Text mode: item is text string
2170
+ item_identifier = str(item) if item else ""
2171
+
2172
+ # Test hook for debugging batch retries (only active when _TEST_FORCE_FAILURE = True)
2173
+ if _test_should_force_failure(item_identifier, cfg["sanitized_name"]):
2174
+ return (cfg["sanitized_name"], '{"1":"e"}', "TEST: Forced first-attempt failure")
2175
+
2176
+ # =================================================================
2177
+ # CHUNKED PATH: split categories into smaller per-call chunks
2178
+ # =================================================================
2179
+ if categories_per_call is not None:
2180
+ from ._chunked import run_chunked_classification
2181
+ try:
2182
+ client = UnifiedLLMClient(
2183
+ provider=cfg["provider"],
2184
+ api_key=cfg["api_key"],
2185
+ model=cfg["model"]
2186
+ )
2187
+ json_result, error = run_chunked_classification(
2188
+ client=client,
2189
+ cfg=cfg,
2190
+ item=item,
2191
+ categories=_original_categories,
2192
+ categories_str=categories_str,
2193
+ example_json=example_json,
2194
+ json_schema=json_schemas[cfg["model"]],
2195
+ cove_original_task=cove_original_task,
2196
+ effective_creativity=effective_creativity,
2197
+ use_json_schema=use_json_schema,
2198
+ survey_question=survey_question,
2199
+ survey_question_context=survey_question_context,
2200
+ examples_text=examples_text,
2201
+ chain_of_thought=chain_of_thought,
2202
+ context_prompt=context_prompt,
2203
+ step_back_prompt=step_back_prompt,
2204
+ stepback_insights=stepback_insights,
2205
+ chain_of_verification=chain_of_verification,
2206
+ thinking_budget=thinking_budget,
2207
+ max_retries=max_retries,
2208
+ multi_label=multi_label,
2209
+ categories_per_call=categories_per_call,
2210
+ add_unified_other=_add_unified_other,
2211
+ formatter_fallback_fn=_try_formatter_fallback,
2212
+ is_pdf_mode=is_pdf_mode,
2213
+ is_image_mode=is_image_mode,
2214
+ pdf_mode=pdf_mode,
2215
+ pdf_dpi=pdf_dpi,
2216
+ input_description=input_description,
2217
+ build_text_prompt_fn=build_text_classification_prompt,
2218
+ build_pdf_prompt_fn=build_pdf_classification_prompt,
2219
+ build_image_prompt_fn=build_image_classification_prompt,
2220
+ google_multimodal_fn=_call_google_multimodal,
2221
+ prepare_page_data_fn=_prepare_page_data,
2222
+ prepare_image_data_fn=_prepare_image_data,
2223
+ build_cove_prompts_fn=build_cove_prompts,
2224
+ run_cove_fn=run_chain_of_verification,
2225
+ )
2226
+ return (cfg["sanitized_name"], json_result, error)
2227
+ except Exception as e:
2228
+ return (cfg["sanitized_name"], '{"1":"e"}', str(e))
2229
+
2230
+ try:
2231
+ client = UnifiedLLMClient(
2232
+ provider=cfg["provider"],
2233
+ api_key=cfg["api_key"],
2234
+ model=cfg["model"]
2235
+ )
2236
+
2237
+ # =================================================================
2238
+ # PDF MODE: Build PDF-specific prompt
2239
+ # =================================================================
2240
+ if is_pdf_mode and isinstance(item, tuple):
2241
+ pdf_path, page_index, page_label = item
2242
+
2243
+ # Prepare page data based on mode and provider
2244
+ page_data = _prepare_page_data(
2245
+ pdf_path=pdf_path,
2246
+ page_index=page_index,
2247
+ page_label=page_label,
2248
+ pdf_mode=pdf_mode,
2249
+ provider=cfg["provider"],
2250
+ pdf_dpi=pdf_dpi,
2251
+ )
2252
+
2253
+ # Check for extraction errors
2254
+ if page_data.get("error"):
2255
+ return (cfg["sanitized_name"], '{"1":"e"}', page_data["error"])
2256
+
2257
+ # Build PDF classification prompt
2258
+ messages = build_pdf_classification_prompt(
2259
+ page_data=page_data,
2260
+ categories_str=categories_str,
2261
+ input_description=input_description,
2262
+ provider=cfg["provider"],
2263
+ pdf_mode=pdf_mode,
2264
+ chain_of_thought=chain_of_thought,
2265
+ context_prompt=context_prompt,
2266
+ step_back_prompt=step_back_prompt,
2267
+ stepback_insights=stepback_insights,
2268
+ model_name=cfg["model"],
2269
+ example_json=example_json,
2270
+ multi_label=multi_label,
2271
+ )
2272
+
2273
+ # Handle Google API separately (different format)
2274
+ if cfg["provider"] == "google":
2275
+ # Google needs special handling for multimodal content
2276
+ reply, error = _call_google_multimodal(
2277
+ client=client,
2278
+ messages=messages,
2279
+ json_schema=json_schemas[cfg["model"]],
2280
+ creativity=effective_creativity,
2281
+ thinking_budget=thinking_budget,
2282
+ max_retries=max_retries,
2283
+ )
2284
+ else:
2285
+ reply, error = client.complete(
2286
+ messages=messages,
2287
+ json_schema=json_schemas[cfg["model"]],
2288
+ creativity=effective_creativity,
2289
+ thinking_budget=thinking_budget if cfg["provider"] in ("google", "openai", "anthropic", "huggingface", "huggingface-together") else None,
2290
+ max_retries=max_retries,
2291
+ )
2292
+
2293
+ if error:
2294
+ json_result = '{"1":"e"}'
2295
+ else:
2296
+ json_result = extract_json(reply)
2297
+ json_result = _try_formatter_fallback(json_result, reply)
2298
+
2299
+ # Note: CoVe for PDF mode is not yet implemented
2300
+ # (would require re-attaching PDF/image to verification prompts)
2301
+
2302
+ # =================================================================
2303
+ # IMAGE MODE: Build image-specific prompt
2304
+ # =================================================================
2305
+ elif is_image_mode and isinstance(item, tuple):
2306
+ image_path, image_label = item
2307
+
2308
+ # Prepare image data
2309
+ image_data = _prepare_image_data(image_path, image_label)
2310
+
2311
+ # Check for encoding errors
2312
+ if image_data.get("error"):
2313
+ return (cfg["sanitized_name"], '{"1":"e"}', image_data["error"])
2314
+
2315
+ # Build image classification prompt
2316
+ messages = build_image_classification_prompt(
2317
+ image_data=image_data,
2318
+ categories_str=categories_str,
2319
+ input_description=input_description,
2320
+ provider=cfg["provider"],
2321
+ chain_of_thought=chain_of_thought,
2322
+ context_prompt=context_prompt,
2323
+ step_back_prompt=step_back_prompt,
2324
+ stepback_insights=stepback_insights,
2325
+ model_name=cfg["model"],
2326
+ example_json=example_json,
2327
+ multi_label=multi_label,
2328
+ )
2329
+
2330
+ # Handle Google API separately (different format)
2331
+ if cfg["provider"] == "google":
2332
+ # Google needs special handling for multimodal content
2333
+ reply, error = _call_google_multimodal(
2334
+ client=client,
2335
+ messages=messages,
2336
+ json_schema=json_schemas[cfg["model"]],
2337
+ creativity=effective_creativity,
2338
+ thinking_budget=thinking_budget,
2339
+ max_retries=max_retries,
2340
+ )
2341
+ else:
2342
+ reply, error = client.complete(
2343
+ messages=messages,
2344
+ json_schema=json_schemas[cfg["model"]],
2345
+ creativity=effective_creativity,
2346
+ thinking_budget=thinking_budget if cfg["provider"] in ("google", "openai", "anthropic", "huggingface", "huggingface-together") else None,
2347
+ max_retries=max_retries,
2348
+ )
2349
+
2350
+ if error:
2351
+ json_result = '{"1":"e"}'
2352
+ else:
2353
+ json_result = extract_json(reply)
2354
+ json_result = _try_formatter_fallback(json_result, reply)
2355
+
2356
+ # Note: CoVe for image mode is not yet implemented
2357
+
2358
+ # =================================================================
2359
+ # TEXT MODE: Original text classification logic
2360
+ # =================================================================
2361
+ else:
2362
+ response_text = item
2363
+
2364
+ if cfg["use_two_step"]: # Ollama
2365
+ json_result, error = ollama_two_step_classify(
2366
+ client=client,
2367
+ response_text=response_text,
2368
+ categories=categories,
2369
+ categories_str=categories_str,
2370
+ survey_question=survey_question,
2371
+ creativity=effective_creativity,
2372
+ max_retries=max_retries,
2373
+ )
2374
+ if not error:
2375
+ json_result = _try_formatter_fallback(json_result, json_result)
2376
+ # CoVe not supported for Ollama two-step (already has verification)
2377
+ else:
2378
+ messages = build_text_classification_prompt(
2379
+ response_text=response_text,
2380
+ categories_str=categories_str,
2381
+ survey_question_context=survey_question_context,
2382
+ examples_text=examples_text,
2383
+ chain_of_thought=chain_of_thought,
2384
+ context_prompt=context_prompt,
2385
+ step_back_prompt=step_back_prompt,
2386
+ stepback_insights=stepback_insights,
2387
+ model_name=cfg["model"],
2388
+ multi_label=multi_label,
2389
+ )
2390
+ reply, error = client.complete(
2391
+ messages=messages,
2392
+ json_schema=json_schemas[cfg["model"]],
2393
+ creativity=effective_creativity,
2394
+ thinking_budget=thinking_budget if cfg["provider"] in ("google", "openai", "anthropic", "huggingface", "huggingface-together") else None,
2395
+ max_retries=max_retries,
2396
+ )
2397
+ if error:
2398
+ json_result = '{"1":"e"}'
2399
+ else:
2400
+ json_result = extract_json(reply)
2401
+ json_result = _try_formatter_fallback(json_result, reply)
2402
+
2403
+ # Run Chain of Verification if enabled
2404
+ if chain_of_verification and not error:
2405
+ step2, step3, step4 = build_cove_prompts(
2406
+ cove_original_task, response_text
2407
+ )
2408
+ json_result = run_chain_of_verification(
2409
+ client=client,
2410
+ initial_reply=json_result,
2411
+ step2_prompt=step2,
2412
+ step3_prompt=step3,
2413
+ step4_prompt=step4,
2414
+ json_schema=json_schemas[cfg["model"]],
2415
+ creativity=effective_creativity,
2416
+ max_retries=max_retries,
2417
+ )
2418
+ json_result = _try_formatter_fallback(json_result, json_result)
2419
+
2420
+ return (cfg["sanitized_name"], json_result, error)
2421
+
2422
+ except Exception as e:
2423
+ return (cfg["sanitized_name"], '{"1":"e"}', str(e))
2424
+
2425
+ # Helper function for Google multimodal API calls
2426
+ def _call_google_multimodal(client, messages, json_schema, creativity, thinking_budget, max_retries):
2427
+ """
2428
+ Handle Google's multimodal API format for PDF/image content.
2429
+ """
2430
+ import requests
2431
+
2432
+ # Extract the content from messages
2433
+ user_msg = messages[-1] # Last message should be the user message
2434
+ content = user_msg.get("content", [])
2435
+
2436
+ # Build Google-format parts
2437
+ parts = []
2438
+ for part in content:
2439
+ if part.get("type") == "text":
2440
+ parts.append({"text": part["text"]})
2441
+ elif part.get("type") == "inline_data":
2442
+ parts.append({
2443
+ "inline_data": {
2444
+ "mime_type": part["mime_type"],
2445
+ "data": part["data"]
2446
+ }
2447
+ })
2448
+ elif part.get("type") == "image_url":
2449
+ # Convert image URL to inline_data format
2450
+ url = part["image_url"]["url"]
2451
+ if url.startswith("data:image/png;base64,"):
2452
+ data = url.replace("data:image/png;base64,", "")
2453
+ parts.append({
2454
+ "inline_data": {
2455
+ "mime_type": "image/png",
2456
+ "data": data
2457
+ }
2458
+ })
2459
+
2460
+ # Get model name from client
2461
+ model_name = client.model
2462
+
2463
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent"
2464
+ headers = {
2465
+ "x-goog-api-key": client.api_key,
2466
+ "Content-Type": "application/json"
2467
+ }
2468
+
2469
+ payload = {
2470
+ "contents": [{"parts": parts}],
2471
+ "generationConfig": {
2472
+ "responseMimeType": "application/json",
2473
+ **({"temperature": creativity} if creativity is not None else {}),
2474
+ **({"thinkingConfig": {"thinkingBudget": thinking_budget}} if thinking_budget else {})
2475
+ }
2476
+ }
2477
+
2478
+ for attempt in range(max_retries):
2479
+ try:
2480
+ response = requests.post(url, headers=headers, json=payload, timeout=120)
2481
+ response.raise_for_status()
2482
+ result = response.json()
2483
+
2484
+ if "candidates" in result and result["candidates"]:
2485
+ reply = result["candidates"][0]["content"]["parts"][0]["text"]
2486
+ return reply, None
2487
+ else:
2488
+ return None, "No response generated"
2489
+
2490
+ except requests.exceptions.HTTPError as e:
2491
+ status_code = e.response.status_code
2492
+ retryable_errors = [429, 500, 502, 503, 504]
2493
+
2494
+ if status_code in retryable_errors and attempt < max_retries - 1:
2495
+ import time
2496
+ wait_time = 10 * (2 ** attempt) if status_code == 429 else 2 * (2 ** attempt)
2497
+ time.sleep(wait_time)
2498
+ else:
2499
+ return None, f"HTTP error {status_code}: {str(e)}"
2500
+
2501
+ except Exception as e:
2502
+ if attempt < max_retries - 1:
2503
+ import time
2504
+ time.sleep(2 * (2 ** attempt))
2505
+ else:
2506
+ return None, str(e)
2507
+
2508
+ return None, "Max retries exceeded"
2509
+
2510
+ # Process all items (text responses, PDF pages, or images)
2511
+ all_results = []
2512
+ completed_calls = [0] # Mutable for closure
2513
+ total_calls = len(items_to_process) * len(model_configs)
2514
+
2515
+ # Set progress description based on mode
2516
+ if is_image_mode:
2517
+ progress_desc = "Classifying images"
2518
+ elif is_pdf_mode:
2519
+ progress_desc = "Classifying PDF pages"
2520
+ else:
2521
+ progress_desc = "Classifying responses"
2522
+
2523
+ # Disable tqdm when progress_callback is provided (e.g., for Streamlit/GUI apps)
2524
+ use_tqdm = progress_callback is None
2525
+ for idx, item in enumerate(tqdm(items_to_process, desc=progress_desc, disable=not use_tqdm)):
2526
+ skipped_nan = False
2527
+
2528
+ # Determine the display identifier and metadata for this item
2529
+ if is_image_mode and isinstance(item, tuple) and len(item) == 2:
2530
+ image_path, image_label = item
2531
+ display_id = image_label
2532
+ pdf_metadata = None
2533
+ image_metadata = {"image_path": image_path}
2534
+ elif is_pdf_mode and isinstance(item, tuple) and len(item) == 3:
2535
+ pdf_path, page_index, page_label = item
2536
+ display_id = page_label
2537
+ pdf_metadata = {"pdf_path": pdf_path, "page_index": page_index}
2538
+ image_metadata = None
2539
+ else:
2540
+ display_id = item
2541
+ pdf_metadata = None
2542
+ image_metadata = None
2543
+
2544
+ # Check for NaN (text mode only)
2545
+ if not is_pdf_mode and not is_image_mode and pd.isna(item):
2546
+ # Handle NaN - mark as skipped, bypass classification entirely
2547
+ skipped_nan = True
2548
+ model_results = {}
2549
+ # Build a clean aggregated result so downstream code doesn't
2550
+ # list every model as failed for a legitimately skipped row.
2551
+ skipped_aggregated = {
2552
+ "per_model": {},
2553
+ "consensus": {},
2554
+ "agreement": {},
2555
+ "failed_models": [],
2556
+ "missing_keys": {},
2557
+ "error": None,
2558
+ }
2559
+ else:
2560
+ # Classification across models
2561
+ model_results = {}
2562
+ if parallel:
2563
+ with ThreadPoolExecutor(max_workers=effective_workers) as executor:
2564
+ futures = {
2565
+ executor.submit(classify_single, cfg, item): cfg["sanitized_name"]
2566
+ for cfg in model_configs
2567
+ }
2568
+
2569
+ for future in as_completed(futures):
2570
+ model_name, json_result, error = future.result()
2571
+ model_results[model_name] = (json_result, error)
2572
+
2573
+ # Update progress (for multi-model detailed callbacks only)
2574
+ completed_calls[0] += 1
2575
+ else:
2576
+ for cfg in model_configs:
2577
+ model_name, json_result, error = classify_single(cfg, item)
2578
+ model_results[model_name] = (json_result, error)
2579
+ completed_calls[0] += 1
2580
+
2581
+ # Aggregate results with majority voting (skip for NaN rows)
2582
+ if skipped_nan:
2583
+ aggregated = skipped_aggregated
2584
+ else:
2585
+ aggregated = aggregate_results(
2586
+ model_results,
2587
+ categories,
2588
+ consensus_threshold,
2589
+ fail_strategy
2590
+ )
2591
+
2592
+ # Build result entry
2593
+ result_entry = {
2594
+ "response": display_id, # Page label for PDF, text for text mode
2595
+ "model_results": model_results,
2596
+ "aggregated": aggregated,
2597
+ "skipped": skipped_nan,
2598
+ }
2599
+
2600
+ # Add PDF metadata if in PDF mode
2601
+ if pdf_metadata:
2602
+ result_entry["pdf_path"] = pdf_metadata["pdf_path"]
2603
+ result_entry["page_index"] = pdf_metadata["page_index"]
2604
+
2605
+ # Add image metadata if in image mode
2606
+ if image_metadata:
2607
+ result_entry["image_path"] = image_metadata["image_path"]
2608
+
2609
+ # Store the original item for retry logic
2610
+ result_entry["_original_item"] = item
2611
+
2612
+ all_results.append(result_entry)
2613
+
2614
+ # Call progress callback with simple item-level signature
2615
+ # This is for UI integrations (like Streamlit) that want per-item updates
2616
+ if progress_callback:
2617
+ try:
2618
+ # Try simple signature: (current_idx, total, item_label)
2619
+ progress_callback(idx, len(items_to_process), display_id)
2620
+ except TypeError:
2621
+ # Fallback: callback might expect keyword args (old signature)
2622
+ pass
2623
+
2624
+ # Safety incremental save
2625
+ if safety:
2626
+ _save_partial_results(
2627
+ all_results,
2628
+ model_configs,
2629
+ categories,
2630
+ filename,
2631
+ save_directory,
2632
+ )
2633
+
2634
+ # Per-row delay to avoid rate limits when models share a provider
2635
+ if row_delay > 0 and idx < len(items_to_process) - 1:
2636
+ time.sleep(row_delay)
2637
+
2638
+ # Retry logic for failed (row, model) pairs
2639
+ if batch_retries > 0:
2640
+ num_cats = len(categories)
2641
+ expected_keys = {str(i) for i in range(1, num_cats + 1)}
2642
+
2643
+ for retry_num in range(1, batch_retries + 1):
2644
+ # Find all failed (row_idx, model_config) pairs
2645
+ failed_pairs = []
2646
+ for row_idx, result in enumerate(all_results):
2647
+ # Skip rows that were NaN inputs
2648
+ if result.get("skipped"):
2649
+ continue
2650
+ # Check each model for this row
2651
+ for cfg in model_configs:
2652
+ model_name = cfg["sanitized_name"]
2653
+ json_str, error = result["model_results"].get(model_name, (None, "Missing"))
2654
+ if error is not None:
2655
+ failed_pairs.append((row_idx, cfg))
2656
+ else:
2657
+ # Check JSON parsing AND schema validation
2658
+ try:
2659
+ parsed = json.loads(json_str)
2660
+ # At least one valid numbered key with 0/1 value
2661
+ valid_count = sum(
2662
+ 1 for k, v in parsed.items()
2663
+ if k in expected_keys and str(v).strip() in ("0", "1")
2664
+ )
2665
+ if valid_count == 0:
2666
+ failed_pairs.append((row_idx, cfg))
2667
+ except (json.JSONDecodeError, TypeError):
2668
+ failed_pairs.append((row_idx, cfg))
2669
+
2670
+ if not failed_pairs:
2671
+ break # All successful, no retries needed
2672
+
2673
+ print(f"\n[Batch retry {retry_num}/{batch_retries}] Retrying {len(failed_pairs)} failed (row, model) pairs...")
2674
+
2675
+ # Wait before retrying
2676
+ if retry_delay > 0:
2677
+ time.sleep(retry_delay)
2678
+
2679
+ # Retry failed pairs
2680
+ successes_this_round = 0
2681
+ if parallel:
2682
+ with ThreadPoolExecutor(max_workers=effective_workers) as executor:
2683
+ futures = {
2684
+ executor.submit(classify_single, cfg, all_results[row_idx]["_original_item"]): (row_idx, cfg)
2685
+ for row_idx, cfg in failed_pairs
2686
+ }
2687
+
2688
+ for future in as_completed(futures):
2689
+ row_idx, cfg = futures[future]
2690
+ model_name, json_result, error = future.result()
2691
+
2692
+ # Update the result in place
2693
+ all_results[row_idx]["model_results"][model_name] = (json_result, error)
2694
+
2695
+ if error is None:
2696
+ # Verify JSON is valid and has correct schema
2697
+ try:
2698
+ parsed = json.loads(json_result)
2699
+ valid_count = sum(
2700
+ 1 for k, v in parsed.items()
2701
+ if k in expected_keys and str(v).strip() in ("0", "1")
2702
+ )
2703
+ if valid_count > 0:
2704
+ successes_this_round += 1
2705
+ except (json.JSONDecodeError, TypeError):
2706
+ pass
2707
+ else:
2708
+ for row_idx, cfg in failed_pairs:
2709
+ model_name, json_result, error = classify_single(cfg, all_results[row_idx]["_original_item"])
2710
+ all_results[row_idx]["model_results"][model_name] = (json_result, error)
2711
+ if error is None:
2712
+ try:
2713
+ parsed = json.loads(json_result)
2714
+ valid_count = sum(
2715
+ 1 for k, v in parsed.items()
2716
+ if k in expected_keys and str(v).strip() in ("0", "1")
2717
+ )
2718
+ if valid_count > 0:
2719
+ successes_this_round += 1
2720
+ except (json.JSONDecodeError, TypeError):
2721
+ pass
2722
+
2723
+ # Recalculate aggregation for all affected rows
2724
+ affected_rows = set(row_idx for row_idx, _ in failed_pairs)
2725
+ for row_idx in affected_rows:
2726
+ all_results[row_idx]["aggregated"] = aggregate_results(
2727
+ all_results[row_idx]["model_results"],
2728
+ categories,
2729
+ consensus_threshold,
2730
+ fail_strategy
2731
+ )
2732
+
2733
+ print(f" -> {successes_this_round}/{len(failed_pairs)} pairs succeeded on retry")
2734
+
2735
+ # Safety save after retry
2736
+ if safety:
2737
+ _save_partial_results(
2738
+ all_results,
2739
+ model_configs,
2740
+ categories,
2741
+ filename,
2742
+ save_directory,
2743
+ )
2744
+
2745
+ # Early exit if ALL retries failed (likely server down or out of credits)
2746
+ if successes_this_round == 0:
2747
+ print(f" -> All retries failed. Stopping retry loop (possible server issue).")
2748
+ break
2749
+
2750
+ # Print summary of filled/failed values
2751
+ total_filled = 0
2752
+ filled_by_model = {}
2753
+ total_schema_failures = 0
2754
+ failures_by_model = {}
2755
+ for result in all_results:
2756
+ if result.get("skipped"):
2757
+ continue
2758
+ agg = result["aggregated"]
2759
+ # Count missing keys that were filled with 0
2760
+ for model_name, count in agg.get("missing_keys", {}).items():
2761
+ total_filled += count
2762
+ filled_by_model[model_name] = filled_by_model.get(model_name, 0) + count
2763
+ # Count models that still failed after retries
2764
+ for model_name in agg.get("failed_models", []):
2765
+ total_schema_failures += 1
2766
+ failures_by_model[model_name] = failures_by_model.get(model_name, 0) + 1
2767
+
2768
+ if total_filled > 0 or total_schema_failures > 0:
2769
+ print(f"\n--- Classification Quality Summary ---")
2770
+ if total_filled > 0:
2771
+ print(f" Missing JSON keys filled with 0: {total_filled} values")
2772
+ for model_name, count in sorted(filled_by_model.items()):
2773
+ print(f" {model_name}: {count} filled")
2774
+ if total_schema_failures > 0:
2775
+ print(f" Schema failures (after retries): {total_schema_failures}")
2776
+ for model_name, count in sorted(failures_by_model.items()):
2777
+ print(f" {model_name}: {count} failed rows")
2778
+ print()
2779
+
2780
+ # Embedding tiebreaker: resolve true ties using centroids
2781
+ if embedding_tiebreaker_state is not None and len(model_configs) > 1:
2782
+ from ._tiebreaker import resolve_ties_with_centroids
2783
+
2784
+ resolve_ties_with_centroids(
2785
+ all_results,
2786
+ categories,
2787
+ embedding_tiebreaker_state["model"],
2788
+ embedding_tiebreaker_state["threshold"],
2789
+ embedding_tiebreaker_state.get("min_centroid_size", 3),
2790
+ )
2791
+
2792
+ # Build output DataFrames
2793
+ print("Building output DataFrames...")
2794
+ return build_output_dataframes(
2795
+ all_results,
2796
+ model_configs,
2797
+ categories,
2798
+ filename,
2799
+ save_directory,
2800
+ )
2801
+
2802
+
2803
+ def build_output_dataframes(
2804
+ all_results: list,
2805
+ model_configs: list,
2806
+ categories: list,
2807
+ filename: str,
2808
+ save_directory: str,
2809
+ ) -> dict:
2810
+ """
2811
+ Build the output DataFrames from classification results.
2812
+
2813
+ Returns:
2814
+ Dict with "combined", "consensus", and per-model DataFrames
2815
+ """
2816
+ model_names = [cfg["sanitized_name"] for cfg in model_configs]
2817
+ num_categories = len(categories)
2818
+
2819
+ # Check if any results have PDF metadata
2820
+ has_pdf_metadata = any(result.get("pdf_path") is not None for result in all_results)
2821
+
2822
+ # Check if any results have image metadata
2823
+ has_image_metadata = any(result.get("image_path") is not None for result in all_results)
2824
+
2825
+ # Initialize data structures
2826
+ combined_data = {
2827
+ "input_data": [],
2828
+ "processing_status": [],
2829
+ "failed_models": [],
2830
+ }
2831
+
2832
+ # Add PDF metadata columns if present
2833
+ if has_pdf_metadata:
2834
+ combined_data["pdf_path"] = []
2835
+ combined_data["page_index"] = []
2836
+
2837
+ # Add image metadata columns if present
2838
+ if has_image_metadata:
2839
+ combined_data["image_path"] = []
2840
+
2841
+ # Add columns for each model and each category
2842
+ for model_name in model_names:
2843
+ for i in range(1, num_categories + 1):
2844
+ combined_data[f"category_{i}_{model_name}"] = []
2845
+
2846
+ # Add consensus and agreement columns
2847
+ for i in range(1, num_categories + 1):
2848
+ combined_data[f"category_{i}_consensus"] = []
2849
+ combined_data[f"category_{i}_agreement"] = []
2850
+
2851
+ # Check if tiebreaker data exists
2852
+ has_tiebreaker = any(
2853
+ "tiebreaker_resolved" in result.get("aggregated", {})
2854
+ for result in all_results
2855
+ if not result.get("skipped")
2856
+ )
2857
+ if has_tiebreaker:
2858
+ for i in range(1, num_categories + 1):
2859
+ combined_data[f"category_{i}_resolved_by"] = []
2860
+
2861
+ # Populate data
2862
+ for result in all_results:
2863
+ combined_data["input_data"].append(result["response"])
2864
+ aggregated = result["aggregated"]
2865
+
2866
+ # Add PDF metadata if present
2867
+ if has_pdf_metadata:
2868
+ combined_data["pdf_path"].append(result.get("pdf_path", ""))
2869
+ combined_data["page_index"].append(result.get("page_index", ""))
2870
+
2871
+ # Add image metadata if present
2872
+ if has_image_metadata:
2873
+ combined_data["image_path"].append(result.get("image_path", ""))
2874
+
2875
+ # Determine processing status
2876
+ if result.get("skipped"):
2877
+ combined_data["processing_status"].append("skipped")
2878
+ elif aggregated["error"]:
2879
+ combined_data["processing_status"].append("error")
2880
+ elif aggregated["failed_models"]:
2881
+ combined_data["processing_status"].append("partial")
2882
+ else:
2883
+ combined_data["processing_status"].append("success")
2884
+
2885
+ combined_data["failed_models"].append(
2886
+ ",".join(aggregated["failed_models"]) if aggregated["failed_models"] else ""
2887
+ )
2888
+
2889
+ # Per-model results
2890
+ failed_set = set(aggregated.get("failed_models", []))
2891
+ for model_name in model_names:
2892
+ if model_name in failed_set:
2893
+ # Model failed validation entirely — mark as NA
2894
+ for i in range(1, num_categories + 1):
2895
+ combined_data[f"category_{i}_{model_name}"].append(None)
2896
+ else:
2897
+ parsed = aggregated["per_model"].get(model_name, {})
2898
+ for i in range(1, num_categories + 1):
2899
+ key = str(i)
2900
+ col_name = f"category_{i}_{model_name}"
2901
+ value = parsed.get(key, "0")
2902
+ try:
2903
+ combined_data[col_name].append(int(value))
2904
+ except (ValueError, TypeError):
2905
+ combined_data[col_name].append(0)
2906
+
2907
+ # Consensus results
2908
+ for i in range(1, num_categories + 1):
2909
+ key = str(i)
2910
+ consensus_val = aggregated["consensus"].get(key, None)
2911
+ agreement_val = aggregated["agreement"].get(key, None)
2912
+
2913
+ if consensus_val is not None:
2914
+ try:
2915
+ combined_data[f"category_{i}_consensus"].append(int(consensus_val))
2916
+ except (ValueError, TypeError):
2917
+ combined_data[f"category_{i}_consensus"].append(None)
2918
+ else:
2919
+ combined_data[f"category_{i}_consensus"].append(None)
2920
+
2921
+ combined_data[f"category_{i}_agreement"].append(agreement_val)
2922
+
2923
+ # Resolved-by metadata (tiebreaker)
2924
+ if has_tiebreaker:
2925
+ tiebreaker_data = aggregated.get("tiebreaker_resolved", {})
2926
+ for i in range(1, num_categories + 1):
2927
+ key = str(i)
2928
+ combined_data[f"category_{i}_resolved_by"].append(
2929
+ tiebreaker_data.get(key, "")
2930
+ )
2931
+
2932
+ # Create combined DataFrame
2933
+ combined_df = pd.DataFrame(combined_data)
2934
+
2935
+ # Convert category columns to Int64 (nullable integer)
2936
+ cat_cols = [c for c in combined_df.columns if c.startswith("category_") and not c.endswith("_agreement") and not c.endswith("_resolved_by")]
2937
+ for col in cat_cols:
2938
+ combined_df[col] = pd.to_numeric(combined_df[col], errors='coerce').astype('Int64')
2939
+
2940
+ # Create consensus-only DataFrame
2941
+ consensus_cols = ["input_data", "processing_status", "failed_models"]
2942
+ # Add PDF columns if present
2943
+ if has_pdf_metadata:
2944
+ consensus_cols += ["pdf_path", "page_index"]
2945
+ # Add image columns if present
2946
+ if has_image_metadata:
2947
+ consensus_cols += ["image_path"]
2948
+ consensus_cols += [c for c in combined_df.columns if "_consensus" in c or "_agreement" in c or "_resolved_by" in c]
2949
+ consensus_df = combined_df[consensus_cols].copy()
2950
+
2951
+ # Create per-model DataFrames
2952
+ output = {
2953
+ "combined": combined_df,
2954
+ "consensus": consensus_df,
2955
+ }
2956
+
2957
+ for model_name in model_names:
2958
+ model_cols = ["input_data", "processing_status"]
2959
+ # Add PDF columns if present
2960
+ if has_pdf_metadata:
2961
+ model_cols += ["pdf_path", "page_index"]
2962
+ # Add image columns if present
2963
+ if has_image_metadata:
2964
+ model_cols += ["image_path"]
2965
+ model_cols += [c for c in combined_df.columns if f"_{model_name}" in c]
2966
+ output[model_name] = combined_df[model_cols].copy()
2967
+
2968
+ # If only one model, simplify before saving (backward compatible with multi_class)
2969
+ if len(model_names) == 1:
2970
+ model_name = model_names[0]
2971
+ simplified_df = combined_df.copy()
2972
+
2973
+ # Remove consensus/agreement/resolved_by/failed_models columns (redundant for single model)
2974
+ cols_to_drop = [c for c in simplified_df.columns if "_consensus" in c or "_agreement" in c or "_resolved_by" in c]
2975
+ if "failed_models" in simplified_df.columns:
2976
+ cols_to_drop.append("failed_models")
2977
+ simplified_df = simplified_df.drop(columns=cols_to_drop)
2978
+
2979
+ # Rename category columns to remove model suffix: category_1_model_name -> category_1
2980
+ rename_map = {}
2981
+ for col in simplified_df.columns:
2982
+ if col.startswith("category_") and f"_{model_name}" in col:
2983
+ new_name = col.replace(f"_{model_name}", "")
2984
+ rename_map[col] = new_name
2985
+ simplified_df = simplified_df.rename(columns=rename_map)
2986
+
2987
+ if filename:
2988
+ save_path = filename
2989
+ if save_directory:
2990
+ os.makedirs(save_directory, exist_ok=True)
2991
+ save_path = os.path.join(save_directory, filename)
2992
+ simplified_df.to_csv(save_path, index=False)
2993
+ print(f"\nCombined results saved to {save_path}")
2994
+
2995
+ return simplified_df
2996
+
2997
+ # Multi-model: save the full combined DataFrame
2998
+ if filename:
2999
+ save_path = filename
3000
+ if save_directory:
3001
+ os.makedirs(save_directory, exist_ok=True)
3002
+ save_path = os.path.join(save_directory, filename)
3003
+ combined_df.to_csv(save_path, index=False)
3004
+ print(f"\nCombined results saved to {save_path}")
3005
+
3006
+ # For multiple models, return the combined DataFrame (contains all model results + consensus)
3007
+ return combined_df
3008
+
3009
+
3010
+ # Backward compatibility alias
3011
+ multi_class_ensemble = classify_ensemble
3012
+
3013
+
3014
+ # =============================================================================
3015
+ # Summarization helpers
3016
+ # =============================================================================
3017
+
3018
+ def _save_partial_summarize_results(all_results, model_configs, model_names, is_pdf_mode, filename, save_directory):
3019
+ """Save partial summarization results to CSV for safety/incremental saves."""
3020
+ rows = []
3021
+ for entry in all_results:
3022
+ item = entry["input_data"]
3023
+ if is_pdf_mode and isinstance(item, tuple) and len(item) == 3:
3024
+ pdf_path, page_index, page_label = item
3025
+ row = {
3026
+ "input_data": page_label,
3027
+ "pdf_path": pdf_path,
3028
+ "page_index": page_index,
3029
+ }
3030
+ else:
3031
+ row = {"input_data": item}
3032
+
3033
+ for model_name, json_str in entry["model_results"].items():
3034
+ is_valid, summary_text = extract_summary_from_json(json_str)
3035
+ if len(model_configs) > 1:
3036
+ row[f"summary_{model_name}"] = summary_text if is_valid else ""
3037
+ else:
3038
+ row["summary"] = summary_text if is_valid else ""
3039
+
3040
+ if entry["errors"]:
3041
+ row["processing_status"] = "error" if all(
3042
+ not extract_summary_from_json(v)[1] for v in entry["model_results"].values()
3043
+ ) else "partial"
3044
+ else:
3045
+ row["processing_status"] = "success"
3046
+
3047
+ rows.append(row)
3048
+
3049
+ df = pd.DataFrame(rows)
3050
+ save_path = os.path.join(save_directory, filename) if save_directory else filename
3051
+ if save_directory:
3052
+ os.makedirs(save_directory, exist_ok=True)
3053
+ df.to_csv(save_path, index=False)
3054
+
3055
+
3056
+ # =============================================================================
3057
+ # Summarization Ensemble Function
3058
+ # =============================================================================
3059
+
3060
+ def summarize_ensemble(
3061
+ input_data,
3062
+ api_key: str = None,
3063
+ input_description: str = "",
3064
+ summary_instructions: str = "",
3065
+ max_length: int = None,
3066
+ focus: str = None,
3067
+ user_model: str = "gpt-4o",
3068
+ model_source: str = "auto",
3069
+ pdf_mode: str = "image",
3070
+ pdf_dpi: int = 150,
3071
+ creativity: float = None,
3072
+ chain_of_thought: bool = False,
3073
+ context_prompt: bool = False,
3074
+ step_back_prompt: bool = False,
3075
+ max_retries: int = 5,
3076
+ batch_retries: int = 2,
3077
+ retry_delay: float = 1.0,
3078
+ row_delay: float = 0.0,
3079
+ fail_strategy: str = "partial",
3080
+ safety: bool = False,
3081
+ filename: str = None,
3082
+ save_directory: str = None,
3083
+ thinking_budget: int = 0,
3084
+ progress_callback: Optional[Callable] = None,
3085
+ # Multi-model parameters
3086
+ models: list = None,
3087
+ max_workers: int = None,
3088
+ parallel: bool = None,
3089
+ auto_download: bool = False,
3090
+ ) -> pd.DataFrame:
3091
+ """
3092
+ Summarize text or PDF inputs using LLMs with optional multi-model ensemble.
3093
+
3094
+ Supports single-model and multi-model modes. In multi-model mode,
3095
+ summaries from all models are collected and synthesized into a consensus
3096
+ summary using an LLM. Input type is auto-detected from the data.
3097
+
3098
+ Args:
3099
+ input_data: Data to summarize. Can be:
3100
+ - Text: list of strings, pandas Series, or single string
3101
+ - PDF: directory path, single PDF path, or list of PDF paths
3102
+ api_key: API key for single-model mode
3103
+ input_description: Description of what the content contains (provides context)
3104
+ summary_instructions: Specific summarization instructions (e.g., "bullet points")
3105
+ max_length: Maximum summary length in words
3106
+ focus: What to focus on (e.g., "main arguments", "emotional content")
3107
+ user_model: Model to use (default "gpt-4o")
3108
+ model_source: Provider - "auto", "openai", "anthropic", "google", etc.
3109
+ pdf_mode: PDF processing mode (only used for PDF input):
3110
+ - "image" (default): Render pages as images
3111
+ - "text": Extract text only
3112
+ - "both": Send both image and extracted text
3113
+ pdf_dpi: DPI for PDF page rendering (default 150)
3114
+ creativity: Temperature setting (None uses provider default)
3115
+ chain_of_thought: Enable step-by-step reasoning (default True)
3116
+ context_prompt: Add expert context prefix
3117
+ step_back_prompt: Enable step-back prompting
3118
+ max_retries: Max retries per API call
3119
+ batch_retries: Number of batch retry passes for failed items
3120
+ retry_delay: Delay between retries in seconds
3121
+ row_delay: Delay in seconds between processing each row (default 0.0)
3122
+ fail_strategy: How to handle failures - "partial" (default) or "strict"
3123
+ thinking_budget: Token budget for extended thinking/reasoning (default 0)
3124
+ safety: Save progress after each item (requires filename)
3125
+ filename: Output CSV filename
3126
+ save_directory: Directory to save results
3127
+ progress_callback: Optional callback for progress updates
3128
+ models: For multi-model mode, list of (model, provider, api_key) tuples
3129
+ max_workers: Maximum parallel workers (default: min(len(models), 8))
3130
+ parallel: Controls concurrent vs sequential execution (None=auto-detect)
3131
+ auto_download: Auto-download missing Ollama models (default False)
3132
+
3133
+ Returns:
3134
+ DataFrame with columns:
3135
+ - input_data: Original text or page label (for PDFs)
3136
+ - summary: Generated summary (or consensus summary for multi-model)
3137
+ - summary_<model>: Per-model summaries (multi-model only)
3138
+ - processing_status: "success", "error", "skipped"
3139
+ - failed_models: Comma-separated list (multi-model only)
3140
+ - pdf_path: Path to source PDF (PDF mode only)
3141
+ - page_index: Page number, 0-indexed (PDF mode only)
3142
+
3143
+ Examples:
3144
+ >>> import cat_stack as cat
3145
+ >>>
3146
+ >>> # Single model text summarization
3147
+ >>> results = cat.summarize(
3148
+ ... input_data=df['responses'],
3149
+ ... description="Customer feedback",
3150
+ ... api_key=api_key
3151
+ ... )
3152
+ >>>
3153
+ >>> # PDF summarization (auto-detected)
3154
+ >>> results = cat.summarize(
3155
+ ... input_data="/path/to/pdfs/",
3156
+ ... description="Research papers",
3157
+ ... mode="image",
3158
+ ... api_key=api_key
3159
+ ... )
3160
+ >>>
3161
+ >>> # Multi-model with synthesis
3162
+ >>> results = cat.summarize(
3163
+ ... input_data=df['responses'],
3164
+ ... models=[
3165
+ ... ("gpt-4o", "openai", "sk-..."),
3166
+ ... ("claude-sonnet-4-5-20250929", "anthropic", "sk-ant-..."),
3167
+ ... ],
3168
+ ... )
3169
+ """
3170
+ # Safety validation
3171
+ if safety and filename is None:
3172
+ raise TypeError("filename is required when using safety=True.")
3173
+
3174
+ # Detect input type: Text vs PDF
3175
+ input_type = _detect_input_type(input_data)
3176
+ is_pdf_mode = (input_type == 'pdf')
3177
+
3178
+ if is_pdf_mode:
3179
+ # Validate pdf_mode parameter
3180
+ pdf_mode = pdf_mode.lower()
3181
+ if pdf_mode not in {"image", "text", "both"}:
3182
+ raise ValueError(f"pdf_mode must be 'image', 'text', or 'both', got: {pdf_mode}")
3183
+
3184
+ print(f"\nInput type detected: PDF")
3185
+ print(f"PDF processing mode: {pdf_mode}")
3186
+
3187
+ # Load PDF files
3188
+ pdf_files = _load_pdf_files(input_data)
3189
+
3190
+ # Extract all pages from all PDFs
3191
+ all_pages = []
3192
+ for pdf_path in pdf_files:
3193
+ pages = _get_pdf_pages(pdf_path)
3194
+ all_pages.extend(pages)
3195
+
3196
+ if not all_pages:
3197
+ raise ValueError("No pages found in the provided PDF files.")
3198
+
3199
+ items_to_process = all_pages
3200
+ print(f"Total PDF pages to summarize: {len(items_to_process)}")
3201
+ else:
3202
+ # TEXT MODE: Normalize input to list
3203
+ print(f"\nInput type detected: TEXT")
3204
+ if isinstance(input_data, str):
3205
+ input_data = [input_data]
3206
+ elif hasattr(input_data, 'tolist'):
3207
+ input_data = input_data.tolist()
3208
+ else:
3209
+ input_data = list(input_data)
3210
+
3211
+ items_to_process = input_data
3212
+ print(f"Total texts to summarize: {len(items_to_process)}")
3213
+
3214
+ # Normalize model input to list of tuples
3215
+ models = normalize_model_input(user_model, api_key, model_source, models)
3216
+
3217
+ # Validate and prepare model configs
3218
+ print(f"Validating {len(models)} model configuration(s)...")
3219
+ model_configs = prepare_model_configs(models, auto_download=auto_download)
3220
+
3221
+ if not model_configs:
3222
+ raise ValueError("No valid model configurations found.")
3223
+
3224
+ model_names = [cfg["sanitized_name"] for cfg in model_configs]
3225
+ print(f"\nModels to use:")
3226
+ for cfg in model_configs:
3227
+ print(f" - {cfg['model']} ({cfg['provider']}) -> column suffix: {cfg['sanitized_name']}")
3228
+
3229
+ # Build JSON schemas per provider (for summary output)
3230
+ json_schemas = {}
3231
+ for cfg in model_configs:
3232
+ provider = cfg["provider"]
3233
+ include_additional = provider != "google"
3234
+ json_schemas[cfg["sanitized_name"]] = build_summary_json_schema(include_additional)
3235
+
3236
+ # Example JSON for prompt
3237
+ example_json = '{"summary": "Your summary here"}'
3238
+
3239
+ # Gather step-back insights if enabled
3240
+ stepback_insights = {}
3241
+ if step_back_prompt:
3242
+ print("\nGathering step-back insights...")
3243
+ stepback_insights = gather_stepback_insights(
3244
+ model_configs=model_configs,
3245
+ context=input_description or "text summarization",
3246
+ question=f"What are the key factors to consider when summarizing text{f' with a focus on {focus}' if focus else ''}?"
3247
+ )
3248
+
3249
+ # Initialize results storage
3250
+ all_results = [] # List of dicts, one per input item
3251
+ failed_pairs = [] # List of (idx, model_name) pairs that failed
3252
+
3253
+ # Define the summarization function for a single item
3254
+ def summarize_single_item(item, idx, cfg):
3255
+ """Summarize a single text item or PDF page with a single model."""
3256
+ model_name = cfg["sanitized_name"]
3257
+
3258
+ # Determine if this is a PDF page or text
3259
+ if is_pdf_mode and isinstance(item, tuple) and len(item) == 3:
3260
+ # PDF mode: item is (pdf_path, page_index, page_label)
3261
+ pdf_path, page_index, page_label = item
3262
+
3263
+ try:
3264
+ # Prepare page data based on mode and provider
3265
+ page_data = _prepare_page_data(
3266
+ pdf_path=pdf_path,
3267
+ page_index=page_index,
3268
+ page_label=page_label,
3269
+ pdf_mode=pdf_mode,
3270
+ provider=cfg["provider"],
3271
+ pdf_dpi=pdf_dpi,
3272
+ )
3273
+
3274
+ # Check for extraction errors
3275
+ if page_data.get("error"):
3276
+ return (model_name, '{"summary": ""}', page_data["error"])
3277
+
3278
+ # Build PDF summarization prompt
3279
+ messages = build_pdf_summarization_prompt(
3280
+ page_data=page_data,
3281
+ input_description=input_description,
3282
+ summary_instructions=summary_instructions,
3283
+ max_length=max_length,
3284
+ focus=focus,
3285
+ provider=cfg["provider"],
3286
+ pdf_mode=pdf_mode,
3287
+ chain_of_thought=chain_of_thought,
3288
+ context_prompt=context_prompt,
3289
+ step_back_prompt=step_back_prompt,
3290
+ stepback_insights=stepback_insights,
3291
+ model_name=model_name,
3292
+ )
3293
+
3294
+ # Create client and make API call
3295
+ client = UnifiedLLMClient(
3296
+ provider=cfg["provider"],
3297
+ api_key=cfg["api_key"],
3298
+ model=cfg["model"],
3299
+ )
3300
+
3301
+ json_schema = json_schemas[model_name]
3302
+
3303
+ # Resolve thinking_budget for this provider
3304
+ effective_thinking = thinking_budget if cfg["provider"] in ("google", "openai", "anthropic", "huggingface", "huggingface-together") else None
3305
+
3306
+ # Handle Google multimodal differently
3307
+ if cfg["provider"] == "google" and pdf_mode != "text":
3308
+ response = _call_google_multimodal(
3309
+ client=client,
3310
+ messages=messages,
3311
+ json_schema=json_schema,
3312
+ creativity=creativity,
3313
+ thinking_budget=effective_thinking or 0,
3314
+ max_retries=max_retries,
3315
+ )
3316
+ else:
3317
+ response, _err = client.complete(
3318
+ messages=messages,
3319
+ json_schema=json_schema,
3320
+ creativity=creativity,
3321
+ thinking_budget=effective_thinking,
3322
+ max_retries=max_retries,
3323
+ )
3324
+
3325
+ # Extract JSON from response
3326
+ json_str = extract_json(response)
3327
+
3328
+ return (model_name, json_str, None)
3329
+
3330
+ except Exception as e:
3331
+ error_msg = str(e)
3332
+ return (model_name, '{"summary": ""}', error_msg)
3333
+
3334
+ else:
3335
+ # TEXT MODE: Original text handling
3336
+ # Skip empty/null items
3337
+ if item is None or (isinstance(item, str) and not item.strip()) or pd.isna(item):
3338
+ return (model_name, '{"summary": ""}', "skipped")
3339
+
3340
+ try:
3341
+ # Build the prompt
3342
+ messages = build_text_summarization_prompt(
3343
+ response_text=str(item),
3344
+ input_description=input_description,
3345
+ summary_instructions=summary_instructions,
3346
+ max_length=max_length,
3347
+ focus=focus,
3348
+ chain_of_thought=chain_of_thought,
3349
+ context_prompt=context_prompt,
3350
+ step_back_prompt=step_back_prompt,
3351
+ stepback_insights=stepback_insights,
3352
+ model_name=model_name,
3353
+ )
3354
+
3355
+ # Create client and make API call
3356
+ client = UnifiedLLMClient(
3357
+ provider=cfg["provider"],
3358
+ api_key=cfg["api_key"],
3359
+ model=cfg["model"],
3360
+ )
3361
+
3362
+ json_schema = json_schemas[model_name]
3363
+
3364
+ # Resolve thinking_budget for this provider
3365
+ effective_thinking = thinking_budget if cfg["provider"] in ("google", "openai", "anthropic", "huggingface", "huggingface-together") else None
3366
+
3367
+ response, _err = client.complete(
3368
+ messages=messages,
3369
+ json_schema=json_schema,
3370
+ creativity=creativity,
3371
+ thinking_budget=effective_thinking,
3372
+ max_retries=max_retries,
3373
+ )
3374
+
3375
+ # Extract JSON from response
3376
+ json_str = extract_json(response)
3377
+
3378
+ return (model_name, json_str, None)
3379
+
3380
+ except Exception as e:
3381
+ error_msg = str(e)
3382
+ return (model_name, '{"summary": ""}', error_msg)
3383
+
3384
+ # Process all items
3385
+ progress_desc = "Summarizing PDF pages" if is_pdf_mode else "Summarizing texts"
3386
+ print(f"\n{progress_desc}...")
3387
+
3388
+ # Auto-resolve parallel mode: sequential for all-local (Ollama), parallel otherwise
3389
+ if parallel is None:
3390
+ all_local = all(cfg["provider"] == "ollama" for cfg in model_configs)
3391
+ parallel = not all_local
3392
+
3393
+ # Determine number of workers
3394
+ effective_workers = max_workers or min(len(model_configs), 8)
3395
+
3396
+ # Progress tracking
3397
+ total_items = len(items_to_process)
3398
+
3399
+ for idx, item in enumerate(tqdm(items_to_process, desc=progress_desc)):
3400
+ item_results = {}
3401
+ item_errors = {}
3402
+
3403
+ # Process with each model (in parallel if enabled, sequential otherwise)
3404
+ if parallel and len(model_configs) > 1:
3405
+ with ThreadPoolExecutor(max_workers=effective_workers) as executor:
3406
+ futures = {
3407
+ executor.submit(summarize_single_item, item, idx, cfg): cfg["sanitized_name"]
3408
+ for cfg in model_configs
3409
+ }
3410
+
3411
+ for future in as_completed(futures):
3412
+ model_name, json_result, error = future.result()
3413
+ item_results[model_name] = json_result
3414
+ if error and error != "skipped":
3415
+ item_errors[model_name] = error
3416
+ failed_pairs.append((idx, model_name))
3417
+ else:
3418
+ for cfg in model_configs:
3419
+ model_name, json_result, error = summarize_single_item(item, idx, cfg)
3420
+ item_results[model_name] = json_result
3421
+ if error and error != "skipped":
3422
+ item_errors[model_name] = error
3423
+ failed_pairs.append((idx, model_name))
3424
+
3425
+ # Store results for this item
3426
+ result_entry = {
3427
+ "idx": idx,
3428
+ "input_data": item,
3429
+ "model_results": item_results,
3430
+ "errors": item_errors,
3431
+ }
3432
+ all_results.append(result_entry)
3433
+
3434
+ # Safety: incremental save after each item
3435
+ if safety:
3436
+ _save_partial_summarize_results(all_results, model_configs, model_names, is_pdf_mode, filename, save_directory)
3437
+
3438
+ # Progress callback
3439
+ if progress_callback:
3440
+ progress_callback(idx + 1, total_items)
3441
+
3442
+ # Row delay
3443
+ if row_delay > 0 and idx < len(items_to_process) - 1:
3444
+ time.sleep(row_delay)
3445
+
3446
+ # Batch retries for failed pairs
3447
+ for retry_pass in range(batch_retries):
3448
+ if not failed_pairs:
3449
+ break
3450
+
3451
+ print(f"\n[Batch retry {retry_pass + 1}/{batch_retries}] Retrying {len(failed_pairs)} failed (row, model) pairs...")
3452
+ time.sleep(retry_delay)
3453
+
3454
+ retry_success = 0
3455
+ still_failed = []
3456
+
3457
+ for idx, model_name in failed_pairs:
3458
+ cfg = next((c for c in model_configs if c["sanitized_name"] == model_name), None)
3459
+ if not cfg:
3460
+ continue
3461
+
3462
+ item = items_to_process[idx]
3463
+ model_name_result, json_result, error = summarize_single_item(item, idx, cfg)
3464
+
3465
+ if error and error != "skipped":
3466
+ still_failed.append((idx, model_name))
3467
+ else:
3468
+ # Update the stored result
3469
+ all_results[idx]["model_results"][model_name] = json_result
3470
+ if model_name in all_results[idx]["errors"]:
3471
+ del all_results[idx]["errors"][model_name]
3472
+ retry_success += 1
3473
+
3474
+ failed_pairs = still_failed
3475
+ print(f" -> {retry_success}/{len(failed_pairs) + retry_success} pairs succeeded on retry")
3476
+
3477
+ # Safety: save after each batch retry pass
3478
+ if safety:
3479
+ _save_partial_summarize_results(all_results, model_configs, model_names, is_pdf_mode, filename, save_directory)
3480
+
3481
+ # Build output DataFrame
3482
+ print("\nBuilding output DataFrame...")
3483
+
3484
+ rows = []
3485
+ for entry in all_results:
3486
+ # Handle PDF mode: extract metadata from tuple
3487
+ item = entry["input_data"]
3488
+ if is_pdf_mode and isinstance(item, tuple) and len(item) == 3:
3489
+ pdf_path, page_index, page_label = item
3490
+ row = {
3491
+ "input_data": page_label,
3492
+ "pdf_path": pdf_path,
3493
+ "page_index": page_index,
3494
+ }
3495
+ original_text_for_synthesis = page_label # Use page label for synthesis context
3496
+ else:
3497
+ row = {"input_data": item}
3498
+ original_text_for_synthesis = item
3499
+
3500
+ # Extract summaries from each model
3501
+ summaries = {}
3502
+ has_errors = bool(entry["errors"])
3503
+ for model_name, json_str in entry["model_results"].items():
3504
+ is_valid, summary_text = extract_summary_from_json(json_str)
3505
+ if is_valid:
3506
+ summaries[model_name] = summary_text
3507
+ else:
3508
+ summaries[model_name] = ""
3509
+
3510
+ # fail_strategy="strict": blank out results if any model failed
3511
+ if fail_strategy == "strict" and has_errors:
3512
+ summaries = {k: "" for k in summaries}
3513
+
3514
+ # For multi-model: synthesize consensus
3515
+ if len(model_configs) > 1:
3516
+ # Add individual model summaries
3517
+ for model_name in model_names:
3518
+ row[f"summary_{model_name}"] = summaries.get(model_name, "")
3519
+
3520
+ # Synthesize consensus summary using the first successful model
3521
+ valid_summaries = {k: v for k, v in summaries.items() if v}
3522
+ if valid_summaries:
3523
+ # Use the first model config for synthesis
3524
+ synthesis_cfg = model_configs[0]
3525
+ consensus = _synthesize_summaries(
3526
+ summaries=valid_summaries,
3527
+ original_text=str(original_text_for_synthesis),
3528
+ synthesis_config=synthesis_cfg,
3529
+ max_retries=max_retries,
3530
+ )
3531
+ row["summary"] = consensus
3532
+ else:
3533
+ row["summary"] = ""
3534
+
3535
+ # Track failed models
3536
+ row["failed_models"] = ",".join(entry["errors"].keys()) if entry["errors"] else ""
3537
+
3538
+ else:
3539
+ # Single model: just use the summary directly
3540
+ model_name = model_names[0]
3541
+ row["summary"] = summaries.get(model_name, "")
3542
+
3543
+ # Processing status
3544
+ if all(not s for s in summaries.values()):
3545
+ # For PDF mode, check if it's a valid tuple (never skip PDFs)
3546
+ if is_pdf_mode:
3547
+ row["processing_status"] = "error"
3548
+ elif item is None or (isinstance(item, str) and not item.strip()):
3549
+ row["processing_status"] = "skipped"
3550
+ else:
3551
+ row["processing_status"] = "error"
3552
+ elif any(not s for s in summaries.values()):
3553
+ row["processing_status"] = "partial"
3554
+ else:
3555
+ row["processing_status"] = "success"
3556
+
3557
+ rows.append(row)
3558
+
3559
+ df = pd.DataFrame(rows)
3560
+
3561
+ # Save to file if requested
3562
+ if filename:
3563
+ save_path = os.path.join(save_directory, filename) if save_directory else filename
3564
+ df.to_csv(save_path, index=False)
3565
+ print(f"\nResults saved to: {save_path}")
3566
+
3567
+ return df
3568
+
3569
+
3570
+ def _synthesize_summaries(
3571
+ summaries: dict,
3572
+ original_text: str,
3573
+ synthesis_config: dict,
3574
+ max_retries: int = 3,
3575
+ ) -> str:
3576
+ """
3577
+ Synthesize multiple model summaries into one consensus summary.
3578
+
3579
+ Args:
3580
+ summaries: Dict of {model_name: summary_text}
3581
+ original_text: The original text that was summarized
3582
+ synthesis_config: Model config to use for synthesis
3583
+ max_retries: Max retries for synthesis call
3584
+
3585
+ Returns:
3586
+ Synthesized consensus summary string
3587
+ """
3588
+ if len(summaries) == 1:
3589
+ return list(summaries.values())[0]
3590
+
3591
+ # Build synthesis prompt
3592
+ summaries_text = "\n".join([
3593
+ f"- {model}: \"{summary}\""
3594
+ for model, summary in summaries.items()
3595
+ ])
3596
+
3597
+ # Truncate original text if too long
3598
+ max_original_len = 500
3599
+ original_display = original_text[:max_original_len]
3600
+ if len(original_text) > max_original_len:
3601
+ original_display += "..."
3602
+
3603
+ synthesis_prompt = f"""You are synthesizing multiple AI-generated summaries of the same text into one optimal summary.
3604
+
3605
+ Original text: "{original_display}"
3606
+
3607
+ Summaries from different models:
3608
+ {summaries_text}
3609
+
3610
+ Create a single, comprehensive summary that captures the best insights from all summaries.
3611
+ Resolve any contradictions by focusing on accuracy.
3612
+
3613
+ Provide your answer in JSON format: {{"summary": "your synthesized summary"}}"""
3614
+
3615
+ try:
3616
+ client = UnifiedLLMClient(
3617
+ provider=synthesis_config["provider"],
3618
+ api_key=synthesis_config["api_key"],
3619
+ model=synthesis_config["model"],
3620
+ )
3621
+
3622
+ json_schema = build_summary_json_schema(
3623
+ include_additional_properties=synthesis_config["provider"] != "google"
3624
+ )
3625
+
3626
+ response, _err = client.complete(
3627
+ messages=[{"role": "user", "content": synthesis_prompt}],
3628
+ json_schema=json_schema,
3629
+ creativity=0.3, # Low creativity for synthesis
3630
+ max_retries=max_retries,
3631
+ )
3632
+
3633
+ json_str = extract_json(response)
3634
+ is_valid, summary = extract_summary_from_json(json_str)
3635
+
3636
+ if is_valid:
3637
+ return summary
3638
+ else:
3639
+ # Fallback: return the longest summary
3640
+ return max(summaries.values(), key=len)
3641
+
3642
+ except Exception as e:
3643
+ print(f"Warning: Synthesis failed ({e}), using longest summary as fallback")
3644
+ return max(summaries.values(), key=len)