natural-pdf 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. docs/api/index.md +386 -0
  2. docs/assets/favicon.png +3 -0
  3. docs/assets/favicon.svg +3 -0
  4. docs/assets/javascripts/custom.js +17 -0
  5. docs/assets/logo.svg +3 -0
  6. docs/assets/sample-screen.png +0 -0
  7. docs/assets/social-preview.png +17 -0
  8. docs/assets/social-preview.svg +17 -0
  9. docs/assets/stylesheets/custom.css +65 -0
  10. docs/document-qa/index.ipynb +435 -0
  11. docs/document-qa/index.md +79 -0
  12. docs/element-selection/index.ipynb +915 -0
  13. docs/element-selection/index.md +229 -0
  14. docs/index.md +170 -0
  15. docs/installation/index.md +69 -0
  16. docs/interactive-widget/index.ipynb +962 -0
  17. docs/interactive-widget/index.md +12 -0
  18. docs/layout-analysis/index.ipynb +818 -0
  19. docs/layout-analysis/index.md +185 -0
  20. docs/ocr/index.md +209 -0
  21. docs/pdf-navigation/index.ipynb +314 -0
  22. docs/pdf-navigation/index.md +97 -0
  23. docs/regions/index.ipynb +816 -0
  24. docs/regions/index.md +294 -0
  25. docs/tables/index.ipynb +658 -0
  26. docs/tables/index.md +144 -0
  27. docs/text-analysis/index.ipynb +370 -0
  28. docs/text-analysis/index.md +105 -0
  29. docs/text-extraction/index.ipynb +1478 -0
  30. docs/text-extraction/index.md +292 -0
  31. docs/tutorials/01-loading-and-extraction.ipynb +1710 -0
  32. docs/tutorials/01-loading-and-extraction.md +95 -0
  33. docs/tutorials/02-finding-elements.ipynb +340 -0
  34. docs/tutorials/02-finding-elements.md +149 -0
  35. docs/tutorials/03-extracting-blocks.ipynb +147 -0
  36. docs/tutorials/03-extracting-blocks.md +48 -0
  37. docs/tutorials/04-table-extraction.ipynb +114 -0
  38. docs/tutorials/04-table-extraction.md +50 -0
  39. docs/tutorials/05-excluding-content.ipynb +270 -0
  40. docs/tutorials/05-excluding-content.md +109 -0
  41. docs/tutorials/06-document-qa.ipynb +332 -0
  42. docs/tutorials/06-document-qa.md +91 -0
  43. docs/tutorials/07-layout-analysis.ipynb +288 -0
  44. docs/tutorials/07-layout-analysis.md +66 -0
  45. docs/tutorials/07-working-with-regions.ipynb +413 -0
  46. docs/tutorials/07-working-with-regions.md +151 -0
  47. docs/tutorials/08-spatial-navigation.ipynb +508 -0
  48. docs/tutorials/08-spatial-navigation.md +190 -0
  49. docs/tutorials/09-section-extraction.ipynb +2434 -0
  50. docs/tutorials/09-section-extraction.md +256 -0
  51. docs/tutorials/10-form-field-extraction.ipynb +512 -0
  52. docs/tutorials/10-form-field-extraction.md +201 -0
  53. docs/tutorials/11-enhanced-table-processing.ipynb +54 -0
  54. docs/tutorials/11-enhanced-table-processing.md +9 -0
  55. docs/tutorials/12-ocr-integration.ipynb +604 -0
  56. docs/tutorials/12-ocr-integration.md +175 -0
  57. docs/tutorials/13-semantic-search.ipynb +1328 -0
  58. docs/tutorials/13-semantic-search.md +77 -0
  59. docs/visual-debugging/index.ipynb +2970 -0
  60. docs/visual-debugging/index.md +157 -0
  61. docs/visual-debugging/region.png +0 -0
  62. natural_pdf/__init__.py +50 -33
  63. natural_pdf/analyzers/__init__.py +2 -1
  64. natural_pdf/analyzers/layout/base.py +32 -24
  65. natural_pdf/analyzers/layout/docling.py +131 -72
  66. natural_pdf/analyzers/layout/gemini.py +264 -0
  67. natural_pdf/analyzers/layout/layout_analyzer.py +156 -113
  68. natural_pdf/analyzers/layout/layout_manager.py +125 -58
  69. natural_pdf/analyzers/layout/layout_options.py +43 -17
  70. natural_pdf/analyzers/layout/paddle.py +152 -95
  71. natural_pdf/analyzers/layout/surya.py +164 -92
  72. natural_pdf/analyzers/layout/tatr.py +149 -84
  73. natural_pdf/analyzers/layout/yolo.py +89 -45
  74. natural_pdf/analyzers/text_options.py +22 -15
  75. natural_pdf/analyzers/text_structure.py +131 -85
  76. natural_pdf/analyzers/utils.py +30 -23
  77. natural_pdf/collections/pdf_collection.py +146 -97
  78. natural_pdf/core/__init__.py +1 -1
  79. natural_pdf/core/element_manager.py +419 -337
  80. natural_pdf/core/highlighting_service.py +268 -196
  81. natural_pdf/core/page.py +1044 -521
  82. natural_pdf/core/pdf.py +516 -313
  83. natural_pdf/elements/__init__.py +1 -1
  84. natural_pdf/elements/base.py +307 -225
  85. natural_pdf/elements/collections.py +805 -543
  86. natural_pdf/elements/line.py +39 -36
  87. natural_pdf/elements/rect.py +32 -30
  88. natural_pdf/elements/region.py +889 -879
  89. natural_pdf/elements/text.py +127 -99
  90. natural_pdf/exporters/__init__.py +0 -1
  91. natural_pdf/exporters/searchable_pdf.py +261 -102
  92. natural_pdf/ocr/__init__.py +57 -35
  93. natural_pdf/ocr/engine.py +150 -46
  94. natural_pdf/ocr/engine_easyocr.py +146 -150
  95. natural_pdf/ocr/engine_paddle.py +118 -175
  96. natural_pdf/ocr/engine_surya.py +78 -141
  97. natural_pdf/ocr/ocr_factory.py +114 -0
  98. natural_pdf/ocr/ocr_manager.py +122 -124
  99. natural_pdf/ocr/ocr_options.py +16 -20
  100. natural_pdf/ocr/utils.py +98 -0
  101. natural_pdf/qa/__init__.py +1 -1
  102. natural_pdf/qa/document_qa.py +119 -111
  103. natural_pdf/search/__init__.py +37 -31
  104. natural_pdf/search/haystack_search_service.py +312 -189
  105. natural_pdf/search/haystack_utils.py +186 -122
  106. natural_pdf/search/search_options.py +25 -14
  107. natural_pdf/search/search_service_protocol.py +12 -6
  108. natural_pdf/search/searchable_mixin.py +261 -176
  109. natural_pdf/selectors/__init__.py +2 -1
  110. natural_pdf/selectors/parser.py +159 -316
  111. natural_pdf/templates/__init__.py +1 -1
  112. natural_pdf/templates/spa/css/style.css +334 -0
  113. natural_pdf/templates/spa/index.html +31 -0
  114. natural_pdf/templates/spa/js/app.js +472 -0
  115. natural_pdf/templates/spa/words.txt +235976 -0
  116. natural_pdf/utils/debug.py +32 -0
  117. natural_pdf/utils/highlighting.py +8 -2
  118. natural_pdf/utils/identifiers.py +29 -0
  119. natural_pdf/utils/packaging.py +418 -0
  120. natural_pdf/utils/reading_order.py +65 -63
  121. natural_pdf/utils/text_extraction.py +195 -0
  122. natural_pdf/utils/visualization.py +70 -61
  123. natural_pdf/widgets/__init__.py +2 -3
  124. natural_pdf/widgets/viewer.py +749 -718
  125. {natural_pdf-0.1.4.dist-info → natural_pdf-0.1.6.dist-info}/METADATA +53 -17
  126. natural_pdf-0.1.6.dist-info/RECORD +141 -0
  127. {natural_pdf-0.1.4.dist-info → natural_pdf-0.1.6.dist-info}/WHEEL +1 -1
  128. natural_pdf-0.1.6.dist-info/top_level.txt +4 -0
  129. notebooks/Examples.ipynb +1293 -0
  130. pdfs/.gitkeep +0 -0
  131. pdfs/01-practice.pdf +543 -0
  132. pdfs/0500000US42001.pdf +0 -0
  133. pdfs/0500000US42007.pdf +0 -0
  134. pdfs/2014 Statistics.pdf +0 -0
  135. pdfs/2019 Statistics.pdf +0 -0
  136. pdfs/Atlanta_Public_Schools_GA_sample.pdf +0 -0
  137. pdfs/needs-ocr.pdf +0 -0
  138. natural_pdf/templates/ocr_debug.html +0 -517
  139. natural_pdf-0.1.4.dist-info/RECORD +0 -61
  140. natural_pdf-0.1.4.dist-info/top_level.txt +0 -1
  141. {natural_pdf-0.1.4.dist-info → natural_pdf-0.1.6.dist-info}/licenses/LICENSE +0 -0
