natural-pdf 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
natural_pdf/__init__.py CHANGED
@@ -37,72 +37,56 @@ def configure_logging(level=logging.INFO, handler=None):
37
37
  logger.propagate = False
38
38
 
39
39
 
40
+ # Version
41
+ __version__ = "0.1.1"
42
+
43
+ # Core imports
44
+ from natural_pdf.collections.pdf_collection import PDFCollection
40
45
  from natural_pdf.core.page import Page
41
46
  from natural_pdf.core.pdf import PDF
42
47
  from natural_pdf.elements.collections import ElementCollection
43
48
  from natural_pdf.elements.region import Region
44
49
 
45
- # Import QA module if available
46
- try:
47
- from natural_pdf.qa import DocumentQA, get_qa_engine
48
-
49
- HAS_QA = True
50
- except ImportError:
51
- HAS_QA = False
52
-
53
- __version__ = "0.1.1"
54
-
55
- __all__ = [
56
- "PDF",
57
- "PDFCollection",
58
- "Page",
59
- "Region",
60
- "ElementCollection",
61
- "TextSearchOptions",
62
- "MultiModalSearchOptions",
63
- "BaseSearchOptions",
64
- "configure_logging",
65
- ]
66
-
67
- if HAS_QA:
68
- __all__.extend(["DocumentQA", "get_qa_engine"])
69
-
70
-
71
- from .collections.pdf_collection import PDFCollection
72
-
73
- # Core classes
74
- from .core.pdf import PDF
75
- from .elements.region import Region
50
+ ElementCollection = None
76
51
 
77
52
  # Search options (if extras installed)
78
53
  try:
79
- from .search.search_options import BaseSearchOptions, MultiModalSearchOptions, TextSearchOptions
54
+ from natural_pdf.search.search_options import BaseSearchOptions, MultiModalSearchOptions, TextSearchOptions
80
55
  except ImportError:
81
56
  # Define dummy classes if extras not installed, so imports don't break
82
57
  # but using them will raise the ImportError from check_haystack_availability
83
- class TextSearchOptions:
58
+ class BaseSearchOptions:
84
59
  def __init__(self, *args, **kwargs):
85
60
  pass
86
61
 
87
- class MultiModalSearchOptions:
62
+ class TextSearchOptions:
88
63
  def __init__(self, *args, **kwargs):
89
64
  pass
90
65
 
91
- class BaseSearchOptions:
66
+ class MultiModalSearchOptions:
92
67
  def __init__(self, *args, **kwargs):
93
68
  pass
94
69
 
95
-
96
- # Expose logging setup? (Optional)
97
- # from . import logging_config
98
- # logging_config.setup_logging()
70
+ # Import QA module if available
71
+ try:
72
+ from natural_pdf.qa import DocumentQA, get_qa_engine
73
+ HAS_QA = True
74
+ except ImportError:
75
+ HAS_QA = False
99
76
 
100
77
  # Explicitly define what gets imported with 'from natural_pdf import *'
101
78
  __all__ = [
102
79
  "PDF",
103
80
  "PDFCollection",
81
+ "Page",
104
82
  "Region",
105
- "TextSearchOptions", # Include search options
83
+ "ElementCollection",
84
+ "TextSearchOptions",
106
85
  "MultiModalSearchOptions",
107
86
  "BaseSearchOptions",
87
+ "configure_logging",
108
88
  ]
89
+
90
+ # Add QA components to __all__ if available
91
+ if HAS_QA:
92
+ __all__.extend(["DocumentQA", "get_qa_engine"])
@@ -161,7 +161,7 @@ class ClassificationManager:
161
161
  def classify_item(
162
162
  self,
163
163
  item_content: Union[str, Image.Image],
164
- categories: List[str],
164
+ labels: List[str],
165
165
  model_id: Optional[str] = None,
166
166
  using: Optional[str] = None,
167
167
  min_confidence: float = 0.0,
@@ -193,13 +193,13 @@ class ClassificationManager:
193
193
  else self.DEFAULT_VISION_MODEL
194
194
  )
195
195
 
