doc-page-extractor 0.0.9__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of doc-page-extractor might be problematic. Click here for more details.
- doc_page_extractor/__init__.py +12 -2
- doc_page_extractor/extractor.py +61 -155
- doc_page_extractor/latex.py +57 -0
- doc_page_extractor/layout_order.py +240 -0
- doc_page_extractor/ocr.py +1 -3
- doc_page_extractor/overlap.py +1 -1
- doc_page_extractor/struct_eqtable/__init__.py +49 -0
- doc_page_extractor/struct_eqtable/internvl/__init__.py +2 -0
- doc_page_extractor/struct_eqtable/internvl/conversation.py +394 -0
- doc_page_extractor/struct_eqtable/internvl/internvl.py +198 -0
- doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +81 -0
- doc_page_extractor/struct_eqtable/pix2s/__init__.py +3 -0
- doc_page_extractor/struct_eqtable/pix2s/pix2s.py +76 -0
- doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +1047 -0
- doc_page_extractor/table.py +71 -0
- doc_page_extractor/types.py +34 -3
- doc_page_extractor/utils.py +23 -1
- {doc_page_extractor-0.0.9.dist-info → doc_page_extractor-0.1.0.dist-info}/METADATA +6 -2
- {doc_page_extractor-0.0.9.dist-info → doc_page_extractor-0.1.0.dist-info}/RECORD +22 -11
- {doc_page_extractor-0.0.9.dist-info → doc_page_extractor-0.1.0.dist-info}/WHEEL +0 -0
- {doc_page_extractor-0.0.9.dist-info → doc_page_extractor-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {doc_page_extractor-0.0.9.dist-info → doc_page_extractor-0.1.0.dist-info}/top_level.txt +0 -0
doc_page_extractor/__init__.py
CHANGED
|
@@ -1,5 +1,15 @@
|
|
|
1
1
|
from .extractor import DocExtractor
|
|
2
2
|
from .clipper import clip, clip_from_image
|
|
3
3
|
from .plot import plot
|
|
4
|
-
from .
|
|
5
|
-
from .
|
|
4
|
+
from .rectangle import Point, Rectangle
|
|
5
|
+
from .types import (
|
|
6
|
+
ExtractedResult,
|
|
7
|
+
OCRFragment,
|
|
8
|
+
LayoutClass,
|
|
9
|
+
TableLayoutParsedFormat,
|
|
10
|
+
Layout,
|
|
11
|
+
BaseLayout,
|
|
12
|
+
PlainLayout,
|
|
13
|
+
FormulaLayout,
|
|
14
|
+
TableLayout,
|
|
15
|
+
)
|
doc_page_extractor/extractor.py
CHANGED
|
@@ -1,20 +1,30 @@
|
|
|
1
1
|
import os
|
|
2
2
|
|
|
3
|
-
from typing import Literal,
|
|
3
|
+
from typing import Literal, Generator
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from PIL.Image import Image
|
|
6
|
-
from transformers import LayoutLMv3ForTokenClassification
|
|
7
6
|
from doclayout_yolo import YOLOv10
|
|
8
7
|
|
|
9
|
-
from .layoutreader import prepare_inputs, boxes2inputs, parse_logits
|
|
10
8
|
from .ocr import OCR
|
|
11
9
|
from .ocr_corrector import correct_fragments
|
|
12
10
|
from .raw_optimizer import RawOptimizer
|
|
13
11
|
from .rectangle import intersection_area, Rectangle
|
|
14
|
-
from .types import ExtractedResult, OCRFragment, LayoutClass, Layout
|
|
15
12
|
from .downloader import download
|
|
16
|
-
from .
|
|
17
|
-
from .
|
|
13
|
+
from .table import Table
|
|
14
|
+
from .latex import LaTeX
|
|
15
|
+
from .layout_order import LayoutOrder
|
|
16
|
+
from .overlap import merge_fragments_as_line, remove_overlap_layouts
|
|
17
|
+
from .clipper import clip_from_image
|
|
18
|
+
from .types import (
|
|
19
|
+
ExtractedResult,
|
|
20
|
+
OCRFragment,
|
|
21
|
+
TableLayoutParsedFormat,
|
|
22
|
+
Layout,
|
|
23
|
+
LayoutClass,
|
|
24
|
+
PlainLayout,
|
|
25
|
+
TableLayout,
|
|
26
|
+
FormulaLayout,
|
|
27
|
+
)
|
|
18
28
|
|
|
19
29
|
|
|
20
30
|
class DocExtractor:
|
|
@@ -23,15 +33,29 @@ class DocExtractor:
|
|
|
23
33
|
model_dir_path: str,
|
|
24
34
|
device: Literal["cpu", "cuda"] = "cpu",
|
|
25
35
|
ocr_for_each_layouts: bool = True,
|
|
26
|
-
|
|
36
|
+
extract_formula: bool = True,
|
|
37
|
+
extract_table_format: TableLayoutParsedFormat | None = None,
|
|
27
38
|
):
|
|
28
39
|
self._model_dir_path: str = model_dir_path
|
|
29
40
|
self._device: Literal["cpu", "cuda"] = device
|
|
30
41
|
self._ocr_for_each_layouts: bool = ocr_for_each_layouts
|
|
31
|
-
self.
|
|
32
|
-
self.
|
|
42
|
+
self._extract_formula: bool = extract_formula
|
|
43
|
+
self._extract_table_format: TableLayoutParsedFormat | None = extract_table_format
|
|
33
44
|
self._yolo: YOLOv10 | None = None
|
|
34
|
-
self.
|
|
45
|
+
self._ocr: OCR = OCR(
|
|
46
|
+
device=device,
|
|
47
|
+
model_dir_path=os.path.join(model_dir_path, "onnx_ocr"),
|
|
48
|
+
)
|
|
49
|
+
self._table: Table = Table(
|
|
50
|
+
device=device,
|
|
51
|
+
model_path=os.path.join(model_dir_path, "struct_eqtable"),
|
|
52
|
+
)
|
|
53
|
+
self._latex: LaTeX = LaTeX(
|
|
54
|
+
model_path=os.path.join(model_dir_path, "latex"),
|
|
55
|
+
)
|
|
56
|
+
self._layout_order: LayoutOrder = LayoutOrder(
|
|
57
|
+
model_path=os.path.join(model_dir_path, "layoutreader"),
|
|
58
|
+
)
|
|
35
59
|
|
|
36
60
|
def extract(
|
|
37
61
|
self,
|
|
@@ -42,26 +66,21 @@ class DocExtractor:
|
|
|
42
66
|
raw_optimizer = RawOptimizer(image, adjust_points)
|
|
43
67
|
fragments = list(self._ocr.search_fragments(raw_optimizer.image_np))
|
|
44
68
|
raw_optimizer.receive_raw_fragments(fragments)
|
|
45
|
-
|
|
46
|
-
layouts = self._get_layouts(raw_optimizer.image)
|
|
69
|
+
layouts = list(self._yolo_extract_layouts(raw_optimizer.image))
|
|
47
70
|
layouts = self._layouts_matched_by_fragments(fragments, layouts)
|
|
48
71
|
layouts = remove_overlap_layouts(layouts)
|
|
49
72
|
|
|
50
73
|
if self._ocr_for_each_layouts:
|
|
51
74
|
self._correct_fragments_by_ocr_layouts(raw_optimizer.image, layouts)
|
|
52
75
|
|
|
53
|
-
|
|
54
|
-
width, height = raw_optimizer.image.size
|
|
55
|
-
self._order_fragments_by_ai(width, height, layouts)
|
|
56
|
-
else:
|
|
57
|
-
self._order_fragments_by_y(layouts)
|
|
58
|
-
|
|
76
|
+
layouts = self._layout_order.sort(layouts, raw_optimizer.image.size)
|
|
59
77
|
layouts = [layout for layout in layouts if self._should_keep_layout(layout)]
|
|
78
|
+
|
|
79
|
+
self._parse_table_and_formula_layouts(layouts, raw_optimizer)
|
|
80
|
+
|
|
60
81
|
for layout in layouts:
|
|
61
|
-
layout.fragments =
|
|
62
|
-
layout.fragments.sort(key=lambda fragment: fragment.order)
|
|
82
|
+
layout.fragments = merge_fragments_as_line(layout.fragments)
|
|
63
83
|
|
|
64
|
-
layouts = self._sort_layouts(layouts)
|
|
65
84
|
raw_optimizer.receive_raw_layouts(layouts)
|
|
66
85
|
|
|
67
86
|
return ExtractedResult(
|
|
@@ -71,7 +90,7 @@ class DocExtractor:
|
|
|
71
90
|
adjusted_image=raw_optimizer.adjusted_image,
|
|
72
91
|
)
|
|
73
92
|
|
|
74
|
-
def
|
|
93
|
+
def _yolo_extract_layouts(self, source: Image) -> Generator[Layout, None, None]:
|
|
75
94
|
# about source parameter to see:
|
|
76
95
|
# https://github.com/opendatalab/DocLayout-YOLO/blob/7c4be36bc61f11b67cf4a44ee47f3c41e9800a91/doclayout_yolo/data/build.py#L157-L175
|
|
77
96
|
det_res = self._get_yolo().predict(
|
|
@@ -81,7 +100,6 @@ class DocExtractor:
|
|
|
81
100
|
device=self._device # Device to use (e.g., "cuda" or "cpu")
|
|
82
101
|
)
|
|
83
102
|
boxes = det_res[0].__dict__["boxes"]
|
|
84
|
-
layouts: list[Layout] = []
|
|
85
103
|
|
|
86
104
|
for cls_id, rect in zip(boxes.cls, boxes.xyxy):
|
|
87
105
|
cls_id = cls_id.item()
|
|
@@ -98,9 +116,12 @@ class DocExtractor:
|
|
|
98
116
|
lb=(x1, y2),
|
|
99
117
|
rb=(x2, y2),
|
|
100
118
|
)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
119
|
+
if cls == LayoutClass.TABLE:
|
|
120
|
+
yield TableLayout(cls=cls, rect=rect, fragments=[], parsed=None)
|
|
121
|
+
elif cls == LayoutClass.ISOLATE_FORMULA:
|
|
122
|
+
yield FormulaLayout(cls=cls, rect=rect, fragments=[], latex=None)
|
|
123
|
+
else:
|
|
124
|
+
yield PlainLayout(cls=cls, rect=rect, fragments=[])
|
|
104
125
|
|
|
105
126
|
def _layouts_matched_by_fragments(self, fragments: list[OCRFragment], layouts: list[Layout]):
|
|
106
127
|
layouts_group = self._split_layouts_by_group(layouts)
|
|
@@ -116,6 +137,17 @@ class DocExtractor:
|
|
|
116
137
|
for layout in layouts:
|
|
117
138
|
correct_fragments(self._ocr, source, layout)
|
|
118
139
|
|
|
140
|
+
def _parse_table_and_formula_layouts(self, layouts: list[Layout], raw_optimizer: RawOptimizer):
|
|
141
|
+
for layout in layouts:
|
|
142
|
+
if isinstance(layout, FormulaLayout) and self._extract_formula:
|
|
143
|
+
image = clip_from_image(raw_optimizer.image, layout.rect)
|
|
144
|
+
layout.latex = self._latex.extract(image)
|
|
145
|
+
elif isinstance(layout, TableLayout) and self._extract_table_format is not None:
|
|
146
|
+
image = clip_from_image(raw_optimizer.image, layout.rect)
|
|
147
|
+
parsed = self._table.predict(image, self._extract_table_format)
|
|
148
|
+
if parsed is not None:
|
|
149
|
+
layout.parsed = (parsed, self._extract_table_format)
|
|
150
|
+
|
|
119
151
|
def _split_layouts_by_group(self, layouts: list[Layout]):
|
|
120
152
|
texts_layouts: list[Layout] = []
|
|
121
153
|
abandon_layouts: list[Layout] = []
|
|
@@ -158,67 +190,16 @@ class DocExtractor:
|
|
|
158
190
|
|
|
159
191
|
def _get_yolo(self) -> YOLOv10:
|
|
160
192
|
if self._yolo is None:
|
|
193
|
+
base_path = os.path.join(self._model_dir_path, "yolo")
|
|
194
|
+
os.makedirs(base_path, exist_ok=True)
|
|
161
195
|
yolo_model_url = "https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0/resolve/main/models/Layout/YOLO/doclayout_yolo_ft.pt"
|
|
162
196
|
yolo_model_name = "doclayout_yolo_ft.pt"
|
|
163
|
-
yolo_model_path = Path(os.path.join(
|
|
197
|
+
yolo_model_path = Path(os.path.join(base_path, yolo_model_name))
|
|
164
198
|
if not yolo_model_path.exists():
|
|
165
199
|
download(yolo_model_url, yolo_model_path)
|
|
166
200
|
self._yolo = YOLOv10(str(yolo_model_path))
|
|
167
201
|
return self._yolo
|
|
168
202
|
|
|
169
|
-
def _order_fragments_by_y(self, layouts: list[Layout]):
|
|
170
|
-
fragments = list(self._iter_fragments(layouts))
|
|
171
|
-
fragments.sort(key=lambda f: f.rect.lt[1] + f.rect.rt[1])
|
|
172
|
-
for i, fragment in enumerate(fragments):
|
|
173
|
-
fragment.order = i
|
|
174
|
-
|
|
175
|
-
def _order_fragments_by_ai(self, width: int, height: int, layouts: list[Layout]):
|
|
176
|
-
if width == 0 or height == 0:
|
|
177
|
-
return
|
|
178
|
-
|
|
179
|
-
layout_model = self._get_layout()
|
|
180
|
-
boxes: list[list[int]] = []
|
|
181
|
-
steps: float = 1000.0 # max value of layoutreader
|
|
182
|
-
x_rate: float = 1.0
|
|
183
|
-
y_rate: float = 1.0
|
|
184
|
-
x_offset: float = 0.0
|
|
185
|
-
y_offset: float = 0.0
|
|
186
|
-
if width > height:
|
|
187
|
-
y_rate = height / width
|
|
188
|
-
y_offset = (1.0 - y_rate) / 2.0
|
|
189
|
-
else:
|
|
190
|
-
x_rate = width / height
|
|
191
|
-
x_offset = (1.0 - x_rate) / 2.0
|
|
192
|
-
|
|
193
|
-
for left, top, right, bottom in self._collect_rate_boxes(
|
|
194
|
-
fragments=self._iter_fragments(layouts),
|
|
195
|
-
):
|
|
196
|
-
boxes.append([
|
|
197
|
-
round((left * x_rate + x_offset) * steps),
|
|
198
|
-
round((top * y_rate + y_offset) * steps),
|
|
199
|
-
round((right * x_rate + x_offset) * steps),
|
|
200
|
-
round((bottom * y_rate + y_offset) * steps),
|
|
201
|
-
])
|
|
202
|
-
inputs = boxes2inputs(boxes)
|
|
203
|
-
inputs = prepare_inputs(inputs, layout_model)
|
|
204
|
-
logits = layout_model(**inputs).logits.cpu().squeeze(0)
|
|
205
|
-
orders: list[int] = parse_logits(logits, len(boxes))
|
|
206
|
-
|
|
207
|
-
for order, fragment in zip(orders, self._iter_fragments(layouts)):
|
|
208
|
-
fragment.order = order
|
|
209
|
-
|
|
210
|
-
def _get_layout(self) -> LayoutLMv3ForTokenClassification:
|
|
211
|
-
if self._layout is None:
|
|
212
|
-
cache_dir = ensure_dir(
|
|
213
|
-
os.path.join(self._model_dir_path, "layoutreader"),
|
|
214
|
-
)
|
|
215
|
-
self._layout = LayoutLMv3ForTokenClassification.from_pretrained(
|
|
216
|
-
pretrained_model_name_or_path="hantian/layoutreader",
|
|
217
|
-
cache_dir=cache_dir,
|
|
218
|
-
local_files_only=os.path.exists(os.path.join(cache_dir, "models--hantian--layoutreader")),
|
|
219
|
-
)
|
|
220
|
-
return self._layout
|
|
221
|
-
|
|
222
203
|
def _should_keep_layout(self, layout: Layout) -> bool:
|
|
223
204
|
if len(layout.fragments) > 0:
|
|
224
205
|
return True
|
|
@@ -229,78 +210,3 @@ class DocExtractor:
|
|
|
229
210
|
cls == LayoutClass.ISOLATE_FORMULA
|
|
230
211
|
)
|
|
231
212
|
|
|
232
|
-
def _sort_layouts(self, layouts: list[Layout]) -> list[Layout]:
|
|
233
|
-
layouts.sort(key=lambda layout: layout.rect.lt[1] + layout.rect.rt[1])
|
|
234
|
-
|
|
235
|
-
sorted_layouts: list[tuple[int, Layout]] = []
|
|
236
|
-
empty_layouts: list[tuple[int, Layout]] = []
|
|
237
|
-
|
|
238
|
-
for i, layout in enumerate(layouts):
|
|
239
|
-
if len(layout.fragments) > 0:
|
|
240
|
-
sorted_layouts.append((i, layout))
|
|
241
|
-
else:
|
|
242
|
-
empty_layouts.append((i, layout))
|
|
243
|
-
|
|
244
|
-
# try to maintain the order of empty layouts and other layouts as much as possible
|
|
245
|
-
for i, layout in empty_layouts:
|
|
246
|
-
max_less_index: int = -1
|
|
247
|
-
max_less_layout: Layout | None = None
|
|
248
|
-
max_less_index_in_enumerated: int = -1
|
|
249
|
-
for j, (k, sorted_layout) in enumerate(sorted_layouts):
|
|
250
|
-
if k < i and k > max_less_index:
|
|
251
|
-
max_less_index = k
|
|
252
|
-
max_less_layout = sorted_layout
|
|
253
|
-
max_less_index_in_enumerated = j
|
|
254
|
-
|
|
255
|
-
if max_less_layout is None:
|
|
256
|
-
sorted_layouts.insert(0, (i, layout))
|
|
257
|
-
else:
|
|
258
|
-
sorted_layouts.insert(max_less_index_in_enumerated + 1, (i, layout))
|
|
259
|
-
|
|
260
|
-
return [layout for _, layout in sorted_layouts]
|
|
261
|
-
|
|
262
|
-
def _collect_rate_boxes(self, fragments: Iterable[OCRFragment]):
|
|
263
|
-
boxes = self._get_boxes(fragments)
|
|
264
|
-
left = float("inf")
|
|
265
|
-
top = float("inf")
|
|
266
|
-
right = float("-inf")
|
|
267
|
-
bottom = float("-inf")
|
|
268
|
-
|
|
269
|
-
for _left, _top, _right, _bottom in boxes:
|
|
270
|
-
left = min(left, _left)
|
|
271
|
-
top = min(top, _top)
|
|
272
|
-
right = max(right, _right)
|
|
273
|
-
bottom = max(bottom, _bottom)
|
|
274
|
-
|
|
275
|
-
width = right - left
|
|
276
|
-
height = bottom - top
|
|
277
|
-
|
|
278
|
-
if width == 0 or height == 0:
|
|
279
|
-
return
|
|
280
|
-
|
|
281
|
-
for _left, _top, _right, _bottom in boxes:
|
|
282
|
-
yield (
|
|
283
|
-
(_left - left) / width,
|
|
284
|
-
(_top - top) / height,
|
|
285
|
-
(_right - left) / width,
|
|
286
|
-
(_bottom - top) / height,
|
|
287
|
-
)
|
|
288
|
-
|
|
289
|
-
def _get_boxes(self, fragments: Iterable[OCRFragment]):
|
|
290
|
-
boxes: list[tuple[float, float, float, float]] = []
|
|
291
|
-
for fragment in fragments:
|
|
292
|
-
left: float = float("inf")
|
|
293
|
-
top: float = float("inf")
|
|
294
|
-
right: float = float("-inf")
|
|
295
|
-
bottom: float = float("-inf")
|
|
296
|
-
for x, y in fragment.rect:
|
|
297
|
-
left = min(left, x)
|
|
298
|
-
top = min(top, y)
|
|
299
|
-
right = max(right, x)
|
|
300
|
-
bottom = max(bottom, y)
|
|
301
|
-
boxes.append((left, top, right, bottom))
|
|
302
|
-
return boxes
|
|
303
|
-
|
|
304
|
-
def _iter_fragments(self, layouts: list[Layout]):
|
|
305
|
-
for layout in layouts:
|
|
306
|
-
yield from layout.fragments
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import requests
|
|
4
|
+
|
|
5
|
+
from munch import Munch
|
|
6
|
+
from pix2tex.cli import LatexOCR
|
|
7
|
+
from PIL.Image import Image
|
|
8
|
+
from .utils import expand_image
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LaTeX:
|
|
12
|
+
def __init__(self, model_path: str):
|
|
13
|
+
self._model_path: str = model_path
|
|
14
|
+
self._model: LatexOCR | None = None
|
|
15
|
+
|
|
16
|
+
def extract(self, image: Image) -> str | None:
|
|
17
|
+
image = expand_image(image, 0.1) # 添加边缘提高识别准确率
|
|
18
|
+
model = self._get_model()
|
|
19
|
+
with torch.no_grad():
|
|
20
|
+
return model(image)
|
|
21
|
+
|
|
22
|
+
def _get_model(self) -> LatexOCR:
|
|
23
|
+
if self._model is None:
|
|
24
|
+
if not os.path.exists(self._model_path):
|
|
25
|
+
self._download_model()
|
|
26
|
+
|
|
27
|
+
self._model = LatexOCR(Munch({
|
|
28
|
+
"config": os.path.join("settings", "config.yaml"),
|
|
29
|
+
"checkpoint": os.path.join(self._model_path, "weights.pth"),
|
|
30
|
+
"no_cuda": True,
|
|
31
|
+
"no_resize": False,
|
|
32
|
+
}))
|
|
33
|
+
return self._model
|
|
34
|
+
|
|
35
|
+
# from https://github.com/lukas-blecher/LaTeX-OCR/blob/5c1ac929bd19a7ecf86d5fb8d94771c8969fcb80/pix2tex/model/checkpoints/get_latest_checkpoint.py#L37-L45
|
|
36
|
+
def _download_model(self):
|
|
37
|
+
os.makedirs(self._model_path, exist_ok=True)
|
|
38
|
+
tag = "v0.0.1"
|
|
39
|
+
files: list[tuple[str, str]] = (
|
|
40
|
+
("weights.pth", f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/weights.pth"),
|
|
41
|
+
("image_resizer.pth", f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/image_resizer.pth")
|
|
42
|
+
)
|
|
43
|
+
for file_name, url in files:
|
|
44
|
+
file_path = os.path.join(self._model_path, file_name)
|
|
45
|
+
try:
|
|
46
|
+
with open(file_path, "wb") as file:
|
|
47
|
+
response = requests.get(url, stream=True, timeout=15)
|
|
48
|
+
response.raise_for_status()
|
|
49
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
50
|
+
if chunk: # 过滤掉保持连接的新块
|
|
51
|
+
file.write(chunk)
|
|
52
|
+
file.flush()
|
|
53
|
+
|
|
54
|
+
except BaseException as e:
|
|
55
|
+
if os.path.exists(file_path):
|
|
56
|
+
os.remove(file_path)
|
|
57
|
+
raise e
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from typing import Generator
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from transformers import LayoutLMv3ForTokenClassification
|
|
7
|
+
|
|
8
|
+
from .types import Layout, LayoutClass
|
|
9
|
+
from .layoutreader import prepare_inputs, boxes2inputs, parse_logits
|
|
10
|
+
from .utils import ensure_dir
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class _BBox:
|
|
15
|
+
layout_index: int
|
|
16
|
+
fragment_index: int
|
|
17
|
+
virtual: bool
|
|
18
|
+
order: int
|
|
19
|
+
value: tuple[float, float, float, float]
|
|
20
|
+
|
|
21
|
+
class LayoutOrder:
|
|
22
|
+
def __init__(self, model_path: str):
|
|
23
|
+
self._model_path: str = model_path
|
|
24
|
+
self._model: LayoutLMv3ForTokenClassification | None = None
|
|
25
|
+
|
|
26
|
+
def _get_model(self) -> LayoutLMv3ForTokenClassification:
|
|
27
|
+
if self._model is None:
|
|
28
|
+
model_path = ensure_dir(self._model_path)
|
|
29
|
+
self._model = LayoutLMv3ForTokenClassification.from_pretrained(
|
|
30
|
+
pretrained_model_name_or_path="hantian/layoutreader",
|
|
31
|
+
cache_dir=model_path,
|
|
32
|
+
local_files_only=os.path.exists(os.path.join(model_path, "models--hantian--layoutreader")),
|
|
33
|
+
)
|
|
34
|
+
return self._model
|
|
35
|
+
|
|
36
|
+
def sort(self, layouts: list[Layout], size: tuple[int, int]) -> list[Layout]:
|
|
37
|
+
width, height = size
|
|
38
|
+
if width == 0 or height == 0:
|
|
39
|
+
return layouts
|
|
40
|
+
|
|
41
|
+
bbox_list = self._order_and_get_bbox_list(
|
|
42
|
+
layouts=layouts,
|
|
43
|
+
width=width,
|
|
44
|
+
height=height,
|
|
45
|
+
)
|
|
46
|
+
if bbox_list is None:
|
|
47
|
+
return layouts
|
|
48
|
+
|
|
49
|
+
return self._sort_layouts_and_fragments(layouts, bbox_list)
|
|
50
|
+
|
|
51
|
+
def _order_and_get_bbox_list(
|
|
52
|
+
self,
|
|
53
|
+
layouts: list[Layout],
|
|
54
|
+
width: int,
|
|
55
|
+
height: int,
|
|
56
|
+
) -> list[_BBox] | None:
|
|
57
|
+
|
|
58
|
+
line_height = self._line_height(layouts)
|
|
59
|
+
bbox_list: list[_BBox] = []
|
|
60
|
+
|
|
61
|
+
for i, layout in enumerate(layouts):
|
|
62
|
+
if layout.cls == LayoutClass.PLAIN_TEXT and \
|
|
63
|
+
len(layout.fragments) > 0:
|
|
64
|
+
for j, fragment in enumerate(layout.fragments):
|
|
65
|
+
bbox_list.append(_BBox(
|
|
66
|
+
layout_index=i,
|
|
67
|
+
fragment_index=j,
|
|
68
|
+
virtual=False,
|
|
69
|
+
order=0,
|
|
70
|
+
value=fragment.rect.wrapper,
|
|
71
|
+
))
|
|
72
|
+
else:
|
|
73
|
+
bbox_list.extend(
|
|
74
|
+
self._generate_virtual_lines(
|
|
75
|
+
layout=layout,
|
|
76
|
+
layout_index=i,
|
|
77
|
+
line_height=line_height,
|
|
78
|
+
width=width,
|
|
79
|
+
height=height,
|
|
80
|
+
),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if len(bbox_list) > 200:
|
|
84
|
+
# https://github.com/opendatalab/MinerU/blob/980f5c8cd70f22f8c0c9b7b40eaff6f4804e6524/magic_pdf/pdf_parse_union_core_v2.py#L522
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
layoutreader_size = 1000.0
|
|
88
|
+
x_scale = layoutreader_size / float(width)
|
|
89
|
+
y_scale = layoutreader_size / float(height)
|
|
90
|
+
|
|
91
|
+
for bbox in bbox_list:
|
|
92
|
+
x0, y0, x1, y1 = self._squeeze(bbox.value, width, height)
|
|
93
|
+
x0 = round(x0 * x_scale)
|
|
94
|
+
y0 = round(y0 * y_scale)
|
|
95
|
+
x1 = round(x1 * x_scale)
|
|
96
|
+
y1 = round(y1 * y_scale)
|
|
97
|
+
bbox.value = (x0, y0, x1, y1)
|
|
98
|
+
|
|
99
|
+
bbox_list.sort(key=lambda b: b.value) # 必须排序,乱序传入 layoutreader 会令它无法识别正确顺序
|
|
100
|
+
model = self._get_model()
|
|
101
|
+
|
|
102
|
+
with torch.no_grad():
|
|
103
|
+
inputs = boxes2inputs([list(bbox.value) for bbox in bbox_list])
|
|
104
|
+
inputs = prepare_inputs(inputs, model)
|
|
105
|
+
logits = model(**inputs).logits.cpu().squeeze(0)
|
|
106
|
+
orders = parse_logits(logits, len(bbox_list))
|
|
107
|
+
|
|
108
|
+
sorted_bbox_list = [bbox_list[i] for i in orders]
|
|
109
|
+
for i, bbox in enumerate(sorted_bbox_list):
|
|
110
|
+
bbox.order = i
|
|
111
|
+
|
|
112
|
+
return sorted_bbox_list
|
|
113
|
+
|
|
114
|
+
def _sort_layouts_and_fragments(self, layouts: list[Layout], bbox_list: list[_BBox]):
|
|
115
|
+
layout_bbox_list: list[list[_BBox]] = [[] for _ in range(len(layouts))]
|
|
116
|
+
for bbox in bbox_list:
|
|
117
|
+
layout_bbox_list[bbox.layout_index].append(bbox)
|
|
118
|
+
|
|
119
|
+
layouts_with_median_order: list[tuple[Layout, float]] = []
|
|
120
|
+
for layout_index, bbox_list in enumerate(layout_bbox_list):
|
|
121
|
+
layout = layouts[layout_index]
|
|
122
|
+
orders = [b.order for b in bbox_list] # virtual bbox 保证了 orders 不可能为空
|
|
123
|
+
median_order = self._median(orders)
|
|
124
|
+
layouts_with_median_order.append((layout, median_order))
|
|
125
|
+
|
|
126
|
+
for layout, bbox_list in zip(layouts, layout_bbox_list):
|
|
127
|
+
for bbox in bbox_list:
|
|
128
|
+
if not bbox.virtual:
|
|
129
|
+
layout.fragments[bbox.fragment_index].order = bbox.order
|
|
130
|
+
if all(not bbox.virtual for bbox in bbox_list):
|
|
131
|
+
layout.fragments.sort(key=lambda f: f.order)
|
|
132
|
+
|
|
133
|
+
layouts_with_median_order.sort(key=lambda x: x[1])
|
|
134
|
+
layouts = [layout for layout, _ in layouts_with_median_order]
|
|
135
|
+
next_fragment_order: int = 0
|
|
136
|
+
|
|
137
|
+
for layout in layouts:
|
|
138
|
+
for fragment in layout.fragments:
|
|
139
|
+
fragment.order = next_fragment_order
|
|
140
|
+
next_fragment_order += 1
|
|
141
|
+
|
|
142
|
+
return layouts
|
|
143
|
+
|
|
144
|
+
def _line_height(self, layouts: list[Layout]) -> float:
|
|
145
|
+
line_height: float = 0.0
|
|
146
|
+
count: int = 0
|
|
147
|
+
for layout in layouts:
|
|
148
|
+
for fragment in layout.fragments:
|
|
149
|
+
_, height = fragment.rect.size
|
|
150
|
+
line_height += height
|
|
151
|
+
count += 1
|
|
152
|
+
if count == 0:
|
|
153
|
+
return 10.0
|
|
154
|
+
return line_height / float(count)
|
|
155
|
+
|
|
156
|
+
def _generate_virtual_lines(
|
|
157
|
+
self,
|
|
158
|
+
layout: Layout,
|
|
159
|
+
layout_index: int,
|
|
160
|
+
line_height: float,
|
|
161
|
+
width: int,
|
|
162
|
+
height: int,
|
|
163
|
+
) -> Generator[_BBox, None, None]:
|
|
164
|
+
|
|
165
|
+
# https://github.com/opendatalab/MinerU/blob/980f5c8cd70f22f8c0c9b7b40eaff6f4804e6524/magic_pdf/pdf_parse_union_core_v2.py#L451-L490
|
|
166
|
+
x0, y0, x1, y1 = layout.rect.wrapper
|
|
167
|
+
layout_height = y1 - y0
|
|
168
|
+
layout_weight = x1 - x0
|
|
169
|
+
lines = int(layout_height / line_height)
|
|
170
|
+
|
|
171
|
+
if layout_height <= line_height * 2:
|
|
172
|
+
yield _BBox(
|
|
173
|
+
layout_index=layout_index,
|
|
174
|
+
fragment_index=0,
|
|
175
|
+
virtual=True,
|
|
176
|
+
order=0,
|
|
177
|
+
value=(x0, y0, x1, y1),
|
|
178
|
+
)
|
|
179
|
+
return
|
|
180
|
+
|
|
181
|
+
elif layout_height <= height * 0.25 or \
|
|
182
|
+
width * 0.5 <= layout_weight or \
|
|
183
|
+
width * 0.25 < layout_weight:
|
|
184
|
+
if layout_weight > width * 0.4:
|
|
185
|
+
lines = 3
|
|
186
|
+
elif layout_weight <= width * 0.25:
|
|
187
|
+
if layout_height / layout_weight > 1.2: # 细长的不分
|
|
188
|
+
yield _BBox(
|
|
189
|
+
layout_index=layout_index,
|
|
190
|
+
fragment_index=0,
|
|
191
|
+
virtual=True,
|
|
192
|
+
order=0,
|
|
193
|
+
value=(x0, y0, x1, y1),
|
|
194
|
+
)
|
|
195
|
+
return
|
|
196
|
+
else: # 不细长的还是分成两行
|
|
197
|
+
lines = 2
|
|
198
|
+
|
|
199
|
+
lines = max(1, lines)
|
|
200
|
+
line_height = (y1 - y0) / lines
|
|
201
|
+
current_y = y0
|
|
202
|
+
|
|
203
|
+
for i in range(lines):
|
|
204
|
+
yield _BBox(
|
|
205
|
+
layout_index=layout_index,
|
|
206
|
+
fragment_index=i,
|
|
207
|
+
virtual=True,
|
|
208
|
+
order=0,
|
|
209
|
+
value=(x0, current_y, x1, current_y + line_height),
|
|
210
|
+
)
|
|
211
|
+
current_y += line_height
|
|
212
|
+
|
|
213
|
+
def _median(self, numbers: list[int]) -> float:
|
|
214
|
+
sorted_numbers = sorted(numbers)
|
|
215
|
+
n = len(sorted_numbers)
|
|
216
|
+
|
|
217
|
+
# 判断是奇数还是偶数个元素
|
|
218
|
+
if n % 2 == 1:
|
|
219
|
+
# 奇数情况,直接取中间的数
|
|
220
|
+
return float(sorted_numbers[n // 2])
|
|
221
|
+
else:
|
|
222
|
+
# 偶数情况,取中间两个数的平均值
|
|
223
|
+
mid1 = sorted_numbers[n // 2 - 1]
|
|
224
|
+
mid2 = sorted_numbers[n // 2]
|
|
225
|
+
return float((mid1 + mid2) / 2)
|
|
226
|
+
|
|
227
|
+
def _squeeze(self, bbox: _BBox, width: int, height: int) -> _BBox:
|
|
228
|
+
x0, y0, x1, y1 = bbox
|
|
229
|
+
x0 = self._squeeze_value(x0, width)
|
|
230
|
+
x1 = self._squeeze_value(x1, width)
|
|
231
|
+
y0 = self._squeeze_value(y0, height)
|
|
232
|
+
y1 = self._squeeze_value(y1, height)
|
|
233
|
+
return x0, y0, x1, y1
|
|
234
|
+
|
|
235
|
+
def _squeeze_value(self, position: float, size: int) -> float:
|
|
236
|
+
if position < 0:
|
|
237
|
+
position = 0.0
|
|
238
|
+
if position > size:
|
|
239
|
+
position = float(size)
|
|
240
|
+
return position
|
doc_page_extractor/ocr.py
CHANGED
|
@@ -58,7 +58,6 @@ class OCR:
|
|
|
58
58
|
self._text_system: TextSystem | None = None
|
|
59
59
|
|
|
60
60
|
def search_fragments(self, image: np.ndarray) -> Generator[OCRFragment, None, None]:
|
|
61
|
-
index: int = 0
|
|
62
61
|
for box, res in self._ocr(image):
|
|
63
62
|
text, rank = res
|
|
64
63
|
if is_space_text(text):
|
|
@@ -74,12 +73,11 @@ class OCR:
|
|
|
74
73
|
continue
|
|
75
74
|
|
|
76
75
|
yield OCRFragment(
|
|
77
|
-
order=
|
|
76
|
+
order=0,
|
|
78
77
|
text=text,
|
|
79
78
|
rank=rank,
|
|
80
79
|
rect=rect,
|
|
81
80
|
)
|
|
82
|
-
index += 1
|
|
83
81
|
|
|
84
82
|
def _ocr(self, image: np.ndarray) -> Generator[tuple[list[list[float]], tuple[str, float]], None, None]:
|
|
85
83
|
text_system = self._get_text_system()
|
doc_page_extractor/overlap.py
CHANGED
|
@@ -60,7 +60,7 @@ class _OverlapMatrixContext:
|
|
|
60
60
|
rate >= _INCLUDES_MIN_RATE:
|
|
61
61
|
yield i
|
|
62
62
|
|
|
63
|
-
def
|
|
63
|
+
def merge_fragments_as_line(origin_fragments: list[OCRFragment]) -> list[OCRFragment]:
|
|
64
64
|
fragments: list[OCRFragment] = []
|
|
65
65
|
for group in _split_fragments_into_groups(origin_fragments):
|
|
66
66
|
if len(group) == 1:
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from .pix2s import Pix2Struct, Pix2StructTensorRT
|
|
2
|
+
from .internvl import InternVL, InternVL_LMDeploy
|
|
3
|
+
|
|
4
|
+
from transformers import AutoConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
__ALL_MODELS__ = {
|
|
8
|
+
'Pix2Struct': Pix2Struct,
|
|
9
|
+
'Pix2StructTensorRT': Pix2StructTensorRT,
|
|
10
|
+
'InternVL': InternVL,
|
|
11
|
+
'InternVL_LMDeploy': InternVL_LMDeploy,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_model_name(model_path):
|
|
16
|
+
model_config = AutoConfig.from_pretrained(
|
|
17
|
+
model_path,
|
|
18
|
+
trust_remote_code=True,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
if 'Pix2Struct' in model_config.architectures[0]:
|
|
22
|
+
model_name = 'Pix2Struct'
|
|
23
|
+
elif 'InternVL' in model_config.architectures[0]:
|
|
24
|
+
model_name = 'InternVL'
|
|
25
|
+
else:
|
|
26
|
+
raise ValueError(f"Unsupported model type: {model_config.architectures[0]}")
|
|
27
|
+
|
|
28
|
+
return model_name
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def build_model(
|
|
32
|
+
model_ckpt='U4R/StructTable-InternVL2-1B',
|
|
33
|
+
cache_dir=None,
|
|
34
|
+
local_files_only=None,
|
|
35
|
+
**kwargs,
|
|
36
|
+
):
|
|
37
|
+
model_name = get_model_name(model_ckpt)
|
|
38
|
+
if model_name == 'InternVL' and kwargs.get('lmdeploy', False):
|
|
39
|
+
model_name = 'InternVL_LMDeploy'
|
|
40
|
+
elif model_name == 'Pix2Struct' and kwargs.get('tensorrt_path', None):
|
|
41
|
+
model_name = 'Pix2StructTensorRT'
|
|
42
|
+
|
|
43
|
+
model = __ALL_MODELS__[model_name](
|
|
44
|
+
model_ckpt,
|
|
45
|
+
cache_dir=cache_dir,
|
|
46
|
+
local_files_only=local_files_only,
|
|
47
|
+
**kwargs
|
|
48
|
+
)
|
|
49
|
+
return model
|