@@ -1,14 +1,15 @@
1
1
  # layout_detector_tatr.py
2
- import logging
3
2
  import importlib.util
3
+ import logging
4
4
  import os
5
5
  import tempfile
6
- from typing import List, Dict, Any, Optional, Tuple
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
7
8
  from PIL import Image
8
9
 
9
10
  # Assuming base class and options are importable
10
11
  from .base import LayoutDetector
11
- from .layout_options import TATRLayoutOptions, BaseLayoutOptions
12
+ from .layout_options import BaseLayoutOptions, TATRLayoutOptions
12
13
 
13
14
  logger = logging.getLogger(__name__)
14
15
 
@@ -26,9 +27,13 @@ if torch_spec and torchvision_spec and transformers_spec:
26
27
  from torchvision import transforms
27
28
  from transformers import AutoModelForObjectDetection
28
29
  except ImportError as e:
29
- logger.warning(f"Could not import TATR dependencies (torch, torchvision, transformers): {e}")
30
+ logger.warning(
31
+ f"Could not import TATR dependencies (torch, torchvision, transformers): {e}"
32
+ )
30
33
  else:
31
- logger.warning("torch, torchvision, or transformers not found. TableTransformerDetector will not be available.")
34
+ logger.warning(
35
+ "torch, torchvision, or transformers not found. TableTransformerDetector will not be available."
36
+ )
32
37
 
