natural-pdf 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/categorizing-documents/index.md +168 -0
- docs/data-extraction/index.md +87 -0
- docs/element-selection/index.ipynb +218 -164
- docs/element-selection/index.md +20 -0
- docs/finetuning/index.md +176 -0
- docs/index.md +19 -0
- docs/ocr/index.md +63 -16
- docs/tutorials/01-loading-and-extraction.ipynb +411 -248
- docs/tutorials/02-finding-elements.ipynb +123 -46
- docs/tutorials/03-extracting-blocks.ipynb +24 -19
- docs/tutorials/04-table-extraction.ipynb +17 -12
- docs/tutorials/05-excluding-content.ipynb +37 -32
- docs/tutorials/06-document-qa.ipynb +36 -31
- docs/tutorials/07-layout-analysis.ipynb +45 -40
- docs/tutorials/07-working-with-regions.ipynb +61 -60
- docs/tutorials/08-spatial-navigation.ipynb +76 -71
- docs/tutorials/09-section-extraction.ipynb +160 -155
- docs/tutorials/10-form-field-extraction.ipynb +71 -66
- docs/tutorials/11-enhanced-table-processing.ipynb +11 -6
- docs/tutorials/12-ocr-integration.ipynb +3420 -312
- docs/tutorials/12-ocr-integration.md +68 -106
- docs/tutorials/13-semantic-search.ipynb +641 -251
- natural_pdf/__init__.py +3 -0
- natural_pdf/analyzers/layout/gemini.py +63 -47
- natural_pdf/classification/manager.py +343 -0
- natural_pdf/classification/mixin.py +149 -0
- natural_pdf/classification/results.py +62 -0
- natural_pdf/collections/mixins.py +63 -0
- natural_pdf/collections/pdf_collection.py +326 -17
- natural_pdf/core/element_manager.py +73 -4
- natural_pdf/core/page.py +255 -83
- natural_pdf/core/pdf.py +385 -367
- natural_pdf/elements/base.py +1 -3
- natural_pdf/elements/collections.py +279 -49
- natural_pdf/elements/region.py +106 -21
- natural_pdf/elements/text.py +5 -2
- natural_pdf/exporters/__init__.py +4 -0
- natural_pdf/exporters/base.py +61 -0
- natural_pdf/exporters/paddleocr.py +345 -0
- natural_pdf/extraction/manager.py +134 -0
- natural_pdf/extraction/mixin.py +246 -0
- natural_pdf/extraction/result.py +37 -0
- natural_pdf/ocr/__init__.py +16 -8
- natural_pdf/ocr/engine.py +46 -30
- natural_pdf/ocr/engine_easyocr.py +86 -42
- natural_pdf/ocr/engine_paddle.py +39 -28
- natural_pdf/ocr/engine_surya.py +32 -16
- natural_pdf/ocr/ocr_factory.py +34 -23
- natural_pdf/ocr/ocr_manager.py +98 -34
- natural_pdf/ocr/ocr_options.py +38 -10
- natural_pdf/ocr/utils.py +59 -33
- natural_pdf/qa/document_qa.py +0 -4
- natural_pdf/selectors/parser.py +363 -238
- natural_pdf/templates/finetune/fine_tune_paddleocr.md +420 -0
- natural_pdf/utils/debug.py +4 -2
- natural_pdf/utils/identifiers.py +9 -5
- natural_pdf/utils/locks.py +8 -0
- natural_pdf/utils/packaging.py +172 -105
- natural_pdf/utils/text_extraction.py +96 -65
- natural_pdf/utils/tqdm_utils.py +43 -0
- natural_pdf/utils/visualization.py +1 -1
- {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/METADATA +10 -3
- {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/RECORD +66 -51
- {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/WHEEL +1 -1
- {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,149 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
3
|
+
|
4
|
+
# Assuming PIL is installed as it's needed for vision
|
5
|
+
try:
|
6
|
+
from PIL import Image
|
7
|
+
except ImportError:
|
8
|
+
Image = None # type: ignore
|
9
|
+
|
10
|
+
# Import result classes
|
11
|
+
from .results import ClassificationResult # Assuming results.py is in the same dir
|
12
|
+
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
# Avoid runtime import cycle
|
15
|
+
from natural_pdf.core.page import Page
|
16
|
+
from natural_pdf.elements.region import Region
|
17
|
+
from .manager import ClassificationManager
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
class ClassificationMixin:
|
22
|
+
"""
|
23
|
+
Mixin class providing classification capabilities to Page and Region objects.
|
24
|
+
Relies on a ClassificationManager being accessible, typically via the parent PDF.
|
25
|
+
"""
|
26
|
+
|
27
|
+
# --- Abstract methods/properties required by the host class --- #
|
28
|
+
# These must be implemented by classes using this mixin (Page, Region)
|
29
|
+
|
30
|
+
def _get_classification_manager(self) -> "ClassificationManager":
|
31
|
+
"""Should return the ClassificationManager instance."""
|
32
|
+
raise NotImplementedError
|
33
|
+
|
34
|
+
def _get_classification_content(self, model_type: str, **kwargs) -> Union[str, "Image"]:
|
35
|
+
"""Should return the text content (str) or image (PIL.Image) for classification."""
|
36
|
+
raise NotImplementedError
|
37
|
+
|
38
|
+
# Host class needs 'analyses' attribute initialized as Dict[str, Any]
|
39
|
+
# analyses: Dict[str, Any]
|
40
|
+
|
41
|
+
# --- End Abstract --- #
|
42
|
+
|
43
|
+
def classify(
|
44
|
+
self,
|
45
|
+
categories: List[str],
|
46
|
+
model: Optional[str] = None, # Default handled by manager
|
47
|
+
using: Optional[str] = None, # Renamed parameter
|
48
|
+
min_confidence: float = 0.0,
|
49
|
+
analysis_key: str = 'classification', # Default key
|
50
|
+
multi_label: bool = False,
|
51
|
+
**kwargs
|
52
|
+
) -> "ClassificationMixin": # Return self for chaining
|
53
|
+
"""
|
54
|
+
Classifies this item (Page or Region) using the configured manager.
|
55
|
+
|
56
|
+
Stores the result in self.analyses[analysis_key]. If analysis_key is not
|
57
|
+
provided, it defaults to 'classification' and overwrites any previous
|
58
|
+
result under that key.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
categories: A list of string category names.
|
62
|
+
model: Model identifier (e.g., 'text', 'vision', HF ID). Defaults handled by manager.
|
63
|
+
using: Optional processing mode ('text' or 'vision'). If None, inferred by manager.
|
64
|
+
min_confidence: Minimum confidence threshold for results (0.0-1.0).
|
65
|
+
analysis_key: Key under which to store the result in `self.analyses`.
|
66
|
+
Defaults to 'classification'.
|
67
|
+
multi_label: Whether to allow multiple labels (passed to HF pipeline).
|
68
|
+
**kwargs: Additional arguments passed to the ClassificationManager.
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
Self for method chaining.
|
72
|
+
"""
|
73
|
+
# Ensure analyses dict exists
|
74
|
+
if not hasattr(self, 'analyses') or self.analyses is None:
|
75
|
+
logger.warning("'analyses' attribute not found or is None. Initializing as empty dict.")
|
76
|
+
self.analyses = {}
|
77
|
+
|
78
|
+
try:
|
79
|
+
manager = self._get_classification_manager()
|
80
|
+
|
81
|
+
# Determine the effective model ID and engine type
|
82
|
+
effective_model_id = model
|
83
|
+
inferred_using = manager.infer_using(model if model else manager.DEFAULT_TEXT_MODEL, using)
|
84
|
+
|
85
|
+
# If model was not provided, use the manager's default for the inferred engine type
|
86
|
+
if effective_model_id is None:
|
87
|
+
effective_model_id = manager.DEFAULT_TEXT_MODEL if inferred_using == 'text' else manager.DEFAULT_VISION_MODEL
|
88
|
+
logger.debug(f"No model provided, using default for mode '{inferred_using}': '{effective_model_id}'")
|
89
|
+
|
90
|
+
# Get content based on the *final* determined engine type
|
91
|
+
content = self._get_classification_content(model_type=inferred_using, **kwargs)
|
92
|
+
|
93
|
+
# Manager now returns a ClassificationResult object
|
94
|
+
result_obj: ClassificationResult = manager.classify_item(
|
95
|
+
item_content=content,
|
96
|
+
categories=categories,
|
97
|
+
model_id=effective_model_id, # Pass the resolved model ID
|
98
|
+
using=inferred_using, # Pass renamed argument
|
99
|
+
min_confidence=min_confidence,
|
100
|
+
multi_label=multi_label,
|
101
|
+
**kwargs
|
102
|
+
)
|
103
|
+
|
104
|
+
# Store the structured result object under the specified key
|
105
|
+
self.analyses[analysis_key] = result_obj
|
106
|
+
logger.debug(f"Stored classification result under key '{analysis_key}': {result_obj}")
|
107
|
+
|
108
|
+
except NotImplementedError as nie:
|
109
|
+
logger.error(f"Classification cannot proceed: {nie}")
|
110
|
+
raise
|
111
|
+
except Exception as e:
|
112
|
+
logger.error(f"Classification failed: {e}", exc_info=True)
|
113
|
+
# Optionally re-raise or just log and return self
|
114
|
+
# raise
|
115
|
+
|
116
|
+
return self
|
117
|
+
|
118
|
+
@property
|
119
|
+
def classification_results(self) -> Optional[ClassificationResult]:
|
120
|
+
"""Returns the ClassificationResult from the *default* ('classification') key, or None."""
|
121
|
+
if not hasattr(self, 'analyses') or self.analyses is None:
|
122
|
+
return None
|
123
|
+
# Return the result object directly from the default key
|
124
|
+
return self.analyses.get('classification')
|
125
|
+
|
126
|
+
@property
|
127
|
+
def category(self) -> Optional[str]:
|
128
|
+
"""Returns the top category label from the *default* ('classification') key, or None."""
|
129
|
+
result_obj = self.classification_results # Uses the property above
|
130
|
+
# Access the property on the result object
|
131
|
+
return result_obj.top_category if result_obj else None
|
132
|
+
|
133
|
+
@property
|
134
|
+
def category_confidence(self) -> Optional[float]:
|
135
|
+
"""Returns the top category confidence from the *default* ('classification') key, or None."""
|
136
|
+
result_obj = self.classification_results # Uses the property above
|
137
|
+
# Access the property on the result object
|
138
|
+
return result_obj.top_confidence if result_obj else None
|
139
|
+
|
140
|
+
# Maybe add a helper to get results by specific key?
|
141
|
+
def get_classification_result(self, analysis_key: str = 'classification') -> Optional[ClassificationResult]:
|
142
|
+
"""Gets a classification result object stored under a specific key."""
|
143
|
+
if not hasattr(self, 'analyses') or self.analyses is None:
|
144
|
+
return None
|
145
|
+
result = self.analyses.get(analysis_key)
|
146
|
+
if result is not None and not isinstance(result, ClassificationResult):
|
147
|
+
logger.warning(f"Item found under key '{analysis_key}' is not a ClassificationResult (type: {type(result)}). Returning None.")
|
148
|
+
return None
|
149
|
+
return result
|
@@ -0,0 +1,62 @@
|
|
1
|
+
# natural_pdf/classification/results.py
|
2
|
+
from typing import List, Optional, Dict, Any
|
3
|
+
from datetime import datetime
|
4
|
+
import logging
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
class CategoryScore:
|
9
|
+
"""Represents the score for a single category."""
|
10
|
+
label: str
|
11
|
+
confidence: float # Score between 0.0 and 1.0
|
12
|
+
|
13
|
+
def __init__(self, label: str, confidence: float):
|
14
|
+
# Basic validation
|
15
|
+
if not isinstance(label, str) or not label:
|
16
|
+
logger.warning(f"Initializing CategoryScore with invalid label: {label}")
|
17
|
+
# Fallback or raise? For now, allow but log.
|
18
|
+
# raise ValueError("Category label must be a non-empty string.")
|
19
|
+
if not isinstance(confidence, (float, int)) or not (0.0 <= confidence <= 1.0):
|
20
|
+
logger.warning(f"Initializing CategoryScore with invalid confidence: {confidence} for label '{label}'. Clamping to [0, 1].")
|
21
|
+
confidence = max(0.0, min(1.0, float(confidence)))
|
22
|
+
# raise ValueError("Category confidence must be a float between 0.0 and 1.0.")
|
23
|
+
|
24
|
+
self.label = str(label)
|
25
|
+
self.confidence = float(confidence)
|
26
|
+
|
27
|
+
def __repr__(self):
|
28
|
+
return f"<CategoryScore label='{self.label}' confidence={self.confidence:.3f}>"
|
29
|
+
|
30
|
+
class ClassificationResult:
|
31
|
+
"""Holds the structured results of a classification task."""
|
32
|
+
model_id: str
|
33
|
+
using: str # Renamed from engine_type ('text' or 'vision')
|
34
|
+
timestamp: datetime
|
35
|
+
parameters: Dict[str, Any] # e.g., {'categories': [...], 'min_confidence': 0.1}
|
36
|
+
scores: List[CategoryScore] # List of scores above threshold, sorted by confidence
|
37
|
+
|
38
|
+
def __init__(self, model_id: str, using: str, timestamp: datetime, parameters: Dict[str, Any], scores: List[CategoryScore]):
|
39
|
+
if not isinstance(scores, list) or not all(isinstance(s, CategoryScore) for s in scores):
|
40
|
+
raise TypeError("Scores must be a list of CategoryScore objects.")
|
41
|
+
|
42
|
+
self.model_id = str(model_id)
|
43
|
+
self.using = str(using) # Renamed from engine_type
|
44
|
+
self.timestamp = timestamp
|
45
|
+
self.parameters = parameters if parameters is not None else {}
|
46
|
+
# Ensure scores are sorted descending by confidence
|
47
|
+
self.scores = sorted(scores, key=lambda s: s.confidence, reverse=True)
|
48
|
+
|
49
|
+
@property
|
50
|
+
def top_category(self) -> Optional[str]:
|
51
|
+
"""Returns the label of the category with the highest confidence."""
|
52
|
+
return self.scores[0].label if self.scores else None
|
53
|
+
|
54
|
+
@property
|
55
|
+
def top_confidence(self) -> Optional[float]:
|
56
|
+
"""Returns the confidence score of the top category."""
|
57
|
+
return self.scores[0].confidence if self.scores else None
|
58
|
+
|
59
|
+
def __repr__(self):
|
60
|
+
top_cat = f" top='{self.top_category}' ({self.top_confidence:.2f})" if self.scores else ""
|
61
|
+
num_scores = len(self.scores)
|
62
|
+
return f"<ClassificationResult model='{self.model_id}' using='{self.using}' scores={num_scores}{top_cat}>"
|
@@ -0,0 +1,63 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Callable, Iterable, Any, TypeVar
|
3
|
+
from tqdm.auto import tqdm
|
4
|
+
|
5
|
+
logger = logging.getLogger(__name__)
|
6
|
+
|
7
|
+
T = TypeVar("T") # Generic type for items in the collection
|
8
|
+
|
9
|
+
class ApplyMixin:
|
10
|
+
"""
|
11
|
+
Mixin class providing an `.apply()` method for collections.
|
12
|
+
|
13
|
+
Assumes the inheriting class implements `__iter__` and `__len__` appropriately
|
14
|
+
for the items to be processed by `apply`.
|
15
|
+
"""
|
16
|
+
def _get_items_for_apply(self) -> Iterable[Any]:
|
17
|
+
"""
|
18
|
+
Returns the iterable of items to apply the function to.
|
19
|
+
Defaults to iterating over `self`. Subclasses should override this
|
20
|
+
if the default iteration is not suitable for the apply operation.
|
21
|
+
"""
|
22
|
+
# Default to standard iteration over the collection itself
|
23
|
+
return iter(self)
|
24
|
+
|
25
|
+
def apply(self: Any, func: Callable[[Any, ...], Any], *args, **kwargs) -> None:
|
26
|
+
"""
|
27
|
+
Applies a function to each item in the collection.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
func: The function to apply to each item. The item itself
|
31
|
+
will be passed as the first argument to the function.
|
32
|
+
*args: Additional positional arguments to pass to func.
|
33
|
+
**kwargs: Additional keyword arguments to pass to func.
|
34
|
+
A special keyword argument 'show_progress' (bool, default=False)
|
35
|
+
can be used to display a progress bar.
|
36
|
+
"""
|
37
|
+
show_progress = kwargs.pop('show_progress', False)
|
38
|
+
# Derive unit name from class name
|
39
|
+
unit_name = self.__class__.__name__.lower()
|
40
|
+
items_iterable = self._get_items_for_apply()
|
41
|
+
|
42
|
+
# Need total count for tqdm, assumes __len__ is implemented by the inheriting class
|
43
|
+
total_items = 0
|
44
|
+
try:
|
45
|
+
total_items = len(self)
|
46
|
+
except TypeError: # Handle cases where __len__ might not be defined on self
|
47
|
+
logger.warning(f"Could not determine collection length for progress bar.")
|
48
|
+
|
49
|
+
if show_progress and total_items > 0:
|
50
|
+
items_iterable = tqdm(items_iterable, total=total_items, desc=f"Applying {func.__name__}", unit=unit_name)
|
51
|
+
elif show_progress:
|
52
|
+
logger.info(f"Applying {func.__name__} (progress bar disabled for zero/unknown length).")
|
53
|
+
|
54
|
+
for item in items_iterable:
|
55
|
+
try:
|
56
|
+
# Apply the function with the item and any extra args/kwargs
|
57
|
+
func(item, *args, **kwargs)
|
58
|
+
except Exception as e:
|
59
|
+
# Log and continue for batch operations
|
60
|
+
logger.error(f"Error applying {func.__name__} to {item}: {e}", exc_info=True)
|
61
|
+
# Optionally add a mechanism to collect errors
|
62
|
+
|
63
|
+
# Returns None, primarily used for side effects.
|
@@ -4,12 +4,24 @@ import logging
|
|
4
4
|
import os
|
5
5
|
import re # Added for safe path generation
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Type, Union
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Type, Union, Callable
|
8
|
+
import concurrent.futures # Import concurrent.futures
|
9
|
+
import time # Import time for logging timestamps
|
10
|
+
import threading # Import threading for logging thread information
|
8
11
|
|
9
12
|
from PIL import Image
|
10
13
|
from tqdm import tqdm
|
14
|
+
from tqdm.auto import tqdm as auto_tqdm
|
15
|
+
from tqdm.notebook import tqdm as notebook_tqdm
|
16
|
+
|
17
|
+
from natural_pdf.utils.tqdm_utils import get_tqdm
|
18
|
+
|
19
|
+
# Get the appropriate tqdm class once
|
20
|
+
tqdm = get_tqdm()
|
11
21
|
|
12
22
|
# Set up logger early
|
23
|
+
# Configure logging to include thread information
|
24
|
+
# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(threadName)s - %(name)s - %(levelname)s - %(message)s')
|
13
25
|
logger = logging.getLogger(__name__)
|
14
26
|
|
15
27
|
from natural_pdf.core.pdf import PDF
|
@@ -36,9 +48,11 @@ except ImportError as e:
|
|
36
48
|
SearchServiceProtocol, SearchOptions, Indexable = object, object, object
|
37
49
|
|
38
50
|
from natural_pdf.search.searchable_mixin import SearchableMixin # Import the new mixin
|
51
|
+
# Import the ApplyMixin
|
52
|
+
from natural_pdf.collections.mixins import ApplyMixin
|
39
53
|
|
40
54
|
|
41
|
-
class PDFCollection(SearchableMixin): # Inherit from
|
55
|
+
class PDFCollection(SearchableMixin, ApplyMixin): # Inherit from ApplyMixin
|
42
56
|
def __init__(
|
43
57
|
self,
|
44
58
|
source: Union[str, Iterable[Union[str, "PDF"]]],
|
@@ -237,30 +251,214 @@ class PDFCollection(SearchableMixin): # Inherit from the mixin
|
|
237
251
|
|
238
252
|
def __repr__(self) -> str:
|
239
253
|
# Removed search status
|
240
|
-
return f"<PDFCollection(count={len(self)})>"
|
254
|
+
return f"<PDFCollection(count={len(self._pdfs)})>"
|
241
255
|
|
242
256
|
@property
|
243
257
|
def pdfs(self) -> List["PDF"]:
|
244
258
|
"""Returns the list of PDF objects held by the collection."""
|
245
259
|
return self._pdfs
|
246
260
|
|
247
|
-
def
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
261
|
+
def find_all(
|
262
|
+
self,
|
263
|
+
selector: str,
|
264
|
+
apply_exclusions: bool = True, # Added explicit parameter
|
265
|
+
regex: bool = False, # Added explicit parameter
|
266
|
+
case: bool = True, # Added explicit parameter
|
267
|
+
**kwargs
|
268
|
+
) -> "ElementCollection":
|
269
|
+
"""
|
270
|
+
Find all elements matching the selector across all PDFs in the collection.
|
271
|
+
|
272
|
+
This creates an ElementCollection that can span multiple PDFs. Note that
|
273
|
+
some ElementCollection methods have limitations when spanning PDFs.
|
274
|
+
|
275
|
+
Args:
|
276
|
+
selector: CSS-like selector string to query elements
|
277
|
+
apply_exclusions: Whether to exclude elements in exclusion regions (default: True)
|
278
|
+
regex: Whether to use regex for text search in :contains (default: False)
|
279
|
+
case: Whether to do case-sensitive text search (default: True)
|
280
|
+
**kwargs: Additional keyword arguments passed to the find_all method of each PDF
|
281
|
+
|
282
|
+
Returns:
|
283
|
+
ElementCollection containing all matching elements across all PDFs
|
284
|
+
"""
|
285
|
+
from natural_pdf.elements.collections import ElementCollection
|
286
|
+
|
287
|
+
# Collect elements from all PDFs
|
288
|
+
all_elements = []
|
252
289
|
for pdf in self._pdfs:
|
253
|
-
# We need to figure out which pages belong to which PDF if batching here
|
254
|
-
# For now, simpler to call on each PDF
|
255
290
|
try:
|
256
|
-
#
|
257
|
-
pdf.
|
291
|
+
# Explicitly pass the relevant arguments down
|
292
|
+
elements = pdf.find_all(
|
293
|
+
selector,
|
294
|
+
apply_exclusions=apply_exclusions,
|
295
|
+
regex=regex,
|
296
|
+
case=case,
|
297
|
+
**kwargs
|
298
|
+
)
|
299
|
+
all_elements.extend(elements.elements)
|
258
300
|
except Exception as e:
|
259
|
-
logger.error(f"
|
301
|
+
logger.error(f"Error finding elements in {pdf.path}: {e}", exc_info=True)
|
302
|
+
|
303
|
+
return ElementCollection(all_elements)
|
304
|
+
|
305
|
+
def apply_ocr(
|
306
|
+
self,
|
307
|
+
engine: Optional[str] = None,
|
308
|
+
languages: Optional[List[str]] = None,
|
309
|
+
min_confidence: Optional[float] = None,
|
310
|
+
device: Optional[str] = None,
|
311
|
+
resolution: Optional[int] = None,
|
312
|
+
apply_exclusions: bool = True,
|
313
|
+
detect_only: bool = False,
|
314
|
+
replace: bool = True,
|
315
|
+
options: Optional[Any] = None,
|
316
|
+
pages: Optional[Union[slice, List[int]]] = None,
|
317
|
+
max_workers: Optional[int] = None,
|
318
|
+
) -> "PDFCollection":
|
319
|
+
"""
|
320
|
+
Apply OCR to all PDFs in the collection, potentially in parallel.
|
321
|
+
|
322
|
+
Args:
|
323
|
+
engine: OCR engine to use (e.g., 'easyocr', 'paddleocr', 'surya')
|
324
|
+
languages: List of language codes for OCR
|
325
|
+
min_confidence: Minimum confidence threshold for text detection
|
326
|
+
device: Device to use for OCR (e.g., 'cpu', 'cuda')
|
327
|
+
resolution: DPI resolution for page rendering
|
328
|
+
apply_exclusions: Whether to apply exclusion regions
|
329
|
+
detect_only: If True, only detect text regions without extracting text
|
330
|
+
replace: If True, replace existing OCR elements
|
331
|
+
options: Engine-specific options
|
332
|
+
pages: Specific pages to process (None for all pages)
|
333
|
+
max_workers: Maximum number of threads to process PDFs concurrently.
|
334
|
+
If None or 1, processing is sequential. (default: None)
|
335
|
+
|
336
|
+
Returns:
|
337
|
+
Self for method chaining
|
338
|
+
"""
|
339
|
+
PDF = self._get_pdf_class()
|
340
|
+
logger.info(f"Applying OCR to {len(self._pdfs)} PDFs in collection (max_workers={max_workers})...")
|
341
|
+
|
342
|
+
# Worker function takes PDF object again
|
343
|
+
def _process_pdf(pdf: PDF):
|
344
|
+
"""Helper function to apply OCR to a single PDF, handling errors."""
|
345
|
+
thread_id = threading.current_thread().name # Get thread name for logging
|
346
|
+
pdf_path = pdf.path # Get path for logging
|
347
|
+
logger.debug(f"[{thread_id}] Starting OCR process for: {pdf_path}")
|
348
|
+
start_time = time.monotonic()
|
349
|
+
try:
|
350
|
+
pdf.apply_ocr( # Call apply_ocr on the original PDF object
|
351
|
+
pages=pages,
|
352
|
+
engine=engine,
|
353
|
+
languages=languages,
|
354
|
+
min_confidence=min_confidence,
|
355
|
+
device=device,
|
356
|
+
resolution=resolution,
|
357
|
+
apply_exclusions=apply_exclusions,
|
358
|
+
detect_only=detect_only,
|
359
|
+
replace=replace,
|
360
|
+
options=options,
|
361
|
+
# Note: We might want a max_workers here too for page rendering?
|
362
|
+
# For now, PDF.apply_ocr doesn't have it.
|
363
|
+
)
|
364
|
+
end_time = time.monotonic()
|
365
|
+
logger.debug(f"[{thread_id}] Finished OCR process for: {pdf_path} (Duration: {end_time - start_time:.2f}s)")
|
366
|
+
return pdf_path, None
|
367
|
+
except Exception as e:
|
368
|
+
end_time = time.monotonic()
|
369
|
+
logger.error(f"[{thread_id}] Failed OCR process for {pdf_path} after {end_time - start_time:.2f}s: {e}", exc_info=False)
|
370
|
+
return pdf_path, e # Return path and error
|
371
|
+
|
372
|
+
# Use ThreadPoolExecutor for parallel processing if max_workers > 1
|
373
|
+
if max_workers is not None and max_workers > 1:
|
374
|
+
futures = []
|
375
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="OCRWorker") as executor:
|
376
|
+
for pdf in self._pdfs:
|
377
|
+
# Submit the PDF object to the worker function
|
378
|
+
futures.append(executor.submit(_process_pdf, pdf))
|
379
|
+
|
380
|
+
# Use the selected tqdm class with as_completed for progress tracking
|
381
|
+
progress_bar = tqdm(
|
382
|
+
concurrent.futures.as_completed(futures),
|
383
|
+
total=len(self._pdfs),
|
384
|
+
desc="Applying OCR (Parallel)",
|
385
|
+
unit="pdf"
|
386
|
+
)
|
387
|
+
|
388
|
+
for future in progress_bar:
|
389
|
+
pdf_path, error = future.result() # Get result (or exception)
|
390
|
+
if error:
|
391
|
+
progress_bar.set_postfix_str(f"Error: {pdf_path}", refresh=True)
|
392
|
+
# Progress is updated automatically by tqdm
|
393
|
+
|
394
|
+
else: # Sequential processing (max_workers is None or 1)
|
395
|
+
logger.info("Applying OCR sequentially...")
|
396
|
+
# Use the selected tqdm class for sequential too for consistency
|
397
|
+
# Iterate over PDF objects directly for sequential
|
398
|
+
for pdf in tqdm(self._pdfs, desc="Applying OCR (Sequential)", unit="pdf"):
|
399
|
+
_process_pdf(pdf) # Call helper directly with PDF object
|
400
|
+
|
401
|
+
logger.info("Finished applying OCR across the collection.")
|
260
402
|
return self
|
261
403
|
|
262
|
-
|
263
|
-
|
404
|
+
def correct_ocr(
|
405
|
+
self,
|
406
|
+
correction_callback: Callable[[Any], Optional[str]],
|
407
|
+
max_workers: Optional[int] = None,
|
408
|
+
progress_callback: Optional[Callable[[], None]] = None,
|
409
|
+
) -> "PDFCollection":
|
410
|
+
"""
|
411
|
+
Apply OCR correction to all relevant elements across all pages and PDFs
|
412
|
+
in the collection using a single progress bar.
|
413
|
+
|
414
|
+
Args:
|
415
|
+
correction_callback: Function to apply to each OCR element.
|
416
|
+
It receives the element and should return
|
417
|
+
the corrected text (str) or None.
|
418
|
+
max_workers: Max threads to use for parallel execution within each page.
|
419
|
+
progress_callback: Optional callback function to call after processing each element.
|
420
|
+
|
421
|
+
Returns:
|
422
|
+
Self for method chaining.
|
423
|
+
"""
|
424
|
+
PDF = self._get_pdf_class() # Ensure PDF class is available
|
425
|
+
if not callable(correction_callback):
|
426
|
+
raise TypeError("`correction_callback` must be a callable function.")
|
427
|
+
|
428
|
+
logger.info(f"Gathering OCR elements from {len(self._pdfs)} PDFs for correction...")
|
429
|
+
|
430
|
+
# 1. Gather all target elements using the collection's find_all
|
431
|
+
# Crucially, set apply_exclusions=False to include elements in headers/footers etc.
|
432
|
+
all_ocr_elements = self.find_all("text[source=ocr]", apply_exclusions=False).elements
|
433
|
+
|
434
|
+
if not all_ocr_elements:
|
435
|
+
logger.info("No OCR elements found in the collection to correct.")
|
436
|
+
return self
|
437
|
+
|
438
|
+
total_elements = len(all_ocr_elements)
|
439
|
+
logger.info(f"Found {total_elements} OCR elements across the collection. Starting correction process...")
|
440
|
+
|
441
|
+
# 2. Initialize the progress bar
|
442
|
+
progress_bar = tqdm(total=total_elements, desc="Correcting OCR Elements", unit="element")
|
443
|
+
|
444
|
+
# 3. Iterate through PDFs and delegate to PDF.correct_ocr
|
445
|
+
# PDF.correct_ocr handles page iteration and passing the progress callback down.
|
446
|
+
for pdf in self._pdfs:
|
447
|
+
if not pdf.pages:
|
448
|
+
continue
|
449
|
+
try:
|
450
|
+
pdf.correct_ocr(
|
451
|
+
correction_callback=correction_callback,
|
452
|
+
max_workers=max_workers,
|
453
|
+
progress_callback=progress_bar.update # Pass the bar's update method
|
454
|
+
)
|
455
|
+
except Exception as e:
|
456
|
+
logger.error(f"Error occurred during correction process for PDF {pdf.path}: {e}", exc_info=True)
|
457
|
+
# Decide if we should stop or continue? For now, continue.
|
458
|
+
|
459
|
+
progress_bar.close()
|
460
|
+
|
461
|
+
return self
|
264
462
|
|
265
463
|
def categorize(self, categories: List[str], **kwargs):
|
266
464
|
"""Categorizes PDFs in the collection based on content or features."""
|
@@ -279,14 +477,17 @@ class PDFCollection(SearchableMixin): # Inherit from the mixin
|
|
279
477
|
"""
|
280
478
|
try:
|
281
479
|
from natural_pdf.utils.packaging import create_correction_task_package
|
480
|
+
|
282
481
|
# Pass the collection itself (self) as the source
|
283
482
|
create_correction_task_package(source=self, output_zip_path=output_zip_path, **kwargs)
|
284
483
|
except ImportError:
|
285
|
-
logger.error(
|
484
|
+
logger.error(
|
485
|
+
"Failed to import 'create_correction_task_package'. Packaging utility might be missing."
|
486
|
+
)
|
286
487
|
# Or raise
|
287
488
|
except Exception as e:
|
288
489
|
logger.error(f"Failed to export correction task for collection: {e}", exc_info=True)
|
289
|
-
raise
|
490
|
+
raise # Re-raise the exception from the utility function
|
290
491
|
|
291
492
|
# --- Mixin Required Implementation ---
|
292
493
|
def get_indexable_items(self) -> Iterable[Indexable]:
|
@@ -306,3 +507,111 @@ class PDFCollection(SearchableMixin): # Inherit from the mixin
|
|
306
507
|
# logger.debug(f"Skipping empty page {page.page_number} from PDF '{pdf.path}'.")
|
307
508
|
# continue
|
308
509
|
yield page
|
510
|
+
|
511
|
+
# --- Classification Method --- #
|
512
|
+
def classify_all(
|
513
|
+
self,
|
514
|
+
categories: List[str],
|
515
|
+
model: str = "text",
|
516
|
+
max_workers: Optional[int] = None,
|
517
|
+
**kwargs,
|
518
|
+
) -> "PDFCollection":
|
519
|
+
"""
|
520
|
+
Classify all pages across all PDFs in the collection, potentially in parallel.
|
521
|
+
|
522
|
+
This method uses the unified `classify_all` approach, delegating page
|
523
|
+
classification to each PDF's `classify_pages` method.
|
524
|
+
It displays a progress bar tracking individual pages.
|
525
|
+
|
526
|
+
Args:
|
527
|
+
categories: A list of string category names.
|
528
|
+
model: Model identifier ('text', 'vision', or specific HF ID).
|
529
|
+
max_workers: Maximum number of threads to process PDFs concurrently.
|
530
|
+
If None or 1, processing is sequential.
|
531
|
+
**kwargs: Additional arguments passed down to `pdf.classify_pages` and
|
532
|
+
subsequently to `page.classify` (e.g., device,
|
533
|
+
confidence_threshold, resolution).
|
534
|
+
|
535
|
+
Returns:
|
536
|
+
Self for method chaining.
|
537
|
+
|
538
|
+
Raises:
|
539
|
+
ValueError: If categories list is empty.
|
540
|
+
ClassificationError: If classification fails for any page (will stop processing).
|
541
|
+
ImportError: If classification dependencies are missing.
|
542
|
+
"""
|
543
|
+
PDF = self._get_pdf_class()
|
544
|
+
if not categories:
|
545
|
+
raise ValueError("Categories list cannot be empty.")
|
546
|
+
|
547
|
+
logger.info(f"Starting classification for {len(self._pdfs)} PDFs in collection (model: '{model}')...")
|
548
|
+
|
549
|
+
# Calculate total pages for the progress bar
|
550
|
+
total_pages = sum(len(pdf.pages) for pdf in self._pdfs if pdf.pages)
|
551
|
+
if total_pages == 0:
|
552
|
+
logger.warning("No pages found in the PDF collection to classify.")
|
553
|
+
return self
|
554
|
+
|
555
|
+
progress_bar = tqdm(
|
556
|
+
total=total_pages,
|
557
|
+
desc=f"Classifying Pages (model: {model})",
|
558
|
+
unit="page"
|
559
|
+
)
|
560
|
+
|
561
|
+
# Worker function
|
562
|
+
def _process_pdf_classification(pdf: PDF):
|
563
|
+
thread_id = threading.current_thread().name
|
564
|
+
pdf_path = pdf.path
|
565
|
+
logger.debug(f"[{thread_id}] Starting classification process for: {pdf_path}")
|
566
|
+
start_time = time.monotonic()
|
567
|
+
try:
|
568
|
+
# Call classify_pages on the PDF, passing the progress callback
|
569
|
+
pdf.classify_pages(
|
570
|
+
categories=categories,
|
571
|
+
model=model,
|
572
|
+
progress_callback=progress_bar.update,
|
573
|
+
**kwargs
|
574
|
+
)
|
575
|
+
end_time = time.monotonic()
|
576
|
+
logger.debug(f"[{thread_id}] Finished classification for: {pdf_path} (Duration: {end_time - start_time:.2f}s)")
|
577
|
+
return pdf_path, None # Return path and no error
|
578
|
+
except Exception as e:
|
579
|
+
end_time = time.monotonic()
|
580
|
+
# Error is logged within classify_pages, but log summary here
|
581
|
+
logger.error(f"[{thread_id}] Failed classification process for {pdf_path} after {end_time - start_time:.2f}s: {e}", exc_info=False)
|
582
|
+
# Close progress bar immediately on error to avoid hanging
|
583
|
+
progress_bar.close()
|
584
|
+
# Re-raise the exception to stop the entire collection processing
|
585
|
+
raise
|
586
|
+
|
587
|
+
# Use ThreadPoolExecutor for parallel processing if max_workers > 1
|
588
|
+
try:
|
589
|
+
if max_workers is not None and max_workers > 1:
|
590
|
+
logger.info(f"Classifying PDFs in parallel with {max_workers} workers.")
|
591
|
+
futures = []
|
592
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="ClassifyWorker") as executor:
|
593
|
+
for pdf in self._pdfs:
|
594
|
+
futures.append(executor.submit(_process_pdf_classification, pdf))
|
595
|
+
|
596
|
+
# Wait for all futures to complete (progress updated by callback)
|
597
|
+
# Exceptions are raised by future.result() if worker failed
|
598
|
+
for future in concurrent.futures.as_completed(futures):
|
599
|
+
future.result() # Raise exception if worker failed
|
600
|
+
|
601
|
+
else: # Sequential processing
|
602
|
+
logger.info("Classifying PDFs sequentially.")
|
603
|
+
for pdf in self._pdfs:
|
604
|
+
_process_pdf_classification(pdf)
|
605
|
+
|
606
|
+
logger.info("Finished classification across the collection.")
|
607
|
+
|
608
|
+
finally:
|
609
|
+
# Ensure progress bar is closed even if errors occurred elsewhere
|
610
|
+
if not progress_bar.disable and progress_bar.n < progress_bar.total:
|
611
|
+
progress_bar.close()
|
612
|
+
elif progress_bar.disable is False:
|
613
|
+
progress_bar.close()
|
614
|
+
|
615
|
+
return self
|
616
|
+
|
617
|
+
# --- End Classification Method --- #
|