cat-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cat_stack/__about__.py +10 -0
- cat_stack/__init__.py +128 -0
- cat_stack/_batch.py +1388 -0
- cat_stack/_category_analysis.py +348 -0
- cat_stack/_chunked.py +424 -0
- cat_stack/_embeddings.py +189 -0
- cat_stack/_formatter.py +169 -0
- cat_stack/_providers.py +1048 -0
- cat_stack/_tiebreaker.py +277 -0
- cat_stack/_utils.py +512 -0
- cat_stack/_web_fetch.py +194 -0
- cat_stack/calls/CoVe.py +287 -0
- cat_stack/calls/__init__.py +25 -0
- cat_stack/calls/all_calls.py +622 -0
- cat_stack/calls/image_CoVe.py +386 -0
- cat_stack/calls/image_stepback.py +210 -0
- cat_stack/calls/pdf_CoVe.py +386 -0
- cat_stack/calls/pdf_stepback.py +210 -0
- cat_stack/calls/stepback.py +180 -0
- cat_stack/calls/top_n.py +217 -0
- cat_stack/classify.py +682 -0
- cat_stack/explore.py +111 -0
- cat_stack/extract.py +218 -0
- cat_stack/image_functions.py +2078 -0
- cat_stack/images/circle.png +0 -0
- cat_stack/images/cube.png +0 -0
- cat_stack/images/diamond.png +0 -0
- cat_stack/images/overlapping_pentagons.png +0 -0
- cat_stack/images/rectangles.png +0 -0
- cat_stack/model_reference_list.py +94 -0
- cat_stack/pdf_functions.py +2087 -0
- cat_stack/summarize.py +290 -0
- cat_stack/text_functions.py +1358 -0
- cat_stack/text_functions_ensemble.py +3644 -0
- cat_stack-0.1.0.dist-info/METADATA +150 -0
- cat_stack-0.1.0.dist-info/RECORD +38 -0
- cat_stack-0.1.0.dist-info/WHEEL +4 -0
- cat_stack-0.1.0.dist-info/licenses/LICENSE +672 -0
|
@@ -0,0 +1,3644 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Ensemble text classification functions for CatLLM.
|
|
3
|
+
|
|
4
|
+
This module provides multi-model ensemble classification using parallel execution.
|
|
5
|
+
Multiple LLM models are called simultaneously and results are combined using
|
|
6
|
+
majority voting for more robust classification.
|
|
7
|
+
|
|
8
|
+
MODULE STRUCTURE:
|
|
9
|
+
=================
|
|
10
|
+
Helper Functions:
|
|
11
|
+
- sanitize_model_name(): Convert model names to column-safe suffixes
|
|
12
|
+
- prepare_model_configs(): Validate model configurations, check Ollama
|
|
13
|
+
|
|
14
|
+
Shared Utilities (reusable for image/pdf classification):
|
|
15
|
+
- normalize_model_input(): Convert various input formats to list of tuples
|
|
16
|
+
- gather_stepback_insights(): Get step-back insights from each model
|
|
17
|
+
- prepare_json_schemas(): Build JSON schemas per provider
|
|
18
|
+
- aggregate_results(): Majority voting consensus
|
|
19
|
+
- build_output_dataframes(): Build output DataFrames
|
|
20
|
+
|
|
21
|
+
Text-Specific (swap for image classification):
|
|
22
|
+
- build_text_classification_prompt(): Build prompt for text classification
|
|
23
|
+
|
|
24
|
+
Main Function:
|
|
25
|
+
- multi_class_ensemble(): Main entry point
|
|
26
|
+
- Supports single model (returns DataFrame) or multiple models (returns dict)
|
|
27
|
+
- Supports categories="auto" for auto-detection
|
|
28
|
+
|
|
29
|
+
Chain of Verification (CoVe):
|
|
30
|
+
============================
|
|
31
|
+
CoVe is a 4-step prompting strategy to improve classification accuracy:
|
|
32
|
+
Step 1: Initial classification (existing classify_single)
|
|
33
|
+
Step 2: Generate verification questions about the classification
|
|
34
|
+
Step 3: Answer each verification question (up to 5 questions)
|
|
35
|
+
Step 4: Final corrected classification based on Q&A
|
|
36
|
+
|
|
37
|
+
Usage: Set chain_of_verification=True in multi_class_ensemble()
|
|
38
|
+
Note: CoVe requires ~4x API calls per response. Not recommended for ensemble mode
|
|
39
|
+
due to cost, but supported for single-model usage.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
__all__ = ["classify_ensemble", "multi_class_ensemble", "summarize_ensemble"]
|
|
43
|
+
|
|
44
|
+
import json
|
|
45
|
+
import os
|
|
46
|
+
import re
|
|
47
|
+
import time
|
|
48
|
+
import pandas as pd
|
|
49
|
+
from pathlib import Path
|
|
50
|
+
from tqdm import tqdm
|
|
51
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
52
|
+
from typing import Optional, Callable, Union
|
|
53
|
+
|
|
54
|
+
from .text_functions import (
|
|
55
|
+
UnifiedLLMClient,
|
|
56
|
+
detect_provider,
|
|
57
|
+
build_json_schema,
|
|
58
|
+
extract_json,
|
|
59
|
+
validate_classification_json,
|
|
60
|
+
ollama_two_step_classify,
|
|
61
|
+
check_ollama_running,
|
|
62
|
+
check_ollama_model,
|
|
63
|
+
pull_ollama_model,
|
|
64
|
+
check_claude_cli_available,
|
|
65
|
+
_get_stepback_insight,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# PDF utility imports
|
|
69
|
+
from .pdf_functions import (
|
|
70
|
+
_load_pdf_files,
|
|
71
|
+
_get_pdf_pages,
|
|
72
|
+
_extract_page_as_image_bytes,
|
|
73
|
+
_extract_page_as_pdf_bytes,
|
|
74
|
+
_extract_page_text,
|
|
75
|
+
_encode_bytes_to_base64,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Image utility imports
|
|
79
|
+
from .image_functions import (
|
|
80
|
+
_load_image_files,
|
|
81
|
+
_encode_image,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# =============================================================================
|
|
86
|
+
# Consensus Threshold Helper
|
|
87
|
+
# =============================================================================
|
|
88
|
+
|
|
89
|
+
def _resolve_consensus_threshold(threshold: Union[str, float, int]) -> float:
|
|
90
|
+
"""
|
|
91
|
+
Convert consensus threshold to a numeric value.
|
|
92
|
+
|
|
93
|
+
Accepts both string aliases and numeric values for flexibility:
|
|
94
|
+
- String values: "majority" (0.5), "two-thirds" (0.67), "unanimous" (1.0)
|
|
95
|
+
- Numeric values: Any float between 0 and 1
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
threshold: Either a string alias or numeric value (0-1)
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
float: The resolved threshold value
|
|
102
|
+
|
|
103
|
+
Examples:
|
|
104
|
+
>>> _resolve_consensus_threshold("majority")
|
|
105
|
+
0.5
|
|
106
|
+
>>> _resolve_consensus_threshold("two-thirds")
|
|
107
|
+
0.67
|
|
108
|
+
>>> _resolve_consensus_threshold(0.75)
|
|
109
|
+
0.75
|
|
110
|
+
"""
|
|
111
|
+
if isinstance(threshold, str):
|
|
112
|
+
mapping = {
|
|
113
|
+
"majority": 0.5,
|
|
114
|
+
"two-thirds": 0.67,
|
|
115
|
+
"two_thirds": 0.67,
|
|
116
|
+
"twothirds": 0.67,
|
|
117
|
+
"unanimous": 1.0,
|
|
118
|
+
}
|
|
119
|
+
resolved = mapping.get(threshold.lower().strip())
|
|
120
|
+
if resolved is None:
|
|
121
|
+
valid_options = ", ".join(f'"{k}"' for k in ["majority", "two-thirds", "unanimous"])
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"Invalid consensus_threshold string: '{threshold}'. "
|
|
124
|
+
f"Valid options: {valid_options}, or a numeric value between 0 and 1."
|
|
125
|
+
)
|
|
126
|
+
return resolved
|
|
127
|
+
else:
|
|
128
|
+
value = float(threshold)
|
|
129
|
+
if not 0 <= value <= 1:
|
|
130
|
+
raise ValueError(f"consensus_threshold must be between 0 and 1, got {value}")
|
|
131
|
+
return value
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# =============================================================================
|
|
135
|
+
# Test Utilities (for debugging batch retry logic)
|
|
136
|
+
# =============================================================================
|
|
137
|
+
|
|
138
|
+
# Global flag and state for retry testing - set _TEST_FORCE_FAILURE = True to test
|
|
139
|
+
_TEST_FORCE_FAILURE = False
|
|
140
|
+
_test_attempted_pairs = set()
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _test_should_force_failure(response_text: str, model_name: str) -> bool:
|
|
144
|
+
"""
|
|
145
|
+
Test helper: Returns True if this (response, model) pair should be forced to fail.
|
|
146
|
+
|
|
147
|
+
Only forces failure on FIRST attempt. Subsequent attempts (retries) will proceed normally.
|
|
148
|
+
|
|
149
|
+
Usage: Set _TEST_FORCE_FAILURE = True at module level to enable testing.
|
|
150
|
+
"""
|
|
151
|
+
if not _TEST_FORCE_FAILURE:
|
|
152
|
+
return False
|
|
153
|
+
|
|
154
|
+
pair_key = (response_text[:50], model_name)
|
|
155
|
+
if pair_key not in _test_attempted_pairs:
|
|
156
|
+
_test_attempted_pairs.add(pair_key)
|
|
157
|
+
print(f" [TEST] Forcing error for: {model_name} on '{response_text[:30]}...'")
|
|
158
|
+
return True
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _test_reset():
|
|
163
|
+
"""Reset test state between test runs."""
|
|
164
|
+
global _test_attempted_pairs
|
|
165
|
+
_test_attempted_pairs = set()
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
# =============================================================================
|
|
169
|
+
# Input Type Detection
|
|
170
|
+
# =============================================================================
|
|
171
|
+
|
|
172
|
+
# Supported image extensions (from image_functions.py)
|
|
173
|
+
_IMAGE_EXTENSIONS = {
|
|
174
|
+
'.png', '.jpg', '.jpeg', '.gif', '.webp', '.svg', '.svgz',
|
|
175
|
+
'.avif', '.apng', '.tif', '.tiff', '.bmp', '.heif', '.heic',
|
|
176
|
+
'.ico', '.psd', '.jfif', '.pjpeg', '.pjp', '.jpe'
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _detect_input_type(input_data) -> str:
|
|
181
|
+
"""
|
|
182
|
+
Detect if input is text strings, PDF files, or image files.
|
|
183
|
+
|
|
184
|
+
Auto-detection logic:
|
|
185
|
+
- If input ends in .pdf → PDF mode
|
|
186
|
+
- If input ends in image extension (.png, .jpg, etc.) → Image mode
|
|
187
|
+
- If input is a directory → Check first file to determine PDF or Image mode
|
|
188
|
+
- Otherwise → Text mode
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
input_data: Text strings, PDF paths, image paths, or directory path
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
'text', 'pdf', or 'image'
|
|
195
|
+
"""
|
|
196
|
+
# Handle single string input
|
|
197
|
+
if isinstance(input_data, (str, Path)):
|
|
198
|
+
survey_str = str(input_data)
|
|
199
|
+
ext = os.path.splitext(survey_str)[1].lower()
|
|
200
|
+
|
|
201
|
+
# Check for PDF
|
|
202
|
+
if ext == '.pdf':
|
|
203
|
+
return 'pdf'
|
|
204
|
+
|
|
205
|
+
# Check for image
|
|
206
|
+
if ext in _IMAGE_EXTENSIONS:
|
|
207
|
+
return 'image'
|
|
208
|
+
|
|
209
|
+
# Check if it's a directory (could contain PDFs or images)
|
|
210
|
+
if os.path.isdir(survey_str):
|
|
211
|
+
# Check first file to determine type
|
|
212
|
+
try:
|
|
213
|
+
for f in sorted(os.listdir(survey_str)):
|
|
214
|
+
f_ext = os.path.splitext(f)[1].lower()
|
|
215
|
+
if f_ext == '.pdf':
|
|
216
|
+
return 'pdf'
|
|
217
|
+
if f_ext in _IMAGE_EXTENSIONS:
|
|
218
|
+
return 'image'
|
|
219
|
+
except OSError:
|
|
220
|
+
pass
|
|
221
|
+
# Default to PDF for directories (backward compatibility)
|
|
222
|
+
return 'pdf'
|
|
223
|
+
|
|
224
|
+
return 'text'
|
|
225
|
+
|
|
226
|
+
# Handle list/series input
|
|
227
|
+
if hasattr(input_data, '__iter__'):
|
|
228
|
+
for item in input_data:
|
|
229
|
+
if item is not None and not pd.isna(item):
|
|
230
|
+
item_str = str(item)
|
|
231
|
+
ext = os.path.splitext(item_str)[1].lower()
|
|
232
|
+
if ext == '.pdf':
|
|
233
|
+
return 'pdf'
|
|
234
|
+
if ext in _IMAGE_EXTENSIONS:
|
|
235
|
+
return 'image'
|
|
236
|
+
# First non-null item is text
|
|
237
|
+
return 'text'
|
|
238
|
+
|
|
239
|
+
return 'text'
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# =============================================================================
|
|
243
|
+
# Helper Functions
|
|
244
|
+
# =============================================================================
|
|
245
|
+
|
|
246
|
+
def sanitize_model_name(model: str) -> str:
|
|
247
|
+
"""
|
|
248
|
+
Convert model name to a valid column suffix.
|
|
249
|
+
|
|
250
|
+
Examples:
|
|
251
|
+
gpt-4o -> gpt_4o
|
|
252
|
+
claude-sonnet-4-5-20250929 -> claude_sonnet_4_5_20250929
|
|
253
|
+
llama3.2:latest -> llama3_2_latest
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
model: The model name string
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
Sanitized string suitable for use in column names
|
|
260
|
+
"""
|
|
261
|
+
sanitized = re.sub(r'[^a-zA-Z0-9]', '_', model)
|
|
262
|
+
sanitized = re.sub(r'_+', '_', sanitized)
|
|
263
|
+
sanitized = sanitized.strip('_').lower()
|
|
264
|
+
return sanitized[:40] # Truncate to reasonable length
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _format_creativity_suffix(creativity) -> str:
|
|
268
|
+
"""Format creativity value as a column suffix, e.g. 0.25 -> '_t25', 1.0 -> '_t100'."""
|
|
269
|
+
if creativity is None:
|
|
270
|
+
return "_tauto"
|
|
271
|
+
# Multiply by 100 and format as integer: 0 -> 0, 0.25 -> 25, 1.0 -> 100
|
|
272
|
+
return f"_t{int(round(creativity * 100))}"
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def prepare_model_configs(models: list, auto_download: bool = False) -> list:
|
|
276
|
+
"""
|
|
277
|
+
Validate and prepare model configurations.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
models: List of tuples. Each tuple can be:
|
|
281
|
+
- (model, provider, api_key) — 3 elements
|
|
282
|
+
- (model, provider, api_key, options) — 4 elements, where options is a
|
|
283
|
+
dict with per-model overrides (e.g. {"creativity": 0.5})
|
|
284
|
+
auto_download: If True, automatically download missing Ollama models
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
List of config dicts with validated settings
|
|
288
|
+
|
|
289
|
+
Raises:
|
|
290
|
+
ValueError: If API key missing for non-Ollama provider
|
|
291
|
+
ConnectionError: If Ollama not running when needed
|
|
292
|
+
RuntimeError: If Ollama model not available and auto_download is False
|
|
293
|
+
"""
|
|
294
|
+
configs = []
|
|
295
|
+
has_ollama = False
|
|
296
|
+
ollama_checked = False
|
|
297
|
+
is_ensemble = len(models) > 1
|
|
298
|
+
|
|
299
|
+
for entry in models:
|
|
300
|
+
if len(entry) == 4:
|
|
301
|
+
model, provider, api_key, options = entry
|
|
302
|
+
elif len(entry) == 3:
|
|
303
|
+
model, provider, api_key = entry
|
|
304
|
+
options = {}
|
|
305
|
+
else:
|
|
306
|
+
raise ValueError(
|
|
307
|
+
f"Each model entry must be a 3-tuple (model, provider, api_key) "
|
|
308
|
+
f"or 4-tuple (model, provider, api_key, options), got {len(entry)} elements"
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
detected_provider = detect_provider(model, provider)
|
|
312
|
+
|
|
313
|
+
if detected_provider == "ollama":
|
|
314
|
+
has_ollama = True
|
|
315
|
+
# Check Ollama running (once)
|
|
316
|
+
if not ollama_checked:
|
|
317
|
+
if not check_ollama_running():
|
|
318
|
+
raise ConnectionError(
|
|
319
|
+
"\n" + "="*60 + "\n"
|
|
320
|
+
" OLLAMA NOT RUNNING\n"
|
|
321
|
+
"="*60 + "\n\n"
|
|
322
|
+
"Ollama must be running to use local models.\n\n"
|
|
323
|
+
"To start Ollama:\n"
|
|
324
|
+
" macOS: Open the Ollama app, or run 'ollama serve'\n"
|
|
325
|
+
" Linux: Run 'ollama serve' in terminal\n"
|
|
326
|
+
" Windows: Open the Ollama app\n\n"
|
|
327
|
+
+ "="*60
|
|
328
|
+
)
|
|
329
|
+
ollama_checked = True
|
|
330
|
+
|
|
331
|
+
# Check model availability
|
|
332
|
+
if not check_ollama_model(model):
|
|
333
|
+
if not pull_ollama_model(model, auto_confirm=auto_download):
|
|
334
|
+
raise RuntimeError(
|
|
335
|
+
f"Ollama model '{model}' not available. "
|
|
336
|
+
f"Run: ollama pull {model}"
|
|
337
|
+
)
|
|
338
|
+
elif detected_provider == "claude-code":
|
|
339
|
+
if not check_claude_cli_available():
|
|
340
|
+
raise ConnectionError(
|
|
341
|
+
"\n" + "="*60 + "\n"
|
|
342
|
+
" CLAUDE CLI NOT FOUND\n"
|
|
343
|
+
"="*60 + "\n\n"
|
|
344
|
+
"The claude CLI must be installed to use claude-code as a provider.\n"
|
|
345
|
+
"Install: https://docs.anthropic.com/en/docs/claude-code\n"
|
|
346
|
+
+ "="*60
|
|
347
|
+
)
|
|
348
|
+
else:
|
|
349
|
+
# Validate API key exists for cloud providers
|
|
350
|
+
if not api_key:
|
|
351
|
+
raise ValueError(
|
|
352
|
+
f"API key required for provider '{detected_provider}' (model: {model})"
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Per-model creativity override (None means use global)
|
|
356
|
+
per_model_creativity = options.get("creativity", None) if options else None
|
|
357
|
+
|
|
358
|
+
# Build sanitized column name
|
|
359
|
+
base_name = sanitize_model_name(model)
|
|
360
|
+
if is_ensemble:
|
|
361
|
+
base_name += _format_creativity_suffix(per_model_creativity)
|
|
362
|
+
|
|
363
|
+
configs.append({
|
|
364
|
+
"model": model,
|
|
365
|
+
"provider": detected_provider,
|
|
366
|
+
"api_key": api_key,
|
|
367
|
+
"use_two_step": (detected_provider == "ollama"),
|
|
368
|
+
"sanitized_name": base_name,
|
|
369
|
+
"creativity": per_model_creativity,
|
|
370
|
+
})
|
|
371
|
+
|
|
372
|
+
# Check for duplicate sanitized names
|
|
373
|
+
sanitized_names = [c["sanitized_name"] for c in configs]
|
|
374
|
+
if len(sanitized_names) != len(set(sanitized_names)):
|
|
375
|
+
# Find duplicates and make unique by appending index
|
|
376
|
+
seen = {}
|
|
377
|
+
for cfg in configs:
|
|
378
|
+
name = cfg["sanitized_name"]
|
|
379
|
+
if name in seen:
|
|
380
|
+
seen[name] += 1
|
|
381
|
+
cfg["sanitized_name"] = f"{name}_{seen[name]}"
|
|
382
|
+
else:
|
|
383
|
+
seen[name] = 0
|
|
384
|
+
|
|
385
|
+
return configs
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def aggregate_results(
|
|
389
|
+
model_results: dict,
|
|
390
|
+
categories: list,
|
|
391
|
+
consensus_threshold: Union[str, float],
|
|
392
|
+
fail_strategy: str,
|
|
393
|
+
formatter_state: dict = None,
|
|
394
|
+
) -> dict:
|
|
395
|
+
"""
|
|
396
|
+
Aggregate results from multiple models using majority voting.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
model_results: Dict mapping model name to (json_str, error)
|
|
400
|
+
categories: List of category names
|
|
401
|
+
consensus_threshold: Threshold for majority vote. Can be:
|
|
402
|
+
- "majority": 50% agreement (default)
|
|
403
|
+
- "two-thirds": 67% agreement
|
|
404
|
+
- "unanimous": 100% agreement
|
|
405
|
+
- float: Custom threshold between 0 and 1
|
|
406
|
+
fail_strategy: How to handle failures ("partial" or "strict")
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
Dict with per_model results, consensus, agreement scores, and metadata
|
|
410
|
+
"""
|
|
411
|
+
# Resolve string thresholds to numeric values
|
|
412
|
+
threshold = _resolve_consensus_threshold(consensus_threshold)
|
|
413
|
+
successful = {}
|
|
414
|
+
failed_models = []
|
|
415
|
+
|
|
416
|
+
num_cats = len(categories)
|
|
417
|
+
expected_keys = {str(i) for i in range(1, num_cats + 1)}
|
|
418
|
+
|
|
419
|
+
for model_name, (json_str, error) in model_results.items():
|
|
420
|
+
if error:
|
|
421
|
+
failed_models.append(model_name)
|
|
422
|
+
else:
|
|
423
|
+
try:
|
|
424
|
+
parsed = json.loads(json_str)
|
|
425
|
+
# Accept if at least one key is a valid numbered category
|
|
426
|
+
# with a 0/1 value. Models may only return present categories
|
|
427
|
+
# (e.g. {"3": "1"}) — missing keys default to 0 downstream.
|
|
428
|
+
# Strip out any keys with invalid values so they also
|
|
429
|
+
# default to 0 cleanly instead of hitting error paths.
|
|
430
|
+
valid_count = sum(
|
|
431
|
+
1 for k, v in parsed.items()
|
|
432
|
+
if k in expected_keys and str(v).strip() in ("0", "1")
|
|
433
|
+
)
|
|
434
|
+
if valid_count > 0:
|
|
435
|
+
cleaned = {
|
|
436
|
+
k: str(v).strip() for k, v in parsed.items()
|
|
437
|
+
if k in expected_keys and str(v).strip() in ("0", "1")
|
|
438
|
+
}
|
|
439
|
+
successful[model_name] = cleaned
|
|
440
|
+
else:
|
|
441
|
+
failed_models.append(model_name)
|
|
442
|
+
except json.JSONDecodeError:
|
|
443
|
+
failed_models.append(model_name)
|
|
444
|
+
|
|
445
|
+
# Handle failure strategies
|
|
446
|
+
if fail_strategy == "strict" and failed_models:
|
|
447
|
+
return {
|
|
448
|
+
"per_model": {},
|
|
449
|
+
"consensus": {},
|
|
450
|
+
"agreement": {},
|
|
451
|
+
"failed_models": failed_models,
|
|
452
|
+
"missing_keys": {},
|
|
453
|
+
"error": f"Models failed (strict mode): {failed_models}",
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
if not successful:
|
|
457
|
+
return {
|
|
458
|
+
"per_model": {},
|
|
459
|
+
"consensus": {},
|
|
460
|
+
"agreement": {},
|
|
461
|
+
"missing_keys": {},
|
|
462
|
+
"failed_models": failed_models,
|
|
463
|
+
"error": "All models failed",
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
# Calculate consensus via majority vote
|
|
467
|
+
consensus = {}
|
|
468
|
+
agreement_scores = {}
|
|
469
|
+
num_successful = len(successful)
|
|
470
|
+
# Track missing keys per model (keys in range but absent from response)
|
|
471
|
+
missing_keys_count = {}
|
|
472
|
+
|
|
473
|
+
for i in range(1, len(categories) + 1):
|
|
474
|
+
key = str(i)
|
|
475
|
+
votes = []
|
|
476
|
+
for model_name, parsed in successful.items():
|
|
477
|
+
if key not in parsed:
|
|
478
|
+
missing_keys_count[model_name] = missing_keys_count.get(model_name, 0) + 1
|
|
479
|
+
vote = parsed.get(key, "0")
|
|
480
|
+
# Handle both string and int values
|
|
481
|
+
try:
|
|
482
|
+
votes.append(int(vote))
|
|
483
|
+
except (ValueError, TypeError):
|
|
484
|
+
votes.append(0)
|
|
485
|
+
|
|
486
|
+
positive_rate = sum(votes) / num_successful if num_successful > 0 else 0
|
|
487
|
+
consensus_val = "1" if positive_rate >= threshold else "0"
|
|
488
|
+
consensus[key] = consensus_val
|
|
489
|
+
# Agreement = fraction of models that match the consensus decision
|
|
490
|
+
consensus_int = int(consensus_val)
|
|
491
|
+
matching = sum(1 for v in votes if v == consensus_int)
|
|
492
|
+
agreement_scores[key] = round(matching / num_successful, 3) if num_successful > 0 else 0
|
|
493
|
+
|
|
494
|
+
return {
|
|
495
|
+
"per_model": successful,
|
|
496
|
+
"consensus": consensus,
|
|
497
|
+
"agreement": agreement_scores,
|
|
498
|
+
"failed_models": failed_models,
|
|
499
|
+
"missing_keys": missing_keys_count,
|
|
500
|
+
"error": None,
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
# =============================================================================
|
|
505
|
+
# Shared Utility Functions
|
|
506
|
+
# =============================================================================
|
|
507
|
+
|
|
508
|
+
def normalize_model_input(
|
|
509
|
+
model: str = None,
|
|
510
|
+
api_key: str = None,
|
|
511
|
+
provider: str = "auto",
|
|
512
|
+
models: list = None,
|
|
513
|
+
) -> list:
|
|
514
|
+
"""
|
|
515
|
+
Normalize model input to a list of tuples.
|
|
516
|
+
|
|
517
|
+
Supports three input formats:
|
|
518
|
+
- Single model: model="gpt-4o", api_key="sk-...", provider="auto"
|
|
519
|
+
- Single tuple: models=("gpt-4o", "openai", "sk-...")
|
|
520
|
+
- List of tuples: models=[("gpt-4o", "openai", "sk-..."), ...]
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
model: Single model name
|
|
524
|
+
api_key: API key for single model
|
|
525
|
+
provider: Provider for single model (default "auto")
|
|
526
|
+
models: List of tuples or single tuple
|
|
527
|
+
|
|
528
|
+
Returns:
|
|
529
|
+
List of tuples: [(model, provider, api_key), ...]
|
|
530
|
+
|
|
531
|
+
Raises:
|
|
532
|
+
ValueError: If no model specified
|
|
533
|
+
"""
|
|
534
|
+
if models is None and model is not None:
|
|
535
|
+
# Single model mode: model="gpt-4o", api_key="sk-...", provider="auto"
|
|
536
|
+
return [(model, provider, api_key)]
|
|
537
|
+
elif models is not None:
|
|
538
|
+
# Check if it's a single tuple (not a list of tuples)
|
|
539
|
+
if isinstance(models, tuple) and len(models) in (3, 4) and isinstance(models[0], str):
|
|
540
|
+
return [models]
|
|
541
|
+
return models
|
|
542
|
+
|
|
543
|
+
raise ValueError(
|
|
544
|
+
"No model specified. Use either:\n"
|
|
545
|
+
" - Single model: model='gpt-4o', api_key='sk-...'\n"
|
|
546
|
+
" - Multiple models: models=[('gpt-4o', 'openai', 'sk-...'), ...]"
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def gather_stepback_insights(
|
|
551
|
+
model_configs: list,
|
|
552
|
+
survey_question: str,
|
|
553
|
+
creativity: float = None,
|
|
554
|
+
) -> dict:
|
|
555
|
+
"""
|
|
556
|
+
Gather step-back insights from each model.
|
|
557
|
+
|
|
558
|
+
Step-back prompting first asks about underlying factors before classification.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
model_configs: List of model configuration dicts
|
|
562
|
+
survey_question: The survey question being analyzed
|
|
563
|
+
creativity: Temperature setting
|
|
564
|
+
|
|
565
|
+
Returns:
|
|
566
|
+
Dict mapping model name to (stepback_question, insight) tuples
|
|
567
|
+
"""
|
|
568
|
+
if not survey_question:
|
|
569
|
+
raise TypeError(
|
|
570
|
+
"survey_question is required when using step_back_prompt. "
|
|
571
|
+
"Please provide the survey question you are analyzing."
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
stepback_question = f'What are the underlying factors or dimensions that explain how people typically answer "{survey_question}"?'
|
|
575
|
+
|
|
576
|
+
print("Getting step-back insights for each model...")
|
|
577
|
+
stepback_insights = {}
|
|
578
|
+
|
|
579
|
+
for cfg in model_configs:
|
|
580
|
+
if cfg["provider"] != "ollama":
|
|
581
|
+
effective_creativity = cfg.get("creativity") if cfg.get("creativity") is not None else creativity
|
|
582
|
+
insight, added = _get_stepback_insight(
|
|
583
|
+
cfg["provider"],
|
|
584
|
+
stepback_question,
|
|
585
|
+
cfg["api_key"],
|
|
586
|
+
cfg["model"],
|
|
587
|
+
effective_creativity
|
|
588
|
+
)
|
|
589
|
+
if added:
|
|
590
|
+
stepback_insights[cfg["model"]] = (stepback_question, insight)
|
|
591
|
+
|
|
592
|
+
return stepback_insights
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def prepare_json_schemas(
|
|
596
|
+
model_configs: list,
|
|
597
|
+
categories: list,
|
|
598
|
+
use_json_schema: bool = True,
|
|
599
|
+
) -> dict:
|
|
600
|
+
"""
|
|
601
|
+
Prepare JSON schemas for each model based on provider requirements.
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
model_configs: List of model configuration dicts
|
|
605
|
+
categories: List of category names
|
|
606
|
+
use_json_schema: Whether to use strict JSON schema
|
|
607
|
+
|
|
608
|
+
Returns:
|
|
609
|
+
Dict mapping model name to JSON schema (or None)
|
|
610
|
+
"""
|
|
611
|
+
json_schemas = {}
|
|
612
|
+
|
|
613
|
+
for cfg in model_configs:
|
|
614
|
+
if use_json_schema:
|
|
615
|
+
# Google doesn't support additionalProperties
|
|
616
|
+
include_additional = (cfg["provider"] != "google")
|
|
617
|
+
json_schemas[cfg["model"]] = build_json_schema(categories, include_additional)
|
|
618
|
+
else:
|
|
619
|
+
json_schemas[cfg["model"]] = None
|
|
620
|
+
|
|
621
|
+
return json_schemas
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
# =============================================================================
|
|
625
|
+
# Chain of Verification (CoVe) Functions
|
|
626
|
+
# =============================================================================
|
|
627
|
+
|
|
628
|
+
def build_cove_prompts(original_task: str, response_text: str) -> tuple:
|
|
629
|
+
"""
|
|
630
|
+
Build Chain of Verification prompts for the 4-step verification process.
|
|
631
|
+
|
|
632
|
+
Args:
|
|
633
|
+
original_task: The original classification prompt/task
|
|
634
|
+
response_text: The text response being classified
|
|
635
|
+
|
|
636
|
+
Returns:
|
|
637
|
+
Tuple of (step2_prompt, step3_prompt, step4_prompt)
|
|
638
|
+
"""
|
|
639
|
+
step2_prompt = """You provided this initial categorization:
|
|
640
|
+
<<INITIAL_REPLY>>
|
|
641
|
+
|
|
642
|
+
Original task: {original_task}
|
|
643
|
+
|
|
644
|
+
Generate a focused list of 3-5 verification questions to fact-check your categorization. Each question should:
|
|
645
|
+
- Be concise and specific (one sentence)
|
|
646
|
+
- Address a distinct aspect of the categorization
|
|
647
|
+
- Be answerable independently
|
|
648
|
+
|
|
649
|
+
Focus on verifying:
|
|
650
|
+
- Whether each category assignment is accurate
|
|
651
|
+
- Whether the categories match the criteria in the original task
|
|
652
|
+
- Whether there are any logical inconsistencies
|
|
653
|
+
|
|
654
|
+
Provide only the verification questions as a numbered list.""".format(original_task=original_task)
|
|
655
|
+
|
|
656
|
+
step3_prompt = """Answer the following verification question based on the text response provided.
|
|
657
|
+
|
|
658
|
+
Text response: {response_text}
|
|
659
|
+
|
|
660
|
+
Verification question: <<QUESTION>>
|
|
661
|
+
|
|
662
|
+
Provide a brief, direct answer (1-2 sentences maximum).
|
|
663
|
+
|
|
664
|
+
Answer:""".format(response_text=response_text)
|
|
665
|
+
|
|
666
|
+
step4_prompt = """Original task: {original_task}
|
|
667
|
+
Initial categorization:
|
|
668
|
+
<<INITIAL_REPLY>>
|
|
669
|
+
Verification questions and answers:
|
|
670
|
+
<<VERIFICATION_QA>>
|
|
671
|
+
If no categories are present, assign "0" to all categories.
|
|
672
|
+
Provide the final corrected categorization in the same JSON format:""".format(original_task=original_task)
|
|
673
|
+
|
|
674
|
+
return step2_prompt, step3_prompt, step4_prompt
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
def _remove_numbering(line: str) -> str:
|
|
678
|
+
"""
|
|
679
|
+
Remove numbering/bullets from a line for CoVe question parsing.
|
|
680
|
+
|
|
681
|
+
Handles formats like:
|
|
682
|
+
- "1. Question"
|
|
683
|
+
- "1) Question"
|
|
684
|
+
- "- Question"
|
|
685
|
+
- "• Question"
|
|
686
|
+
"""
|
|
687
|
+
line = line.strip()
|
|
688
|
+
if line.startswith('- '):
|
|
689
|
+
return line[2:].strip()
|
|
690
|
+
if line.startswith('• '):
|
|
691
|
+
return line[2:].strip()
|
|
692
|
+
if line and line[0].isdigit():
|
|
693
|
+
i = 0
|
|
694
|
+
while i < len(line) and line[i].isdigit():
|
|
695
|
+
i += 1
|
|
696
|
+
if i < len(line) and line[i] in '.)':
|
|
697
|
+
return line[i+1:].strip()
|
|
698
|
+
return line
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
def run_chain_of_verification(
|
|
702
|
+
client,
|
|
703
|
+
initial_reply: str,
|
|
704
|
+
step2_prompt: str,
|
|
705
|
+
step3_prompt: str,
|
|
706
|
+
step4_prompt: str,
|
|
707
|
+
json_schema: dict,
|
|
708
|
+
creativity: float = None,
|
|
709
|
+
max_retries: int = 5,
|
|
710
|
+
) -> str:
|
|
711
|
+
"""
|
|
712
|
+
Run the Chain of Verification process.
|
|
713
|
+
|
|
714
|
+
This is a 4-step process:
|
|
715
|
+
1. Initial classification (already done, passed as initial_reply)
|
|
716
|
+
2. Generate verification questions
|
|
717
|
+
3. Answer each verification question
|
|
718
|
+
4. Final corrected classification
|
|
719
|
+
|
|
720
|
+
Args:
|
|
721
|
+
client: UnifiedLLMClient instance
|
|
722
|
+
initial_reply: The initial JSON classification result
|
|
723
|
+
step2_prompt: Prompt template for generating questions
|
|
724
|
+
step3_prompt: Prompt template for answering questions
|
|
725
|
+
step4_prompt: Prompt template for final classification
|
|
726
|
+
json_schema: JSON schema for the final classification
|
|
727
|
+
creativity: Temperature setting
|
|
728
|
+
max_retries: Maximum retry attempts for each API call
|
|
729
|
+
|
|
730
|
+
Returns:
|
|
731
|
+
Final corrected JSON classification string
|
|
732
|
+
"""
|
|
733
|
+
# Step 2: Generate verification questions (text response, not JSON)
|
|
734
|
+
step2_filled = step2_prompt.replace("<<INITIAL_REPLY>>", initial_reply)
|
|
735
|
+
questions_reply, err = client.complete(
|
|
736
|
+
messages=[{"role": "user", "content": step2_filled}],
|
|
737
|
+
creativity=creativity,
|
|
738
|
+
force_json=False, # Text response
|
|
739
|
+
max_retries=max_retries,
|
|
740
|
+
)
|
|
741
|
+
if err:
|
|
742
|
+
return initial_reply # Fall back to initial reply on error
|
|
743
|
+
|
|
744
|
+
# Parse questions
|
|
745
|
+
questions = [
|
|
746
|
+
_remove_numbering(line)
|
|
747
|
+
for line in questions_reply.strip().split('\n')
|
|
748
|
+
if line.strip()
|
|
749
|
+
]
|
|
750
|
+
|
|
751
|
+
# Step 3: Answer each verification question (text responses)
|
|
752
|
+
qa_pairs = []
|
|
753
|
+
for question in questions[:5]: # Limit to 5 questions
|
|
754
|
+
step3_filled = step3_prompt.replace("<<QUESTION>>", question)
|
|
755
|
+
answer_reply, err = client.complete(
|
|
756
|
+
messages=[{"role": "user", "content": step3_filled}],
|
|
757
|
+
creativity=creativity,
|
|
758
|
+
force_json=False, # Text response
|
|
759
|
+
max_retries=max_retries,
|
|
760
|
+
)
|
|
761
|
+
if not err:
|
|
762
|
+
qa_pairs.append(f"Q: {question}\nA: {answer_reply.strip()}")
|
|
763
|
+
|
|
764
|
+
verification_qa = "\n\n".join(qa_pairs)
|
|
765
|
+
|
|
766
|
+
# Step 4: Final corrected categorization (JSON response)
|
|
767
|
+
step4_filled = step4_prompt.replace(
|
|
768
|
+
"<<INITIAL_REPLY>>", initial_reply
|
|
769
|
+
).replace(
|
|
770
|
+
"<<VERIFICATION_QA>>", verification_qa
|
|
771
|
+
)
|
|
772
|
+
final_reply, err = client.complete(
|
|
773
|
+
messages=[{"role": "user", "content": step4_filled}],
|
|
774
|
+
json_schema=json_schema,
|
|
775
|
+
creativity=creativity,
|
|
776
|
+
max_retries=max_retries,
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
if err:
|
|
780
|
+
return initial_reply
|
|
781
|
+
|
|
782
|
+
# Extract and validate JSON from the response (critical for providers
|
|
783
|
+
# like HuggingFace that use json_object mode instead of strict json_schema)
|
|
784
|
+
extracted = extract_json(final_reply)
|
|
785
|
+
try:
|
|
786
|
+
parsed = json.loads(extracted)
|
|
787
|
+
# Verify it has at least one valid key
|
|
788
|
+
if parsed and any(v in ("0", "1") for v in parsed.values()):
|
|
789
|
+
return extracted
|
|
790
|
+
except (json.JSONDecodeError, AttributeError):
|
|
791
|
+
pass
|
|
792
|
+
|
|
793
|
+
# Fall back to initial reply if extraction/validation fails
|
|
794
|
+
return initial_reply
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
# =============================================================================
|
|
798
|
+
# Text-Specific Functions (swap these for image classification)
|
|
799
|
+
# =============================================================================
|
|
800
|
+
|
|
801
|
+
def build_text_classification_prompt(
|
|
802
|
+
response_text: str,
|
|
803
|
+
categories_str: str,
|
|
804
|
+
survey_question_context: str = "",
|
|
805
|
+
examples_text: str = "",
|
|
806
|
+
chain_of_thought: bool = False,
|
|
807
|
+
context_prompt: bool = False,
|
|
808
|
+
step_back_prompt: bool = False,
|
|
809
|
+
stepback_insights: dict = None,
|
|
810
|
+
model_name: str = None,
|
|
811
|
+
multi_label: bool = True,
|
|
812
|
+
) -> list:
|
|
813
|
+
"""
|
|
814
|
+
Build the classification prompt for a text response.
|
|
815
|
+
|
|
816
|
+
This is the text-specific prompt builder. For image classification,
|
|
817
|
+
a different function would be used.
|
|
818
|
+
|
|
819
|
+
Args:
|
|
820
|
+
response_text: The text to classify
|
|
821
|
+
categories_str: Formatted string of categories (numbered list)
|
|
822
|
+
survey_question_context: Context about the survey question
|
|
823
|
+
examples_text: Few-shot examples text
|
|
824
|
+
chain_of_thought: Whether to use step-by-step reasoning
|
|
825
|
+
context_prompt: Whether to add expert context prefix
|
|
826
|
+
step_back_prompt: Whether step-back prompting is enabled
|
|
827
|
+
stepback_insights: Dict of step-back insights per model
|
|
828
|
+
model_name: Current model name (for step-back lookup)
|
|
829
|
+
|
|
830
|
+
Returns:
|
|
831
|
+
List of message dicts for the LLM
|
|
832
|
+
"""
|
|
833
|
+
if multi_label:
|
|
834
|
+
categorize_instruction = 'into the following categories that apply'
|
|
835
|
+
json_instruction = 'Provide your answer in JSON format where the category number is the key and "1" if present, "0" if not.'
|
|
836
|
+
cot_step3 = 'assign 1 to matching categories and 0 to non-matching categories'
|
|
837
|
+
else:
|
|
838
|
+
categorize_instruction = 'into the single most appropriate category'
|
|
839
|
+
json_instruction = 'Provide your answer in JSON format where the category number is the key. Assign "1" to the single best matching category and "0" to all others.'
|
|
840
|
+
cot_step3 = 'assign 1 to the single best matching category and 0 to all others'
|
|
841
|
+
|
|
842
|
+
if chain_of_thought:
|
|
843
|
+
user_prompt = f"""{survey_question_context}
|
|
844
|
+
|
|
845
|
+
Categorize this text response "{response_text}" {categorize_instruction}:
|
|
846
|
+
{categories_str}
|
|
847
|
+
|
|
848
|
+
Let's think step by step:
|
|
849
|
+
1. First, identify the main themes mentioned in the response
|
|
850
|
+
2. Then, match each theme to the relevant categories
|
|
851
|
+
3. Finally, {cot_step3}
|
|
852
|
+
|
|
853
|
+
{examples_text}
|
|
854
|
+
|
|
855
|
+
{json_instruction}"""
|
|
856
|
+
else:
|
|
857
|
+
user_prompt = f"""{survey_question_context}
|
|
858
|
+
Categorize this text response "{response_text}" {categorize_instruction}:
|
|
859
|
+
{categories_str}
|
|
860
|
+
{examples_text}
|
|
861
|
+
{json_instruction}"""
|
|
862
|
+
|
|
863
|
+
# Add context prompt prefix if enabled
|
|
864
|
+
if context_prompt:
|
|
865
|
+
label_type = "multi-label" if multi_label else "single-label"
|
|
866
|
+
context = f"""You are an expert researcher in text data categorization.
|
|
867
|
+
Apply {label_type} classification and base decisions on explicit and implicit meanings.
|
|
868
|
+
When uncertain, prioritize precision over recall.
|
|
869
|
+
|
|
870
|
+
"""
|
|
871
|
+
user_prompt = context + user_prompt
|
|
872
|
+
|
|
873
|
+
# Build messages list
|
|
874
|
+
messages = []
|
|
875
|
+
|
|
876
|
+
# Add step-back insight if available for this model
|
|
877
|
+
if step_back_prompt and stepback_insights and model_name in stepback_insights:
|
|
878
|
+
sb_question, sb_insight = stepback_insights[model_name]
|
|
879
|
+
messages.append({"role": "user", "content": sb_question})
|
|
880
|
+
messages.append({"role": "assistant", "content": sb_insight})
|
|
881
|
+
|
|
882
|
+
messages.append({"role": "user", "content": user_prompt})
|
|
883
|
+
|
|
884
|
+
return messages
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
# =============================================================================
|
|
888
|
+
# Summarization Functions
|
|
889
|
+
# =============================================================================
|
|
890
|
+
|
|
891
|
+
def build_summary_json_schema(include_additional_properties: bool = True) -> dict:
|
|
892
|
+
"""
|
|
893
|
+
Build JSON schema for summary output.
|
|
894
|
+
|
|
895
|
+
Args:
|
|
896
|
+
include_additional_properties: Whether to include additionalProperties: false.
|
|
897
|
+
Should be False for Google (not supported).
|
|
898
|
+
|
|
899
|
+
Returns:
|
|
900
|
+
JSON schema dict for structured summary output
|
|
901
|
+
"""
|
|
902
|
+
schema = {
|
|
903
|
+
"type": "object",
|
|
904
|
+
"properties": {
|
|
905
|
+
"summary": {
|
|
906
|
+
"type": "string",
|
|
907
|
+
"description": "A concise summary of the input text"
|
|
908
|
+
}
|
|
909
|
+
},
|
|
910
|
+
"required": ["summary"],
|
|
911
|
+
}
|
|
912
|
+
if include_additional_properties:
|
|
913
|
+
schema["additionalProperties"] = False
|
|
914
|
+
return schema
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
def build_text_summarization_prompt(
|
|
918
|
+
response_text: str,
|
|
919
|
+
input_description: str = "",
|
|
920
|
+
summary_instructions: str = "",
|
|
921
|
+
max_length: int = None,
|
|
922
|
+
focus: str = None,
|
|
923
|
+
chain_of_thought: bool = False,
|
|
924
|
+
context_prompt: bool = False,
|
|
925
|
+
step_back_prompt: bool = False,
|
|
926
|
+
stepback_insights: dict = None,
|
|
927
|
+
model_name: str = None,
|
|
928
|
+
) -> list:
|
|
929
|
+
"""
|
|
930
|
+
Build the summarization prompt for a text input.
|
|
931
|
+
|
|
932
|
+
Args:
|
|
933
|
+
response_text: The text to summarize
|
|
934
|
+
input_description: Description of what the text contains
|
|
935
|
+
summary_instructions: Specific instructions (e.g., "bullet points", "one sentence")
|
|
936
|
+
max_length: Maximum summary length in words
|
|
937
|
+
focus: What to focus on (e.g., "main arguments", "emotional content")
|
|
938
|
+
chain_of_thought: Whether to use step-by-step reasoning
|
|
939
|
+
context_prompt: Whether to add expert context prefix
|
|
940
|
+
step_back_prompt: Whether step-back prompting is enabled
|
|
941
|
+
stepback_insights: Dict of step-back insights per model
|
|
942
|
+
model_name: Current model name (for step-back lookup)
|
|
943
|
+
|
|
944
|
+
Returns:
|
|
945
|
+
List of message dicts for the LLM
|
|
946
|
+
"""
|
|
947
|
+
# Build focus instruction if provided
|
|
948
|
+
focus_instruction = ""
|
|
949
|
+
if focus:
|
|
950
|
+
focus_instruction = f", focusing on {focus}"
|
|
951
|
+
|
|
952
|
+
# Build description context if provided
|
|
953
|
+
description_context = ""
|
|
954
|
+
if input_description:
|
|
955
|
+
description_context = f"The following text is: {input_description}\n\n"
|
|
956
|
+
|
|
957
|
+
# Build length instruction if provided
|
|
958
|
+
length_instruction = ""
|
|
959
|
+
if max_length:
|
|
960
|
+
length_instruction = f"\n\nKeep the summary under {max_length} words."
|
|
961
|
+
|
|
962
|
+
# Build custom instructions if provided
|
|
963
|
+
custom_instructions = ""
|
|
964
|
+
if summary_instructions:
|
|
965
|
+
custom_instructions = f"\n\nAdditional instructions: {summary_instructions}"
|
|
966
|
+
|
|
967
|
+
if chain_of_thought:
|
|
968
|
+
user_prompt = f"""{description_context}Summarize the following text{focus_instruction}:
|
|
969
|
+
|
|
970
|
+
"{response_text}"
|
|
971
|
+
|
|
972
|
+
Let's think step by step:
|
|
973
|
+
1. First, identify the main topic or theme
|
|
974
|
+
2. Then, extract the key points
|
|
975
|
+
3. Finally, synthesize into a concise summary{length_instruction}{custom_instructions}
|
|
976
|
+
|
|
977
|
+
Provide your answer in JSON format: {{"summary": "your summary here"}}"""
|
|
978
|
+
else:
|
|
979
|
+
user_prompt = f"""{description_context}Summarize the following text{focus_instruction}:
|
|
980
|
+
|
|
981
|
+
"{response_text}"{length_instruction}{custom_instructions}
|
|
982
|
+
|
|
983
|
+
Provide your answer in JSON format: {{"summary": "your summary here"}}"""
|
|
984
|
+
|
|
985
|
+
# Add context prompt prefix if enabled
|
|
986
|
+
if context_prompt:
|
|
987
|
+
context = """You are an expert at synthesizing key insights from text.
|
|
988
|
+
Focus on accuracy, clarity, and identifying the most important themes.
|
|
989
|
+
Provide concise summaries that capture essential information.
|
|
990
|
+
|
|
991
|
+
"""
|
|
992
|
+
user_prompt = context + user_prompt
|
|
993
|
+
|
|
994
|
+
# Build messages list
|
|
995
|
+
messages = []
|
|
996
|
+
|
|
997
|
+
# Add step-back insight if available for this model
|
|
998
|
+
if step_back_prompt and stepback_insights and model_name in stepback_insights:
|
|
999
|
+
sb_question, sb_insight = stepback_insights[model_name]
|
|
1000
|
+
messages.append({"role": "user", "content": sb_question})
|
|
1001
|
+
messages.append({"role": "assistant", "content": sb_insight})
|
|
1002
|
+
|
|
1003
|
+
messages.append({"role": "user", "content": user_prompt})
|
|
1004
|
+
|
|
1005
|
+
return messages
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
def extract_summary_from_json(json_str: str) -> tuple:
|
|
1009
|
+
"""
|
|
1010
|
+
Extract summary from JSON response.
|
|
1011
|
+
|
|
1012
|
+
Args:
|
|
1013
|
+
json_str: JSON string containing {"summary": "..."}
|
|
1014
|
+
|
|
1015
|
+
Returns:
|
|
1016
|
+
Tuple of (is_valid, summary_text or None)
|
|
1017
|
+
"""
|
|
1018
|
+
try:
|
|
1019
|
+
data = json.loads(json_str)
|
|
1020
|
+
if isinstance(data, dict) and "summary" in data:
|
|
1021
|
+
summary = data["summary"]
|
|
1022
|
+
if isinstance(summary, str) and summary.strip():
|
|
1023
|
+
return True, summary.strip()
|
|
1024
|
+
return False, None
|
|
1025
|
+
except (json.JSONDecodeError, TypeError):
|
|
1026
|
+
return False, None
|
|
1027
|
+
|
|
1028
|
+
|
|
1029
|
+
def build_pdf_summarization_prompt(
|
|
1030
|
+
page_data: dict,
|
|
1031
|
+
input_description: str = "",
|
|
1032
|
+
summary_instructions: str = "",
|
|
1033
|
+
max_length: int = None,
|
|
1034
|
+
focus: str = None,
|
|
1035
|
+
provider: str = "openai",
|
|
1036
|
+
pdf_mode: str = "image",
|
|
1037
|
+
chain_of_thought: bool = False,
|
|
1038
|
+
context_prompt: bool = False,
|
|
1039
|
+
step_back_prompt: bool = False,
|
|
1040
|
+
stepback_insights: dict = None,
|
|
1041
|
+
model_name: str = None,
|
|
1042
|
+
) -> list:
|
|
1043
|
+
"""
|
|
1044
|
+
Build the summarization prompt for a PDF page.
|
|
1045
|
+
|
|
1046
|
+
This is the PDF-specific prompt builder, parallel to build_pdf_classification_prompt()
|
|
1047
|
+
but for summarization instead of classification.
|
|
1048
|
+
|
|
1049
|
+
Args:
|
|
1050
|
+
page_data: Dict containing:
|
|
1051
|
+
- pdf_path: Path to source PDF
|
|
1052
|
+
- page_index: Page number (0-indexed)
|
|
1053
|
+
- page_label: Label like "document_p1"
|
|
1054
|
+
- image_bytes: PNG bytes (for image mode)
|
|
1055
|
+
- pdf_bytes: PDF bytes (for native PDF providers)
|
|
1056
|
+
- text: Extracted text (for text mode)
|
|
1057
|
+
input_description: Description of what the PDF documents contain
|
|
1058
|
+
summary_instructions: Specific instructions (e.g., "bullet points")
|
|
1059
|
+
max_length: Maximum summary length in words
|
|
1060
|
+
focus: What to focus on in the summary
|
|
1061
|
+
provider: Provider name for format-specific handling
|
|
1062
|
+
pdf_mode: "image", "text", or "both"
|
|
1063
|
+
chain_of_thought: Whether to use step-by-step reasoning
|
|
1064
|
+
context_prompt: Whether to add expert context prefix
|
|
1065
|
+
step_back_prompt: Whether step-back prompting is enabled
|
|
1066
|
+
stepback_insights: Dict of step-back insights per model
|
|
1067
|
+
model_name: Current model name (for step-back lookup)
|
|
1068
|
+
|
|
1069
|
+
Returns:
|
|
1070
|
+
List of message content parts for the LLM (format varies by provider)
|
|
1071
|
+
"""
|
|
1072
|
+
# Build focus instruction if provided
|
|
1073
|
+
focus_instruction = ""
|
|
1074
|
+
if focus:
|
|
1075
|
+
focus_instruction = f", focusing on {focus}"
|
|
1076
|
+
|
|
1077
|
+
# Build examine instruction based on mode
|
|
1078
|
+
if pdf_mode == "text":
|
|
1079
|
+
examine_instruction = "Examine the following text extracted from a PDF page"
|
|
1080
|
+
elif pdf_mode == "both":
|
|
1081
|
+
examine_instruction = "Examine the attached PDF page AND the extracted text below"
|
|
1082
|
+
else: # image mode
|
|
1083
|
+
examine_instruction = "Examine the attached PDF page"
|
|
1084
|
+
|
|
1085
|
+
# Build length instruction if provided
|
|
1086
|
+
length_instruction = ""
|
|
1087
|
+
if max_length:
|
|
1088
|
+
length_instruction = f"\n\nKeep the summary under {max_length} words."
|
|
1089
|
+
|
|
1090
|
+
# Build custom instructions if provided
|
|
1091
|
+
custom_instructions = ""
|
|
1092
|
+
if summary_instructions:
|
|
1093
|
+
custom_instructions = f"\n\nAdditional instructions: {summary_instructions}"
|
|
1094
|
+
|
|
1095
|
+
if chain_of_thought:
|
|
1096
|
+
base_text = f"""You are a document summarization assistant.
|
|
1097
|
+
Task: {examine_instruction} and provide a concise summary{focus_instruction}.
|
|
1098
|
+
|
|
1099
|
+
{f'Document context: {input_description}' if input_description else ''}
|
|
1100
|
+
|
|
1101
|
+
Let's analyze step by step:
|
|
1102
|
+
1. First, identify the main topic or theme of this page
|
|
1103
|
+
2. Then, extract the key points and important information
|
|
1104
|
+
3. Finally, synthesize into a concise summary{length_instruction}{custom_instructions}
|
|
1105
|
+
|
|
1106
|
+
Provide your answer in JSON format: {{"summary": "your summary here"}}"""
|
|
1107
|
+
else:
|
|
1108
|
+
base_text = f"""You are a document summarization assistant.
|
|
1109
|
+
Task: {examine_instruction} and provide a concise summary{focus_instruction}.
|
|
1110
|
+
|
|
1111
|
+
{f'Document context: {input_description}' if input_description else ''}{length_instruction}{custom_instructions}
|
|
1112
|
+
|
|
1113
|
+
Provide your answer in JSON format: {{"summary": "your summary here"}}"""
|
|
1114
|
+
|
|
1115
|
+
# Add extracted text for text and both modes
|
|
1116
|
+
if page_data.get("text") and pdf_mode in ("text", "both"):
|
|
1117
|
+
base_text += f"\n\n--- EXTRACTED TEXT FROM PAGE ---\n{page_data['text']}\n--- END OF EXTRACTED TEXT ---"
|
|
1118
|
+
|
|
1119
|
+
# Add context prompt prefix if enabled
|
|
1120
|
+
if context_prompt:
|
|
1121
|
+
context = """You are an expert at synthesizing key insights from documents.
|
|
1122
|
+
Focus on accuracy, clarity, and identifying the most important themes.
|
|
1123
|
+
Provide concise summaries that capture essential information.
|
|
1124
|
+
|
|
1125
|
+
"""
|
|
1126
|
+
base_text = context + base_text
|
|
1127
|
+
|
|
1128
|
+
# Build messages based on provider and mode
|
|
1129
|
+
messages = []
|
|
1130
|
+
|
|
1131
|
+
# Add step-back insight if available
|
|
1132
|
+
if step_back_prompt and stepback_insights and model_name in stepback_insights:
|
|
1133
|
+
sb_question, sb_insight = stepback_insights[model_name]
|
|
1134
|
+
messages.append({"role": "user", "content": sb_question})
|
|
1135
|
+
messages.append({"role": "assistant", "content": sb_insight})
|
|
1136
|
+
|
|
1137
|
+
# TEXT-ONLY MODE: No image/PDF attachment
|
|
1138
|
+
if pdf_mode == "text":
|
|
1139
|
+
messages.append({"role": "user", "content": base_text})
|
|
1140
|
+
return messages
|
|
1141
|
+
|
|
1142
|
+
# IMAGE/BOTH MODE: Include visual content
|
|
1143
|
+
# Format depends on provider
|
|
1144
|
+
|
|
1145
|
+
if provider in _NATIVE_PDF_PROVIDERS and page_data.get("pdf_bytes"):
|
|
1146
|
+
# Anthropic or Google with native PDF
|
|
1147
|
+
encoded_pdf = _encode_bytes_to_base64(page_data["pdf_bytes"])
|
|
1148
|
+
|
|
1149
|
+
if provider == "anthropic":
|
|
1150
|
+
content = [
|
|
1151
|
+
{"type": "text", "text": base_text},
|
|
1152
|
+
{
|
|
1153
|
+
"type": "document",
|
|
1154
|
+
"source": {
|
|
1155
|
+
"type": "base64",
|
|
1156
|
+
"media_type": "application/pdf",
|
|
1157
|
+
"data": encoded_pdf
|
|
1158
|
+
}
|
|
1159
|
+
}
|
|
1160
|
+
]
|
|
1161
|
+
messages.append({"role": "user", "content": content})
|
|
1162
|
+
|
|
1163
|
+
elif provider == "google":
|
|
1164
|
+
# Google uses a special format
|
|
1165
|
+
content = [
|
|
1166
|
+
{"type": "text", "text": base_text},
|
|
1167
|
+
{
|
|
1168
|
+
"type": "inline_data",
|
|
1169
|
+
"mime_type": "application/pdf",
|
|
1170
|
+
"data": encoded_pdf
|
|
1171
|
+
}
|
|
1172
|
+
]
|
|
1173
|
+
messages.append({"role": "user", "content": content})
|
|
1174
|
+
|
|
1175
|
+
elif page_data.get("image_bytes"):
|
|
1176
|
+
# Providers requiring image conversion (OpenAI, Mistral, xAI, etc.)
|
|
1177
|
+
encoded_image = _encode_bytes_to_base64(page_data["image_bytes"])
|
|
1178
|
+
encoded_image_url = f"data:image/png;base64,{encoded_image}"
|
|
1179
|
+
|
|
1180
|
+
content = [
|
|
1181
|
+
{"type": "text", "text": base_text},
|
|
1182
|
+
{"type": "image_url", "image_url": {"url": encoded_image_url, "detail": "high"}}
|
|
1183
|
+
]
|
|
1184
|
+
messages.append({"role": "user", "content": content})
|
|
1185
|
+
|
|
1186
|
+
else:
|
|
1187
|
+
# Fallback to text-only if no visual data available
|
|
1188
|
+
messages.append({"role": "user", "content": base_text})
|
|
1189
|
+
|
|
1190
|
+
return messages
|
|
1191
|
+
|
|
1192
|
+
|
|
1193
|
+
# =============================================================================
|
|
1194
|
+
# PDF-Specific Functions
|
|
1195
|
+
# =============================================================================
|
|
1196
|
+
|
|
1197
|
+
# Provider-specific PDF support info
|
|
1198
|
+
_NATIVE_PDF_PROVIDERS = {"anthropic", "google"} # Send native PDF bytes
|
|
1199
|
+
_IMAGE_PROVIDERS = {"openai", "mistral", "xai", "perplexity", "huggingface"} # Convert to image
|
|
1200
|
+
_TEXT_ONLY_PROVIDERS = {"ollama"} # Text extraction only
|
|
1201
|
+
|
|
1202
|
+
|
|
1203
|
+
def build_pdf_classification_prompt(
|
|
1204
|
+
page_data: dict,
|
|
1205
|
+
categories_str: str,
|
|
1206
|
+
input_description: str = "",
|
|
1207
|
+
provider: str = "openai",
|
|
1208
|
+
pdf_mode: str = "image",
|
|
1209
|
+
chain_of_thought: bool = False,
|
|
1210
|
+
context_prompt: bool = False,
|
|
1211
|
+
step_back_prompt: bool = False,
|
|
1212
|
+
stepback_insights: dict = None,
|
|
1213
|
+
model_name: str = None,
|
|
1214
|
+
example_json: str = None,
|
|
1215
|
+
multi_label: bool = True,
|
|
1216
|
+
) -> list:
|
|
1217
|
+
"""
|
|
1218
|
+
Build the classification prompt for a PDF page.
|
|
1219
|
+
|
|
1220
|
+
This is the PDF-specific prompt builder, parallel to build_text_classification_prompt().
|
|
1221
|
+
|
|
1222
|
+
Args:
|
|
1223
|
+
page_data: Dict containing:
|
|
1224
|
+
- pdf_path: Path to source PDF
|
|
1225
|
+
- page_index: Page number (0-indexed)
|
|
1226
|
+
- page_label: Label like "document_p1"
|
|
1227
|
+
- image_bytes: PNG bytes (for image mode)
|
|
1228
|
+
- pdf_bytes: PDF bytes (for native PDF providers)
|
|
1229
|
+
- text: Extracted text (for text mode)
|
|
1230
|
+
categories_str: Formatted string of categories (numbered list)
|
|
1231
|
+
input_description: Description of what the PDF documents contain
|
|
1232
|
+
provider: Provider name for format-specific handling
|
|
1233
|
+
pdf_mode: "image", "text", or "both"
|
|
1234
|
+
chain_of_thought: Whether to use step-by-step reasoning
|
|
1235
|
+
context_prompt: Whether to add expert context prefix
|
|
1236
|
+
step_back_prompt: Whether step-back prompting is enabled
|
|
1237
|
+
stepback_insights: Dict of step-back insights per model
|
|
1238
|
+
model_name: Current model name (for step-back lookup)
|
|
1239
|
+
example_json: Example JSON output format
|
|
1240
|
+
|
|
1241
|
+
Returns:
|
|
1242
|
+
List of message content parts for the LLM (format varies by provider)
|
|
1243
|
+
"""
|
|
1244
|
+
# Build the base text prompt
|
|
1245
|
+
if pdf_mode == "text":
|
|
1246
|
+
examine_instruction = "Examine the following text extracted from a PDF page"
|
|
1247
|
+
elif pdf_mode == "both":
|
|
1248
|
+
examine_instruction = "Examine the attached PDF page AND the extracted text below"
|
|
1249
|
+
else: # image mode
|
|
1250
|
+
examine_instruction = "Examine the attached PDF page"
|
|
1251
|
+
|
|
1252
|
+
if multi_label:
|
|
1253
|
+
task_instruction = f'{examine_instruction} and decide, **for each category below**, whether it is PRESENT (1) or NOT PRESENT (0).'
|
|
1254
|
+
cot_step3 = 'assign 1 to matching categories and 0 to non-matching categories'
|
|
1255
|
+
else:
|
|
1256
|
+
task_instruction = f'{examine_instruction} and decide which **single category** below best describes this document page. Assign PRESENT (1) to only the best match and NOT PRESENT (0) to all others.'
|
|
1257
|
+
cot_step3 = 'assign 1 to the single best matching category and 0 to all others'
|
|
1258
|
+
|
|
1259
|
+
if chain_of_thought:
|
|
1260
|
+
base_text = f"""You are a document-tagging assistant.
|
|
1261
|
+
Task: {task_instruction}
|
|
1262
|
+
|
|
1263
|
+
{f'Document page is expected to contain: {input_description}' if input_description else ''}
|
|
1264
|
+
|
|
1265
|
+
Categories:
|
|
1266
|
+
{categories_str}
|
|
1267
|
+
|
|
1268
|
+
Let's analyze step by step:
|
|
1269
|
+
1. First, identify the key content elements in the document page
|
|
1270
|
+
2. Then, match each element to the relevant categories
|
|
1271
|
+
3. Finally, {cot_step3}
|
|
1272
|
+
|
|
1273
|
+
Output format: Respond with **only** a JSON object whose keys are the quoted category numbers ('1', '2', ...) and whose values are 1 or 0. No additional keys, comments, or text.
|
|
1274
|
+
|
|
1275
|
+
{f'Example JSON format: {example_json}' if example_json else ''}"""
|
|
1276
|
+
else:
|
|
1277
|
+
base_text = f"""You are a document-tagging assistant.
|
|
1278
|
+
Task: {task_instruction}
|
|
1279
|
+
|
|
1280
|
+
{f'Document page is expected to contain: {input_description}' if input_description else ''}
|
|
1281
|
+
|
|
1282
|
+
Categories:
|
|
1283
|
+
{categories_str}
|
|
1284
|
+
|
|
1285
|
+
Output format: Respond with **only** a JSON object whose keys are the quoted category numbers ('1', '2', ...) and whose values are 1 or 0. No additional keys, comments, or text.
|
|
1286
|
+
|
|
1287
|
+
{f'Example JSON format: {example_json}' if example_json else ''}"""
|
|
1288
|
+
|
|
1289
|
+
# Add extracted text for text and both modes
|
|
1290
|
+
if page_data.get("text") and pdf_mode in ("text", "both"):
|
|
1291
|
+
base_text += f"\n\n--- EXTRACTED TEXT FROM PAGE ---\n{page_data['text']}\n--- END OF EXTRACTED TEXT ---"
|
|
1292
|
+
|
|
1293
|
+
# Add context prompt prefix if enabled
|
|
1294
|
+
if context_prompt:
|
|
1295
|
+
label_type = "multi-label" if multi_label else "single-label"
|
|
1296
|
+
context = f"""You are an expert document analyst specializing in page categorization.
|
|
1297
|
+
Apply {label_type} classification based on explicit and implicit content cues.
|
|
1298
|
+
When uncertain, prioritize precision over recall.
|
|
1299
|
+
|
|
1300
|
+
"""
|
|
1301
|
+
base_text = context + base_text
|
|
1302
|
+
|
|
1303
|
+
# Build messages based on provider and mode
|
|
1304
|
+
messages = []
|
|
1305
|
+
|
|
1306
|
+
# Add step-back insight if available
|
|
1307
|
+
if step_back_prompt and stepback_insights and model_name in stepback_insights:
|
|
1308
|
+
sb_question, sb_insight = stepback_insights[model_name]
|
|
1309
|
+
messages.append({"role": "user", "content": sb_question})
|
|
1310
|
+
messages.append({"role": "assistant", "content": sb_insight})
|
|
1311
|
+
|
|
1312
|
+
# TEXT-ONLY MODE: No image/PDF attachment
|
|
1313
|
+
if pdf_mode == "text":
|
|
1314
|
+
messages.append({"role": "user", "content": base_text})
|
|
1315
|
+
return messages
|
|
1316
|
+
|
|
1317
|
+
# IMAGE/BOTH MODE: Include visual content
|
|
1318
|
+
# Format depends on provider
|
|
1319
|
+
|
|
1320
|
+
if provider in _NATIVE_PDF_PROVIDERS and page_data.get("pdf_bytes"):
|
|
1321
|
+
# Anthropic or Google with native PDF
|
|
1322
|
+
encoded_pdf = _encode_bytes_to_base64(page_data["pdf_bytes"])
|
|
1323
|
+
|
|
1324
|
+
if provider == "anthropic":
|
|
1325
|
+
content = [
|
|
1326
|
+
{"type": "text", "text": base_text},
|
|
1327
|
+
{
|
|
1328
|
+
"type": "document",
|
|
1329
|
+
"source": {
|
|
1330
|
+
"type": "base64",
|
|
1331
|
+
"media_type": "application/pdf",
|
|
1332
|
+
"data": encoded_pdf
|
|
1333
|
+
}
|
|
1334
|
+
}
|
|
1335
|
+
]
|
|
1336
|
+
messages.append({"role": "user", "content": content})
|
|
1337
|
+
|
|
1338
|
+
elif provider == "google":
|
|
1339
|
+
# Google uses a special format - return dict for google-specific handling
|
|
1340
|
+
content = [
|
|
1341
|
+
{"type": "text", "text": base_text},
|
|
1342
|
+
{
|
|
1343
|
+
"type": "inline_data",
|
|
1344
|
+
"mime_type": "application/pdf",
|
|
1345
|
+
"data": encoded_pdf
|
|
1346
|
+
}
|
|
1347
|
+
]
|
|
1348
|
+
messages.append({"role": "user", "content": content})
|
|
1349
|
+
|
|
1350
|
+
elif page_data.get("image_bytes"):
|
|
1351
|
+
# Providers requiring image conversion (OpenAI, Mistral, xAI, etc.)
|
|
1352
|
+
encoded_image = _encode_bytes_to_base64(page_data["image_bytes"])
|
|
1353
|
+
encoded_image_url = f"data:image/png;base64,{encoded_image}"
|
|
1354
|
+
|
|
1355
|
+
content = [
|
|
1356
|
+
{"type": "text", "text": base_text},
|
|
1357
|
+
{"type": "image_url", "image_url": {"url": encoded_image_url, "detail": "high"}}
|
|
1358
|
+
]
|
|
1359
|
+
messages.append({"role": "user", "content": content})
|
|
1360
|
+
|
|
1361
|
+
else:
|
|
1362
|
+
# Fallback to text-only if no visual data available
|
|
1363
|
+
messages.append({"role": "user", "content": base_text})
|
|
1364
|
+
|
|
1365
|
+
return messages
|
|
1366
|
+
|
|
1367
|
+
|
|
1368
|
+
def _prepare_page_data(
|
|
1369
|
+
pdf_path: str,
|
|
1370
|
+
page_index: int,
|
|
1371
|
+
page_label: str,
|
|
1372
|
+
pdf_mode: str,
|
|
1373
|
+
provider: str,
|
|
1374
|
+
pdf_dpi: int = 150,
|
|
1375
|
+
) -> dict:
|
|
1376
|
+
"""
|
|
1377
|
+
Prepare page data for classification based on mode and provider.
|
|
1378
|
+
|
|
1379
|
+
Args:
|
|
1380
|
+
pdf_path: Path to the PDF file
|
|
1381
|
+
page_index: Page number (0-indexed)
|
|
1382
|
+
page_label: Label for the page (e.g., "document_p1")
|
|
1383
|
+
pdf_mode: "image", "text", or "both"
|
|
1384
|
+
provider: Provider name for determining format
|
|
1385
|
+
pdf_dpi: DPI for image extraction
|
|
1386
|
+
|
|
1387
|
+
Returns:
|
|
1388
|
+
Dict with page data ready for classification
|
|
1389
|
+
"""
|
|
1390
|
+
page_data = {
|
|
1391
|
+
"pdf_path": pdf_path,
|
|
1392
|
+
"page_index": page_index,
|
|
1393
|
+
"page_label": page_label,
|
|
1394
|
+
"text": None,
|
|
1395
|
+
"image_bytes": None,
|
|
1396
|
+
"pdf_bytes": None,
|
|
1397
|
+
"error": None,
|
|
1398
|
+
}
|
|
1399
|
+
|
|
1400
|
+
# Extract text if needed
|
|
1401
|
+
if pdf_mode in ("text", "both"):
|
|
1402
|
+
text, is_valid, error = _extract_page_text(pdf_path, page_index)
|
|
1403
|
+
if is_valid:
|
|
1404
|
+
page_data["text"] = text
|
|
1405
|
+
elif pdf_mode == "text":
|
|
1406
|
+
# Text mode requires text
|
|
1407
|
+
page_data["error"] = f"Failed to extract text: {error}"
|
|
1408
|
+
return page_data
|
|
1409
|
+
|
|
1410
|
+
# Extract visual content if needed
|
|
1411
|
+
if pdf_mode in ("image", "both"):
|
|
1412
|
+
if provider in _NATIVE_PDF_PROVIDERS:
|
|
1413
|
+
# Extract as PDF bytes for native PDF providers
|
|
1414
|
+
pdf_bytes, is_valid = _extract_page_as_pdf_bytes(pdf_path, page_index)
|
|
1415
|
+
if is_valid:
|
|
1416
|
+
page_data["pdf_bytes"] = pdf_bytes
|
|
1417
|
+
else:
|
|
1418
|
+
# Fallback to image
|
|
1419
|
+
image_bytes, is_valid = _extract_page_as_image_bytes(pdf_path, page_index, dpi=pdf_dpi)
|
|
1420
|
+
if is_valid:
|
|
1421
|
+
page_data["image_bytes"] = image_bytes
|
|
1422
|
+
else:
|
|
1423
|
+
page_data["error"] = "Failed to extract page as PDF or image"
|
|
1424
|
+
else:
|
|
1425
|
+
# Extract as image for other providers
|
|
1426
|
+
image_bytes, is_valid = _extract_page_as_image_bytes(pdf_path, page_index, dpi=pdf_dpi)
|
|
1427
|
+
if is_valid:
|
|
1428
|
+
page_data["image_bytes"] = image_bytes
|
|
1429
|
+
else:
|
|
1430
|
+
page_data["error"] = "Failed to extract page as image"
|
|
1431
|
+
|
|
1432
|
+
return page_data
|
|
1433
|
+
|
|
1434
|
+
|
|
1435
|
+
# =============================================================================
|
|
1436
|
+
# Image-Specific Functions
|
|
1437
|
+
# =============================================================================
|
|
1438
|
+
|
|
1439
|
+
def build_image_classification_prompt(
|
|
1440
|
+
image_data: dict,
|
|
1441
|
+
categories_str: str,
|
|
1442
|
+
input_description: str = "",
|
|
1443
|
+
provider: str = "openai",
|
|
1444
|
+
chain_of_thought: bool = False,
|
|
1445
|
+
context_prompt: bool = False,
|
|
1446
|
+
step_back_prompt: bool = False,
|
|
1447
|
+
stepback_insights: dict = None,
|
|
1448
|
+
model_name: str = None,
|
|
1449
|
+
example_json: str = None,
|
|
1450
|
+
multi_label: bool = True,
|
|
1451
|
+
) -> list:
|
|
1452
|
+
"""
|
|
1453
|
+
Build the classification prompt for an image.
|
|
1454
|
+
|
|
1455
|
+
This is the image-specific prompt builder, parallel to build_pdf_classification_prompt().
|
|
1456
|
+
|
|
1457
|
+
Args:
|
|
1458
|
+
image_data: Dict containing:
|
|
1459
|
+
- image_path: Path to source image
|
|
1460
|
+
- image_label: Label for the image (filename without extension)
|
|
1461
|
+
- encoded_image: Base64 encoded image
|
|
1462
|
+
- extension: Image file extension (without dot)
|
|
1463
|
+
categories_str: Formatted string of categories (numbered list)
|
|
1464
|
+
input_description: Description of what the images contain
|
|
1465
|
+
provider: Provider name for format-specific handling
|
|
1466
|
+
chain_of_thought: Whether to use step-by-step reasoning
|
|
1467
|
+
context_prompt: Whether to add expert context prefix
|
|
1468
|
+
step_back_prompt: Whether step-back prompting is enabled
|
|
1469
|
+
stepback_insights: Dict of step-back insights per model
|
|
1470
|
+
model_name: Current model name (for step-back lookup)
|
|
1471
|
+
example_json: Example JSON output format
|
|
1472
|
+
|
|
1473
|
+
Returns:
|
|
1474
|
+
List of message content parts for the LLM (format varies by provider)
|
|
1475
|
+
"""
|
|
1476
|
+
# Build the base text prompt
|
|
1477
|
+
if multi_label:
|
|
1478
|
+
task_instruction = 'Examine the attached image and decide, **for each category below**, whether it is PRESENT (1) or NOT PRESENT (0).'
|
|
1479
|
+
cot_step3 = 'assign 1 to matching categories and 0 to non-matching categories'
|
|
1480
|
+
else:
|
|
1481
|
+
task_instruction = 'Examine the attached image and decide which **single category** below best describes this image. Assign PRESENT (1) to only the best match and NOT PRESENT (0) to all others.'
|
|
1482
|
+
cot_step3 = 'assign 1 to the single best matching category and 0 to all others'
|
|
1483
|
+
|
|
1484
|
+
if chain_of_thought:
|
|
1485
|
+
base_text = f"""You are an image-tagging assistant.
|
|
1486
|
+
Task: {task_instruction}
|
|
1487
|
+
|
|
1488
|
+
{f'The image is expected to contain: {input_description}' if input_description else ''}
|
|
1489
|
+
|
|
1490
|
+
Categories:
|
|
1491
|
+
{categories_str}
|
|
1492
|
+
|
|
1493
|
+
Let's analyze step by step:
|
|
1494
|
+
1. First, identify the key visual elements in the image
|
|
1495
|
+
2. Then, match each element to the relevant categories
|
|
1496
|
+
3. Finally, {cot_step3}
|
|
1497
|
+
|
|
1498
|
+
Output format: Respond with **only** a JSON object whose keys are the quoted category numbers ('1', '2', ...) and whose values are 1 or 0. No additional keys, comments, or text.
|
|
1499
|
+
|
|
1500
|
+
{f'Example JSON format: {example_json}' if example_json else ''}"""
|
|
1501
|
+
else:
|
|
1502
|
+
base_text = f"""You are an image-tagging assistant.
|
|
1503
|
+
Task: {task_instruction}
|
|
1504
|
+
|
|
1505
|
+
{f'The image is expected to contain: {input_description}' if input_description else ''}
|
|
1506
|
+
|
|
1507
|
+
Categories:
|
|
1508
|
+
{categories_str}
|
|
1509
|
+
|
|
1510
|
+
Output format: Respond with **only** a JSON object whose keys are the quoted category numbers ('1', '2', ...) and whose values are 1 or 0. No additional keys, comments, or text.
|
|
1511
|
+
|
|
1512
|
+
{f'Example JSON format: {example_json}' if example_json else ''}"""
|
|
1513
|
+
|
|
1514
|
+
# Add context prompt prefix if enabled
|
|
1515
|
+
if context_prompt:
|
|
1516
|
+
label_type = "multi-label" if multi_label else "single-label"
|
|
1517
|
+
context = f"""You are an expert visual analyst specializing in image categorization.
|
|
1518
|
+
Apply {label_type} classification based on explicit and implicit visual cues.
|
|
1519
|
+
When uncertain, prioritize precision over recall.
|
|
1520
|
+
|
|
1521
|
+
"""
|
|
1522
|
+
base_text = context + base_text
|
|
1523
|
+
|
|
1524
|
+
# Build messages based on provider
|
|
1525
|
+
messages = []
|
|
1526
|
+
|
|
1527
|
+
# Add step-back insight if available
|
|
1528
|
+
if step_back_prompt and stepback_insights and model_name in stepback_insights:
|
|
1529
|
+
sb_question, sb_insight = stepback_insights[model_name]
|
|
1530
|
+
messages.append({"role": "user", "content": sb_question})
|
|
1531
|
+
messages.append({"role": "assistant", "content": sb_insight})
|
|
1532
|
+
|
|
1533
|
+
# Get encoded image and extension
|
|
1534
|
+
encoded = image_data.get("encoded_image", "")
|
|
1535
|
+
ext = image_data.get("extension", "png")
|
|
1536
|
+
|
|
1537
|
+
# Format depends on provider
|
|
1538
|
+
if provider == "anthropic":
|
|
1539
|
+
# Anthropic uses explicit media_type
|
|
1540
|
+
content = [
|
|
1541
|
+
{"type": "text", "text": base_text},
|
|
1542
|
+
{
|
|
1543
|
+
"type": "image",
|
|
1544
|
+
"source": {
|
|
1545
|
+
"type": "base64",
|
|
1546
|
+
"media_type": f"image/{ext}",
|
|
1547
|
+
"data": encoded
|
|
1548
|
+
}
|
|
1549
|
+
}
|
|
1550
|
+
]
|
|
1551
|
+
messages.append({"role": "user", "content": content})
|
|
1552
|
+
|
|
1553
|
+
elif provider == "google":
|
|
1554
|
+
# Google uses inline_data format
|
|
1555
|
+
content = [
|
|
1556
|
+
{"type": "text", "text": base_text},
|
|
1557
|
+
{
|
|
1558
|
+
"type": "inline_data",
|
|
1559
|
+
"mime_type": f"image/{ext}",
|
|
1560
|
+
"data": encoded
|
|
1561
|
+
}
|
|
1562
|
+
]
|
|
1563
|
+
messages.append({"role": "user", "content": content})
|
|
1564
|
+
|
|
1565
|
+
else:
|
|
1566
|
+
# OpenAI, Mistral, xAI, etc. use image_url format
|
|
1567
|
+
encoded_url = f"data:image/{ext};base64,{encoded}"
|
|
1568
|
+
content = [
|
|
1569
|
+
{"type": "text", "text": base_text},
|
|
1570
|
+
{"type": "image_url", "image_url": {"url": encoded_url, "detail": "high"}}
|
|
1571
|
+
]
|
|
1572
|
+
messages.append({"role": "user", "content": content})
|
|
1573
|
+
|
|
1574
|
+
return messages
|
|
1575
|
+
|
|
1576
|
+
|
|
1577
|
+
def _prepare_image_data(
|
|
1578
|
+
image_path: str,
|
|
1579
|
+
image_label: str,
|
|
1580
|
+
) -> dict:
|
|
1581
|
+
"""
|
|
1582
|
+
Prepare image data for classification.
|
|
1583
|
+
|
|
1584
|
+
Args:
|
|
1585
|
+
image_path: Path to the image file
|
|
1586
|
+
image_label: Label for the image (typically filename without extension)
|
|
1587
|
+
|
|
1588
|
+
Returns:
|
|
1589
|
+
Dict with image_path, image_label, encoded_image, extension, error
|
|
1590
|
+
"""
|
|
1591
|
+
image_data = {
|
|
1592
|
+
"image_path": image_path,
|
|
1593
|
+
"image_label": image_label,
|
|
1594
|
+
"encoded_image": None,
|
|
1595
|
+
"extension": None,
|
|
1596
|
+
"error": None,
|
|
1597
|
+
}
|
|
1598
|
+
|
|
1599
|
+
try:
|
|
1600
|
+
# _encode_image returns (encoded_data, extension, is_valid)
|
|
1601
|
+
encoded, ext_from_encode, is_valid = _encode_image(image_path)
|
|
1602
|
+
if not is_valid:
|
|
1603
|
+
image_data["error"] = "Failed to encode image"
|
|
1604
|
+
return image_data
|
|
1605
|
+
|
|
1606
|
+
# Normalize extension
|
|
1607
|
+
ext = ext_from_encode.lower() if ext_from_encode else os.path.splitext(image_path)[1].lower().replace('.', '')
|
|
1608
|
+
if ext in ('jpg', 'jpe', 'jfif', 'pjpeg', 'pjp'):
|
|
1609
|
+
ext = 'jpeg'
|
|
1610
|
+
|
|
1611
|
+
image_data["encoded_image"] = encoded
|
|
1612
|
+
image_data["extension"] = ext
|
|
1613
|
+
except Exception as e:
|
|
1614
|
+
image_data["error"] = str(e)
|
|
1615
|
+
|
|
1616
|
+
return image_data
|
|
1617
|
+
|
|
1618
|
+
|
|
1619
|
+
def _save_partial_results(
|
|
1620
|
+
all_results: list,
|
|
1621
|
+
model_configs: list,
|
|
1622
|
+
categories: list,
|
|
1623
|
+
filename: str,
|
|
1624
|
+
save_directory: str,
|
|
1625
|
+
) -> None:
|
|
1626
|
+
"""
|
|
1627
|
+
Save partial results to file for safety/incremental saves.
|
|
1628
|
+
|
|
1629
|
+
This is a simplified version of build_output_dataframes that only
|
|
1630
|
+
saves the combined DataFrame without building per-model DataFrames.
|
|
1631
|
+
"""
|
|
1632
|
+
model_names = [cfg["sanitized_name"] for cfg in model_configs]
|
|
1633
|
+
num_categories = len(categories)
|
|
1634
|
+
|
|
1635
|
+
# Check if any results have PDF or image metadata
|
|
1636
|
+
has_pdf_metadata = any(result.get("pdf_path") is not None for result in all_results)
|
|
1637
|
+
has_image_metadata = any(result.get("image_path") is not None for result in all_results)
|
|
1638
|
+
|
|
1639
|
+
# Build partial data
|
|
1640
|
+
rows = []
|
|
1641
|
+
for result in all_results:
|
|
1642
|
+
# Determine processing status
|
|
1643
|
+
if result.get("skipped"):
|
|
1644
|
+
status = "skipped"
|
|
1645
|
+
elif result["aggregated"]["error"]:
|
|
1646
|
+
status = "error"
|
|
1647
|
+
elif result["aggregated"]["failed_models"]:
|
|
1648
|
+
status = "partial"
|
|
1649
|
+
else:
|
|
1650
|
+
status = "success"
|
|
1651
|
+
|
|
1652
|
+
row = {
|
|
1653
|
+
"input_data": result["response"],
|
|
1654
|
+
"processing_status": status,
|
|
1655
|
+
"failed_models": ",".join(result["aggregated"]["failed_models"]) if result["aggregated"]["failed_models"] else "",
|
|
1656
|
+
}
|
|
1657
|
+
|
|
1658
|
+
# Add PDF metadata columns if present
|
|
1659
|
+
if has_pdf_metadata:
|
|
1660
|
+
row["pdf_path"] = result.get("pdf_path", "")
|
|
1661
|
+
row["page_index"] = result.get("page_index", "")
|
|
1662
|
+
|
|
1663
|
+
# Add image metadata columns if present
|
|
1664
|
+
if has_image_metadata:
|
|
1665
|
+
row["image_path"] = result.get("image_path", "")
|
|
1666
|
+
|
|
1667
|
+
# Per-model results
|
|
1668
|
+
failed_set = set(result["aggregated"].get("failed_models", []))
|
|
1669
|
+
is_skipped = result.get("skipped", False)
|
|
1670
|
+
for model_name in model_names:
|
|
1671
|
+
if is_skipped or model_name in failed_set:
|
|
1672
|
+
# Skipped (NaN input) or model failed — mark as NA
|
|
1673
|
+
for i in range(1, num_categories + 1):
|
|
1674
|
+
row[f"category_{i}_{model_name}"] = None
|
|
1675
|
+
else:
|
|
1676
|
+
parsed = result["aggregated"]["per_model"].get(model_name, {})
|
|
1677
|
+
for i in range(1, num_categories + 1):
|
|
1678
|
+
key = str(i)
|
|
1679
|
+
value = parsed.get(key, "0")
|
|
1680
|
+
try:
|
|
1681
|
+
row[f"category_{i}_{model_name}"] = int(value)
|
|
1682
|
+
except (ValueError, TypeError):
|
|
1683
|
+
row[f"category_{i}_{model_name}"] = 0
|
|
1684
|
+
|
|
1685
|
+
# Consensus results
|
|
1686
|
+
for i in range(1, num_categories + 1):
|
|
1687
|
+
key = str(i)
|
|
1688
|
+
consensus_val = result["aggregated"]["consensus"].get(key, None)
|
|
1689
|
+
agreement_val = result["aggregated"]["agreement"].get(key, None)
|
|
1690
|
+
|
|
1691
|
+
if consensus_val is not None:
|
|
1692
|
+
try:
|
|
1693
|
+
row[f"category_{i}_consensus"] = int(consensus_val)
|
|
1694
|
+
except (ValueError, TypeError):
|
|
1695
|
+
row[f"category_{i}_consensus"] = None
|
|
1696
|
+
else:
|
|
1697
|
+
row[f"category_{i}_consensus"] = None
|
|
1698
|
+
|
|
1699
|
+
row[f"category_{i}_agreement"] = agreement_val
|
|
1700
|
+
|
|
1701
|
+
rows.append(row)
|
|
1702
|
+
|
|
1703
|
+
# Create DataFrame and save
|
|
1704
|
+
partial_df = pd.DataFrame(rows)
|
|
1705
|
+
|
|
1706
|
+
save_path = filename
|
|
1707
|
+
if save_directory:
|
|
1708
|
+
os.makedirs(save_directory, exist_ok=True)
|
|
1709
|
+
save_path = os.path.join(save_directory, filename)
|
|
1710
|
+
|
|
1711
|
+
partial_df.to_csv(save_path, index=False)
|
|
1712
|
+
|
|
1713
|
+
|
|
1714
|
+
def classify_ensemble(
|
|
1715
|
+
input_data,
|
|
1716
|
+
categories,
|
|
1717
|
+
# Single model mode (like original multi_class)
|
|
1718
|
+
model: str = None,
|
|
1719
|
+
api_key: str = None,
|
|
1720
|
+
provider: str = "auto",
|
|
1721
|
+
# Multi-model mode
|
|
1722
|
+
models: list = None,
|
|
1723
|
+
# Common parameters
|
|
1724
|
+
survey_question: str = "",
|
|
1725
|
+
example1: str = None,
|
|
1726
|
+
example2: str = None,
|
|
1727
|
+
example3: str = None,
|
|
1728
|
+
example4: str = None,
|
|
1729
|
+
example5: str = None,
|
|
1730
|
+
example6: str = None,
|
|
1731
|
+
creativity: float = None,
|
|
1732
|
+
chain_of_thought: bool = False,
|
|
1733
|
+
chain_of_verification: bool = False,
|
|
1734
|
+
step_back_prompt: bool = False,
|
|
1735
|
+
context_prompt: bool = False,
|
|
1736
|
+
thinking_budget: int = 0,
|
|
1737
|
+
use_json_schema: bool = True,
|
|
1738
|
+
max_workers: int = None,
|
|
1739
|
+
parallel: bool = None,
|
|
1740
|
+
consensus_threshold: Union[str, float] = "unanimous",
|
|
1741
|
+
fail_strategy: str = "partial",
|
|
1742
|
+
safety: bool = False,
|
|
1743
|
+
max_retries: int = 5,
|
|
1744
|
+
batch_retries: int = 2,
|
|
1745
|
+
retry_delay: float = 1.0,
|
|
1746
|
+
row_delay: float = 0.0,
|
|
1747
|
+
filename: str = None,
|
|
1748
|
+
save_directory: str = None,
|
|
1749
|
+
progress_callback: Callable = None,
|
|
1750
|
+
# Auto-category detection parameters
|
|
1751
|
+
max_categories: int = 12,
|
|
1752
|
+
categories_per_chunk: int = 10,
|
|
1753
|
+
divisions: int = 10,
|
|
1754
|
+
research_question: str = None,
|
|
1755
|
+
# PDF-specific parameters (only used when input_data contains PDFs)
|
|
1756
|
+
pdf_mode: str = "image",
|
|
1757
|
+
pdf_dpi: int = 150,
|
|
1758
|
+
input_description: str = "",
|
|
1759
|
+
# Ollama parameters
|
|
1760
|
+
auto_download: bool = False,
|
|
1761
|
+
# JSON formatter fallback
|
|
1762
|
+
formatter_state: dict = None,
|
|
1763
|
+
# Label mode
|
|
1764
|
+
multi_label: bool = True,
|
|
1765
|
+
# Chunked classification
|
|
1766
|
+
categories_per_call: int = None,
|
|
1767
|
+
# Embedding tiebreaker
|
|
1768
|
+
embedding_tiebreaker_state: dict = None,
|
|
1769
|
+
):
|
|
1770
|
+
"""
|
|
1771
|
+
Multi-class classification with support for text AND PDF inputs, single or multiple LLM models.
|
|
1772
|
+
|
|
1773
|
+
This unified function auto-detects whether the input is text or PDF and processes accordingly.
|
|
1774
|
+
|
|
1775
|
+
Input type detection:
|
|
1776
|
+
- If input_data is a directory path -> PDF mode
|
|
1777
|
+
- If input_data contains .pdf file paths -> PDF mode
|
|
1778
|
+
- Otherwise -> Text mode
|
|
1779
|
+
|
|
1780
|
+
This function can work in multiple modes:
|
|
1781
|
+
1. Single model mode: Like the original multi_class function
|
|
1782
|
+
2. Ensemble mode: Call multiple models in parallel with majority voting
|
|
1783
|
+
3. PDF mode: Classify PDF pages instead of text responses
|
|
1784
|
+
|
|
1785
|
+
Args:
|
|
1786
|
+
input_data: Text responses OR PDF paths (auto-detected)
|
|
1787
|
+
- Text mode: List or Series of text strings to classify
|
|
1788
|
+
- PDF mode: Directory path, single PDF path, or list of PDF paths
|
|
1789
|
+
|
|
1790
|
+
categories: List of category names, or "auto" to auto-detect categories
|
|
1791
|
+
|
|
1792
|
+
# Single model mode (use these for simple single-model classification):
|
|
1793
|
+
model: Model name (e.g., "gpt-4o", "claude-sonnet-4-5-20250929")
|
|
1794
|
+
api_key: API key for the provider
|
|
1795
|
+
provider: Provider name or "auto" to detect from model name
|
|
1796
|
+
|
|
1797
|
+
# Multi-model mode (use this for ensemble classification):
|
|
1798
|
+
models: List of tuples (model_name, provider, api_key), or a single tuple
|
|
1799
|
+
Example: [("gpt-4o", "openai", "sk-..."), ("claude-sonnet-4-5-20250929", "anthropic", "sk-ant-...")]
|
|
1800
|
+
|
|
1801
|
+
# Classification parameters:
|
|
1802
|
+
survey_question: Context about what question was asked (required for categories="auto")
|
|
1803
|
+
example1-6: Optional few-shot examples for classification
|
|
1804
|
+
creativity: Temperature setting (None for provider default)
|
|
1805
|
+
chain_of_thought: If True, uses step-by-step reasoning in prompt
|
|
1806
|
+
chain_of_verification: If True, uses 4-step verification to improve accuracy
|
|
1807
|
+
(Note: ~4x API calls per response - expensive for ensemble mode)
|
|
1808
|
+
step_back_prompt: If True, first asks about underlying factors before classifying
|
|
1809
|
+
context_prompt: If True, adds expert context prefix to prompts
|
|
1810
|
+
thinking_budget: Token budget for Google's extended thinking (0 to disable)
|
|
1811
|
+
use_json_schema: Whether to use strict JSON schema (vs just json_object mode)
|
|
1812
|
+
|
|
1813
|
+
# Ensemble parameters:
|
|
1814
|
+
max_workers: Maximum parallel workers (default: min(len(models), 8))
|
|
1815
|
+
consensus_threshold: Threshold for consensus vote. Can be:
|
|
1816
|
+
- "unanimous": 100% agreement (default — best accuracy in empirical testing)
|
|
1817
|
+
- "majority": 50% agreement
|
|
1818
|
+
- "two-thirds": 67% agreement
|
|
1819
|
+
- float: Custom threshold between 0 and 1 (e.g., 0.75 for 75%)
|
|
1820
|
+
fail_strategy: How to handle model failures:
|
|
1821
|
+
- "partial": Continue with successful models
|
|
1822
|
+
- "strict": Fail row if any model fails
|
|
1823
|
+
safety: If True, saves results incrementally during processing to prevent
|
|
1824
|
+
data loss. Requires filename to be set.
|
|
1825
|
+
max_retries: Maximum retry attempts for each API call (handles rate limits,
|
|
1826
|
+
server errors, timeouts). Default 5.
|
|
1827
|
+
batch_retries: Maximum retry passes for failed (row, model) pairs after
|
|
1828
|
+
the batch completes. Default 2 means up to 3 total attempts. Set to 0
|
|
1829
|
+
to disable batch-level retries.
|
|
1830
|
+
retry_delay: Seconds to wait between batch retry passes.
|
|
1831
|
+
|
|
1832
|
+
# Output parameters:
|
|
1833
|
+
filename: Optional CSV filename to save combined results (required if safety=True)
|
|
1834
|
+
save_directory: Optional directory for saved files
|
|
1835
|
+
progress_callback: Optional callback(response_idx, model_name, success, total, completed)
|
|
1836
|
+
|
|
1837
|
+
# Auto-category detection parameters (used when categories="auto"):
|
|
1838
|
+
max_categories: Maximum number of categories to discover (default 12)
|
|
1839
|
+
categories_per_chunk: Categories to extract per data chunk (default 10)
|
|
1840
|
+
divisions: Number of chunks to divide data into (default 10)
|
|
1841
|
+
research_question: Optional research context for category discovery
|
|
1842
|
+
|
|
1843
|
+
# PDF-specific parameters (only used when input_data contains PDFs):
|
|
1844
|
+
pdf_mode: How to process PDF pages. Options:
|
|
1845
|
+
- "image": Render pages as images (best for visual elements)
|
|
1846
|
+
- "text": Extract text only (faster/cheaper for text-heavy docs)
|
|
1847
|
+
- "both": Send both text and image (most comprehensive)
|
|
1848
|
+
pdf_dpi: Resolution for PDF to image conversion (default 150)
|
|
1849
|
+
input_description: Description of what the PDF documents contain
|
|
1850
|
+
|
|
1851
|
+
# Ollama parameters:
|
|
1852
|
+
auto_download: If True, automatically download missing Ollama models
|
|
1853
|
+
|
|
1854
|
+
Returns:
|
|
1855
|
+
- Single model: Returns DataFrame directly (backward compatible with multi_class)
|
|
1856
|
+
- Multiple models: Returns dict containing:
|
|
1857
|
+
- "combined": DataFrame with all per-model and consensus columns
|
|
1858
|
+
- "consensus": DataFrame with only consensus results
|
|
1859
|
+
- "<model_name>": Individual DataFrame for each model
|
|
1860
|
+
|
|
1861
|
+
DataFrame columns:
|
|
1862
|
+
- input_data: Text string OR page label (e.g., "document_p1")
|
|
1863
|
+
- category_N_<model>: Per-model results (0/1)
|
|
1864
|
+
- category_N_consensus: Majority vote result
|
|
1865
|
+
- category_N_agreement: Model agreement score
|
|
1866
|
+
- processing_status: "success", "partial", "error", "skipped"
|
|
1867
|
+
- failed_models: List of failed models
|
|
1868
|
+
|
|
1869
|
+
Additional columns (PDF mode only):
|
|
1870
|
+
- pdf_path: Source PDF file path
|
|
1871
|
+
- page_index: Page number (0-indexed)
|
|
1872
|
+
|
|
1873
|
+
Examples:
|
|
1874
|
+
# TEXT MODE - Single model (returns DataFrame directly):
|
|
1875
|
+
df = multi_class_ensemble(
|
|
1876
|
+
input_data=["I moved for a new job"],
|
|
1877
|
+
categories=["Employment", "Family", "Housing"],
|
|
1878
|
+
model="gpt-4o",
|
|
1879
|
+
api_key="sk-...",
|
|
1880
|
+
survey_question="Why did you move?",
|
|
1881
|
+
)
|
|
1882
|
+
|
|
1883
|
+
# TEXT MODE - Ensemble with multiple models (returns dict):
|
|
1884
|
+
results = multi_class_ensemble(
|
|
1885
|
+
input_data=["I moved for a new job"],
|
|
1886
|
+
categories=["Employment", "Family", "Housing"],
|
|
1887
|
+
models=[
|
|
1888
|
+
("gpt-4o", "openai", "sk-..."),
|
|
1889
|
+
("claude-sonnet-4-5-20250929", "anthropic", "sk-ant-..."),
|
|
1890
|
+
],
|
|
1891
|
+
survey_question="Why did you move?",
|
|
1892
|
+
)
|
|
1893
|
+
combined_df = results["combined"]
|
|
1894
|
+
|
|
1895
|
+
# PDF MODE - Single model (auto-detected from .pdf paths):
|
|
1896
|
+
df = multi_class_ensemble(
|
|
1897
|
+
input_data="reports/", # Directory of PDFs
|
|
1898
|
+
categories=["Has Chart", "Has Table", "Financial Summary"],
|
|
1899
|
+
model="gpt-4o",
|
|
1900
|
+
api_key="sk-...",
|
|
1901
|
+
pdf_mode="image",
|
|
1902
|
+
input_description="Financial reports with charts and tables",
|
|
1903
|
+
)
|
|
1904
|
+
|
|
1905
|
+
# PDF MODE - Ensemble with native PDF support:
|
|
1906
|
+
results = multi_class_ensemble(
|
|
1907
|
+
input_data=["doc1.pdf", "doc2.pdf"],
|
|
1908
|
+
categories=["Diagnosis", "Treatment Plan"],
|
|
1909
|
+
models=[
|
|
1910
|
+
("gpt-4o", "openai", "sk-..."),
|
|
1911
|
+
("claude-sonnet-4-5-20250929", "anthropic", "sk-ant-..."), # Native PDF
|
|
1912
|
+
],
|
|
1913
|
+
pdf_mode="both",
|
|
1914
|
+
consensus_threshold=0.5,
|
|
1915
|
+
)
|
|
1916
|
+
"""
|
|
1917
|
+
# Normalize model input to list of tuples
|
|
1918
|
+
models = normalize_model_input(model, api_key, provider, models)
|
|
1919
|
+
|
|
1920
|
+
# Validate safety parameter
|
|
1921
|
+
if safety and filename is None:
|
|
1922
|
+
raise TypeError(
|
|
1923
|
+
"filename is required when using safety=True. "
|
|
1924
|
+
"Please provide a filename to save incremental results to."
|
|
1925
|
+
)
|
|
1926
|
+
|
|
1927
|
+
# Handle categories="auto" - auto-detect categories from the data
|
|
1928
|
+
if categories == "auto":
|
|
1929
|
+
from .main import extract
|
|
1930
|
+
|
|
1931
|
+
# Detect input type to choose the right extraction path
|
|
1932
|
+
detected_type = _detect_input_type(input_data)
|
|
1933
|
+
|
|
1934
|
+
if detected_type == "text" and survey_question == "":
|
|
1935
|
+
raise TypeError(
|
|
1936
|
+
"survey_question is required when using categories='auto' with text input. "
|
|
1937
|
+
"Please provide the survey question you are analyzing."
|
|
1938
|
+
)
|
|
1939
|
+
|
|
1940
|
+
# Use first model for category discovery
|
|
1941
|
+
first_entry = models[0]
|
|
1942
|
+
first_model, first_provider, first_api_key = first_entry[0], first_entry[1], first_entry[2]
|
|
1943
|
+
detected_provider = detect_provider(first_model, first_provider)
|
|
1944
|
+
|
|
1945
|
+
print(f"Auto-detecting categories using {first_model} (input type: {detected_type})...")
|
|
1946
|
+
auto_result = extract(
|
|
1947
|
+
input_data=input_data,
|
|
1948
|
+
api_key=first_api_key,
|
|
1949
|
+
input_type=detected_type,
|
|
1950
|
+
description=survey_question or input_description,
|
|
1951
|
+
max_categories=max_categories,
|
|
1952
|
+
categories_per_chunk=categories_per_chunk,
|
|
1953
|
+
divisions=divisions,
|
|
1954
|
+
user_model=first_model,
|
|
1955
|
+
model_source=detected_provider,
|
|
1956
|
+
research_question=research_question,
|
|
1957
|
+
mode=pdf_mode,
|
|
1958
|
+
)
|
|
1959
|
+
categories = auto_result["top_categories"]
|
|
1960
|
+
print(f"Discovered {len(categories)} categories: {categories}")
|
|
1961
|
+
|
|
1962
|
+
if not isinstance(categories, list) or len(categories) == 0:
|
|
1963
|
+
raise ValueError("categories must be a non-empty list of category names, or 'auto'")
|
|
1964
|
+
|
|
1965
|
+
# Prepare model configurations
|
|
1966
|
+
print(f"Validating {len(models)} model configuration(s)...")
|
|
1967
|
+
model_configs = prepare_model_configs(models, auto_download=auto_download)
|
|
1968
|
+
|
|
1969
|
+
# Print model info
|
|
1970
|
+
print(f"\nModels to use:")
|
|
1971
|
+
for cfg in model_configs:
|
|
1972
|
+
print(f" - {cfg['model']} ({cfg['provider']}) -> column suffix: {cfg['sanitized_name']}")
|
|
1973
|
+
|
|
1974
|
+
# =============================================================================
|
|
1975
|
+
# DETECT INPUT TYPE: Text vs PDF vs Image
|
|
1976
|
+
# =============================================================================
|
|
1977
|
+
input_type = _detect_input_type(input_data)
|
|
1978
|
+
print(f"\nInput type detected: {input_type.upper()}")
|
|
1979
|
+
|
|
1980
|
+
# Initialize processing variables
|
|
1981
|
+
items_to_process = []
|
|
1982
|
+
is_pdf_mode = (input_type == 'pdf')
|
|
1983
|
+
is_image_mode = (input_type == 'image')
|
|
1984
|
+
|
|
1985
|
+
# Build example JSON for visual modes (PDF/image)
|
|
1986
|
+
category_dict = {str(i+1): "0" for i in range(len(categories))}
|
|
1987
|
+
example_json = json.dumps(category_dict, indent=2)
|
|
1988
|
+
|
|
1989
|
+
if is_image_mode:
|
|
1990
|
+
# =================================================================
|
|
1991
|
+
# IMAGE MODE: Load images
|
|
1992
|
+
# =================================================================
|
|
1993
|
+
print(f"Loading images...")
|
|
1994
|
+
|
|
1995
|
+
# Load image files
|
|
1996
|
+
image_files = _load_image_files(input_data)
|
|
1997
|
+
|
|
1998
|
+
if not image_files:
|
|
1999
|
+
raise ValueError("No images found in the provided input.")
|
|
2000
|
+
|
|
2001
|
+
print(f"Total images to process: {len(image_files)}")
|
|
2002
|
+
|
|
2003
|
+
# items_to_process is list of (image_path, image_label) tuples
|
|
2004
|
+
items_to_process = [
|
|
2005
|
+
(img_path, os.path.splitext(os.path.basename(img_path))[0])
|
|
2006
|
+
for img_path in image_files
|
|
2007
|
+
]
|
|
2008
|
+
|
|
2009
|
+
elif is_pdf_mode:
|
|
2010
|
+
# =================================================================
|
|
2011
|
+
# PDF MODE: Load PDFs and extract all pages
|
|
2012
|
+
# =================================================================
|
|
2013
|
+
# Validate pdf_mode parameter
|
|
2014
|
+
pdf_mode = pdf_mode.lower()
|
|
2015
|
+
if pdf_mode not in {"image", "text", "both"}:
|
|
2016
|
+
raise ValueError(f"pdf_mode must be 'image', 'text', or 'both', got: {pdf_mode}")
|
|
2017
|
+
|
|
2018
|
+
print(f"PDF processing mode: {pdf_mode}")
|
|
2019
|
+
|
|
2020
|
+
# Load PDF files
|
|
2021
|
+
pdf_files = _load_pdf_files(input_data)
|
|
2022
|
+
|
|
2023
|
+
# Extract all pages from all PDFs
|
|
2024
|
+
all_pages = []
|
|
2025
|
+
for pdf_path in pdf_files:
|
|
2026
|
+
pages = _get_pdf_pages(pdf_path)
|
|
2027
|
+
all_pages.extend(pages)
|
|
2028
|
+
|
|
2029
|
+
if not all_pages:
|
|
2030
|
+
raise ValueError("No pages found in the provided PDF files.")
|
|
2031
|
+
|
|
2032
|
+
print(f"Total pages to process: {len(all_pages)}")
|
|
2033
|
+
|
|
2034
|
+
# items_to_process is list of (pdf_path, page_index, page_label) tuples
|
|
2035
|
+
items_to_process = all_pages
|
|
2036
|
+
|
|
2037
|
+
else:
|
|
2038
|
+
# =================================================================
|
|
2039
|
+
# TEXT MODE: input_data is the items to process
|
|
2040
|
+
# =================================================================
|
|
2041
|
+
items_to_process = input_data
|
|
2042
|
+
|
|
2043
|
+
# Auto-resolve parallel mode: sequential for all-local (Ollama), parallel otherwise
|
|
2044
|
+
if parallel is None:
|
|
2045
|
+
all_local = all(cfg["provider"] == "ollama" for cfg in model_configs)
|
|
2046
|
+
parallel = not all_local
|
|
2047
|
+
|
|
2048
|
+
# Set max workers
|
|
2049
|
+
effective_workers = max_workers or min(len(models), 8)
|
|
2050
|
+
if parallel:
|
|
2051
|
+
print(f"\nParallel workers: {effective_workers}")
|
|
2052
|
+
else:
|
|
2053
|
+
print(f"\nSequential mode (models run one at a time per row)")
|
|
2054
|
+
|
|
2055
|
+
# Warn about CoVe cost with ensemble
|
|
2056
|
+
if chain_of_verification:
|
|
2057
|
+
print("\n[Chain of Verification enabled]")
|
|
2058
|
+
print(" - ~4x API calls per response per model")
|
|
2059
|
+
if len(models) > 1:
|
|
2060
|
+
print(" - WARNING: CoVe with ensemble is expensive. Consider single-model mode.")
|
|
2061
|
+
|
|
2062
|
+
# Build shared prompt components
|
|
2063
|
+
categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
|
|
2064
|
+
|
|
2065
|
+
examples = [example1, example2, example3, example4, example5, example6]
|
|
2066
|
+
examples_text = "\n".join(
|
|
2067
|
+
f"Example {i}: {ex}" for i, ex in enumerate(examples, 1) if ex is not None
|
|
2068
|
+
)
|
|
2069
|
+
|
|
2070
|
+
survey_question_context = f"Context: {survey_question}." if survey_question else ""
|
|
2071
|
+
|
|
2072
|
+
# Print categories
|
|
2073
|
+
print(f"\nCategories to classify ({len(categories)} total):")
|
|
2074
|
+
for i, cat in enumerate(categories, 1):
|
|
2075
|
+
print(f" {i}. {cat}")
|
|
2076
|
+
print()
|
|
2077
|
+
|
|
2078
|
+
# Get step-back insights per model (if enabled)
|
|
2079
|
+
stepback_insights = {}
|
|
2080
|
+
if step_back_prompt:
|
|
2081
|
+
stepback_insights = gather_stepback_insights(model_configs, survey_question, creativity)
|
|
2082
|
+
|
|
2083
|
+
# Build JSON schemas per provider
|
|
2084
|
+
json_schemas = prepare_json_schemas(model_configs, categories, use_json_schema)
|
|
2085
|
+
|
|
2086
|
+
# Build original task prompt for CoVe (if enabled)
|
|
2087
|
+
cove_original_task = ""
|
|
2088
|
+
if chain_of_verification:
|
|
2089
|
+
if multi_label:
|
|
2090
|
+
cove_categorize = "into the following categories"
|
|
2091
|
+
cove_json = 'Provide your answer in JSON format where the category number is the key and "1" if present, "0" if not.'
|
|
2092
|
+
else:
|
|
2093
|
+
cove_categorize = "into the single most appropriate category"
|
|
2094
|
+
cove_json = 'Provide your answer in JSON format where the category number is the key. Assign "1" to the single best matching category and "0" to all others.'
|
|
2095
|
+
cove_original_task = f"""{survey_question_context}
|
|
2096
|
+
Categorize text responses {cove_categorize}:
|
|
2097
|
+
{categories_str}
|
|
2098
|
+
{cove_json}"""
|
|
2099
|
+
|
|
2100
|
+
# Formatter fallback helper (only active when formatter_state is provided)
|
|
2101
|
+
def _try_formatter_fallback(json_result, raw_reply, chunk_categories=None):
|
|
2102
|
+
"""Try the JSON formatter if extract_json produced invalid output.
|
|
2103
|
+
|
|
2104
|
+
Args:
|
|
2105
|
+
chunk_categories: When called from chunked classification, the
|
|
2106
|
+
actual chunk category list (not the full list). Needed so the
|
|
2107
|
+
formatter sees the correct category names and count.
|
|
2108
|
+
"""
|
|
2109
|
+
if not formatter_state:
|
|
2110
|
+
return json_result
|
|
2111
|
+
cats = chunk_categories if chunk_categories is not None else categories
|
|
2112
|
+
n = len(cats)
|
|
2113
|
+
is_valid, _ = validate_classification_json(json_result, n)
|
|
2114
|
+
if is_valid:
|
|
2115
|
+
return json_result
|
|
2116
|
+
from ._formatter import run_formatter
|
|
2117
|
+
fixed_output = run_formatter(
|
|
2118
|
+
raw_reply, cats,
|
|
2119
|
+
formatter_state["model"],
|
|
2120
|
+
formatter_state["tokenizer"],
|
|
2121
|
+
formatter_state["device"],
|
|
2122
|
+
)
|
|
2123
|
+
fixed_json = extract_json(fixed_output)
|
|
2124
|
+
fixed_valid, _ = validate_classification_json(fixed_json, n)
|
|
2125
|
+
if fixed_valid:
|
|
2126
|
+
return fixed_json
|
|
2127
|
+
return json_result
|
|
2128
|
+
|
|
2129
|
+
# When chunking is active, extend categories with unified "Other" so
|
|
2130
|
+
# aggregate_results and build_output_dataframes create the column.
|
|
2131
|
+
# Save original list for the actual chunked LLM calls (which add their
|
|
2132
|
+
# own per-chunk "Other" internally).
|
|
2133
|
+
_original_categories = categories
|
|
2134
|
+
_add_unified_other = False
|
|
2135
|
+
if categories_per_call is not None:
|
|
2136
|
+
from ._category_analysis import has_other_category as _has_other
|
|
2137
|
+
if not _has_other(categories):
|
|
2138
|
+
categories = list(categories) + ["Other"]
|
|
2139
|
+
_add_unified_other = True
|
|
2140
|
+
|
|
2141
|
+
# Classification function for single model + single item (text, PDF page, or image)
|
|
2142
|
+
def classify_single(cfg: dict, item) -> tuple:
|
|
2143
|
+
"""
|
|
2144
|
+
Classify one item (text response, PDF page, or image) with one model.
|
|
2145
|
+
|
|
2146
|
+
Args:
|
|
2147
|
+
cfg: Model configuration dict
|
|
2148
|
+
item: Either:
|
|
2149
|
+
- Text string (text mode)
|
|
2150
|
+
- Tuple (pdf_path, page_index, page_label) (PDF mode)
|
|
2151
|
+
- Tuple (image_path, image_label) (image mode)
|
|
2152
|
+
|
|
2153
|
+
Returns:
|
|
2154
|
+
tuple: (model_name, json_result, error)
|
|
2155
|
+
"""
|
|
2156
|
+
# Resolve per-model creativity override (falls back to global)
|
|
2157
|
+
effective_creativity = cfg["creativity"] if cfg["creativity"] is not None else creativity
|
|
2158
|
+
|
|
2159
|
+
# Determine item type and identifier
|
|
2160
|
+
if is_image_mode and isinstance(item, tuple) and len(item) == 2:
|
|
2161
|
+
# Image mode: item is (image_path, image_label)
|
|
2162
|
+
image_path, image_label = item
|
|
2163
|
+
item_identifier = image_label
|
|
2164
|
+
elif is_pdf_mode and isinstance(item, tuple) and len(item) == 3:
|
|
2165
|
+
# PDF mode: item is (pdf_path, page_index, page_label)
|
|
2166
|
+
pdf_path, page_index, page_label = item
|
|
2167
|
+
item_identifier = page_label
|
|
2168
|
+
else:
|
|
2169
|
+
# Text mode: item is text string
|
|
2170
|
+
item_identifier = str(item) if item else ""
|
|
2171
|
+
|
|
2172
|
+
# Test hook for debugging batch retries (only active when _TEST_FORCE_FAILURE = True)
|
|
2173
|
+
if _test_should_force_failure(item_identifier, cfg["sanitized_name"]):
|
|
2174
|
+
return (cfg["sanitized_name"], '{"1":"e"}', "TEST: Forced first-attempt failure")
|
|
2175
|
+
|
|
2176
|
+
# =================================================================
|
|
2177
|
+
# CHUNKED PATH: split categories into smaller per-call chunks
|
|
2178
|
+
# =================================================================
|
|
2179
|
+
if categories_per_call is not None:
|
|
2180
|
+
from ._chunked import run_chunked_classification
|
|
2181
|
+
try:
|
|
2182
|
+
client = UnifiedLLMClient(
|
|
2183
|
+
provider=cfg["provider"],
|
|
2184
|
+
api_key=cfg["api_key"],
|
|
2185
|
+
model=cfg["model"]
|
|
2186
|
+
)
|
|
2187
|
+
json_result, error = run_chunked_classification(
|
|
2188
|
+
client=client,
|
|
2189
|
+
cfg=cfg,
|
|
2190
|
+
item=item,
|
|
2191
|
+
categories=_original_categories,
|
|
2192
|
+
categories_str=categories_str,
|
|
2193
|
+
example_json=example_json,
|
|
2194
|
+
json_schema=json_schemas[cfg["model"]],
|
|
2195
|
+
cove_original_task=cove_original_task,
|
|
2196
|
+
effective_creativity=effective_creativity,
|
|
2197
|
+
use_json_schema=use_json_schema,
|
|
2198
|
+
survey_question=survey_question,
|
|
2199
|
+
survey_question_context=survey_question_context,
|
|
2200
|
+
examples_text=examples_text,
|
|
2201
|
+
chain_of_thought=chain_of_thought,
|
|
2202
|
+
context_prompt=context_prompt,
|
|
2203
|
+
step_back_prompt=step_back_prompt,
|
|
2204
|
+
stepback_insights=stepback_insights,
|
|
2205
|
+
chain_of_verification=chain_of_verification,
|
|
2206
|
+
thinking_budget=thinking_budget,
|
|
2207
|
+
max_retries=max_retries,
|
|
2208
|
+
multi_label=multi_label,
|
|
2209
|
+
categories_per_call=categories_per_call,
|
|
2210
|
+
add_unified_other=_add_unified_other,
|
|
2211
|
+
formatter_fallback_fn=_try_formatter_fallback,
|
|
2212
|
+
is_pdf_mode=is_pdf_mode,
|
|
2213
|
+
is_image_mode=is_image_mode,
|
|
2214
|
+
pdf_mode=pdf_mode,
|
|
2215
|
+
pdf_dpi=pdf_dpi,
|
|
2216
|
+
input_description=input_description,
|
|
2217
|
+
build_text_prompt_fn=build_text_classification_prompt,
|
|
2218
|
+
build_pdf_prompt_fn=build_pdf_classification_prompt,
|
|
2219
|
+
build_image_prompt_fn=build_image_classification_prompt,
|
|
2220
|
+
google_multimodal_fn=_call_google_multimodal,
|
|
2221
|
+
prepare_page_data_fn=_prepare_page_data,
|
|
2222
|
+
prepare_image_data_fn=_prepare_image_data,
|
|
2223
|
+
build_cove_prompts_fn=build_cove_prompts,
|
|
2224
|
+
run_cove_fn=run_chain_of_verification,
|
|
2225
|
+
)
|
|
2226
|
+
return (cfg["sanitized_name"], json_result, error)
|
|
2227
|
+
except Exception as e:
|
|
2228
|
+
return (cfg["sanitized_name"], '{"1":"e"}', str(e))
|
|
2229
|
+
|
|
2230
|
+
try:
|
|
2231
|
+
client = UnifiedLLMClient(
|
|
2232
|
+
provider=cfg["provider"],
|
|
2233
|
+
api_key=cfg["api_key"],
|
|
2234
|
+
model=cfg["model"]
|
|
2235
|
+
)
|
|
2236
|
+
|
|
2237
|
+
# =================================================================
|
|
2238
|
+
# PDF MODE: Build PDF-specific prompt
|
|
2239
|
+
# =================================================================
|
|
2240
|
+
if is_pdf_mode and isinstance(item, tuple):
|
|
2241
|
+
pdf_path, page_index, page_label = item
|
|
2242
|
+
|
|
2243
|
+
# Prepare page data based on mode and provider
|
|
2244
|
+
page_data = _prepare_page_data(
|
|
2245
|
+
pdf_path=pdf_path,
|
|
2246
|
+
page_index=page_index,
|
|
2247
|
+
page_label=page_label,
|
|
2248
|
+
pdf_mode=pdf_mode,
|
|
2249
|
+
provider=cfg["provider"],
|
|
2250
|
+
pdf_dpi=pdf_dpi,
|
|
2251
|
+
)
|
|
2252
|
+
|
|
2253
|
+
# Check for extraction errors
|
|
2254
|
+
if page_data.get("error"):
|
|
2255
|
+
return (cfg["sanitized_name"], '{"1":"e"}', page_data["error"])
|
|
2256
|
+
|
|
2257
|
+
# Build PDF classification prompt
|
|
2258
|
+
messages = build_pdf_classification_prompt(
|
|
2259
|
+
page_data=page_data,
|
|
2260
|
+
categories_str=categories_str,
|
|
2261
|
+
input_description=input_description,
|
|
2262
|
+
provider=cfg["provider"],
|
|
2263
|
+
pdf_mode=pdf_mode,
|
|
2264
|
+
chain_of_thought=chain_of_thought,
|
|
2265
|
+
context_prompt=context_prompt,
|
|
2266
|
+
step_back_prompt=step_back_prompt,
|
|
2267
|
+
stepback_insights=stepback_insights,
|
|
2268
|
+
model_name=cfg["model"],
|
|
2269
|
+
example_json=example_json,
|
|
2270
|
+
multi_label=multi_label,
|
|
2271
|
+
)
|
|
2272
|
+
|
|
2273
|
+
# Handle Google API separately (different format)
|
|
2274
|
+
if cfg["provider"] == "google":
|
|
2275
|
+
# Google needs special handling for multimodal content
|
|
2276
|
+
reply, error = _call_google_multimodal(
|
|
2277
|
+
client=client,
|
|
2278
|
+
messages=messages,
|
|
2279
|
+
json_schema=json_schemas[cfg["model"]],
|
|
2280
|
+
creativity=effective_creativity,
|
|
2281
|
+
thinking_budget=thinking_budget,
|
|
2282
|
+
max_retries=max_retries,
|
|
2283
|
+
)
|
|
2284
|
+
else:
|
|
2285
|
+
reply, error = client.complete(
|
|
2286
|
+
messages=messages,
|
|
2287
|
+
json_schema=json_schemas[cfg["model"]],
|
|
2288
|
+
creativity=effective_creativity,
|
|
2289
|
+
thinking_budget=thinking_budget if cfg["provider"] in ("google", "openai", "anthropic", "huggingface", "huggingface-together") else None,
|
|
2290
|
+
max_retries=max_retries,
|
|
2291
|
+
)
|
|
2292
|
+
|
|
2293
|
+
if error:
|
|
2294
|
+
json_result = '{"1":"e"}'
|
|
2295
|
+
else:
|
|
2296
|
+
json_result = extract_json(reply)
|
|
2297
|
+
json_result = _try_formatter_fallback(json_result, reply)
|
|
2298
|
+
|
|
2299
|
+
# Note: CoVe for PDF mode is not yet implemented
|
|
2300
|
+
# (would require re-attaching PDF/image to verification prompts)
|
|
2301
|
+
|
|
2302
|
+
# =================================================================
|
|
2303
|
+
# IMAGE MODE: Build image-specific prompt
|
|
2304
|
+
# =================================================================
|
|
2305
|
+
elif is_image_mode and isinstance(item, tuple):
|
|
2306
|
+
image_path, image_label = item
|
|
2307
|
+
|
|
2308
|
+
# Prepare image data
|
|
2309
|
+
image_data = _prepare_image_data(image_path, image_label)
|
|
2310
|
+
|
|
2311
|
+
# Check for encoding errors
|
|
2312
|
+
if image_data.get("error"):
|
|
2313
|
+
return (cfg["sanitized_name"], '{"1":"e"}', image_data["error"])
|
|
2314
|
+
|
|
2315
|
+
# Build image classification prompt
|
|
2316
|
+
messages = build_image_classification_prompt(
|
|
2317
|
+
image_data=image_data,
|
|
2318
|
+
categories_str=categories_str,
|
|
2319
|
+
input_description=input_description,
|
|
2320
|
+
provider=cfg["provider"],
|
|
2321
|
+
chain_of_thought=chain_of_thought,
|
|
2322
|
+
context_prompt=context_prompt,
|
|
2323
|
+
step_back_prompt=step_back_prompt,
|
|
2324
|
+
stepback_insights=stepback_insights,
|
|
2325
|
+
model_name=cfg["model"],
|
|
2326
|
+
example_json=example_json,
|
|
2327
|
+
multi_label=multi_label,
|
|
2328
|
+
)
|
|
2329
|
+
|
|
2330
|
+
# Handle Google API separately (different format)
|
|
2331
|
+
if cfg["provider"] == "google":
|
|
2332
|
+
# Google needs special handling for multimodal content
|
|
2333
|
+
reply, error = _call_google_multimodal(
|
|
2334
|
+
client=client,
|
|
2335
|
+
messages=messages,
|
|
2336
|
+
json_schema=json_schemas[cfg["model"]],
|
|
2337
|
+
creativity=effective_creativity,
|
|
2338
|
+
thinking_budget=thinking_budget,
|
|
2339
|
+
max_retries=max_retries,
|
|
2340
|
+
)
|
|
2341
|
+
else:
|
|
2342
|
+
reply, error = client.complete(
|
|
2343
|
+
messages=messages,
|
|
2344
|
+
json_schema=json_schemas[cfg["model"]],
|
|
2345
|
+
creativity=effective_creativity,
|
|
2346
|
+
thinking_budget=thinking_budget if cfg["provider"] in ("google", "openai", "anthropic", "huggingface", "huggingface-together") else None,
|
|
2347
|
+
max_retries=max_retries,
|
|
2348
|
+
)
|
|
2349
|
+
|
|
2350
|
+
if error:
|
|
2351
|
+
json_result = '{"1":"e"}'
|
|
2352
|
+
else:
|
|
2353
|
+
json_result = extract_json(reply)
|
|
2354
|
+
json_result = _try_formatter_fallback(json_result, reply)
|
|
2355
|
+
|
|
2356
|
+
# Note: CoVe for image mode is not yet implemented
|
|
2357
|
+
|
|
2358
|
+
# =================================================================
|
|
2359
|
+
# TEXT MODE: Original text classification logic
|
|
2360
|
+
# =================================================================
|
|
2361
|
+
else:
|
|
2362
|
+
response_text = item
|
|
2363
|
+
|
|
2364
|
+
if cfg["use_two_step"]: # Ollama
|
|
2365
|
+
json_result, error = ollama_two_step_classify(
|
|
2366
|
+
client=client,
|
|
2367
|
+
response_text=response_text,
|
|
2368
|
+
categories=categories,
|
|
2369
|
+
categories_str=categories_str,
|
|
2370
|
+
survey_question=survey_question,
|
|
2371
|
+
creativity=effective_creativity,
|
|
2372
|
+
max_retries=max_retries,
|
|
2373
|
+
)
|
|
2374
|
+
if not error:
|
|
2375
|
+
json_result = _try_formatter_fallback(json_result, json_result)
|
|
2376
|
+
# CoVe not supported for Ollama two-step (already has verification)
|
|
2377
|
+
else:
|
|
2378
|
+
messages = build_text_classification_prompt(
|
|
2379
|
+
response_text=response_text,
|
|
2380
|
+
categories_str=categories_str,
|
|
2381
|
+
survey_question_context=survey_question_context,
|
|
2382
|
+
examples_text=examples_text,
|
|
2383
|
+
chain_of_thought=chain_of_thought,
|
|
2384
|
+
context_prompt=context_prompt,
|
|
2385
|
+
step_back_prompt=step_back_prompt,
|
|
2386
|
+
stepback_insights=stepback_insights,
|
|
2387
|
+
model_name=cfg["model"],
|
|
2388
|
+
multi_label=multi_label,
|
|
2389
|
+
)
|
|
2390
|
+
reply, error = client.complete(
|
|
2391
|
+
messages=messages,
|
|
2392
|
+
json_schema=json_schemas[cfg["model"]],
|
|
2393
|
+
creativity=effective_creativity,
|
|
2394
|
+
thinking_budget=thinking_budget if cfg["provider"] in ("google", "openai", "anthropic", "huggingface", "huggingface-together") else None,
|
|
2395
|
+
max_retries=max_retries,
|
|
2396
|
+
)
|
|
2397
|
+
if error:
|
|
2398
|
+
json_result = '{"1":"e"}'
|
|
2399
|
+
else:
|
|
2400
|
+
json_result = extract_json(reply)
|
|
2401
|
+
json_result = _try_formatter_fallback(json_result, reply)
|
|
2402
|
+
|
|
2403
|
+
# Run Chain of Verification if enabled
|
|
2404
|
+
if chain_of_verification and not error:
|
|
2405
|
+
step2, step3, step4 = build_cove_prompts(
|
|
2406
|
+
cove_original_task, response_text
|
|
2407
|
+
)
|
|
2408
|
+
json_result = run_chain_of_verification(
|
|
2409
|
+
client=client,
|
|
2410
|
+
initial_reply=json_result,
|
|
2411
|
+
step2_prompt=step2,
|
|
2412
|
+
step3_prompt=step3,
|
|
2413
|
+
step4_prompt=step4,
|
|
2414
|
+
json_schema=json_schemas[cfg["model"]],
|
|
2415
|
+
creativity=effective_creativity,
|
|
2416
|
+
max_retries=max_retries,
|
|
2417
|
+
)
|
|
2418
|
+
json_result = _try_formatter_fallback(json_result, json_result)
|
|
2419
|
+
|
|
2420
|
+
return (cfg["sanitized_name"], json_result, error)
|
|
2421
|
+
|
|
2422
|
+
except Exception as e:
|
|
2423
|
+
return (cfg["sanitized_name"], '{"1":"e"}', str(e))
|
|
2424
|
+
|
|
2425
|
+
# Helper function for Google multimodal API calls
|
|
2426
|
+
def _call_google_multimodal(client, messages, json_schema, creativity, thinking_budget, max_retries):
|
|
2427
|
+
"""
|
|
2428
|
+
Handle Google's multimodal API format for PDF/image content.
|
|
2429
|
+
"""
|
|
2430
|
+
import requests
|
|
2431
|
+
|
|
2432
|
+
# Extract the content from messages
|
|
2433
|
+
user_msg = messages[-1] # Last message should be the user message
|
|
2434
|
+
content = user_msg.get("content", [])
|
|
2435
|
+
|
|
2436
|
+
# Build Google-format parts
|
|
2437
|
+
parts = []
|
|
2438
|
+
for part in content:
|
|
2439
|
+
if part.get("type") == "text":
|
|
2440
|
+
parts.append({"text": part["text"]})
|
|
2441
|
+
elif part.get("type") == "inline_data":
|
|
2442
|
+
parts.append({
|
|
2443
|
+
"inline_data": {
|
|
2444
|
+
"mime_type": part["mime_type"],
|
|
2445
|
+
"data": part["data"]
|
|
2446
|
+
}
|
|
2447
|
+
})
|
|
2448
|
+
elif part.get("type") == "image_url":
|
|
2449
|
+
# Convert image URL to inline_data format
|
|
2450
|
+
url = part["image_url"]["url"]
|
|
2451
|
+
if url.startswith("data:image/png;base64,"):
|
|
2452
|
+
data = url.replace("data:image/png;base64,", "")
|
|
2453
|
+
parts.append({
|
|
2454
|
+
"inline_data": {
|
|
2455
|
+
"mime_type": "image/png",
|
|
2456
|
+
"data": data
|
|
2457
|
+
}
|
|
2458
|
+
})
|
|
2459
|
+
|
|
2460
|
+
# Get model name from client
|
|
2461
|
+
model_name = client.model
|
|
2462
|
+
|
|
2463
|
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent"
|
|
2464
|
+
headers = {
|
|
2465
|
+
"x-goog-api-key": client.api_key,
|
|
2466
|
+
"Content-Type": "application/json"
|
|
2467
|
+
}
|
|
2468
|
+
|
|
2469
|
+
payload = {
|
|
2470
|
+
"contents": [{"parts": parts}],
|
|
2471
|
+
"generationConfig": {
|
|
2472
|
+
"responseMimeType": "application/json",
|
|
2473
|
+
**({"temperature": creativity} if creativity is not None else {}),
|
|
2474
|
+
**({"thinkingConfig": {"thinkingBudget": thinking_budget}} if thinking_budget else {})
|
|
2475
|
+
}
|
|
2476
|
+
}
|
|
2477
|
+
|
|
2478
|
+
for attempt in range(max_retries):
|
|
2479
|
+
try:
|
|
2480
|
+
response = requests.post(url, headers=headers, json=payload, timeout=120)
|
|
2481
|
+
response.raise_for_status()
|
|
2482
|
+
result = response.json()
|
|
2483
|
+
|
|
2484
|
+
if "candidates" in result and result["candidates"]:
|
|
2485
|
+
reply = result["candidates"][0]["content"]["parts"][0]["text"]
|
|
2486
|
+
return reply, None
|
|
2487
|
+
else:
|
|
2488
|
+
return None, "No response generated"
|
|
2489
|
+
|
|
2490
|
+
except requests.exceptions.HTTPError as e:
|
|
2491
|
+
status_code = e.response.status_code
|
|
2492
|
+
retryable_errors = [429, 500, 502, 503, 504]
|
|
2493
|
+
|
|
2494
|
+
if status_code in retryable_errors and attempt < max_retries - 1:
|
|
2495
|
+
import time
|
|
2496
|
+
wait_time = 10 * (2 ** attempt) if status_code == 429 else 2 * (2 ** attempt)
|
|
2497
|
+
time.sleep(wait_time)
|
|
2498
|
+
else:
|
|
2499
|
+
return None, f"HTTP error {status_code}: {str(e)}"
|
|
2500
|
+
|
|
2501
|
+
except Exception as e:
|
|
2502
|
+
if attempt < max_retries - 1:
|
|
2503
|
+
import time
|
|
2504
|
+
time.sleep(2 * (2 ** attempt))
|
|
2505
|
+
else:
|
|
2506
|
+
return None, str(e)
|
|
2507
|
+
|
|
2508
|
+
return None, "Max retries exceeded"
|
|
2509
|
+
|
|
2510
|
+
# Process all items (text responses, PDF pages, or images)
|
|
2511
|
+
all_results = []
|
|
2512
|
+
completed_calls = [0] # Mutable for closure
|
|
2513
|
+
total_calls = len(items_to_process) * len(model_configs)
|
|
2514
|
+
|
|
2515
|
+
# Set progress description based on mode
|
|
2516
|
+
if is_image_mode:
|
|
2517
|
+
progress_desc = "Classifying images"
|
|
2518
|
+
elif is_pdf_mode:
|
|
2519
|
+
progress_desc = "Classifying PDF pages"
|
|
2520
|
+
else:
|
|
2521
|
+
progress_desc = "Classifying responses"
|
|
2522
|
+
|
|
2523
|
+
# Disable tqdm when progress_callback is provided (e.g., for Streamlit/GUI apps)
|
|
2524
|
+
use_tqdm = progress_callback is None
|
|
2525
|
+
for idx, item in enumerate(tqdm(items_to_process, desc=progress_desc, disable=not use_tqdm)):
|
|
2526
|
+
skipped_nan = False
|
|
2527
|
+
|
|
2528
|
+
# Determine the display identifier and metadata for this item
|
|
2529
|
+
if is_image_mode and isinstance(item, tuple) and len(item) == 2:
|
|
2530
|
+
image_path, image_label = item
|
|
2531
|
+
display_id = image_label
|
|
2532
|
+
pdf_metadata = None
|
|
2533
|
+
image_metadata = {"image_path": image_path}
|
|
2534
|
+
elif is_pdf_mode and isinstance(item, tuple) and len(item) == 3:
|
|
2535
|
+
pdf_path, page_index, page_label = item
|
|
2536
|
+
display_id = page_label
|
|
2537
|
+
pdf_metadata = {"pdf_path": pdf_path, "page_index": page_index}
|
|
2538
|
+
image_metadata = None
|
|
2539
|
+
else:
|
|
2540
|
+
display_id = item
|
|
2541
|
+
pdf_metadata = None
|
|
2542
|
+
image_metadata = None
|
|
2543
|
+
|
|
2544
|
+
# Check for NaN (text mode only)
|
|
2545
|
+
if not is_pdf_mode and not is_image_mode and pd.isna(item):
|
|
2546
|
+
# Handle NaN - mark as skipped, bypass classification entirely
|
|
2547
|
+
skipped_nan = True
|
|
2548
|
+
model_results = {}
|
|
2549
|
+
# Build a clean aggregated result so downstream code doesn't
|
|
2550
|
+
# list every model as failed for a legitimately skipped row.
|
|
2551
|
+
skipped_aggregated = {
|
|
2552
|
+
"per_model": {},
|
|
2553
|
+
"consensus": {},
|
|
2554
|
+
"agreement": {},
|
|
2555
|
+
"failed_models": [],
|
|
2556
|
+
"missing_keys": {},
|
|
2557
|
+
"error": None,
|
|
2558
|
+
}
|
|
2559
|
+
else:
|
|
2560
|
+
# Classification across models
|
|
2561
|
+
model_results = {}
|
|
2562
|
+
if parallel:
|
|
2563
|
+
with ThreadPoolExecutor(max_workers=effective_workers) as executor:
|
|
2564
|
+
futures = {
|
|
2565
|
+
executor.submit(classify_single, cfg, item): cfg["sanitized_name"]
|
|
2566
|
+
for cfg in model_configs
|
|
2567
|
+
}
|
|
2568
|
+
|
|
2569
|
+
for future in as_completed(futures):
|
|
2570
|
+
model_name, json_result, error = future.result()
|
|
2571
|
+
model_results[model_name] = (json_result, error)
|
|
2572
|
+
|
|
2573
|
+
# Update progress (for multi-model detailed callbacks only)
|
|
2574
|
+
completed_calls[0] += 1
|
|
2575
|
+
else:
|
|
2576
|
+
for cfg in model_configs:
|
|
2577
|
+
model_name, json_result, error = classify_single(cfg, item)
|
|
2578
|
+
model_results[model_name] = (json_result, error)
|
|
2579
|
+
completed_calls[0] += 1
|
|
2580
|
+
|
|
2581
|
+
# Aggregate results with majority voting (skip for NaN rows)
|
|
2582
|
+
if skipped_nan:
|
|
2583
|
+
aggregated = skipped_aggregated
|
|
2584
|
+
else:
|
|
2585
|
+
aggregated = aggregate_results(
|
|
2586
|
+
model_results,
|
|
2587
|
+
categories,
|
|
2588
|
+
consensus_threshold,
|
|
2589
|
+
fail_strategy
|
|
2590
|
+
)
|
|
2591
|
+
|
|
2592
|
+
# Build result entry
|
|
2593
|
+
result_entry = {
|
|
2594
|
+
"response": display_id, # Page label for PDF, text for text mode
|
|
2595
|
+
"model_results": model_results,
|
|
2596
|
+
"aggregated": aggregated,
|
|
2597
|
+
"skipped": skipped_nan,
|
|
2598
|
+
}
|
|
2599
|
+
|
|
2600
|
+
# Add PDF metadata if in PDF mode
|
|
2601
|
+
if pdf_metadata:
|
|
2602
|
+
result_entry["pdf_path"] = pdf_metadata["pdf_path"]
|
|
2603
|
+
result_entry["page_index"] = pdf_metadata["page_index"]
|
|
2604
|
+
|
|
2605
|
+
# Add image metadata if in image mode
|
|
2606
|
+
if image_metadata:
|
|
2607
|
+
result_entry["image_path"] = image_metadata["image_path"]
|
|
2608
|
+
|
|
2609
|
+
# Store the original item for retry logic
|
|
2610
|
+
result_entry["_original_item"] = item
|
|
2611
|
+
|
|
2612
|
+
all_results.append(result_entry)
|
|
2613
|
+
|
|
2614
|
+
# Call progress callback with simple item-level signature
|
|
2615
|
+
# This is for UI integrations (like Streamlit) that want per-item updates
|
|
2616
|
+
if progress_callback:
|
|
2617
|
+
try:
|
|
2618
|
+
# Try simple signature: (current_idx, total, item_label)
|
|
2619
|
+
progress_callback(idx, len(items_to_process), display_id)
|
|
2620
|
+
except TypeError:
|
|
2621
|
+
# Fallback: callback might expect keyword args (old signature)
|
|
2622
|
+
pass
|
|
2623
|
+
|
|
2624
|
+
# Safety incremental save
|
|
2625
|
+
if safety:
|
|
2626
|
+
_save_partial_results(
|
|
2627
|
+
all_results,
|
|
2628
|
+
model_configs,
|
|
2629
|
+
categories,
|
|
2630
|
+
filename,
|
|
2631
|
+
save_directory,
|
|
2632
|
+
)
|
|
2633
|
+
|
|
2634
|
+
# Per-row delay to avoid rate limits when models share a provider
|
|
2635
|
+
if row_delay > 0 and idx < len(items_to_process) - 1:
|
|
2636
|
+
time.sleep(row_delay)
|
|
2637
|
+
|
|
2638
|
+
# Retry logic for failed (row, model) pairs
|
|
2639
|
+
if batch_retries > 0:
|
|
2640
|
+
num_cats = len(categories)
|
|
2641
|
+
expected_keys = {str(i) for i in range(1, num_cats + 1)}
|
|
2642
|
+
|
|
2643
|
+
for retry_num in range(1, batch_retries + 1):
|
|
2644
|
+
# Find all failed (row_idx, model_config) pairs
|
|
2645
|
+
failed_pairs = []
|
|
2646
|
+
for row_idx, result in enumerate(all_results):
|
|
2647
|
+
# Skip rows that were NaN inputs
|
|
2648
|
+
if result.get("skipped"):
|
|
2649
|
+
continue
|
|
2650
|
+
# Check each model for this row
|
|
2651
|
+
for cfg in model_configs:
|
|
2652
|
+
model_name = cfg["sanitized_name"]
|
|
2653
|
+
json_str, error = result["model_results"].get(model_name, (None, "Missing"))
|
|
2654
|
+
if error is not None:
|
|
2655
|
+
failed_pairs.append((row_idx, cfg))
|
|
2656
|
+
else:
|
|
2657
|
+
# Check JSON parsing AND schema validation
|
|
2658
|
+
try:
|
|
2659
|
+
parsed = json.loads(json_str)
|
|
2660
|
+
# At least one valid numbered key with 0/1 value
|
|
2661
|
+
valid_count = sum(
|
|
2662
|
+
1 for k, v in parsed.items()
|
|
2663
|
+
if k in expected_keys and str(v).strip() in ("0", "1")
|
|
2664
|
+
)
|
|
2665
|
+
if valid_count == 0:
|
|
2666
|
+
failed_pairs.append((row_idx, cfg))
|
|
2667
|
+
except (json.JSONDecodeError, TypeError):
|
|
2668
|
+
failed_pairs.append((row_idx, cfg))
|
|
2669
|
+
|
|
2670
|
+
if not failed_pairs:
|
|
2671
|
+
break # All successful, no retries needed
|
|
2672
|
+
|
|
2673
|
+
print(f"\n[Batch retry {retry_num}/{batch_retries}] Retrying {len(failed_pairs)} failed (row, model) pairs...")
|
|
2674
|
+
|
|
2675
|
+
# Wait before retrying
|
|
2676
|
+
if retry_delay > 0:
|
|
2677
|
+
time.sleep(retry_delay)
|
|
2678
|
+
|
|
2679
|
+
# Retry failed pairs
|
|
2680
|
+
successes_this_round = 0
|
|
2681
|
+
if parallel:
|
|
2682
|
+
with ThreadPoolExecutor(max_workers=effective_workers) as executor:
|
|
2683
|
+
futures = {
|
|
2684
|
+
executor.submit(classify_single, cfg, all_results[row_idx]["_original_item"]): (row_idx, cfg)
|
|
2685
|
+
for row_idx, cfg in failed_pairs
|
|
2686
|
+
}
|
|
2687
|
+
|
|
2688
|
+
for future in as_completed(futures):
|
|
2689
|
+
row_idx, cfg = futures[future]
|
|
2690
|
+
model_name, json_result, error = future.result()
|
|
2691
|
+
|
|
2692
|
+
# Update the result in place
|
|
2693
|
+
all_results[row_idx]["model_results"][model_name] = (json_result, error)
|
|
2694
|
+
|
|
2695
|
+
if error is None:
|
|
2696
|
+
# Verify JSON is valid and has correct schema
|
|
2697
|
+
try:
|
|
2698
|
+
parsed = json.loads(json_result)
|
|
2699
|
+
valid_count = sum(
|
|
2700
|
+
1 for k, v in parsed.items()
|
|
2701
|
+
if k in expected_keys and str(v).strip() in ("0", "1")
|
|
2702
|
+
)
|
|
2703
|
+
if valid_count > 0:
|
|
2704
|
+
successes_this_round += 1
|
|
2705
|
+
except (json.JSONDecodeError, TypeError):
|
|
2706
|
+
pass
|
|
2707
|
+
else:
|
|
2708
|
+
for row_idx, cfg in failed_pairs:
|
|
2709
|
+
model_name, json_result, error = classify_single(cfg, all_results[row_idx]["_original_item"])
|
|
2710
|
+
all_results[row_idx]["model_results"][model_name] = (json_result, error)
|
|
2711
|
+
if error is None:
|
|
2712
|
+
try:
|
|
2713
|
+
parsed = json.loads(json_result)
|
|
2714
|
+
valid_count = sum(
|
|
2715
|
+
1 for k, v in parsed.items()
|
|
2716
|
+
if k in expected_keys and str(v).strip() in ("0", "1")
|
|
2717
|
+
)
|
|
2718
|
+
if valid_count > 0:
|
|
2719
|
+
successes_this_round += 1
|
|
2720
|
+
except (json.JSONDecodeError, TypeError):
|
|
2721
|
+
pass
|
|
2722
|
+
|
|
2723
|
+
# Recalculate aggregation for all affected rows
|
|
2724
|
+
affected_rows = set(row_idx for row_idx, _ in failed_pairs)
|
|
2725
|
+
for row_idx in affected_rows:
|
|
2726
|
+
all_results[row_idx]["aggregated"] = aggregate_results(
|
|
2727
|
+
all_results[row_idx]["model_results"],
|
|
2728
|
+
categories,
|
|
2729
|
+
consensus_threshold,
|
|
2730
|
+
fail_strategy
|
|
2731
|
+
)
|
|
2732
|
+
|
|
2733
|
+
print(f" -> {successes_this_round}/{len(failed_pairs)} pairs succeeded on retry")
|
|
2734
|
+
|
|
2735
|
+
# Safety save after retry
|
|
2736
|
+
if safety:
|
|
2737
|
+
_save_partial_results(
|
|
2738
|
+
all_results,
|
|
2739
|
+
model_configs,
|
|
2740
|
+
categories,
|
|
2741
|
+
filename,
|
|
2742
|
+
save_directory,
|
|
2743
|
+
)
|
|
2744
|
+
|
|
2745
|
+
# Early exit if ALL retries failed (likely server down or out of credits)
|
|
2746
|
+
if successes_this_round == 0:
|
|
2747
|
+
print(f" -> All retries failed. Stopping retry loop (possible server issue).")
|
|
2748
|
+
break
|
|
2749
|
+
|
|
2750
|
+
# Print summary of filled/failed values
|
|
2751
|
+
total_filled = 0
|
|
2752
|
+
filled_by_model = {}
|
|
2753
|
+
total_schema_failures = 0
|
|
2754
|
+
failures_by_model = {}
|
|
2755
|
+
for result in all_results:
|
|
2756
|
+
if result.get("skipped"):
|
|
2757
|
+
continue
|
|
2758
|
+
agg = result["aggregated"]
|
|
2759
|
+
# Count missing keys that were filled with 0
|
|
2760
|
+
for model_name, count in agg.get("missing_keys", {}).items():
|
|
2761
|
+
total_filled += count
|
|
2762
|
+
filled_by_model[model_name] = filled_by_model.get(model_name, 0) + count
|
|
2763
|
+
# Count models that still failed after retries
|
|
2764
|
+
for model_name in agg.get("failed_models", []):
|
|
2765
|
+
total_schema_failures += 1
|
|
2766
|
+
failures_by_model[model_name] = failures_by_model.get(model_name, 0) + 1
|
|
2767
|
+
|
|
2768
|
+
if total_filled > 0 or total_schema_failures > 0:
|
|
2769
|
+
print(f"\n--- Classification Quality Summary ---")
|
|
2770
|
+
if total_filled > 0:
|
|
2771
|
+
print(f" Missing JSON keys filled with 0: {total_filled} values")
|
|
2772
|
+
for model_name, count in sorted(filled_by_model.items()):
|
|
2773
|
+
print(f" {model_name}: {count} filled")
|
|
2774
|
+
if total_schema_failures > 0:
|
|
2775
|
+
print(f" Schema failures (after retries): {total_schema_failures}")
|
|
2776
|
+
for model_name, count in sorted(failures_by_model.items()):
|
|
2777
|
+
print(f" {model_name}: {count} failed rows")
|
|
2778
|
+
print()
|
|
2779
|
+
|
|
2780
|
+
# Embedding tiebreaker: resolve true ties using centroids
|
|
2781
|
+
if embedding_tiebreaker_state is not None and len(model_configs) > 1:
|
|
2782
|
+
from ._tiebreaker import resolve_ties_with_centroids
|
|
2783
|
+
|
|
2784
|
+
resolve_ties_with_centroids(
|
|
2785
|
+
all_results,
|
|
2786
|
+
categories,
|
|
2787
|
+
embedding_tiebreaker_state["model"],
|
|
2788
|
+
embedding_tiebreaker_state["threshold"],
|
|
2789
|
+
embedding_tiebreaker_state.get("min_centroid_size", 3),
|
|
2790
|
+
)
|
|
2791
|
+
|
|
2792
|
+
# Build output DataFrames
|
|
2793
|
+
print("Building output DataFrames...")
|
|
2794
|
+
return build_output_dataframes(
|
|
2795
|
+
all_results,
|
|
2796
|
+
model_configs,
|
|
2797
|
+
categories,
|
|
2798
|
+
filename,
|
|
2799
|
+
save_directory,
|
|
2800
|
+
)
|
|
2801
|
+
|
|
2802
|
+
|
|
2803
|
+
def build_output_dataframes(
|
|
2804
|
+
all_results: list,
|
|
2805
|
+
model_configs: list,
|
|
2806
|
+
categories: list,
|
|
2807
|
+
filename: str,
|
|
2808
|
+
save_directory: str,
|
|
2809
|
+
) -> dict:
|
|
2810
|
+
"""
|
|
2811
|
+
Build the output DataFrames from classification results.
|
|
2812
|
+
|
|
2813
|
+
Returns:
|
|
2814
|
+
Dict with "combined", "consensus", and per-model DataFrames
|
|
2815
|
+
"""
|
|
2816
|
+
model_names = [cfg["sanitized_name"] for cfg in model_configs]
|
|
2817
|
+
num_categories = len(categories)
|
|
2818
|
+
|
|
2819
|
+
# Check if any results have PDF metadata
|
|
2820
|
+
has_pdf_metadata = any(result.get("pdf_path") is not None for result in all_results)
|
|
2821
|
+
|
|
2822
|
+
# Check if any results have image metadata
|
|
2823
|
+
has_image_metadata = any(result.get("image_path") is not None for result in all_results)
|
|
2824
|
+
|
|
2825
|
+
# Initialize data structures
|
|
2826
|
+
combined_data = {
|
|
2827
|
+
"input_data": [],
|
|
2828
|
+
"processing_status": [],
|
|
2829
|
+
"failed_models": [],
|
|
2830
|
+
}
|
|
2831
|
+
|
|
2832
|
+
# Add PDF metadata columns if present
|
|
2833
|
+
if has_pdf_metadata:
|
|
2834
|
+
combined_data["pdf_path"] = []
|
|
2835
|
+
combined_data["page_index"] = []
|
|
2836
|
+
|
|
2837
|
+
# Add image metadata columns if present
|
|
2838
|
+
if has_image_metadata:
|
|
2839
|
+
combined_data["image_path"] = []
|
|
2840
|
+
|
|
2841
|
+
# Add columns for each model and each category
|
|
2842
|
+
for model_name in model_names:
|
|
2843
|
+
for i in range(1, num_categories + 1):
|
|
2844
|
+
combined_data[f"category_{i}_{model_name}"] = []
|
|
2845
|
+
|
|
2846
|
+
# Add consensus and agreement columns
|
|
2847
|
+
for i in range(1, num_categories + 1):
|
|
2848
|
+
combined_data[f"category_{i}_consensus"] = []
|
|
2849
|
+
combined_data[f"category_{i}_agreement"] = []
|
|
2850
|
+
|
|
2851
|
+
# Check if tiebreaker data exists
|
|
2852
|
+
has_tiebreaker = any(
|
|
2853
|
+
"tiebreaker_resolved" in result.get("aggregated", {})
|
|
2854
|
+
for result in all_results
|
|
2855
|
+
if not result.get("skipped")
|
|
2856
|
+
)
|
|
2857
|
+
if has_tiebreaker:
|
|
2858
|
+
for i in range(1, num_categories + 1):
|
|
2859
|
+
combined_data[f"category_{i}_resolved_by"] = []
|
|
2860
|
+
|
|
2861
|
+
# Populate data
|
|
2862
|
+
for result in all_results:
|
|
2863
|
+
combined_data["input_data"].append(result["response"])
|
|
2864
|
+
aggregated = result["aggregated"]
|
|
2865
|
+
|
|
2866
|
+
# Add PDF metadata if present
|
|
2867
|
+
if has_pdf_metadata:
|
|
2868
|
+
combined_data["pdf_path"].append(result.get("pdf_path", ""))
|
|
2869
|
+
combined_data["page_index"].append(result.get("page_index", ""))
|
|
2870
|
+
|
|
2871
|
+
# Add image metadata if present
|
|
2872
|
+
if has_image_metadata:
|
|
2873
|
+
combined_data["image_path"].append(result.get("image_path", ""))
|
|
2874
|
+
|
|
2875
|
+
# Determine processing status
|
|
2876
|
+
if result.get("skipped"):
|
|
2877
|
+
combined_data["processing_status"].append("skipped")
|
|
2878
|
+
elif aggregated["error"]:
|
|
2879
|
+
combined_data["processing_status"].append("error")
|
|
2880
|
+
elif aggregated["failed_models"]:
|
|
2881
|
+
combined_data["processing_status"].append("partial")
|
|
2882
|
+
else:
|
|
2883
|
+
combined_data["processing_status"].append("success")
|
|
2884
|
+
|
|
2885
|
+
combined_data["failed_models"].append(
|
|
2886
|
+
",".join(aggregated["failed_models"]) if aggregated["failed_models"] else ""
|
|
2887
|
+
)
|
|
2888
|
+
|
|
2889
|
+
# Per-model results
|
|
2890
|
+
failed_set = set(aggregated.get("failed_models", []))
|
|
2891
|
+
for model_name in model_names:
|
|
2892
|
+
if model_name in failed_set:
|
|
2893
|
+
# Model failed validation entirely — mark as NA
|
|
2894
|
+
for i in range(1, num_categories + 1):
|
|
2895
|
+
combined_data[f"category_{i}_{model_name}"].append(None)
|
|
2896
|
+
else:
|
|
2897
|
+
parsed = aggregated["per_model"].get(model_name, {})
|
|
2898
|
+
for i in range(1, num_categories + 1):
|
|
2899
|
+
key = str(i)
|
|
2900
|
+
col_name = f"category_{i}_{model_name}"
|
|
2901
|
+
value = parsed.get(key, "0")
|
|
2902
|
+
try:
|
|
2903
|
+
combined_data[col_name].append(int(value))
|
|
2904
|
+
except (ValueError, TypeError):
|
|
2905
|
+
combined_data[col_name].append(0)
|
|
2906
|
+
|
|
2907
|
+
# Consensus results
|
|
2908
|
+
for i in range(1, num_categories + 1):
|
|
2909
|
+
key = str(i)
|
|
2910
|
+
consensus_val = aggregated["consensus"].get(key, None)
|
|
2911
|
+
agreement_val = aggregated["agreement"].get(key, None)
|
|
2912
|
+
|
|
2913
|
+
if consensus_val is not None:
|
|
2914
|
+
try:
|
|
2915
|
+
combined_data[f"category_{i}_consensus"].append(int(consensus_val))
|
|
2916
|
+
except (ValueError, TypeError):
|
|
2917
|
+
combined_data[f"category_{i}_consensus"].append(None)
|
|
2918
|
+
else:
|
|
2919
|
+
combined_data[f"category_{i}_consensus"].append(None)
|
|
2920
|
+
|
|
2921
|
+
combined_data[f"category_{i}_agreement"].append(agreement_val)
|
|
2922
|
+
|
|
2923
|
+
# Resolved-by metadata (tiebreaker)
|
|
2924
|
+
if has_tiebreaker:
|
|
2925
|
+
tiebreaker_data = aggregated.get("tiebreaker_resolved", {})
|
|
2926
|
+
for i in range(1, num_categories + 1):
|
|
2927
|
+
key = str(i)
|
|
2928
|
+
combined_data[f"category_{i}_resolved_by"].append(
|
|
2929
|
+
tiebreaker_data.get(key, "")
|
|
2930
|
+
)
|
|
2931
|
+
|
|
2932
|
+
# Create combined DataFrame
|
|
2933
|
+
combined_df = pd.DataFrame(combined_data)
|
|
2934
|
+
|
|
2935
|
+
# Convert category columns to Int64 (nullable integer)
|
|
2936
|
+
cat_cols = [c for c in combined_df.columns if c.startswith("category_") and not c.endswith("_agreement") and not c.endswith("_resolved_by")]
|
|
2937
|
+
for col in cat_cols:
|
|
2938
|
+
combined_df[col] = pd.to_numeric(combined_df[col], errors='coerce').astype('Int64')
|
|
2939
|
+
|
|
2940
|
+
# Create consensus-only DataFrame
|
|
2941
|
+
consensus_cols = ["input_data", "processing_status", "failed_models"]
|
|
2942
|
+
# Add PDF columns if present
|
|
2943
|
+
if has_pdf_metadata:
|
|
2944
|
+
consensus_cols += ["pdf_path", "page_index"]
|
|
2945
|
+
# Add image columns if present
|
|
2946
|
+
if has_image_metadata:
|
|
2947
|
+
consensus_cols += ["image_path"]
|
|
2948
|
+
consensus_cols += [c for c in combined_df.columns if "_consensus" in c or "_agreement" in c or "_resolved_by" in c]
|
|
2949
|
+
consensus_df = combined_df[consensus_cols].copy()
|
|
2950
|
+
|
|
2951
|
+
# Create per-model DataFrames
|
|
2952
|
+
output = {
|
|
2953
|
+
"combined": combined_df,
|
|
2954
|
+
"consensus": consensus_df,
|
|
2955
|
+
}
|
|
2956
|
+
|
|
2957
|
+
for model_name in model_names:
|
|
2958
|
+
model_cols = ["input_data", "processing_status"]
|
|
2959
|
+
# Add PDF columns if present
|
|
2960
|
+
if has_pdf_metadata:
|
|
2961
|
+
model_cols += ["pdf_path", "page_index"]
|
|
2962
|
+
# Add image columns if present
|
|
2963
|
+
if has_image_metadata:
|
|
2964
|
+
model_cols += ["image_path"]
|
|
2965
|
+
model_cols += [c for c in combined_df.columns if f"_{model_name}" in c]
|
|
2966
|
+
output[model_name] = combined_df[model_cols].copy()
|
|
2967
|
+
|
|
2968
|
+
# If only one model, simplify before saving (backward compatible with multi_class)
|
|
2969
|
+
if len(model_names) == 1:
|
|
2970
|
+
model_name = model_names[0]
|
|
2971
|
+
simplified_df = combined_df.copy()
|
|
2972
|
+
|
|
2973
|
+
# Remove consensus/agreement/resolved_by/failed_models columns (redundant for single model)
|
|
2974
|
+
cols_to_drop = [c for c in simplified_df.columns if "_consensus" in c or "_agreement" in c or "_resolved_by" in c]
|
|
2975
|
+
if "failed_models" in simplified_df.columns:
|
|
2976
|
+
cols_to_drop.append("failed_models")
|
|
2977
|
+
simplified_df = simplified_df.drop(columns=cols_to_drop)
|
|
2978
|
+
|
|
2979
|
+
# Rename category columns to remove model suffix: category_1_model_name -> category_1
|
|
2980
|
+
rename_map = {}
|
|
2981
|
+
for col in simplified_df.columns:
|
|
2982
|
+
if col.startswith("category_") and f"_{model_name}" in col:
|
|
2983
|
+
new_name = col.replace(f"_{model_name}", "")
|
|
2984
|
+
rename_map[col] = new_name
|
|
2985
|
+
simplified_df = simplified_df.rename(columns=rename_map)
|
|
2986
|
+
|
|
2987
|
+
if filename:
|
|
2988
|
+
save_path = filename
|
|
2989
|
+
if save_directory:
|
|
2990
|
+
os.makedirs(save_directory, exist_ok=True)
|
|
2991
|
+
save_path = os.path.join(save_directory, filename)
|
|
2992
|
+
simplified_df.to_csv(save_path, index=False)
|
|
2993
|
+
print(f"\nCombined results saved to {save_path}")
|
|
2994
|
+
|
|
2995
|
+
return simplified_df
|
|
2996
|
+
|
|
2997
|
+
# Multi-model: save the full combined DataFrame
|
|
2998
|
+
if filename:
|
|
2999
|
+
save_path = filename
|
|
3000
|
+
if save_directory:
|
|
3001
|
+
os.makedirs(save_directory, exist_ok=True)
|
|
3002
|
+
save_path = os.path.join(save_directory, filename)
|
|
3003
|
+
combined_df.to_csv(save_path, index=False)
|
|
3004
|
+
print(f"\nCombined results saved to {save_path}")
|
|
3005
|
+
|
|
3006
|
+
# For multiple models, return the combined DataFrame (contains all model results + consensus)
|
|
3007
|
+
return combined_df
|
|
3008
|
+
|
|
3009
|
+
|
|
3010
|
+
# Backward compatibility alias
|
|
3011
|
+
multi_class_ensemble = classify_ensemble
|
|
3012
|
+
|
|
3013
|
+
|
|
3014
|
+
# =============================================================================
|
|
3015
|
+
# Summarization helpers
|
|
3016
|
+
# =============================================================================
|
|
3017
|
+
|
|
3018
|
+
def _save_partial_summarize_results(all_results, model_configs, model_names, is_pdf_mode, filename, save_directory):
|
|
3019
|
+
"""Save partial summarization results to CSV for safety/incremental saves."""
|
|
3020
|
+
rows = []
|
|
3021
|
+
for entry in all_results:
|
|
3022
|
+
item = entry["input_data"]
|
|
3023
|
+
if is_pdf_mode and isinstance(item, tuple) and len(item) == 3:
|
|
3024
|
+
pdf_path, page_index, page_label = item
|
|
3025
|
+
row = {
|
|
3026
|
+
"input_data": page_label,
|
|
3027
|
+
"pdf_path": pdf_path,
|
|
3028
|
+
"page_index": page_index,
|
|
3029
|
+
}
|
|
3030
|
+
else:
|
|
3031
|
+
row = {"input_data": item}
|
|
3032
|
+
|
|
3033
|
+
for model_name, json_str in entry["model_results"].items():
|
|
3034
|
+
is_valid, summary_text = extract_summary_from_json(json_str)
|
|
3035
|
+
if len(model_configs) > 1:
|
|
3036
|
+
row[f"summary_{model_name}"] = summary_text if is_valid else ""
|
|
3037
|
+
else:
|
|
3038
|
+
row["summary"] = summary_text if is_valid else ""
|
|
3039
|
+
|
|
3040
|
+
if entry["errors"]:
|
|
3041
|
+
row["processing_status"] = "error" if all(
|
|
3042
|
+
not extract_summary_from_json(v)[1] for v in entry["model_results"].values()
|
|
3043
|
+
) else "partial"
|
|
3044
|
+
else:
|
|
3045
|
+
row["processing_status"] = "success"
|
|
3046
|
+
|
|
3047
|
+
rows.append(row)
|
|
3048
|
+
|
|
3049
|
+
df = pd.DataFrame(rows)
|
|
3050
|
+
save_path = os.path.join(save_directory, filename) if save_directory else filename
|
|
3051
|
+
if save_directory:
|
|
3052
|
+
os.makedirs(save_directory, exist_ok=True)
|
|
3053
|
+
df.to_csv(save_path, index=False)
|
|
3054
|
+
|
|
3055
|
+
|
|
3056
|
+
# =============================================================================
|
|
3057
|
+
# Summarization Ensemble Function
|
|
3058
|
+
# =============================================================================
|
|
3059
|
+
|
|
3060
|
+
def summarize_ensemble(
|
|
3061
|
+
input_data,
|
|
3062
|
+
api_key: str = None,
|
|
3063
|
+
input_description: str = "",
|
|
3064
|
+
summary_instructions: str = "",
|
|
3065
|
+
max_length: int = None,
|
|
3066
|
+
focus: str = None,
|
|
3067
|
+
user_model: str = "gpt-4o",
|
|
3068
|
+
model_source: str = "auto",
|
|
3069
|
+
pdf_mode: str = "image",
|
|
3070
|
+
pdf_dpi: int = 150,
|
|
3071
|
+
creativity: float = None,
|
|
3072
|
+
chain_of_thought: bool = False,
|
|
3073
|
+
context_prompt: bool = False,
|
|
3074
|
+
step_back_prompt: bool = False,
|
|
3075
|
+
max_retries: int = 5,
|
|
3076
|
+
batch_retries: int = 2,
|
|
3077
|
+
retry_delay: float = 1.0,
|
|
3078
|
+
row_delay: float = 0.0,
|
|
3079
|
+
fail_strategy: str = "partial",
|
|
3080
|
+
safety: bool = False,
|
|
3081
|
+
filename: str = None,
|
|
3082
|
+
save_directory: str = None,
|
|
3083
|
+
thinking_budget: int = 0,
|
|
3084
|
+
progress_callback: Optional[Callable] = None,
|
|
3085
|
+
# Multi-model parameters
|
|
3086
|
+
models: list = None,
|
|
3087
|
+
max_workers: int = None,
|
|
3088
|
+
parallel: bool = None,
|
|
3089
|
+
auto_download: bool = False,
|
|
3090
|
+
) -> pd.DataFrame:
|
|
3091
|
+
"""
|
|
3092
|
+
Summarize text or PDF inputs using LLMs with optional multi-model ensemble.
|
|
3093
|
+
|
|
3094
|
+
Supports single-model and multi-model modes. In multi-model mode,
|
|
3095
|
+
summaries from all models are collected and synthesized into a consensus
|
|
3096
|
+
summary using an LLM. Input type is auto-detected from the data.
|
|
3097
|
+
|
|
3098
|
+
Args:
|
|
3099
|
+
input_data: Data to summarize. Can be:
|
|
3100
|
+
- Text: list of strings, pandas Series, or single string
|
|
3101
|
+
- PDF: directory path, single PDF path, or list of PDF paths
|
|
3102
|
+
api_key: API key for single-model mode
|
|
3103
|
+
input_description: Description of what the content contains (provides context)
|
|
3104
|
+
summary_instructions: Specific summarization instructions (e.g., "bullet points")
|
|
3105
|
+
max_length: Maximum summary length in words
|
|
3106
|
+
focus: What to focus on (e.g., "main arguments", "emotional content")
|
|
3107
|
+
user_model: Model to use (default "gpt-4o")
|
|
3108
|
+
model_source: Provider - "auto", "openai", "anthropic", "google", etc.
|
|
3109
|
+
pdf_mode: PDF processing mode (only used for PDF input):
|
|
3110
|
+
- "image" (default): Render pages as images
|
|
3111
|
+
- "text": Extract text only
|
|
3112
|
+
- "both": Send both image and extracted text
|
|
3113
|
+
pdf_dpi: DPI for PDF page rendering (default 150)
|
|
3114
|
+
creativity: Temperature setting (None uses provider default)
|
|
3115
|
+
chain_of_thought: Enable step-by-step reasoning (default True)
|
|
3116
|
+
context_prompt: Add expert context prefix
|
|
3117
|
+
step_back_prompt: Enable step-back prompting
|
|
3118
|
+
max_retries: Max retries per API call
|
|
3119
|
+
batch_retries: Number of batch retry passes for failed items
|
|
3120
|
+
retry_delay: Delay between retries in seconds
|
|
3121
|
+
row_delay: Delay in seconds between processing each row (default 0.0)
|
|
3122
|
+
fail_strategy: How to handle failures - "partial" (default) or "strict"
|
|
3123
|
+
thinking_budget: Token budget for extended thinking/reasoning (default 0)
|
|
3124
|
+
safety: Save progress after each item (requires filename)
|
|
3125
|
+
filename: Output CSV filename
|
|
3126
|
+
save_directory: Directory to save results
|
|
3127
|
+
progress_callback: Optional callback for progress updates
|
|
3128
|
+
models: For multi-model mode, list of (model, provider, api_key) tuples
|
|
3129
|
+
max_workers: Maximum parallel workers (default: min(len(models), 8))
|
|
3130
|
+
parallel: Controls concurrent vs sequential execution (None=auto-detect)
|
|
3131
|
+
auto_download: Auto-download missing Ollama models (default False)
|
|
3132
|
+
|
|
3133
|
+
Returns:
|
|
3134
|
+
DataFrame with columns:
|
|
3135
|
+
- input_data: Original text or page label (for PDFs)
|
|
3136
|
+
- summary: Generated summary (or consensus summary for multi-model)
|
|
3137
|
+
- summary_<model>: Per-model summaries (multi-model only)
|
|
3138
|
+
- processing_status: "success", "error", "skipped"
|
|
3139
|
+
- failed_models: Comma-separated list (multi-model only)
|
|
3140
|
+
- pdf_path: Path to source PDF (PDF mode only)
|
|
3141
|
+
- page_index: Page number, 0-indexed (PDF mode only)
|
|
3142
|
+
|
|
3143
|
+
Examples:
|
|
3144
|
+
>>> import cat_stack as cat
|
|
3145
|
+
>>>
|
|
3146
|
+
>>> # Single model text summarization
|
|
3147
|
+
>>> results = cat.summarize(
|
|
3148
|
+
... input_data=df['responses'],
|
|
3149
|
+
... description="Customer feedback",
|
|
3150
|
+
... api_key=api_key
|
|
3151
|
+
... )
|
|
3152
|
+
>>>
|
|
3153
|
+
>>> # PDF summarization (auto-detected)
|
|
3154
|
+
>>> results = cat.summarize(
|
|
3155
|
+
... input_data="/path/to/pdfs/",
|
|
3156
|
+
... description="Research papers",
|
|
3157
|
+
... mode="image",
|
|
3158
|
+
... api_key=api_key
|
|
3159
|
+
... )
|
|
3160
|
+
>>>
|
|
3161
|
+
>>> # Multi-model with synthesis
|
|
3162
|
+
>>> results = cat.summarize(
|
|
3163
|
+
... input_data=df['responses'],
|
|
3164
|
+
... models=[
|
|
3165
|
+
... ("gpt-4o", "openai", "sk-..."),
|
|
3166
|
+
... ("claude-sonnet-4-5-20250929", "anthropic", "sk-ant-..."),
|
|
3167
|
+
... ],
|
|
3168
|
+
... )
|
|
3169
|
+
"""
|
|
3170
|
+
# Safety validation
|
|
3171
|
+
if safety and filename is None:
|
|
3172
|
+
raise TypeError("filename is required when using safety=True.")
|
|
3173
|
+
|
|
3174
|
+
# Detect input type: Text vs PDF
|
|
3175
|
+
input_type = _detect_input_type(input_data)
|
|
3176
|
+
is_pdf_mode = (input_type == 'pdf')
|
|
3177
|
+
|
|
3178
|
+
if is_pdf_mode:
|
|
3179
|
+
# Validate pdf_mode parameter
|
|
3180
|
+
pdf_mode = pdf_mode.lower()
|
|
3181
|
+
if pdf_mode not in {"image", "text", "both"}:
|
|
3182
|
+
raise ValueError(f"pdf_mode must be 'image', 'text', or 'both', got: {pdf_mode}")
|
|
3183
|
+
|
|
3184
|
+
print(f"\nInput type detected: PDF")
|
|
3185
|
+
print(f"PDF processing mode: {pdf_mode}")
|
|
3186
|
+
|
|
3187
|
+
# Load PDF files
|
|
3188
|
+
pdf_files = _load_pdf_files(input_data)
|
|
3189
|
+
|
|
3190
|
+
# Extract all pages from all PDFs
|
|
3191
|
+
all_pages = []
|
|
3192
|
+
for pdf_path in pdf_files:
|
|
3193
|
+
pages = _get_pdf_pages(pdf_path)
|
|
3194
|
+
all_pages.extend(pages)
|
|
3195
|
+
|
|
3196
|
+
if not all_pages:
|
|
3197
|
+
raise ValueError("No pages found in the provided PDF files.")
|
|
3198
|
+
|
|
3199
|
+
items_to_process = all_pages
|
|
3200
|
+
print(f"Total PDF pages to summarize: {len(items_to_process)}")
|
|
3201
|
+
else:
|
|
3202
|
+
# TEXT MODE: Normalize input to list
|
|
3203
|
+
print(f"\nInput type detected: TEXT")
|
|
3204
|
+
if isinstance(input_data, str):
|
|
3205
|
+
input_data = [input_data]
|
|
3206
|
+
elif hasattr(input_data, 'tolist'):
|
|
3207
|
+
input_data = input_data.tolist()
|
|
3208
|
+
else:
|
|
3209
|
+
input_data = list(input_data)
|
|
3210
|
+
|
|
3211
|
+
items_to_process = input_data
|
|
3212
|
+
print(f"Total texts to summarize: {len(items_to_process)}")
|
|
3213
|
+
|
|
3214
|
+
# Normalize model input to list of tuples
|
|
3215
|
+
models = normalize_model_input(user_model, api_key, model_source, models)
|
|
3216
|
+
|
|
3217
|
+
# Validate and prepare model configs
|
|
3218
|
+
print(f"Validating {len(models)} model configuration(s)...")
|
|
3219
|
+
model_configs = prepare_model_configs(models, auto_download=auto_download)
|
|
3220
|
+
|
|
3221
|
+
if not model_configs:
|
|
3222
|
+
raise ValueError("No valid model configurations found.")
|
|
3223
|
+
|
|
3224
|
+
model_names = [cfg["sanitized_name"] for cfg in model_configs]
|
|
3225
|
+
print(f"\nModels to use:")
|
|
3226
|
+
for cfg in model_configs:
|
|
3227
|
+
print(f" - {cfg['model']} ({cfg['provider']}) -> column suffix: {cfg['sanitized_name']}")
|
|
3228
|
+
|
|
3229
|
+
# Build JSON schemas per provider (for summary output)
|
|
3230
|
+
json_schemas = {}
|
|
3231
|
+
for cfg in model_configs:
|
|
3232
|
+
provider = cfg["provider"]
|
|
3233
|
+
include_additional = provider != "google"
|
|
3234
|
+
json_schemas[cfg["sanitized_name"]] = build_summary_json_schema(include_additional)
|
|
3235
|
+
|
|
3236
|
+
# Example JSON for prompt
|
|
3237
|
+
example_json = '{"summary": "Your summary here"}'
|
|
3238
|
+
|
|
3239
|
+
# Gather step-back insights if enabled
|
|
3240
|
+
stepback_insights = {}
|
|
3241
|
+
if step_back_prompt:
|
|
3242
|
+
print("\nGathering step-back insights...")
|
|
3243
|
+
stepback_insights = gather_stepback_insights(
|
|
3244
|
+
model_configs=model_configs,
|
|
3245
|
+
context=input_description or "text summarization",
|
|
3246
|
+
question=f"What are the key factors to consider when summarizing text{f' with a focus on {focus}' if focus else ''}?"
|
|
3247
|
+
)
|
|
3248
|
+
|
|
3249
|
+
# Initialize results storage
|
|
3250
|
+
all_results = [] # List of dicts, one per input item
|
|
3251
|
+
failed_pairs = [] # List of (idx, model_name) pairs that failed
|
|
3252
|
+
|
|
3253
|
+
# Define the summarization function for a single item
|
|
3254
|
+
def summarize_single_item(item, idx, cfg):
|
|
3255
|
+
"""Summarize a single text item or PDF page with a single model."""
|
|
3256
|
+
model_name = cfg["sanitized_name"]
|
|
3257
|
+
|
|
3258
|
+
# Determine if this is a PDF page or text
|
|
3259
|
+
if is_pdf_mode and isinstance(item, tuple) and len(item) == 3:
|
|
3260
|
+
# PDF mode: item is (pdf_path, page_index, page_label)
|
|
3261
|
+
pdf_path, page_index, page_label = item
|
|
3262
|
+
|
|
3263
|
+
try:
|
|
3264
|
+
# Prepare page data based on mode and provider
|
|
3265
|
+
page_data = _prepare_page_data(
|
|
3266
|
+
pdf_path=pdf_path,
|
|
3267
|
+
page_index=page_index,
|
|
3268
|
+
page_label=page_label,
|
|
3269
|
+
pdf_mode=pdf_mode,
|
|
3270
|
+
provider=cfg["provider"],
|
|
3271
|
+
pdf_dpi=pdf_dpi,
|
|
3272
|
+
)
|
|
3273
|
+
|
|
3274
|
+
# Check for extraction errors
|
|
3275
|
+
if page_data.get("error"):
|
|
3276
|
+
return (model_name, '{"summary": ""}', page_data["error"])
|
|
3277
|
+
|
|
3278
|
+
# Build PDF summarization prompt
|
|
3279
|
+
messages = build_pdf_summarization_prompt(
|
|
3280
|
+
page_data=page_data,
|
|
3281
|
+
input_description=input_description,
|
|
3282
|
+
summary_instructions=summary_instructions,
|
|
3283
|
+
max_length=max_length,
|
|
3284
|
+
focus=focus,
|
|
3285
|
+
provider=cfg["provider"],
|
|
3286
|
+
pdf_mode=pdf_mode,
|
|
3287
|
+
chain_of_thought=chain_of_thought,
|
|
3288
|
+
context_prompt=context_prompt,
|
|
3289
|
+
step_back_prompt=step_back_prompt,
|
|
3290
|
+
stepback_insights=stepback_insights,
|
|
3291
|
+
model_name=model_name,
|
|
3292
|
+
)
|
|
3293
|
+
|
|
3294
|
+
# Create client and make API call
|
|
3295
|
+
client = UnifiedLLMClient(
|
|
3296
|
+
provider=cfg["provider"],
|
|
3297
|
+
api_key=cfg["api_key"],
|
|
3298
|
+
model=cfg["model"],
|
|
3299
|
+
)
|
|
3300
|
+
|
|
3301
|
+
json_schema = json_schemas[model_name]
|
|
3302
|
+
|
|
3303
|
+
# Resolve thinking_budget for this provider
|
|
3304
|
+
effective_thinking = thinking_budget if cfg["provider"] in ("google", "openai", "anthropic", "huggingface", "huggingface-together") else None
|
|
3305
|
+
|
|
3306
|
+
# Handle Google multimodal differently
|
|
3307
|
+
if cfg["provider"] == "google" and pdf_mode != "text":
|
|
3308
|
+
response = _call_google_multimodal(
|
|
3309
|
+
client=client,
|
|
3310
|
+
messages=messages,
|
|
3311
|
+
json_schema=json_schema,
|
|
3312
|
+
creativity=creativity,
|
|
3313
|
+
thinking_budget=effective_thinking or 0,
|
|
3314
|
+
max_retries=max_retries,
|
|
3315
|
+
)
|
|
3316
|
+
else:
|
|
3317
|
+
response, _err = client.complete(
|
|
3318
|
+
messages=messages,
|
|
3319
|
+
json_schema=json_schema,
|
|
3320
|
+
creativity=creativity,
|
|
3321
|
+
thinking_budget=effective_thinking,
|
|
3322
|
+
max_retries=max_retries,
|
|
3323
|
+
)
|
|
3324
|
+
|
|
3325
|
+
# Extract JSON from response
|
|
3326
|
+
json_str = extract_json(response)
|
|
3327
|
+
|
|
3328
|
+
return (model_name, json_str, None)
|
|
3329
|
+
|
|
3330
|
+
except Exception as e:
|
|
3331
|
+
error_msg = str(e)
|
|
3332
|
+
return (model_name, '{"summary": ""}', error_msg)
|
|
3333
|
+
|
|
3334
|
+
else:
|
|
3335
|
+
# TEXT MODE: Original text handling
|
|
3336
|
+
# Skip empty/null items
|
|
3337
|
+
if item is None or (isinstance(item, str) and not item.strip()) or pd.isna(item):
|
|
3338
|
+
return (model_name, '{"summary": ""}', "skipped")
|
|
3339
|
+
|
|
3340
|
+
try:
|
|
3341
|
+
# Build the prompt
|
|
3342
|
+
messages = build_text_summarization_prompt(
|
|
3343
|
+
response_text=str(item),
|
|
3344
|
+
input_description=input_description,
|
|
3345
|
+
summary_instructions=summary_instructions,
|
|
3346
|
+
max_length=max_length,
|
|
3347
|
+
focus=focus,
|
|
3348
|
+
chain_of_thought=chain_of_thought,
|
|
3349
|
+
context_prompt=context_prompt,
|
|
3350
|
+
step_back_prompt=step_back_prompt,
|
|
3351
|
+
stepback_insights=stepback_insights,
|
|
3352
|
+
model_name=model_name,
|
|
3353
|
+
)
|
|
3354
|
+
|
|
3355
|
+
# Create client and make API call
|
|
3356
|
+
client = UnifiedLLMClient(
|
|
3357
|
+
provider=cfg["provider"],
|
|
3358
|
+
api_key=cfg["api_key"],
|
|
3359
|
+
model=cfg["model"],
|
|
3360
|
+
)
|
|
3361
|
+
|
|
3362
|
+
json_schema = json_schemas[model_name]
|
|
3363
|
+
|
|
3364
|
+
# Resolve thinking_budget for this provider
|
|
3365
|
+
effective_thinking = thinking_budget if cfg["provider"] in ("google", "openai", "anthropic", "huggingface", "huggingface-together") else None
|
|
3366
|
+
|
|
3367
|
+
response, _err = client.complete(
|
|
3368
|
+
messages=messages,
|
|
3369
|
+
json_schema=json_schema,
|
|
3370
|
+
creativity=creativity,
|
|
3371
|
+
thinking_budget=effective_thinking,
|
|
3372
|
+
max_retries=max_retries,
|
|
3373
|
+
)
|
|
3374
|
+
|
|
3375
|
+
# Extract JSON from response
|
|
3376
|
+
json_str = extract_json(response)
|
|
3377
|
+
|
|
3378
|
+
return (model_name, json_str, None)
|
|
3379
|
+
|
|
3380
|
+
except Exception as e:
|
|
3381
|
+
error_msg = str(e)
|
|
3382
|
+
return (model_name, '{"summary": ""}', error_msg)
|
|
3383
|
+
|
|
3384
|
+
# Process all items
|
|
3385
|
+
progress_desc = "Summarizing PDF pages" if is_pdf_mode else "Summarizing texts"
|
|
3386
|
+
print(f"\n{progress_desc}...")
|
|
3387
|
+
|
|
3388
|
+
# Auto-resolve parallel mode: sequential for all-local (Ollama), parallel otherwise
|
|
3389
|
+
if parallel is None:
|
|
3390
|
+
all_local = all(cfg["provider"] == "ollama" for cfg in model_configs)
|
|
3391
|
+
parallel = not all_local
|
|
3392
|
+
|
|
3393
|
+
# Determine number of workers
|
|
3394
|
+
effective_workers = max_workers or min(len(model_configs), 8)
|
|
3395
|
+
|
|
3396
|
+
# Progress tracking
|
|
3397
|
+
total_items = len(items_to_process)
|
|
3398
|
+
|
|
3399
|
+
for idx, item in enumerate(tqdm(items_to_process, desc=progress_desc)):
|
|
3400
|
+
item_results = {}
|
|
3401
|
+
item_errors = {}
|
|
3402
|
+
|
|
3403
|
+
# Process with each model (in parallel if enabled, sequential otherwise)
|
|
3404
|
+
if parallel and len(model_configs) > 1:
|
|
3405
|
+
with ThreadPoolExecutor(max_workers=effective_workers) as executor:
|
|
3406
|
+
futures = {
|
|
3407
|
+
executor.submit(summarize_single_item, item, idx, cfg): cfg["sanitized_name"]
|
|
3408
|
+
for cfg in model_configs
|
|
3409
|
+
}
|
|
3410
|
+
|
|
3411
|
+
for future in as_completed(futures):
|
|
3412
|
+
model_name, json_result, error = future.result()
|
|
3413
|
+
item_results[model_name] = json_result
|
|
3414
|
+
if error and error != "skipped":
|
|
3415
|
+
item_errors[model_name] = error
|
|
3416
|
+
failed_pairs.append((idx, model_name))
|
|
3417
|
+
else:
|
|
3418
|
+
for cfg in model_configs:
|
|
3419
|
+
model_name, json_result, error = summarize_single_item(item, idx, cfg)
|
|
3420
|
+
item_results[model_name] = json_result
|
|
3421
|
+
if error and error != "skipped":
|
|
3422
|
+
item_errors[model_name] = error
|
|
3423
|
+
failed_pairs.append((idx, model_name))
|
|
3424
|
+
|
|
3425
|
+
# Store results for this item
|
|
3426
|
+
result_entry = {
|
|
3427
|
+
"idx": idx,
|
|
3428
|
+
"input_data": item,
|
|
3429
|
+
"model_results": item_results,
|
|
3430
|
+
"errors": item_errors,
|
|
3431
|
+
}
|
|
3432
|
+
all_results.append(result_entry)
|
|
3433
|
+
|
|
3434
|
+
# Safety: incremental save after each item
|
|
3435
|
+
if safety:
|
|
3436
|
+
_save_partial_summarize_results(all_results, model_configs, model_names, is_pdf_mode, filename, save_directory)
|
|
3437
|
+
|
|
3438
|
+
# Progress callback
|
|
3439
|
+
if progress_callback:
|
|
3440
|
+
progress_callback(idx + 1, total_items)
|
|
3441
|
+
|
|
3442
|
+
# Row delay
|
|
3443
|
+
if row_delay > 0 and idx < len(items_to_process) - 1:
|
|
3444
|
+
time.sleep(row_delay)
|
|
3445
|
+
|
|
3446
|
+
# Batch retries for failed pairs
|
|
3447
|
+
for retry_pass in range(batch_retries):
|
|
3448
|
+
if not failed_pairs:
|
|
3449
|
+
break
|
|
3450
|
+
|
|
3451
|
+
print(f"\n[Batch retry {retry_pass + 1}/{batch_retries}] Retrying {len(failed_pairs)} failed (row, model) pairs...")
|
|
3452
|
+
time.sleep(retry_delay)
|
|
3453
|
+
|
|
3454
|
+
retry_success = 0
|
|
3455
|
+
still_failed = []
|
|
3456
|
+
|
|
3457
|
+
for idx, model_name in failed_pairs:
|
|
3458
|
+
cfg = next((c for c in model_configs if c["sanitized_name"] == model_name), None)
|
|
3459
|
+
if not cfg:
|
|
3460
|
+
continue
|
|
3461
|
+
|
|
3462
|
+
item = items_to_process[idx]
|
|
3463
|
+
model_name_result, json_result, error = summarize_single_item(item, idx, cfg)
|
|
3464
|
+
|
|
3465
|
+
if error and error != "skipped":
|
|
3466
|
+
still_failed.append((idx, model_name))
|
|
3467
|
+
else:
|
|
3468
|
+
# Update the stored result
|
|
3469
|
+
all_results[idx]["model_results"][model_name] = json_result
|
|
3470
|
+
if model_name in all_results[idx]["errors"]:
|
|
3471
|
+
del all_results[idx]["errors"][model_name]
|
|
3472
|
+
retry_success += 1
|
|
3473
|
+
|
|
3474
|
+
failed_pairs = still_failed
|
|
3475
|
+
print(f" -> {retry_success}/{len(failed_pairs) + retry_success} pairs succeeded on retry")
|
|
3476
|
+
|
|
3477
|
+
# Safety: save after each batch retry pass
|
|
3478
|
+
if safety:
|
|
3479
|
+
_save_partial_summarize_results(all_results, model_configs, model_names, is_pdf_mode, filename, save_directory)
|
|
3480
|
+
|
|
3481
|
+
# Build output DataFrame
|
|
3482
|
+
print("\nBuilding output DataFrame...")
|
|
3483
|
+
|
|
3484
|
+
rows = []
|
|
3485
|
+
for entry in all_results:
|
|
3486
|
+
# Handle PDF mode: extract metadata from tuple
|
|
3487
|
+
item = entry["input_data"]
|
|
3488
|
+
if is_pdf_mode and isinstance(item, tuple) and len(item) == 3:
|
|
3489
|
+
pdf_path, page_index, page_label = item
|
|
3490
|
+
row = {
|
|
3491
|
+
"input_data": page_label,
|
|
3492
|
+
"pdf_path": pdf_path,
|
|
3493
|
+
"page_index": page_index,
|
|
3494
|
+
}
|
|
3495
|
+
original_text_for_synthesis = page_label # Use page label for synthesis context
|
|
3496
|
+
else:
|
|
3497
|
+
row = {"input_data": item}
|
|
3498
|
+
original_text_for_synthesis = item
|
|
3499
|
+
|
|
3500
|
+
# Extract summaries from each model
|
|
3501
|
+
summaries = {}
|
|
3502
|
+
has_errors = bool(entry["errors"])
|
|
3503
|
+
for model_name, json_str in entry["model_results"].items():
|
|
3504
|
+
is_valid, summary_text = extract_summary_from_json(json_str)
|
|
3505
|
+
if is_valid:
|
|
3506
|
+
summaries[model_name] = summary_text
|
|
3507
|
+
else:
|
|
3508
|
+
summaries[model_name] = ""
|
|
3509
|
+
|
|
3510
|
+
# fail_strategy="strict": blank out results if any model failed
|
|
3511
|
+
if fail_strategy == "strict" and has_errors:
|
|
3512
|
+
summaries = {k: "" for k in summaries}
|
|
3513
|
+
|
|
3514
|
+
# For multi-model: synthesize consensus
|
|
3515
|
+
if len(model_configs) > 1:
|
|
3516
|
+
# Add individual model summaries
|
|
3517
|
+
for model_name in model_names:
|
|
3518
|
+
row[f"summary_{model_name}"] = summaries.get(model_name, "")
|
|
3519
|
+
|
|
3520
|
+
# Synthesize consensus summary using the first successful model
|
|
3521
|
+
valid_summaries = {k: v for k, v in summaries.items() if v}
|
|
3522
|
+
if valid_summaries:
|
|
3523
|
+
# Use the first model config for synthesis
|
|
3524
|
+
synthesis_cfg = model_configs[0]
|
|
3525
|
+
consensus = _synthesize_summaries(
|
|
3526
|
+
summaries=valid_summaries,
|
|
3527
|
+
original_text=str(original_text_for_synthesis),
|
|
3528
|
+
synthesis_config=synthesis_cfg,
|
|
3529
|
+
max_retries=max_retries,
|
|
3530
|
+
)
|
|
3531
|
+
row["summary"] = consensus
|
|
3532
|
+
else:
|
|
3533
|
+
row["summary"] = ""
|
|
3534
|
+
|
|
3535
|
+
# Track failed models
|
|
3536
|
+
row["failed_models"] = ",".join(entry["errors"].keys()) if entry["errors"] else ""
|
|
3537
|
+
|
|
3538
|
+
else:
|
|
3539
|
+
# Single model: just use the summary directly
|
|
3540
|
+
model_name = model_names[0]
|
|
3541
|
+
row["summary"] = summaries.get(model_name, "")
|
|
3542
|
+
|
|
3543
|
+
# Processing status
|
|
3544
|
+
if all(not s for s in summaries.values()):
|
|
3545
|
+
# For PDF mode, check if it's a valid tuple (never skip PDFs)
|
|
3546
|
+
if is_pdf_mode:
|
|
3547
|
+
row["processing_status"] = "error"
|
|
3548
|
+
elif item is None or (isinstance(item, str) and not item.strip()):
|
|
3549
|
+
row["processing_status"] = "skipped"
|
|
3550
|
+
else:
|
|
3551
|
+
row["processing_status"] = "error"
|
|
3552
|
+
elif any(not s for s in summaries.values()):
|
|
3553
|
+
row["processing_status"] = "partial"
|
|
3554
|
+
else:
|
|
3555
|
+
row["processing_status"] = "success"
|
|
3556
|
+
|
|
3557
|
+
rows.append(row)
|
|
3558
|
+
|
|
3559
|
+
df = pd.DataFrame(rows)
|
|
3560
|
+
|
|
3561
|
+
# Save to file if requested
|
|
3562
|
+
if filename:
|
|
3563
|
+
save_path = os.path.join(save_directory, filename) if save_directory else filename
|
|
3564
|
+
df.to_csv(save_path, index=False)
|
|
3565
|
+
print(f"\nResults saved to: {save_path}")
|
|
3566
|
+
|
|
3567
|
+
return df
|
|
3568
|
+
|
|
3569
|
+
|
|
3570
|
+
def _synthesize_summaries(
|
|
3571
|
+
summaries: dict,
|
|
3572
|
+
original_text: str,
|
|
3573
|
+
synthesis_config: dict,
|
|
3574
|
+
max_retries: int = 3,
|
|
3575
|
+
) -> str:
|
|
3576
|
+
"""
|
|
3577
|
+
Synthesize multiple model summaries into one consensus summary.
|
|
3578
|
+
|
|
3579
|
+
Args:
|
|
3580
|
+
summaries: Dict of {model_name: summary_text}
|
|
3581
|
+
original_text: The original text that was summarized
|
|
3582
|
+
synthesis_config: Model config to use for synthesis
|
|
3583
|
+
max_retries: Max retries for synthesis call
|
|
3584
|
+
|
|
3585
|
+
Returns:
|
|
3586
|
+
Synthesized consensus summary string
|
|
3587
|
+
"""
|
|
3588
|
+
if len(summaries) == 1:
|
|
3589
|
+
return list(summaries.values())[0]
|
|
3590
|
+
|
|
3591
|
+
# Build synthesis prompt
|
|
3592
|
+
summaries_text = "\n".join([
|
|
3593
|
+
f"- {model}: \"{summary}\""
|
|
3594
|
+
for model, summary in summaries.items()
|
|
3595
|
+
])
|
|
3596
|
+
|
|
3597
|
+
# Truncate original text if too long
|
|
3598
|
+
max_original_len = 500
|
|
3599
|
+
original_display = original_text[:max_original_len]
|
|
3600
|
+
if len(original_text) > max_original_len:
|
|
3601
|
+
original_display += "..."
|
|
3602
|
+
|
|
3603
|
+
synthesis_prompt = f"""You are synthesizing multiple AI-generated summaries of the same text into one optimal summary.
|
|
3604
|
+
|
|
3605
|
+
Original text: "{original_display}"
|
|
3606
|
+
|
|
3607
|
+
Summaries from different models:
|
|
3608
|
+
{summaries_text}
|
|
3609
|
+
|
|
3610
|
+
Create a single, comprehensive summary that captures the best insights from all summaries.
|
|
3611
|
+
Resolve any contradictions by focusing on accuracy.
|
|
3612
|
+
|
|
3613
|
+
Provide your answer in JSON format: {{"summary": "your synthesized summary"}}"""
|
|
3614
|
+
|
|
3615
|
+
try:
|
|
3616
|
+
client = UnifiedLLMClient(
|
|
3617
|
+
provider=synthesis_config["provider"],
|
|
3618
|
+
api_key=synthesis_config["api_key"],
|
|
3619
|
+
model=synthesis_config["model"],
|
|
3620
|
+
)
|
|
3621
|
+
|
|
3622
|
+
json_schema = build_summary_json_schema(
|
|
3623
|
+
include_additional_properties=synthesis_config["provider"] != "google"
|
|
3624
|
+
)
|
|
3625
|
+
|
|
3626
|
+
response, _err = client.complete(
|
|
3627
|
+
messages=[{"role": "user", "content": synthesis_prompt}],
|
|
3628
|
+
json_schema=json_schema,
|
|
3629
|
+
creativity=0.3, # Low creativity for synthesis
|
|
3630
|
+
max_retries=max_retries,
|
|
3631
|
+
)
|
|
3632
|
+
|
|
3633
|
+
json_str = extract_json(response)
|
|
3634
|
+
is_valid, summary = extract_summary_from_json(json_str)
|
|
3635
|
+
|
|
3636
|
+
if is_valid:
|
|
3637
|
+
return summary
|
|
3638
|
+
else:
|
|
3639
|
+
# Fallback: return the longest summary
|
|
3640
|
+
return max(summaries.values(), key=len)
|
|
3641
|
+
|
|
3642
|
+
except Exception as e:
|
|
3643
|
+
print(f"Warning: Synthesis failed ({e}), using longest summary as fallback")
|
|
3644
|
+
return max(summaries.values(), key=len)
|