natural-pdf 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- natural_pdf/__init__.py +55 -0
- natural_pdf/analyzers/__init__.py +6 -0
- natural_pdf/analyzers/layout/__init__.py +1 -0
- natural_pdf/analyzers/layout/base.py +151 -0
- natural_pdf/analyzers/layout/docling.py +247 -0
- natural_pdf/analyzers/layout/layout_analyzer.py +166 -0
- natural_pdf/analyzers/layout/layout_manager.py +200 -0
- natural_pdf/analyzers/layout/layout_options.py +78 -0
- natural_pdf/analyzers/layout/paddle.py +240 -0
- natural_pdf/analyzers/layout/surya.py +151 -0
- natural_pdf/analyzers/layout/tatr.py +251 -0
- natural_pdf/analyzers/layout/yolo.py +165 -0
- natural_pdf/analyzers/text_options.py +60 -0
- natural_pdf/analyzers/text_structure.py +270 -0
- natural_pdf/analyzers/utils.py +57 -0
- natural_pdf/core/__init__.py +3 -0
- natural_pdf/core/element_manager.py +457 -0
- natural_pdf/core/highlighting_service.py +698 -0
- natural_pdf/core/page.py +1444 -0
- natural_pdf/core/pdf.py +653 -0
- natural_pdf/elements/__init__.py +3 -0
- natural_pdf/elements/base.py +761 -0
- natural_pdf/elements/collections.py +1345 -0
- natural_pdf/elements/line.py +140 -0
- natural_pdf/elements/rect.py +122 -0
- natural_pdf/elements/region.py +1793 -0
- natural_pdf/elements/text.py +304 -0
- natural_pdf/ocr/__init__.py +56 -0
- natural_pdf/ocr/engine.py +104 -0
- natural_pdf/ocr/engine_easyocr.py +179 -0
- natural_pdf/ocr/engine_paddle.py +204 -0
- natural_pdf/ocr/engine_surya.py +171 -0
- natural_pdf/ocr/ocr_manager.py +191 -0
- natural_pdf/ocr/ocr_options.py +114 -0
- natural_pdf/qa/__init__.py +3 -0
- natural_pdf/qa/document_qa.py +396 -0
- natural_pdf/selectors/__init__.py +4 -0
- natural_pdf/selectors/parser.py +354 -0
- natural_pdf/templates/__init__.py +1 -0
- natural_pdf/templates/ocr_debug.html +517 -0
- natural_pdf/utils/__init__.py +3 -0
- natural_pdf/utils/highlighting.py +12 -0
- natural_pdf/utils/reading_order.py +227 -0
- natural_pdf/utils/visualization.py +223 -0
- natural_pdf/widgets/__init__.py +4 -0
- natural_pdf/widgets/frontend/viewer.js +88 -0
- natural_pdf/widgets/viewer.py +765 -0
- natural_pdf-0.1.0.dist-info/METADATA +295 -0
- natural_pdf-0.1.0.dist-info/RECORD +52 -0
- natural_pdf-0.1.0.dist-info/WHEEL +5 -0
- natural_pdf-0.1.0.dist-info/licenses/LICENSE +21 -0
- natural_pdf-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,251 @@
|
|
1
|
+
# layout_detector_tatr.py
|
2
|
+
import logging
|
3
|
+
import importlib.util
|
4
|
+
import os
|
5
|
+
import tempfile
|
6
|
+
from typing import List, Dict, Any, Optional, Tuple
|
7
|
+
from PIL import Image
|
8
|
+
|
9
|
+
# Assuming base class and options are importable
|
10
|
+
from .base import LayoutDetector
|
11
|
+
from .layout_options import TATRLayoutOptions, BaseLayoutOptions
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
# Check for dependencies
|
16
|
+
torch_spec = importlib.util.find_spec("torch")
|
17
|
+
torchvision_spec = importlib.util.find_spec("torchvision")
|
18
|
+
transformers_spec = importlib.util.find_spec("transformers")
|
19
|
+
torch = None
|
20
|
+
transforms = None
|
21
|
+
AutoModelForObjectDetection = None
|
22
|
+
|
23
|
+
if torch_spec and torchvision_spec and transformers_spec:
|
24
|
+
try:
|
25
|
+
import torch
|
26
|
+
from torchvision import transforms
|
27
|
+
from transformers import AutoModelForObjectDetection
|
28
|
+
except ImportError as e:
|
29
|
+
logger.warning(f"Could not import TATR dependencies (torch, torchvision, transformers): {e}")
|
30
|
+
else:
|
31
|
+
logger.warning("torch, torchvision, or transformers not found. TableTransformerDetector will not be available.")
|
32
|
+
|
33
|
+
|
34
|
+
class TableTransformerDetector(LayoutDetector):
|
35
|
+
"""Table structure detector using Microsoft's Table Transformer (TATR) models."""
|
36
|
+
|
37
|
+
# Custom resize transform (keep as nested class or move outside)
|
38
|
+
class MaxResize(object):
|
39
|
+
def __init__(self, max_size=800):
|
40
|
+
self.max_size = max_size
|
41
|
+
def __call__(self, image):
|
42
|
+
width, height = image.size
|
43
|
+
current_max_size = max(width, height)
|
44
|
+
scale = self.max_size / current_max_size
|
45
|
+
# Use LANCZOS for resizing
|
46
|
+
resized_image = image.resize((int(round(scale*width)), int(round(scale*height))), Image.Resampling.LANCZOS)
|
47
|
+
return resized_image
|
48
|
+
|
49
|
+
def __init__(self):
|
50
|
+
super().__init__()
|
51
|
+
self.supported_classes = {
|
52
|
+
'table', 'table row', 'table column', 'table column header', 'table projected row header', 'table spanning cell' # Add others if supported by models used
|
53
|
+
}
|
54
|
+
# Models are loaded via _get_model
|
55
|
+
|
56
|
+
def is_available(self) -> bool:
|
57
|
+
"""Check if dependencies are installed."""
|
58
|
+
return torch is not None and transforms is not None and AutoModelForObjectDetection is not None
|
59
|
+
|
60
|
+
def _get_cache_key(self, options: TATRLayoutOptions) -> str:
|
61
|
+
"""Generate cache key based on model IDs and device."""
|
62
|
+
if not isinstance(options, TATRLayoutOptions):
|
63
|
+
options = TATRLayoutOptions(device=options.device)
|
64
|
+
|
65
|
+
device_key = str(options.device).lower()
|
66
|
+
det_model_key = options.detection_model.replace('/','_')
|
67
|
+
struct_model_key = options.structure_model.replace('/','_')
|
68
|
+
return f"{self.__class__.__name__}_{device_key}_{det_model_key}_{struct_model_key}"
|
69
|
+
|
70
|
+
def _load_model_from_options(self, options: TATRLayoutOptions) -> Dict[str, Any]:
|
71
|
+
"""Load the TATR detection and structure models."""
|
72
|
+
if not self.is_available():
|
73
|
+
raise RuntimeError("TATR dependencies (torch, torchvision, transformers) not installed.")
|
74
|
+
|
75
|
+
device = options.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
76
|
+
self.logger.info(f"Loading TATR models: Detection='{options.detection_model}', Structure='{options.structure_model}' onto device='{device}'")
|
77
|
+
try:
|
78
|
+
detection_model = AutoModelForObjectDetection.from_pretrained(
|
79
|
+
options.detection_model, revision="no_timm" # Important revision for some versions
|
80
|
+
).to(device)
|
81
|
+
structure_model = AutoModelForObjectDetection.from_pretrained(
|
82
|
+
options.structure_model
|
83
|
+
).to(device)
|
84
|
+
self.logger.info("TATR models loaded.")
|
85
|
+
return {'detection': detection_model, 'structure': structure_model}
|
86
|
+
except Exception as e:
|
87
|
+
self.logger.error(f"Failed to load TATR models: {e}", exc_info=True)
|
88
|
+
raise
|
89
|
+
|
90
|
+
# --- Helper methods (box_cxcywh_to_xyxy, rescale_bboxes, outputs_to_objects) ---
|
91
|
+
# Keep these as defined in the original tatr.txt file, making them instance methods
|
92
|
+
def box_cxcywh_to_xyxy(self, x):
|
93
|
+
x_c, y_c, w, h = x.unbind(-1)
|
94
|
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
95
|
+
return torch.stack(b, dim=1)
|
96
|
+
|
97
|
+
def rescale_bboxes(self, out_bbox, size):
|
98
|
+
img_w, img_h = size
|
99
|
+
boxes = self.box_cxcywh_to_xyxy(out_bbox)
|
100
|
+
boxes = boxes * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(out_bbox.device) # Ensure tensor on correct device
|
101
|
+
return boxes
|
102
|
+
|
103
|
+
def outputs_to_objects(self, outputs, img_size, id2label):
|
104
|
+
logits = outputs.logits
|
105
|
+
bboxes = outputs.pred_boxes
|
106
|
+
# Use softmax activation function
|
107
|
+
prob = logits.softmax(-1)[0, :, :-1] # Exclude the "no object" class
|
108
|
+
scores, labels = prob.max(-1)
|
109
|
+
|
110
|
+
# Convert to absolute coordinates
|
111
|
+
img_w, img_h = img_size
|
112
|
+
boxes = self.rescale_bboxes(bboxes[0, ...], (img_w, img_h)) # Pass tuple size
|
113
|
+
|
114
|
+
# Move results to CPU for list comprehension
|
115
|
+
scores = scores.cpu().tolist()
|
116
|
+
labels = labels.cpu().tolist()
|
117
|
+
boxes = boxes.cpu().tolist()
|
118
|
+
|
119
|
+
objects = []
|
120
|
+
for score, label_idx, bbox in zip(scores, labels, boxes):
|
121
|
+
class_label = id2label.get(label_idx, 'unknown') # Use get with default
|
122
|
+
if class_label != 'no object' and class_label != 'unknown':
|
123
|
+
objects.append({
|
124
|
+
'label': class_label,
|
125
|
+
'score': float(score),
|
126
|
+
'bbox': [round(float(c), 2) for c in bbox] # Round coordinates
|
127
|
+
})
|
128
|
+
return objects
|
129
|
+
# --- End Helper Methods ---
|
130
|
+
|
131
|
+
def detect(self, image: Image.Image, options: BaseLayoutOptions) -> List[Dict[str, Any]]:
|
132
|
+
"""Detect tables and their structure in an image."""
|
133
|
+
if not self.is_available():
|
134
|
+
raise RuntimeError("TATR dependencies (torch, torchvision, transformers) not installed.")
|
135
|
+
|
136
|
+
if not isinstance(options, TATRLayoutOptions):
|
137
|
+
self.logger.warning("Received BaseLayoutOptions, expected TATRLayoutOptions. Using defaults.")
|
138
|
+
options = TATRLayoutOptions(
|
139
|
+
confidence=options.confidence, classes=options.classes,
|
140
|
+
exclude_classes=options.exclude_classes, device=options.device,
|
141
|
+
extra_args=options.extra_args
|
142
|
+
)
|
143
|
+
|
144
|
+
self.validate_classes(options.classes or [])
|
145
|
+
if options.exclude_classes:
|
146
|
+
self.validate_classes(options.exclude_classes)
|
147
|
+
|
148
|
+
models = self._get_model(options)
|
149
|
+
detection_model = models['detection']
|
150
|
+
structure_model = models['structure']
|
151
|
+
device = options.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
152
|
+
|
153
|
+
# Prepare transforms based on options
|
154
|
+
detection_transform = transforms.Compose([
|
155
|
+
self.MaxResize(options.max_detection_size),
|
156
|
+
transforms.ToTensor(),
|
157
|
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
158
|
+
])
|
159
|
+
structure_transform = transforms.Compose([
|
160
|
+
self.MaxResize(options.max_structure_size),
|
161
|
+
transforms.ToTensor(),
|
162
|
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
163
|
+
])
|
164
|
+
|
165
|
+
# --- Detect Tables ---
|
166
|
+
self.logger.debug("Running TATR table detection...")
|
167
|
+
pixel_values = detection_transform(image.convert("RGB")).unsqueeze(0).to(device)
|
168
|
+
with torch.no_grad():
|
169
|
+
outputs = detection_model(pixel_values)
|
170
|
+
|
171
|
+
id2label_det = detection_model.config.id2label
|
172
|
+
id2label_det[detection_model.config.num_labels] = "no object" # Add no object class
|
173
|
+
tables = self.outputs_to_objects(outputs, image.size, id2label_det)
|
174
|
+
tables = [t for t in tables if t['score'] >= options.confidence and t['label'] == 'table'] # Filter for tables
|
175
|
+
self.logger.debug(f"Detected {len(tables)} table regions.")
|
176
|
+
|
177
|
+
all_detections = []
|
178
|
+
|
179
|
+
# Add table detections if requested
|
180
|
+
normalized_classes_req = {self._normalize_class_name(c) for c in options.classes} if options.classes else None
|
181
|
+
normalized_classes_excl = {self._normalize_class_name(c) for c in options.exclude_classes} if options.exclude_classes else set()
|
182
|
+
|
183
|
+
if normalized_classes_req is None or 'table' in normalized_classes_req:
|
184
|
+
if 'table' not in normalized_classes_excl:
|
185
|
+
for table in tables:
|
186
|
+
all_detections.append({
|
187
|
+
'bbox': tuple(table['bbox']),
|
188
|
+
'class': 'table',
|
189
|
+
'confidence': float(table['score']),
|
190
|
+
'normalized_class': 'table',
|
191
|
+
'source': 'layout',
|
192
|
+
'model': 'tatr'
|
193
|
+
})
|
194
|
+
|
195
|
+
# --- Process Structure ---
|
196
|
+
structure_class_names = {'table row', 'table column', 'table column header', 'table projected row header', 'table spanning cell'}
|
197
|
+
normalized_structure_classes = {self._normalize_class_name(c) for c in structure_class_names}
|
198
|
+
|
199
|
+
needed_structure = False
|
200
|
+
if normalized_classes_req is None: # If no specific classes requested
|
201
|
+
needed_structure = any(norm_cls not in normalized_classes_excl for norm_cls in normalized_structure_classes)
|
202
|
+
else: # Specific classes requested
|
203
|
+
needed_structure = any(norm_cls in normalized_classes_req for norm_cls in normalized_structure_classes)
|
204
|
+
|
205
|
+
if needed_structure and tables:
|
206
|
+
self.logger.debug("Running TATR structure recognition...")
|
207
|
+
id2label_struct = structure_model.config.id2label
|
208
|
+
id2label_struct[structure_model.config.num_labels] = "no object"
|
209
|
+
|
210
|
+
for table in tables:
|
211
|
+
x_min, y_min, x_max, y_max = map(int, table['bbox'])
|
212
|
+
# Ensure coordinates are within image bounds
|
213
|
+
x_min, y_min = max(0, x_min), max(0, y_min)
|
214
|
+
x_max, y_max = min(image.width, x_max), min(image.height, y_max)
|
215
|
+
if x_max <= x_min or y_max <= y_min: continue # Skip invalid crop
|
216
|
+
|
217
|
+
cropped_table = image.crop((x_min, y_min, x_max, y_max))
|
218
|
+
if cropped_table.width == 0 or cropped_table.height == 0: continue # Skip empty crop
|
219
|
+
|
220
|
+
pixel_values_struct = structure_transform(cropped_table).unsqueeze(0).to(device)
|
221
|
+
with torch.no_grad():
|
222
|
+
outputs_struct = structure_model(pixel_values_struct)
|
223
|
+
|
224
|
+
structure_elements = self.outputs_to_objects(outputs_struct, cropped_table.size, id2label_struct)
|
225
|
+
structure_elements = [e for e in structure_elements if e['score'] >= options.confidence]
|
226
|
+
|
227
|
+
for element in structure_elements:
|
228
|
+
element_class_orig = element['label']
|
229
|
+
normalized_class = self._normalize_class_name(element_class_orig)
|
230
|
+
|
231
|
+
# Apply class filtering
|
232
|
+
if normalized_classes_req and normalized_class not in normalized_classes_req: continue
|
233
|
+
if normalized_class in normalized_classes_excl: continue
|
234
|
+
|
235
|
+
# Adjust coordinates
|
236
|
+
ex0, ey0, ex1, ey1 = element['bbox']
|
237
|
+
adj_bbox = (ex0 + x_min, ey0 + y_min, ex1 + x_min, ey1 + y_min)
|
238
|
+
|
239
|
+
all_detections.append({
|
240
|
+
'bbox': adj_bbox,
|
241
|
+
'class': element_class_orig,
|
242
|
+
'confidence': float(element['score']),
|
243
|
+
'normalized_class': normalized_class,
|
244
|
+
'source': 'layout',
|
245
|
+
'model': 'tatr'
|
246
|
+
})
|
247
|
+
self.logger.debug(f"Added {len(all_detections) - len(tables)} structure elements.")
|
248
|
+
|
249
|
+
self.logger.info(f"TATR detected {len(all_detections)} layout elements matching criteria.")
|
250
|
+
return all_detections
|
251
|
+
|
@@ -0,0 +1,165 @@
|
|
1
|
+
# layout_detector_yolo.py
|
2
|
+
import logging
|
3
|
+
import importlib.util
|
4
|
+
import os
|
5
|
+
import tempfile
|
6
|
+
from typing import List, Dict, Any, Optional
|
7
|
+
from PIL import Image
|
8
|
+
|
9
|
+
# Assuming base class and options are importable
|
10
|
+
try:
|
11
|
+
from .base import LayoutDetector
|
12
|
+
from .layout_options import YOLOLayoutOptions, BaseLayoutOptions
|
13
|
+
except ImportError:
|
14
|
+
# Placeholders if run standalone or imports fail
|
15
|
+
class BaseLayoutOptions: pass
|
16
|
+
class YOLOLayoutOptions(BaseLayoutOptions): pass
|
17
|
+
class LayoutDetector:
|
18
|
+
def __init__(self): self.logger=logging.getLogger(); self.supported_classes=set()
|
19
|
+
def _get_model(self, options): raise NotImplementedError
|
20
|
+
def _normalize_class_name(self, n): return n
|
21
|
+
def validate_classes(self, c): pass
|
22
|
+
logging.basicConfig()
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
# Check for dependencies
|
27
|
+
yolo_spec = importlib.util.find_spec("doclayout_yolo")
|
28
|
+
hf_spec = importlib.util.find_spec("huggingface_hub")
|
29
|
+
YOLOv10 = None
|
30
|
+
hf_hub_download = None
|
31
|
+
|
32
|
+
if yolo_spec and hf_spec:
|
33
|
+
try:
|
34
|
+
from doclayout_yolo import YOLOv10
|
35
|
+
from huggingface_hub import hf_hub_download
|
36
|
+
except ImportError as e:
|
37
|
+
logger.warning(f"Could not import YOLO dependencies: {e}")
|
38
|
+
else:
|
39
|
+
logger.warning("doclayout_yolo or huggingface_hub not found. YOLODocLayoutDetector will not be available.")
|
40
|
+
|
41
|
+
|
42
|
+
class YOLODocLayoutDetector(LayoutDetector):
|
43
|
+
"""Document layout detector using YOLO model."""
|
44
|
+
|
45
|
+
def __init__(self):
|
46
|
+
super().__init__()
|
47
|
+
self.supported_classes = {
|
48
|
+
'title', 'plain text', 'abandon', 'figure', 'figure_caption',
|
49
|
+
'table', 'table_caption', 'table_footnote', 'isolate_formula',
|
50
|
+
'formula_caption'
|
51
|
+
}
|
52
|
+
|
53
|
+
def is_available(self) -> bool:
|
54
|
+
"""Check if dependencies are installed."""
|
55
|
+
return YOLOv10 is not None and hf_hub_download is not None
|
56
|
+
|
57
|
+
def _get_cache_key(self, options: YOLOLayoutOptions) -> str:
|
58
|
+
"""Generate cache key based on model repo/file and device."""
|
59
|
+
# Ensure options is the correct type
|
60
|
+
if not isinstance(options, YOLOLayoutOptions):
|
61
|
+
# This shouldn't happen if called correctly, but handle defensively
|
62
|
+
options = YOLOLayoutOptions(device=options.device) # Use base device
|
63
|
+
|
64
|
+
device_key = str(options.device).lower()
|
65
|
+
model_key = f"{options.model_repo.replace('/','_')}_{options.model_file}"
|
66
|
+
return f"{self.__class__.__name__}_{device_key}_{model_key}"
|
67
|
+
|
68
|
+
def _load_model_from_options(self, options: YOLOLayoutOptions) -> Any:
|
69
|
+
"""Load the YOLOv10 model based on options."""
|
70
|
+
if not self.is_available():
|
71
|
+
raise RuntimeError("YOLO dependencies (doclayout_yolo, huggingface_hub) not installed.")
|
72
|
+
self.logger.info(f"Loading YOLO model: {options.model_repo}/{options.model_file}")
|
73
|
+
try:
|
74
|
+
model_path = hf_hub_download(repo_id=options.model_repo, filename=options.model_file)
|
75
|
+
model = YOLOv10(model_path)
|
76
|
+
self.logger.info("YOLO model loaded.")
|
77
|
+
return model
|
78
|
+
except Exception as e:
|
79
|
+
self.logger.error(f"Failed to download or load YOLO model: {e}", exc_info=True)
|
80
|
+
raise
|
81
|
+
|
82
|
+
def detect(self, image: Image.Image, options: BaseLayoutOptions) -> List[Dict[str, Any]]:
|
83
|
+
"""Detect layout elements in an image using YOLO."""
|
84
|
+
if not self.is_available():
|
85
|
+
raise RuntimeError("YOLO dependencies (doclayout_yolo, huggingface_hub) not installed.")
|
86
|
+
|
87
|
+
# Ensure options are the correct type, falling back to defaults if base type passed
|
88
|
+
if not isinstance(options, YOLOLayoutOptions):
|
89
|
+
self.logger.warning("Received BaseLayoutOptions, expected YOLOLayoutOptions. Using defaults.")
|
90
|
+
options = YOLOLayoutOptions(
|
91
|
+
confidence=options.confidence, classes=options.classes,
|
92
|
+
exclude_classes=options.exclude_classes, device=options.device,
|
93
|
+
extra_args=options.extra_args
|
94
|
+
)
|
95
|
+
|
96
|
+
# Validate classes before proceeding
|
97
|
+
self.validate_classes(options.classes or [])
|
98
|
+
if options.exclude_classes:
|
99
|
+
self.validate_classes(options.exclude_classes)
|
100
|
+
|
101
|
+
# Get the cached/loaded model
|
102
|
+
model = self._get_model(options)
|
103
|
+
|
104
|
+
# YOLOv10 predict method requires an image path. Save temp file.
|
105
|
+
detections = []
|
106
|
+
# Use a context manager for robust temp file handling
|
107
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
108
|
+
temp_image_path = os.path.join(temp_dir, "temp_layout_image.png")
|
109
|
+
try:
|
110
|
+
self.logger.debug(f"Saving temporary image for YOLO detector to: {temp_image_path}")
|
111
|
+
image.convert("RGB").save(temp_image_path) # Ensure RGB
|
112
|
+
|
113
|
+
# Run model prediction
|
114
|
+
self.logger.debug(f"Running YOLO prediction (imgsz={options.image_size}, conf={options.confidence}, device={options.device})...")
|
115
|
+
results = model.predict(
|
116
|
+
temp_image_path,
|
117
|
+
imgsz=options.image_size,
|
118
|
+
conf=options.confidence,
|
119
|
+
device=options.device or 'cpu' # Default to cpu if None
|
120
|
+
# Add other predict args from options.extra_args if needed
|
121
|
+
# **options.extra_args
|
122
|
+
)
|
123
|
+
self.logger.debug(f"YOLO prediction returned {len(results)} result objects.")
|
124
|
+
|
125
|
+
# Process results into standardized format
|
126
|
+
img_width, img_height = image.size # Get original image size for context if needed
|
127
|
+
for result in results:
|
128
|
+
if result.boxes is None: continue
|
129
|
+
boxes = result.boxes.xyxy
|
130
|
+
labels = result.boxes.cls
|
131
|
+
scores = result.boxes.conf
|
132
|
+
class_names = result.names # Dictionary mapping index to name
|
133
|
+
|
134
|
+
for box, label_idx_tensor, score_tensor in zip(boxes, labels, scores):
|
135
|
+
x_min, y_min, x_max, y_max = map(float, box.tolist())
|
136
|
+
label_idx = int(label_idx_tensor.item()) # Get int index
|
137
|
+
score = float(score_tensor.item()) # Get float score
|
138
|
+
|
139
|
+
if label_idx not in class_names:
|
140
|
+
self.logger.warning(f"Label index {label_idx} not found in model names dict. Skipping.")
|
141
|
+
continue
|
142
|
+
label_name = class_names[label_idx]
|
143
|
+
normalized_class = self._normalize_class_name(label_name)
|
144
|
+
|
145
|
+
# Apply class filtering (using normalized names)
|
146
|
+
if options.classes and normalized_class not in [self._normalize_class_name(c) for c in options.classes]:
|
147
|
+
continue
|
148
|
+
if options.exclude_classes and normalized_class in [self._normalize_class_name(c) for c in options.exclude_classes]:
|
149
|
+
continue
|
150
|
+
|
151
|
+
detections.append({
|
152
|
+
'bbox': (x_min, y_min, x_max, y_max),
|
153
|
+
'class': label_name,
|
154
|
+
'confidence': score,
|
155
|
+
'normalized_class': normalized_class,
|
156
|
+
'source': 'layout',
|
157
|
+
'model': 'yolo'
|
158
|
+
})
|
159
|
+
self.logger.info(f"YOLO detected {len(detections)} layout elements matching criteria.")
|
160
|
+
|
161
|
+
except Exception as e:
|
162
|
+
self.logger.error(f"Error during YOLO detection: {e}", exc_info=True)
|
163
|
+
raise # Re-raise the exception
|
164
|
+
|
165
|
+
return detections
|
@@ -0,0 +1,60 @@
|
|
1
|
+
import logging
|
2
|
+
from dataclasses import dataclass, field
|
3
|
+
from typing import List, Optional
|
4
|
+
|
5
|
+
logger = logging.getLogger(__name__)
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class TextStyleOptions:
|
9
|
+
"""Options for configuring text style analysis."""
|
10
|
+
|
11
|
+
# Properties to consider when grouping elements by style
|
12
|
+
group_by: List[str] = field(default_factory=lambda: ['size', 'fontname', 'is_bold', 'is_italic', 'color'])
|
13
|
+
|
14
|
+
# Tolerance for comparing font sizes (e.g., 0.5 rounds to nearest 0.5 point)
|
15
|
+
size_tolerance: float = 0.5
|
16
|
+
|
17
|
+
# If True, ignores text color during grouping
|
18
|
+
ignore_color: bool = False
|
19
|
+
|
20
|
+
# If True, ignores small variations often found in font names (e.g., '+ArialMT')
|
21
|
+
normalize_fontname: bool = True
|
22
|
+
|
23
|
+
# If True, generates descriptive labels (e.g., "12pt-Bold-Arial")
|
24
|
+
# If False, uses simple numeric labels ("Style 1")
|
25
|
+
descriptive_labels: bool = True
|
26
|
+
|
27
|
+
# Prefix for generated labels (used if descriptive_labels is False or as fallback)
|
28
|
+
label_prefix: str = "Style"
|
29
|
+
|
30
|
+
# Format string for descriptive labels. Placeholders match keys in style_properties dict.
|
31
|
+
# Example: "{size}pt {weight}{style} {family} ({color})"
|
32
|
+
# Available keys: size, fontname, is_bold, is_italic, color, weight, style, family
|
33
|
+
label_format: str = "{size}pt {weight}{style} {family}" # Default format without color
|
34
|
+
|
35
|
+
|
36
|
+
def __post_init__(self):
|
37
|
+
# Validate size_tolerance
|
38
|
+
if self.size_tolerance <= 0:
|
39
|
+
logger.warning(f"size_tolerance must be positive, setting to 0.1. Original value: {self.size_tolerance}")
|
40
|
+
self.size_tolerance = 0.1
|
41
|
+
|
42
|
+
# Ensure 'size' is always considered if tolerance is relevant
|
43
|
+
if 'size' not in self.group_by and self.size_tolerance > 0:
|
44
|
+
logger.debug("Adding 'size' to group_by keys because size_tolerance is set.")
|
45
|
+
if 'size' not in self.group_by: self.group_by.append('size')
|
46
|
+
|
47
|
+
if self.ignore_color and 'color' in self.group_by:
|
48
|
+
logger.debug("Removing 'color' from group_by keys because ignore_color is True.")
|
49
|
+
self.group_by = [key for key in self.group_by if key != 'color']
|
50
|
+
elif not self.ignore_color and 'color' not in self.group_by:
|
51
|
+
# If color isn't ignored, ensure it's included if requested in label format?
|
52
|
+
# For now, just rely on explicit group_by setting.
|
53
|
+
pass
|
54
|
+
|
55
|
+
# Basic validation for group_by keys
|
56
|
+
allowed_keys = {'size', 'fontname', 'is_bold', 'is_italic', 'color'}
|
57
|
+
invalid_keys = set(self.group_by) - allowed_keys
|
58
|
+
if invalid_keys:
|
59
|
+
logger.warning(f"Invalid keys found in group_by: {invalid_keys}. Allowed keys: {allowed_keys}. Ignoring invalid keys.")
|
60
|
+
self.group_by = [key for key in self.group_by if key in allowed_keys]
|