natural-pdf 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/api/index.md +386 -0
- docs/assets/favicon.png +3 -0
- docs/assets/favicon.svg +3 -0
- docs/assets/javascripts/custom.js +17 -0
- docs/assets/logo.svg +3 -0
- docs/assets/sample-screen.png +0 -0
- docs/assets/social-preview.png +17 -0
- docs/assets/social-preview.svg +17 -0
- docs/assets/stylesheets/custom.css +65 -0
- docs/document-qa/index.ipynb +435 -0
- docs/document-qa/index.md +79 -0
- docs/element-selection/index.ipynb +915 -0
- docs/element-selection/index.md +229 -0
- docs/index.md +170 -0
- docs/installation/index.md +69 -0
- docs/interactive-widget/index.ipynb +962 -0
- docs/interactive-widget/index.md +12 -0
- docs/layout-analysis/index.ipynb +818 -0
- docs/layout-analysis/index.md +185 -0
- docs/ocr/index.md +222 -0
- docs/pdf-navigation/index.ipynb +314 -0
- docs/pdf-navigation/index.md +97 -0
- docs/regions/index.ipynb +816 -0
- docs/regions/index.md +294 -0
- docs/tables/index.ipynb +658 -0
- docs/tables/index.md +144 -0
- docs/text-analysis/index.ipynb +370 -0
- docs/text-analysis/index.md +105 -0
- docs/text-extraction/index.ipynb +1478 -0
- docs/text-extraction/index.md +292 -0
- docs/tutorials/01-loading-and-extraction.ipynb +1696 -0
- docs/tutorials/01-loading-and-extraction.md +95 -0
- docs/tutorials/02-finding-elements.ipynb +340 -0
- docs/tutorials/02-finding-elements.md +149 -0
- docs/tutorials/03-extracting-blocks.ipynb +147 -0
- docs/tutorials/03-extracting-blocks.md +48 -0
- docs/tutorials/04-table-extraction.ipynb +114 -0
- docs/tutorials/04-table-extraction.md +50 -0
- docs/tutorials/05-excluding-content.ipynb +270 -0
- docs/tutorials/05-excluding-content.md +109 -0
- docs/tutorials/06-document-qa.ipynb +332 -0
- docs/tutorials/06-document-qa.md +91 -0
- docs/tutorials/07-layout-analysis.ipynb +260 -0
- docs/tutorials/07-layout-analysis.md +66 -0
- docs/tutorials/07-working-with-regions.ipynb +409 -0
- docs/tutorials/07-working-with-regions.md +151 -0
- docs/tutorials/08-spatial-navigation.ipynb +508 -0
- docs/tutorials/08-spatial-navigation.md +190 -0
- docs/tutorials/09-section-extraction.ipynb +2434 -0
- docs/tutorials/09-section-extraction.md +256 -0
- docs/tutorials/10-form-field-extraction.ipynb +484 -0
- docs/tutorials/10-form-field-extraction.md +201 -0
- docs/tutorials/11-enhanced-table-processing.ipynb +54 -0
- docs/tutorials/11-enhanced-table-processing.md +9 -0
- docs/tutorials/12-ocr-integration.ipynb +586 -0
- docs/tutorials/12-ocr-integration.md +188 -0
- docs/tutorials/13-semantic-search.ipynb +1888 -0
- docs/tutorials/13-semantic-search.md +77 -0
- docs/visual-debugging/index.ipynb +2970 -0
- docs/visual-debugging/index.md +157 -0
- docs/visual-debugging/region.png +0 -0
- natural_pdf/__init__.py +39 -20
- natural_pdf/analyzers/__init__.py +2 -1
- natural_pdf/analyzers/layout/base.py +32 -24
- natural_pdf/analyzers/layout/docling.py +131 -72
- natural_pdf/analyzers/layout/layout_analyzer.py +156 -113
- natural_pdf/analyzers/layout/layout_manager.py +98 -58
- natural_pdf/analyzers/layout/layout_options.py +32 -17
- natural_pdf/analyzers/layout/paddle.py +152 -95
- natural_pdf/analyzers/layout/surya.py +164 -92
- natural_pdf/analyzers/layout/tatr.py +149 -84
- natural_pdf/analyzers/layout/yolo.py +84 -44
- natural_pdf/analyzers/text_options.py +22 -15
- natural_pdf/analyzers/text_structure.py +131 -85
- natural_pdf/analyzers/utils.py +30 -23
- natural_pdf/collections/pdf_collection.py +125 -97
- natural_pdf/core/__init__.py +1 -1
- natural_pdf/core/element_manager.py +416 -337
- natural_pdf/core/highlighting_service.py +268 -196
- natural_pdf/core/page.py +907 -513
- natural_pdf/core/pdf.py +385 -287
- natural_pdf/elements/__init__.py +1 -1
- natural_pdf/elements/base.py +302 -214
- natural_pdf/elements/collections.py +708 -508
- natural_pdf/elements/line.py +39 -36
- natural_pdf/elements/rect.py +32 -30
- natural_pdf/elements/region.py +854 -883
- natural_pdf/elements/text.py +122 -99
- natural_pdf/exporters/__init__.py +0 -1
- natural_pdf/exporters/searchable_pdf.py +261 -102
- natural_pdf/ocr/__init__.py +23 -14
- natural_pdf/ocr/engine.py +17 -8
- natural_pdf/ocr/engine_easyocr.py +63 -47
- natural_pdf/ocr/engine_paddle.py +97 -68
- natural_pdf/ocr/engine_surya.py +54 -44
- natural_pdf/ocr/ocr_manager.py +88 -62
- natural_pdf/ocr/ocr_options.py +16 -10
- natural_pdf/qa/__init__.py +1 -1
- natural_pdf/qa/document_qa.py +119 -111
- natural_pdf/search/__init__.py +37 -31
- natural_pdf/search/haystack_search_service.py +312 -189
- natural_pdf/search/haystack_utils.py +186 -122
- natural_pdf/search/search_options.py +25 -14
- natural_pdf/search/search_service_protocol.py +12 -6
- natural_pdf/search/searchable_mixin.py +261 -176
- natural_pdf/selectors/__init__.py +2 -1
- natural_pdf/selectors/parser.py +159 -316
- natural_pdf/templates/__init__.py +1 -1
- natural_pdf/utils/highlighting.py +8 -2
- natural_pdf/utils/reading_order.py +65 -63
- natural_pdf/utils/text_extraction.py +195 -0
- natural_pdf/utils/visualization.py +70 -61
- natural_pdf/widgets/__init__.py +2 -3
- natural_pdf/widgets/viewer.py +749 -718
- {natural_pdf-0.1.4.dist-info → natural_pdf-0.1.5.dist-info}/METADATA +15 -1
- natural_pdf-0.1.5.dist-info/RECORD +134 -0
- natural_pdf-0.1.5.dist-info/top_level.txt +5 -0
- notebooks/Examples.ipynb +1293 -0
- pdfs/.gitkeep +0 -0
- pdfs/01-practice.pdf +543 -0
- pdfs/0500000US42001.pdf +0 -0
- pdfs/0500000US42007.pdf +0 -0
- pdfs/2014 Statistics.pdf +0 -0
- pdfs/2019 Statistics.pdf +0 -0
- pdfs/Atlanta_Public_Schools_GA_sample.pdf +0 -0
- pdfs/needs-ocr.pdf +0 -0
- tests/test_loading.py +50 -0
- tests/test_optional_deps.py +298 -0
- natural_pdf-0.1.4.dist-info/RECORD +0 -61
- natural_pdf-0.1.4.dist-info/top_level.txt +0 -1
- {natural_pdf-0.1.4.dist-info → natural_pdf-0.1.5.dist-info}/WHEEL +0 -0
- {natural_pdf-0.1.4.dist-info → natural_pdf-0.1.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,14 +1,15 @@
|
|
1
1
|
# layout_detector_tatr.py
|
2
|
-
import logging
|
3
2
|
import importlib.util
|
3
|
+
import logging
|
4
4
|
import os
|
5
5
|
import tempfile
|
6
|
-
from typing import
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple
|
7
|
+
|
7
8
|
from PIL import Image
|
8
9
|
|
9
10
|
# Assuming base class and options are importable
|
10
11
|
from .base import LayoutDetector
|
11
|
-
from .layout_options import
|
12
|
+
from .layout_options import BaseLayoutOptions, TATRLayoutOptions
|
12
13
|
|
13
14
|
logger = logging.getLogger(__name__)
|
14
15
|
|
@@ -26,9 +27,13 @@ if torch_spec and torchvision_spec and transformers_spec:
|
|
26
27
|
from torchvision import transforms
|
27
28
|
from transformers import AutoModelForObjectDetection
|
28
29
|
except ImportError as e:
|
29
|
-
logger.warning(
|
30
|
+
logger.warning(
|
31
|
+
f"Could not import TATR dependencies (torch, torchvision, transformers): {e}"
|
32
|
+
)
|
30
33
|
else:
|
31
|
-
logger.warning(
|
34
|
+
logger.warning(
|
35
|
+
"torch, torchvision, or transformers not found. TableTransformerDetector will not be available."
|
36
|
+
)
|
32
37
|
|
33
38
|
|
34
39
|
class TableTransformerDetector(LayoutDetector):
|
@@ -36,26 +41,36 @@ class TableTransformerDetector(LayoutDetector):
|
|
36
41
|
|
37
42
|
# Custom resize transform (keep as nested class or move outside)
|
38
43
|
class MaxResize(object):
|
39
|
-
def __init__(self, max_size=
|
44
|
+
def __init__(self, max_size=2000):
|
40
45
|
self.max_size = max_size
|
46
|
+
|
41
47
|
def __call__(self, image):
|
42
48
|
width, height = image.size
|
43
49
|
current_max_size = max(width, height)
|
44
50
|
scale = self.max_size / current_max_size
|
45
51
|
# Use LANCZOS for resizing
|
46
|
-
resized_image = image.resize(
|
52
|
+
resized_image = image.resize(
|
53
|
+
(int(round(scale * width)), int(round(scale * height))), Image.Resampling.LANCZOS
|
54
|
+
)
|
47
55
|
return resized_image
|
48
56
|
|
49
57
|
def __init__(self):
|
50
58
|
super().__init__()
|
51
59
|
self.supported_classes = {
|
52
|
-
|
60
|
+
"table",
|
61
|
+
"table row",
|
62
|
+
"table column",
|
63
|
+
"table column header",
|
64
|
+
"table projected row header",
|
65
|
+
"table spanning cell", # Add others if supported by models used
|
53
66
|
}
|
54
67
|
# Models are loaded via _get_model
|
55
68
|
|
56
69
|
def is_available(self) -> bool:
|
57
70
|
"""Check if dependencies are installed."""
|
58
|
-
return
|
71
|
+
return (
|
72
|
+
torch is not None and transforms is not None and AutoModelForObjectDetection is not None
|
73
|
+
)
|
59
74
|
|
60
75
|
def _get_cache_key(self, options: TATRLayoutOptions) -> str:
|
61
76
|
"""Generate cache key based on model IDs and device."""
|
@@ -63,26 +78,30 @@ class TableTransformerDetector(LayoutDetector):
|
|
63
78
|
options = TATRLayoutOptions(device=options.device)
|
64
79
|
|
65
80
|
device_key = str(options.device).lower()
|
66
|
-
det_model_key = options.detection_model.replace(
|
67
|
-
struct_model_key = options.structure_model.replace(
|
81
|
+
det_model_key = options.detection_model.replace("/", "_")
|
82
|
+
struct_model_key = options.structure_model.replace("/", "_")
|
68
83
|
return f"{self.__class__.__name__}_{device_key}_{det_model_key}_{struct_model_key}"
|
69
84
|
|
70
85
|
def _load_model_from_options(self, options: TATRLayoutOptions) -> Dict[str, Any]:
|
71
86
|
"""Load the TATR detection and structure models."""
|
72
87
|
if not self.is_available():
|
73
|
-
|
88
|
+
raise RuntimeError(
|
89
|
+
"TATR dependencies (torch, torchvision, transformers) not installed."
|
90
|
+
)
|
74
91
|
|
75
92
|
device = options.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
76
|
-
self.logger.info(
|
93
|
+
self.logger.info(
|
94
|
+
f"Loading TATR models: Detection='{options.detection_model}', Structure='{options.structure_model}' onto device='{device}'"
|
95
|
+
)
|
77
96
|
try:
|
78
97
|
detection_model = AutoModelForObjectDetection.from_pretrained(
|
79
|
-
options.detection_model, revision="no_timm"
|
98
|
+
options.detection_model, revision="no_timm" # Important revision for some versions
|
80
99
|
).to(device)
|
81
100
|
structure_model = AutoModelForObjectDetection.from_pretrained(
|
82
101
|
options.structure_model
|
83
102
|
).to(device)
|
84
103
|
self.logger.info("TATR models loaded.")
|
85
|
-
return {
|
104
|
+
return {"detection": detection_model, "structure": structure_model}
|
86
105
|
except Exception as e:
|
87
106
|
self.logger.error(f"Failed to load TATR models: {e}", exc_info=True)
|
88
107
|
raise
|
@@ -97,19 +116,21 @@ class TableTransformerDetector(LayoutDetector):
|
|
97
116
|
def rescale_bboxes(self, out_bbox, size):
|
98
117
|
img_w, img_h = size
|
99
118
|
boxes = self.box_cxcywh_to_xyxy(out_bbox)
|
100
|
-
boxes = boxes * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(
|
119
|
+
boxes = boxes * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(
|
120
|
+
out_bbox.device
|
121
|
+
) # Ensure tensor on correct device
|
101
122
|
return boxes
|
102
123
|
|
103
124
|
def outputs_to_objects(self, outputs, img_size, id2label):
|
104
125
|
logits = outputs.logits
|
105
126
|
bboxes = outputs.pred_boxes
|
106
127
|
# Use softmax activation function
|
107
|
-
prob = logits.softmax(-1)[0, :, :-1]
|
128
|
+
prob = logits.softmax(-1)[0, :, :-1] # Exclude the "no object" class
|
108
129
|
scores, labels = prob.max(-1)
|
109
130
|
|
110
131
|
# Convert to absolute coordinates
|
111
132
|
img_w, img_h = img_size
|
112
|
-
boxes = self.rescale_bboxes(bboxes[0, ...], (img_w, img_h))
|
133
|
+
boxes = self.rescale_bboxes(bboxes[0, ...], (img_w, img_h)) # Pass tuple size
|
113
134
|
|
114
135
|
# Move results to CPU for list comprehension
|
115
136
|
scores = scores.cpu().tolist()
|
@@ -118,49 +139,62 @@ class TableTransformerDetector(LayoutDetector):
|
|
118
139
|
|
119
140
|
objects = []
|
120
141
|
for score, label_idx, bbox in zip(scores, labels, boxes):
|
121
|
-
class_label = id2label.get(label_idx,
|
122
|
-
if class_label !=
|
123
|
-
objects.append(
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
142
|
+
class_label = id2label.get(label_idx, "unknown") # Use get with default
|
143
|
+
if class_label != "no object" and class_label != "unknown":
|
144
|
+
objects.append(
|
145
|
+
{
|
146
|
+
"label": class_label,
|
147
|
+
"score": float(score),
|
148
|
+
"bbox": [round(float(c), 2) for c in bbox], # Round coordinates
|
149
|
+
}
|
150
|
+
)
|
128
151
|
return objects
|
152
|
+
|
129
153
|
# --- End Helper Methods ---
|
130
154
|
|
131
155
|
def detect(self, image: Image.Image, options: BaseLayoutOptions) -> List[Dict[str, Any]]:
|
132
156
|
"""Detect tables and their structure in an image."""
|
133
157
|
if not self.is_available():
|
134
|
-
raise RuntimeError(
|
158
|
+
raise RuntimeError(
|
159
|
+
"TATR dependencies (torch, torchvision, transformers) not installed."
|
160
|
+
)
|
135
161
|
|
136
162
|
if not isinstance(options, TATRLayoutOptions):
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
163
|
+
self.logger.warning(
|
164
|
+
"Received BaseLayoutOptions, expected TATRLayoutOptions. Using defaults."
|
165
|
+
)
|
166
|
+
options = TATRLayoutOptions(
|
167
|
+
confidence=options.confidence,
|
168
|
+
classes=options.classes,
|
169
|
+
exclude_classes=options.exclude_classes,
|
170
|
+
device=options.device,
|
171
|
+
extra_args=options.extra_args,
|
172
|
+
)
|
143
173
|
|
144
174
|
self.validate_classes(options.classes or [])
|
145
175
|
if options.exclude_classes:
|
146
176
|
self.validate_classes(options.exclude_classes)
|
147
177
|
|
148
178
|
models = self._get_model(options)
|
149
|
-
detection_model = models[
|
150
|
-
structure_model = models[
|
179
|
+
detection_model = models["detection"]
|
180
|
+
structure_model = models["structure"]
|
151
181
|
device = options.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
152
182
|
|
153
183
|
# Prepare transforms based on options
|
154
|
-
detection_transform = transforms.Compose(
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
184
|
+
detection_transform = transforms.Compose(
|
185
|
+
[
|
186
|
+
self.MaxResize(options.max_detection_size),
|
187
|
+
transforms.ToTensor(),
|
188
|
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
189
|
+
]
|
190
|
+
)
|
191
|
+
structure_transform = transforms.Compose(
|
192
|
+
[
|
193
|
+
self.MaxResize(options.max_structure_size),
|
194
|
+
transforms.ToTensor(),
|
195
|
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
196
|
+
]
|
197
|
+
)
|
164
198
|
|
165
199
|
# --- Detect Tables ---
|
166
200
|
self.logger.debug("Running TATR table detection...")
|
@@ -169,38 +203,60 @@ class TableTransformerDetector(LayoutDetector):
|
|
169
203
|
outputs = detection_model(pixel_values)
|
170
204
|
|
171
205
|
id2label_det = detection_model.config.id2label
|
172
|
-
id2label_det[detection_model.config.num_labels] = "no object"
|
206
|
+
id2label_det[detection_model.config.num_labels] = "no object" # Add no object class
|
173
207
|
tables = self.outputs_to_objects(outputs, image.size, id2label_det)
|
174
|
-
tables = [
|
208
|
+
tables = [
|
209
|
+
t for t in tables if t["score"] >= options.confidence and t["label"] == "table"
|
210
|
+
] # Filter for tables
|
175
211
|
self.logger.debug(f"Detected {len(tables)} table regions.")
|
176
212
|
|
177
213
|
all_detections = []
|
178
214
|
|
179
215
|
# Add table detections if requested
|
180
|
-
normalized_classes_req =
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
216
|
+
normalized_classes_req = (
|
217
|
+
{self._normalize_class_name(c) for c in options.classes} if options.classes else None
|
218
|
+
)
|
219
|
+
normalized_classes_excl = (
|
220
|
+
{self._normalize_class_name(c) for c in options.exclude_classes}
|
221
|
+
if options.exclude_classes
|
222
|
+
else set()
|
223
|
+
)
|
224
|
+
|
225
|
+
if normalized_classes_req is None or "table" in normalized_classes_req:
|
226
|
+
if "table" not in normalized_classes_excl:
|
227
|
+
for table in tables:
|
228
|
+
all_detections.append(
|
229
|
+
{
|
230
|
+
"bbox": tuple(table["bbox"]),
|
231
|
+
"class": "table",
|
232
|
+
"confidence": float(table["score"]),
|
233
|
+
"normalized_class": "table",
|
234
|
+
"source": "layout",
|
235
|
+
"model": "tatr",
|
236
|
+
}
|
237
|
+
)
|
194
238
|
|
195
239
|
# --- Process Structure ---
|
196
|
-
structure_class_names = {
|
197
|
-
|
240
|
+
structure_class_names = {
|
241
|
+
"table row",
|
242
|
+
"table column",
|
243
|
+
"table column header",
|
244
|
+
"table projected row header",
|
245
|
+
"table spanning cell",
|
246
|
+
}
|
247
|
+
normalized_structure_classes = {
|
248
|
+
self._normalize_class_name(c) for c in structure_class_names
|
249
|
+
}
|
198
250
|
|
199
251
|
needed_structure = False
|
200
|
-
if normalized_classes_req is None:
|
201
|
-
|
202
|
-
|
203
|
-
|
252
|
+
if normalized_classes_req is None: # If no specific classes requested
|
253
|
+
needed_structure = any(
|
254
|
+
norm_cls not in normalized_classes_excl for norm_cls in normalized_structure_classes
|
255
|
+
)
|
256
|
+
else: # Specific classes requested
|
257
|
+
needed_structure = any(
|
258
|
+
norm_cls in normalized_classes_req for norm_cls in normalized_structure_classes
|
259
|
+
)
|
204
260
|
|
205
261
|
if needed_structure and tables:
|
206
262
|
self.logger.debug("Running TATR structure recognition...")
|
@@ -208,44 +264,53 @@ class TableTransformerDetector(LayoutDetector):
|
|
208
264
|
id2label_struct[structure_model.config.num_labels] = "no object"
|
209
265
|
|
210
266
|
for table in tables:
|
211
|
-
x_min, y_min, x_max, y_max = map(int, table[
|
267
|
+
x_min, y_min, x_max, y_max = map(int, table["bbox"])
|
212
268
|
# Ensure coordinates are within image bounds
|
213
269
|
x_min, y_min = max(0, x_min), max(0, y_min)
|
214
270
|
x_max, y_max = min(image.width, x_max), min(image.height, y_max)
|
215
|
-
if x_max <= x_min or y_max <= y_min:
|
271
|
+
if x_max <= x_min or y_max <= y_min:
|
272
|
+
continue # Skip invalid crop
|
216
273
|
|
217
274
|
cropped_table = image.crop((x_min, y_min, x_max, y_max))
|
218
|
-
if cropped_table.width == 0 or cropped_table.height == 0:
|
275
|
+
if cropped_table.width == 0 or cropped_table.height == 0:
|
276
|
+
continue # Skip empty crop
|
219
277
|
|
220
278
|
pixel_values_struct = structure_transform(cropped_table).unsqueeze(0).to(device)
|
221
279
|
with torch.no_grad():
|
222
280
|
outputs_struct = structure_model(pixel_values_struct)
|
223
281
|
|
224
|
-
structure_elements = self.outputs_to_objects(
|
225
|
-
|
282
|
+
structure_elements = self.outputs_to_objects(
|
283
|
+
outputs_struct, cropped_table.size, id2label_struct
|
284
|
+
)
|
285
|
+
structure_elements = [
|
286
|
+
e for e in structure_elements if e["score"] >= options.confidence
|
287
|
+
]
|
226
288
|
|
227
289
|
for element in structure_elements:
|
228
|
-
element_class_orig = element[
|
290
|
+
element_class_orig = element["label"]
|
229
291
|
normalized_class = self._normalize_class_name(element_class_orig)
|
230
292
|
|
231
293
|
# Apply class filtering
|
232
|
-
if normalized_classes_req and normalized_class not in normalized_classes_req:
|
233
|
-
|
294
|
+
if normalized_classes_req and normalized_class not in normalized_classes_req:
|
295
|
+
continue
|
296
|
+
if normalized_class in normalized_classes_excl:
|
297
|
+
continue
|
234
298
|
|
235
299
|
# Adjust coordinates
|
236
|
-
ex0, ey0, ex1, ey1 = element[
|
300
|
+
ex0, ey0, ex1, ey1 = element["bbox"]
|
237
301
|
adj_bbox = (ex0 + x_min, ey0 + y_min, ex1 + x_min, ey1 + y_min)
|
238
302
|
|
239
|
-
all_detections.append(
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
303
|
+
all_detections.append(
|
304
|
+
{
|
305
|
+
"bbox": adj_bbox,
|
306
|
+
"class": element_class_orig,
|
307
|
+
"confidence": float(element["score"]),
|
308
|
+
"normalized_class": normalized_class,
|
309
|
+
"source": "layout",
|
310
|
+
"model": "tatr",
|
311
|
+
}
|
312
|
+
)
|
247
313
|
self.logger.debug(f"Added {len(all_detections) - len(tables)} structure elements.")
|
248
314
|
|
249
315
|
self.logger.info(f"TATR detected {len(all_detections)} layout elements matching criteria.")
|
250
316
|
return all_detections
|
251
|
-
|
@@ -1,24 +1,38 @@
|
|
1
1
|
# layout_detector_yolo.py
|
2
|
-
import logging
|
3
2
|
import importlib.util
|
3
|
+
import logging
|
4
4
|
import os
|
5
5
|
import tempfile
|
6
|
-
from typing import
|
6
|
+
from typing import Any, Dict, List, Optional
|
7
|
+
|
7
8
|
from PIL import Image
|
8
9
|
|
9
10
|
# Assuming base class and options are importable
|
10
11
|
try:
|
11
12
|
from .base import LayoutDetector
|
12
|
-
from .layout_options import
|
13
|
+
from .layout_options import BaseLayoutOptions, YOLOLayoutOptions
|
13
14
|
except ImportError:
|
14
15
|
# Placeholders if run standalone or imports fail
|
15
|
-
class BaseLayoutOptions:
|
16
|
-
|
16
|
+
class BaseLayoutOptions:
|
17
|
+
pass
|
18
|
+
|
19
|
+
class YOLOLayoutOptions(BaseLayoutOptions):
|
20
|
+
pass
|
21
|
+
|
17
22
|
class LayoutDetector:
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
23
|
+
def __init__(self):
|
24
|
+
self.logger = logging.getLogger()
|
25
|
+
self.supported_classes = set()
|
26
|
+
|
27
|
+
def _get_model(self, options):
|
28
|
+
raise NotImplementedError
|
29
|
+
|
30
|
+
def _normalize_class_name(self, n):
|
31
|
+
return n
|
32
|
+
|
33
|
+
def validate_classes(self, c):
|
34
|
+
pass
|
35
|
+
|
22
36
|
logging.basicConfig()
|
23
37
|
|
24
38
|
logger = logging.getLogger(__name__)
|
@@ -36,7 +50,9 @@ if yolo_spec and hf_spec:
|
|
36
50
|
except ImportError as e:
|
37
51
|
logger.warning(f"Could not import YOLO dependencies: {e}")
|
38
52
|
else:
|
39
|
-
logger.warning(
|
53
|
+
logger.warning(
|
54
|
+
"doclayout_yolo or huggingface_hub not found. YOLODocLayoutDetector will not be available."
|
55
|
+
)
|
40
56
|
|
41
57
|
|
42
58
|
class YOLODocLayoutDetector(LayoutDetector):
|
@@ -45,9 +61,16 @@ class YOLODocLayoutDetector(LayoutDetector):
|
|
45
61
|
def __init__(self):
|
46
62
|
super().__init__()
|
47
63
|
self.supported_classes = {
|
48
|
-
|
49
|
-
|
50
|
-
|
64
|
+
"title",
|
65
|
+
"plain text",
|
66
|
+
"abandon",
|
67
|
+
"figure",
|
68
|
+
"figure_caption",
|
69
|
+
"table",
|
70
|
+
"table_caption",
|
71
|
+
"table_footnote",
|
72
|
+
"isolate_formula",
|
73
|
+
"formula_caption",
|
51
74
|
}
|
52
75
|
|
53
76
|
def is_available(self) -> bool:
|
@@ -58,8 +81,8 @@ class YOLODocLayoutDetector(LayoutDetector):
|
|
58
81
|
"""Generate cache key based on model repo/file and device."""
|
59
82
|
# Ensure options is the correct type
|
60
83
|
if not isinstance(options, YOLOLayoutOptions):
|
61
|
-
|
62
|
-
|
84
|
+
# This shouldn't happen if called correctly, but handle defensively
|
85
|
+
options = YOLOLayoutOptions(device=options.device) # Use base device
|
63
86
|
|
64
87
|
device_key = str(options.device).lower()
|
65
88
|
model_key = f"{options.model_repo.replace('/','_')}_{options.model_file}"
|
@@ -68,7 +91,7 @@ class YOLODocLayoutDetector(LayoutDetector):
|
|
68
91
|
def _load_model_from_options(self, options: YOLOLayoutOptions) -> Any:
|
69
92
|
"""Load the YOLOv10 model based on options."""
|
70
93
|
if not self.is_available():
|
71
|
-
|
94
|
+
raise RuntimeError("YOLO dependencies (doclayout_yolo, huggingface_hub) not installed.")
|
72
95
|
self.logger.info(f"Loading YOLO model: {options.model_repo}/{options.model_file}")
|
73
96
|
try:
|
74
97
|
model_path = hf_hub_download(repo_id=options.model_repo, filename=options.model_file)
|
@@ -86,12 +109,16 @@ class YOLODocLayoutDetector(LayoutDetector):
|
|
86
109
|
|
87
110
|
# Ensure options are the correct type, falling back to defaults if base type passed
|
88
111
|
if not isinstance(options, YOLOLayoutOptions):
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
112
|
+
self.logger.warning(
|
113
|
+
"Received BaseLayoutOptions, expected YOLOLayoutOptions. Using defaults."
|
114
|
+
)
|
115
|
+
options = YOLOLayoutOptions(
|
116
|
+
confidence=options.confidence,
|
117
|
+
classes=options.classes,
|
118
|
+
exclude_classes=options.exclude_classes,
|
119
|
+
device=options.device,
|
120
|
+
extra_args=options.extra_args,
|
121
|
+
)
|
95
122
|
|
96
123
|
# Validate classes before proceeding
|
97
124
|
self.validate_classes(options.classes or [])
|
@@ -108,58 +135,71 @@ class YOLODocLayoutDetector(LayoutDetector):
|
|
108
135
|
temp_image_path = os.path.join(temp_dir, "temp_layout_image.png")
|
109
136
|
try:
|
110
137
|
self.logger.debug(f"Saving temporary image for YOLO detector to: {temp_image_path}")
|
111
|
-
image.convert("RGB").save(temp_image_path)
|
138
|
+
image.convert("RGB").save(temp_image_path) # Ensure RGB
|
112
139
|
|
113
140
|
# Run model prediction
|
114
|
-
self.logger.debug(
|
141
|
+
self.logger.debug(
|
142
|
+
f"Running YOLO prediction (imgsz={options.image_size}, conf={options.confidence}, device={options.device})..."
|
143
|
+
)
|
115
144
|
results = model.predict(
|
116
145
|
temp_image_path,
|
117
146
|
imgsz=options.image_size,
|
118
147
|
conf=options.confidence,
|
119
|
-
device=options.device or
|
148
|
+
device=options.device or "cpu", # Default to cpu if None
|
120
149
|
# Add other predict args from options.extra_args if needed
|
121
150
|
# **options.extra_args
|
122
151
|
)
|
123
152
|
self.logger.debug(f"YOLO prediction returned {len(results)} result objects.")
|
124
153
|
|
125
154
|
# Process results into standardized format
|
126
|
-
img_width, img_height = image.size
|
155
|
+
img_width, img_height = image.size # Get original image size for context if needed
|
127
156
|
for result in results:
|
128
|
-
if result.boxes is None:
|
157
|
+
if result.boxes is None:
|
158
|
+
continue
|
129
159
|
boxes = result.boxes.xyxy
|
130
160
|
labels = result.boxes.cls
|
131
161
|
scores = result.boxes.conf
|
132
|
-
class_names = result.names
|
162
|
+
class_names = result.names # Dictionary mapping index to name
|
133
163
|
|
134
164
|
for box, label_idx_tensor, score_tensor in zip(boxes, labels, scores):
|
135
165
|
x_min, y_min, x_max, y_max = map(float, box.tolist())
|
136
|
-
label_idx = int(label_idx_tensor.item())
|
137
|
-
score = float(score_tensor.item())
|
166
|
+
label_idx = int(label_idx_tensor.item()) # Get int index
|
167
|
+
score = float(score_tensor.item()) # Get float score
|
138
168
|
|
139
169
|
if label_idx not in class_names:
|
140
|
-
|
141
|
-
|
170
|
+
self.logger.warning(
|
171
|
+
f"Label index {label_idx} not found in model names dict. Skipping."
|
172
|
+
)
|
173
|
+
continue
|
142
174
|
label_name = class_names[label_idx]
|
143
175
|
normalized_class = self._normalize_class_name(label_name)
|
144
176
|
|
145
177
|
# Apply class filtering (using normalized names)
|
146
|
-
if options.classes and normalized_class not in [
|
178
|
+
if options.classes and normalized_class not in [
|
179
|
+
self._normalize_class_name(c) for c in options.classes
|
180
|
+
]:
|
147
181
|
continue
|
148
|
-
if options.exclude_classes and normalized_class in [
|
182
|
+
if options.exclude_classes and normalized_class in [
|
183
|
+
self._normalize_class_name(c) for c in options.exclude_classes
|
184
|
+
]:
|
149
185
|
continue
|
150
186
|
|
151
|
-
detections.append(
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
187
|
+
detections.append(
|
188
|
+
{
|
189
|
+
"bbox": (x_min, y_min, x_max, y_max),
|
190
|
+
"class": label_name,
|
191
|
+
"confidence": score,
|
192
|
+
"normalized_class": normalized_class,
|
193
|
+
"source": "layout",
|
194
|
+
"model": "yolo",
|
195
|
+
}
|
196
|
+
)
|
197
|
+
self.logger.info(
|
198
|
+
f"YOLO detected {len(detections)} layout elements matching criteria."
|
199
|
+
)
|
160
200
|
|
161
201
|
except Exception as e:
|
162
202
|
self.logger.error(f"Error during YOLO detection: {e}", exc_info=True)
|
163
|
-
raise
|
203
|
+
raise # Re-raise the exception
|
164
204
|
|
165
205
|
return detections
|