natural-pdf 0.1.33__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. natural_pdf/analyzers/__init__.py +2 -2
  2. natural_pdf/analyzers/guides.py +670 -595
  3. natural_pdf/analyzers/layout/base.py +53 -6
  4. natural_pdf/analyzers/layout/layout_analyzer.py +3 -1
  5. natural_pdf/analyzers/layout/layout_manager.py +18 -14
  6. natural_pdf/analyzers/layout/layout_options.py +1 -0
  7. natural_pdf/analyzers/layout/paddle.py +102 -64
  8. natural_pdf/analyzers/layout/table_structure_utils.py +3 -1
  9. natural_pdf/analyzers/layout/yolo.py +2 -6
  10. natural_pdf/analyzers/shape_detection_mixin.py +15 -6
  11. natural_pdf/classification/manager.py +92 -77
  12. natural_pdf/classification/mixin.py +49 -5
  13. natural_pdf/classification/results.py +1 -1
  14. natural_pdf/cli.py +7 -3
  15. natural_pdf/collections/pdf_collection.py +96 -101
  16. natural_pdf/core/element_manager.py +131 -45
  17. natural_pdf/core/highlighting_service.py +5 -6
  18. natural_pdf/core/page.py +113 -22
  19. natural_pdf/core/pdf.py +477 -75
  20. natural_pdf/describe/__init__.py +18 -12
  21. natural_pdf/describe/base.py +179 -172
  22. natural_pdf/describe/elements.py +155 -155
  23. natural_pdf/describe/mixin.py +27 -19
  24. natural_pdf/describe/summary.py +44 -55
  25. natural_pdf/elements/base.py +134 -18
  26. natural_pdf/elements/collections.py +90 -18
  27. natural_pdf/elements/image.py +2 -1
  28. natural_pdf/elements/line.py +0 -31
  29. natural_pdf/elements/rect.py +0 -14
  30. natural_pdf/elements/region.py +222 -108
  31. natural_pdf/elements/text.py +18 -12
  32. natural_pdf/exporters/__init__.py +4 -1
  33. natural_pdf/exporters/original_pdf.py +12 -4
  34. natural_pdf/extraction/mixin.py +66 -10
  35. natural_pdf/extraction/result.py +1 -1
  36. natural_pdf/flows/flow.py +63 -4
  37. natural_pdf/flows/region.py +4 -4
  38. natural_pdf/ocr/engine.py +83 -2
  39. natural_pdf/ocr/engine_paddle.py +5 -5
  40. natural_pdf/ocr/ocr_factory.py +2 -1
  41. natural_pdf/ocr/ocr_manager.py +24 -13
  42. natural_pdf/ocr/ocr_options.py +3 -10
  43. natural_pdf/qa/document_qa.py +21 -8
  44. natural_pdf/qa/qa_result.py +3 -7
  45. natural_pdf/search/__init__.py +3 -2
  46. natural_pdf/search/lancedb_search_service.py +5 -6
  47. natural_pdf/search/numpy_search_service.py +5 -2
  48. natural_pdf/selectors/parser.py +51 -6
  49. natural_pdf/tables/__init__.py +2 -2
  50. natural_pdf/tables/result.py +7 -6
  51. natural_pdf/utils/bidi_mirror.py +2 -1
  52. natural_pdf/utils/reading_order.py +3 -2
  53. natural_pdf/utils/visualization.py +3 -3
  54. natural_pdf/widgets/viewer.py +0 -1
  55. {natural_pdf-0.1.33.dist-info → natural_pdf-0.1.34.dist-info}/METADATA +1 -1
  56. natural_pdf-0.1.34.dist-info/RECORD +121 -0
  57. optimization/memory_comparison.py +73 -58
  58. optimization/pdf_analyzer.py +141 -96
  59. optimization/performance_analysis.py +111 -110
  60. optimization/test_cleanup_methods.py +47 -36
  61. optimization/test_memory_fix.py +40 -39
  62. tools/bad_pdf_eval/__init__.py +0 -1
  63. tools/bad_pdf_eval/analyser.py +35 -18
  64. tools/bad_pdf_eval/collate_summaries.py +22 -18
  65. tools/bad_pdf_eval/compile_attempts_markdown.py +127 -0
  66. tools/bad_pdf_eval/eval_suite.py +21 -9
  67. tools/bad_pdf_eval/evaluate_quality.py +198 -0
  68. tools/bad_pdf_eval/export_enrichment_csv.py +12 -8
  69. tools/bad_pdf_eval/llm_enrich.py +71 -39
  70. tools/bad_pdf_eval/llm_enrich_with_retry.py +289 -0
  71. tools/bad_pdf_eval/reporter.py +1 -1
  72. tools/bad_pdf_eval/utils.py +7 -4
  73. natural_pdf-0.1.33.dist-info/RECORD +0 -118
  74. {natural_pdf-0.1.33.dist-info → natural_pdf-0.1.34.dist-info}/WHEEL +0 -0
  75. {natural_pdf-0.1.33.dist-info → natural_pdf-0.1.34.dist-info}/entry_points.txt +0 -0
  76. {natural_pdf-0.1.33.dist-info → natural_pdf-0.1.34.dist-info}/licenses/LICENSE +0 -0
  77. {natural_pdf-0.1.33.dist-info → natural_pdf-0.1.34.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import time
3
+ import threading # Add threading for locks
3
4
  from datetime import datetime
4
5
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
5
6
 
@@ -9,6 +10,7 @@ from PIL import Image
9
10
  # Use try-except for robustness if dependencies are missing
10
11
  _CLASSIFICATION_AVAILABLE = None
11
12
 
13
+
12
14
  def _check_classification_dependencies():
13
15
  """Lazy check for classification dependencies."""
14
16
  global _CLASSIFICATION_AVAILABLE
@@ -16,16 +18,20 @@ def _check_classification_dependencies():
16
18
  try:
17
19
  import torch
18
20
  import transformers
21
+
19
22
  _CLASSIFICATION_AVAILABLE = True
20
23
  except ImportError:
21
24
  _CLASSIFICATION_AVAILABLE = False
22
25
  return _CLASSIFICATION_AVAILABLE
23
26
 
27
+
24
28
  def _get_torch():
25
29
  """Lazy import for torch."""
26
30
  import torch
31
+
27
32
  return torch
28
33
 
34
+
29
35
  def _get_transformers_components():
30
36
  """Lazy import for transformers components."""
31
37
  from transformers import (
@@ -34,13 +40,15 @@ def _get_transformers_components():
34
40
  AutoTokenizer,
35
41
  pipeline,
36
42
  )
43
+
37
44
  return {
38
- 'AutoModelForSequenceClassification': AutoModelForSequenceClassification,
39
- 'AutoModelForZeroShotImageClassification': AutoModelForZeroShotImageClassification,
40
- 'AutoTokenizer': AutoTokenizer,
41
- 'pipeline': pipeline,
45
+ "AutoModelForSequenceClassification": AutoModelForSequenceClassification,
46
+ "AutoModelForZeroShotImageClassification": AutoModelForZeroShotImageClassification,
47
+ "AutoTokenizer": AutoTokenizer,
48
+ "pipeline": pipeline,
42
49
  }
43
50
 
51
+
44
52
  from tqdm.auto import tqdm
45
53
 
46
54
  # Import result classes
@@ -52,10 +60,12 @@ if TYPE_CHECKING:
52
60
 
53
61
  logger = logging.getLogger(__name__)
54
62
 
55
- # Global cache for models/pipelines
63
+ # Global cache for models/pipelines with thread safety
56
64
  _PIPELINE_CACHE: Dict[str, "Pipeline"] = {}
57
65
  _TOKENIZER_CACHE: Dict[str, Any] = {}
58
66
  _MODEL_CACHE: Dict[str, Any] = {}
67
+ _CACHE_LOCK = threading.RLock() # Reentrant lock for thread safety
68
+
59
69
 
60
70
  # Export the availability check function for external use
61
71
  def is_classification_available() -> bool:
@@ -107,34 +117,35 @@ class ClassificationManager:
107
117
  def _get_pipeline(self, model_id: str, using: str) -> "Pipeline":
108
118
  """Get or create a classification pipeline."""
109
119
  cache_key = f"{model_id}_{using}_{self.device}"
110
- if cache_key not in _PIPELINE_CACHE:
111
- logger.info(
112
- f"Loading {using} classification pipeline for model '{model_id}' on device '{self.device}'..."
113
- )
114
- start_time = time.time()
115
- try:
116
- # Lazy import transformers components
117
- transformers_components = _get_transformers_components()
118
- pipeline = transformers_components['pipeline']
119
-
120
- task = (
121
- "zero-shot-classification"
122
- if using == "text"
123
- else "zero-shot-image-classification"
124
- )
125
- _PIPELINE_CACHE[cache_key] = pipeline(task, model=model_id, device=self.device)
126
- end_time = time.time()
120
+ with _CACHE_LOCK:
121
+ if cache_key not in _PIPELINE_CACHE:
127
122
  logger.info(
128
- f"Pipeline for '{model_id}' loaded in {end_time - start_time:.2f} seconds."
129
- )
130
- except Exception as e:
131
- logger.error(
132
- f"Failed to load pipeline for model '{model_id}' (using: {using}): {e}",
133
- exc_info=True,
123
+ f"Loading {using} classification pipeline for model '{model_id}' on device '{self.device}'..."
134
124
  )
135
- raise ClassificationError(
136
- f"Failed to load pipeline for model '{model_id}'. Ensure the model ID is correct and supports the {task} task."
137
- ) from e
125
+ start_time = time.time()
126
+ try:
127
+ # Lazy import transformers components
128
+ transformers_components = _get_transformers_components()
129
+ pipeline = transformers_components["pipeline"]
130
+
131
+ task = (
132
+ "zero-shot-classification"
133
+ if using == "text"
134
+ else "zero-shot-image-classification"
135
+ )
136
+ _PIPELINE_CACHE[cache_key] = pipeline(task, model=model_id, device=self.device)
137
+ end_time = time.time()
138
+ logger.info(
139
+ f"Pipeline for '{model_id}' loaded in {end_time - start_time:.2f} seconds."
140
+ )
141
+ except Exception as e:
142
+ logger.error(
143
+ f"Failed to load pipeline for model '{model_id}' (using: {using}): {e}",
144
+ exc_info=True,
145
+ )
146
+ raise ClassificationError(
147
+ f"Failed to load pipeline for model '{model_id}'. Ensure the model ID is correct and supports the {task} task."
148
+ ) from e
138
149
  return _PIPELINE_CACHE[cache_key]
139
150
 
140
151
  def infer_using(self, model_id: str, using: Optional[str] = None) -> str:
@@ -452,66 +463,70 @@ class ClassificationManager:
452
463
  def cleanup_models(self, model_id: Optional[str] = None) -> int:
453
464
  """
454
465
  Cleanup classification models to free memory.
455
-
466
+
456
467
  Args:
457
468
  model_id: Specific model to cleanup, or None to cleanup all models
458
-
469
+
459
470
  Returns:
460
471
  Number of models cleaned up
461
472
  """
462
473
  global _PIPELINE_CACHE, _TOKENIZER_CACHE, _MODEL_CACHE
463
-
474
+
464
475
  cleaned_count = 0
465
-
476
+
466
477
  if model_id:
467
478
  # Cleanup specific model - search cache keys that contain the model_id
468
- keys_to_remove = [key for key in _PIPELINE_CACHE.keys() if model_id in key]
469
- for key in keys_to_remove:
470
- pipeline = _PIPELINE_CACHE.pop(key, None)
471
- if pipeline and hasattr(pipeline, 'model'):
472
- # Try to cleanup GPU memory if using torch
473
- try:
474
- torch = _get_torch()
475
- if hasattr(pipeline.model, 'to'):
476
- pipeline.model.to('cpu') # Move to CPU
477
- if torch.cuda.is_available():
478
- torch.cuda.empty_cache() # Clear GPU cache
479
- except Exception as e:
480
- logger.debug(f"GPU cleanup failed for model {model_id}: {e}")
481
-
482
- cleaned_count += 1
483
- logger.info(f"Cleaned up classification pipeline: {key}")
484
-
479
+ with _CACHE_LOCK:
480
+ keys_to_remove = [key for key in _PIPELINE_CACHE.keys() if model_id in key]
481
+ for key in keys_to_remove:
482
+ pipeline = _PIPELINE_CACHE.pop(key, None)
483
+ if pipeline and hasattr(pipeline, "model"):
484
+ # Try to cleanup GPU memory if using torch
485
+ try:
486
+ torch = _get_torch()
487
+ if hasattr(pipeline.model, "to"):
488
+ pipeline.model.to("cpu") # Move to CPU
489
+ if torch.cuda.is_available():
490
+ torch.cuda.empty_cache() # Clear GPU cache
491
+ except Exception as e:
492
+ logger.debug(f"GPU cleanup failed for model {model_id}: {e}")
493
+
494
+ cleaned_count += 1
495
+ logger.info(f"Cleaned up classification pipeline: {key}")
496
+
485
497
  # Also cleanup tokenizer and model caches for this model
486
- tokenizer_keys = [key for key in _TOKENIZER_CACHE.keys() if model_id in key]
487
- for key in tokenizer_keys:
488
- _TOKENIZER_CACHE.pop(key, None)
489
-
490
- model_keys = [key for key in _MODEL_CACHE.keys() if model_id in key]
491
- for key in model_keys:
492
- _MODEL_CACHE.pop(key, None)
493
-
498
+ with _CACHE_LOCK:
499
+ tokenizer_keys = [key for key in _TOKENIZER_CACHE.keys() if model_id in key]
500
+ for key in tokenizer_keys:
501
+ _TOKENIZER_CACHE.pop(key, None)
502
+
503
+ model_keys = [key for key in _MODEL_CACHE.keys() if model_id in key]
504
+ for key in model_keys:
505
+ _MODEL_CACHE.pop(key, None)
506
+
494
507
  else:
495
508
  # Cleanup all models
496
- for key, pipeline in list(_PIPELINE_CACHE.items()):
497
- if hasattr(pipeline, 'model'):
498
- try:
499
- torch = _get_torch()
500
- if hasattr(pipeline.model, 'to'):
501
- pipeline.model.to('cpu') # Move to CPU
502
- if torch.cuda.is_available():
503
- torch.cuda.empty_cache() # Clear GPU cache
504
- except Exception as e:
505
- logger.debug(f"GPU cleanup failed for pipeline {key}: {e}")
506
-
509
+ with _CACHE_LOCK:
510
+ for key, pipeline in list(_PIPELINE_CACHE.items()):
511
+ if hasattr(pipeline, "model"):
512
+ try:
513
+ torch = _get_torch()
514
+ if hasattr(pipeline.model, "to"):
515
+ pipeline.model.to("cpu") # Move to CPU
516
+ if torch.cuda.is_available():
517
+ torch.cuda.empty_cache() # Clear GPU cache
518
+ except Exception as e:
519
+ logger.debug(f"GPU cleanup failed for pipeline {key}: {e}")
520
+
507
521
  # Clear all caches
508
- pipeline_count = len(_PIPELINE_CACHE)
509
- _PIPELINE_CACHE.clear()
510
- _TOKENIZER_CACHE.clear()
511
- _MODEL_CACHE.clear()
512
-
522
+ with _CACHE_LOCK:
523
+ pipeline_count = len(_PIPELINE_CACHE)
524
+ _PIPELINE_CACHE.clear()
525
+ _TOKENIZER_CACHE.clear()
526
+ _MODEL_CACHE.clear()
527
+
513
528
  if pipeline_count > 0:
514
529
  logger.info(f"Cleaned up {pipeline_count} classification models")
515
530
  cleaned_count = pipeline_count
516
-
531
+
517
532
  return cleaned_count
@@ -1,8 +1,8 @@
1
1
  import logging
2
+ import warnings
2
3
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
3
4
 
4
5
  from PIL import Image
5
- import warnings
6
6
 
7
7
  from .results import ClassificationResult
8
8
 
@@ -16,9 +16,51 @@ logger = logging.getLogger(__name__)
16
16
 
17
17
 
18
18
  class ClassificationMixin:
19
- """
20
- Mixin class providing classification capabilities to Page and Region objects.
21
- Relies on a ClassificationManager being accessible, typically via the parent PDF.
19
+ """Mixin class providing classification capabilities to Page and Region objects.
20
+
21
+ This mixin adds AI-powered classification functionality to pages, regions, and
22
+ elements, enabling document categorization and content analysis using both
23
+ text-based and vision-based models. It integrates with the ClassificationManager
24
+ to provide a consistent interface across different model types.
25
+
26
+ The mixin supports both single-label and multi-label classification, confidence
27
+ thresholding, and various analysis storage strategies for complex workflows.
28
+ Results are stored in the host object's 'analyses' dictionary for retrieval
29
+ and further processing.
30
+
31
+ Classification modes:
32
+ - Text-based: Uses extracted text content for classification
33
+ - Vision-based: Uses rendered images for visual classification
34
+ - Automatic: Manager selects best mode based on content availability
35
+
36
+ Host class requirements:
37
+ - Must implement _get_classification_manager() -> ClassificationManager
38
+ - Must implement _get_classification_content() -> str | Image
39
+ - Must have 'analyses' attribute as Dict[str, Any]
40
+
41
+ Example:
42
+ ```python
43
+ pdf = npdf.PDF("document.pdf")
44
+ page = pdf.pages[0]
45
+
46
+ # Document type classification
47
+ page.classify(['invoice', 'contract', 'report'],
48
+ model='text', analysis_key='doc_type')
49
+
50
+ # Multi-label content analysis
51
+ region = page.find('text:contains("Summary")').below()
52
+ region.classify(['technical', 'financial', 'legal'],
53
+ multi_label=True, min_confidence=0.8)
54
+
55
+ # Access results
56
+ doc_type = page.analyses['doc_type']
57
+ content_labels = region.analyses['classification']
58
+ ```
59
+
60
+ Note:
61
+ Classification requires appropriate models to be available through the
62
+ ClassificationManager. Results include confidence scores and detailed
63
+ metadata for analysis workflows.
22
64
  """
23
65
 
24
66
  # --- Abstract methods/properties required by the host class --- #
@@ -86,7 +128,9 @@ class ClassificationMixin:
86
128
  # Try text first
87
129
  try:
88
130
  tentative_text = self._get_classification_content("text", **kwargs)
89
- if tentative_text and not (isinstance(tentative_text, str) and tentative_text.isspace()):
131
+ if tentative_text and not (
132
+ isinstance(tentative_text, str) and tentative_text.isspace()
133
+ ):
90
134
  engine = "text"
91
135
  content = tentative_text
92
136
  else:
@@ -1,9 +1,9 @@
1
1
  # natural_pdf/classification/results.py
2
2
  import logging
3
+ from collections.abc import Mapping
3
4
  from dataclasses import dataclass
4
5
  from datetime import datetime
5
6
  from typing import Any, Dict, List, Optional
6
- from collections.abc import Mapping
7
7
 
8
8
  logger = logging.getLogger(__name__)
9
9
 
natural_pdf/cli.py CHANGED
@@ -1,9 +1,11 @@
1
1
  import argparse
2
2
  import subprocess
3
3
  import sys
4
- from importlib.metadata import distribution, PackageNotFoundError, version as get_version
4
+ from importlib.metadata import PackageNotFoundError, distribution
5
+ from importlib.metadata import version as get_version
5
6
  from pathlib import Path
6
7
  from typing import Dict
8
+
7
9
  from packaging.requirements import Requirement
8
10
 
9
11
  # ---------------------------------------------------------------------------
@@ -71,7 +73,9 @@ def main():
71
73
  install_p = subparsers.add_parser(
72
74
  "install", help="Install optional dependency groups (e.g. paddle, surya)"
73
75
  )
74
- install_p.add_argument("extras", nargs="+", help="One or more extras to install (e.g. paddle surya)")
76
+ install_p.add_argument(
77
+ "extras", nargs="+", help="One or more extras to install (e.g. paddle surya)"
78
+ )
75
79
  install_p.set_defaults(func=cmd_install)
76
80
 
77
81
  # list subcommand -------------------------------------------------------
@@ -113,4 +117,4 @@ def cmd_list(args):
113
117
 
114
118
 
115
119
  if __name__ == "__main__":
116
- main()
120
+ main()
@@ -548,37 +548,31 @@ class PDFCollection(
548
548
  labels: List[str],
549
549
  using: Optional[str] = None, # Default handled by PDF.classify -> manager
550
550
  model: Optional[str] = None, # Optional model ID
551
- max_workers: Optional[int] = None,
552
551
  analysis_key: str = "classification", # Key for storing result in PDF.analyses
553
552
  **kwargs,
554
553
  ) -> "PDFCollection":
555
554
  """
556
- Classify each PDF document in the collection, potentially in parallel.
555
+ Classify each PDF document in the collection using batch processing.
557
556
 
558
- This method delegates classification to each PDF object's `classify` method.
559
- By default, uses the full extracted text of the PDF.
560
- If `using='vision'`, it classifies the first page's image, but ONLY if
561
- the PDF has a single page (raises ValueError otherwise).
557
+ This method gathers content from all PDFs and processes them in a single
558
+ batch to avoid multiprocessing resource accumulation that can occur with
559
+ sequential individual classifications.
562
560
 
563
561
  Args:
564
562
  labels: A list of string category names.
565
563
  using: Processing mode ('text', 'vision'). If None, manager infers (defaulting to text).
566
564
  model: Optional specific model identifier (e.g., HF ID). If None, manager uses default for 'using' mode.
567
- max_workers: Maximum number of threads to process PDFs concurrently.
568
- If None or 1, processing is sequential.
569
565
  analysis_key: Key under which to store the ClassificationResult in each PDF's `analyses` dict.
570
- **kwargs: Additional arguments passed down to `pdf.classify` (e.g., device,
571
- min_confidence, multi_label, text extraction options).
566
+ **kwargs: Additional arguments passed down to the ClassificationManager.
572
567
 
573
568
  Returns:
574
569
  Self for method chaining.
575
570
 
576
571
  Raises:
577
572
  ValueError: If labels list is empty, or if using='vision' on a multi-page PDF.
578
- ClassificationError: If classification fails for any PDF (will stop processing).
573
+ ClassificationError: If classification fails.
579
574
  ImportError: If classification dependencies are missing.
580
575
  """
581
- PDF = self._get_pdf_class()
582
576
  if not labels:
583
577
  raise ValueError("Labels list cannot be empty.")
584
578
 
@@ -588,102 +582,103 @@ class PDFCollection(
588
582
 
589
583
  mode_desc = f"using='{using}'" if using else f"model='{model}'" if model else "default text"
590
584
  logger.info(
591
- f"Starting classification for {len(self._pdfs)} PDFs in collection ({mode_desc})..."
585
+ f"Starting batch classification for {len(self._pdfs)} PDFs in collection ({mode_desc})..."
592
586
  )
593
587
 
594
- progress_bar = tqdm(
595
- total=len(self._pdfs), desc=f"Classifying PDFs ({mode_desc})", unit="pdf"
596
- )
597
-
598
- # Worker function
599
- def _process_pdf_classification(pdf: PDF):
600
- thread_id = threading.current_thread().name
601
- pdf_path = pdf.path
602
- logger.debug(f"[{thread_id}] Starting classification process for PDF: {pdf_path}")
603
- start_time = time.monotonic()
588
+ # Get classification manager from first PDF
589
+ try:
590
+ first_pdf = self._pdfs[0]
591
+ if not hasattr(first_pdf, 'get_manager'):
592
+ raise RuntimeError("PDFs do not support classification manager")
593
+ manager = first_pdf.get_manager('classification')
594
+ if not manager or not manager.is_available():
595
+ raise RuntimeError("ClassificationManager is not available")
596
+ except Exception as e:
597
+ from natural_pdf.classification.manager import ClassificationError
598
+ raise ClassificationError(f"Cannot access ClassificationManager: {e}") from e
599
+
600
+ # Determine processing mode early
601
+ inferred_using = manager.infer_using(model if model else manager.DEFAULT_TEXT_MODEL, using)
602
+
603
+ # Gather content from all PDFs
604
+ pdf_contents = []
605
+ valid_pdfs = []
606
+
607
+ logger.info(f"Gathering content from {len(self._pdfs)} PDFs for batch classification...")
608
+
609
+ for pdf in self._pdfs:
604
610
  try:
605
- # Call classify directly on the PDF object
606
- pdf.classify(
607
- labels=labels,
608
- using=using,
609
- model=model,
610
- analysis_key=analysis_key,
611
- **kwargs, # Pass other relevant args like min_confidence, multi_label
612
- )
613
- end_time = time.monotonic()
614
- logger.debug(
615
- f"[{thread_id}] Finished classification for PDF: {pdf_path} (Duration: {end_time - start_time:.2f}s)"
616
- )
617
- progress_bar.update(1) # Update progress bar upon success
618
- return pdf_path, None # Return path and no error
619
- except ValueError as ve:
620
- # Catch specific error for vision on multi-page PDF
621
- end_time = time.monotonic()
622
- logger.error(
623
- f"[{thread_id}] Skipped classification for {pdf_path} after {end_time - start_time:.2f}s: {ve}",
624
- exc_info=False,
625
- )
626
- progress_bar.update(1) # Still update progress bar
627
- return pdf_path, ve # Return the specific ValueError
611
+ # Get the content for classification - use the same logic as individual PDF classify
612
+ if inferred_using == "text":
613
+ # Extract text content from PDF
614
+ content = pdf.extract_text()
615
+ if not content or content.isspace():
616
+ logger.warning(f"Skipping PDF {pdf.path}: No text content found")
617
+ continue
618
+ elif inferred_using == "vision":
619
+ # For vision, we need single-page PDFs only
620
+ if len(pdf.pages) != 1:
621
+ logger.warning(f"Skipping PDF {pdf.path}: Vision classification requires single-page PDFs")
622
+ continue
623
+ # Get first page image
624
+ content = pdf.pages[0].to_image()
625
+ else:
626
+ raise ValueError(f"Unsupported using mode: {inferred_using}")
627
+
628
+ pdf_contents.append(content)
629
+ valid_pdfs.append(pdf)
630
+
628
631
  except Exception as e:
629
- end_time = time.monotonic()
630
- logger.error(
631
- f"[{thread_id}] Failed classification process for PDF {pdf_path} after {end_time - start_time:.2f}s: {e}",
632
- exc_info=True, # Log full traceback for unexpected errors
633
- )
634
- # Close progress bar immediately on critical error to avoid hanging
635
- if not progress_bar.disable:
636
- progress_bar.close()
637
- # Re-raise the exception to stop the entire collection processing
638
- raise ClassificationError(f"Classification failed for {pdf_path}: {e}") from e
632
+ logger.warning(f"Skipping PDF {pdf.path}: Error getting content - {e}")
633
+ continue
639
634
 
640
- # Use ThreadPoolExecutor for parallel processing if max_workers > 1
641
- processed_count = 0
642
- skipped_count = 0
635
+ if not pdf_contents:
636
+ logger.warning("No valid content could be gathered from PDFs for classification.")
637
+ return self
638
+
639
+ logger.info(f"Gathered content from {len(valid_pdfs)} PDFs. Running batch classification...")
640
+
641
+ # Run batch classification
643
642
  try:
644
- if max_workers is not None and max_workers > 1:
645
- logger.info(f"Classifying PDFs in parallel with {max_workers} workers.")
646
- futures = []
647
- with concurrent.futures.ThreadPoolExecutor(
648
- max_workers=max_workers, thread_name_prefix="ClassifyWorker"
649
- ) as executor:
650
- for pdf in self._pdfs:
651
- futures.append(executor.submit(_process_pdf_classification, pdf))
652
-
653
- # Wait for all futures to complete
654
- # Progress updated within worker
655
- for future in concurrent.futures.as_completed(futures):
656
- processed_count += 1
657
- pdf_path, error = (
658
- future.result()
659
- ) # Raise ClassificationError if worker failed critically
660
- if isinstance(error, ValueError):
661
- # Logged in worker, just count as skipped
662
- skipped_count += 1
663
-
664
- else: # Sequential processing
665
- logger.info("Classifying PDFs sequentially.")
666
- for pdf in self._pdfs:
667
- processed_count += 1
668
- pdf_path, error = _process_pdf_classification(
669
- pdf
670
- ) # Raise ClassificationError if worker failed critically
671
- if isinstance(error, ValueError):
672
- skipped_count += 1
673
-
674
- final_message = (
675
- f"Finished classification across the collection. Processed: {processed_count}"
643
+ batch_results = manager.classify_batch(
644
+ item_contents=pdf_contents,
645
+ labels=labels,
646
+ model_id=model,
647
+ using=inferred_using,
648
+ progress_bar=True, # Let the manager handle progress display
649
+ **kwargs,
650
+ )
651
+ except Exception as e:
652
+ logger.error(f"Batch classification failed: {e}")
653
+ from natural_pdf.classification.manager import ClassificationError
654
+ raise ClassificationError(f"Batch classification failed: {e}") from e
655
+
656
+ # Assign results back to PDFs
657
+ if len(batch_results) != len(valid_pdfs):
658
+ logger.error(
659
+ f"Batch classification result count ({len(batch_results)}) mismatch "
660
+ f"with PDFs processed ({len(valid_pdfs)}). Cannot assign results."
676
661
  )
677
- if skipped_count > 0:
678
- final_message += f", Skipped (e.g., vision on multi-page): {skipped_count}"
679
- logger.info(final_message + ".")
680
-
681
- finally:
682
- # Ensure progress bar is closed properly
683
- if not progress_bar.disable and progress_bar.n < progress_bar.total:
684
- progress_bar.n = progress_bar.total # Ensure it reaches 100%
685
- if not progress_bar.disable:
686
- progress_bar.close()
662
+ from natural_pdf.classification.manager import ClassificationError
663
+ raise ClassificationError("Batch result count mismatch with input PDFs")
664
+
665
+ logger.info(f"Assigning {len(batch_results)} results to PDFs under key '{analysis_key}'.")
666
+
667
+ processed_count = 0
668
+ for pdf, result_obj in zip(valid_pdfs, batch_results):
669
+ try:
670
+ if not hasattr(pdf, "analyses") or pdf.analyses is None:
671
+ pdf.analyses = {}
672
+ pdf.analyses[analysis_key] = result_obj
673
+ processed_count += 1
674
+ except Exception as e:
675
+ logger.warning(f"Failed to store classification result for {pdf.path}: {e}")
676
+
677
+ skipped_count = len(self._pdfs) - processed_count
678
+ final_message = f"Finished batch classification. Processed: {processed_count}"
679
+ if skipped_count > 0:
680
+ final_message += f", Skipped: {skipped_count}"
681
+ logger.info(final_message + ".")
687
682
 
688
683
  return self
689
684