natural-pdf 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. natural_pdf/__init__.py +3 -0
  2. natural_pdf/analyzers/layout/base.py +1 -5
  3. natural_pdf/analyzers/layout/gemini.py +61 -51
  4. natural_pdf/analyzers/layout/layout_analyzer.py +40 -11
  5. natural_pdf/analyzers/layout/layout_manager.py +26 -84
  6. natural_pdf/analyzers/layout/layout_options.py +7 -0
  7. natural_pdf/analyzers/layout/pdfplumber_table_finder.py +142 -0
  8. natural_pdf/analyzers/layout/surya.py +46 -123
  9. natural_pdf/analyzers/layout/tatr.py +51 -4
  10. natural_pdf/analyzers/text_structure.py +3 -5
  11. natural_pdf/analyzers/utils.py +3 -3
  12. natural_pdf/classification/manager.py +422 -0
  13. natural_pdf/classification/mixin.py +163 -0
  14. natural_pdf/classification/results.py +80 -0
  15. natural_pdf/collections/mixins.py +111 -0
  16. natural_pdf/collections/pdf_collection.py +434 -15
  17. natural_pdf/core/element_manager.py +83 -0
  18. natural_pdf/core/highlighting_service.py +13 -22
  19. natural_pdf/core/page.py +578 -93
  20. natural_pdf/core/pdf.py +912 -460
  21. natural_pdf/elements/base.py +134 -40
  22. natural_pdf/elements/collections.py +712 -109
  23. natural_pdf/elements/region.py +722 -69
  24. natural_pdf/elements/text.py +4 -1
  25. natural_pdf/export/mixin.py +137 -0
  26. natural_pdf/exporters/base.py +3 -3
  27. natural_pdf/exporters/paddleocr.py +5 -4
  28. natural_pdf/extraction/manager.py +135 -0
  29. natural_pdf/extraction/mixin.py +279 -0
  30. natural_pdf/extraction/result.py +23 -0
  31. natural_pdf/ocr/__init__.py +5 -5
  32. natural_pdf/ocr/engine_doctr.py +346 -0
  33. natural_pdf/ocr/engine_easyocr.py +6 -3
  34. natural_pdf/ocr/ocr_factory.py +24 -4
  35. natural_pdf/ocr/ocr_manager.py +122 -26
  36. natural_pdf/ocr/ocr_options.py +94 -11
  37. natural_pdf/ocr/utils.py +19 -6
  38. natural_pdf/qa/document_qa.py +0 -4
  39. natural_pdf/search/__init__.py +20 -34
  40. natural_pdf/search/haystack_search_service.py +309 -265
  41. natural_pdf/search/haystack_utils.py +99 -75
  42. natural_pdf/search/search_service_protocol.py +11 -12
  43. natural_pdf/selectors/parser.py +431 -230
  44. natural_pdf/utils/debug.py +3 -3
  45. natural_pdf/utils/identifiers.py +1 -1
  46. natural_pdf/utils/locks.py +8 -0
  47. natural_pdf/utils/packaging.py +8 -6
  48. natural_pdf/utils/text_extraction.py +60 -1
  49. natural_pdf/utils/tqdm_utils.py +51 -0
  50. natural_pdf/utils/visualization.py +18 -0
  51. natural_pdf/widgets/viewer.py +4 -25
  52. {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.9.dist-info}/METADATA +17 -3
  53. natural_pdf-0.1.9.dist-info/RECORD +80 -0
  54. {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.9.dist-info}/WHEEL +1 -1
  55. {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.9.dist-info}/top_level.txt +0 -2
  56. docs/api/index.md +0 -386
  57. docs/assets/favicon.png +0 -3
  58. docs/assets/favicon.svg +0 -3
  59. docs/assets/javascripts/custom.js +0 -17
  60. docs/assets/logo.svg +0 -3
  61. docs/assets/sample-screen.png +0 -0
  62. docs/assets/social-preview.png +0 -17
  63. docs/assets/social-preview.svg +0 -17
  64. docs/assets/stylesheets/custom.css +0 -65
  65. docs/document-qa/index.ipynb +0 -435
  66. docs/document-qa/index.md +0 -79
  67. docs/element-selection/index.ipynb +0 -915
  68. docs/element-selection/index.md +0 -229
  69. docs/finetuning/index.md +0 -176
  70. docs/index.md +0 -170
  71. docs/installation/index.md +0 -69
  72. docs/interactive-widget/index.ipynb +0 -962
  73. docs/interactive-widget/index.md +0 -12
  74. docs/layout-analysis/index.ipynb +0 -818
  75. docs/layout-analysis/index.md +0 -185
  76. docs/ocr/index.md +0 -209
  77. docs/pdf-navigation/index.ipynb +0 -314
  78. docs/pdf-navigation/index.md +0 -97
  79. docs/regions/index.ipynb +0 -816
  80. docs/regions/index.md +0 -294
  81. docs/tables/index.ipynb +0 -658
  82. docs/tables/index.md +0 -144
  83. docs/text-analysis/index.ipynb +0 -370
  84. docs/text-analysis/index.md +0 -105
  85. docs/text-extraction/index.ipynb +0 -1478
  86. docs/text-extraction/index.md +0 -292
  87. docs/tutorials/01-loading-and-extraction.ipynb +0 -194
  88. docs/tutorials/01-loading-and-extraction.md +0 -95
  89. docs/tutorials/02-finding-elements.ipynb +0 -340
  90. docs/tutorials/02-finding-elements.md +0 -149
  91. docs/tutorials/03-extracting-blocks.ipynb +0 -147
  92. docs/tutorials/03-extracting-blocks.md +0 -48
  93. docs/tutorials/04-table-extraction.ipynb +0 -114
  94. docs/tutorials/04-table-extraction.md +0 -50
  95. docs/tutorials/05-excluding-content.ipynb +0 -270
  96. docs/tutorials/05-excluding-content.md +0 -109
  97. docs/tutorials/06-document-qa.ipynb +0 -332
  98. docs/tutorials/06-document-qa.md +0 -91
  99. docs/tutorials/07-layout-analysis.ipynb +0 -288
  100. docs/tutorials/07-layout-analysis.md +0 -66
  101. docs/tutorials/07-working-with-regions.ipynb +0 -413
  102. docs/tutorials/07-working-with-regions.md +0 -151
  103. docs/tutorials/08-spatial-navigation.ipynb +0 -508
  104. docs/tutorials/08-spatial-navigation.md +0 -190
  105. docs/tutorials/09-section-extraction.ipynb +0 -2434
  106. docs/tutorials/09-section-extraction.md +0 -256
  107. docs/tutorials/10-form-field-extraction.ipynb +0 -512
  108. docs/tutorials/10-form-field-extraction.md +0 -201
  109. docs/tutorials/11-enhanced-table-processing.ipynb +0 -54
  110. docs/tutorials/11-enhanced-table-processing.md +0 -9
  111. docs/tutorials/12-ocr-integration.ipynb +0 -604
  112. docs/tutorials/12-ocr-integration.md +0 -175
  113. docs/tutorials/13-semantic-search.ipynb +0 -1328
  114. docs/tutorials/13-semantic-search.md +0 -77
  115. docs/visual-debugging/index.ipynb +0 -2970
  116. docs/visual-debugging/index.md +0 -157
  117. docs/visual-debugging/region.png +0 -0
  118. natural_pdf/templates/finetune/fine_tune_paddleocr.md +0 -415
  119. natural_pdf/templates/spa/css/style.css +0 -334
  120. natural_pdf/templates/spa/index.html +0 -31
  121. natural_pdf/templates/spa/js/app.js +0 -472
  122. natural_pdf/templates/spa/words.txt +0 -235976
  123. natural_pdf/widgets/frontend/viewer.js +0 -88
  124. natural_pdf-0.1.7.dist-info/RECORD +0 -145
  125. notebooks/Examples.ipynb +0 -1293
  126. pdfs/.gitkeep +0 -0
  127. pdfs/01-practice.pdf +0 -543
  128. pdfs/0500000US42001.pdf +0 -0
  129. pdfs/0500000US42007.pdf +0 -0
  130. pdfs/2014 Statistics.pdf +0 -0
  131. pdfs/2019 Statistics.pdf +0 -0
  132. pdfs/Atlanta_Public_Schools_GA_sample.pdf +0 -0
  133. pdfs/needs-ocr.pdf +0 -0
  134. {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,422 @@
1
+ import logging
2
+ import time
3
+ from datetime import datetime
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
5
+
6
+ from PIL import Image
7
+
8
+ # Use try-except for robustness if dependencies are missing
9
+ try:
10
+ import torch
11
+ from transformers import (
12
+ AutoModelForSequenceClassification,
13
+ AutoModelForZeroShotImageClassification,
14
+ AutoTokenizer,
15
+ pipeline,
16
+ )
17
+
18
+ _CLASSIFICATION_AVAILABLE = True
19
+ except ImportError:
20
+ _CLASSIFICATION_AVAILABLE = False
21
+ # Define dummy types for type hinting if imports fail
22
+ pipeline = object
23
+ AutoTokenizer = object
24
+ AutoModelForZeroShotImageClassification = object
25
+ AutoModelForSequenceClassification = object
26
+ torch = None
27
+
28
+ from natural_pdf.utils.tqdm_utils import get_tqdm
29
+
30
+ # Import result classes
31
+ from .results import CategoryScore, ClassificationResult
32
+
33
+ if TYPE_CHECKING:
34
+ from transformers import Pipeline
35
+
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+ # Global cache for models/pipelines
40
+ _PIPELINE_CACHE: Dict[str, "Pipeline"] = {}
41
+ _TOKENIZER_CACHE: Dict[str, Any] = {}
42
+ _MODEL_CACHE: Dict[str, Any] = {}
43
+
44
+
45
+ class ClassificationError(Exception):
46
+ """Custom exception for classification errors."""
47
+
48
+ pass
49
+
50
+
51
+ class ClassificationManager:
52
+ """Manages classification models and execution."""
53
+
54
+ DEFAULT_TEXT_MODEL = "facebook/bart-large-mnli"
55
+ DEFAULT_VISION_MODEL = "openai/clip-vit-base-patch16"
56
+
57
+ def __init__(
58
+ self,
59
+ model_mapping: Optional[Dict[str, str]] = None,
60
+ default_device: Optional[str] = None,
61
+ ):
62
+ """
63
+ Initialize the ClassificationManager.
64
+
65
+ Args:
66
+ model_mapping: Optional dictionary mapping aliases ('text', 'vision') to model IDs.
67
+ default_device: Default device ('cpu', 'cuda') if not specified in classify calls.
68
+ """
69
+ if not _CLASSIFICATION_AVAILABLE:
70
+ raise ImportError(
71
+ "Classification dependencies missing. "
72
+ 'Install with: pip install "natural-pdf[classification]"'
73
+ )
74
+
75
+ self.pipelines: Dict[Tuple[str, str], "Pipeline"] = (
76
+ {}
77
+ ) # Cache: (model_id, device) -> pipeline
78
+
79
+ self.device = default_device
80
+ logger.info(f"ClassificationManager initialized on device: {self.device}")
81
+
82
+ def is_available(self) -> bool:
83
+ """Check if required dependencies are installed."""
84
+ return _CLASSIFICATION_AVAILABLE
85
+
86
+ def _get_pipeline(self, model_id: str, using: str) -> "Pipeline":
87
+ """Get or create a classification pipeline."""
88
+ cache_key = f"{model_id}_{using}_{self.device}"
89
+ if cache_key not in _PIPELINE_CACHE:
90
+ logger.info(
91
+ f"Loading {using} classification pipeline for model '{model_id}' on device '{self.device}'..."
92
+ )
93
+ start_time = time.time()
94
+ try:
95
+ task = (
96
+ "zero-shot-classification"
97
+ if using == "text"
98
+ else "zero-shot-image-classification"
99
+ )
100
+ _PIPELINE_CACHE[cache_key] = pipeline(task, model=model_id, device=self.device)
101
+ end_time = time.time()
102
+ logger.info(
103
+ f"Pipeline for '{model_id}' loaded in {end_time - start_time:.2f} seconds."
104
+ )
105
+ except Exception as e:
106
+ logger.error(
107
+ f"Failed to load pipeline for model '{model_id}' (using: {using}): {e}",
108
+ exc_info=True,
109
+ )
110
+ raise ClassificationError(
111
+ f"Failed to load pipeline for model '{model_id}'. Ensure the model ID is correct and supports the {task} task."
112
+ ) from e
113
+ return _PIPELINE_CACHE[cache_key]
114
+
115
+ def infer_using(self, model_id: str, using: Optional[str] = None) -> str:
116
+ """Infers processing mode ('text' or 'vision') if not provided."""
117
+ if using in ["text", "vision"]:
118
+ return using
119
+
120
+ # Simple inference based on common model names
121
+ normalized_model_id = model_id.lower()
122
+ if (
123
+ "clip" in normalized_model_id
124
+ or "vit" in normalized_model_id
125
+ or "siglip" in normalized_model_id
126
+ ):
127
+ logger.debug(f"Inferred using='vision' for model '{model_id}'")
128
+ return "vision"
129
+ if (
130
+ "bart" in normalized_model_id
131
+ or "bert" in normalized_model_id
132
+ or "mnli" in normalized_model_id
133
+ or "xnli" in normalized_model_id
134
+ or "deberta" in normalized_model_id
135
+ ):
136
+ logger.debug(f"Inferred using='text' for model '{model_id}'")
137
+ return "text"
138
+
139
+ # Fallback or raise error? Let's try loading text first, then vision.
140
+ logger.warning(
141
+ f"Could not reliably infer mode for '{model_id}'. Trying text, then vision pipeline loading."
142
+ )
143
+ try:
144
+ self._get_pipeline(model_id, "text")
145
+ logger.info(f"Successfully loaded '{model_id}' as a text model.")
146
+ return "text"
147
+ except Exception:
148
+ logger.warning(f"Failed to load '{model_id}' as text model. Trying vision.")
149
+ try:
150
+ self._get_pipeline(model_id, "vision")
151
+ logger.info(f"Successfully loaded '{model_id}' as a vision model.")
152
+ return "vision"
153
+ except Exception as e_vision:
154
+ logger.error(
155
+ f"Failed to load '{model_id}' as either text or vision model.", exc_info=True
156
+ )
157
+ raise ClassificationError(
158
+ f"Cannot determine mode for model '{model_id}'. Please specify `using='text'` or `using='vision'`. Error: {e_vision}"
159
+ )
160
+
161
+ def classify_item(
162
+ self,
163
+ item_content: Union[str, Image.Image],
164
+ categories: List[str],
165
+ model_id: Optional[str] = None,
166
+ using: Optional[str] = None,
167
+ min_confidence: float = 0.0,
168
+ multi_label: bool = False,
169
+ **kwargs,
170
+ ) -> ClassificationResult: # Return ClassificationResult
171
+ """Classifies a single item (text or image)."""
172
+
173
+ # Determine model and engine type
174
+ effective_using = using
175
+ if model_id is None:
176
+ # Try inferring based on content type
177
+ if isinstance(item_content, str):
178
+ effective_using = "text"
179
+ model_id = self.DEFAULT_TEXT_MODEL
180
+ elif isinstance(item_content, Image.Image):
181
+ effective_using = "vision"
182
+ model_id = self.DEFAULT_VISION_MODEL
183
+ else:
184
+ raise TypeError(f"Unsupported item_content type: {type(item_content)}")
185
+ else:
186
+ # Infer engine type if not given
187
+ effective_using = self.infer_using(model_id, using)
188
+ # Set default model if needed (though should usually be provided if engine known)
189
+ if model_id is None:
190
+ model_id = (
191
+ self.DEFAULT_TEXT_MODEL
192
+ if effective_using == "text"
193
+ else self.DEFAULT_VISION_MODEL
194
+ )
195
+
196
+ if not categories:
197
+ raise ValueError("Categories list cannot be empty.")
198
+
199
+ pipeline_instance = self._get_pipeline(model_id, effective_using)
200
+ timestamp = datetime.now()
201
+ parameters = { # Store parameters used for this run
202
+ "categories": categories,
203
+ "model_id": model_id,
204
+ "using": effective_using,
205
+ "min_confidence": min_confidence,
206
+ "multi_label": multi_label,
207
+ **kwargs,
208
+ }
209
+
210
+ logger.debug(
211
+ f"Classifying content (type: {type(item_content).__name__}) with model '{model_id}'"
212
+ )
213
+ try:
214
+ # Handle potential kwargs for specific pipelines if needed
215
+ # The zero-shot pipelines expect `candidate_labels`
216
+ result_raw = pipeline_instance(
217
+ item_content, candidate_labels=categories, multi_label=multi_label, **kwargs
218
+ )
219
+ logger.debug(f"Raw pipeline result: {result_raw}")
220
+
221
+ # --- Process raw result into ClassificationResult --- #
222
+ scores_list: List[CategoryScore] = []
223
+
224
+ # Handle text pipeline format (dict with 'labels' and 'scores')
225
+ if isinstance(result_raw, dict) and "labels" in result_raw and "scores" in result_raw:
226
+ for label, score_val in zip(result_raw["labels"], result_raw["scores"]):
227
+ if score_val >= min_confidence:
228
+ try:
229
+ scores_list.append(CategoryScore(label=label, confidence=score_val))
230
+ except (ValueError, TypeError) as score_err:
231
+ logger.warning(
232
+ f"Skipping invalid score from text pipeline: label='{label}', score={score_val}. Error: {score_err}"
233
+ )
234
+ # Handle vision pipeline format (list of dicts with 'label' and 'score')
235
+ elif isinstance(result_raw, list) and all(
236
+ isinstance(item, dict) and "label" in item and "score" in item
237
+ for item in result_raw
238
+ ):
239
+ for item in result_raw:
240
+ score_val = item["score"]
241
+ label = item["label"]
242
+ if score_val >= min_confidence:
243
+ try:
244
+ scores_list.append(CategoryScore(label=label, confidence=score_val))
245
+ except (ValueError, TypeError) as score_err:
246
+ logger.warning(
247
+ f"Skipping invalid score from vision pipeline: label='{label}', score={score_val}. Error: {score_err}"
248
+ )
249
+ else:
250
+ logger.warning(
251
+ f"Unexpected raw result format from pipeline for model '{model_id}': {type(result_raw)}. Cannot extract scores."
252
+ )
253
+ # Return empty result?
254
+ # scores_list = []
255
+
256
+ return ClassificationResult(
257
+ model_id=model_id,
258
+ using=effective_using,
259
+ timestamp=timestamp,
260
+ parameters=parameters,
261
+ scores=scores_list,
262
+ )
263
+ # --- End Processing --- #
264
+
265
+ except Exception as e:
266
+ logger.error(f"Classification failed for model '{model_id}': {e}", exc_info=True)
267
+ # Return an empty result object on failure?
268
+ # return ClassificationResult(model_id=model_id, engine_type=engine_type, timestamp=timestamp, parameters=parameters, scores=[])
269
+ raise ClassificationError(
270
+ f"Classification failed using model '{model_id}'. Error: {e}"
271
+ ) from e
272
+
273
+ def classify_batch(
274
+ self,
275
+ item_contents: List[Union[str, Image.Image]],
276
+ categories: List[str],
277
+ model_id: Optional[str] = None,
278
+ using: Optional[str] = None,
279
+ min_confidence: float = 0.0,
280
+ multi_label: bool = False,
281
+ batch_size: int = 8,
282
+ progress_bar: bool = True,
283
+ **kwargs,
284
+ ) -> List[ClassificationResult]: # Return list of ClassificationResult
285
+ """Classifies a batch of items (text or image) using the pipeline's batching."""
286
+ if not item_contents:
287
+ return []
288
+
289
+ # Determine model and engine type (assuming uniform type in batch)
290
+ first_item = item_contents[0]
291
+ effective_using = using
292
+ if model_id is None:
293
+ if isinstance(first_item, str):
294
+ effective_using = "text"
295
+ model_id = self.DEFAULT_TEXT_MODEL
296
+ elif isinstance(first_item, Image.Image):
297
+ effective_using = "vision"
298
+ model_id = self.DEFAULT_VISION_MODEL
299
+ else:
300
+ raise TypeError(f"Unsupported item_content type in batch: {type(first_item)}")
301
+ else:
302
+ effective_using = self.infer_using(model_id, using)
303
+ if model_id is None:
304
+ model_id = (
305
+ self.DEFAULT_TEXT_MODEL
306
+ if effective_using == "text"
307
+ else self.DEFAULT_VISION_MODEL
308
+ )
309
+
310
+ if not categories:
311
+ raise ValueError("Categories list cannot be empty.")
312
+
313
+ pipeline_instance = self._get_pipeline(model_id, effective_using)
314
+ timestamp = datetime.now() # Single timestamp for the batch run
315
+ parameters = { # Parameters for the whole batch
316
+ "categories": categories,
317
+ "model_id": model_id,
318
+ "using": effective_using,
319
+ "min_confidence": min_confidence,
320
+ "multi_label": multi_label,
321
+ "batch_size": batch_size,
322
+ **kwargs,
323
+ }
324
+
325
+ logger.info(
326
+ f"Classifying batch of {len(item_contents)} items with model '{model_id}' (batch size: {batch_size})"
327
+ )
328
+ batch_results_list: List[ClassificationResult] = []
329
+
330
+ try:
331
+ # Use pipeline directly for batching
332
+ results_iterator = pipeline_instance(
333
+ item_contents,
334
+ candidate_labels=categories,
335
+ multi_label=multi_label,
336
+ batch_size=batch_size,
337
+ **kwargs,
338
+ )
339
+
340
+ # Wrap with tqdm for progress if requested
341
+ total_items = len(item_contents)
342
+ if progress_bar:
343
+ # Get the appropriate tqdm class
344
+ tqdm_class = get_tqdm()
345
+ results_iterator = tqdm_class(
346
+ results_iterator,
347
+ total=total_items,
348
+ desc=f"Classifying batch ({model_id})",
349
+ leave=False, # Don't leave progress bar hanging
350
+ )
351
+
352
+ for raw_result in results_iterator:
353
+ # --- Process each raw result (which corresponds to ONE input item) --- #
354
+ scores_list: List[CategoryScore] = []
355
+ try:
356
+ # Check for text format (dict with 'labels' and 'scores')
357
+ if (
358
+ isinstance(raw_result, dict)
359
+ and "labels" in raw_result
360
+ and "scores" in raw_result
361
+ ):
362
+ for label, score_val in zip(raw_result["labels"], raw_result["scores"]):
363
+ if score_val >= min_confidence:
364
+ try:
365
+ scores_list.append(
366
+ CategoryScore(label=label, confidence=score_val)
367
+ )
368
+ except (ValueError, TypeError) as score_err:
369
+ logger.warning(
370
+ f"Skipping invalid score from text pipeline batch: label='{label}', score={score_val}. Error: {score_err}"
371
+ )
372
+ # Check for vision format (list of dicts with 'label' and 'score')
373
+ elif isinstance(raw_result, list):
374
+ for item in raw_result:
375
+ try:
376
+ score_val = item["score"]
377
+ label = item["label"]
378
+ if score_val >= min_confidence:
379
+ scores_list.append(
380
+ CategoryScore(label=label, confidence=score_val)
381
+ )
382
+ except (KeyError, ValueError, TypeError) as item_err:
383
+ logger.warning(
384
+ f"Skipping invalid item in vision result list from batch: {item}. Error: {item_err}"
385
+ )
386
+ else:
387
+ logger.warning(
388
+ f"Unexpected raw result format in batch item from model '{model_id}': {type(raw_result)}. Cannot extract scores."
389
+ )
390
+
391
+ except Exception as proc_err:
392
+ logger.error(
393
+ f"Error processing result item in batch: {proc_err}", exc_info=True
394
+ )
395
+ # scores_list remains empty for this item
396
+
397
+ # Append result object for this item
398
+ batch_results_list.append(
399
+ ClassificationResult(
400
+ model_id=model_id,
401
+ using=effective_using,
402
+ timestamp=timestamp, # Use same timestamp for batch
403
+ parameters=parameters, # Use same params for batch
404
+ scores=scores_list,
405
+ )
406
+ )
407
+ # --- End Processing --- #
408
+
409
+ if len(batch_results_list) != total_items:
410
+ logger.warning(
411
+ f"Batch classification returned {len(batch_results_list)} results, but expected {total_items}. Results might be incomplete or misaligned."
412
+ )
413
+
414
+ return batch_results_list
415
+
416
+ except Exception as e:
417
+ logger.error(f"Batch classification failed for model '{model_id}': {e}", exc_info=True)
418
+ # Return list of empty results?
419
+ # return [ClassificationResult(model_id=model_id, s=engine_type, timestamp=timestamp, parameters=parameters, scores=[]) for _ in item_contents]
420
+ raise ClassificationError(
421
+ f"Batch classification failed using model '{model_id}'. Error: {e}"
422
+ ) from e
@@ -0,0 +1,163 @@
1
+ import logging
2
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ # Assuming PIL is installed as it's needed for vision
5
+ try:
6
+ from PIL import Image
7
+ except ImportError:
8
+ Image = None # type: ignore
9
+
10
+ # Import result classes
11
+ from .results import ClassificationResult # Assuming results.py is in the same dir
12
+
13
+ if TYPE_CHECKING:
14
+ # Avoid runtime import cycle
15
+ from natural_pdf.core.page import Page
16
+ from natural_pdf.elements.region import Region
17
+
18
+ from .manager import ClassificationManager
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class ClassificationMixin:
24
+ """
25
+ Mixin class providing classification capabilities to Page and Region objects.
26
+ Relies on a ClassificationManager being accessible, typically via the parent PDF.
27
+ """
28
+
29
+ # --- Abstract methods/properties required by the host class --- #
30
+ # These must be implemented by classes using this mixin (Page, Region)
31
+
32
+ def _get_classification_manager(self) -> "ClassificationManager":
33
+ """Should return the ClassificationManager instance."""
34
+ raise NotImplementedError
35
+
36
+ def _get_classification_content(self, model_type: str, **kwargs) -> Union[str, "Image"]:
37
+ """Should return the text content (str) or image (PIL.Image) for classification."""
38
+ raise NotImplementedError
39
+
40
+ # Host class needs 'analyses' attribute initialized as Dict[str, Any]
41
+ # analyses: Dict[str, Any]
42
+
43
+ # --- End Abstract --- #
44
+
45
+ def classify(
46
+ self,
47
+ categories: List[str],
48
+ model: Optional[str] = None, # Default handled by manager
49
+ using: Optional[str] = None, # Renamed parameter
50
+ min_confidence: float = 0.0,
51
+ analysis_key: str = "classification", # Default key
52
+ multi_label: bool = False,
53
+ **kwargs,
54
+ ) -> "ClassificationMixin": # Return self for chaining
55
+ """
56
+ Classifies this item (Page or Region) using the configured manager.
57
+
58
+ Stores the result in self.analyses[analysis_key]. If analysis_key is not
59
+ provided, it defaults to 'classification' and overwrites any previous
60
+ result under that key.
61
+
62
+ Args:
63
+ categories: A list of string category names.
64
+ model: Model identifier (e.g., 'text', 'vision', HF ID). Defaults handled by manager.
65
+ using: Optional processing mode ('text' or 'vision'). If None, inferred by manager.
66
+ min_confidence: Minimum confidence threshold for results (0.0-1.0).
67
+ analysis_key: Key under which to store the result in `self.analyses`.
68
+ Defaults to 'classification'.
69
+ multi_label: Whether to allow multiple labels (passed to HF pipeline).
70
+ **kwargs: Additional arguments passed to the ClassificationManager.
71
+
72
+ Returns:
73
+ Self for method chaining.
74
+ """
75
+ # Ensure analyses dict exists
76
+ if not hasattr(self, "analyses") or self.analyses is None:
77
+ logger.warning("'analyses' attribute not found or is None. Initializing as empty dict.")
78
+ self.analyses = {}
79
+
80
+ try:
81
+ manager = self._get_classification_manager()
82
+
83
+ # Determine the effective model ID and engine type
84
+ effective_model_id = model
85
+ inferred_using = manager.infer_using(
86
+ model if model else manager.DEFAULT_TEXT_MODEL, using
87
+ )
88
+
89
+ # If model was not provided, use the manager's default for the inferred engine type
90
+ if effective_model_id is None:
91
+ effective_model_id = (
92
+ manager.DEFAULT_TEXT_MODEL
93
+ if inferred_using == "text"
94
+ else manager.DEFAULT_VISION_MODEL
95
+ )
96
+ logger.debug(
97
+ f"No model provided, using default for mode '{inferred_using}': '{effective_model_id}'"
98
+ )
99
+
100
+ # Get content based on the *final* determined engine type
101
+ content = self._get_classification_content(model_type=inferred_using, **kwargs)
102
+
103
+ # Manager now returns a ClassificationResult object
104
+ result_obj: ClassificationResult = manager.classify_item(
105
+ item_content=content,
106
+ categories=categories,
107
+ model_id=effective_model_id, # Pass the resolved model ID
108
+ using=inferred_using, # Pass renamed argument
109
+ min_confidence=min_confidence,
110
+ multi_label=multi_label,
111
+ **kwargs,
112
+ )
113
+
114
+ # Store the structured result object under the specified key
115
+ self.analyses[analysis_key] = result_obj
116
+ logger.debug(f"Stored classification result under key '{analysis_key}': {result_obj}")
117
+
118
+ except NotImplementedError as nie:
119
+ logger.error(f"Classification cannot proceed: {nie}")
120
+ raise
121
+ except Exception as e:
122
+ logger.error(f"Classification failed: {e}", exc_info=True)
123
+ # Optionally re-raise or just log and return self
124
+ # raise
125
+
126
+ return self
127
+
128
+ @property
129
+ def classification_results(self) -> Optional[ClassificationResult]:
130
+ """Returns the ClassificationResult from the *default* ('classification') key, or None."""
131
+ if not hasattr(self, "analyses") or self.analyses is None:
132
+ return None
133
+ # Return the result object directly from the default key
134
+ return self.analyses.get("classification")
135
+
136
+ @property
137
+ def category(self) -> Optional[str]:
138
+ """Returns the top category label from the *default* ('classification') key, or None."""
139
+ result_obj = self.classification_results # Uses the property above
140
+ # Access the property on the result object
141
+ return result_obj.top_category if result_obj else None
142
+
143
+ @property
144
+ def category_confidence(self) -> Optional[float]:
145
+ """Returns the top category confidence from the *default* ('classification') key, or None."""
146
+ result_obj = self.classification_results # Uses the property above
147
+ # Access the property on the result object
148
+ return result_obj.top_confidence if result_obj else None
149
+
150
+ # Maybe add a helper to get results by specific key?
151
+ def get_classification_result(
152
+ self, analysis_key: str = "classification"
153
+ ) -> Optional[ClassificationResult]:
154
+ """Gets a classification result object stored under a specific key."""
155
+ if not hasattr(self, "analyses") or self.analyses is None:
156
+ return None
157
+ result = self.analyses.get(analysis_key)
158
+ if result is not None and not isinstance(result, ClassificationResult):
159
+ logger.warning(
160
+ f"Item found under key '{analysis_key}' is not a ClassificationResult (type: {type(result)}). Returning None."
161
+ )
162
+ return None
163
+ return result
@@ -0,0 +1,80 @@
1
+ # natural_pdf/classification/results.py
2
+ import logging
3
+ from dataclasses import dataclass
4
+ from datetime import datetime
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ @dataclass
11
+ class CategoryScore:
12
+ """Represents a category and its confidence score from classification."""
13
+
14
+ category: str
15
+ score: float
16
+
17
+ def to_dict(self) -> Dict[str, Any]:
18
+ """Convert to dictionary for serialization."""
19
+ return {"category": self.category, "score": self.score}
20
+
21
+
22
+ @dataclass
23
+ class ClassificationResult:
24
+ """Results from a classification operation."""
25
+
26
+ category: str
27
+ score: float
28
+ scores: List[CategoryScore]
29
+ model_id: str
30
+ timestamp: datetime
31
+ using: str # 'text' or 'vision'
32
+ parameters: Optional[Dict[str, Any]] = None
33
+
34
+ def __init__(
35
+ self,
36
+ category: str,
37
+ score: float,
38
+ scores: List[CategoryScore],
39
+ model_id: str,
40
+ using: str,
41
+ parameters: Optional[Dict[str, Any]] = None,
42
+ timestamp: Optional[datetime] = None,
43
+ ):
44
+ self.category = category
45
+ self.score = score
46
+ self.scores = scores
47
+ self.model_id = model_id
48
+ self.using = using
49
+ self.parameters = parameters or {}
50
+ self.timestamp = timestamp or datetime.now()
51
+
52
+ def to_dict(self) -> Dict[str, Any]:
53
+ """
54
+ Convert the classification result to a dictionary for serialization.
55
+
56
+ Returns:
57
+ Dictionary representation of the classification result
58
+ """
59
+ return {
60
+ "category": self.category,
61
+ "score": self.score,
62
+ "scores": [s.to_dict() for s in self.scores],
63
+ "model_id": self.model_id,
64
+ "using": self.using,
65
+ "parameters": self.parameters,
66
+ "timestamp": self.timestamp.isoformat(),
67
+ }
68
+
69
+ @property
70
+ def top_category(self) -> str:
71
+ """Returns the category with the highest score."""
72
+ return self.category
73
+
74
+ @property
75
+ def top_confidence(self) -> float:
76
+ """Returns the confidence score of the top category."""
77
+ return self.score
78
+
79
+ def __repr__(self) -> str:
80
+ return f"<ClassificationResult category='{self.category}' score={self.score:.3f} model='{self.model_id}'>"