natural-pdf 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. docs/categorizing-documents/index.md +168 -0
  2. docs/data-extraction/index.md +87 -0
  3. docs/element-selection/index.ipynb +218 -164
  4. docs/element-selection/index.md +20 -0
  5. docs/index.md +19 -0
  6. docs/ocr/index.md +63 -16
  7. docs/tutorials/01-loading-and-extraction.ipynb +1713 -34
  8. docs/tutorials/02-finding-elements.ipynb +123 -46
  9. docs/tutorials/03-extracting-blocks.ipynb +24 -19
  10. docs/tutorials/04-table-extraction.ipynb +17 -12
  11. docs/tutorials/05-excluding-content.ipynb +37 -32
  12. docs/tutorials/06-document-qa.ipynb +36 -31
  13. docs/tutorials/07-layout-analysis.ipynb +45 -40
  14. docs/tutorials/07-working-with-regions.ipynb +61 -60
  15. docs/tutorials/08-spatial-navigation.ipynb +76 -71
  16. docs/tutorials/09-section-extraction.ipynb +160 -155
  17. docs/tutorials/10-form-field-extraction.ipynb +71 -66
  18. docs/tutorials/11-enhanced-table-processing.ipynb +11 -6
  19. docs/tutorials/12-ocr-integration.ipynb +3420 -312
  20. docs/tutorials/12-ocr-integration.md +68 -106
  21. docs/tutorials/13-semantic-search.ipynb +641 -251
  22. natural_pdf/__init__.py +2 -0
  23. natural_pdf/classification/manager.py +343 -0
  24. natural_pdf/classification/mixin.py +149 -0
  25. natural_pdf/classification/results.py +62 -0
  26. natural_pdf/collections/mixins.py +63 -0
  27. natural_pdf/collections/pdf_collection.py +321 -15
  28. natural_pdf/core/element_manager.py +67 -0
  29. natural_pdf/core/page.py +227 -64
  30. natural_pdf/core/pdf.py +387 -378
  31. natural_pdf/elements/collections.py +272 -41
  32. natural_pdf/elements/region.py +99 -15
  33. natural_pdf/elements/text.py +5 -2
  34. natural_pdf/exporters/paddleocr.py +1 -1
  35. natural_pdf/extraction/manager.py +134 -0
  36. natural_pdf/extraction/mixin.py +246 -0
  37. natural_pdf/extraction/result.py +37 -0
  38. natural_pdf/ocr/engine_easyocr.py +6 -3
  39. natural_pdf/ocr/ocr_manager.py +85 -25
  40. natural_pdf/ocr/ocr_options.py +33 -10
  41. natural_pdf/ocr/utils.py +14 -3
  42. natural_pdf/qa/document_qa.py +0 -4
  43. natural_pdf/selectors/parser.py +363 -238
  44. natural_pdf/templates/finetune/fine_tune_paddleocr.md +10 -5
  45. natural_pdf/utils/locks.py +8 -0
  46. natural_pdf/utils/text_extraction.py +52 -1
  47. natural_pdf/utils/tqdm_utils.py +43 -0
  48. {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.8.dist-info}/METADATA +6 -1
  49. {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.8.dist-info}/RECORD +52 -41
  50. {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.8.dist-info}/WHEEL +1 -1
  51. {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.8.dist-info}/licenses/LICENSE +0 -0
  52. {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.8.dist-info}/top_level.txt +0 -0