33
38
 
34
39
  class TableTransformerDetector(LayoutDetector):
@@ -36,26 +41,36 @@ class TableTransformerDetector(LayoutDetector):
36
41
 
37
42
  # Custom resize transform (keep as nested class or move outside)
38
43
  class MaxResize(object):
39
- def __init__(self, max_size=800):
44
+ def __init__(self, max_size=2000):
40
45
  self.max_size = max_size
46
+
41
47
  def __call__(self, image):
42
48
  width, height = image.size
43
49
  current_max_size = max(width, height)
44
50
  scale = self.max_size / current_max_size
45
51
  # Use LANCZOS for resizing
46
- resized_image = image.resize((int(round(scale*width)), int(round(scale*height))), Image.Resampling.LANCZOS)
52
+ resized_image = image.resize(
53
+ (int(round(scale * width)), int(round(scale * height))), Image.Resampling.LANCZOS
54
+ )
47
55
  return resized_image
48
56
 
49
57
  def __init__(self):
50
58
  super().__init__()
51
59
  self.supported_classes = {
52
- 'table', 'table row', 'table column', 'table column header', 'table projected row header', 'table spanning cell' # Add others if supported by models used
60
+ "table",
61
+ "table row",
62
+ "table column",
63
+ "table column header",
64
+ "table projected row header",
65
+ "table spanning cell", # Add others if supported by models used
53
66
  }
54
67
  # Models are loaded via _get_model
55
68
 
56
69
  def is_available(self) -> bool:
57
70
  """Check if dependencies are installed."""
58
- return torch is not None and transforms is not None and AutoModelForObjectDetection is not None
71
+ return (
72
+ torch is not None and transforms is not None and AutoModelForObjectDetection is not None
73
+ )
59
74
 
60
75
  def _get_cache_key(self, options: TATRLayoutOptions) -> str:
61
76
  """Generate cache key based on model IDs and device."""
@@ -63,26 +78,30 @@ class TableTransformerDetector(LayoutDetector):
63
78
  options = TATRLayoutOptions(device=options.device)
64
79
 
65
80
  device_key = str(options.device).lower()
66
- det_model_key = options.detection_model.replace('/','_')
67
- struct_model_key = options.structure_model.replace('/','_')
81
+ det_model_key = options.detection_model.replace("/", "_")
82
+ struct_model_key = options.structure_model.replace("/", "_")
68
83
  return f"{self.__class__.__name__}_{device_key}_{det_model_key}_{struct_model_key}"
69
84
 
70
85
  def _load_model_from_options(self, options: TATRLayoutOptions) -> Dict[str, Any]:
71
86
  """Load the TATR detection and structure models."""
72
87
  if not self.is_available():
73
- raise RuntimeError("TATR dependencies (torch, torchvision, transformers) not installed.")
88
+ raise RuntimeError(
89
+ "TATR dependencies (torch, torchvision, transformers) not installed."
90
+ )
74
91
 
75
92
  device = options.device or ("cuda" if torch.cuda.is_available() else "cpu")
76
- self.logger.info(f"Loading TATR models: Detection='{options.detection_model}', Structure='{options.structure_model}' onto device='{device}'")
93
+ self.logger.info(
94
+ f"Loading TATR models: Detection='{options.detection_model}', Structure='{options.structure_model}' onto device='{device}'"
95
+ )
77
96
  try:
78
97
  detection_model = AutoModelForObjectDetection.from_pretrained(
79
- options.detection_model, revision="no_timm" # Important revision for some versions
98
+ options.detection_model, revision="no_timm" # Important revision for some versions
80
99
  ).to(device)
81
100
  structure_model = AutoModelForObjectDetection.from_pretrained(
82
101
  options.structure_model
83
102
  ).to(device)
84
103
  self.logger.info("TATR models loaded.")
85
- return {'detection': detection_model, 'structure': structure_model}
104
+ return {"detection": detection_model, "structure": structure_model}
86
105
  except Exception as e:
87
106
  self.logger.error(f"Failed to load TATR models: {e}", exc_info=True)
88
107
  raise
@@ -97,19 +116,21 @@ class TableTransformerDetector(LayoutDetector):
97
116
  def rescale_bboxes(self, out_bbox, size):
98
117
  img_w, img_h = size
99
118
  boxes = self.box_cxcywh_to_xyxy(out_bbox)
100
- boxes = boxes * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(out_bbox.device) # Ensure tensor on correct device
119
+ boxes = boxes * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(
120
+ out_bbox.device
121
+ ) # Ensure tensor on correct device
101
122
  return boxes
102
123
 
103
124
  def outputs_to_objects(self, outputs, img_size, id2label):
104
125
  logits = outputs.logits
105
126
  bboxes = outputs.pred_boxes
106
127
  # Use softmax activation function
107
- prob = logits.softmax(-1)[0, :, :-1] # Exclude the "no object" class
128
+ prob = logits.softmax(-1)[0, :, :-1] # Exclude the "no object" class
108
129
  scores, labels = prob.max(-1)
109
130
 
110
131
  # Convert to absolute coordinates
111
132
  img_w, img_h = img_size
112
- boxes = self.rescale_bboxes(bboxes[0, ...], (img_w, img_h)) # Pass tuple size
133
+ boxes = self.rescale_bboxes(bboxes[0, ...], (img_w, img_h)) # Pass tuple size
113
134
 
114
135
  # Move results to CPU for list comprehension
115
136
  scores = scores.cpu().tolist()
@@ -118,49 +139,62 @@ class TableTransformerDetector(LayoutDetector):
118
139
 
119
140
  objects = []
120
141
  for score, label_idx, bbox in zip(scores, labels, boxes):
121
- class_label = id2label.get(label_idx, 'unknown') # Use get with default
122
- if class_label != 'no object' and class_label != 'unknown':
123
- objects.append({
124
- 'label': class_label,
125
- 'score': float(score),
126
- 'bbox': [round(float(c), 2) for c in bbox] # Round coordinates
127
- })
142
+ class_label = id2label.get(label_idx, "unknown") # Use get with default
143
+ if class_label != "no object" and class_label != "unknown":
144
+ objects.append(
145
+ {
146
+ "label": class_label,
147
+ "score": float(score),
148
+ "bbox": [round(float(c), 2) for c in bbox], # Round coordinates
149
+ }
150
+ )
128
151
  return objects
152
+
129
153
  # --- End Helper Methods ---
130
154
 
131
155
  def detect(self, image: Image.Image, options: BaseLayoutOptions) -> List[Dict[str, Any]]:
132
156
  """Detect tables and their structure in an image."""
133
157
  if not self.is_available():
134
- raise RuntimeError("TATR dependencies (torch, torchvision, transformers) not installed.")
158
+ raise RuntimeError(
159
+ "TATR dependencies (torch, torchvision, transformers) not installed."
160
+ )
135
161
 
136
162
  if not isinstance(options, TATRLayoutOptions):
137
- self.logger.warning("Received BaseLayoutOptions, expected TATRLayoutOptions. Using defaults.")
138
- options = TATRLayoutOptions(
139
- confidence=options.confidence, classes=options.classes,
140
- exclude_classes=options.exclude_classes, device=options.device,
141
- extra_args=options.extra_args
142
- )
163
+ self.logger.warning(
164
+ "Received BaseLayoutOptions, expected TATRLayoutOptions. Using defaults."
165
+ )
166
+ options = TATRLayoutOptions(
167
+ confidence=options.confidence,
168
+ classes=options.classes,
169
+ exclude_classes=options.exclude_classes,
170
+ device=options.device,
171
+ extra_args=options.extra_args,
172
+ )
143
173
 
144
174
  self.validate_classes(options.classes or [])
145
175
  if options.exclude_classes:
146
176
  self.validate_classes(options.exclude_classes)
147
177
 
148
178
  models = self._get_model(options)
149
- detection_model = models['detection']
150
- structure_model = models['structure']
179
+ detection_model = models["detection"]
180
+ structure_model = models["structure"]
151
181
  device = options.device or ("cuda" if torch.cuda.is_available() else "cpu")