196
- if not categories:
197
- raise ValueError("Categories list cannot be empty.")
196
+ if not labels:
197
+ raise ValueError("Labels list cannot be empty.")
198
198
 
199
199
  pipeline_instance = self._get_pipeline(model_id, effective_using)
200
200
  timestamp = datetime.now()
201
201
  parameters = { # Store parameters used for this run
202
- "categories": categories,
202
+ "labels": labels,
203
203
  "model_id": model_id,
204
204
  "using": effective_using,
205
205
  "min_confidence": min_confidence,
@@ -214,7 +214,7 @@ class ClassificationManager:
214
214
  # Handle potential kwargs for specific pipelines if needed
215
215
  # The zero-shot pipelines expect `candidate_labels`
216
216
  result_raw = pipeline_instance(
217
- item_content, candidate_labels=categories, multi_label=multi_label, **kwargs
217
+ item_content, candidate_labels=labels, multi_label=multi_label, **kwargs
218
218
  )
219
219
  logger.debug(f"Raw pipeline result: {result_raw}")
220
220
 
@@ -226,7 +226,7 @@ class ClassificationManager:
226
226
  for label, score_val in zip(result_raw["labels"], result_raw["scores"]):
227
227
  if score_val >= min_confidence:
228
228
  try:
229
- scores_list.append(CategoryScore(label=label, confidence=score_val))
229
+ scores_list.append(CategoryScore(label, score_val))
230
230
  except (ValueError, TypeError) as score_err:
231
231
  logger.warning(
232
232
  f"Skipping invalid score from text pipeline: label='{label}', score={score_val}. Error: {score_err}"
@@ -241,7 +241,7 @@ class ClassificationManager:
241
241
  label = item["label"]
242
242
  if score_val >= min_confidence:
243
243
  try:
244
- scores_list.append(CategoryScore(label=label, confidence=score_val))
244
+ scores_list.append(CategoryScore(label, score_val))
245
245
  except (ValueError, TypeError) as score_err:
246
246
  logger.warning(
247
247
  f"Skipping invalid score from vision pipeline: label='{label}', score={score_val}. Error: {score_err}"
@@ -253,13 +253,15 @@ class ClassificationManager:
253
253
  # Return empty result?
254
254
  # scores_list = []
255
255
 
256
- return ClassificationResult(
256
+ # ClassificationResult now calculates top score/category internally
257
+ result_obj = ClassificationResult(
258
+ scores=scores_list, # Pass the filtered list
257
259
  model_id=model_id,
258
260
  using=effective_using,
259
- timestamp=timestamp,
260
261
  parameters=parameters,
261
- scores=scores_list,
262
+ timestamp=timestamp,
262
263
  )
264
+ return result_obj
263
265
  # --- End Processing --- #
264
266
 
265
267
  except Exception as e:
@@ -273,7 +275,7 @@ class ClassificationManager:
273
275
  def classify_batch(
274
276
  self,
275
277
  item_contents: List[Union[str, Image.Image]],
276
- categories: List[str],
278
+ labels: List[str],
277
279
  model_id: Optional[str] = None,
278
280
  using: Optional[str] = None,
279
281
  min_confidence: float = 0.0,
@@ -307,13 +309,13 @@ class ClassificationManager:
307
309
  else self.DEFAULT_VISION_MODEL
308
310
  )
309
311
 
310
- if not categories:
311
- raise ValueError("Categories list cannot be empty.")
312
+ if not labels:
313
+ raise ValueError("Labels list cannot be empty.")
312
314
 
313
315
  pipeline_instance = self._get_pipeline(model_id, effective_using)
314
316
  timestamp = datetime.now() # Single timestamp for the batch run
315
317
  parameters = { # Parameters for the whole batch
316
- "categories": categories,
318
+ "labels": labels,
317
319
  "model_id": model_id,
318
320
  "using": effective_using,
319
321
  "min_confidence": min_confidence,
@@ -331,7 +333,7 @@ class ClassificationManager:
331
333
  # Use pipeline directly for batching
332
334
  results_iterator = pipeline_instance(
333
335
  item_contents,
334
- candidate_labels=categories,
336
+ candidate_labels=labels,
335
337
  multi_label=multi_label,
336
338
  batch_size=batch_size,
337
339
  **kwargs,
@@ -362,9 +364,7 @@ class ClassificationManager:
362
364
  for label, score_val in zip(raw_result["labels"], raw_result["scores"]):
363
365
  if score_val >= min_confidence:
364
366
  try:
365
- scores_list.append(
366
- CategoryScore(label=label, confidence=score_val)
367
- )
367
+ scores_list.append(CategoryScore(label, score_val))
368
368
  except (ValueError, TypeError) as score_err:
369
369
  logger.warning(
370
370
  f"Skipping invalid score from text pipeline batch: label='{label}', score={score_val}. Error: {score_err}"
@@ -376,9 +376,7 @@ class ClassificationManager:
376
376
  score_val = item["score"]
377
377
  label = item["label"]
378
378
  if score_val >= min_confidence:
379
- scores_list.append(
380
- CategoryScore(label=label, confidence=score_val)
381
- )
379
+ scores_list.append(CategoryScore(label, score_val))
382
380
  except (KeyError, ValueError, TypeError) as item_err:
383
381
  logger.warning(
384
382
  f"Skipping invalid item in vision result list from batch: {item}. Error: {item_err}"
@@ -394,14 +392,20 @@ class ClassificationManager:
394
392
  )
395
393
  # scores_list remains empty for this item
396
394
 
395
+ # --- Determine top category and score ---
396
+ scores_list.sort(key=lambda s: s.score, reverse=True)
397
+ top_category = scores_list[0].label
398
+ top_score = scores_list[0].score
399
+ # --- End Determine top category ---
400
+
397
401
  # Append result object for this item
398
402
  batch_results_list.append(
399
403
  ClassificationResult(
404
+ scores=scores_list, # Pass the full list, init will sort/filter
400
405
  model_id=model_id,
401
406
  using=effective_using,
402
407
  timestamp=timestamp, # Use same timestamp for batch
403
408
  parameters=parameters, # Use same params for batch
404
- scores=scores_list,
405
409
  )
406
410
  )
407
411
  # --- End Processing --- #
@@ -44,9 +44,9 @@ class ClassificationMixin:
44
44
 
45
45
  def classify(
46
46
  self,
47
- categories: List[str],
48
- model: Optional[str] = None, # Default handled by manager
49
- using: Optional[str] = None, # Renamed parameter
47
+ labels: List[str],
48
+ model: Optional[str] = None,
49
+ using: Optional[str] = None,
50
50
  min_confidence: float = 0.0,
51
51
  analysis_key: str = "classification", # Default key
52
52
  multi_label: bool = False,
@@ -60,7 +60,7 @@ class ClassificationMixin:
60
60
  result under that key.
61
61
 
62
62
  Args:
63
- categories: A list of string category names.
63
+ labels: A list of string category names.
64
64
  model: Model identifier (e.g., 'text', 'vision', HF ID). Defaults handled by manager.
65
65
  using: Optional processing mode ('text' or 'vision'). If None, inferred by manager.
66
66
  min_confidence: Minimum confidence threshold for results (0.0-1.0).
@@ -103,9 +103,9 @@ class ClassificationMixin:
103
103
  # Manager now returns a ClassificationResult object
104
104
  result_obj: ClassificationResult = manager.classify_item(
105
105
  item_content=content,
106
- categories=categories,
107
- model_id=effective_model_id, # Pass the resolved model ID
108
- using=inferred_using, # Pass renamed argument
106
+ labels=labels,
107
+ model_id=effective_model_id,
108
+ using=inferred_using,
109
109
  min_confidence=min_confidence,
110
110
  multi_label=multi_label,
111
111
  **kwargs,
@@ -11,19 +11,19 @@ logger = logging.getLogger(__name__)
11
11
  class CategoryScore:
12
12
  """Represents a category and its confidence score from classification."""
13
13
 
14
- category: str
14
+ label: str
15
15
  score: float
16
16
 
17
17
  def to_dict(self) -> Dict[str, Any]:
18
18
  """Convert to dictionary for serialization."""
19
- return {"category": self.category, "score": self.score}
19
+ return {"category": self.label, "score": self.score}
20
20
 
21
21
 
22
22
  @dataclass
23
23
  class ClassificationResult:
24
24
  """Results from a classification operation."""
25
25
 
26
- category: str
26
+ category: Optional[str] # Can be None if scores are empty
27
27
  score: float
28
28
  scores: List[CategoryScore]
29
29
  model_id: str
@@ -33,17 +33,25 @@ class ClassificationResult:
33
33
 
34
34
  def __init__(
35
35
  self,
36
- category: str,
37
- score: float,
38
- scores: List[CategoryScore],
36
+ scores: List[CategoryScore], # Now the primary source
39
37
  model_id: str,
40
38
  using: str,
41
39
  parameters: Optional[Dict[str, Any]] = None,
42
40
  timestamp: Optional[datetime] = None,
43
41
  ):
44
- self.category = category
45
- self.score = score
46
- self.scores = scores
42
+ # Determine top category and score from the scores list
43
+ if scores:
44
+ # Sort scores descending by score to find the top one
45
+ sorted_scores = sorted(scores, key=lambda s: s.score, reverse=True)
46
+ self.category = sorted_scores[0].label
47
+ self.score = sorted_scores[0].score
48
+ self.scores = sorted_scores # Store the sorted list
49
+ else:
50
+ # Handle empty scores list
51
+ self.category = None
52
+ self.score = 0.0
53
+ self.scores = [] # Store empty list
54
+
47
55
  self.model_id = model_id
48
56
  self.using = using
49
57
  self.parameters = parameters or {}
@@ -109,3 +109,20 @@ class ApplyMixin:
109
109
  return PageCollection(results)
110
110
 
111
111
  return results
112
+
113
+ def filter(self: Any, predicate: Callable[[Any], bool]) -> Any:
114
+ """
115
+ Filters the collection based on a predicate function.
116
+
117
+ Args:
118
+ predicate: A function that takes an item and returns True if the item
119
+ should be included in the result, False otherwise.
120
+
121
+ Returns:
122
+ A new collection of the same type containing only the items
123
+ for which the predicate returned True.
124
+ """
125
+ items_iterable = self._get_items_for_apply()
126
+ filtered_items = [item for item in items_iterable if predicate(item)]
127
+
128
+ return type(self)(filtered_items)
@@ -519,7 +519,7 @@ class PDFCollection(SearchableMixin, ApplyMixin, ExportMixin): # Add ExportMixi
519
519
 
520
520
  return self
521
521
 
522
- def categorize(self, categories: List[str], **kwargs):
522
+ def categorize(self, labels: List[str], **kwargs):
523
523
  """Categorizes PDFs in the collection based on content or features."""
524
524
  # Implementation requires integrating with classification models or logic
525
525
  raise NotImplementedError("categorize requires classification implementation.")
@@ -570,85 +570,101 @@ class PDFCollection(SearchableMixin, ApplyMixin, ExportMixin): # Add ExportMixi
570
570
  # --- Classification Method --- #
571
571
  def classify_all(
572
572
  self,
573
- categories: List[str],
574
- model: str = "text",
573
+ labels: List[str],
574
+ using: Optional[str] = None, # Default handled by PDF.classify -> manager
575
+ model: Optional[str] = None, # Optional model ID
575
576
  max_workers: Optional[int] = None,
577
+ analysis_key: str = "classification", # Key for storing result in PDF.analyses
576
578
  **kwargs,
577
579
  ) -> "PDFCollection":
578
580
  """
579
- Classify all pages across all PDFs in the collection, potentially in parallel.
581
+ Classify each PDF document in the collection, potentially in parallel.
580
582
 
581
- This method uses the unified `classify_all` approach, delegating page
582
- classification to each PDF's `classify_pages` method.
583
- It displays a progress bar tracking individual pages.
583
+ This method delegates classification to each PDF object's `classify` method.
584
+ By default, uses the full extracted text of the PDF.
585
+ If `using='vision'`, it classifies the first page's image, but ONLY if
586
+ the PDF has a single page (raises ValueError otherwise).
584
587
 
585
588
  Args:
586
- categories: A list of string category names.
587
- model: Model identifier ('text', 'vision', or specific HF ID).
589
+ labels: A list of string category names.
590
+ using: Processing mode ('text', 'vision'). If None, manager infers (defaulting to text).
591
+ model: Optional specific model identifier (e.g., HF ID). If None, manager uses default for 'using' mode.
588
592
  max_workers: Maximum number of threads to process PDFs concurrently.
589
593
  If None or 1, processing is sequential.
590
- **kwargs: Additional arguments passed down to `pdf.classify_pages` and
591
- subsequently to `page.classify` (e.g., device,
592
- confidence_threshold, resolution).
594
+ analysis_key: Key under which to store the ClassificationResult in each PDF's `analyses` dict.
595
+ **kwargs: Additional arguments passed down to `pdf.classify` (e.g., device,
596
+ min_confidence, multi_label, text extraction options).
593
597
 
594
598
  Returns:
595
599
  Self for method chaining.
596
600
 
597
601
  Raises:
598
- ValueError: If categories list is empty.
599
- ClassificationError: If classification fails for any page (will stop processing).
602
+ ValueError: If labels list is empty, or if using='vision' on a multi-page PDF.
603
+ ClassificationError: If classification fails for any PDF (will stop processing).
600
604
  ImportError: If classification dependencies are missing.
601
605
  """
602
606
  PDF = self._get_pdf_class()
603
- if not categories:
604
- raise ValueError("Categories list cannot be empty.")
607
+ if not labels:
608
+ raise ValueError("Labels list cannot be empty.")
605
609
 
610
+ if not self._pdfs:
611
+ logger.warning("PDFCollection is empty, skipping classification.")
612
+ return self
613
+
614
+ mode_desc = f"using='{using}'" if using else f"model='{model}'" if model else "default text"
606
615
  logger.info(
607
- f"Starting classification for {len(self._pdfs)} PDFs in collection (model: '{model}')..."
616
+ f"Starting classification for {len(self._pdfs)} PDFs in collection ({mode_desc})..."
608
617
  )
609
618
 
610
- # Calculate total pages for the progress bar
611
- total_pages = sum(len(pdf.pages) for pdf in self._pdfs if pdf.pages)
612
- if total_pages == 0:
613
- logger.warning("No pages found in the PDF collection to classify.")
614
- return self
615
-
616
619
  progress_bar = tqdm(
617
- total=total_pages, desc=f"Classifying Pages (model: {model})", unit="page"
620
+ total=len(self._pdfs), desc=f"Classifying PDFs ({mode_desc})", unit="pdf"
618
621
  )
619
622
 
620
623
  # Worker function
621
624
  def _process_pdf_classification(pdf: PDF):
622
625
  thread_id = threading.current_thread().name
623
626
  pdf_path = pdf.path
624
- logger.debug(f"[{thread_id}] Starting classification process for: {pdf_path}")
627
+ logger.debug(f"[{thread_id}] Starting classification process for PDF: {pdf_path}")
625
628
  start_time = time.monotonic()
626
629
  try:
627
- # Call classify_pages on the PDF, passing the progress callback
628
- pdf.classify_pages(
629
- categories=categories,
630
+ # Call classify directly on the PDF object
631
+ pdf.classify(
632
+ labels=labels,
633
+ using=using,
630
634
  model=model,
631
- progress_callback=progress_bar.update,
632
- **kwargs,
635
+ analysis_key=analysis_key,
636
+ **kwargs, # Pass other relevant args like min_confidence, multi_label
633
637
  )
634
638
  end_time = time.monotonic()
635
639
  logger.debug(
636
- f"[{thread_id}] Finished classification for: {pdf_path} (Duration: {end_time - start_time:.2f}s)"
640
+ f"[{thread_id}] Finished classification for PDF: {pdf_path} (Duration: {end_time - start_time:.2f}s)"
637
641
  )
642
+ progress_bar.update(1) # Update progress bar upon success
638
643
  return pdf_path, None # Return path and no error
639
- except Exception as e:
644
+ except ValueError as ve:
645
+ # Catch specific error for vision on multi-page PDF
640
646
  end_time = time.monotonic()
641
- # Error is logged within classify_pages, but log summary here
642
647
  logger.error(
643
- f"[{thread_id}] Failed classification process for {pdf_path} after {end_time - start_time:.2f}s: {e}",
648
+ f"[{thread_id}] Skipped classification for {pdf_path} after {end_time - start_time:.2f}s: {ve}",
644
649
  exc_info=False,
645
650
  )
646
- # Close progress bar immediately on error to avoid hanging
647
- progress_bar.close()
651
+ progress_bar.update(1) # Still update progress bar
652
+ return pdf_path, ve # Return the specific ValueError
653
+ except Exception as e:
654
+ end_time = time.monotonic()
655
+ logger.error(
656
+ f"[{thread_id}] Failed classification process for PDF {pdf_path} after {end_time - start_time:.2f}s: {e}",
657
+ exc_info=True, # Log full traceback for unexpected errors
658
+ )
659
+ # Close progress bar immediately on critical error to avoid hanging
660
+ if not progress_bar.disable:
661
+ progress_bar.close()
648
662
  # Re-raise the exception to stop the entire collection processing
649
- raise
663
+ raise ClassificationError(f"Classification failed for {pdf_path}: {e}") from e
650
664
 
651
665
  # Use ThreadPoolExecutor for parallel processing if max_workers > 1
666
+ processed_count = 0
667
+ skipped_count = 0
652
668
  try:
653
669
  if max_workers is not None and max_workers > 1:
654
670
  logger.info(f"Classifying PDFs in parallel with {max_workers} workers.")
@@ -659,23 +675,39 @@ class PDFCollection(SearchableMixin, ApplyMixin, ExportMixin): # Add ExportMixi
659
675
  for pdf in self._pdfs:
660
676
  futures.append(executor.submit(_process_pdf_classification, pdf))
661
677
 
662
- # Wait for all futures to complete (progress updated by callback)
663
- # Exceptions are raised by future.result() if worker failed
678
+ # Wait for all futures to complete
679
+ # Progress updated within worker
664
680
  for future in concurrent.futures.as_completed(futures):
665
- future.result() # Raise exception if worker failed
681
+ processed_count += 1
682
+ pdf_path, error = (
683
+ future.result()
684
+ ) # Raise ClassificationError if worker failed critically
685
+ if isinstance(error, ValueError):
686
+ # Logged in worker, just count as skipped
687
+ skipped_count += 1
666
688
 
667
689
  else: # Sequential processing
668
690
  logger.info("Classifying PDFs sequentially.")
669
691
  for pdf in self._pdfs:
670
- _process_pdf_classification(pdf)
671
-
672
- logger.info("Finished classification across the collection.")
692
+ processed_count += 1
693
+ pdf_path, error = _process_pdf_classification(
694
+ pdf
695
+ ) # Raise ClassificationError if worker failed critically
696
+ if isinstance(error, ValueError):
697
+ skipped_count += 1
698
+
699
+ final_message = (
700
+ f"Finished classification across the collection. Processed: {processed_count}"
701
+ )
702
+ if skipped_count > 0:
703
+ final_message += f", Skipped (e.g., vision on multi-page): {skipped_count}"
704
+ logger.info(final_message + ".")
673
705
 
674
706
  finally:
675
- # Ensure progress bar is closed even if errors occurred elsewhere
707
+ # Ensure progress bar is closed properly
676
708
  if not progress_bar.disable and progress_bar.n < progress_bar.total:
677
- progress_bar.close()
678
- elif progress_bar.disable is False:
709
+ progress_bar.n = progress_bar.total # Ensure it reaches 100%
710
+ if not progress_bar.disable:
679
711
  progress_bar.close()
680
712
 
681
713
  return self