cat-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cat_stack/__about__.py +10 -0
- cat_stack/__init__.py +128 -0
- cat_stack/_batch.py +1388 -0
- cat_stack/_category_analysis.py +348 -0
- cat_stack/_chunked.py +424 -0
- cat_stack/_embeddings.py +189 -0
- cat_stack/_formatter.py +169 -0
- cat_stack/_providers.py +1048 -0
- cat_stack/_tiebreaker.py +277 -0
- cat_stack/_utils.py +512 -0
- cat_stack/_web_fetch.py +194 -0
- cat_stack/calls/CoVe.py +287 -0
- cat_stack/calls/__init__.py +25 -0
- cat_stack/calls/all_calls.py +622 -0
- cat_stack/calls/image_CoVe.py +386 -0
- cat_stack/calls/image_stepback.py +210 -0
- cat_stack/calls/pdf_CoVe.py +386 -0
- cat_stack/calls/pdf_stepback.py +210 -0
- cat_stack/calls/stepback.py +180 -0
- cat_stack/calls/top_n.py +217 -0
- cat_stack/classify.py +682 -0
- cat_stack/explore.py +111 -0
- cat_stack/extract.py +218 -0
- cat_stack/image_functions.py +2078 -0
- cat_stack/images/circle.png +0 -0
- cat_stack/images/cube.png +0 -0
- cat_stack/images/diamond.png +0 -0
- cat_stack/images/overlapping_pentagons.png +0 -0
- cat_stack/images/rectangles.png +0 -0
- cat_stack/model_reference_list.py +94 -0
- cat_stack/pdf_functions.py +2087 -0
- cat_stack/summarize.py +290 -0
- cat_stack/text_functions.py +1358 -0
- cat_stack/text_functions_ensemble.py +3644 -0
- cat_stack-0.1.0.dist-info/METADATA +150 -0
- cat_stack-0.1.0.dist-info/RECORD +38 -0
- cat_stack-0.1.0.dist-info/WHEEL +4 -0
- cat_stack-0.1.0.dist-info/licenses/LICENSE +672 -0
cat_stack/classify.py
ADDED
|
@@ -0,0 +1,682 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Classification functions for CatLLM.
|
|
3
|
+
|
|
4
|
+
This module provides unified classification for text, image, and PDF inputs,
|
|
5
|
+
supporting both single-model and multi-model (ensemble) classification.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import math
|
|
9
|
+
import warnings
|
|
10
|
+
from typing import Union, Callable
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
# Main entry point
|
|
14
|
+
"classify",
|
|
15
|
+
# Ensemble function
|
|
16
|
+
"classify_ensemble",
|
|
17
|
+
# Deprecated functions (kept for backward compatibility)
|
|
18
|
+
"multi_class",
|
|
19
|
+
"image_multi_class",
|
|
20
|
+
"pdf_multi_class",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
# Import provider infrastructure
|
|
24
|
+
from ._providers import (
|
|
25
|
+
UnifiedLLMClient,
|
|
26
|
+
detect_provider,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
# Category analysis
|
|
30
|
+
from ._category_analysis import has_other_category, check_category_verbosity
|
|
31
|
+
|
|
32
|
+
# Import the implementation functions from existing modules
|
|
33
|
+
from .text_functions_ensemble import (
|
|
34
|
+
classify_ensemble,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Import deprecated functions for backward compatibility
|
|
38
|
+
from .text_functions import multi_class
|
|
39
|
+
from .image_functions import image_multi_class
|
|
40
|
+
from .pdf_functions import pdf_multi_class
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def classify(
|
|
44
|
+
input_data,
|
|
45
|
+
categories,
|
|
46
|
+
api_key=None,
|
|
47
|
+
input_type="text",
|
|
48
|
+
description="",
|
|
49
|
+
user_model="gpt-4o",
|
|
50
|
+
mode="image",
|
|
51
|
+
creativity=None,
|
|
52
|
+
safety=False,
|
|
53
|
+
chain_of_verification=False,
|
|
54
|
+
chain_of_thought=False,
|
|
55
|
+
step_back_prompt=False,
|
|
56
|
+
context_prompt=False,
|
|
57
|
+
thinking_budget=0,
|
|
58
|
+
example1=None,
|
|
59
|
+
example2=None,
|
|
60
|
+
example3=None,
|
|
61
|
+
example4=None,
|
|
62
|
+
example5=None,
|
|
63
|
+
example6=None,
|
|
64
|
+
filename=None,
|
|
65
|
+
save_directory=None,
|
|
66
|
+
model_source="auto",
|
|
67
|
+
max_categories=12,
|
|
68
|
+
categories_per_chunk=10,
|
|
69
|
+
divisions=10,
|
|
70
|
+
research_question=None,
|
|
71
|
+
progress_callback=None,
|
|
72
|
+
# Batch mode parameters
|
|
73
|
+
batch_mode: bool = False,
|
|
74
|
+
batch_poll_interval: float = 30.0,
|
|
75
|
+
batch_timeout: float = 86400.0,
|
|
76
|
+
# Multi-model parameters
|
|
77
|
+
models=None,
|
|
78
|
+
consensus_threshold: Union[str, float] = "unanimous",
|
|
79
|
+
# Parameters previously only on classify_ensemble
|
|
80
|
+
survey_question: str = "",
|
|
81
|
+
use_json_schema: bool = True,
|
|
82
|
+
max_workers: int = None,
|
|
83
|
+
parallel: bool = None,
|
|
84
|
+
fail_strategy: str = "partial",
|
|
85
|
+
max_retries: int = 5,
|
|
86
|
+
batch_retries: int = 2,
|
|
87
|
+
retry_delay: float = 1.0,
|
|
88
|
+
row_delay: float = 0.0,
|
|
89
|
+
pdf_dpi: int = 150,
|
|
90
|
+
auto_download: bool = False,
|
|
91
|
+
add_other = "prompt",
|
|
92
|
+
check_verbosity: bool = True,
|
|
93
|
+
json_formatter: bool = False,
|
|
94
|
+
embeddings: bool = False,
|
|
95
|
+
category_descriptions: dict = None,
|
|
96
|
+
embedding_tiebreaker: bool = False,
|
|
97
|
+
min_centroid_size: int = 3,
|
|
98
|
+
multi_label: bool = True,
|
|
99
|
+
categories_per_call: int = None,
|
|
100
|
+
):
|
|
101
|
+
"""
|
|
102
|
+
Unified classification function for text, image, and PDF inputs.
|
|
103
|
+
|
|
104
|
+
Supports single-model and multi-model (ensemble) classification. Input type
|
|
105
|
+
is auto-detected from the data (text strings, image paths, or PDF paths).
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
input_data: The data to classify. Can be:
|
|
109
|
+
- For text: list of text responses or pandas Series
|
|
110
|
+
- For image: directory path or list of image file paths
|
|
111
|
+
- For pdf: directory path or list of PDF file paths
|
|
112
|
+
categories (list): List of category names for classification.
|
|
113
|
+
api_key (str): API key for the model provider (single-model mode).
|
|
114
|
+
input_type (str): DEPRECATED - input type is now auto-detected.
|
|
115
|
+
Kept for backward compatibility.
|
|
116
|
+
description (str): Description of the input data context.
|
|
117
|
+
user_model (str): Model name to use. Default "gpt-4o".
|
|
118
|
+
mode (str): PDF processing mode:
|
|
119
|
+
- "image" (default): Render pages as images
|
|
120
|
+
- "text": Extract text only
|
|
121
|
+
- "both": Send both image and extracted text
|
|
122
|
+
creativity (float): Temperature setting. None uses model default.
|
|
123
|
+
safety (bool): If True, saves progress after each item.
|
|
124
|
+
chain_of_verification (bool): Enable Chain of Verification for accuracy.
|
|
125
|
+
chain_of_thought (bool): Enable step-by-step reasoning. Default False.
|
|
126
|
+
step_back_prompt (bool): Enable step-back prompting.
|
|
127
|
+
context_prompt (bool): Add expert context to prompts.
|
|
128
|
+
thinking_budget (int): Controls reasoning behavior per provider:
|
|
129
|
+
Google: token budget for extended thinking (0=off, >0=on).
|
|
130
|
+
OpenAI: maps to reasoning_effort (0="minimal", >0="high").
|
|
131
|
+
Anthropic: enables extended thinking (0=off, >0=on, min 1024).
|
|
132
|
+
example1-6 (str): Example categorizations for few-shot learning.
|
|
133
|
+
filename (str): Output filename for CSV.
|
|
134
|
+
save_directory (str): Directory to save results.
|
|
135
|
+
model_source (str): Provider - "auto", "openai", "anthropic", "google",
|
|
136
|
+
"mistral", "perplexity", "huggingface", "xai".
|
|
137
|
+
progress_callback: Optional callback for progress updates.
|
|
138
|
+
batch_mode (bool): If True, use async batch API (50% cost savings, higher rate limits).
|
|
139
|
+
Supported providers: openai, anthropic, google, mistral, xai.
|
|
140
|
+
Not supported: huggingface, perplexity, ollama.
|
|
141
|
+
Ensemble mode: supported. Each model submits its own batch job concurrently.
|
|
142
|
+
Providers without batch API (HuggingFace, Perplexity, Ollama) fall back to
|
|
143
|
+
synchronous calls and are merged in with the batch results.
|
|
144
|
+
Incompatible with: PDF/image input, progress_callback.
|
|
145
|
+
batch_poll_interval (float): Seconds between batch job status checks. Default 30.
|
|
146
|
+
batch_timeout (float): Max seconds to wait for batch completion. Default 86400 (24h).
|
|
147
|
+
models (list): For multi-model mode, list of (model, provider, api_key) tuples.
|
|
148
|
+
If provided, overrides user_model/api_key/model_source.
|
|
149
|
+
consensus_threshold (str or float): For multi-model mode, agreement threshold.
|
|
150
|
+
- "unanimous": 100% agreement (default — empirically produces
|
|
151
|
+
the highest accuracy by aggressively eliminating false positives)
|
|
152
|
+
- "majority": 50% agreement
|
|
153
|
+
- "two-thirds": 67% agreement
|
|
154
|
+
- float: Custom threshold between 0 and 1
|
|
155
|
+
survey_question (str): The survey question (used when categories="auto").
|
|
156
|
+
use_json_schema (bool): Use JSON schema for structured output. Default True.
|
|
157
|
+
max_workers (int): Max parallel workers for API calls. None = auto.
|
|
158
|
+
parallel (bool): Controls concurrent vs sequential model execution.
|
|
159
|
+
- None (default): auto-detect. Sequential for local models (Ollama),
|
|
160
|
+
parallel for cloud providers.
|
|
161
|
+
- True: force parallel execution.
|
|
162
|
+
- False: force sequential execution.
|
|
163
|
+
Sequential mode is useful for resource-constrained environments
|
|
164
|
+
(e.g., Ollama on limited hardware) or debugging.
|
|
165
|
+
fail_strategy (str): How to handle failures - "partial" (default) or "strict".
|
|
166
|
+
max_retries (int): Max retries per API call. Default 5.
|
|
167
|
+
batch_retries (int): Max retries for batch-level failures. Default 2.
|
|
168
|
+
retry_delay (float): Delay between retries in seconds. Default 1.0.
|
|
169
|
+
row_delay (float): Delay in seconds between processing each row. Useful
|
|
170
|
+
when multiple models share the same API provider/key to avoid rate
|
|
171
|
+
limits. Default 0.0 (no delay).
|
|
172
|
+
pdf_dpi (int): DPI for PDF page rendering. Default 150.
|
|
173
|
+
auto_download (bool): Auto-download Ollama models. Default False.
|
|
174
|
+
add_other (str or bool): Controls auto-addition of an "Other" catch-all
|
|
175
|
+
category when none is detected. An "Other" category improves accuracy
|
|
176
|
+
by preventing the model from forcing ambiguous responses into
|
|
177
|
+
ill-fitting categories.
|
|
178
|
+
- "prompt" (default): Ask the user to accept or reject the suggestion.
|
|
179
|
+
- True: Silently add "Other" without prompting.
|
|
180
|
+
- False: Never add "Other".
|
|
181
|
+
check_verbosity (bool): Check whether each category has a description
|
|
182
|
+
and examples (1 API call). Verbose categories with descriptions and
|
|
183
|
+
examples significantly improve classification accuracy over bare
|
|
184
|
+
labels. Default True. Set to False to skip.
|
|
185
|
+
json_formatter (bool): If True, use a local fine-tuned model to fix
|
|
186
|
+
malformed JSON output from classification LLMs before marking
|
|
187
|
+
responses as failed. The formatter runs only when extract_json()
|
|
188
|
+
produces invalid output — zero cost on the happy path. On first
|
|
189
|
+
use, the model (~1GB) is downloaded from HuggingFace Hub.
|
|
190
|
+
Requires: pip install cat-llm[formatter]. Default False.
|
|
191
|
+
embeddings (bool): If True, add embedding-based similarity scores
|
|
192
|
+
alongside binary 0/1 classifications. Uses a local sentence-
|
|
193
|
+
transformer model (BAAI/bge-small-en-v1.5, ~130MB) to compute
|
|
194
|
+
cosine similarity between each input text and each category.
|
|
195
|
+
Scores are independent per (text, category) pair — no softmax.
|
|
196
|
+
On first use, the model is downloaded from HuggingFace Hub.
|
|
197
|
+
Only works with text input (skipped for PDF/image).
|
|
198
|
+
Requires: pip install cat-llm[embeddings]. Default False.
|
|
199
|
+
category_descriptions (dict): Optional dict mapping category names
|
|
200
|
+
to richer text descriptions for embedding similarity. E.g.,
|
|
201
|
+
{"Past_Support": "References to help received from family"}.
|
|
202
|
+
Only used when embeddings=True.
|
|
203
|
+
embedding_tiebreaker (bool): If True, use embedding centroids to
|
|
204
|
+
resolve true ties in ensemble consensus. Builds per-category
|
|
205
|
+
centroids from unanimously-agreed rows and compares tied texts
|
|
206
|
+
to those centroids. Only applies to multi-model ensemble mode
|
|
207
|
+
with text input. Requires: pip install cat-llm[embeddings].
|
|
208
|
+
Default False.
|
|
209
|
+
min_centroid_size (int): Minimum number of unanimously-agreed rows
|
|
210
|
+
needed to build a centroid for a category. Categories with fewer
|
|
211
|
+
confident rows fall back to vote-based consensus. Default 3.
|
|
212
|
+
multi_label (bool): If True (default), allow multiple categories per
|
|
213
|
+
input (multi-label classification). If False, the prompt instructs
|
|
214
|
+
the model to pick the single best category (single-label mode).
|
|
215
|
+
The output format is unchanged — still one 0/1 column per category,
|
|
216
|
+
but exactly one column will be "1" per row in single-label mode.
|
|
217
|
+
categories_per_call (int): Maximum number of categories to send per
|
|
218
|
+
LLM call. When set, the category list is split into chunks of this
|
|
219
|
+
size, each chunk gets its own LLM call with local 1..N numbering,
|
|
220
|
+
and results are merged back into global numbering. This reduces
|
|
221
|
+
prompt complexity per call and can improve accuracy for large
|
|
222
|
+
category sets (e.g., 20+). Default None (all categories in one call).
|
|
223
|
+
Not supported with batch_mode=True.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
pd.DataFrame: Results with classification columns.
|
|
227
|
+
|
|
228
|
+
Examples:
|
|
229
|
+
>>> import cat_stack as cat
|
|
230
|
+
>>>
|
|
231
|
+
>>> # Single model classification
|
|
232
|
+
>>> results = cat.classify(
|
|
233
|
+
... input_data=df['responses'],
|
|
234
|
+
... categories=["Positive", "Negative", "Neutral"],
|
|
235
|
+
... description="Customer feedback survey",
|
|
236
|
+
... api_key="your-api-key"
|
|
237
|
+
... )
|
|
238
|
+
>>>
|
|
239
|
+
>>> # Multi-model ensemble
|
|
240
|
+
>>> results = cat.classify(
|
|
241
|
+
... input_data=df['responses'],
|
|
242
|
+
... categories=["Positive", "Negative"],
|
|
243
|
+
... models=[
|
|
244
|
+
... ("gpt-4o", "openai", "sk-..."),
|
|
245
|
+
... ("claude-sonnet-4-5-20250929", "anthropic", "sk-ant-..."),
|
|
246
|
+
... ],
|
|
247
|
+
... consensus_threshold="unanimous", # or "majority", "two-thirds", or 0.75
|
|
248
|
+
... )
|
|
249
|
+
"""
|
|
250
|
+
# Build models list
|
|
251
|
+
if models is None:
|
|
252
|
+
# Single model mode - build models list from individual params
|
|
253
|
+
models = [(user_model, model_source, api_key)]
|
|
254
|
+
|
|
255
|
+
# Auto-append "Other" catch-all category if missing
|
|
256
|
+
if add_other and categories and categories != "auto":
|
|
257
|
+
if not has_other_category(categories):
|
|
258
|
+
if add_other == "prompt":
|
|
259
|
+
print(
|
|
260
|
+
"\n[CatLLM] It looks like your categories may not include a catch-all\n"
|
|
261
|
+
" 'Other' option. Adding one can improve accuracy by giving the\n"
|
|
262
|
+
" model an outlet for ambiguous responses instead of forcing them\n"
|
|
263
|
+
" into ill-fitting categories.\n"
|
|
264
|
+
" (If you already have a catch-all under a different name, choose 'n'.)\n"
|
|
265
|
+
)
|
|
266
|
+
try:
|
|
267
|
+
answer = input(" Add 'Other' to your categories? (Y/n): ").strip().lower()
|
|
268
|
+
except (EOFError, KeyboardInterrupt):
|
|
269
|
+
answer = "n"
|
|
270
|
+
if answer in ("", "y", "yes"):
|
|
271
|
+
categories = list(categories) + ["Other"]
|
|
272
|
+
print(f" -> Categories are now: {categories}\n")
|
|
273
|
+
else:
|
|
274
|
+
print(" -> Keeping original categories.\n")
|
|
275
|
+
else:
|
|
276
|
+
# add_other=True — silently add
|
|
277
|
+
categories = list(categories) + ["Other"]
|
|
278
|
+
print(
|
|
279
|
+
f"[CatLLM] Auto-added 'Other' catch-all category. "
|
|
280
|
+
f"Categories are now: {categories} "
|
|
281
|
+
f"(set add_other=False to disable)"
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# Check category verbosity (1 API call)
|
|
285
|
+
if check_verbosity and categories and categories != "auto":
|
|
286
|
+
# Extract API key and provider from first model entry
|
|
287
|
+
first_entry = models[0]
|
|
288
|
+
check_key = first_entry[2] if len(first_entry) >= 3 else None
|
|
289
|
+
check_source = first_entry[1] if len(first_entry) >= 2 else "auto"
|
|
290
|
+
|
|
291
|
+
if check_key:
|
|
292
|
+
try:
|
|
293
|
+
verbosity = check_category_verbosity(
|
|
294
|
+
categories,
|
|
295
|
+
api_key=check_key,
|
|
296
|
+
model_source=check_source,
|
|
297
|
+
)
|
|
298
|
+
lacking = [r for r in verbosity if not r["is_verbose"]]
|
|
299
|
+
|
|
300
|
+
if lacking:
|
|
301
|
+
missing_desc = [r for r in lacking if not r["has_description"]]
|
|
302
|
+
missing_ex = [r for r in lacking if not r["has_examples"]]
|
|
303
|
+
|
|
304
|
+
print(
|
|
305
|
+
"\n[CatLLM] Category verbosity check (set check_verbosity=False to skip):"
|
|
306
|
+
)
|
|
307
|
+
for r in lacking:
|
|
308
|
+
issues = []
|
|
309
|
+
if not r["has_description"]:
|
|
310
|
+
issues.append("description")
|
|
311
|
+
if not r["has_examples"]:
|
|
312
|
+
issues.append("examples")
|
|
313
|
+
print(f' - "{r["category"]}" (missing: {", ".join(issues)})')
|
|
314
|
+
|
|
315
|
+
print(
|
|
316
|
+
"\n Verbose categories with descriptions and examples significantly\n"
|
|
317
|
+
" improve classification accuracy over bare labels.\n"
|
|
318
|
+
"\n"
|
|
319
|
+
" Instead of:\n"
|
|
320
|
+
' "Positive"\n'
|
|
321
|
+
" Consider:\n"
|
|
322
|
+
' "Positive: The response expresses satisfaction, approval, or\n'
|
|
323
|
+
" happiness (e.g., 'I love this product', 'Great experience',\n"
|
|
324
|
+
" 'Very pleased with the result')\"\n"
|
|
325
|
+
)
|
|
326
|
+
except Exception:
|
|
327
|
+
pass # Non-critical — don't block classification
|
|
328
|
+
|
|
329
|
+
# =========================================================================
|
|
330
|
+
# Validate categories_per_call
|
|
331
|
+
# =========================================================================
|
|
332
|
+
if categories_per_call is not None:
|
|
333
|
+
if not isinstance(categories_per_call, int) or categories_per_call < 1:
|
|
334
|
+
raise ValueError(
|
|
335
|
+
f"categories_per_call must be a positive integer, got {categories_per_call!r}"
|
|
336
|
+
)
|
|
337
|
+
if batch_mode:
|
|
338
|
+
raise ValueError(
|
|
339
|
+
"categories_per_call is not supported with batch_mode=True. "
|
|
340
|
+
"Set batch_mode=False to use categories_per_call."
|
|
341
|
+
)
|
|
342
|
+
if categories and categories != "auto":
|
|
343
|
+
if categories_per_call >= len(categories):
|
|
344
|
+
categories_per_call = None # no-op, all categories fit in one call
|
|
345
|
+
else:
|
|
346
|
+
num_chunks = math.ceil(len(categories) / categories_per_call)
|
|
347
|
+
print(
|
|
348
|
+
f"[CatLLM] categories_per_call={categories_per_call}: "
|
|
349
|
+
f"splitting {len(categories)} categories into {num_chunks} chunks"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# =========================================================================
|
|
353
|
+
# Evidence-based warnings for prompting strategies
|
|
354
|
+
# Based on empirical findings from Soria et al. (2026) comparing prompting
|
|
355
|
+
# strategies across 4 representative models and 4 survey tasks.
|
|
356
|
+
# =========================================================================
|
|
357
|
+
_strategy_warnings = []
|
|
358
|
+
|
|
359
|
+
if chain_of_verification:
|
|
360
|
+
_strategy_warnings.append(
|
|
361
|
+
"[CatLLM] WARNING: chain_of_verification=True is enabled.\n"
|
|
362
|
+
" Empirical evidence shows CoVe DEGRADES accuracy by ~2 pp and\n"
|
|
363
|
+
" sensitivity by up to 12 pp for structured classification tasks.\n"
|
|
364
|
+
" The verification step causes models to retract correct classifications.\n"
|
|
365
|
+
" Cost: ~4x API calls per response.\n"
|
|
366
|
+
" This feature is provided for research purposes only — it is not\n"
|
|
367
|
+
" recommended for improving classification accuracy."
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
examples = [example1, example2, example3, example4, example5, example6]
|
|
371
|
+
n_examples = sum(1 for ex in examples if ex is not None)
|
|
372
|
+
if n_examples > 0:
|
|
373
|
+
_strategy_warnings.append(
|
|
374
|
+
f"[CatLLM] NOTE: {n_examples} few-shot example(s) provided.\n"
|
|
375
|
+
" Empirical evidence shows few-shot examples DEGRADE accuracy by\n"
|
|
376
|
+
" ~1.1-1.2 pp on average. Examples encourage over-classification\n"
|
|
377
|
+
" (sensitivity up, but precision drops ~2-3 pp), amplifying false\n"
|
|
378
|
+
" positives. This feature is provided for research purposes — for\n"
|
|
379
|
+
" best results, use verbose category definitions instead."
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
if thinking_budget and thinking_budget > 0:
|
|
383
|
+
_strategy_warnings.append(
|
|
384
|
+
f"[CatLLM] NOTE: thinking_budget={thinking_budget} is enabled.\n"
|
|
385
|
+
" Empirical evidence shows reasoning/thinking modes produce negligible\n"
|
|
386
|
+
" accuracy gains (<1 pp) for classification tasks, while significantly\n"
|
|
387
|
+
" increasing latency, token usage, and failure rates (up to 40% timeouts\n"
|
|
388
|
+
" observed for some models). Consider thinking_budget=0 unless you are\n"
|
|
389
|
+
" specifically researching reasoning effects."
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
if chain_of_thought:
|
|
393
|
+
_strategy_warnings.append(
|
|
394
|
+
"[CatLLM] NOTE: chain_of_thought=True is enabled.\n"
|
|
395
|
+
" Empirical evidence shows CoT has no measurable effect on structured\n"
|
|
396
|
+
" classification accuracy (~0 pp change). When categories are well-defined\n"
|
|
397
|
+
" with verbose descriptions, explicit reasoning steps add no value.\n"
|
|
398
|
+
" This won't hurt results, but it won't help either."
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
if step_back_prompt:
|
|
402
|
+
_strategy_warnings.append(
|
|
403
|
+
"[CatLLM] NOTE: step_back_prompt=True is enabled.\n"
|
|
404
|
+
" Empirical evidence shows step-back prompting produces small, inconsistent\n"
|
|
405
|
+
" gains (+0.6 pp average) and actually degrades top-tier model performance.\n"
|
|
406
|
+
" Cost: ~2x API calls per response."
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
if _strategy_warnings:
|
|
410
|
+
print()
|
|
411
|
+
print("\n\n".join(_strategy_warnings))
|
|
412
|
+
print()
|
|
413
|
+
|
|
414
|
+
# =========================================================================
|
|
415
|
+
# JSON formatter fallback (opt-in)
|
|
416
|
+
# =========================================================================
|
|
417
|
+
_formatter_state = None
|
|
418
|
+
if json_formatter:
|
|
419
|
+
try:
|
|
420
|
+
from ._formatter import ensure_formatter_available, load_formatter
|
|
421
|
+
|
|
422
|
+
if ensure_formatter_available():
|
|
423
|
+
fmt_model, fmt_tokenizer, fmt_device = load_formatter()
|
|
424
|
+
_formatter_state = {
|
|
425
|
+
"model": fmt_model,
|
|
426
|
+
"tokenizer": fmt_tokenizer,
|
|
427
|
+
"device": fmt_device,
|
|
428
|
+
}
|
|
429
|
+
else:
|
|
430
|
+
json_formatter = False
|
|
431
|
+
print("[CatLLM] Continuing without JSON formatter fallback.")
|
|
432
|
+
except ImportError as e:
|
|
433
|
+
json_formatter = False
|
|
434
|
+
print(f"[CatLLM] JSON formatter unavailable: {e}")
|
|
435
|
+
print("[CatLLM] Continuing without JSON formatter fallback.")
|
|
436
|
+
|
|
437
|
+
# =========================================================================
|
|
438
|
+
# Embedding-based probability scores (opt-in)
|
|
439
|
+
# =========================================================================
|
|
440
|
+
_embedding_state = None
|
|
441
|
+
if embeddings:
|
|
442
|
+
try:
|
|
443
|
+
from ._embeddings import ensure_embeddings_available, load_embedding_model
|
|
444
|
+
|
|
445
|
+
if ensure_embeddings_available():
|
|
446
|
+
_embedding_state = {
|
|
447
|
+
"model": load_embedding_model(),
|
|
448
|
+
"category_descriptions": category_descriptions,
|
|
449
|
+
}
|
|
450
|
+
else:
|
|
451
|
+
embeddings = False
|
|
452
|
+
print("[CatLLM] Continuing without embedding scores.")
|
|
453
|
+
except ImportError as e:
|
|
454
|
+
embeddings = False
|
|
455
|
+
print(f"[CatLLM] Embeddings unavailable: {e}")
|
|
456
|
+
print("[CatLLM] Continuing without embedding scores.")
|
|
457
|
+
|
|
458
|
+
# Helper: apply embedding scores to a result DataFrame if enabled
|
|
459
|
+
def _maybe_apply_embeddings(result):
|
|
460
|
+
if _embedding_state is None:
|
|
461
|
+
return result
|
|
462
|
+
from ._embeddings import apply_embedding_scores
|
|
463
|
+
import pandas as _pd
|
|
464
|
+
if isinstance(result, _pd.DataFrame):
|
|
465
|
+
return apply_embedding_scores(
|
|
466
|
+
result, categories, _embedding_state["model"],
|
|
467
|
+
_embedding_state["category_descriptions"],
|
|
468
|
+
)
|
|
469
|
+
return result
|
|
470
|
+
|
|
471
|
+
# Map mode to pdf_mode
|
|
472
|
+
pdf_mode = mode if mode in ("image", "text", "both") else "image"
|
|
473
|
+
|
|
474
|
+
# Guard: skip embeddings for PDF/image input (embeddings need text)
|
|
475
|
+
if _embedding_state is not None:
|
|
476
|
+
from .text_functions_ensemble import _detect_input_type
|
|
477
|
+
_emb_detected_type = _detect_input_type(input_data)
|
|
478
|
+
if _emb_detected_type in ("pdf", "image"):
|
|
479
|
+
print(
|
|
480
|
+
f"[CatLLM] Embedding scores skipped — not supported for {_emb_detected_type} input."
|
|
481
|
+
)
|
|
482
|
+
_embedding_state = None
|
|
483
|
+
|
|
484
|
+
# =========================================================================
|
|
485
|
+
# Embedding tiebreaker setup (opt-in)
|
|
486
|
+
# =========================================================================
|
|
487
|
+
_embedding_tiebreaker_state = None
|
|
488
|
+
if embedding_tiebreaker:
|
|
489
|
+
# Guards: skip for single-model, PDF/image, batch mode
|
|
490
|
+
is_single_model = models is not None and len(models) == 1
|
|
491
|
+
if is_single_model:
|
|
492
|
+
print("[CatLLM] Embedding tiebreaker skipped — not applicable for single-model mode.")
|
|
493
|
+
else:
|
|
494
|
+
# Check input type
|
|
495
|
+
from .text_functions_ensemble import _detect_input_type
|
|
496
|
+
_tb_detected_type = _detect_input_type(input_data)
|
|
497
|
+
if _tb_detected_type in ("pdf", "image"):
|
|
498
|
+
print(
|
|
499
|
+
f"[CatLLM] Embedding tiebreaker skipped — not supported for {_tb_detected_type} input."
|
|
500
|
+
)
|
|
501
|
+
else:
|
|
502
|
+
try:
|
|
503
|
+
from ._embeddings import ensure_embeddings_available, load_embedding_model
|
|
504
|
+
|
|
505
|
+
# Reuse embedding model if embeddings=True already loaded it
|
|
506
|
+
if _embedding_state is not None:
|
|
507
|
+
tb_model = _embedding_state["model"]
|
|
508
|
+
elif ensure_embeddings_available():
|
|
509
|
+
tb_model = load_embedding_model()
|
|
510
|
+
else:
|
|
511
|
+
tb_model = None
|
|
512
|
+
print("[CatLLM] Continuing without embedding tiebreaker.")
|
|
513
|
+
|
|
514
|
+
if tb_model is not None:
|
|
515
|
+
# Resolve threshold to numeric for the tiebreaker
|
|
516
|
+
from .text_functions_ensemble import _resolve_consensus_threshold
|
|
517
|
+
_embedding_tiebreaker_state = {
|
|
518
|
+
"model": tb_model,
|
|
519
|
+
"threshold": _resolve_consensus_threshold(consensus_threshold),
|
|
520
|
+
"min_centroid_size": min_centroid_size,
|
|
521
|
+
}
|
|
522
|
+
except ImportError as e:
|
|
523
|
+
print(f"[CatLLM] Embedding tiebreaker unavailable: {e}")
|
|
524
|
+
print("[CatLLM] Continuing without embedding tiebreaker.")
|
|
525
|
+
|
|
526
|
+
# =========================================================================
|
|
527
|
+
# Batch mode — bypass classify_ensemble entirely
|
|
528
|
+
# =========================================================================
|
|
529
|
+
if batch_mode:
|
|
530
|
+
from ._batch import UNSUPPORTED_BATCH_PROVIDERS, run_batch_classify
|
|
531
|
+
from .text_functions_ensemble import prepare_json_schemas, prepare_model_configs
|
|
532
|
+
|
|
533
|
+
# Guard: text input only (auto-detect)
|
|
534
|
+
from .text_functions_ensemble import _detect_input_type
|
|
535
|
+
detected_type = _detect_input_type(input_data)
|
|
536
|
+
if detected_type in ("pdf", "image"):
|
|
537
|
+
raise ValueError(
|
|
538
|
+
f"batch_mode=True only supports text input, but detected input type is '{detected_type}'. "
|
|
539
|
+
"Set batch_mode=False for PDF/image classification."
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
# Warn if embedding_tiebreaker was provided (not supported in batch mode yet)
|
|
543
|
+
if _embedding_tiebreaker_state is not None:
|
|
544
|
+
print(
|
|
545
|
+
"[CatLLM] WARNING: embedding_tiebreaker is not supported in batch_mode. "
|
|
546
|
+
"The tiebreaker will be skipped for this run."
|
|
547
|
+
)
|
|
548
|
+
_embedding_tiebreaker_state = None
|
|
549
|
+
|
|
550
|
+
# Warn if progress_callback was provided (incompatible with batch)
|
|
551
|
+
if progress_callback is not None:
|
|
552
|
+
print(
|
|
553
|
+
"[CatLLM] WARNING: progress_callback is not available in batch_mode "
|
|
554
|
+
"(no per-item progress until the job completes). Ignoring callback."
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
# Build prompt components (mirrors what classify_ensemble does)
|
|
558
|
+
categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
|
|
559
|
+
survey_question_context = f"Context: {survey_question}." if survey_question else ""
|
|
560
|
+
examples = [example1, example2, example3, example4, example5, example6]
|
|
561
|
+
examples_text = "\n".join(
|
|
562
|
+
f"Example {i}: {ex}" for i, ex in enumerate(examples, 1) if ex is not None
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
model_configs = prepare_model_configs(models, auto_download=auto_download)
|
|
566
|
+
json_schemas = prepare_json_schemas(model_configs, categories, use_json_schema)
|
|
567
|
+
items = list(input_data) if not isinstance(input_data, list) else input_data
|
|
568
|
+
|
|
569
|
+
if len(models) == 1:
|
|
570
|
+
cfg = model_configs[0]
|
|
571
|
+
if cfg["provider"] in UNSUPPORTED_BATCH_PROVIDERS:
|
|
572
|
+
raise ValueError(
|
|
573
|
+
f"batch_mode=True is not supported for provider '{cfg['provider']}'. "
|
|
574
|
+
f"Supported providers: openai, anthropic, google, mistral, xai."
|
|
575
|
+
)
|
|
576
|
+
prompt_params = {
|
|
577
|
+
"categories_str": categories_str,
|
|
578
|
+
"survey_question_context": survey_question_context,
|
|
579
|
+
"examples_text": examples_text,
|
|
580
|
+
"chain_of_thought": chain_of_thought,
|
|
581
|
+
"context_prompt": context_prompt,
|
|
582
|
+
"step_back_prompt": step_back_prompt,
|
|
583
|
+
"stepback_insights": {},
|
|
584
|
+
"json_schema": json_schemas[cfg["model"]],
|
|
585
|
+
"creativity": creativity,
|
|
586
|
+
"thinking_budget": thinking_budget,
|
|
587
|
+
"multi_label": multi_label,
|
|
588
|
+
}
|
|
589
|
+
result = run_batch_classify(
|
|
590
|
+
items=items,
|
|
591
|
+
cfg=cfg,
|
|
592
|
+
categories=categories,
|
|
593
|
+
prompt_params=prompt_params,
|
|
594
|
+
filename=filename,
|
|
595
|
+
save_directory=save_directory,
|
|
596
|
+
batch_poll_interval=batch_poll_interval,
|
|
597
|
+
batch_timeout=batch_timeout,
|
|
598
|
+
fail_strategy=fail_strategy,
|
|
599
|
+
)
|
|
600
|
+
return _maybe_apply_embeddings(result)
|
|
601
|
+
|
|
602
|
+
# Ensemble batch path: one job per model, run concurrently
|
|
603
|
+
print(
|
|
604
|
+
"[CatLLM] NOTE: batch_mode=True with multiple models is experimental. "
|
|
605
|
+
"Each model submits a separate batch job concurrently. Providers without "
|
|
606
|
+
"a batch API (HuggingFace, Perplexity, Ollama) fall back to synchronous calls."
|
|
607
|
+
)
|
|
608
|
+
from ._batch import run_batch_ensemble_classify
|
|
609
|
+
prompt_params_per_model = {
|
|
610
|
+
cfg["model"]: {
|
|
611
|
+
"categories_str": categories_str,
|
|
612
|
+
"survey_question_context": survey_question_context,
|
|
613
|
+
"examples_text": examples_text,
|
|
614
|
+
"chain_of_thought": chain_of_thought,
|
|
615
|
+
"context_prompt": context_prompt,
|
|
616
|
+
"step_back_prompt": step_back_prompt,
|
|
617
|
+
"stepback_insights": {},
|
|
618
|
+
"json_schema": json_schemas[cfg["model"]],
|
|
619
|
+
"creativity": cfg["creativity"] if cfg["creativity"] is not None else creativity,
|
|
620
|
+
"thinking_budget": thinking_budget,
|
|
621
|
+
"multi_label": multi_label,
|
|
622
|
+
}
|
|
623
|
+
for cfg in model_configs
|
|
624
|
+
}
|
|
625
|
+
result = run_batch_ensemble_classify(
|
|
626
|
+
items=items,
|
|
627
|
+
model_configs=model_configs,
|
|
628
|
+
categories=categories,
|
|
629
|
+
prompt_params_per_model=prompt_params_per_model,
|
|
630
|
+
consensus_threshold=consensus_threshold,
|
|
631
|
+
fail_strategy=fail_strategy,
|
|
632
|
+
filename=filename,
|
|
633
|
+
save_directory=save_directory,
|
|
634
|
+
batch_poll_interval=batch_poll_interval,
|
|
635
|
+
batch_timeout=batch_timeout,
|
|
636
|
+
)
|
|
637
|
+
return _maybe_apply_embeddings(result)
|
|
638
|
+
|
|
639
|
+
result = classify_ensemble(
|
|
640
|
+
input_data=input_data,
|
|
641
|
+
categories=categories,
|
|
642
|
+
models=models,
|
|
643
|
+
input_description=description,
|
|
644
|
+
survey_question=survey_question,
|
|
645
|
+
pdf_mode=pdf_mode,
|
|
646
|
+
pdf_dpi=pdf_dpi,
|
|
647
|
+
creativity=creativity,
|
|
648
|
+
safety=safety,
|
|
649
|
+
chain_of_thought=chain_of_thought,
|
|
650
|
+
chain_of_verification=chain_of_verification,
|
|
651
|
+
step_back_prompt=step_back_prompt,
|
|
652
|
+
context_prompt=context_prompt,
|
|
653
|
+
thinking_budget=thinking_budget,
|
|
654
|
+
use_json_schema=use_json_schema,
|
|
655
|
+
max_workers=max_workers,
|
|
656
|
+
parallel=parallel,
|
|
657
|
+
fail_strategy=fail_strategy,
|
|
658
|
+
max_retries=max_retries,
|
|
659
|
+
batch_retries=batch_retries,
|
|
660
|
+
retry_delay=retry_delay,
|
|
661
|
+
row_delay=row_delay,
|
|
662
|
+
auto_download=auto_download,
|
|
663
|
+
example1=example1,
|
|
664
|
+
example2=example2,
|
|
665
|
+
example3=example3,
|
|
666
|
+
example4=example4,
|
|
667
|
+
example5=example5,
|
|
668
|
+
example6=example6,
|
|
669
|
+
consensus_threshold=consensus_threshold,
|
|
670
|
+
max_categories=max_categories,
|
|
671
|
+
categories_per_chunk=categories_per_chunk,
|
|
672
|
+
divisions=divisions,
|
|
673
|
+
research_question=research_question,
|
|
674
|
+
filename=filename,
|
|
675
|
+
save_directory=save_directory,
|
|
676
|
+
progress_callback=progress_callback,
|
|
677
|
+
formatter_state=_formatter_state,
|
|
678
|
+
multi_label=multi_label,
|
|
679
|
+
categories_per_call=categories_per_call,
|
|
680
|
+
embedding_tiebreaker_state=_embedding_tiebreaker_state,
|
|
681
|
+
)
|
|
682
|
+
return _maybe_apply_embeddings(result)
|