natural-pdf 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/categorizing-documents/index.md +168 -0
- docs/data-extraction/index.md +87 -0
- docs/element-selection/index.ipynb +218 -164
- docs/element-selection/index.md +20 -0
- docs/index.md +19 -0
- docs/ocr/index.md +63 -16
- docs/tutorials/01-loading-and-extraction.ipynb +1713 -34
- docs/tutorials/02-finding-elements.ipynb +123 -46
- docs/tutorials/03-extracting-blocks.ipynb +24 -19
- docs/tutorials/04-table-extraction.ipynb +17 -12
- docs/tutorials/05-excluding-content.ipynb +37 -32
- docs/tutorials/06-document-qa.ipynb +36 -31
- docs/tutorials/07-layout-analysis.ipynb +45 -40
- docs/tutorials/07-working-with-regions.ipynb +61 -60
- docs/tutorials/08-spatial-navigation.ipynb +76 -71
- docs/tutorials/09-section-extraction.ipynb +160 -155
- docs/tutorials/10-form-field-extraction.ipynb +71 -66
- docs/tutorials/11-enhanced-table-processing.ipynb +11 -6
- docs/tutorials/12-ocr-integration.ipynb +3420 -312
- docs/tutorials/12-ocr-integration.md +68 -106
- docs/tutorials/13-semantic-search.ipynb +641 -251
- natural_pdf/__init__.py +2 -0
- natural_pdf/classification/manager.py +343 -0
- natural_pdf/classification/mixin.py +149 -0
- natural_pdf/classification/results.py +62 -0
- natural_pdf/collections/mixins.py +63 -0
- natural_pdf/collections/pdf_collection.py +321 -15
- natural_pdf/core/element_manager.py +67 -0
- natural_pdf/core/page.py +227 -64
- natural_pdf/core/pdf.py +387 -378
- natural_pdf/elements/collections.py +272 -41
- natural_pdf/elements/region.py +99 -15
- natural_pdf/elements/text.py +5 -2
- natural_pdf/exporters/paddleocr.py +1 -1
- natural_pdf/extraction/manager.py +134 -0
- natural_pdf/extraction/mixin.py +246 -0
- natural_pdf/extraction/result.py +37 -0
- natural_pdf/ocr/engine_easyocr.py +6 -3
- natural_pdf/ocr/ocr_manager.py +85 -25
- natural_pdf/ocr/ocr_options.py +33 -10
- natural_pdf/ocr/utils.py +14 -3
- natural_pdf/qa/document_qa.py +0 -4
- natural_pdf/selectors/parser.py +363 -238
- natural_pdf/templates/finetune/fine_tune_paddleocr.md +10 -5
- natural_pdf/utils/locks.py +8 -0
- natural_pdf/utils/text_extraction.py +52 -1
- natural_pdf/utils/tqdm_utils.py +43 -0
- {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.8.dist-info}/METADATA +6 -1
- {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.8.dist-info}/RECORD +52 -41
- {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.8.dist-info}/WHEEL +1 -1
- {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.8.dist-info}/top_level.txt +0 -0
natural_pdf/__init__.py
CHANGED
@@ -0,0 +1,343 @@
|
|
1
|
+
import logging
|
2
|
+
import time
|
3
|
+
from datetime import datetime
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Tuple
|
5
|
+
|
6
|
+
# Use try-except for robustness if dependencies are missing
|
7
|
+
try:
|
8
|
+
import torch
|
9
|
+
from PIL import Image
|
10
|
+
from transformers import pipeline, AutoTokenizer, AutoModelForZeroShotImageClassification, AutoModelForSequenceClassification
|
11
|
+
_CLASSIFICATION_AVAILABLE = True
|
12
|
+
except ImportError:
|
13
|
+
_CLASSIFICATION_AVAILABLE = False
|
14
|
+
# Define dummy types for type hinting if imports fail
|
15
|
+
Image = type("Image", (), {})
|
16
|
+
pipeline = object
|
17
|
+
AutoTokenizer = object
|
18
|
+
AutoModelForZeroShotImageClassification = object
|
19
|
+
AutoModelForSequenceClassification = object
|
20
|
+
torch = None
|
21
|
+
|
22
|
+
# Import result classes
|
23
|
+
from .results import ClassificationResult, CategoryScore
|
24
|
+
from natural_pdf.utils.tqdm_utils import get_tqdm
|
25
|
+
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from transformers import Pipeline
|
28
|
+
|
29
|
+
|
30
|
+
logger = logging.getLogger(__name__)
|
31
|
+
|
32
|
+
# Global cache for models/pipelines
|
33
|
+
_PIPELINE_CACHE: Dict[str, "Pipeline"] = {}
|
34
|
+
_TOKENIZER_CACHE: Dict[str, Any] = {}
|
35
|
+
_MODEL_CACHE: Dict[str, Any] = {}
|
36
|
+
|
37
|
+
class ClassificationError(Exception):
|
38
|
+
"""Custom exception for classification errors."""
|
39
|
+
pass
|
40
|
+
|
41
|
+
|
42
|
+
class ClassificationManager:
|
43
|
+
"""Manages classification models and execution."""
|
44
|
+
|
45
|
+
DEFAULT_TEXT_MODEL = "facebook/bart-large-mnli"
|
46
|
+
DEFAULT_VISION_MODEL = "openai/clip-vit-base-patch16"
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
model_mapping: Optional[Dict[str, str]] = None,
|
51
|
+
default_device: Optional[str] = None,
|
52
|
+
):
|
53
|
+
"""
|
54
|
+
Initialize the ClassificationManager.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
model_mapping: Optional dictionary mapping aliases ('text', 'vision') to model IDs.
|
58
|
+
default_device: Default device ('cpu', 'cuda') if not specified in classify calls.
|
59
|
+
"""
|
60
|
+
if not _CLASSIFICATION_AVAILABLE:
|
61
|
+
raise ImportError(
|
62
|
+
"Classification dependencies missing. "
|
63
|
+
"Install with: pip install \"natural-pdf[classification]\""
|
64
|
+
)
|
65
|
+
|
66
|
+
self.pipelines: Dict[Tuple[str, str], "Pipeline"] = {} # Cache: (model_id, device) -> pipeline
|
67
|
+
|
68
|
+
self.device = default_device
|
69
|
+
logger.info(f"ClassificationManager initialized on device: {self.device}")
|
70
|
+
|
71
|
+
def is_available(self) -> bool:
|
72
|
+
"""Check if required dependencies are installed."""
|
73
|
+
return _CLASSIFICATION_AVAILABLE
|
74
|
+
|
75
|
+
def _get_pipeline(self, model_id: str, using: str) -> "Pipeline":
|
76
|
+
"""Get or create a classification pipeline."""
|
77
|
+
cache_key = f"{model_id}_{using}_{self.device}"
|
78
|
+
if cache_key not in _PIPELINE_CACHE:
|
79
|
+
logger.info(f"Loading {using} classification pipeline for model '{model_id}' on device '{self.device}'...")
|
80
|
+
start_time = time.time()
|
81
|
+
try:
|
82
|
+
task = (
|
83
|
+
"zero-shot-classification"
|
84
|
+
if using == "text"
|
85
|
+
else "zero-shot-image-classification"
|
86
|
+
)
|
87
|
+
_PIPELINE_CACHE[cache_key] = pipeline(
|
88
|
+
task,
|
89
|
+
model=model_id,
|
90
|
+
device=self.device
|
91
|
+
)
|
92
|
+
end_time = time.time()
|
93
|
+
logger.info(f"Pipeline for '{model_id}' loaded in {end_time - start_time:.2f} seconds.")
|
94
|
+
except Exception as e:
|
95
|
+
logger.error(f"Failed to load pipeline for model '{model_id}' (using: {using}): {e}", exc_info=True)
|
96
|
+
raise ClassificationError(f"Failed to load pipeline for model '{model_id}'. Ensure the model ID is correct and supports the {task} task.") from e
|
97
|
+
return _PIPELINE_CACHE[cache_key]
|
98
|
+
|
99
|
+
def infer_using(self, model_id: str, using: Optional[str] = None) -> str:
|
100
|
+
"""Infers processing mode ('text' or 'vision') if not provided."""
|
101
|
+
if using in ["text", "vision"]:
|
102
|
+
return using
|
103
|
+
|
104
|
+
# Simple inference based on common model names
|
105
|
+
normalized_model_id = model_id.lower()
|
106
|
+
if "clip" in normalized_model_id or "vit" in normalized_model_id or "siglip" in normalized_model_id:
|
107
|
+
logger.debug(f"Inferred using='vision' for model '{model_id}'")
|
108
|
+
return "vision"
|
109
|
+
if "bart" in normalized_model_id or "bert" in normalized_model_id or "mnli" in normalized_model_id or "xnli" in normalized_model_id or "deberta" in normalized_model_id:
|
110
|
+
logger.debug(f"Inferred using='text' for model '{model_id}'")
|
111
|
+
return "text"
|
112
|
+
|
113
|
+
# Fallback or raise error? Let's try loading text first, then vision.
|
114
|
+
logger.warning(f"Could not reliably infer mode for '{model_id}'. Trying text, then vision pipeline loading.")
|
115
|
+
try:
|
116
|
+
self._get_pipeline(model_id, "text")
|
117
|
+
logger.info(f"Successfully loaded '{model_id}' as a text model.")
|
118
|
+
return "text"
|
119
|
+
except Exception:
|
120
|
+
logger.warning(f"Failed to load '{model_id}' as text model. Trying vision.")
|
121
|
+
try:
|
122
|
+
self._get_pipeline(model_id, "vision")
|
123
|
+
logger.info(f"Successfully loaded '{model_id}' as a vision model.")
|
124
|
+
return "vision"
|
125
|
+
except Exception as e_vision:
|
126
|
+
logger.error(f"Failed to load '{model_id}' as either text or vision model.", exc_info=True)
|
127
|
+
raise ClassificationError(f"Cannot determine mode for model '{model_id}'. Please specify `using='text'` or `using='vision'`. Error: {e_vision}")
|
128
|
+
|
129
|
+
def classify_item(
|
130
|
+
self,
|
131
|
+
item_content: Union[str, Image.Image],
|
132
|
+
categories: List[str],
|
133
|
+
model_id: Optional[str] = None,
|
134
|
+
using: Optional[str] = None,
|
135
|
+
min_confidence: float = 0.0,
|
136
|
+
multi_label: bool = False,
|
137
|
+
**kwargs
|
138
|
+
) -> ClassificationResult: # Return ClassificationResult
|
139
|
+
"""Classifies a single item (text or image)."""
|
140
|
+
|
141
|
+
# Determine model and engine type
|
142
|
+
effective_using = using
|
143
|
+
if model_id is None:
|
144
|
+
# Try inferring based on content type
|
145
|
+
if isinstance(item_content, str):
|
146
|
+
effective_using = "text"
|
147
|
+
model_id = self.DEFAULT_TEXT_MODEL
|
148
|
+
elif isinstance(item_content, Image.Image):
|
149
|
+
effective_using = "vision"
|
150
|
+
model_id = self.DEFAULT_VISION_MODEL
|
151
|
+
else:
|
152
|
+
raise TypeError(f"Unsupported item_content type: {type(item_content)}")
|
153
|
+
else:
|
154
|
+
# Infer engine type if not given
|
155
|
+
effective_using = self.infer_using(model_id, using)
|
156
|
+
# Set default model if needed (though should usually be provided if engine known)
|
157
|
+
if model_id is None:
|
158
|
+
model_id = self.DEFAULT_TEXT_MODEL if effective_using == "text" else self.DEFAULT_VISION_MODEL
|
159
|
+
|
160
|
+
if not categories:
|
161
|
+
raise ValueError("Categories list cannot be empty.")
|
162
|
+
|
163
|
+
pipeline_instance = self._get_pipeline(model_id, effective_using)
|
164
|
+
timestamp = datetime.now()
|
165
|
+
parameters = { # Store parameters used for this run
|
166
|
+
'categories': categories,
|
167
|
+
'model_id': model_id,
|
168
|
+
'using': effective_using,
|
169
|
+
'min_confidence': min_confidence,
|
170
|
+
'multi_label': multi_label,
|
171
|
+
**kwargs
|
172
|
+
}
|
173
|
+
|
174
|
+
logger.debug(f"Classifying content (type: {type(item_content).__name__}) with model '{model_id}'")
|
175
|
+
try:
|
176
|
+
# Handle potential kwargs for specific pipelines if needed
|
177
|
+
# The zero-shot pipelines expect `candidate_labels`
|
178
|
+
result_raw = pipeline_instance(item_content, candidate_labels=categories, multi_label=multi_label, **kwargs)
|
179
|
+
logger.debug(f"Raw pipeline result: {result_raw}")
|
180
|
+
|
181
|
+
# --- Process raw result into ClassificationResult --- #
|
182
|
+
scores_list: List[CategoryScore] = []
|
183
|
+
|
184
|
+
# Handle text pipeline format (dict with 'labels' and 'scores')
|
185
|
+
if isinstance(result_raw, dict) and 'labels' in result_raw and 'scores' in result_raw:
|
186
|
+
for label, score_val in zip(result_raw['labels'], result_raw['scores']):
|
187
|
+
if score_val >= min_confidence:
|
188
|
+
try:
|
189
|
+
scores_list.append(CategoryScore(label=label, confidence=score_val))
|
190
|
+
except (ValueError, TypeError) as score_err:
|
191
|
+
logger.warning(f"Skipping invalid score from text pipeline: label='{label}', score={score_val}. Error: {score_err}")
|
192
|
+
# Handle vision pipeline format (list of dicts with 'label' and 'score')
|
193
|
+
elif isinstance(result_raw, list) and all(isinstance(item, dict) and 'label' in item and 'score' in item for item in result_raw):
|
194
|
+
for item in result_raw:
|
195
|
+
score_val = item['score']
|
196
|
+
label = item['label']
|
197
|
+
if score_val >= min_confidence:
|
198
|
+
try:
|
199
|
+
scores_list.append(CategoryScore(label=label, confidence=score_val))
|
200
|
+
except (ValueError, TypeError) as score_err:
|
201
|
+
logger.warning(f"Skipping invalid score from vision pipeline: label='{label}', score={score_val}. Error: {score_err}")
|
202
|
+
else:
|
203
|
+
logger.warning(f"Unexpected raw result format from pipeline for model '{model_id}': {type(result_raw)}. Cannot extract scores.")
|
204
|
+
# Return empty result?
|
205
|
+
# scores_list = []
|
206
|
+
|
207
|
+
return ClassificationResult(
|
208
|
+
model_id=model_id,
|
209
|
+
using=effective_using,
|
210
|
+
timestamp=timestamp,
|
211
|
+
parameters=parameters,
|
212
|
+
scores=scores_list
|
213
|
+
)
|
214
|
+
# --- End Processing --- #
|
215
|
+
|
216
|
+
except Exception as e:
|
217
|
+
logger.error(f"Classification failed for model '{model_id}': {e}", exc_info=True)
|
218
|
+
# Return an empty result object on failure?
|
219
|
+
# return ClassificationResult(model_id=model_id, engine_type=engine_type, timestamp=timestamp, parameters=parameters, scores=[])
|
220
|
+
raise ClassificationError(f"Classification failed using model '{model_id}'. Error: {e}") from e
|
221
|
+
|
222
|
+
def classify_batch(
|
223
|
+
self,
|
224
|
+
item_contents: List[Union[str, Image.Image]],
|
225
|
+
categories: List[str],
|
226
|
+
model_id: Optional[str] = None,
|
227
|
+
using: Optional[str] = None,
|
228
|
+
min_confidence: float = 0.0,
|
229
|
+
multi_label: bool = False,
|
230
|
+
batch_size: int = 8,
|
231
|
+
progress_bar: bool = True,
|
232
|
+
**kwargs
|
233
|
+
) -> List[ClassificationResult]: # Return list of ClassificationResult
|
234
|
+
"""Classifies a batch of items (text or image) using the pipeline's batching."""
|
235
|
+
if not item_contents:
|
236
|
+
return []
|
237
|
+
|
238
|
+
# Determine model and engine type (assuming uniform type in batch)
|
239
|
+
first_item = item_contents[0]
|
240
|
+
effective_using = using
|
241
|
+
if model_id is None:
|
242
|
+
if isinstance(first_item, str):
|
243
|
+
effective_using = "text"
|
244
|
+
model_id = self.DEFAULT_TEXT_MODEL
|
245
|
+
elif isinstance(first_item, Image.Image):
|
246
|
+
effective_using = "vision"
|
247
|
+
model_id = self.DEFAULT_VISION_MODEL
|
248
|
+
else:
|
249
|
+
raise TypeError(f"Unsupported item_content type in batch: {type(first_item)}")
|
250
|
+
else:
|
251
|
+
effective_using = self.infer_using(model_id, using)
|
252
|
+
if model_id is None:
|
253
|
+
model_id = self.DEFAULT_TEXT_MODEL if effective_using == "text" else self.DEFAULT_VISION_MODEL
|
254
|
+
|
255
|
+
if not categories:
|
256
|
+
raise ValueError("Categories list cannot be empty.")
|
257
|
+
|
258
|
+
pipeline_instance = self._get_pipeline(model_id, effective_using)
|
259
|
+
timestamp = datetime.now() # Single timestamp for the batch run
|
260
|
+
parameters = { # Parameters for the whole batch
|
261
|
+
'categories': categories,
|
262
|
+
'model_id': model_id,
|
263
|
+
'using': effective_using,
|
264
|
+
'min_confidence': min_confidence,
|
265
|
+
'multi_label': multi_label,
|
266
|
+
'batch_size': batch_size,
|
267
|
+
**kwargs
|
268
|
+
}
|
269
|
+
|
270
|
+
logger.info(f"Classifying batch of {len(item_contents)} items with model '{model_id}' (batch size: {batch_size})")
|
271
|
+
batch_results_list: List[ClassificationResult] = []
|
272
|
+
|
273
|
+
try:
|
274
|
+
# Use pipeline directly for batching
|
275
|
+
results_iterator = pipeline_instance(
|
276
|
+
item_contents,
|
277
|
+
candidate_labels=categories,
|
278
|
+
multi_label=multi_label,
|
279
|
+
batch_size=batch_size,
|
280
|
+
**kwargs
|
281
|
+
)
|
282
|
+
|
283
|
+
# Wrap with tqdm for progress if requested
|
284
|
+
total_items = len(item_contents)
|
285
|
+
if progress_bar:
|
286
|
+
# Get the appropriate tqdm class
|
287
|
+
tqdm_class = get_tqdm()
|
288
|
+
results_iterator = tqdm_class(
|
289
|
+
results_iterator,
|
290
|
+
total=total_items,
|
291
|
+
desc=f"Classifying batch ({model_id})",
|
292
|
+
leave=False # Don't leave progress bar hanging
|
293
|
+
)
|
294
|
+
|
295
|
+
for raw_result in results_iterator:
|
296
|
+
# --- Process each raw result (which corresponds to ONE input item) --- #
|
297
|
+
scores_list: List[CategoryScore] = []
|
298
|
+
try:
|
299
|
+
# Check for text format (dict with 'labels' and 'scores')
|
300
|
+
if isinstance(raw_result, dict) and 'labels' in raw_result and 'scores' in raw_result:
|
301
|
+
for label, score_val in zip(raw_result['labels'], raw_result['scores']):
|
302
|
+
if score_val >= min_confidence:
|
303
|
+
try:
|
304
|
+
scores_list.append(CategoryScore(label=label, confidence=score_val))
|
305
|
+
except (ValueError, TypeError) as score_err:
|
306
|
+
logger.warning(f"Skipping invalid score from text pipeline batch: label='{label}', score={score_val}. Error: {score_err}")
|
307
|
+
# Check for vision format (list of dicts with 'label' and 'score')
|
308
|
+
elif isinstance(raw_result, list):
|
309
|
+
for item in raw_result:
|
310
|
+
try:
|
311
|
+
score_val = item['score']
|
312
|
+
label = item['label']
|
313
|
+
if score_val >= min_confidence:
|
314
|
+
scores_list.append(CategoryScore(label=label, confidence=score_val))
|
315
|
+
except (KeyError, ValueError, TypeError) as item_err:
|
316
|
+
logger.warning(f"Skipping invalid item in vision result list from batch: {item}. Error: {item_err}")
|
317
|
+
else:
|
318
|
+
logger.warning(f"Unexpected raw result format in batch item from model '{model_id}': {type(raw_result)}. Cannot extract scores.")
|
319
|
+
|
320
|
+
except Exception as proc_err:
|
321
|
+
logger.error(f"Error processing result item in batch: {proc_err}", exc_info=True)
|
322
|
+
# scores_list remains empty for this item
|
323
|
+
|
324
|
+
# Append result object for this item
|
325
|
+
batch_results_list.append(ClassificationResult(
|
326
|
+
model_id=model_id,
|
327
|
+
using=effective_using,
|
328
|
+
timestamp=timestamp, # Use same timestamp for batch
|
329
|
+
parameters=parameters, # Use same params for batch
|
330
|
+
scores=scores_list
|
331
|
+
))
|
332
|
+
# --- End Processing --- #
|
333
|
+
|
334
|
+
if len(batch_results_list) != total_items:
|
335
|
+
logger.warning(f"Batch classification returned {len(batch_results_list)} results, but expected {total_items}. Results might be incomplete or misaligned.")
|
336
|
+
|
337
|
+
return batch_results_list
|
338
|
+
|
339
|
+
except Exception as e:
|
340
|
+
logger.error(f"Batch classification failed for model '{model_id}': {e}", exc_info=True)
|
341
|
+
# Return list of empty results?
|
342
|
+
# return [ClassificationResult(model_id=model_id, s=engine_type, timestamp=timestamp, parameters=parameters, scores=[]) for _ in item_contents]
|
343
|
+
raise ClassificationError(f"Batch classification failed using model '{model_id}'. Error: {e}") from e
|
@@ -0,0 +1,149 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
3
|
+
|
4
|
+
# Assuming PIL is installed as it's needed for vision
|
5
|
+
try:
|
6
|
+
from PIL import Image
|
7
|
+
except ImportError:
|
8
|
+
Image = None # type: ignore
|
9
|
+
|
10
|
+
# Import result classes
|
11
|
+
from .results import ClassificationResult # Assuming results.py is in the same dir
|
12
|
+
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
# Avoid runtime import cycle
|
15
|
+
from natural_pdf.core.page import Page
|
16
|
+
from natural_pdf.elements.region import Region
|
17
|
+
from .manager import ClassificationManager
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
class ClassificationMixin:
|
22
|
+
"""
|
23
|
+
Mixin class providing classification capabilities to Page and Region objects.
|
24
|
+
Relies on a ClassificationManager being accessible, typically via the parent PDF.
|
25
|
+
"""
|
26
|
+
|
27
|
+
# --- Abstract methods/properties required by the host class --- #
|
28
|
+
# These must be implemented by classes using this mixin (Page, Region)
|
29
|
+
|
30
|
+
def _get_classification_manager(self) -> "ClassificationManager":
|
31
|
+
"""Should return the ClassificationManager instance."""
|
32
|
+
raise NotImplementedError
|
33
|
+
|
34
|
+
def _get_classification_content(self, model_type: str, **kwargs) -> Union[str, "Image"]:
|
35
|
+
"""Should return the text content (str) or image (PIL.Image) for classification."""
|
36
|
+
raise NotImplementedError
|
37
|
+
|
38
|
+
# Host class needs 'analyses' attribute initialized as Dict[str, Any]
|
39
|
+
# analyses: Dict[str, Any]
|
40
|
+
|
41
|
+
# --- End Abstract --- #
|
42
|
+
|
43
|
+
def classify(
|
44
|
+
self,
|
45
|
+
categories: List[str],
|
46
|
+
model: Optional[str] = None, # Default handled by manager
|
47
|
+
using: Optional[str] = None, # Renamed parameter
|
48
|
+
min_confidence: float = 0.0,
|
49
|
+
analysis_key: str = 'classification', # Default key
|
50
|
+
multi_label: bool = False,
|
51
|
+
**kwargs
|
52
|
+
) -> "ClassificationMixin": # Return self for chaining
|
53
|
+
"""
|
54
|
+
Classifies this item (Page or Region) using the configured manager.
|
55
|
+
|
56
|
+
Stores the result in self.analyses[analysis_key]. If analysis_key is not
|
57
|
+
provided, it defaults to 'classification' and overwrites any previous
|
58
|
+
result under that key.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
categories: A list of string category names.
|
62
|
+
model: Model identifier (e.g., 'text', 'vision', HF ID). Defaults handled by manager.
|
63
|
+
using: Optional processing mode ('text' or 'vision'). If None, inferred by manager.
|
64
|
+
min_confidence: Minimum confidence threshold for results (0.0-1.0).
|
65
|
+
analysis_key: Key under which to store the result in `self.analyses`.
|
66
|
+
Defaults to 'classification'.
|
67
|
+
multi_label: Whether to allow multiple labels (passed to HF pipeline).
|
68
|
+
**kwargs: Additional arguments passed to the ClassificationManager.
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
Self for method chaining.
|
72
|
+
"""
|
73
|
+
# Ensure analyses dict exists
|
74
|
+
if not hasattr(self, 'analyses') or self.analyses is None:
|
75
|
+
logger.warning("'analyses' attribute not found or is None. Initializing as empty dict.")
|
76
|
+
self.analyses = {}
|
77
|
+
|
78
|
+
try:
|
79
|
+
manager = self._get_classification_manager()
|
80
|
+
|
81
|
+
# Determine the effective model ID and engine type
|
82
|
+
effective_model_id = model
|
83
|
+
inferred_using = manager.infer_using(model if model else manager.DEFAULT_TEXT_MODEL, using)
|
84
|
+
|
85
|
+
# If model was not provided, use the manager's default for the inferred engine type
|
86
|
+
if effective_model_id is None:
|
87
|
+
effective_model_id = manager.DEFAULT_TEXT_MODEL if inferred_using == 'text' else manager.DEFAULT_VISION_MODEL
|
88
|
+
logger.debug(f"No model provided, using default for mode '{inferred_using}': '{effective_model_id}'")
|
89
|
+
|
90
|
+
# Get content based on the *final* determined engine type
|
91
|
+
content = self._get_classification_content(model_type=inferred_using, **kwargs)
|
92
|
+
|
93
|
+
# Manager now returns a ClassificationResult object
|
94
|
+
result_obj: ClassificationResult = manager.classify_item(
|
95
|
+
item_content=content,
|
96
|
+
categories=categories,
|
97
|
+
model_id=effective_model_id, # Pass the resolved model ID
|
98
|
+
using=inferred_using, # Pass renamed argument
|
99
|
+
min_confidence=min_confidence,
|
100
|
+
multi_label=multi_label,
|
101
|
+
**kwargs
|
102
|
+
)
|
103
|
+
|
104
|
+
# Store the structured result object under the specified key
|
105
|
+
self.analyses[analysis_key] = result_obj
|
106
|
+
logger.debug(f"Stored classification result under key '{analysis_key}': {result_obj}")
|
107
|
+
|
108
|
+
except NotImplementedError as nie:
|
109
|
+
logger.error(f"Classification cannot proceed: {nie}")
|
110
|
+
raise
|
111
|
+
except Exception as e:
|
112
|
+
logger.error(f"Classification failed: {e}", exc_info=True)
|
113
|
+
# Optionally re-raise or just log and return self
|
114
|
+
# raise
|
115
|
+
|
116
|
+
return self
|
117
|
+
|
118
|
+
@property
|
119
|
+
def classification_results(self) -> Optional[ClassificationResult]:
|
120
|
+
"""Returns the ClassificationResult from the *default* ('classification') key, or None."""
|
121
|
+
if not hasattr(self, 'analyses') or self.analyses is None:
|
122
|
+
return None
|
123
|
+
# Return the result object directly from the default key
|
124
|
+
return self.analyses.get('classification')
|
125
|
+
|
126
|
+
@property
|
127
|
+
def category(self) -> Optional[str]:
|
128
|
+
"""Returns the top category label from the *default* ('classification') key, or None."""
|
129
|
+
result_obj = self.classification_results # Uses the property above
|
130
|
+
# Access the property on the result object
|
131
|
+
return result_obj.top_category if result_obj else None
|
132
|
+
|
133
|
+
@property
|
134
|
+
def category_confidence(self) -> Optional[float]:
|
135
|
+
"""Returns the top category confidence from the *default* ('classification') key, or None."""
|
136
|
+
result_obj = self.classification_results # Uses the property above
|
137
|
+
# Access the property on the result object
|
138
|
+
return result_obj.top_confidence if result_obj else None
|
139
|
+
|
140
|
+
# Maybe add a helper to get results by specific key?
|
141
|
+
def get_classification_result(self, analysis_key: str = 'classification') -> Optional[ClassificationResult]:
|
142
|
+
"""Gets a classification result object stored under a specific key."""
|
143
|
+
if not hasattr(self, 'analyses') or self.analyses is None:
|
144
|
+
return None
|
145
|
+
result = self.analyses.get(analysis_key)
|
146
|
+
if result is not None and not isinstance(result, ClassificationResult):
|
147
|
+
logger.warning(f"Item found under key '{analysis_key}' is not a ClassificationResult (type: {type(result)}). Returning None.")
|
148
|
+
return None
|
149
|
+
return result
|
@@ -0,0 +1,62 @@
|
|
1
|
+
# natural_pdf/classification/results.py
|
2
|
+
from typing import List, Optional, Dict, Any
|
3
|
+
from datetime import datetime
|
4
|
+
import logging
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
class CategoryScore:
|
9
|
+
"""Represents the score for a single category."""
|
10
|
+
label: str
|
11
|
+
confidence: float # Score between 0.0 and 1.0
|
12
|
+
|
13
|
+
def __init__(self, label: str, confidence: float):
|
14
|
+
# Basic validation
|
15
|
+
if not isinstance(label, str) or not label:
|
16
|
+
logger.warning(f"Initializing CategoryScore with invalid label: {label}")
|
17
|
+
# Fallback or raise? For now, allow but log.
|
18
|
+
# raise ValueError("Category label must be a non-empty string.")
|
19
|
+
if not isinstance(confidence, (float, int)) or not (0.0 <= confidence <= 1.0):
|
20
|
+
logger.warning(f"Initializing CategoryScore with invalid confidence: {confidence} for label '{label}'. Clamping to [0, 1].")
|
21
|
+
confidence = max(0.0, min(1.0, float(confidence)))
|
22
|
+
# raise ValueError("Category confidence must be a float between 0.0 and 1.0.")
|
23
|
+
|
24
|
+
self.label = str(label)
|
25
|
+
self.confidence = float(confidence)
|
26
|
+
|
27
|
+
def __repr__(self):
|
28
|
+
return f"<CategoryScore label='{self.label}' confidence={self.confidence:.3f}>"
|
29
|
+
|
30
|
+
class ClassificationResult:
|
31
|
+
"""Holds the structured results of a classification task."""
|
32
|
+
model_id: str
|
33
|
+
using: str # Renamed from engine_type ('text' or 'vision')
|
34
|
+
timestamp: datetime
|
35
|
+
parameters: Dict[str, Any] # e.g., {'categories': [...], 'min_confidence': 0.1}
|
36
|
+
scores: List[CategoryScore] # List of scores above threshold, sorted by confidence
|
37
|
+
|
38
|
+
def __init__(self, model_id: str, using: str, timestamp: datetime, parameters: Dict[str, Any], scores: List[CategoryScore]):
|
39
|
+
if not isinstance(scores, list) or not all(isinstance(s, CategoryScore) for s in scores):
|
40
|
+
raise TypeError("Scores must be a list of CategoryScore objects.")
|
41
|
+
|
42
|
+
self.model_id = str(model_id)
|
43
|
+
self.using = str(using) # Renamed from engine_type
|
44
|
+
self.timestamp = timestamp
|
45
|
+
self.parameters = parameters if parameters is not None else {}
|
46
|
+
# Ensure scores are sorted descending by confidence
|
47
|
+
self.scores = sorted(scores, key=lambda s: s.confidence, reverse=True)
|
48
|
+
|
49
|
+
@property
|
50
|
+
def top_category(self) -> Optional[str]:
|
51
|
+
"""Returns the label of the category with the highest confidence."""
|
52
|
+
return self.scores[0].label if self.scores else None
|
53
|
+
|
54
|
+
@property
|
55
|
+
def top_confidence(self) -> Optional[float]:
|
56
|
+
"""Returns the confidence score of the top category."""
|
57
|
+
return self.scores[0].confidence if self.scores else None
|
58
|
+
|
59
|
+
def __repr__(self):
|
60
|
+
top_cat = f" top='{self.top_category}' ({self.top_confidence:.2f})" if self.scores else ""
|
61
|
+
num_scores = len(self.scores)
|
62
|
+
return f"<ClassificationResult model='{self.model_id}' using='{self.using}' scores={num_scores}{top_cat}>"
|
@@ -0,0 +1,63 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Callable, Iterable, Any, TypeVar
|
3
|
+
from tqdm.auto import tqdm
|
4
|
+
|
5
|
+
logger = logging.getLogger(__name__)
|
6
|
+
|
7
|
+
T = TypeVar("T") # Generic type for items in the collection
|
8
|
+
|
9
|
+
class ApplyMixin:
|
10
|
+
"""
|
11
|
+
Mixin class providing an `.apply()` method for collections.
|
12
|
+
|
13
|
+
Assumes the inheriting class implements `__iter__` and `__len__` appropriately
|
14
|
+
for the items to be processed by `apply`.
|
15
|
+
"""
|
16
|
+
def _get_items_for_apply(self) -> Iterable[Any]:
|
17
|
+
"""
|
18
|
+
Returns the iterable of items to apply the function to.
|
19
|
+
Defaults to iterating over `self`. Subclasses should override this
|
20
|
+
if the default iteration is not suitable for the apply operation.
|
21
|
+
"""
|
22
|
+
# Default to standard iteration over the collection itself
|
23
|
+
return iter(self)
|
24
|
+
|
25
|
+
def apply(self: Any, func: Callable[[Any, ...], Any], *args, **kwargs) -> None:
|
26
|
+
"""
|
27
|
+
Applies a function to each item in the collection.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
func: The function to apply to each item. The item itself
|
31
|
+
will be passed as the first argument to the function.
|
32
|
+
*args: Additional positional arguments to pass to func.
|
33
|
+
**kwargs: Additional keyword arguments to pass to func.
|
34
|
+
A special keyword argument 'show_progress' (bool, default=False)
|
35
|
+
can be used to display a progress bar.
|
36
|
+
"""
|
37
|
+
show_progress = kwargs.pop('show_progress', False)
|
38
|
+
# Derive unit name from class name
|
39
|
+
unit_name = self.__class__.__name__.lower()
|
40
|
+
items_iterable = self._get_items_for_apply()
|
41
|
+
|
42
|
+
# Need total count for tqdm, assumes __len__ is implemented by the inheriting class
|
43
|
+
total_items = 0
|
44
|
+
try:
|
45
|
+
total_items = len(self)
|
46
|
+
except TypeError: # Handle cases where __len__ might not be defined on self
|
47
|
+
logger.warning(f"Could not determine collection length for progress bar.")
|
48
|
+
|
49
|
+
if show_progress and total_items > 0:
|
50
|
+
items_iterable = tqdm(items_iterable, total=total_items, desc=f"Applying {func.__name__}", unit=unit_name)
|
51
|
+
elif show_progress:
|
52
|
+
logger.info(f"Applying {func.__name__} (progress bar disabled for zero/unknown length).")
|
53
|
+
|
54
|
+
for item in items_iterable:
|
55
|
+
try:
|
56
|
+
# Apply the function with the item and any extra args/kwargs
|
57
|
+
func(item, *args, **kwargs)
|
58
|
+
except Exception as e:
|
59
|
+
# Log and continue for batch operations
|
60
|
+
logger.error(f"Error applying {func.__name__} to {item}: {e}", exc_info=True)
|
61
|
+
# Optionally add a mechanism to collect errors
|
62
|
+
|
63
|
+
# Returns None, primarily used for side effects.
|