cat-stack 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {cat_stack-0.1.0 → cat_stack-0.2.0}/PKG-INFO +4 -2
  2. {cat_stack-0.1.0 → cat_stack-0.2.0}/pyproject.toml +2 -1
  3. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/__about__.py +1 -1
  4. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/classify.py +15 -1
  5. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/extract.py +12 -2
  6. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/summarize.py +13 -1
  7. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/text_functions_ensemble.py +529 -24
  8. {cat_stack-0.1.0 → cat_stack-0.2.0}/.gitignore +0 -0
  9. {cat_stack-0.1.0 → cat_stack-0.2.0}/LICENSE +0 -0
  10. {cat_stack-0.1.0 → cat_stack-0.2.0}/README.md +0 -0
  11. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/__init__.py +0 -0
  12. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/_batch.py +0 -0
  13. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/_category_analysis.py +0 -0
  14. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/_chunked.py +0 -0
  15. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/_embeddings.py +0 -0
  16. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/_formatter.py +0 -0
  17. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/_providers.py +0 -0
  18. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/_tiebreaker.py +0 -0
  19. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/_utils.py +0 -0
  20. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/_web_fetch.py +0 -0
  21. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/calls/CoVe.py +0 -0
  22. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/calls/__init__.py +0 -0
  23. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/calls/all_calls.py +0 -0
  24. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/calls/image_CoVe.py +0 -0
  25. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/calls/image_stepback.py +0 -0
  26. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/calls/pdf_CoVe.py +0 -0
  27. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/calls/pdf_stepback.py +0 -0
  28. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/calls/stepback.py +0 -0
  29. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/calls/top_n.py +0 -0
  30. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/explore.py +0 -0
  31. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/image_functions.py +0 -0
  32. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/images/circle.png +0 -0
  33. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/images/cube.png +0 -0
  34. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/images/diamond.png +0 -0
  35. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/images/overlapping_pentagons.png +0 -0
  36. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/images/rectangles.png +0 -0
  37. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/model_reference_list.py +0 -0
  38. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/pdf_functions.py +0 -0
  39. {cat_stack-0.1.0 → cat_stack-0.2.0}/src/cat_stack/text_functions.py +0 -0
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cat-stack
3
- Version: 0.1.0
4
- Summary: Domain-agnostic text, image, and PDF classification engine powered by LLMs
3
+ Version: 0.2.0
4
+ Summary: Domain-agnostic text, image, PDF, and DOCX classification engine powered by LLMs
5
5
  Project-URL: Documentation, https://github.com/chrissoria/cat-stack#readme
6
6
  Project-URL: Issues, https://github.com/chrissoria/cat-stack/issues
7
7
  Project-URL: Source, https://github.com/chrissoria/cat-stack
@@ -25,6 +25,8 @@ Requires-Dist: pandas
25
25
  Requires-Dist: perplexityai
26
26
  Requires-Dist: requests
27
27
  Requires-Dist: tqdm
28
+ Provides-Extra: docx
29
+ Requires-Dist: python-docx>=1.0.0; extra == 'docx'
28
30
  Provides-Extra: embeddings
29
31
  Requires-Dist: sentence-transformers>=2.2.0; extra == 'embeddings'
30
32
  Provides-Extra: formatter
@@ -5,7 +5,7 @@ build-backend = "hatchling.build"
5
5
  [project]
6
6
  name = "cat-stack"
7
7
  dynamic = ["version"]
8
- description = "Domain-agnostic text, image, and PDF classification engine powered by LLMs"
8
+ description = "Domain-agnostic text, image, PDF, and DOCX classification engine powered by LLMs"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
11
11
  license = "GPL-3.0-or-later"
