cat-stack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cat_stack/classify.py ADDED
@@ -0,0 +1,682 @@
1
+ """
2
+ Classification functions for CatLLM.
3
+
4
+ This module provides unified classification for text, image, and PDF inputs,
5
+ supporting both single-model and multi-model (ensemble) classification.
6
+ """
7
+
8
+ import math
9
+ import warnings
10
+ from typing import Union, Callable
11
+
12
+ __all__ = [
13
+ # Main entry point
14
+ "classify",
15
+ # Ensemble function
16
+ "classify_ensemble",
17
+ # Deprecated functions (kept for backward compatibility)
18
+ "multi_class",
19
+ "image_multi_class",
20
+ "pdf_multi_class",
21
+ ]
22
+
23
+ # Import provider infrastructure
24
+ from ._providers import (
25
+ UnifiedLLMClient,
26
+ detect_provider,
27
+ )
28
+
29
+ # Category analysis
30
+ from ._category_analysis import has_other_category, check_category_verbosity
31
+
32
+ # Import the implementation functions from existing modules
33
+ from .text_functions_ensemble import (
34
+ classify_ensemble,
35
+ )
36
+
37
+ # Import deprecated functions for backward compatibility
38
+ from .text_functions import multi_class
39
+ from .image_functions import image_multi_class
40
+ from .pdf_functions import pdf_multi_class
41
+
42
+
43
+ def classify(
44
+ input_data,
45
+ categories,
46
+ api_key=None,
47
+ input_type="text",
48
+ description="",
49
+ user_model="gpt-4o",
50
+ mode="image",
51
+ creativity=None,
52
+ safety=False,
53
+ chain_of_verification=False,
54
+ chain_of_thought=False,
55
+ step_back_prompt=False,
56
+ context_prompt=False,
57
+ thinking_budget=0,
58
+ example1=None,
59
+ example2=None,
60
+ example3=None,
61
+ example4=None,
62
+ example5=None,
63
+ example6=None,
64
+ filename=None,
65
+ save_directory=None,
66
+ model_source="auto",
67
+ max_categories=12,
68
+ categories_per_chunk=10,
69
+ divisions=10,
70
+ research_question=None,
71
+ progress_callback=None,
72
+ # Batch mode parameters
73
+ batch_mode: bool = False,
74
+ batch_poll_interval: float = 30.0,
75
+ batch_timeout: float = 86400.0,
76
+ # Multi-model parameters
77
+ models=None,
78
+ consensus_threshold: Union[str, float] = "unanimous",
79
+ # Parameters previously only on classify_ensemble
80
+ survey_question: str = "",
81
+ use_json_schema: bool = True,
82
+ max_workers: int = None,
83
+ parallel: bool = None,
84
+ fail_strategy: str = "partial",
85
+ max_retries: int = 5,
86
+ batch_retries: int = 2,
87
+ retry_delay: float = 1.0,
88
+ row_delay: float = 0.0,
89
+ pdf_dpi: int = 150,
90
+ auto_download: bool = False,
91
+ add_other = "prompt",
92
+ check_verbosity: bool = True,
93
+ json_formatter: bool = False,
94
+ embeddings: bool = False,
95
+ category_descriptions: dict = None,
96
+ embedding_tiebreaker: bool = False,
97
+ min_centroid_size: int = 3,
98
+ multi_label: bool = True,
99
+ categories_per_call: int = None,
100
+ ):
101
+ """
102
+ Unified classification function for text, image, and PDF inputs.
103
+
104
+ Supports single-model and multi-model (ensemble) classification. Input type
105
+ is auto-detected from the data (text strings, image paths, or PDF paths).
106
+
107
+ Args:
108
+ input_data: The data to classify. Can be:
109
+ - For text: list of text responses or pandas Series
110
+ - For image: directory path or list of image file paths
111
+ - For pdf: directory path or list of PDF file paths
112
+ categories (list): List of category names for classification.
113
+ api_key (str): API key for the model provider (single-model mode).
114
+ input_type (str): DEPRECATED - input type is now auto-detected.
115
+ Kept for backward compatibility.
116
+ description (str): Description of the input data context.
117
+ user_model (str): Model name to use. Default "gpt-4o".
118
+ mode (str): PDF processing mode:
119
+ - "image" (default): Render pages as images
120
+ - "text": Extract text only
121
+ - "both": Send both image and extracted text
122
+ creativity (float): Temperature setting. None uses model default.
123
+ safety (bool): If True, saves progress after each item.
124
+ chain_of_verification (bool): Enable Chain of Verification for accuracy.
125
+ chain_of_thought (bool): Enable step-by-step reasoning. Default False.
126
+ step_back_prompt (bool): Enable step-back prompting.
127
+ context_prompt (bool): Add expert context to prompts.
128
+ thinking_budget (int): Controls reasoning behavior per provider:
129
+ Google: token budget for extended thinking (0=off, >0=on).
130
+ OpenAI: maps to reasoning_effort (0="minimal", >0="high").
131
+ Anthropic: enables extended thinking (0=off, >0=on, min 1024).
132
+ example1-6 (str): Example categorizations for few-shot learning.
133
+ filename (str): Output filename for CSV.
134
+ save_directory (str): Directory to save results.
135
+ model_source (str): Provider - "auto", "openai", "anthropic", "google",
136
+ "mistral", "perplexity", "huggingface", "xai".
137
+ progress_callback: Optional callback for progress updates.
138
+ batch_mode (bool): If True, use async batch API (50% cost savings, higher rate limits).
139
+ Supported providers: openai, anthropic, google, mistral, xai.
140
+ Not supported: huggingface, perplexity, ollama.
141
+ Ensemble mode: supported. Each model submits its own batch job concurrently.
142
+ Providers without batch API (HuggingFace, Perplexity, Ollama) fall back to
143
+ synchronous calls and are merged in with the batch results.
144
+ Incompatible with: PDF/image input, progress_callback.
145
+ batch_poll_interval (float): Seconds between batch job status checks. Default 30.
146
+ batch_timeout (float): Max seconds to wait for batch completion. Default 86400 (24h).
147
+ models (list): For multi-model mode, list of (model, provider, api_key) tuples.
148
+ If provided, overrides user_model/api_key/model_source.
149
+ consensus_threshold (str or float): For multi-model mode, agreement threshold.
150
+ - "unanimous": 100% agreement (default — empirically produces
151
+ the highest accuracy by aggressively eliminating false positives)
152
+ - "majority": 50% agreement
153
+ - "two-thirds": 67% agreement
154
+ - float: Custom threshold between 0 and 1
155
+ survey_question (str): The survey question (used when categories="auto").
156
+ use_json_schema (bool): Use JSON schema for structured output. Default True.
157
+ max_workers (int): Max parallel workers for API calls. None = auto.
158
+ parallel (bool): Controls concurrent vs sequential model execution.
159
+ - None (default): auto-detect. Sequential for local models (Ollama),
160
+ parallel for cloud providers.
161
+ - True: force parallel execution.
162
+ - False: force sequential execution.
163
+ Sequential mode is useful for resource-constrained environments
164
+ (e.g., Ollama on limited hardware) or debugging.
165
+ fail_strategy (str): How to handle failures - "partial" (default) or "strict".
166
+ max_retries (int): Max retries per API call. Default 5.
167
+ batch_retries (int): Max retries for batch-level failures. Default 2.
168
+ retry_delay (float): Delay between retries in seconds. Default 1.0.
169
+ row_delay (float): Delay in seconds between processing each row. Useful
170
+ when multiple models share the same API provider/key to avoid rate
171
+ limits. Default 0.0 (no delay).
172
+ pdf_dpi (int): DPI for PDF page rendering. Default 150.
173
+ auto_download (bool): Auto-download Ollama models. Default False.
174
+ add_other (str or bool): Controls auto-addition of an "Other" catch-all
175
+ category when none is detected. An "Other" category improves accuracy
176
+ by preventing the model from forcing ambiguous responses into
177
+ ill-fitting categories.
178
+ - "prompt" (default): Ask the user to accept or reject the suggestion.
179
+ - True: Silently add "Other" without prompting.
180
+ - False: Never add "Other".
181
+ check_verbosity (bool): Check whether each category has a description
182
+ and examples (1 API call). Verbose categories with descriptions and
183
+ examples significantly improve classification accuracy over bare
184
+ labels. Default True. Set to False to skip.
185
+ json_formatter (bool): If True, use a local fine-tuned model to fix
186
+ malformed JSON output from classification LLMs before marking
187
+ responses as failed. The formatter runs only when extract_json()
188
+ produces invalid output — zero cost on the happy path. On first
189
+ use, the model (~1GB) is downloaded from HuggingFace Hub.
190
+ Requires: pip install cat-llm[formatter]. Default False.
191
+ embeddings (bool): If True, add embedding-based similarity scores
192
+ alongside binary 0/1 classifications. Uses a local sentence-
193
+ transformer model (BAAI/bge-small-en-v1.5, ~130MB) to compute
194
+ cosine similarity between each input text and each category.
195
+ Scores are independent per (text, category) pair — no softmax.
196
+ On first use, the model is downloaded from HuggingFace Hub.
197
+ Only works with text input (skipped for PDF/image).
198
+ Requires: pip install cat-llm[embeddings]. Default False.
199
+ category_descriptions (dict): Optional dict mapping category names
200
+ to richer text descriptions for embedding similarity. E.g.,
201
+ {"Past_Support": "References to help received from family"}.
202
+ Only used when embeddings=True.
203
+ embedding_tiebreaker (bool): If True, use embedding centroids to
204
+ resolve true ties in ensemble consensus. Builds per-category
205
+ centroids from unanimously-agreed rows and compares tied texts
206
+ to those centroids. Only applies to multi-model ensemble mode
207
+ with text input. Requires: pip install cat-llm[embeddings].
208
+ Default False.
209
+ min_centroid_size (int): Minimum number of unanimously-agreed rows
210
+ needed to build a centroid for a category. Categories with fewer
211
+ confident rows fall back to vote-based consensus. Default 3.
212
+ multi_label (bool): If True (default), allow multiple categories per
213
+ input (multi-label classification). If False, the prompt instructs
214
+ the model to pick the single best category (single-label mode).
215
+ The output format is unchanged — still one 0/1 column per category,
216
+ but exactly one column will be "1" per row in single-label mode.
217
+ categories_per_call (int): Maximum number of categories to send per
218
+ LLM call. When set, the category list is split into chunks of this
219
+ size, each chunk gets its own LLM call with local 1..N numbering,
220
+ and results are merged back into global numbering. This reduces
221
+ prompt complexity per call and can improve accuracy for large
222
+ category sets (e.g., 20+). Default None (all categories in one call).
223
+ Not supported with batch_mode=True.
224
+
225
+ Returns:
226
+ pd.DataFrame: Results with classification columns.
227
+
228
+ Examples:
229
+ >>> import cat_stack as cat
230
+ >>>
231
+ >>> # Single model classification
232
+ >>> results = cat.classify(
233
+ ... input_data=df['responses'],
234
+ ... categories=["Positive", "Negative", "Neutral"],
235
+ ... description="Customer feedback survey",
236
+ ... api_key="your-api-key"
237
+ ... )
238
+ >>>
239
+ >>> # Multi-model ensemble
240
+ >>> results = cat.classify(
241
+ ... input_data=df['responses'],
242
+ ... categories=["Positive", "Negative"],
243
+ ... models=[
244
+ ... ("gpt-4o", "openai", "sk-..."),
245
+ ... ("claude-sonnet-4-5-20250929", "anthropic", "sk-ant-..."),
246
+ ... ],
247
+ ... consensus_threshold="unanimous", # or "majority", "two-thirds", or 0.75
248
+ ... )
249
+ """
250
+ # Build models list
251
+ if models is None:
252
+ # Single model mode - build models list from individual params
253
+ models = [(user_model, model_source, api_key)]
254
+
255
+ # Auto-append "Other" catch-all category if missing
256
+ if add_other and categories and categories != "auto":
257
+ if not has_other_category(categories):
258
+ if add_other == "prompt":
259
+ print(
260
+ "\n[CatLLM] It looks like your categories may not include a catch-all\n"
261
+ " 'Other' option. Adding one can improve accuracy by giving the\n"
262
+ " model an outlet for ambiguous responses instead of forcing them\n"
263
+ " into ill-fitting categories.\n"
264
+ " (If you already have a catch-all under a different name, choose 'n'.)\n"
265
+ )
266
+ try:
267
+ answer = input(" Add 'Other' to your categories? (Y/n): ").strip().lower()
268
+ except (EOFError, KeyboardInterrupt):
269
+ answer = "n"
270
+ if answer in ("", "y", "yes"):
271
+ categories = list(categories) + ["Other"]
272
+ print(f" -> Categories are now: {categories}\n")
273
+ else:
274
+ print(" -> Keeping original categories.\n")
275
+ else:
276
+ # add_other=True — silently add
277
+ categories = list(categories) + ["Other"]
278
+ print(
279
+ f"[CatLLM] Auto-added 'Other' catch-all category. "
280
+ f"Categories are now: {categories} "
281
+ f"(set add_other=False to disable)"
282
+ )
283
+
284
+ # Check category verbosity (1 API call)
285
+ if check_verbosity and categories and categories != "auto":
286
+ # Extract API key and provider from first model entry
287
+ first_entry = models[0]
288
+ check_key = first_entry[2] if len(first_entry) >= 3 else None
289
+ check_source = first_entry[1] if len(first_entry) >= 2 else "auto"
290
+
291
+ if check_key:
292
+ try:
293
+ verbosity = check_category_verbosity(
294
+ categories,
295
+ api_key=check_key,
296
+ model_source=check_source,
297
+ )
298
+ lacking = [r for r in verbosity if not r["is_verbose"]]
299
+
300
+ if lacking:
301
+ missing_desc = [r for r in lacking if not r["has_description"]]
302
+ missing_ex = [r for r in lacking if not r["has_examples"]]
303
+
304
+ print(
305
+ "\n[CatLLM] Category verbosity check (set check_verbosity=False to skip):"
306
+ )
307
+ for r in lacking:
308
+ issues = []
309
+ if not r["has_description"]:
310
+ issues.append("description")
311
+ if not r["has_examples"]:
312
+ issues.append("examples")
313
+ print(f' - "{r["category"]}" (missing: {", ".join(issues)})')
314
+
315
+ print(
316
+ "\n Verbose categories with descriptions and examples significantly\n"
317
+ " improve classification accuracy over bare labels.\n"
318
+ "\n"
319
+ " Instead of:\n"
320
+ ' "Positive"\n'
321
+ " Consider:\n"
322
+ ' "Positive: The response expresses satisfaction, approval, or\n'
323
+ " happiness (e.g., 'I love this product', 'Great experience',\n"
324
+ " 'Very pleased with the result')\"\n"
325
+ )
326
+ except Exception:
327
+ pass # Non-critical — don't block classification
328
+
329
+ # =========================================================================
330
+ # Validate categories_per_call
331
+ # =========================================================================
332
+ if categories_per_call is not None:
333
+ if not isinstance(categories_per_call, int) or categories_per_call < 1:
334
+ raise ValueError(
335
+ f"categories_per_call must be a positive integer, got {categories_per_call!r}"
336
+ )
337
+ if batch_mode:
338
+ raise ValueError(
339
+ "categories_per_call is not supported with batch_mode=True. "
340
+ "Set batch_mode=False to use categories_per_call."
341
+ )
342
+ if categories and categories != "auto":
343
+ if categories_per_call >= len(categories):
344
+ categories_per_call = None # no-op, all categories fit in one call
345
+ else:
346
+ num_chunks = math.ceil(len(categories) / categories_per_call)
347
+ print(
348
+ f"[CatLLM] categories_per_call={categories_per_call}: "
349
+ f"splitting {len(categories)} categories into {num_chunks} chunks"
350
+ )
351
+
352
+ # =========================================================================
353
+ # Evidence-based warnings for prompting strategies
354
+ # Based on empirical findings from Soria et al. (2026) comparing prompting
355
+ # strategies across 4 representative models and 4 survey tasks.
356
+ # =========================================================================
357
+ _strategy_warnings = []
358
+
359
+ if chain_of_verification:
360
+ _strategy_warnings.append(
361
+ "[CatLLM] WARNING: chain_of_verification=True is enabled.\n"
362
+ " Empirical evidence shows CoVe DEGRADES accuracy by ~2 pp and\n"
363
+ " sensitivity by up to 12 pp for structured classification tasks.\n"
364
+ " The verification step causes models to retract correct classifications.\n"
365
+ " Cost: ~4x API calls per response.\n"
366
+ " This feature is provided for research purposes only — it is not\n"
367
+ " recommended for improving classification accuracy."
368
+ )
369
+
370
+ examples = [example1, example2, example3, example4, example5, example6]
371
+ n_examples = sum(1 for ex in examples if ex is not None)
372
+ if n_examples > 0:
373
+ _strategy_warnings.append(
374
+ f"[CatLLM] NOTE: {n_examples} few-shot example(s) provided.\n"
375
+ " Empirical evidence shows few-shot examples DEGRADE accuracy by\n"
376
+ " ~1.1-1.2 pp on average. Examples encourage over-classification\n"
377
+ " (sensitivity up, but precision drops ~2-3 pp), amplifying false\n"
378
+ " positives. This feature is provided for research purposes — for\n"
379
+ " best results, use verbose category definitions instead."
380
+ )
381
+
382
+ if thinking_budget and thinking_budget > 0:
383
+ _strategy_warnings.append(
384
+ f"[CatLLM] NOTE: thinking_budget={thinking_budget} is enabled.\n"
385
+ " Empirical evidence shows reasoning/thinking modes produce negligible\n"
386
+ " accuracy gains (<1 pp) for classification tasks, while significantly\n"
387
+ " increasing latency, token usage, and failure rates (up to 40% timeouts\n"
388
+ " observed for some models). Consider thinking_budget=0 unless you are\n"
389
+ " specifically researching reasoning effects."
390
+ )
391
+
392
+ if chain_of_thought:
393
+ _strategy_warnings.append(
394
+ "[CatLLM] NOTE: chain_of_thought=True is enabled.\n"
395
+ " Empirical evidence shows CoT has no measurable effect on structured\n"
396
+ " classification accuracy (~0 pp change). When categories are well-defined\n"
397
+ " with verbose descriptions, explicit reasoning steps add no value.\n"
398
+ " This won't hurt results, but it won't help either."
399
+ )
400
+
401
+ if step_back_prompt:
402
+ _strategy_warnings.append(
403
+ "[CatLLM] NOTE: step_back_prompt=True is enabled.\n"
404
+ " Empirical evidence shows step-back prompting produces small, inconsistent\n"
405
+ " gains (+0.6 pp average) and actually degrades top-tier model performance.\n"
406
+ " Cost: ~2x API calls per response."
407
+ )
408
+
409
+ if _strategy_warnings:
410
+ print()
411
+ print("\n\n".join(_strategy_warnings))
412
+ print()
413
+
414
+ # =========================================================================
415
+ # JSON formatter fallback (opt-in)
416
+ # =========================================================================
417
+ _formatter_state = None
418
+ if json_formatter:
419
+ try:
420
+ from ._formatter import ensure_formatter_available, load_formatter
421
+
422
+ if ensure_formatter_available():
423
+ fmt_model, fmt_tokenizer, fmt_device = load_formatter()
424
+ _formatter_state = {
425
+ "model": fmt_model,
426
+ "tokenizer": fmt_tokenizer,
427
+ "device": fmt_device,
428
+ }
429
+ else:
430
+ json_formatter = False
431
+ print("[CatLLM] Continuing without JSON formatter fallback.")
432
+ except ImportError as e:
433
+ json_formatter = False
434
+ print(f"[CatLLM] JSON formatter unavailable: {e}")
435
+ print("[CatLLM] Continuing without JSON formatter fallback.")
436
+
437
+ # =========================================================================
438
+ # Embedding-based probability scores (opt-in)
439
+ # =========================================================================
440
+ _embedding_state = None
441
+ if embeddings:
442
+ try:
443
+ from ._embeddings import ensure_embeddings_available, load_embedding_model
444
+
445
+ if ensure_embeddings_available():
446
+ _embedding_state = {
447
+ "model": load_embedding_model(),
448
+ "category_descriptions": category_descriptions,
449
+ }
450
+ else:
451
+ embeddings = False
452
+ print("[CatLLM] Continuing without embedding scores.")
453
+ except ImportError as e:
454
+ embeddings = False
455
+ print(f"[CatLLM] Embeddings unavailable: {e}")
456
+ print("[CatLLM] Continuing without embedding scores.")
457
+
458
+ # Helper: apply embedding scores to a result DataFrame if enabled
459
+ def _maybe_apply_embeddings(result):
460
+ if _embedding_state is None:
461
+ return result
462
+ from ._embeddings import apply_embedding_scores
463
+ import pandas as _pd
464
+ if isinstance(result, _pd.DataFrame):
465
+ return apply_embedding_scores(
466
+ result, categories, _embedding_state["model"],
467
+ _embedding_state["category_descriptions"],
468
+ )
469
+ return result
470
+
471
+ # Map mode to pdf_mode
472
+ pdf_mode = mode if mode in ("image", "text", "both") else "image"
473
+
474
+ # Guard: skip embeddings for PDF/image input (embeddings need text)
475
+ if _embedding_state is not None:
476
+ from .text_functions_ensemble import _detect_input_type
477
+ _emb_detected_type = _detect_input_type(input_data)
478
+ if _emb_detected_type in ("pdf", "image"):
479
+ print(
480
+ f"[CatLLM] Embedding scores skipped — not supported for {_emb_detected_type} input."
481
+ )
482
+ _embedding_state = None
483
+
484
+ # =========================================================================
485
+ # Embedding tiebreaker setup (opt-in)
486
+ # =========================================================================
487
+ _embedding_tiebreaker_state = None
488
+ if embedding_tiebreaker:
489
+ # Guards: skip for single-model, PDF/image, batch mode
490
+ is_single_model = models is not None and len(models) == 1
491
+ if is_single_model:
492
+ print("[CatLLM] Embedding tiebreaker skipped — not applicable for single-model mode.")
493
+ else:
494
+ # Check input type
495
+ from .text_functions_ensemble import _detect_input_type
496
+ _tb_detected_type = _detect_input_type(input_data)
497
+ if _tb_detected_type in ("pdf", "image"):
498
+ print(
499
+ f"[CatLLM] Embedding tiebreaker skipped — not supported for {_tb_detected_type} input."
500
+ )
501
+ else:
502
+ try:
503
+ from ._embeddings import ensure_embeddings_available, load_embedding_model
504
+
505
+ # Reuse embedding model if embeddings=True already loaded it
506
+ if _embedding_state is not None:
507
+ tb_model = _embedding_state["model"]
508
+ elif ensure_embeddings_available():
509
+ tb_model = load_embedding_model()
510
+ else:
511
+ tb_model = None
512
+ print("[CatLLM] Continuing without embedding tiebreaker.")
513
+
514
+ if tb_model is not None:
515
+ # Resolve threshold to numeric for the tiebreaker
516
+ from .text_functions_ensemble import _resolve_consensus_threshold
517
+ _embedding_tiebreaker_state = {
518
+ "model": tb_model,
519
+ "threshold": _resolve_consensus_threshold(consensus_threshold),
520
+ "min_centroid_size": min_centroid_size,
521
+ }
522
+ except ImportError as e:
523
+ print(f"[CatLLM] Embedding tiebreaker unavailable: {e}")
524
+ print("[CatLLM] Continuing without embedding tiebreaker.")
525
+
526
+ # =========================================================================
527
+ # Batch mode — bypass classify_ensemble entirely
528
+ # =========================================================================
529
+ if batch_mode:
530
+ from ._batch import UNSUPPORTED_BATCH_PROVIDERS, run_batch_classify
531
+ from .text_functions_ensemble import prepare_json_schemas, prepare_model_configs
532
+
533
+ # Guard: text input only (auto-detect)
534
+ from .text_functions_ensemble import _detect_input_type
535
+ detected_type = _detect_input_type(input_data)
536
+ if detected_type in ("pdf", "image"):
537
+ raise ValueError(
538
+ f"batch_mode=True only supports text input, but detected input type is '{detected_type}'. "
539
+ "Set batch_mode=False for PDF/image classification."
540
+ )
541
+
542
+ # Warn if embedding_tiebreaker was provided (not supported in batch mode yet)
543
+ if _embedding_tiebreaker_state is not None:
544
+ print(
545
+ "[CatLLM] WARNING: embedding_tiebreaker is not supported in batch_mode. "
546
+ "The tiebreaker will be skipped for this run."
547
+ )
548
+ _embedding_tiebreaker_state = None
549
+
550
+ # Warn if progress_callback was provided (incompatible with batch)
551
+ if progress_callback is not None:
552
+ print(
553
+ "[CatLLM] WARNING: progress_callback is not available in batch_mode "
554
+ "(no per-item progress until the job completes). Ignoring callback."
555
+ )
556
+
557
+ # Build prompt components (mirrors what classify_ensemble does)
558
+ categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
559
+ survey_question_context = f"Context: {survey_question}." if survey_question else ""
560
+ examples = [example1, example2, example3, example4, example5, example6]
561
+ examples_text = "\n".join(
562
+ f"Example {i}: {ex}" for i, ex in enumerate(examples, 1) if ex is not None
563
+ )
564
+
565
+ model_configs = prepare_model_configs(models, auto_download=auto_download)
566
+ json_schemas = prepare_json_schemas(model_configs, categories, use_json_schema)
567
+ items = list(input_data) if not isinstance(input_data, list) else input_data
568
+
569
+ if len(models) == 1:
570
+ cfg = model_configs[0]
571
+ if cfg["provider"] in UNSUPPORTED_BATCH_PROVIDERS:
572
+ raise ValueError(
573
+ f"batch_mode=True is not supported for provider '{cfg['provider']}'. "
574
+ f"Supported providers: openai, anthropic, google, mistral, xai."
575
+ )
576
+ prompt_params = {
577
+ "categories_str": categories_str,
578
+ "survey_question_context": survey_question_context,
579
+ "examples_text": examples_text,
580
+ "chain_of_thought": chain_of_thought,
581
+ "context_prompt": context_prompt,
582
+ "step_back_prompt": step_back_prompt,
583
+ "stepback_insights": {},
584
+ "json_schema": json_schemas[cfg["model"]],
585
+ "creativity": creativity,
586
+ "thinking_budget": thinking_budget,
587
+ "multi_label": multi_label,
588
+ }
589
+ result = run_batch_classify(
590
+ items=items,
591
+ cfg=cfg,
592
+ categories=categories,
593
+ prompt_params=prompt_params,
594
+ filename=filename,
595
+ save_directory=save_directory,
596
+ batch_poll_interval=batch_poll_interval,
597
+ batch_timeout=batch_timeout,
598
+ fail_strategy=fail_strategy,
599
+ )
600
+ return _maybe_apply_embeddings(result)
601
+
602
+ # Ensemble batch path: one job per model, run concurrently
603
+ print(
604
+ "[CatLLM] NOTE: batch_mode=True with multiple models is experimental. "
605
+ "Each model submits a separate batch job concurrently. Providers without "
606
+ "a batch API (HuggingFace, Perplexity, Ollama) fall back to synchronous calls."
607
+ )
608
+ from ._batch import run_batch_ensemble_classify
609
+ prompt_params_per_model = {
610
+ cfg["model"]: {
611
+ "categories_str": categories_str,
612
+ "survey_question_context": survey_question_context,
613
+ "examples_text": examples_text,
614
+ "chain_of_thought": chain_of_thought,
615
+ "context_prompt": context_prompt,
616
+ "step_back_prompt": step_back_prompt,
617
+ "stepback_insights": {},
618
+ "json_schema": json_schemas[cfg["model"]],
619
+ "creativity": cfg["creativity"] if cfg["creativity"] is not None else creativity,
620
+ "thinking_budget": thinking_budget,
621
+ "multi_label": multi_label,
622
+ }
623
+ for cfg in model_configs
624
+ }
625
+ result = run_batch_ensemble_classify(
626
+ items=items,
627
+ model_configs=model_configs,
628
+ categories=categories,
629
+ prompt_params_per_model=prompt_params_per_model,
630
+ consensus_threshold=consensus_threshold,
631
+ fail_strategy=fail_strategy,
632
+ filename=filename,
633
+ save_directory=save_directory,
634
+ batch_poll_interval=batch_poll_interval,
635
+ batch_timeout=batch_timeout,
636
+ )
637
+ return _maybe_apply_embeddings(result)
638
+
639
+ result = classify_ensemble(
640
+ input_data=input_data,
641
+ categories=categories,
642
+ models=models,
643
+ input_description=description,
644
+ survey_question=survey_question,
645
+ pdf_mode=pdf_mode,
646
+ pdf_dpi=pdf_dpi,
647
+ creativity=creativity,
648
+ safety=safety,
649
+ chain_of_thought=chain_of_thought,
650
+ chain_of_verification=chain_of_verification,
651
+ step_back_prompt=step_back_prompt,
652
+ context_prompt=context_prompt,
653
+ thinking_budget=thinking_budget,
654
+ use_json_schema=use_json_schema,
655
+ max_workers=max_workers,
656
+ parallel=parallel,
657
+ fail_strategy=fail_strategy,
658
+ max_retries=max_retries,
659
+ batch_retries=batch_retries,
660
+ retry_delay=retry_delay,
661
+ row_delay=row_delay,
662
+ auto_download=auto_download,
663
+ example1=example1,
664
+ example2=example2,
665
+ example3=example3,
666
+ example4=example4,
667
+ example5=example5,
668
+ example6=example6,
669
+ consensus_threshold=consensus_threshold,
670
+ max_categories=max_categories,
671
+ categories_per_chunk=categories_per_chunk,
672
+ divisions=divisions,
673
+ research_question=research_question,
674
+ filename=filename,
675
+ save_directory=save_directory,
676
+ progress_callback=progress_callback,
677
+ formatter_state=_formatter_state,
678
+ multi_label=multi_label,
679
+ categories_per_call=categories_per_call,
680
+ embedding_tiebreaker_state=_embedding_tiebreaker_state,
681
+ )
682
+ return _maybe_apply_embeddings(result)