natural-pdf 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- natural_pdf/__init__.py +3 -0
- natural_pdf/analyzers/layout/base.py +1 -5
- natural_pdf/analyzers/layout/gemini.py +61 -51
- natural_pdf/analyzers/layout/layout_analyzer.py +40 -11
- natural_pdf/analyzers/layout/layout_manager.py +26 -84
- natural_pdf/analyzers/layout/layout_options.py +7 -0
- natural_pdf/analyzers/layout/pdfplumber_table_finder.py +142 -0
- natural_pdf/analyzers/layout/surya.py +46 -123
- natural_pdf/analyzers/layout/tatr.py +51 -4
- natural_pdf/analyzers/text_structure.py +3 -5
- natural_pdf/analyzers/utils.py +3 -3
- natural_pdf/classification/manager.py +422 -0
- natural_pdf/classification/mixin.py +163 -0
- natural_pdf/classification/results.py +80 -0
- natural_pdf/collections/mixins.py +111 -0
- natural_pdf/collections/pdf_collection.py +434 -15
- natural_pdf/core/element_manager.py +83 -0
- natural_pdf/core/highlighting_service.py +13 -22
- natural_pdf/core/page.py +578 -93
- natural_pdf/core/pdf.py +912 -460
- natural_pdf/elements/base.py +134 -40
- natural_pdf/elements/collections.py +712 -109
- natural_pdf/elements/region.py +722 -69
- natural_pdf/elements/text.py +4 -1
- natural_pdf/export/mixin.py +137 -0
- natural_pdf/exporters/base.py +3 -3
- natural_pdf/exporters/paddleocr.py +5 -4
- natural_pdf/extraction/manager.py +135 -0
- natural_pdf/extraction/mixin.py +279 -0
- natural_pdf/extraction/result.py +23 -0
- natural_pdf/ocr/__init__.py +5 -5
- natural_pdf/ocr/engine_doctr.py +346 -0
- natural_pdf/ocr/engine_easyocr.py +6 -3
- natural_pdf/ocr/ocr_factory.py +24 -4
- natural_pdf/ocr/ocr_manager.py +122 -26
- natural_pdf/ocr/ocr_options.py +94 -11
- natural_pdf/ocr/utils.py +19 -6
- natural_pdf/qa/document_qa.py +0 -4
- natural_pdf/search/__init__.py +20 -34
- natural_pdf/search/haystack_search_service.py +309 -265
- natural_pdf/search/haystack_utils.py +99 -75
- natural_pdf/search/search_service_protocol.py +11 -12
- natural_pdf/selectors/parser.py +431 -230
- natural_pdf/utils/debug.py +3 -3
- natural_pdf/utils/identifiers.py +1 -1
- natural_pdf/utils/locks.py +8 -0
- natural_pdf/utils/packaging.py +8 -6
- natural_pdf/utils/text_extraction.py +60 -1
- natural_pdf/utils/tqdm_utils.py +51 -0
- natural_pdf/utils/visualization.py +18 -0
- natural_pdf/widgets/viewer.py +4 -25
- {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.9.dist-info}/METADATA +17 -3
- natural_pdf-0.1.9.dist-info/RECORD +80 -0
- {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.9.dist-info}/WHEEL +1 -1
- {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.9.dist-info}/top_level.txt +0 -2
- docs/api/index.md +0 -386
- docs/assets/favicon.png +0 -3
- docs/assets/favicon.svg +0 -3
- docs/assets/javascripts/custom.js +0 -17
- docs/assets/logo.svg +0 -3
- docs/assets/sample-screen.png +0 -0
- docs/assets/social-preview.png +0 -17
- docs/assets/social-preview.svg +0 -17
- docs/assets/stylesheets/custom.css +0 -65
- docs/document-qa/index.ipynb +0 -435
- docs/document-qa/index.md +0 -79
- docs/element-selection/index.ipynb +0 -915
- docs/element-selection/index.md +0 -229
- docs/finetuning/index.md +0 -176
- docs/index.md +0 -170
- docs/installation/index.md +0 -69
- docs/interactive-widget/index.ipynb +0 -962
- docs/interactive-widget/index.md +0 -12
- docs/layout-analysis/index.ipynb +0 -818
- docs/layout-analysis/index.md +0 -185
- docs/ocr/index.md +0 -209
- docs/pdf-navigation/index.ipynb +0 -314
- docs/pdf-navigation/index.md +0 -97
- docs/regions/index.ipynb +0 -816
- docs/regions/index.md +0 -294
- docs/tables/index.ipynb +0 -658
- docs/tables/index.md +0 -144
- docs/text-analysis/index.ipynb +0 -370
- docs/text-analysis/index.md +0 -105
- docs/text-extraction/index.ipynb +0 -1478
- docs/text-extraction/index.md +0 -292
- docs/tutorials/01-loading-and-extraction.ipynb +0 -194
- docs/tutorials/01-loading-and-extraction.md +0 -95
- docs/tutorials/02-finding-elements.ipynb +0 -340
- docs/tutorials/02-finding-elements.md +0 -149
- docs/tutorials/03-extracting-blocks.ipynb +0 -147
- docs/tutorials/03-extracting-blocks.md +0 -48
- docs/tutorials/04-table-extraction.ipynb +0 -114
- docs/tutorials/04-table-extraction.md +0 -50
- docs/tutorials/05-excluding-content.ipynb +0 -270
- docs/tutorials/05-excluding-content.md +0 -109
- docs/tutorials/06-document-qa.ipynb +0 -332
- docs/tutorials/06-document-qa.md +0 -91
- docs/tutorials/07-layout-analysis.ipynb +0 -288
- docs/tutorials/07-layout-analysis.md +0 -66
- docs/tutorials/07-working-with-regions.ipynb +0 -413
- docs/tutorials/07-working-with-regions.md +0 -151
- docs/tutorials/08-spatial-navigation.ipynb +0 -508
- docs/tutorials/08-spatial-navigation.md +0 -190
- docs/tutorials/09-section-extraction.ipynb +0 -2434
- docs/tutorials/09-section-extraction.md +0 -256
- docs/tutorials/10-form-field-extraction.ipynb +0 -512
- docs/tutorials/10-form-field-extraction.md +0 -201
- docs/tutorials/11-enhanced-table-processing.ipynb +0 -54
- docs/tutorials/11-enhanced-table-processing.md +0 -9
- docs/tutorials/12-ocr-integration.ipynb +0 -604
- docs/tutorials/12-ocr-integration.md +0 -175
- docs/tutorials/13-semantic-search.ipynb +0 -1328
- docs/tutorials/13-semantic-search.md +0 -77
- docs/visual-debugging/index.ipynb +0 -2970
- docs/visual-debugging/index.md +0 -157
- docs/visual-debugging/region.png +0 -0
- natural_pdf/templates/finetune/fine_tune_paddleocr.md +0 -415
- natural_pdf/templates/spa/css/style.css +0 -334
- natural_pdf/templates/spa/index.html +0 -31
- natural_pdf/templates/spa/js/app.js +0 -472
- natural_pdf/templates/spa/words.txt +0 -235976
- natural_pdf/widgets/frontend/viewer.js +0 -88
- natural_pdf-0.1.7.dist-info/RECORD +0 -145
- notebooks/Examples.ipynb +0 -1293
- pdfs/.gitkeep +0 -0
- pdfs/01-practice.pdf +0 -543
- pdfs/0500000US42001.pdf +0 -0
- pdfs/0500000US42007.pdf +0 -0
- pdfs/2014 Statistics.pdf +0 -0
- pdfs/2019 Statistics.pdf +0 -0
- pdfs/Atlanta_Public_Schools_GA_sample.pdf +0 -0
- pdfs/needs-ocr.pdf +0 -0
- {natural_pdf-0.1.7.dist-info → natural_pdf-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,422 @@
|
|
1
|
+
import logging
|
2
|
+
import time
|
3
|
+
from datetime import datetime
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
5
|
+
|
6
|
+
from PIL import Image
|
7
|
+
|
8
|
+
# Use try-except for robustness if dependencies are missing
|
9
|
+
try:
|
10
|
+
import torch
|
11
|
+
from transformers import (
|
12
|
+
AutoModelForSequenceClassification,
|
13
|
+
AutoModelForZeroShotImageClassification,
|
14
|
+
AutoTokenizer,
|
15
|
+
pipeline,
|
16
|
+
)
|
17
|
+
|
18
|
+
_CLASSIFICATION_AVAILABLE = True
|
19
|
+
except ImportError:
|
20
|
+
_CLASSIFICATION_AVAILABLE = False
|
21
|
+
# Define dummy types for type hinting if imports fail
|
22
|
+
pipeline = object
|
23
|
+
AutoTokenizer = object
|
24
|
+
AutoModelForZeroShotImageClassification = object
|
25
|
+
AutoModelForSequenceClassification = object
|
26
|
+
torch = None
|
27
|
+
|
28
|
+
from natural_pdf.utils.tqdm_utils import get_tqdm
|
29
|
+
|
30
|
+
# Import result classes
|
31
|
+
from .results import CategoryScore, ClassificationResult
|
32
|
+
|
33
|
+
if TYPE_CHECKING:
|
34
|
+
from transformers import Pipeline
|
35
|
+
|
36
|
+
|
37
|
+
logger = logging.getLogger(__name__)
|
38
|
+
|
39
|
+
# Global cache for models/pipelines
|
40
|
+
_PIPELINE_CACHE: Dict[str, "Pipeline"] = {}
|
41
|
+
_TOKENIZER_CACHE: Dict[str, Any] = {}
|
42
|
+
_MODEL_CACHE: Dict[str, Any] = {}
|
43
|
+
|
44
|
+
|
45
|
+
class ClassificationError(Exception):
|
46
|
+
"""Custom exception for classification errors."""
|
47
|
+
|
48
|
+
pass
|
49
|
+
|
50
|
+
|
51
|
+
class ClassificationManager:
|
52
|
+
"""Manages classification models and execution."""
|
53
|
+
|
54
|
+
DEFAULT_TEXT_MODEL = "facebook/bart-large-mnli"
|
55
|
+
DEFAULT_VISION_MODEL = "openai/clip-vit-base-patch16"
|
56
|
+
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
model_mapping: Optional[Dict[str, str]] = None,
|
60
|
+
default_device: Optional[str] = None,
|
61
|
+
):
|
62
|
+
"""
|
63
|
+
Initialize the ClassificationManager.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
model_mapping: Optional dictionary mapping aliases ('text', 'vision') to model IDs.
|
67
|
+
default_device: Default device ('cpu', 'cuda') if not specified in classify calls.
|
68
|
+
"""
|
69
|
+
if not _CLASSIFICATION_AVAILABLE:
|
70
|
+
raise ImportError(
|
71
|
+
"Classification dependencies missing. "
|
72
|
+
'Install with: pip install "natural-pdf[classification]"'
|
73
|
+
)
|
74
|
+
|
75
|
+
self.pipelines: Dict[Tuple[str, str], "Pipeline"] = (
|
76
|
+
{}
|
77
|
+
) # Cache: (model_id, device) -> pipeline
|
78
|
+
|
79
|
+
self.device = default_device
|
80
|
+
logger.info(f"ClassificationManager initialized on device: {self.device}")
|
81
|
+
|
82
|
+
def is_available(self) -> bool:
|
83
|
+
"""Check if required dependencies are installed."""
|
84
|
+
return _CLASSIFICATION_AVAILABLE
|
85
|
+
|
86
|
+
def _get_pipeline(self, model_id: str, using: str) -> "Pipeline":
|
87
|
+
"""Get or create a classification pipeline."""
|
88
|
+
cache_key = f"{model_id}_{using}_{self.device}"
|
89
|
+
if cache_key not in _PIPELINE_CACHE:
|
90
|
+
logger.info(
|
91
|
+
f"Loading {using} classification pipeline for model '{model_id}' on device '{self.device}'..."
|
92
|
+
)
|
93
|
+
start_time = time.time()
|
94
|
+
try:
|
95
|
+
task = (
|
96
|
+
"zero-shot-classification"
|
97
|
+
if using == "text"
|
98
|
+
else "zero-shot-image-classification"
|
99
|
+
)
|
100
|
+
_PIPELINE_CACHE[cache_key] = pipeline(task, model=model_id, device=self.device)
|
101
|
+
end_time = time.time()
|
102
|
+
logger.info(
|
103
|
+
f"Pipeline for '{model_id}' loaded in {end_time - start_time:.2f} seconds."
|
104
|
+
)
|
105
|
+
except Exception as e:
|
106
|
+
logger.error(
|
107
|
+
f"Failed to load pipeline for model '{model_id}' (using: {using}): {e}",
|
108
|
+
exc_info=True,
|
109
|
+
)
|
110
|
+
raise ClassificationError(
|
111
|
+
f"Failed to load pipeline for model '{model_id}'. Ensure the model ID is correct and supports the {task} task."
|
112
|
+
) from e
|
113
|
+
return _PIPELINE_CACHE[cache_key]
|
114
|
+
|
115
|
+
def infer_using(self, model_id: str, using: Optional[str] = None) -> str:
|
116
|
+
"""Infers processing mode ('text' or 'vision') if not provided."""
|
117
|
+
if using in ["text", "vision"]:
|
118
|
+
return using
|
119
|
+
|
120
|
+
# Simple inference based on common model names
|
121
|
+
normalized_model_id = model_id.lower()
|
122
|
+
if (
|
123
|
+
"clip" in normalized_model_id
|
124
|
+
or "vit" in normalized_model_id
|
125
|
+
or "siglip" in normalized_model_id
|
126
|
+
):
|
127
|
+
logger.debug(f"Inferred using='vision' for model '{model_id}'")
|
128
|
+
return "vision"
|
129
|
+
if (
|
130
|
+
"bart" in normalized_model_id
|
131
|
+
or "bert" in normalized_model_id
|
132
|
+
or "mnli" in normalized_model_id
|
133
|
+
or "xnli" in normalized_model_id
|
134
|
+
or "deberta" in normalized_model_id
|
135
|
+
):
|
136
|
+
logger.debug(f"Inferred using='text' for model '{model_id}'")
|
137
|
+
return "text"
|
138
|
+
|
139
|
+
# Fallback or raise error? Let's try loading text first, then vision.
|
140
|
+
logger.warning(
|
141
|
+
f"Could not reliably infer mode for '{model_id}'. Trying text, then vision pipeline loading."
|
142
|
+
)
|
143
|
+
try:
|
144
|
+
self._get_pipeline(model_id, "text")
|
145
|
+
logger.info(f"Successfully loaded '{model_id}' as a text model.")
|
146
|
+
return "text"
|
147
|
+
except Exception:
|
148
|
+
logger.warning(f"Failed to load '{model_id}' as text model. Trying vision.")
|
149
|
+
try:
|
150
|
+
self._get_pipeline(model_id, "vision")
|
151
|
+
logger.info(f"Successfully loaded '{model_id}' as a vision model.")
|
152
|
+
return "vision"
|
153
|
+
except Exception as e_vision:
|
154
|
+
logger.error(
|
155
|
+
f"Failed to load '{model_id}' as either text or vision model.", exc_info=True
|
156
|
+
)
|
157
|
+
raise ClassificationError(
|
158
|
+
f"Cannot determine mode for model '{model_id}'. Please specify `using='text'` or `using='vision'`. Error: {e_vision}"
|
159
|
+
)
|
160
|
+
|
161
|
+
def classify_item(
|
162
|
+
self,
|
163
|
+
item_content: Union[str, Image.Image],
|
164
|
+
categories: List[str],
|
165
|
+
model_id: Optional[str] = None,
|
166
|
+
using: Optional[str] = None,
|
167
|
+
min_confidence: float = 0.0,
|
168
|
+
multi_label: bool = False,
|
169
|
+
**kwargs,
|
170
|
+
) -> ClassificationResult: # Return ClassificationResult
|
171
|
+
"""Classifies a single item (text or image)."""
|
172
|
+
|
173
|
+
# Determine model and engine type
|
174
|
+
effective_using = using
|
175
|
+
if model_id is None:
|
176
|
+
# Try inferring based on content type
|
177
|
+
if isinstance(item_content, str):
|
178
|
+
effective_using = "text"
|
179
|
+
model_id = self.DEFAULT_TEXT_MODEL
|
180
|
+
elif isinstance(item_content, Image.Image):
|
181
|
+
effective_using = "vision"
|
182
|
+
model_id = self.DEFAULT_VISION_MODEL
|
183
|
+
else:
|
184
|
+
raise TypeError(f"Unsupported item_content type: {type(item_content)}")
|
185
|
+
else:
|
186
|
+
# Infer engine type if not given
|
187
|
+
effective_using = self.infer_using(model_id, using)
|
188
|
+
# Set default model if needed (though should usually be provided if engine known)
|
189
|
+
if model_id is None:
|
190
|
+
model_id = (
|
191
|
+
self.DEFAULT_TEXT_MODEL
|
192
|
+
if effective_using == "text"
|
193
|
+
else self.DEFAULT_VISION_MODEL
|
194
|
+
)
|
195
|
+
|
196
|
+
if not categories:
|
197
|
+
raise ValueError("Categories list cannot be empty.")
|
198
|
+
|
199
|
+
pipeline_instance = self._get_pipeline(model_id, effective_using)
|
200
|
+
timestamp = datetime.now()
|
201
|
+
parameters = { # Store parameters used for this run
|
202
|
+
"categories": categories,
|
203
|
+
"model_id": model_id,
|
204
|
+
"using": effective_using,
|
205
|
+
"min_confidence": min_confidence,
|
206
|
+
"multi_label": multi_label,
|
207
|
+
**kwargs,
|
208
|
+
}
|
209
|
+
|
210
|
+
logger.debug(
|
211
|
+
f"Classifying content (type: {type(item_content).__name__}) with model '{model_id}'"
|
212
|
+
)
|
213
|
+
try:
|
214
|
+
# Handle potential kwargs for specific pipelines if needed
|
215
|
+
# The zero-shot pipelines expect `candidate_labels`
|
216
|
+
result_raw = pipeline_instance(
|
217
|
+
item_content, candidate_labels=categories, multi_label=multi_label, **kwargs
|
218
|
+
)
|
219
|
+
logger.debug(f"Raw pipeline result: {result_raw}")
|
220
|
+
|
221
|
+
# --- Process raw result into ClassificationResult --- #
|
222
|
+
scores_list: List[CategoryScore] = []
|
223
|
+
|
224
|
+
# Handle text pipeline format (dict with 'labels' and 'scores')
|
225
|
+
if isinstance(result_raw, dict) and "labels" in result_raw and "scores" in result_raw:
|
226
|
+
for label, score_val in zip(result_raw["labels"], result_raw["scores"]):
|
227
|
+
if score_val >= min_confidence:
|
228
|
+
try:
|
229
|
+
scores_list.append(CategoryScore(label=label, confidence=score_val))
|
230
|
+
except (ValueError, TypeError) as score_err:
|
231
|
+
logger.warning(
|
232
|
+
f"Skipping invalid score from text pipeline: label='{label}', score={score_val}. Error: {score_err}"
|
233
|
+
)
|
234
|
+
# Handle vision pipeline format (list of dicts with 'label' and 'score')
|
235
|
+
elif isinstance(result_raw, list) and all(
|
236
|
+
isinstance(item, dict) and "label" in item and "score" in item
|
237
|
+
for item in result_raw
|
238
|
+
):
|
239
|
+
for item in result_raw:
|
240
|
+
score_val = item["score"]
|
241
|
+
label = item["label"]
|
242
|
+
if score_val >= min_confidence:
|
243
|
+
try:
|
244
|
+
scores_list.append(CategoryScore(label=label, confidence=score_val))
|
245
|
+
except (ValueError, TypeError) as score_err:
|
246
|
+
logger.warning(
|
247
|
+
f"Skipping invalid score from vision pipeline: label='{label}', score={score_val}. Error: {score_err}"
|
248
|
+
)
|
249
|
+
else:
|
250
|
+
logger.warning(
|
251
|
+
f"Unexpected raw result format from pipeline for model '{model_id}': {type(result_raw)}. Cannot extract scores."
|
252
|
+
)
|
253
|
+
# Return empty result?
|
254
|
+
# scores_list = []
|
255
|
+
|
256
|
+
return ClassificationResult(
|
257
|
+
model_id=model_id,
|
258
|
+
using=effective_using,
|
259
|
+
timestamp=timestamp,
|
260
|
+
parameters=parameters,
|
261
|
+
scores=scores_list,
|
262
|
+
)
|
263
|
+
# --- End Processing --- #
|
264
|
+
|
265
|
+
except Exception as e:
|
266
|
+
logger.error(f"Classification failed for model '{model_id}': {e}", exc_info=True)
|
267
|
+
# Return an empty result object on failure?
|
268
|
+
# return ClassificationResult(model_id=model_id, engine_type=engine_type, timestamp=timestamp, parameters=parameters, scores=[])
|
269
|
+
raise ClassificationError(
|
270
|
+
f"Classification failed using model '{model_id}'. Error: {e}"
|
271
|
+
) from e
|
272
|
+
|
273
|
+
def classify_batch(
|
274
|
+
self,
|
275
|
+
item_contents: List[Union[str, Image.Image]],
|
276
|
+
categories: List[str],
|
277
|
+
model_id: Optional[str] = None,
|
278
|
+
using: Optional[str] = None,
|
279
|
+
min_confidence: float = 0.0,
|
280
|
+
multi_label: bool = False,
|
281
|
+
batch_size: int = 8,
|
282
|
+
progress_bar: bool = True,
|
283
|
+
**kwargs,
|
284
|
+
) -> List[ClassificationResult]: # Return list of ClassificationResult
|
285
|
+
"""Classifies a batch of items (text or image) using the pipeline's batching."""
|
286
|
+
if not item_contents:
|
287
|
+
return []
|
288
|
+
|
289
|
+
# Determine model and engine type (assuming uniform type in batch)
|
290
|
+
first_item = item_contents[0]
|
291
|
+
effective_using = using
|
292
|
+
if model_id is None:
|
293
|
+
if isinstance(first_item, str):
|
294
|
+
effective_using = "text"
|
295
|
+
model_id = self.DEFAULT_TEXT_MODEL
|
296
|
+
elif isinstance(first_item, Image.Image):
|
297
|
+
effective_using = "vision"
|
298
|
+
model_id = self.DEFAULT_VISION_MODEL
|
299
|
+
else:
|
300
|
+
raise TypeError(f"Unsupported item_content type in batch: {type(first_item)}")
|
301
|
+
else:
|
302
|
+
effective_using = self.infer_using(model_id, using)
|
303
|
+
if model_id is None:
|
304
|
+
model_id = (
|
305
|
+
self.DEFAULT_TEXT_MODEL
|
306
|
+
if effective_using == "text"
|
307
|
+
else self.DEFAULT_VISION_MODEL
|
308
|
+
)
|
309
|
+
|
310
|
+
if not categories:
|
311
|
+
raise ValueError("Categories list cannot be empty.")
|
312
|
+
|
313
|
+
pipeline_instance = self._get_pipeline(model_id, effective_using)
|
314
|
+
timestamp = datetime.now() # Single timestamp for the batch run
|
315
|
+
parameters = { # Parameters for the whole batch
|
316
|
+
"categories": categories,
|
317
|
+
"model_id": model_id,
|
318
|
+
"using": effective_using,
|
319
|
+
"min_confidence": min_confidence,
|
320
|
+
"multi_label": multi_label,
|
321
|
+
"batch_size": batch_size,
|
322
|
+
**kwargs,
|
323
|
+
}
|
324
|
+
|
325
|
+
logger.info(
|
326
|
+
f"Classifying batch of {len(item_contents)} items with model '{model_id}' (batch size: {batch_size})"
|
327
|
+
)
|
328
|
+
batch_results_list: List[ClassificationResult] = []
|
329
|
+
|
330
|
+
try:
|
331
|
+
# Use pipeline directly for batching
|
332
|
+
results_iterator = pipeline_instance(
|
333
|
+
item_contents,
|
334
|
+
candidate_labels=categories,
|
335
|
+
multi_label=multi_label,
|
336
|
+
batch_size=batch_size,
|
337
|
+
**kwargs,
|
338
|
+
)
|
339
|
+
|
340
|
+
# Wrap with tqdm for progress if requested
|
341
|
+
total_items = len(item_contents)
|
342
|
+
if progress_bar:
|
343
|
+
# Get the appropriate tqdm class
|
344
|
+
tqdm_class = get_tqdm()
|
345
|
+
results_iterator = tqdm_class(
|
346
|
+
results_iterator,
|
347
|
+
total=total_items,
|
348
|
+
desc=f"Classifying batch ({model_id})",
|
349
|
+
leave=False, # Don't leave progress bar hanging
|
350
|
+
)
|
351
|
+
|
352
|
+
for raw_result in results_iterator:
|
353
|
+
# --- Process each raw result (which corresponds to ONE input item) --- #
|
354
|
+
scores_list: List[CategoryScore] = []
|
355
|
+
try:
|
356
|
+
# Check for text format (dict with 'labels' and 'scores')
|
357
|
+
if (
|
358
|
+
isinstance(raw_result, dict)
|
359
|
+
and "labels" in raw_result
|
360
|
+
and "scores" in raw_result
|
361
|
+
):
|
362
|
+
for label, score_val in zip(raw_result["labels"], raw_result["scores"]):
|
363
|
+
if score_val >= min_confidence:
|
364
|
+
try:
|
365
|
+
scores_list.append(
|
366
|
+
CategoryScore(label=label, confidence=score_val)
|
367
|
+
)
|
368
|
+
except (ValueError, TypeError) as score_err:
|
369
|
+
logger.warning(
|
370
|
+
f"Skipping invalid score from text pipeline batch: label='{label}', score={score_val}. Error: {score_err}"
|
371
|
+
)
|
372
|
+
# Check for vision format (list of dicts with 'label' and 'score')
|
373
|
+
elif isinstance(raw_result, list):
|
374
|
+
for item in raw_result:
|
375
|
+
try:
|
376
|
+
score_val = item["score"]
|
377
|
+
label = item["label"]
|
378
|
+
if score_val >= min_confidence:
|
379
|
+
scores_list.append(
|
380
|
+
CategoryScore(label=label, confidence=score_val)
|
381
|
+
)
|
382
|
+
except (KeyError, ValueError, TypeError) as item_err:
|
383
|
+
logger.warning(
|
384
|
+
f"Skipping invalid item in vision result list from batch: {item}. Error: {item_err}"
|
385
|
+
)
|
386
|
+
else:
|
387
|
+
logger.warning(
|
388
|
+
f"Unexpected raw result format in batch item from model '{model_id}': {type(raw_result)}. Cannot extract scores."
|
389
|
+
)
|
390
|
+
|
391
|
+
except Exception as proc_err:
|
392
|
+
logger.error(
|
393
|
+
f"Error processing result item in batch: {proc_err}", exc_info=True
|
394
|
+
)
|
395
|
+
# scores_list remains empty for this item
|
396
|
+
|
397
|
+
# Append result object for this item
|
398
|
+
batch_results_list.append(
|
399
|
+
ClassificationResult(
|
400
|
+
model_id=model_id,
|
401
|
+
using=effective_using,
|
402
|
+
timestamp=timestamp, # Use same timestamp for batch
|
403
|
+
parameters=parameters, # Use same params for batch
|
404
|
+
scores=scores_list,
|
405
|
+
)
|
406
|
+
)
|
407
|
+
# --- End Processing --- #
|
408
|
+
|
409
|
+
if len(batch_results_list) != total_items:
|
410
|
+
logger.warning(
|
411
|
+
f"Batch classification returned {len(batch_results_list)} results, but expected {total_items}. Results might be incomplete or misaligned."
|
412
|
+
)
|
413
|
+
|
414
|
+
return batch_results_list
|
415
|
+
|
416
|
+
except Exception as e:
|
417
|
+
logger.error(f"Batch classification failed for model '{model_id}': {e}", exc_info=True)
|
418
|
+
# Return list of empty results?
|
419
|
+
# return [ClassificationResult(model_id=model_id, s=engine_type, timestamp=timestamp, parameters=parameters, scores=[]) for _ in item_contents]
|
420
|
+
raise ClassificationError(
|
421
|
+
f"Batch classification failed using model '{model_id}'. Error: {e}"
|
422
|
+
) from e
|
@@ -0,0 +1,163 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
3
|
+
|
4
|
+
# Assuming PIL is installed as it's needed for vision
|
5
|
+
try:
|
6
|
+
from PIL import Image
|
7
|
+
except ImportError:
|
8
|
+
Image = None # type: ignore
|
9
|
+
|
10
|
+
# Import result classes
|
11
|
+
from .results import ClassificationResult # Assuming results.py is in the same dir
|
12
|
+
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
# Avoid runtime import cycle
|
15
|
+
from natural_pdf.core.page import Page
|
16
|
+
from natural_pdf.elements.region import Region
|
17
|
+
|
18
|
+
from .manager import ClassificationManager
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class ClassificationMixin:
|
24
|
+
"""
|
25
|
+
Mixin class providing classification capabilities to Page and Region objects.
|
26
|
+
Relies on a ClassificationManager being accessible, typically via the parent PDF.
|
27
|
+
"""
|
28
|
+
|
29
|
+
# --- Abstract methods/properties required by the host class --- #
|
30
|
+
# These must be implemented by classes using this mixin (Page, Region)
|
31
|
+
|
32
|
+
def _get_classification_manager(self) -> "ClassificationManager":
|
33
|
+
"""Should return the ClassificationManager instance."""
|
34
|
+
raise NotImplementedError
|
35
|
+
|
36
|
+
def _get_classification_content(self, model_type: str, **kwargs) -> Union[str, "Image"]:
|
37
|
+
"""Should return the text content (str) or image (PIL.Image) for classification."""
|
38
|
+
raise NotImplementedError
|
39
|
+
|
40
|
+
# Host class needs 'analyses' attribute initialized as Dict[str, Any]
|
41
|
+
# analyses: Dict[str, Any]
|
42
|
+
|
43
|
+
# --- End Abstract --- #
|
44
|
+
|
45
|
+
def classify(
|
46
|
+
self,
|
47
|
+
categories: List[str],
|
48
|
+
model: Optional[str] = None, # Default handled by manager
|
49
|
+
using: Optional[str] = None, # Renamed parameter
|
50
|
+
min_confidence: float = 0.0,
|
51
|
+
analysis_key: str = "classification", # Default key
|
52
|
+
multi_label: bool = False,
|
53
|
+
**kwargs,
|
54
|
+
) -> "ClassificationMixin": # Return self for chaining
|
55
|
+
"""
|
56
|
+
Classifies this item (Page or Region) using the configured manager.
|
57
|
+
|
58
|
+
Stores the result in self.analyses[analysis_key]. If analysis_key is not
|
59
|
+
provided, it defaults to 'classification' and overwrites any previous
|
60
|
+
result under that key.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
categories: A list of string category names.
|
64
|
+
model: Model identifier (e.g., 'text', 'vision', HF ID). Defaults handled by manager.
|
65
|
+
using: Optional processing mode ('text' or 'vision'). If None, inferred by manager.
|
66
|
+
min_confidence: Minimum confidence threshold for results (0.0-1.0).
|
67
|
+
analysis_key: Key under which to store the result in `self.analyses`.
|
68
|
+
Defaults to 'classification'.
|
69
|
+
multi_label: Whether to allow multiple labels (passed to HF pipeline).
|
70
|
+
**kwargs: Additional arguments passed to the ClassificationManager.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
Self for method chaining.
|
74
|
+
"""
|
75
|
+
# Ensure analyses dict exists
|
76
|
+
if not hasattr(self, "analyses") or self.analyses is None:
|
77
|
+
logger.warning("'analyses' attribute not found or is None. Initializing as empty dict.")
|
78
|
+
self.analyses = {}
|
79
|
+
|
80
|
+
try:
|
81
|
+
manager = self._get_classification_manager()
|
82
|
+
|
83
|
+
# Determine the effective model ID and engine type
|
84
|
+
effective_model_id = model
|
85
|
+
inferred_using = manager.infer_using(
|
86
|
+
model if model else manager.DEFAULT_TEXT_MODEL, using
|
87
|
+
)
|
88
|
+
|
89
|
+
# If model was not provided, use the manager's default for the inferred engine type
|
90
|
+
if effective_model_id is None:
|
91
|
+
effective_model_id = (
|
92
|
+
manager.DEFAULT_TEXT_MODEL
|
93
|
+
if inferred_using == "text"
|
94
|
+
else manager.DEFAULT_VISION_MODEL
|
95
|
+
)
|
96
|
+
logger.debug(
|
97
|
+
f"No model provided, using default for mode '{inferred_using}': '{effective_model_id}'"
|
98
|
+
)
|
99
|
+
|
100
|
+
# Get content based on the *final* determined engine type
|
101
|
+
content = self._get_classification_content(model_type=inferred_using, **kwargs)
|
102
|
+
|
103
|
+
# Manager now returns a ClassificationResult object
|
104
|
+
result_obj: ClassificationResult = manager.classify_item(
|
105
|
+
item_content=content,
|
106
|
+
categories=categories,
|
107
|
+
model_id=effective_model_id, # Pass the resolved model ID
|
108
|
+
using=inferred_using, # Pass renamed argument
|
109
|
+
min_confidence=min_confidence,
|
110
|
+
multi_label=multi_label,
|
111
|
+
**kwargs,
|
112
|
+
)
|
113
|
+
|
114
|
+
# Store the structured result object under the specified key
|
115
|
+
self.analyses[analysis_key] = result_obj
|
116
|
+
logger.debug(f"Stored classification result under key '{analysis_key}': {result_obj}")
|
117
|
+
|
118
|
+
except NotImplementedError as nie:
|
119
|
+
logger.error(f"Classification cannot proceed: {nie}")
|
120
|
+
raise
|
121
|
+
except Exception as e:
|
122
|
+
logger.error(f"Classification failed: {e}", exc_info=True)
|
123
|
+
# Optionally re-raise or just log and return self
|
124
|
+
# raise
|
125
|
+
|
126
|
+
return self
|
127
|
+
|
128
|
+
@property
|
129
|
+
def classification_results(self) -> Optional[ClassificationResult]:
|
130
|
+
"""Returns the ClassificationResult from the *default* ('classification') key, or None."""
|
131
|
+
if not hasattr(self, "analyses") or self.analyses is None:
|
132
|
+
return None
|
133
|
+
# Return the result object directly from the default key
|
134
|
+
return self.analyses.get("classification")
|
135
|
+
|
136
|
+
@property
|
137
|
+
def category(self) -> Optional[str]:
|
138
|
+
"""Returns the top category label from the *default* ('classification') key, or None."""
|
139
|
+
result_obj = self.classification_results # Uses the property above
|
140
|
+
# Access the property on the result object
|
141
|
+
return result_obj.top_category if result_obj else None
|
142
|
+
|
143
|
+
@property
|
144
|
+
def category_confidence(self) -> Optional[float]:
|
145
|
+
"""Returns the top category confidence from the *default* ('classification') key, or None."""
|
146
|
+
result_obj = self.classification_results # Uses the property above
|
147
|
+
# Access the property on the result object
|
148
|
+
return result_obj.top_confidence if result_obj else None
|
149
|
+
|
150
|
+
# Maybe add a helper to get results by specific key?
|
151
|
+
def get_classification_result(
|
152
|
+
self, analysis_key: str = "classification"
|
153
|
+
) -> Optional[ClassificationResult]:
|
154
|
+
"""Gets a classification result object stored under a specific key."""
|
155
|
+
if not hasattr(self, "analyses") or self.analyses is None:
|
156
|
+
return None
|
157
|
+
result = self.analyses.get(analysis_key)
|
158
|
+
if result is not None and not isinstance(result, ClassificationResult):
|
159
|
+
logger.warning(
|
160
|
+
f"Item found under key '{analysis_key}' is not a ClassificationResult (type: {type(result)}). Returning None."
|
161
|
+
)
|
162
|
+
return None
|
163
|
+
return result
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# natural_pdf/classification/results.py
|
2
|
+
import logging
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from datetime import datetime
|
5
|
+
from typing import Any, Dict, List, Optional
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class CategoryScore:
|
12
|
+
"""Represents a category and its confidence score from classification."""
|
13
|
+
|
14
|
+
category: str
|
15
|
+
score: float
|
16
|
+
|
17
|
+
def to_dict(self) -> Dict[str, Any]:
|
18
|
+
"""Convert to dictionary for serialization."""
|
19
|
+
return {"category": self.category, "score": self.score}
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class ClassificationResult:
|
24
|
+
"""Results from a classification operation."""
|
25
|
+
|
26
|
+
category: str
|
27
|
+
score: float
|
28
|
+
scores: List[CategoryScore]
|
29
|
+
model_id: str
|
30
|
+
timestamp: datetime
|
31
|
+
using: str # 'text' or 'vision'
|
32
|
+
parameters: Optional[Dict[str, Any]] = None
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
category: str,
|
37
|
+
score: float,
|
38
|
+
scores: List[CategoryScore],
|
39
|
+
model_id: str,
|
40
|
+
using: str,
|
41
|
+
parameters: Optional[Dict[str, Any]] = None,
|
42
|
+
timestamp: Optional[datetime] = None,
|
43
|
+
):
|
44
|
+
self.category = category
|
45
|
+
self.score = score
|
46
|
+
self.scores = scores
|
47
|
+
self.model_id = model_id
|
48
|
+
self.using = using
|
49
|
+
self.parameters = parameters or {}
|
50
|
+
self.timestamp = timestamp or datetime.now()
|
51
|
+
|
52
|
+
def to_dict(self) -> Dict[str, Any]:
|
53
|
+
"""
|
54
|
+
Convert the classification result to a dictionary for serialization.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
Dictionary representation of the classification result
|
58
|
+
"""
|
59
|
+
return {
|
60
|
+
"category": self.category,
|
61
|
+
"score": self.score,
|
62
|
+
"scores": [s.to_dict() for s in self.scores],
|
63
|
+
"model_id": self.model_id,
|
64
|
+
"using": self.using,
|
65
|
+
"parameters": self.parameters,
|
66
|
+
"timestamp": self.timestamp.isoformat(),
|
67
|
+
}
|
68
|
+
|
69
|
+
@property
|
70
|
+
def top_category(self) -> str:
|
71
|
+
"""Returns the category with the highest score."""
|
72
|
+
return self.category
|
73
|
+
|
74
|
+
@property
|
75
|
+
def top_confidence(self) -> float:
|
76
|
+
"""Returns the confidence score of the top category."""
|
77
|
+
return self.score
|
78
|
+
|
79
|
+
def __repr__(self) -> str:
|
80
|
+
return f"<ClassificationResult category='{self.category}' score={self.score:.3f} model='{self.model_id}'>"
|