natural-pdf 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. docs/categorizing-documents/index.md +168 -0
  2. docs/data-extraction/index.md +87 -0
  3. docs/element-selection/index.ipynb +218 -164
  4. docs/element-selection/index.md +20 -0
  5. docs/finetuning/index.md +176 -0
  6. docs/index.md +19 -0
  7. docs/ocr/index.md +63 -16
  8. docs/tutorials/01-loading-and-extraction.ipynb +411 -248
  9. docs/tutorials/02-finding-elements.ipynb +123 -46
  10. docs/tutorials/03-extracting-blocks.ipynb +24 -19
  11. docs/tutorials/04-table-extraction.ipynb +17 -12
  12. docs/tutorials/05-excluding-content.ipynb +37 -32
  13. docs/tutorials/06-document-qa.ipynb +36 -31
  14. docs/tutorials/07-layout-analysis.ipynb +45 -40
  15. docs/tutorials/07-working-with-regions.ipynb +61 -60
  16. docs/tutorials/08-spatial-navigation.ipynb +76 -71
  17. docs/tutorials/09-section-extraction.ipynb +160 -155
  18. docs/tutorials/10-form-field-extraction.ipynb +71 -66
  19. docs/tutorials/11-enhanced-table-processing.ipynb +11 -6
  20. docs/tutorials/12-ocr-integration.ipynb +3420 -312
  21. docs/tutorials/12-ocr-integration.md +68 -106
  22. docs/tutorials/13-semantic-search.ipynb +641 -251
  23. natural_pdf/__init__.py +3 -0
  24. natural_pdf/analyzers/layout/gemini.py +63 -47
  25. natural_pdf/classification/manager.py +343 -0
  26. natural_pdf/classification/mixin.py +149 -0
  27. natural_pdf/classification/results.py +62 -0
  28. natural_pdf/collections/mixins.py +63 -0
  29. natural_pdf/collections/pdf_collection.py +326 -17
  30. natural_pdf/core/element_manager.py +73 -4
  31. natural_pdf/core/page.py +255 -83
  32. natural_pdf/core/pdf.py +385 -367
  33. natural_pdf/elements/base.py +1 -3
  34. natural_pdf/elements/collections.py +279 -49
  35. natural_pdf/elements/region.py +106 -21
  36. natural_pdf/elements/text.py +5 -2
  37. natural_pdf/exporters/__init__.py +4 -0
  38. natural_pdf/exporters/base.py +61 -0
  39. natural_pdf/exporters/paddleocr.py +345 -0
  40. natural_pdf/extraction/manager.py +134 -0
  41. natural_pdf/extraction/mixin.py +246 -0
  42. natural_pdf/extraction/result.py +37 -0
  43. natural_pdf/ocr/__init__.py +16 -8
  44. natural_pdf/ocr/engine.py +46 -30
  45. natural_pdf/ocr/engine_easyocr.py +86 -42
  46. natural_pdf/ocr/engine_paddle.py +39 -28
  47. natural_pdf/ocr/engine_surya.py +32 -16
  48. natural_pdf/ocr/ocr_factory.py +34 -23
  49. natural_pdf/ocr/ocr_manager.py +98 -34
  50. natural_pdf/ocr/ocr_options.py +38 -10
  51. natural_pdf/ocr/utils.py +59 -33
  52. natural_pdf/qa/document_qa.py +0 -4
  53. natural_pdf/selectors/parser.py +363 -238
  54. natural_pdf/templates/finetune/fine_tune_paddleocr.md +420 -0
  55. natural_pdf/utils/debug.py +4 -2
  56. natural_pdf/utils/identifiers.py +9 -5
  57. natural_pdf/utils/locks.py +8 -0
  58. natural_pdf/utils/packaging.py +172 -105
  59. natural_pdf/utils/text_extraction.py +96 -65
  60. natural_pdf/utils/tqdm_utils.py +43 -0
  61. natural_pdf/utils/visualization.py +1 -1
  62. {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/METADATA +10 -3
  63. {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/RECORD +66 -51
  64. {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/WHEEL +1 -1
  65. {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/licenses/LICENSE +0 -0
  66. {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,149 @@
1
+ import logging
2
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ # Assuming PIL is installed as it's needed for vision
5
+ try:
6
+ from PIL import Image
7
+ except ImportError:
8
+ Image = None # type: ignore
9
+
10
+ # Import result classes
11
+ from .results import ClassificationResult # Assuming results.py is in the same dir
12
+
13
+ if TYPE_CHECKING:
14
+ # Avoid runtime import cycle
15
+ from natural_pdf.core.page import Page
16
+ from natural_pdf.elements.region import Region
17
+ from .manager import ClassificationManager
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class ClassificationMixin:
22
+ """
23
+ Mixin class providing classification capabilities to Page and Region objects.
24
+ Relies on a ClassificationManager being accessible, typically via the parent PDF.
25
+ """
26
+
27
+ # --- Abstract methods/properties required by the host class --- #
28
+ # These must be implemented by classes using this mixin (Page, Region)
29
+
30
+ def _get_classification_manager(self) -> "ClassificationManager":
31
+ """Should return the ClassificationManager instance."""
32
+ raise NotImplementedError
33
+
34
+ def _get_classification_content(self, model_type: str, **kwargs) -> Union[str, "Image"]:
35
+ """Should return the text content (str) or image (PIL.Image) for classification."""
36
+ raise NotImplementedError
37
+
38
+ # Host class needs 'analyses' attribute initialized as Dict[str, Any]
39
+ # analyses: Dict[str, Any]
40
+
41
+ # --- End Abstract --- #
42
+
43
+ def classify(
44
+ self,
45
+ categories: List[str],
46
+ model: Optional[str] = None, # Default handled by manager
47
+ using: Optional[str] = None, # Renamed parameter
48
+ min_confidence: float = 0.0,
49
+ analysis_key: str = 'classification', # Default key
50
+ multi_label: bool = False,
51
+ **kwargs
52
+ ) -> "ClassificationMixin": # Return self for chaining
53
+ """
54
+ Classifies this item (Page or Region) using the configured manager.
55
+
56
+ Stores the result in self.analyses[analysis_key]. If analysis_key is not
57
+ provided, it defaults to 'classification' and overwrites any previous
58
+ result under that key.
59
+
60
+ Args:
61
+ categories: A list of string category names.
62
+ model: Model identifier (e.g., 'text', 'vision', HF ID). Defaults handled by manager.
63
+ using: Optional processing mode ('text' or 'vision'). If None, inferred by manager.
64
+ min_confidence: Minimum confidence threshold for results (0.0-1.0).
65
+ analysis_key: Key under which to store the result in `self.analyses`.
66
+ Defaults to 'classification'.
67
+ multi_label: Whether to allow multiple labels (passed to HF pipeline).
68
+ **kwargs: Additional arguments passed to the ClassificationManager.
69
+
70
+ Returns:
71
+ Self for method chaining.
72
+ """
73
+ # Ensure analyses dict exists
74
+ if not hasattr(self, 'analyses') or self.analyses is None:
75
+ logger.warning("'analyses' attribute not found or is None. Initializing as empty dict.")
76
+ self.analyses = {}
77
+
78
+ try:
79
+ manager = self._get_classification_manager()
80
+
81
+ # Determine the effective model ID and engine type
82
+ effective_model_id = model
83
+ inferred_using = manager.infer_using(model if model else manager.DEFAULT_TEXT_MODEL, using)
84
+
85
+ # If model was not provided, use the manager's default for the inferred engine type
86
+ if effective_model_id is None:
87
+ effective_model_id = manager.DEFAULT_TEXT_MODEL if inferred_using == 'text' else manager.DEFAULT_VISION_MODEL
88
+ logger.debug(f"No model provided, using default for mode '{inferred_using}': '{effective_model_id}'")
89
+
90
+ # Get content based on the *final* determined engine type
91
+ content = self._get_classification_content(model_type=inferred_using, **kwargs)
92
+
93
+ # Manager now returns a ClassificationResult object
94
+ result_obj: ClassificationResult = manager.classify_item(
95
+ item_content=content,
96
+ categories=categories,
97
+ model_id=effective_model_id, # Pass the resolved model ID
98
+ using=inferred_using, # Pass renamed argument
99
+ min_confidence=min_confidence,
100
+ multi_label=multi_label,
101
+ **kwargs
102
+ )
103
+
104
+ # Store the structured result object under the specified key
105
+ self.analyses[analysis_key] = result_obj
106
+ logger.debug(f"Stored classification result under key '{analysis_key}': {result_obj}")
107
+
108
+ except NotImplementedError as nie:
109
+ logger.error(f"Classification cannot proceed: {nie}")
110
+ raise
111
+ except Exception as e:
112
+ logger.error(f"Classification failed: {e}", exc_info=True)
113
+ # Optionally re-raise or just log and return self
114
+ # raise
115
+
116
+ return self
117
+
118
+ @property
119
+ def classification_results(self) -> Optional[ClassificationResult]:
120
+ """Returns the ClassificationResult from the *default* ('classification') key, or None."""
121
+ if not hasattr(self, 'analyses') or self.analyses is None:
122
+ return None
123
+ # Return the result object directly from the default key
124
+ return self.analyses.get('classification')
125
+
126
+ @property
127
+ def category(self) -> Optional[str]:
128
+ """Returns the top category label from the *default* ('classification') key, or None."""
129
+ result_obj = self.classification_results # Uses the property above
130
+ # Access the property on the result object
131
+ return result_obj.top_category if result_obj else None
132
+
133
+ @property
134
+ def category_confidence(self) -> Optional[float]:
135
+ """Returns the top category confidence from the *default* ('classification') key, or None."""
136
+ result_obj = self.classification_results # Uses the property above
137
+ # Access the property on the result object
138
+ return result_obj.top_confidence if result_obj else None
139
+
140
+ # Maybe add a helper to get results by specific key?
141
+ def get_classification_result(self, analysis_key: str = 'classification') -> Optional[ClassificationResult]:
142
+ """Gets a classification result object stored under a specific key."""
143
+ if not hasattr(self, 'analyses') or self.analyses is None:
144
+ return None
145
+ result = self.analyses.get(analysis_key)
146
+ if result is not None and not isinstance(result, ClassificationResult):
147
+ logger.warning(f"Item found under key '{analysis_key}' is not a ClassificationResult (type: {type(result)}). Returning None.")
148
+ return None
149
+ return result
@@ -0,0 +1,62 @@
1
+ # natural_pdf/classification/results.py
2
+ from typing import List, Optional, Dict, Any
3
+ from datetime import datetime
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class CategoryScore:
9
+ """Represents the score for a single category."""
10
+ label: str
11
+ confidence: float # Score between 0.0 and 1.0
12
+
13
+ def __init__(self, label: str, confidence: float):
14
+ # Basic validation
15
+ if not isinstance(label, str) or not label:
16
+ logger.warning(f"Initializing CategoryScore with invalid label: {label}")
17
+ # Fallback or raise? For now, allow but log.
18
+ # raise ValueError("Category label must be a non-empty string.")
19
+ if not isinstance(confidence, (float, int)) or not (0.0 <= confidence <= 1.0):
20
+ logger.warning(f"Initializing CategoryScore with invalid confidence: {confidence} for label '{label}'. Clamping to [0, 1].")
21
+ confidence = max(0.0, min(1.0, float(confidence)))
22
+ # raise ValueError("Category confidence must be a float between 0.0 and 1.0.")
23
+
24
+ self.label = str(label)
25
+ self.confidence = float(confidence)
26
+
27
+ def __repr__(self):
28
+ return f"<CategoryScore label='{self.label}' confidence={self.confidence:.3f}>"
29
+
30
+ class ClassificationResult:
31
+ """Holds the structured results of a classification task."""
32
+ model_id: str
33
+ using: str # Renamed from engine_type ('text' or 'vision')
34
+ timestamp: datetime
35
+ parameters: Dict[str, Any] # e.g., {'categories': [...], 'min_confidence': 0.1}
36
+ scores: List[CategoryScore] # List of scores above threshold, sorted by confidence
37
+
38
+ def __init__(self, model_id: str, using: str, timestamp: datetime, parameters: Dict[str, Any], scores: List[CategoryScore]):
39
+ if not isinstance(scores, list) or not all(isinstance(s, CategoryScore) for s in scores):
40
+ raise TypeError("Scores must be a list of CategoryScore objects.")
41
+
42
+ self.model_id = str(model_id)
43
+ self.using = str(using) # Renamed from engine_type
44
+ self.timestamp = timestamp
45
+ self.parameters = parameters if parameters is not None else {}
46
+ # Ensure scores are sorted descending by confidence
47
+ self.scores = sorted(scores, key=lambda s: s.confidence, reverse=True)
48
+
49
+ @property
50
+ def top_category(self) -> Optional[str]:
51
+ """Returns the label of the category with the highest confidence."""
52
+ return self.scores[0].label if self.scores else None
53
+
54
+ @property
55
+ def top_confidence(self) -> Optional[float]:
56
+ """Returns the confidence score of the top category."""
57
+ return self.scores[0].confidence if self.scores else None
58
+
59
+ def __repr__(self):
60
+ top_cat = f" top='{self.top_category}' ({self.top_confidence:.2f})" if self.scores else ""
61
+ num_scores = len(self.scores)
62
+ return f"<ClassificationResult model='{self.model_id}' using='{self.using}' scores={num_scores}{top_cat}>"
@@ -0,0 +1,63 @@
1
+ import logging
2
+ from typing import Callable, Iterable, Any, TypeVar
3
+ from tqdm.auto import tqdm
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ T = TypeVar("T") # Generic type for items in the collection
8
+
9
+ class ApplyMixin:
10
+ """
11
+ Mixin class providing an `.apply()` method for collections.
12
+
13
+ Assumes the inheriting class implements `__iter__` and `__len__` appropriately
14
+ for the items to be processed by `apply`.
15
+ """
16
+ def _get_items_for_apply(self) -> Iterable[Any]:
17
+ """
18
+ Returns the iterable of items to apply the function to.
19
+ Defaults to iterating over `self`. Subclasses should override this
20
+ if the default iteration is not suitable for the apply operation.
21
+ """
22
+ # Default to standard iteration over the collection itself
23
+ return iter(self)
24
+
25
+ def apply(self: Any, func: Callable[[Any, ...], Any], *args, **kwargs) -> None:
26
+ """
27
+ Applies a function to each item in the collection.
28
+
29
+ Args:
30
+ func: The function to apply to each item. The item itself
31
+ will be passed as the first argument to the function.
32
+ *args: Additional positional arguments to pass to func.
33
+ **kwargs: Additional keyword arguments to pass to func.
34
+ A special keyword argument 'show_progress' (bool, default=False)
35
+ can be used to display a progress bar.
36
+ """
37
+ show_progress = kwargs.pop('show_progress', False)
38
+ # Derive unit name from class name
39
+ unit_name = self.__class__.__name__.lower()
40
+ items_iterable = self._get_items_for_apply()
41
+
42
+ # Need total count for tqdm, assumes __len__ is implemented by the inheriting class
43
+ total_items = 0
44
+ try:
45
+ total_items = len(self)
46
+ except TypeError: # Handle cases where __len__ might not be defined on self
47
+ logger.warning(f"Could not determine collection length for progress bar.")
48
+
49
+ if show_progress and total_items > 0:
50
+ items_iterable = tqdm(items_iterable, total=total_items, desc=f"Applying {func.__name__}", unit=unit_name)
51
+ elif show_progress:
52
+ logger.info(f"Applying {func.__name__} (progress bar disabled for zero/unknown length).")
53
+
54
+ for item in items_iterable:
55
+ try:
56
+ # Apply the function with the item and any extra args/kwargs
57
+ func(item, *args, **kwargs)
58
+ except Exception as e:
59
+ # Log and continue for batch operations
60
+ logger.error(f"Error applying {func.__name__} to {item}: {e}", exc_info=True)
61
+ # Optionally add a mechanism to collect errors
62
+
63
+ # Returns None, primarily used for side effects.
@@ -4,12 +4,24 @@ import logging
4
4
  import os
5
5
  import re # Added for safe path generation
6
6
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Type, Union
7
+ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Type, Union, Callable
8
+ import concurrent.futures # Import concurrent.futures
9
+ import time # Import time for logging timestamps
10
+ import threading # Import threading for logging thread information
8
11
 
9
12
  from PIL import Image
10
13
  from tqdm import tqdm
14
+ from tqdm.auto import tqdm as auto_tqdm
15
+ from tqdm.notebook import tqdm as notebook_tqdm
16
+
17
+ from natural_pdf.utils.tqdm_utils import get_tqdm
18
+
19
+ # Get the appropriate tqdm class once
20
+ tqdm = get_tqdm()
11
21
 
12
22
  # Set up logger early
23
+ # Configure logging to include thread information
24
+ # logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(threadName)s - %(name)s - %(levelname)s - %(message)s')
13
25
  logger = logging.getLogger(__name__)
14
26
 
15
27
  from natural_pdf.core.pdf import PDF
@@ -36,9 +48,11 @@ except ImportError as e:
36
48
  SearchServiceProtocol, SearchOptions, Indexable = object, object, object
37
49
 
38
50
  from natural_pdf.search.searchable_mixin import SearchableMixin # Import the new mixin
51
+ # Import the ApplyMixin
52
+ from natural_pdf.collections.mixins import ApplyMixin
39
53
 
40
54
 
41
- class PDFCollection(SearchableMixin): # Inherit from the mixin
55
+ class PDFCollection(SearchableMixin, ApplyMixin): # Inherit from ApplyMixin
42
56
  def __init__(
43
57
  self,
44
58
  source: Union[str, Iterable[Union[str, "PDF"]]],
@@ -237,30 +251,214 @@ class PDFCollection(SearchableMixin): # Inherit from the mixin
237
251
 
238
252
  def __repr__(self) -> str:
239
253
  # Removed search status
240
- return f"<PDFCollection(count={len(self)})>"
254
+ return f"<PDFCollection(count={len(self._pdfs)})>"
241
255
 
242
256
  @property
243
257
  def pdfs(self) -> List["PDF"]:
244
258
  """Returns the list of PDF objects held by the collection."""
245
259
  return self._pdfs
246
260
 
247
- def apply_ocr(self, *args, **kwargs):
248
- PDF = self._get_pdf_class()
249
- # Delegate to individual PDF objects
250
- logger.info("Applying OCR to relevant PDFs in collection...")
251
- results = []
261
+ def find_all(
262
+ self,
263
+ selector: str,
264
+ apply_exclusions: bool = True, # Added explicit parameter
265
+ regex: bool = False, # Added explicit parameter
266
+ case: bool = True, # Added explicit parameter
267
+ **kwargs
268
+ ) -> "ElementCollection":
269
+ """
270
+ Find all elements matching the selector across all PDFs in the collection.
271
+
272
+ This creates an ElementCollection that can span multiple PDFs. Note that
273
+ some ElementCollection methods have limitations when spanning PDFs.
274
+
275
+ Args:
276
+ selector: CSS-like selector string to query elements
277
+ apply_exclusions: Whether to exclude elements in exclusion regions (default: True)
278
+ regex: Whether to use regex for text search in :contains (default: False)
279
+ case: Whether to do case-sensitive text search (default: True)
280
+ **kwargs: Additional keyword arguments passed to the find_all method of each PDF
281
+
282
+ Returns:
283
+ ElementCollection containing all matching elements across all PDFs
284
+ """
285
+ from natural_pdf.elements.collections import ElementCollection
286
+
287
+ # Collect elements from all PDFs
288
+ all_elements = []
252
289
  for pdf in self._pdfs:
253
- # We need to figure out which pages belong to which PDF if batching here
254
- # For now, simpler to call on each PDF
255
290
  try:
256
- # Assume apply_ocr exists on PDF and accepts similar args
257
- pdf.apply_ocr(*args, **kwargs)
291
+ # Explicitly pass the relevant arguments down
292
+ elements = pdf.find_all(
293
+ selector,
294
+ apply_exclusions=apply_exclusions,
295
+ regex=regex,
296
+ case=case,
297
+ **kwargs
298
+ )
299
+ all_elements.extend(elements.elements)
258
300
  except Exception as e:
259
- logger.error(f"Failed applying OCR to {pdf.path}: {e}", exc_info=True)
301
+ logger.error(f"Error finding elements in {pdf.path}: {e}", exc_info=True)
302
+
303
+ return ElementCollection(all_elements)
304
+
305
+ def apply_ocr(
306
+ self,
307
+ engine: Optional[str] = None,
308
+ languages: Optional[List[str]] = None,
309
+ min_confidence: Optional[float] = None,
310
+ device: Optional[str] = None,
311
+ resolution: Optional[int] = None,
312
+ apply_exclusions: bool = True,
313
+ detect_only: bool = False,
314
+ replace: bool = True,
315
+ options: Optional[Any] = None,
316
+ pages: Optional[Union[slice, List[int]]] = None,
317
+ max_workers: Optional[int] = None,
318
+ ) -> "PDFCollection":
319
+ """
320
+ Apply OCR to all PDFs in the collection, potentially in parallel.
321
+
322
+ Args:
323
+ engine: OCR engine to use (e.g., 'easyocr', 'paddleocr', 'surya')
324
+ languages: List of language codes for OCR
325
+ min_confidence: Minimum confidence threshold for text detection
326
+ device: Device to use for OCR (e.g., 'cpu', 'cuda')
327
+ resolution: DPI resolution for page rendering
328
+ apply_exclusions: Whether to apply exclusion regions
329
+ detect_only: If True, only detect text regions without extracting text
330
+ replace: If True, replace existing OCR elements
331
+ options: Engine-specific options
332
+ pages: Specific pages to process (None for all pages)
333
+ max_workers: Maximum number of threads to process PDFs concurrently.
334
+ If None or 1, processing is sequential. (default: None)
335
+
336
+ Returns:
337
+ Self for method chaining
338
+ """
339
+ PDF = self._get_pdf_class()
340
+ logger.info(f"Applying OCR to {len(self._pdfs)} PDFs in collection (max_workers={max_workers})...")
341
+
342
+ # Worker function takes PDF object again
343
+ def _process_pdf(pdf: PDF):
344
+ """Helper function to apply OCR to a single PDF, handling errors."""
345
+ thread_id = threading.current_thread().name # Get thread name for logging
346
+ pdf_path = pdf.path # Get path for logging
347
+ logger.debug(f"[{thread_id}] Starting OCR process for: {pdf_path}")
348
+ start_time = time.monotonic()
349
+ try:
350
+ pdf.apply_ocr( # Call apply_ocr on the original PDF object
351
+ pages=pages,
352
+ engine=engine,
353
+ languages=languages,
354
+ min_confidence=min_confidence,
355
+ device=device,
356
+ resolution=resolution,
357
+ apply_exclusions=apply_exclusions,
358
+ detect_only=detect_only,
359
+ replace=replace,
360
+ options=options,
361
+ # Note: We might want a max_workers here too for page rendering?
362
+ # For now, PDF.apply_ocr doesn't have it.
363
+ )
364
+ end_time = time.monotonic()
365
+ logger.debug(f"[{thread_id}] Finished OCR process for: {pdf_path} (Duration: {end_time - start_time:.2f}s)")
366
+ return pdf_path, None
367
+ except Exception as e:
368
+ end_time = time.monotonic()
369
+ logger.error(f"[{thread_id}] Failed OCR process for {pdf_path} after {end_time - start_time:.2f}s: {e}", exc_info=False)
370
+ return pdf_path, e # Return path and error
371
+
372
+ # Use ThreadPoolExecutor for parallel processing if max_workers > 1
373
+ if max_workers is not None and max_workers > 1:
374
+ futures = []
375
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="OCRWorker") as executor:
376
+ for pdf in self._pdfs:
377
+ # Submit the PDF object to the worker function
378
+ futures.append(executor.submit(_process_pdf, pdf))
379
+
380
+ # Use the selected tqdm class with as_completed for progress tracking
381
+ progress_bar = tqdm(
382
+ concurrent.futures.as_completed(futures),
383
+ total=len(self._pdfs),
384
+ desc="Applying OCR (Parallel)",
385
+ unit="pdf"
386
+ )
387
+
388
+ for future in progress_bar:
389
+ pdf_path, error = future.result() # Get result (or exception)
390
+ if error:
391
+ progress_bar.set_postfix_str(f"Error: {pdf_path}", refresh=True)
392
+ # Progress is updated automatically by tqdm
393
+
394
+ else: # Sequential processing (max_workers is None or 1)
395
+ logger.info("Applying OCR sequentially...")
396
+ # Use the selected tqdm class for sequential too for consistency
397
+ # Iterate over PDF objects directly for sequential
398
+ for pdf in tqdm(self._pdfs, desc="Applying OCR (Sequential)", unit="pdf"):
399
+ _process_pdf(pdf) # Call helper directly with PDF object
400
+
401
+ logger.info("Finished applying OCR across the collection.")
260
402
  return self
261
403
 
262
- # --- Advanced Method Placeholders ---
263
- # Placeholder for categorize removed as find_relevant is now implemented
404
+ def correct_ocr(
405
+ self,
406
+ correction_callback: Callable[[Any], Optional[str]],
407
+ max_workers: Optional[int] = None,
408
+ progress_callback: Optional[Callable[[], None]] = None,
409
+ ) -> "PDFCollection":
410
+ """
411
+ Apply OCR correction to all relevant elements across all pages and PDFs
412
+ in the collection using a single progress bar.
413
+
414
+ Args:
415
+ correction_callback: Function to apply to each OCR element.
416
+ It receives the element and should return
417
+ the corrected text (str) or None.
418
+ max_workers: Max threads to use for parallel execution within each page.
419
+ progress_callback: Optional callback function to call after processing each element.
420
+
421
+ Returns:
422
+ Self for method chaining.
423
+ """
424
+ PDF = self._get_pdf_class() # Ensure PDF class is available
425
+ if not callable(correction_callback):
426
+ raise TypeError("`correction_callback` must be a callable function.")
427
+
428
+ logger.info(f"Gathering OCR elements from {len(self._pdfs)} PDFs for correction...")
429
+
430
+ # 1. Gather all target elements using the collection's find_all
431
+ # Crucially, set apply_exclusions=False to include elements in headers/footers etc.
432
+ all_ocr_elements = self.find_all("text[source=ocr]", apply_exclusions=False).elements
433
+
434
+ if not all_ocr_elements:
435
+ logger.info("No OCR elements found in the collection to correct.")
436
+ return self
437
+
438
+ total_elements = len(all_ocr_elements)
439
+ logger.info(f"Found {total_elements} OCR elements across the collection. Starting correction process...")
440
+
441
+ # 2. Initialize the progress bar
442
+ progress_bar = tqdm(total=total_elements, desc="Correcting OCR Elements", unit="element")
443
+
444
+ # 3. Iterate through PDFs and delegate to PDF.correct_ocr
445
+ # PDF.correct_ocr handles page iteration and passing the progress callback down.
446
+ for pdf in self._pdfs:
447
+ if not pdf.pages:
448
+ continue
449
+ try:
450
+ pdf.correct_ocr(
451
+ correction_callback=correction_callback,
452
+ max_workers=max_workers,
453
+ progress_callback=progress_bar.update # Pass the bar's update method
454
+ )
455
+ except Exception as e:
456
+ logger.error(f"Error occurred during correction process for PDF {pdf.path}: {e}", exc_info=True)
457
+ # Decide if we should stop or continue? For now, continue.
458
+
459
+ progress_bar.close()
460
+
461
+ return self
264
462
 
265
463
  def categorize(self, categories: List[str], **kwargs):
266
464
  """Categorizes PDFs in the collection based on content or features."""
@@ -279,14 +477,17 @@ class PDFCollection(SearchableMixin): # Inherit from the mixin
279
477
  """
280
478
  try:
281
479
  from natural_pdf.utils.packaging import create_correction_task_package
480
+
282
481
  # Pass the collection itself (self) as the source
283
482
  create_correction_task_package(source=self, output_zip_path=output_zip_path, **kwargs)
284
483
  except ImportError:
285
- logger.error("Failed to import 'create_correction_task_package'. Packaging utility might be missing.")
484
+ logger.error(
485
+ "Failed to import 'create_correction_task_package'. Packaging utility might be missing."
486
+ )
286
487
  # Or raise
287
488
  except Exception as e:
288
489
  logger.error(f"Failed to export correction task for collection: {e}", exc_info=True)
289
- raise # Re-raise the exception from the utility function
490
+ raise # Re-raise the exception from the utility function
290
491
 
291
492
  # --- Mixin Required Implementation ---
292
493
  def get_indexable_items(self) -> Iterable[Indexable]:
@@ -306,3 +507,111 @@ class PDFCollection(SearchableMixin): # Inherit from the mixin
306
507
  # logger.debug(f"Skipping empty page {page.page_number} from PDF '{pdf.path}'.")
307
508
  # continue
308
509
  yield page
510
+
511
+ # --- Classification Method --- #
512
+ def classify_all(
513
+ self,
514
+ categories: List[str],
515
+ model: str = "text",
516
+ max_workers: Optional[int] = None,
517
+ **kwargs,
518
+ ) -> "PDFCollection":
519
+ """
520
+ Classify all pages across all PDFs in the collection, potentially in parallel.
521
+
522
+ This method uses the unified `classify_all` approach, delegating page
523
+ classification to each PDF's `classify_pages` method.
524
+ It displays a progress bar tracking individual pages.
525
+
526
+ Args:
527
+ categories: A list of string category names.
528
+ model: Model identifier ('text', 'vision', or specific HF ID).
529
+ max_workers: Maximum number of threads to process PDFs concurrently.
530
+ If None or 1, processing is sequential.
531
+ **kwargs: Additional arguments passed down to `pdf.classify_pages` and
532
+ subsequently to `page.classify` (e.g., device,
533
+ confidence_threshold, resolution).
534
+
535
+ Returns:
536
+ Self for method chaining.
537
+
538
+ Raises:
539
+ ValueError: If categories list is empty.
540
+ ClassificationError: If classification fails for any page (will stop processing).
541
+ ImportError: If classification dependencies are missing.
542
+ """
543
+ PDF = self._get_pdf_class()
544
+ if not categories:
545
+ raise ValueError("Categories list cannot be empty.")
546
+
547
+ logger.info(f"Starting classification for {len(self._pdfs)} PDFs in collection (model: '{model}')...")
548
+
549
+ # Calculate total pages for the progress bar
550
+ total_pages = sum(len(pdf.pages) for pdf in self._pdfs if pdf.pages)
551
+ if total_pages == 0:
552
+ logger.warning("No pages found in the PDF collection to classify.")
553
+ return self
554
+
555
+ progress_bar = tqdm(
556
+ total=total_pages,
557
+ desc=f"Classifying Pages (model: {model})",
558
+ unit="page"
559
+ )
560
+
561
+ # Worker function
562
+ def _process_pdf_classification(pdf: PDF):
563
+ thread_id = threading.current_thread().name
564
+ pdf_path = pdf.path
565
+ logger.debug(f"[{thread_id}] Starting classification process for: {pdf_path}")
566
+ start_time = time.monotonic()
567
+ try:
568
+ # Call classify_pages on the PDF, passing the progress callback
569
+ pdf.classify_pages(
570
+ categories=categories,
571
+ model=model,
572
+ progress_callback=progress_bar.update,
573
+ **kwargs
574
+ )
575
+ end_time = time.monotonic()
576
+ logger.debug(f"[{thread_id}] Finished classification for: {pdf_path} (Duration: {end_time - start_time:.2f}s)")
577
+ return pdf_path, None # Return path and no error
578
+ except Exception as e:
579
+ end_time = time.monotonic()
580
+ # Error is logged within classify_pages, but log summary here
581
+ logger.error(f"[{thread_id}] Failed classification process for {pdf_path} after {end_time - start_time:.2f}s: {e}", exc_info=False)
582
+ # Close progress bar immediately on error to avoid hanging
583
+ progress_bar.close()
584
+ # Re-raise the exception to stop the entire collection processing
585
+ raise
586
+
587
+ # Use ThreadPoolExecutor for parallel processing if max_workers > 1
588
+ try:
589
+ if max_workers is not None and max_workers > 1:
590
+ logger.info(f"Classifying PDFs in parallel with {max_workers} workers.")
591
+ futures = []
592
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="ClassifyWorker") as executor:
593
+ for pdf in self._pdfs:
594
+ futures.append(executor.submit(_process_pdf_classification, pdf))
595
+
596
+ # Wait for all futures to complete (progress updated by callback)
597
+ # Exceptions are raised by future.result() if worker failed
598
+ for future in concurrent.futures.as_completed(futures):
599
+ future.result() # Raise exception if worker failed
600
+
601
+ else: # Sequential processing
602
+ logger.info("Classifying PDFs sequentially.")
603
+ for pdf in self._pdfs:
604
+ _process_pdf_classification(pdf)
605
+
606
+ logger.info("Finished classification across the collection.")
607
+
608
+ finally:
609
+ # Ensure progress bar is closed even if errors occurred elsewhere
610
+ if not progress_bar.disable and progress_bar.n < progress_bar.total:
611
+ progress_bar.close()
612
+ elif progress_bar.disable is False:
613
+ progress_bar.close()
614
+
615
+ return self
616
+
617
+ # --- End Classification Method --- #