152
182
 
153
183
  # Prepare transforms based on options
154
- detection_transform = transforms.Compose([
155
- self.MaxResize(options.max_detection_size),
156
- transforms.ToTensor(),
157
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
158
- ])
159
- structure_transform = transforms.Compose([
160
- self.MaxResize(options.max_structure_size),
161
- transforms.ToTensor(),
162
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
163
- ])
184
+ detection_transform = transforms.Compose(
185
+ [
186
+ self.MaxResize(options.max_detection_size),
187
+ transforms.ToTensor(),
188
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
189
+ ]
190
+ )
191
+ structure_transform = transforms.Compose(
192
+ [
193
+ self.MaxResize(options.max_structure_size),
194
+ transforms.ToTensor(),
195
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
196
+ ]
197
+ )
164
198
 
165
199
  # --- Detect Tables ---
166
200
  self.logger.debug("Running TATR table detection...")
@@ -169,38 +203,60 @@ class TableTransformerDetector(LayoutDetector):
169
203
  outputs = detection_model(pixel_values)
170
204
 
171
205
  id2label_det = detection_model.config.id2label
172
- id2label_det[detection_model.config.num_labels] = "no object" # Add no object class
206
+ id2label_det[detection_model.config.num_labels] = "no object" # Add no object class
173
207
  tables = self.outputs_to_objects(outputs, image.size, id2label_det)
174
- tables = [t for t in tables if t['score'] >= options.confidence and t['label'] == 'table'] # Filter for tables
208
+ tables = [
209
+ t for t in tables if t["score"] >= options.confidence and t["label"] == "table"
210
+ ] # Filter for tables
175
211
  self.logger.debug(f"Detected {len(tables)} table regions.")
176
212
 
177
213
  all_detections = []
178
214
 
179
215
  # Add table detections if requested
180
- normalized_classes_req = {self._normalize_class_name(c) for c in options.classes} if options.classes else None
181
- normalized_classes_excl = {self._normalize_class_name(c) for c in options.exclude_classes} if options.exclude_classes else set()
182
-
183
- if normalized_classes_req is None or 'table' in normalized_classes_req:
184
- if 'table' not in normalized_classes_excl:
185
- for table in tables:
186
- all_detections.append({
187
- 'bbox': tuple(table['bbox']),
188
- 'class': 'table',
189
- 'confidence': float(table['score']),
190
- 'normalized_class': 'table',
191
- 'source': 'layout',
192
- 'model': 'tatr'
193
- })
216
+ normalized_classes_req = (
217
+ {self._normalize_class_name(c) for c in options.classes} if options.classes else None
218
+ )
219
+ normalized_classes_excl = (
220
+ {self._normalize_class_name(c) for c in options.exclude_classes}
221
+ if options.exclude_classes
222
+ else set()
223
+ )
224
+
225
+ if normalized_classes_req is None or "table" in normalized_classes_req:
226
+ if "table" not in normalized_classes_excl:
227
+ for table in tables:
228
+ all_detections.append(
229
+ {
230
+ "bbox": tuple(table["bbox"]),
231
+ "class": "table",
232
+ "confidence": float(table["score"]),
233
+ "normalized_class": "table",
234
+ "source": "layout",
235
+ "model": "tatr",
236
+ }
237
+ )
194
238
 
195
239
  # --- Process Structure ---
196
- structure_class_names = {'table row', 'table column', 'table column header', 'table projected row header', 'table spanning cell'}
197
- normalized_structure_classes = {self._normalize_class_name(c) for c in structure_class_names}
240
+ structure_class_names = {
241
+ "table row",
242
+ "table column",
243
+ "table column header",
244
+ "table projected row header",
245
+ "table spanning cell",
246
+ }
247
+ normalized_structure_classes = {
248
+ self._normalize_class_name(c) for c in structure_class_names
249
+ }
198
250
 
199
251
  needed_structure = False
200
- if normalized_classes_req is None: # If no specific classes requested
201
- needed_structure = any(norm_cls not in normalized_classes_excl for norm_cls in normalized_structure_classes)
202
- else: # Specific classes requested
203
- needed_structure = any(norm_cls in normalized_classes_req for norm_cls in normalized_structure_classes)
252
+ if normalized_classes_req is None: # If no specific classes requested
253
+ needed_structure = any(
254
+ norm_cls not in normalized_classes_excl for norm_cls in normalized_structure_classes
255
+ )
256
+ else: # Specific classes requested
257
+ needed_structure = any(
258
+ norm_cls in normalized_classes_req for norm_cls in normalized_structure_classes
259
+ )
204
260
 
205
261
  if needed_structure and tables:
206
262
  self.logger.debug("Running TATR structure recognition...")
@@ -208,44 +264,53 @@ class TableTransformerDetector(LayoutDetector):
208
264
  id2label_struct[structure_model.config.num_labels] = "no object"
209
265
 
210
266
  for table in tables:
211
- x_min, y_min, x_max, y_max = map(int, table['bbox'])
267
+ x_min, y_min, x_max, y_max = map(int, table["bbox"])
212
268
  # Ensure coordinates are within image bounds
213
269
  x_min, y_min = max(0, x_min), max(0, y_min)
214
270
  x_max, y_max = min(image.width, x_max), min(image.height, y_max)
215
- if x_max <= x_min or y_max <= y_min: continue # Skip invalid crop
271
+ if x_max <= x_min or y_max <= y_min:
272
+ continue # Skip invalid crop
216
273
 
217
274
  cropped_table = image.crop((x_min, y_min, x_max, y_max))
218
- if cropped_table.width == 0 or cropped_table.height == 0: continue # Skip empty crop
275
+ if cropped_table.width == 0 or cropped_table.height == 0:
276
+ continue # Skip empty crop
219
277
 
220
278
  pixel_values_struct = structure_transform(cropped_table).unsqueeze(0).to(device)
221
279
  with torch.no_grad():
222
280
  outputs_struct = structure_model(pixel_values_struct)
223
281
 
224
- structure_elements = self.outputs_to_objects(outputs_struct, cropped_table.size, id2label_struct)
225
- structure_elements = [e for e in structure_elements if e['score'] >= options.confidence]
282
+ structure_elements = self.outputs_to_objects(
283
+ outputs_struct, cropped_table.size, id2label_struct
284
+ )
285
+ structure_elements = [
286
+ e for e in structure_elements if e["score"] >= options.confidence
287
+ ]
226
288
 
227
289
  for element in structure_elements:
228
- element_class_orig = element['label']
290
+ element_class_orig = element["label"]
229
291
  normalized_class = self._normalize_class_name(element_class_orig)
230
292
 
231
293
  # Apply class filtering
232
- if normalized_classes_req and normalized_class not in normalized_classes_req: continue
233
- if normalized_class in normalized_classes_excl: continue
294
+ if normalized_classes_req and normalized_class not in normalized_classes_req:
295
+ continue
296
+ if normalized_class in normalized_classes_excl:
297
+ continue
234
298
 
235
299
  # Adjust coordinates
236
- ex0, ey0, ex1, ey1 = element['bbox']
300
+ ex0, ey0, ex1, ey1 = element["bbox"]
237
301
  adj_bbox = (ex0 + x_min, ey0 + y_min, ex1 + x_min, ey1 + y_min)
238
302
 
239
- all_detections.append({
240
- 'bbox': adj_bbox,
241
- 'class': element_class_orig,
242
- 'confidence': float(element['score']),
243
- 'normalized_class': normalized_class,
244
- 'source': 'layout',
245
- 'model': 'tatr'
246
- })
303
+ all_detections.append(
304
+ {
305
+ "bbox": adj_bbox,
306
+ "class": element_class_orig,
307
+ "confidence": float(element["score"]),
308
+ "normalized_class": normalized_class,
309
+ "source": "layout",
310
+ "model": "tatr",
311
+ }
312
+ )
247
313
  self.logger.debug(f"Added {len(all_detections) - len(tables)} structure elements.")
248
314
 
249
315
  self.logger.info(f"TATR detected {len(all_detections)} layout elements matching criteria.")
250
316
  return all_detections
251
-
@@ -1,24 +1,38 @@
1
1
  # layout_detector_yolo.py
2
- import logging
3
2
  import importlib.util
3
+ import logging
4
4
  import os
5
5
  import tempfile
6
- from typing import List, Dict, Any, Optional
6
+ from typing import Any, Dict, List, Optional
7
+
7
8
  from PIL import Image
8
9
 
9
10
  # Assuming base class and options are importable
10
11
  try:
11
12
  from .base import LayoutDetector
12
- from .layout_options import YOLOLayoutOptions, BaseLayoutOptions
13
+ from .layout_options import BaseLayoutOptions, YOLOLayoutOptions
13
14
  except ImportError:
14
15
  # Placeholders if run standalone or imports fail
15
- class BaseLayoutOptions: pass
16
- class YOLOLayoutOptions(BaseLayoutOptions): pass
16
+ class BaseLayoutOptions:
17
+ pass
18
+
19
+ class YOLOLayoutOptions(BaseLayoutOptions):
20
+ pass
21
+
17
22
  class LayoutDetector:
18
- def __init__(self): self.logger=logging.getLogger(); self.supported_classes=set()
19
- def _get_model(self, options): raise NotImplementedError
20
- def _normalize_class_name(self, n): return n
21
- def validate_classes(self, c): pass
23
+ def __init__(self):
24
+ self.logger = logging.getLogger()
25
+ self.supported_classes = set()
26
+
27
+ def _get_model(self, options):
28
+ raise NotImplementedError
29
+
30
+ def _normalize_class_name(self, n):
31
+ return n
32
+
33
+ def validate_classes(self, c):
34
+ pass
35
+
22
36
  logging.basicConfig()
23
37
 
24
38
  logger = logging.getLogger(__name__)
@@ -36,7 +50,9 @@ if yolo_spec and hf_spec:
36
50
  except ImportError as e:
37
51
  logger.warning(f"Could not import YOLO dependencies: {e}")
38
52
  else:
39
- logger.warning("doclayout_yolo or huggingface_hub not found. YOLODocLayoutDetector will not be available.")
53
+ logger.warning(
54
+ "doclayout_yolo or huggingface_hub not found. YOLODocLayoutDetector will not be available."
55
+ )
40
56
 
41
57
 
42
58
  class YOLODocLayoutDetector(LayoutDetector):
@@ -45,9 +61,16 @@ class YOLODocLayoutDetector(LayoutDetector):
45
61
  def __init__(self):
46
62
  super().__init__()
47
63
  self.supported_classes = {
48
- 'title', 'plain text', 'abandon', 'figure', 'figure_caption',
49
- 'table', 'table_caption', 'table_footnote', 'isolate_formula',
50
- 'formula_caption'
64
+ "title",
65
+ "plain text",
66
+ "abandon",
67
+ "figure",
68
+ "figure_caption",
69
+ "table",
70
+ "table_caption",
71
+ "table_footnote",
72
+ "isolate_formula",
73
+ "formula_caption",
51
74
  }
52
75
 
53
76
  def is_available(self) -> bool:
@@ -58,8 +81,8 @@ class YOLODocLayoutDetector(LayoutDetector):
58
81
  """Generate cache key based on model repo/file and device."""
59
82
  # Ensure options is the correct type
60
83
  if not isinstance(options, YOLOLayoutOptions):
61
- # This shouldn't happen if called correctly, but handle defensively
62
- options = YOLOLayoutOptions(device=options.device) # Use base device
84
+ # This shouldn't happen if called correctly, but handle defensively
85
+ options = YOLOLayoutOptions(device=options.device) # Use base device
63
86
 
64
87
  device_key = str(options.device).lower()
65
88
  model_key = f"{options.model_repo.replace('/','_')}_{options.model_file}"
@@ -68,7 +91,9 @@ class YOLODocLayoutDetector(LayoutDetector):
68
91
  def _load_model_from_options(self, options: YOLOLayoutOptions) -> Any:
69
92
  """Load the YOLOv10 model based on options."""
70
93
  if not self.is_available():
71
- raise RuntimeError("YOLO dependencies (doclayout_yolo, huggingface_hub) not installed.")
94
+ raise RuntimeError(
95
+ "YOLO dependencies not installed. Please run: pip install 'natural-pdf[layout_yolo]'"
96
+ )
72
97
  self.logger.info(f"Loading YOLO model: {options.model_repo}/{options.model_file}")
73
98
  try:
74
99
  model_path = hf_hub_download(repo_id=options.model_repo, filename=options.model_file)
@@ -82,16 +107,22 @@ class YOLODocLayoutDetector(LayoutDetector):
82
107
  def detect(self, image: Image.Image, options: BaseLayoutOptions) -> List[Dict[str, Any]]:
83
108
  """Detect layout elements in an image using YOLO."""
84
109
  if not self.is_available():
85
- raise RuntimeError("YOLO dependencies (doclayout_yolo, huggingface_hub) not installed.")
110
+ raise RuntimeError(
111
+ "YOLO dependencies not installed. Please run: pip install 'natural-pdf[layout_yolo]'"
112
+ )
86
113
 
87
114
  # Ensure options are the correct type, falling back to defaults if base type passed
88
115
  if not isinstance(options, YOLOLayoutOptions):
89
- self.logger.warning("Received BaseLayoutOptions, expected YOLOLayoutOptions. Using defaults.")
90
- options = YOLOLayoutOptions(
91
- confidence=options.confidence, classes=options.classes,
92
- exclude_classes=options.exclude_classes, device=options.device,
93
- extra_args=options.extra_args
94
- )
116
+ self.logger.warning(
117
+ "Received BaseLayoutOptions, expected YOLOLayoutOptions. Using defaults."
118
+ )
119
+ options = YOLOLayoutOptions(
120
+ confidence=options.confidence,
121
+ classes=options.classes,
122
+ exclude_classes=options.exclude_classes,
123
+ device=options.device,
124
+ extra_args=options.extra_args,
125
+ )
95
126
 