natural_pdf/__init__.py CHANGED
@@ -3,6 +3,8 @@ Natural PDF - A more intuitive interface for working with PDFs.
3
3
  """
4
4
 
5
5
  import logging
6
+ import os
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
6
8
 
7
9
  # Create library logger
8
10
  logger = logging.getLogger("natural_pdf")
@@ -0,0 +1,343 @@
1
+ import logging
2
+ import time
3
+ from datetime import datetime
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Tuple
5
+
6
+ # Use try-except for robustness if dependencies are missing
7
+ try:
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import pipeline, AutoTokenizer, AutoModelForZeroShotImageClassification, AutoModelForSequenceClassification
11
+ _CLASSIFICATION_AVAILABLE = True
12
+ except ImportError:
13
+ _CLASSIFICATION_AVAILABLE = False
14
+ # Define dummy types for type hinting if imports fail
15
+ Image = type("Image", (), {})
16
+ pipeline = object
17
+ AutoTokenizer = object
18
+ AutoModelForZeroShotImageClassification = object
19
+ AutoModelForSequenceClassification = object
20
+ torch = None
21
+
22
+ # Import result classes
23
+ from .results import ClassificationResult, CategoryScore
24
+ from natural_pdf.utils.tqdm_utils import get_tqdm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers import Pipeline
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Global cache for models/pipelines
33
+ _PIPELINE_CACHE: Dict[str, "Pipeline"] = {}
34
+ _TOKENIZER_CACHE: Dict[str, Any] = {}
35
+ _MODEL_CACHE: Dict[str, Any] = {}
36
+
37
+ class ClassificationError(Exception):
38
+ """Custom exception for classification errors."""
39
+ pass
40
+
41
+
42
+ class ClassificationManager:
43
+ """Manages classification models and execution."""
44
+
45
+ DEFAULT_TEXT_MODEL = "facebook/bart-large-mnli"
46
+ DEFAULT_VISION_MODEL = "openai/clip-vit-base-patch16"
47
+
48
+ def __init__(
49
+ self,
50
+ model_mapping: Optional[Dict[str, str]] = None,
51
+ default_device: Optional[str] = None,
52
+ ):
53
+ """
54
+ Initialize the ClassificationManager.
55
+
56
+ Args:
57
+ model_mapping: Optional dictionary mapping aliases ('text', 'vision') to model IDs.
58
+ default_device: Default device ('cpu', 'cuda') if not specified in classify calls.
59
+ """
60
+ if not _CLASSIFICATION_AVAILABLE:
61
+ raise ImportError(
62
+ "Classification dependencies missing. "
63
+ "Install with: pip install \"natural-pdf[classification]\""
64
+ )
65
+
66
+ self.pipelines: Dict[Tuple[str, str], "Pipeline"] = {} # Cache: (model_id, device) -> pipeline
67
+
68
+ self.device = default_device
69
+ logger.info(f"ClassificationManager initialized on device: {self.device}")
70
+
71
+ def is_available(self) -> bool:
72
+ """Check if required dependencies are installed."""
73
+ return _CLASSIFICATION_AVAILABLE
74
+
75
+ def _get_pipeline(self, model_id: str, using: str) -> "Pipeline":
76
+ """Get or create a classification pipeline."""
77
+ cache_key = f"{model_id}_{using}_{self.device}"
78
+ if cache_key not in _PIPELINE_CACHE:
79
+ logger.info(f"Loading {using} classification pipeline for model '{model_id}' on device '{self.device}'...")
80
+ start_time = time.time()
81
+ try:
82
+ task = (
83
+ "zero-shot-classification"
84
+ if using == "text"
85
+ else "zero-shot-image-classification"
86
+ )
87
+ _PIPELINE_CACHE[cache_key] = pipeline(
88
+ task,
89
+ model=model_id,
90
+ device=self.device
91
+ )
92
+ end_time = time.time()
93
+ logger.info(f"Pipeline for '{model_id}' loaded in {end_time - start_time:.2f} seconds.")
94
+ except Exception as e:
95
+ logger.error(f"Failed to load pipeline for model '{model_id}' (using: {using}): {e}", exc_info=True)
96
+ raise ClassificationError(f"Failed to load pipeline for model '{model_id}'. Ensure the model ID is correct and supports the {task} task.") from e
97
+ return _PIPELINE_CACHE[cache_key]
98
+
99
+ def infer_using(self, model_id: str, using: Optional[str] = None) -> str:
100
+ """Infers processing mode ('text' or 'vision') if not provided."""
101
+ if using in ["text", "vision"]:
102
+ return using
103
+
104
+ # Simple inference based on common model names
105
+ normalized_model_id = model_id.lower()
106
+ if "clip" in normalized_model_id or "vit" in normalized_model_id or "siglip" in normalized_model_id:
107
+ logger.debug(f"Inferred using='vision' for model '{model_id}'")
108
+ return "vision"
109
+ if "bart" in normalized_model_id or "bert" in normalized_model_id or "mnli" in normalized_model_id or "xnli" in normalized_model_id or "deberta" in normalized_model_id:
110
+ logger.debug(f"Inferred using='text' for model '{model_id}'")
111
+ return "text"
112
+
113
+ # Fallback or raise error? Let's try loading text first, then vision.
114
+ logger.warning(f"Could not reliably infer mode for '{model_id}'. Trying text, then vision pipeline loading.")
115
+ try:
116
+ self._get_pipeline(model_id, "text")
117
+ logger.info(f"Successfully loaded '{model_id}' as a text model.")
118
+ return "text"
119
+ except Exception:
120
+ logger.warning(f"Failed to load '{model_id}' as text model. Trying vision.")
121
+ try:
122
+ self._get_pipeline(model_id, "vision")
123
+ logger.info(f"Successfully loaded '{model_id}' as a vision model.")
124
+ return "vision"
125
+ except Exception as e_vision:
126
+ logger.error(f"Failed to load '{model_id}' as either text or vision model.", exc_info=True)
127
+ raise ClassificationError(f"Cannot determine mode for model '{model_id}'. Please specify `using='text'` or `using='vision'`. Error: {e_vision}")
128
+
129
+ def classify_item(
130
+ self,
131
+ item_content: Union[str, Image.Image],
132
+ categories: List[str],
133
+ model_id: Optional[str] = None,
134
+ using: Optional[str] = None,
135
+ min_confidence: float = 0.0,
136
+ multi_label: bool = False,
137
+ **kwargs
138
+ ) -> ClassificationResult: # Return ClassificationResult
139
+ """Classifies a single item (text or image)."""
140
+
141
+ # Determine model and engine type
142
+ effective_using = using
143
+ if model_id is None:
144
+ # Try inferring based on content type
145
+ if isinstance(item_content, str):
146
+ effective_using = "text"
147
+ model_id = self.DEFAULT_TEXT_MODEL
148
+ elif isinstance(item_content, Image.Image):
149
+ effective_using = "vision"
150
+ model_id = self.DEFAULT_VISION_MODEL
151
+ else:
152
+ raise TypeError(f"Unsupported item_content type: {type(item_content)}")
153
+ else:
154
+ # Infer engine type if not given
155
+ effective_using = self.infer_using(model_id, using)
156
+ # Set default model if needed (though should usually be provided if engine known)
157
+ if model_id is None:
158
+ model_id = self.DEFAULT_TEXT_MODEL if effective_using == "text" else self.DEFAULT_VISION_MODEL
159
+
160
+ if not categories:
161
+ raise ValueError("Categories list cannot be empty.")
162
+
163
+ pipeline_instance = self._get_pipeline(model_id, effective_using)
164
+ timestamp = datetime.now()
165
+ parameters = { # Store parameters used for this run
166
+ 'categories': categories,
167
+ 'model_id': model_id,
168
+ 'using': effective_using,
169
+ 'min_confidence': min_confidence,
170
+ 'multi_label': multi_label,
171
+ **kwargs
172
+ }
173
+
174
+ logger.debug(f"Classifying content (type: {type(item_content).__name__}) with model '{model_id}'")
175
+ try:
176
+ # Handle potential kwargs for specific pipelines if needed
177
+ # The zero-shot pipelines expect `candidate_labels`
178
+ result_raw = pipeline_instance(item_content, candidate_labels=categories, multi_label=multi_label, **kwargs)
179
+ logger.debug(f"Raw pipeline result: {result_raw}")
180
+
181
+ # --- Process raw result into ClassificationResult --- #
182
+ scores_list: List[CategoryScore] = []
183
+
184
+ # Handle text pipeline format (dict with 'labels' and 'scores')
185
+ if isinstance(result_raw, dict) and 'labels' in result_raw and 'scores' in result_raw:
186
+ for label, score_val in zip(result_raw['labels'], result_raw['scores']):
187
+ if score_val >= min_confidence:
188
+ try:
189
+ scores_list.append(CategoryScore(label=label, confidence=score_val))
190
+ except (ValueError, TypeError) as score_err:
191
+ logger.warning(f"Skipping invalid score from text pipeline: label='{label}', score={score_val}. Error: {score_err}")
192
+ # Handle vision pipeline format (list of dicts with 'label' and 'score')
193
+ elif isinstance(result_raw, list) and all(isinstance(item, dict) and 'label' in item and 'score' in item for item in result_raw):
194
+ for item in result_raw:
195
+ score_val = item['score']
196
+ label = item['label']
197
+ if score_val >= min_confidence:
198
+ try:
199
+ scores_list.append(CategoryScore(label=label, confidence=score_val))
200
+ except (ValueError, TypeError) as score_err:
201
+ logger.warning(f"Skipping invalid score from vision pipeline: label='{label}', score={score_val}. Error: {score_err}")
202
+ else:
203
+ logger.warning(f"Unexpected raw result format from pipeline for model '{model_id}': {type(result_raw)}. Cannot extract scores.")
204
+ # Return empty result?
205
+ # scores_list = []
206
+
207
+ return ClassificationResult(
208
+ model_id=model_id,
209
+ using=effective_using,
210
+ timestamp=timestamp,
211
+ parameters=parameters,
212
+ scores=scores_list
213
+ )
214
+ # --- End Processing --- #
215
+
216
+ except Exception as e:
217
+ logger.error(f"Classification failed for model '{model_id}': {e}", exc_info=True)
218
+ # Return an empty result object on failure?
219
+ # return ClassificationResult(model_id=model_id, engine_type=engine_type, timestamp=timestamp, parameters=parameters, scores=[])
220
+ raise ClassificationError(f"Classification failed using model '{model_id}'. Error: {e}") from e
221
+
222
+ def classify_batch(
223
+ self,
224
+ item_contents: List[Union[str, Image.Image]],
225
+ categories: List[str],
226
+ model_id: Optional[str] = None,
227
+ using: Optional[str] = None,
228
+ min_confidence: float = 0.0,
229
+ multi_label: bool = False,
230
+ batch_size: int = 8,
231
+ progress_bar: bool = True,
232
+ **kwargs
233
+ ) -> List[ClassificationResult]: # Return list of ClassificationResult
234
+ """Classifies a batch of items (text or image) using the pipeline's batching."""
235
+ if not item_contents:
236
+ return []
237
+
238
+ # Determine model and engine type (assuming uniform type in batch)
239
+ first_item = item_contents[0]
240
+ effective_using = using
241
+ if model_id is None:
242
+ if isinstance(first_item, str):
243
+ effective_using = "text"
244
+ model_id = self.DEFAULT_TEXT_MODEL
245
+ elif isinstance(first_item, Image.Image):
246
+ effective_using = "vision"
247
+ model_id = self.DEFAULT_VISION_MODEL
248
+ else:
249
+ raise TypeError(f"Unsupported item_content type in batch: {type(first_item)}")
250
+ else:
251
+ effective_using = self.infer_using(model_id, using)
252
+ if model_id is None:
253
+ model_id = self.DEFAULT_TEXT_MODEL if effective_using == "text" else self.DEFAULT_VISION_MODEL
254
+
255
+ if not categories:
256
+ raise ValueError("Categories list cannot be empty.")
257
+
258
+ pipeline_instance = self._get_pipeline(model_id, effective_using)
259
+ timestamp = datetime.now() # Single timestamp for the batch run
260
+ parameters = { # Parameters for the whole batch
261
+ 'categories': categories,
262
+ 'model_id': model_id,
263
+ 'using': effective_using,
264
+ 'min_confidence': min_confidence,
265
+ 'multi_label': multi_label,
266
+ 'batch_size': batch_size,
267
+ **kwargs
268
+ }
269
+
270
+ logger.info(f"Classifying batch of {len(item_contents)} items with model '{model_id}' (batch size: {batch_size})")
271
+ batch_results_list: List[ClassificationResult] = []
272
+
273
+ try:
274
+ # Use pipeline directly for batching
275
+ results_iterator = pipeline_instance(
276
+ item_contents,
277
+ candidate_labels=categories,
278
+ multi_label=multi_label,
279
+ batch_size=batch_size,
280
+ **kwargs
281
+ )
282
+
283
+ # Wrap with tqdm for progress if requested
284
+ total_items = len(item_contents)
285
+ if progress_bar:
286
+ # Get the appropriate tqdm class
287
+ tqdm_class = get_tqdm()
288
+ results_iterator = tqdm_class(
289
+ results_iterator,
290
+ total=total_items,
291
+ desc=f"Classifying batch ({model_id})",
292
+ leave=False # Don't leave progress bar hanging
293
+ )
294
+
295
+ for raw_result in results_iterator:
296
+ # --- Process each raw result (which corresponds to ONE input item) --- #
297
+ scores_list: List[CategoryScore] = []
298
+ try:
299
+ # Check for text format (dict with 'labels' and 'scores')
300
+ if isinstance(raw_result, dict) and 'labels' in raw_result and 'scores' in raw_result:
301
+ for label, score_val in zip(raw_result['labels'], raw_result['scores']):
302
+ if score_val >= min_confidence:
303
+ try:
304
+ scores_list.append(CategoryScore(label=label, confidence=score_val))
305
+ except (ValueError, TypeError) as score_err:
306
+ logger.warning(f"Skipping invalid score from text pipeline batch: label='{label}', score={score_val}. Error: {score_err}")
307
+ # Check for vision format (list of dicts with 'label' and 'score')
308
+ elif isinstance(raw_result, list):
309
+ for item in raw_result:
310
+ try:
311
+ score_val = item['score']
312
+ label = item['label']
313
+ if score_val >= min_confidence:
314
+ scores_list.append(CategoryScore(label=label, confidence=score_val))
315
+ except (KeyError, ValueError, TypeError) as item_err:
316
+ logger.warning(f"Skipping invalid item in vision result list from batch: {item}. Error: {item_err}")
317
+ else:
318
+ logger.warning(f"Unexpected raw result format in batch item from model '{model_id}': {type(raw_result)}. Cannot extract scores.")
319
+
320
+ except Exception as proc_err:
321
+ logger.error(f"Error processing result item in batch: {proc_err}", exc_info=True)
322
+ # scores_list remains empty for this item
323
+
324
+ # Append result object for this item
325
+ batch_results_list.append(ClassificationResult(
326
+ model_id=model_id,
327
+ using=effective_using,
328
+ timestamp=timestamp, # Use same timestamp for batch
329
+ parameters=parameters, # Use same params for batch
330
+ scores=scores_list
331
+ ))
332
+ # --- End Processing --- #
333
+
334
+ if len(batch_results_list) != total_items:
335
+ logger.warning(f"Batch classification returned {len(batch_results_list)} results, but expected {total_items}. Results might be incomplete or misaligned.")
336
+
337
+ return batch_results_list
338
+
339
+ except Exception as e:
340
+ logger.error(f"Batch classification failed for model '{model_id}': {e}", exc_info=True)
341
+ # Return list of empty results?
342
+ # return [ClassificationResult(model_id=model_id, s=engine_type, timestamp=timestamp, parameters=parameters, scores=[]) for _ in item_contents]
343
+ raise ClassificationError(f"Batch classification failed using model '{model_id}'. Error: {e}") from e
@@ -0,0 +1,149 @@
1
+ import logging
2
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ # Assuming PIL is installed as it's needed for vision
5
+ try:
6
+ from PIL import Image
7
+ except ImportError:
8
+ Image = None # type: ignore
9
+
10
+ # Import result classes
11
+ from .results import ClassificationResult # Assuming results.py is in the same dir
12
+
13
+ if TYPE_CHECKING:
14
+ # Avoid runtime import cycle
15
+ from natural_pdf.core.page import Page
16
+ from natural_pdf.elements.region import Region
17
+ from .manager import ClassificationManager
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class ClassificationMixin:
22
+ """
23
+ Mixin class providing classification capabilities to Page and Region objects.
24
+ Relies on a ClassificationManager being accessible, typically via the parent PDF.
25
+ """
26
+
27
+ # --- Abstract methods/properties required by the host class --- #
28
+ # These must be implemented by classes using this mixin (Page, Region)
29
+
30
+ def _get_classification_manager(self) -> "ClassificationManager":
31
+ """Should return the ClassificationManager instance."""
32
+ raise NotImplementedError
33
+
34
+ def _get_classification_content(self, model_type: str, **kwargs) -> Union[str, "Image"]:
35
+ """Should return the text content (str) or image (PIL.Image) for classification."""
36
+ raise NotImplementedError
37
+
38
+ # Host class needs 'analyses' attribute initialized as Dict[str, Any]
39
+ # analyses: Dict[str, Any]
40
+
41
+ # --- End Abstract --- #
42
+
43
+ def classify(
44
+ self,
45
+ categories: List[str],
46
+ model: Optional[str] = None, # Default handled by manager
47
+ using: Optional[str] = None, # Renamed parameter
48
+ min_confidence: float = 0.0,
49
+ analysis_key: str = 'classification', # Default key
50
+ multi_label: bool = False,
51
+ **kwargs
52
+ ) -> "ClassificationMixin": # Return self for chaining
53
+ """
54
+ Classifies this item (Page or Region) using the configured manager.
55
+
56
+ Stores the result in self.analyses[analysis_key]. If analysis_key is not
57
+ provided, it defaults to 'classification' and overwrites any previous
58
+ result under that key.
59
+
60
+ Args:
61
+ categories: A list of string category names.
62
+ model: Model identifier (e.g., 'text', 'vision', HF ID). Defaults handled by manager.
63
+ using: Optional processing mode ('text' or 'vision'). If None, inferred by manager.
64
+ min_confidence: Minimum confidence threshold for results (0.0-1.0).
65
+ analysis_key: Key under which to store the result in `self.analyses`.
66
+ Defaults to 'classification'.
67
+ multi_label: Whether to allow multiple labels (passed to HF pipeline).
68
+ **kwargs: Additional arguments passed to the ClassificationManager.
69
+
70
+ Returns:
71
+ Self for method chaining.
72
+ """
73
+ # Ensure analyses dict exists
74
+ if not hasattr(self, 'analyses') or self.analyses is None:
75
+ logger.warning("'analyses' attribute not found or is None. Initializing as empty dict.")
76
+ self.analyses = {}
77
+
78
+ try:
79
+ manager = self._get_classification_manager()
80
+
81
+ # Determine the effective model ID and engine type
82
+ effective_model_id = model
83
+ inferred_using = manager.infer_using(model if model else manager.DEFAULT_TEXT_MODEL, using)
84
+
85
+ # If model was not provided, use the manager's default for the inferred engine type
86
+ if effective_model_id is None:
87
+ effective_model_id = manager.DEFAULT_TEXT_MODEL if inferred_using == 'text' else manager.DEFAULT_VISION_MODEL
88
+ logger.debug(f"No model provided, using default for mode '{inferred_using}': '{effective_model_id}'")
89
+
90
+ # Get content based on the *final* determined engine type
91
+ content = self._get_classification_content(model_type=inferred_using, **kwargs)
92
+
93
+ # Manager now returns a ClassificationResult object
94
+ result_obj: ClassificationResult = manager.classify_item(
95
+ item_content=content,
96
+ categories=categories,
97
+ model_id=effective_model_id, # Pass the resolved model ID
98
+ using=inferred_using, # Pass renamed argument
99
+ min_confidence=min_confidence,
100
+ multi_label=multi_label,
101
+ **kwargs
102
+ )
103
+
104
+ # Store the structured result object under the specified key
105
+ self.analyses[analysis_key] = result_obj
106
+ logger.debug(f"Stored classification result under key '{analysis_key}': {result_obj}")
107
+
108
+ except NotImplementedError as nie:
109
+ logger.error(f"Classification cannot proceed: {nie}")
110
+ raise
111
+ except Exception as e:
112
+ logger.error(f"Classification failed: {e}", exc_info=True)
113
+ # Optionally re-raise or just log and return self
114
+ # raise
115
+
116
+ return self
117
+
118
+ @property
119
+ def classification_results(self) -> Optional[ClassificationResult]:
120
+ """Returns the ClassificationResult from the *default* ('classification') key, or None."""
121
+ if not hasattr(self, 'analyses') or self.analyses is None:
122
+ return None
123
+ # Return the result object directly from the default key
124
+ return self.analyses.get('classification')
125
+
126
+ @property
127
+ def category(self) -> Optional[str]:
128
+ """Returns the top category label from the *default* ('classification') key, or None."""
129
+ result_obj = self.classification_results # Uses the property above
130
+ # Access the property on the result object
131
+ return result_obj.top_category if result_obj else None
132
+
133
+ @property
134
+ def category_confidence(self) -> Optional[float]:
135
+ """Returns the top category confidence from the *default* ('classification') key, or None."""
136
+ result_obj = self.classification_results # Uses the property above
137
+ # Access the property on the result object
138
+ return result_obj.top_confidence if result_obj else None
139
+
140
+ # Maybe add a helper to get results by specific key?
141
+ def get_classification_result(self, analysis_key: str = 'classification') -> Optional[ClassificationResult]:
142
+ """Gets a classification result object stored under a specific key."""
143
+ if not hasattr(self, 'analyses') or self.analyses is None:
144
+ return None
145
+ result = self.analyses.get(analysis_key)
146
+ if result is not None and not isinstance(result, ClassificationResult):
147
+ logger.warning(f"Item found under key '{analysis_key}' is not a ClassificationResult (type: {type(result)}). Returning None.")
148
+ return None
149
+ return result
@@ -0,0 +1,62 @@
1
+ # natural_pdf/classification/results.py
2
+ from typing import List, Optional, Dict, Any
3
+ from datetime import datetime
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class CategoryScore:
9
+ """Represents the score for a single category."""
10
+ label: str
11
+ confidence: float # Score between 0.0 and 1.0
12
+
13
+ def __init__(self, label: str, confidence: float):
14
+ # Basic validation
15
+ if not isinstance(label, str) or not label:
16
+ logger.warning(f"Initializing CategoryScore with invalid label: {label}")
17
+ # Fallback or raise? For now, allow but log.
18
+ # raise ValueError("Category label must be a non-empty string.")
19
+ if not isinstance(confidence, (float, int)) or not (0.0 <= confidence <= 1.0):
20
+ logger.warning(f"Initializing CategoryScore with invalid confidence: {confidence} for label '{label}'. Clamping to [0, 1].")
21
+ confidence = max(0.0, min(1.0, float(confidence)))
22
+ # raise ValueError("Category confidence must be a float between 0.0 and 1.0.")
23
+
24
+ self.label = str(label)
25
+ self.confidence = float(confidence)
26
+
27
+ def __repr__(self):
28
+ return f"<CategoryScore label='{self.label}' confidence={self.confidence:.3f}>"
29
+
30
+ class ClassificationResult:
31
+ """Holds the structured results of a classification task."""
32
+ model_id: str
33
+ using: str # Renamed from engine_type ('text' or 'vision')
34
+ timestamp: datetime
35
+ parameters: Dict[str, Any] # e.g., {'categories': [...], 'min_confidence': 0.1}
36
+ scores: List[CategoryScore] # List of scores above threshold, sorted by confidence
37
+
38
+ def __init__(self, model_id: str, using: str, timestamp: datetime, parameters: Dict[str, Any], scores: List[CategoryScore]):
39
+ if not isinstance(scores, list) or not all(isinstance(s, CategoryScore) for s in scores):
40
+ raise TypeError("Scores must be a list of CategoryScore objects.")
41
+
42
+ self.model_id = str(model_id)
43
+ self.using = str(using) # Renamed from engine_type
44
+ self.timestamp = timestamp
45
+ self.parameters = parameters if parameters is not None else {}
46
+ # Ensure scores are sorted descending by confidence
47
+ self.scores = sorted(scores, key=lambda s: s.confidence, reverse=True)
48
+
49
+ @property
50
+ def top_category(self) -> Optional[str]:
51
+ """Returns the label of the category with the highest confidence."""
52
+ return self.scores[0].label if self.scores else None
53
+
54
+ @property
55
+ def top_confidence(self) -> Optional[float]:
56
+ """Returns the confidence score of the top category."""
57
+ return self.scores[0].confidence if self.scores else None
58
+
59
+ def __repr__(self):
60
+ top_cat = f" top='{self.top_category}' ({self.top_confidence:.2f})" if self.scores else ""
61
+ num_scores = len(self.scores)
62
+ return f"<ClassificationResult model='{self.model_id}' using='{self.using}' scores={num_scores}{top_cat}>"
@@ -0,0 +1,63 @@
1
+ import logging
2
+ from typing import Callable, Iterable, Any, TypeVar
3
+ from tqdm.auto import tqdm
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ T = TypeVar("T") # Generic type for items in the collection
8
+
9
+ class ApplyMixin:
10
+ """
11
+ Mixin class providing an `.apply()` method for collections.
12
+
13
+ Assumes the inheriting class implements `__iter__` and `__len__` appropriately
14
+ for the items to be processed by `apply`.
15
+ """
16
+ def _get_items_for_apply(self) -> Iterable[Any]:
17
+ """
18
+ Returns the iterable of items to apply the function to.
19
+ Defaults to iterating over `self`. Subclasses should override this
20
+ if the default iteration is not suitable for the apply operation.
21
+ """
22
+ # Default to standard iteration over the collection itself
23
+ return iter(self)
24
+
25
+ def apply(self: Any, func: Callable[[Any, ...], Any], *args, **kwargs) -> None:
26
+ """
27
+ Applies a function to each item in the collection.
28
+
29
+ Args:
30
+ func: The function to apply to each item. The item itself
31
+ will be passed as the first argument to the function.
32
+ *args: Additional positional arguments to pass to func.
33
+ **kwargs: Additional keyword arguments to pass to func.
34
+ A special keyword argument 'show_progress' (bool, default=False)
35
+ can be used to display a progress bar.
36
+ """
37
+ show_progress = kwargs.pop('show_progress', False)
38
+ # Derive unit name from class name
39
+ unit_name = self.__class__.__name__.lower()
40
+ items_iterable = self._get_items_for_apply()
41
+
42
+ # Need total count for tqdm, assumes __len__ is implemented by the inheriting class
43
+ total_items = 0
44
+ try:
45
+ total_items = len(self)
46
+ except TypeError: # Handle cases where __len__ might not be defined on self
47
+ logger.warning(f"Could not determine collection length for progress bar.")
48
+
49
+ if show_progress and total_items > 0:
50
+ items_iterable = tqdm(items_iterable, total=total_items, desc=f"Applying {func.__name__}", unit=unit_name)
51
+ elif show_progress:
52
+ logger.info(f"Applying {func.__name__} (progress bar disabled for zero/unknown length).")
53
+
54
+ for item in items_iterable:
55
+ try:
56
+ # Apply the function with the item and any extra args/kwargs
57
+ func(item, *args, **kwargs)
58
+ except Exception as e:
59
+ # Log and continue for batch operations
60
+ logger.error(f"Error applying {func.__name__} to {item}: {e}", exc_info=True)
61
+ # Optionally add a mechanism to collect errors
62
+
63
+ # Returns None, primarily used for side effects.