natural-pdf 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- natural_pdf/__init__.py +24 -40
- natural_pdf/classification/manager.py +26 -22
- natural_pdf/classification/mixin.py +7 -7
- natural_pdf/classification/results.py +17 -9
- natural_pdf/collections/mixins.py +17 -0
- natural_pdf/collections/pdf_collection.py +78 -46
- natural_pdf/core/page.py +17 -17
- natural_pdf/core/pdf.py +192 -18
- natural_pdf/elements/collections.py +307 -3
- natural_pdf/elements/region.py +2 -3
- natural_pdf/exporters/hocr.py +540 -0
- natural_pdf/exporters/hocr_font.py +142 -0
- natural_pdf/exporters/original_pdf.py +130 -0
- natural_pdf/exporters/searchable_pdf.py +3 -3
- natural_pdf/ocr/engine_surya.py +1 -1
- {natural_pdf-0.1.9.dist-info → natural_pdf-0.1.11.dist-info}/METADATA +1 -2
- {natural_pdf-0.1.9.dist-info → natural_pdf-0.1.11.dist-info}/RECORD +20 -17
- {natural_pdf-0.1.9.dist-info → natural_pdf-0.1.11.dist-info}/WHEEL +1 -1
- {natural_pdf-0.1.9.dist-info → natural_pdf-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {natural_pdf-0.1.9.dist-info → natural_pdf-0.1.11.dist-info}/top_level.txt +0 -0
natural_pdf/__init__.py
CHANGED
@@ -37,72 +37,56 @@ def configure_logging(level=logging.INFO, handler=None):
|
|
37
37
|
logger.propagate = False
|
38
38
|
|
39
39
|
|
40
|
+
# Version
|
41
|
+
__version__ = "0.1.1"
|
42
|
+
|
43
|
+
# Core imports
|
44
|
+
from natural_pdf.collections.pdf_collection import PDFCollection
|
40
45
|
from natural_pdf.core.page import Page
|
41
46
|
from natural_pdf.core.pdf import PDF
|
42
47
|
from natural_pdf.elements.collections import ElementCollection
|
43
48
|
from natural_pdf.elements.region import Region
|
44
49
|
|
45
|
-
|
46
|
-
try:
|
47
|
-
from natural_pdf.qa import DocumentQA, get_qa_engine
|
48
|
-
|
49
|
-
HAS_QA = True
|
50
|
-
except ImportError:
|
51
|
-
HAS_QA = False
|
52
|
-
|
53
|
-
__version__ = "0.1.1"
|
54
|
-
|
55
|
-
__all__ = [
|
56
|
-
"PDF",
|
57
|
-
"PDFCollection",
|
58
|
-
"Page",
|
59
|
-
"Region",
|
60
|
-
"ElementCollection",
|
61
|
-
"TextSearchOptions",
|
62
|
-
"MultiModalSearchOptions",
|
63
|
-
"BaseSearchOptions",
|
64
|
-
"configure_logging",
|
65
|
-
]
|
66
|
-
|
67
|
-
if HAS_QA:
|
68
|
-
__all__.extend(["DocumentQA", "get_qa_engine"])
|
69
|
-
|
70
|
-
|
71
|
-
from .collections.pdf_collection import PDFCollection
|
72
|
-
|
73
|
-
# Core classes
|
74
|
-
from .core.pdf import PDF
|
75
|
-
from .elements.region import Region
|
50
|
+
ElementCollection = None
|
76
51
|
|
77
52
|
# Search options (if extras installed)
|
78
53
|
try:
|
79
|
-
from .search.search_options import BaseSearchOptions, MultiModalSearchOptions, TextSearchOptions
|
54
|
+
from natural_pdf.search.search_options import BaseSearchOptions, MultiModalSearchOptions, TextSearchOptions
|
80
55
|
except ImportError:
|
81
56
|
# Define dummy classes if extras not installed, so imports don't break
|
82
57
|
# but using them will raise the ImportError from check_haystack_availability
|
83
|
-
class
|
58
|
+
class BaseSearchOptions:
|
84
59
|
def __init__(self, *args, **kwargs):
|
85
60
|
pass
|
86
61
|
|
87
|
-
class
|
62
|
+
class TextSearchOptions:
|
88
63
|
def __init__(self, *args, **kwargs):
|
89
64
|
pass
|
90
65
|
|
91
|
-
class
|
66
|
+
class MultiModalSearchOptions:
|
92
67
|
def __init__(self, *args, **kwargs):
|
93
68
|
pass
|
94
69
|
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
70
|
+
# Import QA module if available
|
71
|
+
try:
|
72
|
+
from natural_pdf.qa import DocumentQA, get_qa_engine
|
73
|
+
HAS_QA = True
|
74
|
+
except ImportError:
|
75
|
+
HAS_QA = False
|
99
76
|
|
100
77
|
# Explicitly define what gets imported with 'from natural_pdf import *'
|
101
78
|
__all__ = [
|
102
79
|
"PDF",
|
103
80
|
"PDFCollection",
|
81
|
+
"Page",
|
104
82
|
"Region",
|
105
|
-
"
|
83
|
+
"ElementCollection",
|
84
|
+
"TextSearchOptions",
|
106
85
|
"MultiModalSearchOptions",
|
107
86
|
"BaseSearchOptions",
|
87
|
+
"configure_logging",
|
108
88
|
]
|
89
|
+
|
90
|
+
# Add QA components to __all__ if available
|
91
|
+
if HAS_QA:
|
92
|
+
__all__.extend(["DocumentQA", "get_qa_engine"])
|
@@ -161,7 +161,7 @@ class ClassificationManager:
|
|
161
161
|
def classify_item(
|
162
162
|
self,
|
163
163
|
item_content: Union[str, Image.Image],
|
164
|
-
|
164
|
+
labels: List[str],
|
165
165
|
model_id: Optional[str] = None,
|
166
166
|
using: Optional[str] = None,
|
167
167
|
min_confidence: float = 0.0,
|
@@ -193,13 +193,13 @@ class ClassificationManager:
|
|
193
193
|
else self.DEFAULT_VISION_MODEL
|
194
194
|
)
|
195
195
|
|
196
|
-
if not
|
197
|
-
raise ValueError("
|
196
|
+
if not labels:
|
197
|
+
raise ValueError("Labels list cannot be empty.")
|
198
198
|
|
199
199
|
pipeline_instance = self._get_pipeline(model_id, effective_using)
|
200
200
|
timestamp = datetime.now()
|
201
201
|
parameters = { # Store parameters used for this run
|
202
|
-
"
|
202
|
+
"labels": labels,
|
203
203
|
"model_id": model_id,
|
204
204
|
"using": effective_using,
|
205
205
|
"min_confidence": min_confidence,
|
@@ -214,7 +214,7 @@ class ClassificationManager:
|
|
214
214
|
# Handle potential kwargs for specific pipelines if needed
|
215
215
|
# The zero-shot pipelines expect `candidate_labels`
|
216
216
|
result_raw = pipeline_instance(
|
217
|
-
item_content, candidate_labels=
|
217
|
+
item_content, candidate_labels=labels, multi_label=multi_label, **kwargs
|
218
218
|
)
|
219
219
|
logger.debug(f"Raw pipeline result: {result_raw}")
|
220
220
|
|
@@ -226,7 +226,7 @@ class ClassificationManager:
|
|
226
226
|
for label, score_val in zip(result_raw["labels"], result_raw["scores"]):
|
227
227
|
if score_val >= min_confidence:
|
228
228
|
try:
|
229
|
-
scores_list.append(CategoryScore(label
|
229
|
+
scores_list.append(CategoryScore(label, score_val))
|
230
230
|
except (ValueError, TypeError) as score_err:
|
231
231
|
logger.warning(
|
232
232
|
f"Skipping invalid score from text pipeline: label='{label}', score={score_val}. Error: {score_err}"
|
@@ -241,7 +241,7 @@ class ClassificationManager:
|
|
241
241
|
label = item["label"]
|
242
242
|
if score_val >= min_confidence:
|
243
243
|
try:
|
244
|
-
scores_list.append(CategoryScore(label
|
244
|
+
scores_list.append(CategoryScore(label, score_val))
|
245
245
|
except (ValueError, TypeError) as score_err:
|
246
246
|
logger.warning(
|
247
247
|
f"Skipping invalid score from vision pipeline: label='{label}', score={score_val}. Error: {score_err}"
|
@@ -253,13 +253,15 @@ class ClassificationManager:
|
|
253
253
|
# Return empty result?
|
254
254
|
# scores_list = []
|
255
255
|
|
256
|
-
|
256
|
+
# ClassificationResult now calculates top score/category internally
|
257
|
+
result_obj = ClassificationResult(
|
258
|
+
scores=scores_list, # Pass the filtered list
|
257
259
|
model_id=model_id,
|
258
260
|
using=effective_using,
|
259
|
-
timestamp=timestamp,
|
260
261
|
parameters=parameters,
|
261
|
-
|
262
|
+
timestamp=timestamp,
|
262
263
|
)
|
264
|
+
return result_obj
|
263
265
|
# --- End Processing --- #
|
264
266
|
|
265
267
|
except Exception as e:
|
@@ -273,7 +275,7 @@ class ClassificationManager:
|
|
273
275
|
def classify_batch(
|
274
276
|
self,
|
275
277
|
item_contents: List[Union[str, Image.Image]],
|
276
|
-
|
278
|
+
labels: List[str],
|
277
279
|
model_id: Optional[str] = None,
|
278
280
|
using: Optional[str] = None,
|
279
281
|
min_confidence: float = 0.0,
|
@@ -307,13 +309,13 @@ class ClassificationManager:
|
|
307
309
|
else self.DEFAULT_VISION_MODEL
|
308
310
|
)
|
309
311
|
|
310
|
-
if not
|
311
|
-
raise ValueError("
|
312
|
+
if not labels:
|
313
|
+
raise ValueError("Labels list cannot be empty.")
|
312
314
|
|
313
315
|
pipeline_instance = self._get_pipeline(model_id, effective_using)
|
314
316
|
timestamp = datetime.now() # Single timestamp for the batch run
|
315
317
|
parameters = { # Parameters for the whole batch
|
316
|
-
"
|
318
|
+
"labels": labels,
|
317
319
|
"model_id": model_id,
|
318
320
|
"using": effective_using,
|
319
321
|
"min_confidence": min_confidence,
|
@@ -331,7 +333,7 @@ class ClassificationManager:
|
|
331
333
|
# Use pipeline directly for batching
|
332
334
|
results_iterator = pipeline_instance(
|
333
335
|
item_contents,
|
334
|
-
candidate_labels=
|
336
|
+
candidate_labels=labels,
|
335
337
|
multi_label=multi_label,
|
336
338
|
batch_size=batch_size,
|
337
339
|
**kwargs,
|
@@ -362,9 +364,7 @@ class ClassificationManager:
|
|
362
364
|
for label, score_val in zip(raw_result["labels"], raw_result["scores"]):
|
363
365
|
if score_val >= min_confidence:
|
364
366
|
try:
|
365
|
-
scores_list.append(
|
366
|
-
CategoryScore(label=label, confidence=score_val)
|
367
|
-
)
|
367
|
+
scores_list.append(CategoryScore(label, score_val))
|
368
368
|
except (ValueError, TypeError) as score_err:
|
369
369
|
logger.warning(
|
370
370
|
f"Skipping invalid score from text pipeline batch: label='{label}', score={score_val}. Error: {score_err}"
|
@@ -376,9 +376,7 @@ class ClassificationManager:
|
|
376
376
|
score_val = item["score"]
|
377
377
|
label = item["label"]
|
378
378
|
if score_val >= min_confidence:
|
379
|
-
scores_list.append(
|
380
|
-
CategoryScore(label=label, confidence=score_val)
|
381
|
-
)
|
379
|
+
scores_list.append(CategoryScore(label, score_val))
|
382
380
|
except (KeyError, ValueError, TypeError) as item_err:
|
383
381
|
logger.warning(
|
384
382
|
f"Skipping invalid item in vision result list from batch: {item}. Error: {item_err}"
|
@@ -394,14 +392,20 @@ class ClassificationManager:
|
|
394
392
|
)
|
395
393
|
# scores_list remains empty for this item
|
396
394
|
|
395
|
+
# --- Determine top category and score ---
|
396
|
+
scores_list.sort(key=lambda s: s.score, reverse=True)
|
397
|
+
top_category = scores_list[0].label
|
398
|
+
top_score = scores_list[0].score
|
399
|
+
# --- End Determine top category ---
|
400
|
+
|
397
401
|
# Append result object for this item
|
398
402
|
batch_results_list.append(
|
399
403
|
ClassificationResult(
|
404
|
+
scores=scores_list, # Pass the full list, init will sort/filter
|
400
405
|
model_id=model_id,
|
401
406
|
using=effective_using,
|
402
407
|
timestamp=timestamp, # Use same timestamp for batch
|
403
408
|
parameters=parameters, # Use same params for batch
|
404
|
-
scores=scores_list,
|
405
409
|
)
|
406
410
|
)
|
407
411
|
# --- End Processing --- #
|
@@ -44,9 +44,9 @@ class ClassificationMixin:
|
|
44
44
|
|
45
45
|
def classify(
|
46
46
|
self,
|
47
|
-
|
48
|
-
model: Optional[str] = None,
|
49
|
-
using: Optional[str] = None,
|
47
|
+
labels: List[str],
|
48
|
+
model: Optional[str] = None,
|
49
|
+
using: Optional[str] = None,
|
50
50
|
min_confidence: float = 0.0,
|
51
51
|
analysis_key: str = "classification", # Default key
|
52
52
|
multi_label: bool = False,
|
@@ -60,7 +60,7 @@ class ClassificationMixin:
|
|
60
60
|
result under that key.
|
61
61
|
|
62
62
|
Args:
|
63
|
-
|
63
|
+
labels: A list of string category names.
|
64
64
|
model: Model identifier (e.g., 'text', 'vision', HF ID). Defaults handled by manager.
|
65
65
|
using: Optional processing mode ('text' or 'vision'). If None, inferred by manager.
|
66
66
|
min_confidence: Minimum confidence threshold for results (0.0-1.0).
|
@@ -103,9 +103,9 @@ class ClassificationMixin:
|
|
103
103
|
# Manager now returns a ClassificationResult object
|
104
104
|
result_obj: ClassificationResult = manager.classify_item(
|
105
105
|
item_content=content,
|
106
|
-
|
107
|
-
model_id=effective_model_id,
|
108
|
-
using=inferred_using,
|
106
|
+
labels=labels,
|
107
|
+
model_id=effective_model_id,
|
108
|
+
using=inferred_using,
|
109
109
|
min_confidence=min_confidence,
|
110
110
|
multi_label=multi_label,
|
111
111
|
**kwargs,
|
@@ -11,19 +11,19 @@ logger = logging.getLogger(__name__)
|
|
11
11
|
class CategoryScore:
|
12
12
|
"""Represents a category and its confidence score from classification."""
|
13
13
|
|
14
|
-
|
14
|
+
label: str
|
15
15
|
score: float
|
16
16
|
|
17
17
|
def to_dict(self) -> Dict[str, Any]:
|
18
18
|
"""Convert to dictionary for serialization."""
|
19
|
-
return {"category": self.
|
19
|
+
return {"category": self.label, "score": self.score}
|
20
20
|
|
21
21
|
|
22
22
|
@dataclass
|
23
23
|
class ClassificationResult:
|
24
24
|
"""Results from a classification operation."""
|
25
25
|
|
26
|
-
category: str
|
26
|
+
category: Optional[str] # Can be None if scores are empty
|
27
27
|
score: float
|
28
28
|
scores: List[CategoryScore]
|
29
29
|
model_id: str
|
@@ -33,17 +33,25 @@ class ClassificationResult:
|
|
33
33
|
|
34
34
|
def __init__(
|
35
35
|
self,
|
36
|
-
|
37
|
-
score: float,
|
38
|
-
scores: List[CategoryScore],
|
36
|
+
scores: List[CategoryScore], # Now the primary source
|
39
37
|
model_id: str,
|
40
38
|
using: str,
|
41
39
|
parameters: Optional[Dict[str, Any]] = None,
|
42
40
|
timestamp: Optional[datetime] = None,
|
43
41
|
):
|
44
|
-
|
45
|
-
|
46
|
-
|
42
|
+
# Determine top category and score from the scores list
|
43
|
+
if scores:
|
44
|
+
# Sort scores descending by score to find the top one
|
45
|
+
sorted_scores = sorted(scores, key=lambda s: s.score, reverse=True)
|
46
|
+
self.category = sorted_scores[0].label
|
47
|
+
self.score = sorted_scores[0].score
|
48
|
+
self.scores = sorted_scores # Store the sorted list
|
49
|
+
else:
|
50
|
+
# Handle empty scores list
|
51
|
+
self.category = None
|
52
|
+
self.score = 0.0
|
53
|
+
self.scores = [] # Store empty list
|
54
|
+
|
47
55
|
self.model_id = model_id
|
48
56
|
self.using = using
|
49
57
|
self.parameters = parameters or {}
|
@@ -109,3 +109,20 @@ class ApplyMixin:
|
|
109
109
|
return PageCollection(results)
|
110
110
|
|
111
111
|
return results
|
112
|
+
|
113
|
+
def filter(self: Any, predicate: Callable[[Any], bool]) -> Any:
|
114
|
+
"""
|
115
|
+
Filters the collection based on a predicate function.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
predicate: A function that takes an item and returns True if the item
|
119
|
+
should be included in the result, False otherwise.
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
A new collection of the same type containing only the items
|
123
|
+
for which the predicate returned True.
|
124
|
+
"""
|
125
|
+
items_iterable = self._get_items_for_apply()
|
126
|
+
filtered_items = [item for item in items_iterable if predicate(item)]
|
127
|
+
|
128
|
+
return type(self)(filtered_items)
|
@@ -519,7 +519,7 @@ class PDFCollection(SearchableMixin, ApplyMixin, ExportMixin): # Add ExportMixi
|
|
519
519
|
|
520
520
|
return self
|
521
521
|
|
522
|
-
def categorize(self,
|
522
|
+
def categorize(self, labels: List[str], **kwargs):
|
523
523
|
"""Categorizes PDFs in the collection based on content or features."""
|
524
524
|
# Implementation requires integrating with classification models or logic
|
525
525
|
raise NotImplementedError("categorize requires classification implementation.")
|
@@ -570,85 +570,101 @@ class PDFCollection(SearchableMixin, ApplyMixin, ExportMixin): # Add ExportMixi
|
|
570
570
|
# --- Classification Method --- #
|
571
571
|
def classify_all(
|
572
572
|
self,
|
573
|
-
|
574
|
-
|
573
|
+
labels: List[str],
|
574
|
+
using: Optional[str] = None, # Default handled by PDF.classify -> manager
|
575
|
+
model: Optional[str] = None, # Optional model ID
|
575
576
|
max_workers: Optional[int] = None,
|
577
|
+
analysis_key: str = "classification", # Key for storing result in PDF.analyses
|
576
578
|
**kwargs,
|
577
579
|
) -> "PDFCollection":
|
578
580
|
"""
|
579
|
-
Classify
|
581
|
+
Classify each PDF document in the collection, potentially in parallel.
|
580
582
|
|
581
|
-
This method
|
582
|
-
|
583
|
-
|
583
|
+
This method delegates classification to each PDF object's `classify` method.
|
584
|
+
By default, uses the full extracted text of the PDF.
|
585
|
+
If `using='vision'`, it classifies the first page's image, but ONLY if
|
586
|
+
the PDF has a single page (raises ValueError otherwise).
|
584
587
|
|
585
588
|
Args:
|
586
|
-
|
587
|
-
|
589
|
+
labels: A list of string category names.
|
590
|
+
using: Processing mode ('text', 'vision'). If None, manager infers (defaulting to text).
|
591
|
+
model: Optional specific model identifier (e.g., HF ID). If None, manager uses default for 'using' mode.
|
588
592
|
max_workers: Maximum number of threads to process PDFs concurrently.
|
589
593
|
If None or 1, processing is sequential.
|
590
|
-
|
591
|
-
|
592
|
-
|
594
|
+
analysis_key: Key under which to store the ClassificationResult in each PDF's `analyses` dict.
|
595
|
+
**kwargs: Additional arguments passed down to `pdf.classify` (e.g., device,
|
596
|
+
min_confidence, multi_label, text extraction options).
|
593
597
|
|
594
598
|
Returns:
|
595
599
|
Self for method chaining.
|
596
600
|
|
597
601
|
Raises:
|
598
|
-
ValueError: If
|
599
|
-
ClassificationError: If classification fails for any
|
602
|
+
ValueError: If labels list is empty, or if using='vision' on a multi-page PDF.
|
603
|
+
ClassificationError: If classification fails for any PDF (will stop processing).
|
600
604
|
ImportError: If classification dependencies are missing.
|
601
605
|
"""
|
602
606
|
PDF = self._get_pdf_class()
|
603
|
-
if not
|
604
|
-
raise ValueError("
|
607
|
+
if not labels:
|
608
|
+
raise ValueError("Labels list cannot be empty.")
|
605
609
|
|
610
|
+
if not self._pdfs:
|
611
|
+
logger.warning("PDFCollection is empty, skipping classification.")
|
612
|
+
return self
|
613
|
+
|
614
|
+
mode_desc = f"using='{using}'" if using else f"model='{model}'" if model else "default text"
|
606
615
|
logger.info(
|
607
|
-
f"Starting classification for {len(self._pdfs)} PDFs in collection (
|
616
|
+
f"Starting classification for {len(self._pdfs)} PDFs in collection ({mode_desc})..."
|
608
617
|
)
|
609
618
|
|
610
|
-
# Calculate total pages for the progress bar
|
611
|
-
total_pages = sum(len(pdf.pages) for pdf in self._pdfs if pdf.pages)
|
612
|
-
if total_pages == 0:
|
613
|
-
logger.warning("No pages found in the PDF collection to classify.")
|
614
|
-
return self
|
615
|
-
|
616
619
|
progress_bar = tqdm(
|
617
|
-
total=
|
620
|
+
total=len(self._pdfs), desc=f"Classifying PDFs ({mode_desc})", unit="pdf"
|
618
621
|
)
|
619
622
|
|
620
623
|
# Worker function
|
621
624
|
def _process_pdf_classification(pdf: PDF):
|
622
625
|
thread_id = threading.current_thread().name
|
623
626
|
pdf_path = pdf.path
|
624
|
-
logger.debug(f"[{thread_id}] Starting classification process for: {pdf_path}")
|
627
|
+
logger.debug(f"[{thread_id}] Starting classification process for PDF: {pdf_path}")
|
625
628
|
start_time = time.monotonic()
|
626
629
|
try:
|
627
|
-
# Call
|
628
|
-
pdf.
|
629
|
-
|
630
|
+
# Call classify directly on the PDF object
|
631
|
+
pdf.classify(
|
632
|
+
labels=labels,
|
633
|
+
using=using,
|
630
634
|
model=model,
|
631
|
-
|
632
|
-
**kwargs,
|
635
|
+
analysis_key=analysis_key,
|
636
|
+
**kwargs, # Pass other relevant args like min_confidence, multi_label
|
633
637
|
)
|
634
638
|
end_time = time.monotonic()
|
635
639
|
logger.debug(
|
636
|
-
f"[{thread_id}] Finished classification for: {pdf_path} (Duration: {end_time - start_time:.2f}s)"
|
640
|
+
f"[{thread_id}] Finished classification for PDF: {pdf_path} (Duration: {end_time - start_time:.2f}s)"
|
637
641
|
)
|
642
|
+
progress_bar.update(1) # Update progress bar upon success
|
638
643
|
return pdf_path, None # Return path and no error
|
639
|
-
except
|
644
|
+
except ValueError as ve:
|
645
|
+
# Catch specific error for vision on multi-page PDF
|
640
646
|
end_time = time.monotonic()
|
641
|
-
# Error is logged within classify_pages, but log summary here
|
642
647
|
logger.error(
|
643
|
-
f"[{thread_id}]
|
648
|
+
f"[{thread_id}] Skipped classification for {pdf_path} after {end_time - start_time:.2f}s: {ve}",
|
644
649
|
exc_info=False,
|
645
650
|
)
|
646
|
-
#
|
647
|
-
|
651
|
+
progress_bar.update(1) # Still update progress bar
|
652
|
+
return pdf_path, ve # Return the specific ValueError
|
653
|
+
except Exception as e:
|
654
|
+
end_time = time.monotonic()
|
655
|
+
logger.error(
|
656
|
+
f"[{thread_id}] Failed classification process for PDF {pdf_path} after {end_time - start_time:.2f}s: {e}",
|
657
|
+
exc_info=True, # Log full traceback for unexpected errors
|
658
|
+
)
|
659
|
+
# Close progress bar immediately on critical error to avoid hanging
|
660
|
+
if not progress_bar.disable:
|
661
|
+
progress_bar.close()
|
648
662
|
# Re-raise the exception to stop the entire collection processing
|
649
|
-
raise
|
663
|
+
raise ClassificationError(f"Classification failed for {pdf_path}: {e}") from e
|
650
664
|
|
651
665
|
# Use ThreadPoolExecutor for parallel processing if max_workers > 1
|
666
|
+
processed_count = 0
|
667
|
+
skipped_count = 0
|
652
668
|
try:
|
653
669
|
if max_workers is not None and max_workers > 1:
|
654
670
|
logger.info(f"Classifying PDFs in parallel with {max_workers} workers.")
|
@@ -659,23 +675,39 @@ class PDFCollection(SearchableMixin, ApplyMixin, ExportMixin): # Add ExportMixi
|
|
659
675
|
for pdf in self._pdfs:
|
660
676
|
futures.append(executor.submit(_process_pdf_classification, pdf))
|
661
677
|
|
662
|
-
# Wait for all futures to complete
|
663
|
-
#
|
678
|
+
# Wait for all futures to complete
|
679
|
+
# Progress updated within worker
|
664
680
|
for future in concurrent.futures.as_completed(futures):
|
665
|
-
|
681
|
+
processed_count += 1
|
682
|
+
pdf_path, error = (
|
683
|
+
future.result()
|
684
|
+
) # Raise ClassificationError if worker failed critically
|
685
|
+
if isinstance(error, ValueError):
|
686
|
+
# Logged in worker, just count as skipped
|
687
|
+
skipped_count += 1
|
666
688
|
|
667
689
|
else: # Sequential processing
|
668
690
|
logger.info("Classifying PDFs sequentially.")
|
669
691
|
for pdf in self._pdfs:
|
670
|
-
|
671
|
-
|
672
|
-
|
692
|
+
processed_count += 1
|
693
|
+
pdf_path, error = _process_pdf_classification(
|
694
|
+
pdf
|
695
|
+
) # Raise ClassificationError if worker failed critically
|
696
|
+
if isinstance(error, ValueError):
|
697
|
+
skipped_count += 1
|
698
|
+
|
699
|
+
final_message = (
|
700
|
+
f"Finished classification across the collection. Processed: {processed_count}"
|
701
|
+
)
|
702
|
+
if skipped_count > 0:
|
703
|
+
final_message += f", Skipped (e.g., vision on multi-page): {skipped_count}"
|
704
|
+
logger.info(final_message + ".")
|
673
705
|
|
674
706
|
finally:
|
675
|
-
# Ensure progress bar is closed
|
707
|
+
# Ensure progress bar is closed properly
|
676
708
|
if not progress_bar.disable and progress_bar.n < progress_bar.total:
|
677
|
-
progress_bar.
|
678
|
-
|
709
|
+
progress_bar.n = progress_bar.total # Ensure it reaches 100%
|
710
|
+
if not progress_bar.disable:
|
679
711
|
progress_bar.close()
|
680
712
|
|
681
713
|
return self
|