@@ -35,6 +35,7 @@ dependencies = [
35
35
 
36
36
  [project.optional-dependencies]
37
37
  pdf = ["PyMuPDF>=1.23.0"]
38
+ docx = ["python-docx>=1.0.0"]
38
39
  formatter = ["torch>=2.0.0", "transformers>=4.40.0", "accelerate>=0.27.0"]
39
40
  embeddings = ["sentence-transformers>=2.2.0"]
40
41
 
@@ -1,7 +1,7 @@
1
1
  # SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
2
2
  #
3
3
  # SPDX-License-Identifier: GPL-3.0-or-later
4
- __version__ = "0.1.0"
4
+ __version__ = "0.2.0"
5
5
  __author__ = "Chris Soria"
6
6
  __email__ = "chrissoria@berkeley.edu"
7
7
  __title__ = "cat-stack"
@@ -48,6 +48,7 @@ def classify(
48
48
  description="",
49
49
  user_model="gpt-4o",
50
50
  mode="image",
51
+ input_mode=None,
51
52
  creativity=None,
52
53
  safety=False,
53
54
  chain_of_verification=False,
@@ -115,7 +116,14 @@ def classify(
115
116
  Kept for backward compatibility.
116
117
  description (str): Description of the input data context.
117
118
  user_model (str): Model name to use. Default "gpt-4o".
118
- mode (str): PDF processing mode:
119
+ input_mode (str): What you want the model to do with the input. Default None.
120
+ - None: Auto-select based on file type (text/docx→"text", image→"visual",
121
+ pdf→uses mode param or "visual")
122
+ - "text": Classify text content, regardless of source format. For images
123
+ and scanned PDFs, uses LLM-based OCR to extract text first.
124
+ - "visual": Classify visual features of images/rendered PDFs. Not
125
+ compatible with text or DOCX input.
126
+ mode (str): PDF processing mode (legacy, use input_mode instead):
119
127
  - "image" (default): Render pages as images
120
128
  - "text": Extract text only
121
129
  - "both": Send both image and extracted text
@@ -534,6 +542,11 @@ def classify(
534
542
  from .text_functions_ensemble import _detect_input_type
535
543
  detected_type = _detect_input_type(input_data)
536
544
  if detected_type in ("pdf", "image"):
545
+ if input_mode == "text":
546
+ raise ValueError(
547
+ "batch_mode=True does not support OCR (input_mode='text' on image/PDF input). "
548
+ "Set batch_mode=False to use OCR-based classification."
549
+ )
537
550
  raise ValueError(
538
551
  f"batch_mode=True only supports text input, but detected input type is '{detected_type}'. "
539
552
  "Set batch_mode=False for PDF/image classification."
@@ -678,5 +691,6 @@ def classify(
678
691
  multi_label=multi_label,
679
692
  categories_per_call=categories_per_call,
680
693
  embedding_tiebreaker_state=_embedding_tiebreaker_state,
694
+ input_mode=input_mode,
681
695
  )
682
696
  return _maybe_apply_embeddings(result)
@@ -40,7 +40,7 @@ from .pdf_functions import (
40
40
  def extract(
41
41
  input_data,
42
42
  api_key,
43
- input_type="text",
43
+ input_type="auto",
44
44
  survey_question="",
45
45
  description=None,
46
46
  max_categories=12,
@@ -59,6 +59,7 @@ def extract(
59
59
  progress_callback=None,
60
60
  chunk_delay: float = 0.0,
61
61
  auto_download: bool = False,
62
+ input_mode=None,
62
63
  ):
63
64
  """
64
65
  Unified category extraction function for text, image, and PDF inputs.
@@ -74,7 +75,8 @@ def extract(
74
75
  - For pdf: directory path, single file, or list of PDF paths
75
76
  api_key (str): API key for the model provider.
76
77
  input_type (str): Type of input data. Options:
77
- - "text" (default): Text responses
78
+ - "auto" (default): Auto-detect from file extensions
79
+ - "text": Text responses
78
80
  - "image": Image files
79
81
  - "pdf": PDF documents
80
82
  survey_question (str): The survey question or description of the text data.
@@ -142,6 +144,14 @@ def extract(
142
144
  """
143
145
  input_type = input_type.lower().rstrip('s') # Normalize: "texts" -> "text", "images" -> "image", "pdfs" -> "pdf"
144
146
 
147
+ # Auto-detect input type if set to "auto"
148
+ if input_type == "auto":
149
+ from .text_functions_ensemble import _detect_input_type
150
+ input_type = _detect_input_type(input_data)
151
+ # docx → text for extraction purposes
152
+ if input_type == "docx":
153
+ input_type = "text"
154
+
145
155
  # survey_question is the canonical name; description is kept for backward compatibility
146
156
  resolved_survey_question = survey_question if survey_question else (description or "")
147
157
 
@@ -36,6 +36,8 @@ def summarize(
36
36
  user_model: str = "gpt-4o",
37
37
  model_source: str = "auto",
38
38
  mode: str = "image",
39
+ input_mode: str = None,
40
+ input_type: str = "auto",
39
41
  pdf_dpi: int = 150,
40
42
  creativity: float = None,
41
43
  thinking_budget: int = 0,
@@ -79,7 +81,15 @@ def summarize(
79
81
  focus (str): What to focus on (e.g., "main arguments", "emotional content")
80
82
  user_model (str): Model to use (default "gpt-4o")
81
83
  model_source (str): Provider - "auto", "openai", "anthropic", "google", etc.
82
- mode (str): PDF processing mode (only used for PDF input):
84
+ input_mode (str): What you want the model to do with the input. Default None.
85
+ - None: Auto-select based on file type (text→"text", image→"visual",
86
+ pdf→uses mode param or "visual")
87
+ - "text": Summarize text content, regardless of source format. For images
88
+ and scanned PDFs, uses LLM-based OCR to extract text first.
89
+ - "visual": Summarize visual features of images/rendered PDFs.
90
+ input_type (str): File type filter. Default "auto" (auto-detect).
91
+ Options: "auto", "pdf", "image", "text", "docx"
92
+ mode (str): PDF processing mode (legacy, use input_mode instead):
83
93
  - "image" (default): Render pages as images
84
94
  - "text": Extract text only
85
95
  - "both": Send both image and extracted text
@@ -287,4 +297,6 @@ def summarize(
287
297
  max_workers=max_workers,
288
298
  parallel=parallel,
289
299
  auto_download=auto_download,
300
+ input_mode=input_mode,
301
+ input_type=input_type,
290
302
  )
@@ -176,49 +176,351 @@ _IMAGE_EXTENSIONS = {
176
176
  '.ico', '.psd', '.jfif', '.pjpeg', '.pjp', '.jpe'
177
177
  }
178
178
 
179
+ _DOCX_EXTENSIONS = {'.docx', '.doc'}
180
+
181
+
182
+ def _extract_docx_text(file_path: str) -> str:
183
+ """Extract text from a DOCX file using python-docx."""
184
+ try:
185
+ from docx import Document
186
+ except ImportError:
187
+ raise ImportError(
188
+ "The 'python-docx' package is required for DOCX support. "
189
+ "Install it with: pip install python-docx"
190
+ )
191
+ doc = Document(file_path)
192
+ return "\n\n".join(p.text for p in doc.paragraphs if p.text.strip())
193
+
194
+
195
+ def _convert_docx_to_text(input_data):
196
+ """Convert DOCX file paths to extracted text strings.
197
+
198
+ Accepts the same input formats as classify(): single path, list, Series, or directory.
199
+ Returns the same shape but with file paths replaced by extracted text.
200
+ """
201
+ def _convert_single(item):
202
+ if item is None or (isinstance(item, float) and pd.isna(item)):
203
+ return item
204
+ item_str = str(item)
205
+ ext = os.path.splitext(item_str)[1].lower()
206
+ if ext in _DOCX_EXTENSIONS and os.path.isfile(item_str):
207
+ return _extract_docx_text(item_str)
208
+ return item
209
+
210
+ # Single file path
211
+ if isinstance(input_data, (str, Path)):
212
+ input_str = str(input_data)
213
+ # Directory of DOCX files
214
+ if os.path.isdir(input_str):
215
+ texts = []
216
+ for f in sorted(os.listdir(input_str)):
217
+ f_path = os.path.join(input_str, f)
218
+ f_ext = os.path.splitext(f)[1].lower()
219
+ if f_ext in _DOCX_EXTENSIONS:
220
+ texts.append(_extract_docx_text(f_path))
221
+ return texts if texts else [input_str]
222
+ return _convert_single(input_str)
223
+
224
+ # List or Series
225
+ if isinstance(input_data, pd.Series):
226
+ return input_data.apply(_convert_single)
227
+
228
+ if hasattr(input_data, '__iter__'):
229
+ return [_convert_single(item) for item in input_data]
230
+
231
+ return input_data
232
+
233
+
234
+ def _ocr_extract_text(
235
+ cfg: dict,
236
+ image_data: dict = None,
237
+ page_data: dict = None,
238
+ max_retries: int = 3,
239
+ ) -> tuple:
240
+ """
241
+ Use an LLM to extract (OCR) text from an image or PDF page.
242
+
243
+ Sends a multimodal message asking the model to return only the raw text
244
+ visible in the document. No JSON schema is used — the response is plain text.
245
+
246
+ Args:
247
+ cfg: Model configuration dict (from prepare_model_configs)
248
+ image_data: Dict from _prepare_image_data (for images)
249
+ page_data: Dict from _prepare_page_data (for PDF pages)
250
+ max_retries: Max retry attempts
251
+
252
+ Returns:
253
+ (extracted_text, error) — error is None on success
254
+ """
255
+ ocr_prompt = (
256
+ "Extract all visible text from this document. "
257
+ "Return only the raw extracted text, preserving paragraph breaks. "
258
+ "Do not add any commentary, labels, or formatting."
259
+ )
260
+
261
+ provider = cfg["provider"]
262
+
263
+ # Build multimodal message based on source type
264
+ if image_data is not None:
265
+ encoded = image_data.get("encoded_image", "")
266
+ ext = image_data.get("extension", "png")
267
+
268
+ if provider == "anthropic":
269
+ content = [
270
+ {"type": "text", "text": ocr_prompt},
271
+ {
272
+ "type": "image",
273
+ "source": {
274
+ "type": "base64",
275
+ "media_type": f"image/{ext}",
276
+ "data": encoded,
277
+ },
278
+ },
279
+ ]
280
+ elif provider == "google":
281
+ content = [
282
+ {"type": "text", "text": ocr_prompt},
283
+ {"type": "inline_data", "mime_type": f"image/{ext}", "data": encoded},
284
+ ]
285
+ else:
286
+ encoded_url = f"data:image/{ext};base64,{encoded}"
287
+ content = [
288
+ {"type": "text", "text": ocr_prompt},
289
+ {"type": "image_url", "image_url": {"url": encoded_url, "detail": "high"}},
290
+ ]
291
+
292
+ elif page_data is not None:
293
+ # PDF page — prefer image_bytes, fall back to pdf_bytes
294
+ if page_data.get("image_bytes"):
295
+ img_b64 = _encode_bytes_to_base64(page_data["image_bytes"])
296
+ if provider == "anthropic":
297
+ content = [
298
+ {"type": "text", "text": ocr_prompt},
299
+ {
300
+ "type": "image",
301
+ "source": {
302
+ "type": "base64",
303
+ "media_type": "image/png",
304
+ "data": img_b64,
305
+ },
306
+ },
307
+ ]
308
+ elif provider == "google":
309
+ content = [
310
+ {"type": "text", "text": ocr_prompt},
311
+ {"type": "inline_data", "mime_type": "image/png", "data": img_b64},
312
+ ]
313
+ else:
314
+ encoded_url = f"data:image/png;base64,{img_b64}"
315
+ content = [
316
+ {"type": "text", "text": ocr_prompt},
317
+ {"type": "image_url", "image_url": {"url": encoded_url, "detail": "high"}},
318
+ ]
319
+ elif page_data.get("pdf_bytes"):
320
+ pdf_b64 = _encode_bytes_to_base64(page_data["pdf_bytes"])
321
+ if provider == "anthropic":
322
+ content = [
323
+ {"type": "text", "text": ocr_prompt},
324
+ {
325
+ "type": "document",
326
+ "source": {
327
+ "type": "base64",
328
+ "media_type": "application/pdf",
329
+ "data": pdf_b64,
330
+ },
331
+ },
332
+ ]
333
+ elif provider == "google":
334
+ content = [
335
+ {"type": "text", "text": ocr_prompt},
336
+ {"type": "inline_data", "mime_type": "application/pdf", "data": pdf_b64},
337
+ ]
338
+ else:
339
+ return ("", "Provider does not support native PDF for OCR; render as image first")
340
+ else:
341
+ return ("", "No image or PDF bytes available for OCR")
342
+ else:
343
+ return ("", "No image_data or page_data provided for OCR")
344
+
345
+ messages = [{"role": "user", "content": content}]
346
+
347
+ try:
348
+ if provider == "google":
349
+ # Google multimodal needs direct API call (handled later in classify_ensemble
350
+ # via _call_google_multimodal). Here we use the same approach but without
351
+ # JSON schema so we get plain text back.
352
+ import requests
353
+
354
+ model_name = cfg["model"]
355
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent"
356
+ headers = {
357
+ "x-goog-api-key": cfg["api_key"],
358
+ "Content-Type": "application/json",
359
+ }
360
+
361
+ parts = []
362
+ for part in content:
363
+ if part.get("type") == "text":
364
+ parts.append({"text": part["text"]})
365
+ elif part.get("type") == "inline_data":
366
+ parts.append({
367
+ "inline_data": {
368
+ "mime_type": part["mime_type"],
369
+ "data": part["data"],
370
+ }
371
+ })
372
+
373
+ payload = {"contents": [{"parts": parts}]}
374
+
375
+ for attempt in range(max_retries):
376
+ try:
377
+ response = requests.post(url, headers=headers, json=payload, timeout=120)
378
+ response.raise_for_status()
379
+ result = response.json()
380
+ if "candidates" in result and result["candidates"]:
381
+ text = result["candidates"][0]["content"]["parts"][0]["text"]
382
+ return (text.strip(), None)
383
+ return ("", "No response from Google OCR")
384
+ except requests.exceptions.HTTPError as e:
385
+ if e.response.status_code in (429, 500, 502, 503, 504) and attempt < max_retries - 1:
386
+ time.sleep(2 * (2 ** attempt))
387
+ else:
388
+ return ("", f"Google OCR HTTP error: {e}")
389
+ except Exception as e:
390
+ if attempt < max_retries - 1:
391
+ time.sleep(2 * (2 ** attempt))
392
+ else:
393
+ return ("", f"Google OCR error: {e}")
394
+ return ("", "Google OCR max retries exceeded")
395
+
396
+ else:
397
+ client = UnifiedLLMClient(
398
+ provider=provider,
399
+ api_key=cfg["api_key"],
400
+ model=cfg["model"],
401
+ )
402
+ reply, error = client.complete(
403
+ messages=messages,
404
+ json_schema=None,
405
+ force_json=False,
406
+ max_retries=max_retries,
407
+ )
408
+ if error:
409
+ return ("", f"OCR error: {error}")
410
+ return (reply.strip() if reply else "", None)
411
+
412
+ except Exception as e:
413
+ return ("", f"OCR exception: {e}")
414
+
415
+
416
+ def _resolve_input_params(
417
+ input_mode,
418
+ input_type,
419
+ old_mode,
420
+ input_data,
421
+ ) -> tuple:
422
+ """
423
+ Resolve the new input_mode/input_type params, handling backward compat.
424
+
425
+ Args:
426
+ input_mode: "text", "visual", or None (auto based on detected type)
427
+ input_type: "auto", "pdf", "image", "docx", "text" — file type filter
428
+ old_mode: Legacy mode param ("image", "text", "both")
429
+ input_data: The raw input data (for auto-detection)
430
+
431
+ Returns:
432
+ (resolved_mode, file_type, warnings) where:
433
+ resolved_mode: "text" or "visual"
434
+ file_type: "text", "pdf", "image", "docx"
435
+ warnings: list of deprecation/info warning strings
436
+ """
437
+ warnings_list = []
438
+
439
+ # Step 1: Detect the file type
440
+ if input_type == "auto":
441
+ file_type = _detect_input_type(input_data)
442
+ else:
443
+ file_type = input_type.lower().rstrip("s")
444
+
445
+ # Step 2: Resolve input_mode
446
+ if input_mode is not None:
447
+ resolved_mode = input_mode.lower()
448
+ if resolved_mode not in ("text", "visual"):
449
+ raise ValueError(
450
+ f"input_mode must be 'text' or 'visual', got '{input_mode}'"
451
+ )
452
+ # Validate: visual mode on text/docx is an error
453
+ if resolved_mode == "visual" and file_type in ("text", "docx"):
454
+ raise ValueError(
455
+ f"input_mode='visual' is not compatible with {file_type} input. "
456
+ f"Visual mode requires image or PDF files."
457
+ )
458
+ # Warn if old mode is also set
459
+ if old_mode and old_mode != "image":
460
+ warnings_list.append(
461
+ f"[CatStack] Both input_mode='{input_mode}' and mode='{old_mode}' "
462
+ f"are set. Using input_mode='{input_mode}' (mode is deprecated)."
463
+ )
464
+ else:
465
+ # input_mode is None — backward compat defaults
466
+ if file_type in ("text", "docx"):
467
+ resolved_mode = "text"
468
+ elif file_type == "image":
469
+ resolved_mode = "visual" # preserve current behavior
470
+ elif file_type == "pdf":
471
+ # Map old mode param to new system
472
+ if old_mode == "text":
473
+ resolved_mode = "text"
474
+ else:
475
+ resolved_mode = "visual" # "image" and "both" → visual
476
+ else:
477
+ resolved_mode = "text"
478
+
479
+ return (resolved_mode, file_type, warnings_list)
480
+
179
481
 
180
482
  def _detect_input_type(input_data) -> str:
181
483
  """
182
- Detect if input is text strings, PDF files, or image files.
484
+ Detect if input is text strings, PDF files, image files, or DOCX files.
183
485
 
184
486
  Auto-detection logic:
185
487
  - If input ends in .pdf → PDF mode
488
+ - If input ends in .docx/.doc → DOCX mode (converted to text)
186
489
  - If input ends in image extension (.png, .jpg, etc.) → Image mode
187
- - If input is a directory → Check first file to determine PDF or Image mode
490
+ - If input is a directory → Check first file to determine mode
188
491
  - Otherwise → Text mode
189
492
 
190
493
  Args:
191
- input_data: Text strings, PDF paths, image paths, or directory path
494
+ input_data: Text strings, PDF paths, image paths, DOCX paths, or directory path
192
495
 
193
496
  Returns:
194
- 'text', 'pdf', or 'image'
497
+ 'text', 'pdf', 'image', or 'docx'
195
498
  """
196
499
  # Handle single string input
197
500
  if isinstance(input_data, (str, Path)):
198
501
  survey_str = str(input_data)
199
502
  ext = os.path.splitext(survey_str)[1].lower()
200
503
 
201
- # Check for PDF
202
504
  if ext == '.pdf':
203
505
  return 'pdf'
204
-
205
- # Check for image
506
+ if ext in _DOCX_EXTENSIONS:
507
+ return 'docx'
206
508
  if ext in _IMAGE_EXTENSIONS:
207
509
  return 'image'
208
510
 
209
- # Check if it's a directory (could contain PDFs or images)
511
+ # Check if it's a directory (could contain PDFs, images, or DOCX)
210
512
  if os.path.isdir(survey_str):
211
- # Check first file to determine type
212
513
  try:
213
514
  for f in sorted(os.listdir(survey_str)):
214
515
  f_ext = os.path.splitext(f)[1].lower()
215
516
  if f_ext == '.pdf':
216
517
  return 'pdf'
518
+ if f_ext in _DOCX_EXTENSIONS:
519
+ return 'docx'
217
520
  if f_ext in _IMAGE_EXTENSIONS:
218
521
  return 'image'
219
522
  except OSError:
220
523
  pass
221
- # Default to PDF for directories (backward compatibility)
222
524
  return 'pdf'
223
525
 
224
526
  return 'text'
@@ -231,9 +533,10 @@ def _detect_input_type(input_data) -> str:
231
533
  ext = os.path.splitext(item_str)[1].lower()
232
534
  if ext == '.pdf':
233
535
  return 'pdf'
536
+ if ext in _DOCX_EXTENSIONS:
537
+ return 'docx'
234
538
  if ext in _IMAGE_EXTENSIONS:
235
539
  return 'image'
236
- # First non-null item is text
237
540
  return 'text'
238
541
 
239
542
  return 'text'
@@ -1766,6 +2069,9 @@ def classify_ensemble(
1766
2069
  categories_per_call: int = None,
1767
2070
  # Embedding tiebreaker
1768
2071
  embedding_tiebreaker_state: dict = None,
2072
+ # New input_mode / input_type parameters
2073
+ input_mode: str = None,
2074
+ input_type: str = "auto",
1769
2075
  ):
1770
2076
  """
1771
2077
  Multi-class classification with support for text AND PDF inputs, single or multiple LLM models.
@@ -1972,15 +2278,38 @@ def classify_ensemble(
1972
2278
  print(f" - {cfg['model']} ({cfg['provider']}) -> column suffix: {cfg['sanitized_name']}")
1973
2279
 
1974
2280
  # =============================================================================
1975
- # DETECT INPUT TYPE: Text vs PDF vs Image
2281
+ # RESOLVE INPUT MODE AND FILE TYPE
1976
2282
  # =============================================================================
1977
- input_type = _detect_input_type(input_data)
1978
- print(f"\nInput type detected: {input_type.upper()}")
2283
+ resolved_mode, file_type, resolve_warnings = _resolve_input_params(
2284
+ input_mode=input_mode,
2285
+ input_type=input_type,
2286
+ old_mode=pdf_mode,
2287
+ input_data=input_data,
2288
+ )
2289
+ for w in resolve_warnings:
2290
+ print(w)
2291
+
2292
+ print(f"\nFile type detected: {file_type.upper()}")
2293
+ print(f"Input mode: {resolved_mode}")
2294
+
2295
+ is_visual_mode = (resolved_mode == "visual")
2296
+ needs_ocr = (resolved_mode == "text" and file_type in ("image", "pdf"))
2297
+
2298
+ # DOCX pre-processing: convert to text, then proceed as text mode
2299
+ if file_type == 'docx':
2300
+ print("Converting DOCX files to text...")
2301
+ input_data = _convert_docx_to_text(input_data)
2302
+ file_type = 'text'
2303
+ needs_ocr = False
2304
+ print(f"Converted to {len(input_data) if hasattr(input_data, '__len__') else 1} text item(s)")
2305
+
2306
+ # Guard: no OCR in batch mode (batch APIs don't support the two-call pattern)
2307
+ # (This is checked at the classify() level too, but guard here as well)
1979
2308
 
1980
2309
  # Initialize processing variables
1981
2310
  items_to_process = []
1982
- is_pdf_mode = (input_type == 'pdf')
1983
- is_image_mode = (input_type == 'image')
2311
+ is_pdf_mode = (file_type == 'pdf')
2312
+ is_image_mode = (file_type == 'image')
1984
2313
 
1985
2314
  # Build example JSON for visual modes (PDF/image)
1986
2315
  category_dict = {str(i+1): "0" for i in range(len(categories))}
@@ -2040,6 +2369,20 @@ def classify_ensemble(
2040
2369
  # =================================================================
2041
2370
  items_to_process = input_data
2042
2371
 
2372
+ # Select OCR model config: first model that supports multimodal (skip Ollama)
2373
+ _ocr_cfg = None
2374
+ if needs_ocr:
2375
+ _text_only_providers = {"ollama"}
2376
+ for cfg in model_configs:
2377
+ if cfg["provider"] not in _text_only_providers:
2378
+ _ocr_cfg = cfg
2379
+ break
2380
+ if _ocr_cfg is None:
2381
+ raise ValueError(
2382
+ "input_mode='text' on image/PDF input requires OCR, but no multimodal-capable "
2383
+ "model is available in the ensemble. Add a cloud provider model (OpenAI, Anthropic, Google, etc.)."
2384
+ )
2385
+
2043
2386
  # Auto-resolve parallel mode: sequential for all-local (Ollama), parallel otherwise
2044
2387
  if parallel is None:
2045
2388
  all_local = all(cfg["provider"] == "ollama" for cfg in model_configs)
@@ -2513,7 +2856,9 @@ Categorize text responses {cove_categorize}:
2513
2856
  total_calls = len(items_to_process) * len(model_configs)
2514
2857
 
2515
2858
  # Set progress description based on mode
2516
- if is_image_mode:
2859
+ if needs_ocr:
2860
+ progress_desc = "OCR + Classifying" + (" images" if is_image_mode else " PDF pages")
2861
+ elif is_image_mode:
2517
2862
  progress_desc = "Classifying images"
2518
2863
  elif is_pdf_mode:
2519
2864
  progress_desc = "Classifying PDF pages"
@@ -2541,6 +2886,61 @@ Categorize text responses {cove_categorize}:
2541
2886
  pdf_metadata = None
2542
2887
  image_metadata = None
2543
2888
 
2889
+ # =================================================================
2890
+ # OCR PRE-PROCESSING: Extract text from images/PDFs before classifying
2891
+ # =================================================================
2892
+ if needs_ocr and not (not is_pdf_mode and not is_image_mode):
2893
+ ocr_text = None
2894
+
2895
+ if is_image_mode and isinstance(item, tuple) and len(item) == 2:
2896
+ img_path, img_label = item
2897
+ img_data = _prepare_image_data(img_path, img_label)
2898
+ if img_data.get("error"):
2899
+ ocr_text = ""
2900
+ else:
2901
+ ocr_text, ocr_err = _ocr_extract_text(
2902
+ cfg=_ocr_cfg, image_data=img_data, max_retries=max_retries
2903
+ )
2904
+ if ocr_err:
2905
+ import sys
2906
+ sys.stderr.write(f"[CatStack] OCR failed for {img_label}: {ocr_err}\n")
2907
+ ocr_text = ""
2908
+
2909
+ elif is_pdf_mode and isinstance(item, tuple) and len(item) == 3:
2910
+ p_path, p_idx, p_label = item
2911
+ # Try PyMuPDF text extraction first
2912
+ text_content, text_valid, text_err = _extract_page_text(p_path, p_idx)
2913
+ if text_valid and text_content and len(text_content.strip()) > 20:
2914
+ ocr_text = text_content
2915
+ else:
2916
+ # No extractable text — render as image and OCR
2917
+ print(f"[CatStack] Page {p_label} has no extractable text. Using LLM-based OCR.")
2918
+ pg_data = _prepare_page_data(
2919
+ pdf_path=p_path,
2920
+ page_index=p_idx,
2921
+ page_label=p_label,
2922
+ pdf_mode="image",
2923
+ provider=_ocr_cfg["provider"],
2924
+ pdf_dpi=pdf_dpi,
2925
+ )
2926
+ if pg_data.get("error"):
2927
+ ocr_text = ""
2928
+ else:
2929
+ ocr_text, ocr_err = _ocr_extract_text(
2930
+ cfg=_ocr_cfg, page_data=pg_data, max_retries=max_retries
2931
+ )
2932
+ if ocr_err:
2933
+ import sys
2934
+ sys.stderr.write(f"[CatStack] OCR failed for {p_label}: {ocr_err}\n")
2935
+ ocr_text = ""
2936
+
2937
+ # Replace the item with OCR-extracted text for classification
2938
+ if ocr_text is not None:
2939
+ _pre_ocr_item = item # preserve original tuple for retry
2940
+ item = ocr_text
2941
+ # After OCR, this item is a plain text string so classify_single
2942
+ # will naturally route to the text classification path
2943
+
2544
2944
  # Check for NaN (text mode only)
2545
2945
  if not is_pdf_mode and not is_image_mode and pd.isna(item):
2546
2946
  # Handle NaN - mark as skipped, bypass classification entirely
@@ -3087,6 +3487,9 @@ def summarize_ensemble(
3087
3487
  max_workers: int = None,
3088
3488
  parallel: bool = None,
3089
3489
  auto_download: bool = False,
3490
+ # New input_mode / input_type parameters
3491
+ input_mode: str = None,
3492
+ input_type: str = "auto",
3090
3493
  ) -> pd.DataFrame:
3091
3494
  """
3092
3495
  Summarize text or PDF inputs using LLMs with optional multi-model ensemble.
@@ -3171,9 +3574,30 @@ def summarize_ensemble(
3171
3574
  if safety and filename is None:
3172
3575
  raise TypeError("filename is required when using safety=True.")
3173
3576
 
3174
- # Detect input type: Text vs PDF
3175
- input_type = _detect_input_type(input_data)
3176
- is_pdf_mode = (input_type == 'pdf')
3577
+ # Resolve input mode and file type
3578
+ resolved_mode, file_type, resolve_warnings = _resolve_input_params(
3579
+ input_mode=input_mode,
3580
+ input_type=input_type,
3581
+ old_mode=pdf_mode,
3582
+ input_data=input_data,
3583
+ )
3584
+ for w in resolve_warnings:
3585
+ print(w)
3586
+
3587
+ print(f"\nFile type detected: {file_type.upper()}")
3588
+ print(f"Input mode: {resolved_mode}")
3589
+
3590
+ needs_ocr = (resolved_mode == "text" and file_type in ("image", "pdf"))
3591
+ is_pdf_mode = (file_type == 'pdf')
3592
+
3593
+ # DOCX pre-processing
3594
+ if file_type == 'docx':
3595
+ print("Converting DOCX files to text...")
3596
+ input_data = _convert_docx_to_text(input_data)
3597
+ file_type = 'text'
3598
+ is_pdf_mode = False
3599
+ needs_ocr = False
3600
+ print(f"Converted to {len(input_data) if hasattr(input_data, '__len__') else 1} text item(s)")
3177
3601
 
3178
3602
  if is_pdf_mode:
3179
3603
  # Validate pdf_mode parameter
@@ -3181,7 +3605,6 @@ def summarize_ensemble(
3181
3605
  if pdf_mode not in {"image", "text", "both"}:
3182
3606
  raise ValueError(f"pdf_mode must be 'image', 'text', or 'both', got: {pdf_mode}")
3183
3607
 
3184
- print(f"\nInput type detected: PDF")
3185
3608
  print(f"PDF processing mode: {pdf_mode}")
3186
3609
 
3187
3610
  # Load PDF files
@@ -3198,6 +3621,17 @@ def summarize_ensemble(
3198
3621
 
3199
3622
  items_to_process = all_pages
3200
3623
  print(f"Total PDF pages to summarize: {len(items_to_process)}")
3624
+ elif file_type == 'image':
3625
+ # IMAGE MODE: Load images
3626
+ print(f"Loading images...")
3627
+ image_files = _load_image_files(input_data)
3628
+ if not image_files:
3629
+ raise ValueError("No images found in the provided input.")
3630
+ items_to_process = [
3631
+ (img_path, os.path.splitext(os.path.basename(img_path))[0])
3632
+ for img_path in image_files
3633
+ ]
3634
+ print(f"Total images to summarize: {len(items_to_process)}")
3201
3635
  else:
3202
3636
  # TEXT MODE: Normalize input to list
3203
3637
  print(f"\nInput type detected: TEXT")
@@ -3381,8 +3815,30 @@ def summarize_ensemble(
3381
3815
  error_msg = str(e)
3382
3816
  return (model_name, '{"summary": ""}', error_msg)
3383
3817
 
3818
+ # Select OCR model config if needed
3819
+ _ocr_cfg = None
3820
+ if needs_ocr:
3821
+ _text_only_providers = {"ollama"}
3822
+ for cfg in model_configs:
3823
+ if cfg["provider"] not in _text_only_providers:
3824
+ _ocr_cfg = cfg
3825
+ break
3826
+ if _ocr_cfg is None:
3827
+ raise ValueError(
3828
+ "input_mode='text' on image/PDF input requires OCR, but no multimodal-capable "
3829
+ "model is available. Add a cloud provider model (OpenAI, Anthropic, Google, etc.)."
3830
+ )
3831
+
3384
3832
  # Process all items
3385
- progress_desc = "Summarizing PDF pages" if is_pdf_mode else "Summarizing texts"
3833
+ is_image_mode = (file_type == 'image')
3834
+ if needs_ocr:
3835
+ progress_desc = "OCR + Summarizing" + (" images" if is_image_mode else " PDF pages")
3836
+ elif is_pdf_mode:
3837
+ progress_desc = "Summarizing PDF pages"
3838
+ elif is_image_mode:
3839
+ progress_desc = "Summarizing images"
3840
+ else:
3841
+ progress_desc = "Summarizing texts"
3386
3842
  print(f"\n{progress_desc}...")
3387
3843
 
3388
3844
  # Auto-resolve parallel mode: sequential for all-local (Ollama), parallel otherwise
@@ -3397,6 +3853,55 @@ def summarize_ensemble(
3397
3853
  total_items = len(items_to_process)
3398
3854
 
3399
3855
  for idx, item in enumerate(tqdm(items_to_process, desc=progress_desc)):
3856
+ original_item = item # preserve for metadata extraction
3857
+
3858
+ # OCR pre-processing: extract text from images/PDFs before summarizing
3859
+ if needs_ocr:
3860
+ ocr_text = None
3861
+
3862
+ if is_image_mode and isinstance(item, tuple) and len(item) == 2:
3863
+ img_path, img_label = item
3864
+ img_data = _prepare_image_data(img_path, img_label)
3865
+ if img_data.get("error"):
3866
+ ocr_text = ""
3867
+ else:
3868
+ ocr_text, ocr_err = _ocr_extract_text(
3869
+ cfg=_ocr_cfg, image_data=img_data, max_retries=max_retries
3870
+ )
3871
+ if ocr_err:
3872
+ import sys
3873
+ sys.stderr.write(f"[CatStack] OCR failed for {img_label}: {ocr_err}\n")
3874
+ ocr_text = ""
3875
+
3876
+ elif is_pdf_mode and isinstance(item, tuple) and len(item) == 3:
3877
+ p_path, p_idx, p_label = item
3878
+ text_content, text_valid, text_err = _extract_page_text(p_path, p_idx)
3879
+ if text_valid and text_content and len(text_content.strip()) > 20:
3880
+ ocr_text = text_content
3881
+ else:
3882
+ print(f"[CatStack] Page {p_label} has no extractable text. Using LLM-based OCR.")
3883
+ pg_data = _prepare_page_data(
3884
+ pdf_path=p_path,
3885
+ page_index=p_idx,
3886
+ page_label=p_label,
3887
+ pdf_mode="image",
3888
+ provider=_ocr_cfg["provider"],
3889
+ pdf_dpi=pdf_dpi,
3890
+ )
3891
+ if pg_data.get("error"):
3892
+ ocr_text = ""
3893
+ else:
3894
+ ocr_text, ocr_err = _ocr_extract_text(
3895
+ cfg=_ocr_cfg, page_data=pg_data, max_retries=max_retries
3896
+ )
3897
+ if ocr_err:
3898
+ import sys
3899
+ sys.stderr.write(f"[CatStack] OCR failed for {p_label}: {ocr_err}\n")
3900
+ ocr_text = ""
3901
+
3902
+ if ocr_text is not None:
3903
+ item = ocr_text
3904
+
3400
3905
  item_results = {}
3401
3906
  item_errors = {}
3402
3907
 
@@ -3422,10 +3927,10 @@ def summarize_ensemble(
3422
3927
  item_errors[model_name] = error
3423
3928
  failed_pairs.append((idx, model_name))
3424
3929
 
3425
- # Store results for this item
3930
+ # Store results for this item (use original_item to preserve metadata tuples)
3426
3931
  result_entry = {
3427
3932
  "idx": idx,
3428
- "input_data": item,
3933
+ "input_data": original_item,
3429
3934
  "model_results": item_results,
3430
3935
  "errors": item_errors,
3431
3936
  }
File without changes
File without changes
File without changes