96
127
  # Validate classes before proceeding
97
128
  self.validate_classes(options.classes or [])
@@ -108,58 +139,71 @@ class YOLODocLayoutDetector(LayoutDetector):
108
139
  temp_image_path = os.path.join(temp_dir, "temp_layout_image.png")
109
140
  try:
110
141
  self.logger.debug(f"Saving temporary image for YOLO detector to: {temp_image_path}")
111
- image.convert("RGB").save(temp_image_path) # Ensure RGB
142
+ image.convert("RGB").save(temp_image_path) # Ensure RGB
112
143
 
113
144
  # Run model prediction
114
- self.logger.debug(f"Running YOLO prediction (imgsz={options.image_size}, conf={options.confidence}, device={options.device})...")
145
+ self.logger.debug(
146
+ f"Running YOLO prediction (imgsz={options.image_size}, conf={options.confidence}, device={options.device})..."
147
+ )
115
148
  results = model.predict(
116
149
  temp_image_path,
117
150
  imgsz=options.image_size,
118
151
  conf=options.confidence,
119
- device=options.device or 'cpu' # Default to cpu if None
152
+ device=options.device or "cpu", # Default to cpu if None
120
153
  # Add other predict args from options.extra_args if needed
121
154
  # **options.extra_args
122
155
  )
123
156
  self.logger.debug(f"YOLO prediction returned {len(results)} result objects.")
124
157
 
125
158
  # Process results into standardized format
126
- img_width, img_height = image.size # Get original image size for context if needed
159
+ img_width, img_height = image.size # Get original image size for context if needed
127
160
  for result in results:
128
- if result.boxes is None: continue
161
+ if result.boxes is None:
162
+ continue
129
163
  boxes = result.boxes.xyxy
130
164
  labels = result.boxes.cls
131
165
  scores = result.boxes.conf
132
- class_names = result.names # Dictionary mapping index to name
166
+ class_names = result.names # Dictionary mapping index to name
133
167
 
134
168
  for box, label_idx_tensor, score_tensor in zip(boxes, labels, scores):
135
169
  x_min, y_min, x_max, y_max = map(float, box.tolist())
136
- label_idx = int(label_idx_tensor.item()) # Get int index
137
- score = float(score_tensor.item()) # Get float score
170
+ label_idx = int(label_idx_tensor.item()) # Get int index
171
+ score = float(score_tensor.item()) # Get float score
138
172
 
139
173
  if label_idx not in class_names:
140
- self.logger.warning(f"Label index {label_idx} not found in model names dict. Skipping.")
141
- continue
174
+ self.logger.warning(
175
+ f"Label index {label_idx} not found in model names dict. Skipping."
176
+ )
177
+ continue
142
178
  label_name = class_names[label_idx]
143
179
  normalized_class = self._normalize_class_name(label_name)
144
180
 
145
181
  # Apply class filtering (using normalized names)
146
- if options.classes and normalized_class not in [self._normalize_class_name(c) for c in options.classes]:
182
+ if options.classes and normalized_class not in [
183
+ self._normalize_class_name(c) for c in options.classes
184
+ ]:
147
185
  continue
148
- if options.exclude_classes and normalized_class in [self._normalize_class_name(c) for c in options.exclude_classes]:
186
+ if options.exclude_classes and normalized_class in [
187
+ self._normalize_class_name(c) for c in options.exclude_classes
188
+ ]:
149
189
  continue
150
190
 
151
- detections.append({
152
- 'bbox': (x_min, y_min, x_max, y_max),
153
- 'class': label_name,
154
- 'confidence': score,
155
- 'normalized_class': normalized_class,
156
- 'source': 'layout',
157
- 'model': 'yolo'
158
- })
159
- self.logger.info(f"YOLO detected {len(detections)} layout elements matching criteria.")
191
+ detections.append(
192
+ {
193
+ "bbox": (x_min, y_min, x_max, y_max),
194
+ "class": label_name,
195
+ "confidence": score,
196
+ "normalized_class": normalized_class,
197
+ "source": "layout",
198
+ "model": "yolo",
199
+ }
200
+ )
201
+ self.logger.info(
202
+ f"YOLO detected {len(detections)} layout elements matching criteria."
203
+ )
160
204
 
161
205
  except Exception as e:
162
206
  self.logger.error(f"Error during YOLO detection: {e}", exc_info=True)
163
- raise # Re-raise the exception
207
+ raise # Re-raise the exception
164
208
 
165
209
  return detections