doc-page-extractor 0.0.9__tar.gz → 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of doc-page-extractor might be problematic. Click here for more details.

Files changed (53) hide show
  1. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/PKG-INFO +6 -2
  2. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/README.md +3 -1
  3. doc_page_extractor-0.1.0/doc_page_extractor/__init__.py +15 -0
  4. doc_page_extractor-0.1.0/doc_page_extractor/extractor.py +212 -0
  5. doc_page_extractor-0.1.0/doc_page_extractor/latex.py +57 -0
  6. doc_page_extractor-0.1.0/doc_page_extractor/layout_order.py +240 -0
  7. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/ocr.py +1 -3
  8. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/overlap.py +1 -1
  9. doc_page_extractor-0.1.0/doc_page_extractor/struct_eqtable/__init__.py +49 -0
  10. doc_page_extractor-0.1.0/doc_page_extractor/struct_eqtable/internvl/__init__.py +2 -0
  11. doc_page_extractor-0.1.0/doc_page_extractor/struct_eqtable/internvl/conversation.py +394 -0
  12. doc_page_extractor-0.1.0/doc_page_extractor/struct_eqtable/internvl/internvl.py +198 -0
  13. doc_page_extractor-0.1.0/doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +81 -0
  14. doc_page_extractor-0.1.0/doc_page_extractor/struct_eqtable/pix2s/__init__.py +3 -0
  15. doc_page_extractor-0.1.0/doc_page_extractor/struct_eqtable/pix2s/pix2s.py +76 -0
  16. doc_page_extractor-0.1.0/doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +1047 -0
  17. doc_page_extractor-0.1.0/doc_page_extractor/table.py +71 -0
  18. doc_page_extractor-0.1.0/doc_page_extractor/types.py +67 -0
  19. doc_page_extractor-0.1.0/doc_page_extractor/utils.py +32 -0
  20. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor.egg-info/PKG-INFO +6 -2
  21. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor.egg-info/SOURCES.txt +11 -0
  22. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor.egg-info/requires.txt +3 -1
  23. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/setup.py +7 -2
  24. doc_page_extractor-0.0.9/doc_page_extractor/__init__.py +0 -5
  25. doc_page_extractor-0.0.9/doc_page_extractor/extractor.py +0 -306
  26. doc_page_extractor-0.0.9/doc_page_extractor/types.py +0 -36
  27. doc_page_extractor-0.0.9/doc_page_extractor/utils.py +0 -10
  28. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/LICENSE +0 -0
  29. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/clipper.py +0 -0
  30. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/downloader.py +0 -0
  31. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/layoutreader.py +0 -0
  32. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/ocr_corrector.py +0 -0
  33. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/onnxocr/__init__.py +0 -0
  34. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/onnxocr/cls_postprocess.py +0 -0
  35. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/onnxocr/db_postprocess.py +0 -0
  36. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/onnxocr/imaug.py +0 -0
  37. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/onnxocr/operators.py +0 -0
  38. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/onnxocr/predict_base.py +0 -0
  39. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/onnxocr/predict_cls.py +0 -0
  40. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/onnxocr/predict_det.py +0 -0
  41. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/onnxocr/predict_rec.py +0 -0
  42. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/onnxocr/predict_system.py +0 -0
  43. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/onnxocr/rec_postprocess.py +0 -0
  44. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/onnxocr/utils.py +0 -0
  45. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/plot.py +0 -0
  46. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/raw_optimizer.py +0 -0
  47. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/rectangle.py +0 -0
  48. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor/rotation.py +0 -0
  49. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor.egg-info/dependency_links.txt +0 -0
  50. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/doc_page_extractor.egg-info/top_level.txt +0 -0
  51. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/setup.cfg +0 -0
  52. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/tests/__init__.py +0 -0
  53. {doc_page_extractor-0.0.9 → doc_page_extractor-0.1.0}/tests/test_history_bus.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: doc-page-extractor
3
- Version: 0.0.9
3
+ Version: 0.1.0
4
4
  Summary: doc page extractor can identify text and format in images and return structured data.
5
5
  Home-page: https://github.com/Moskize91/doc-page-extractor
6
6
  Author: Tao Zeyu
@@ -12,8 +12,10 @@ Requires-Dist: pillow<11.0,>=10.3
12
12
  Requires-Dist: pyclipper<2.0,>=1.2.0
13
13
  Requires-Dist: numpy<2.0,>=1.24.0
14
14
  Requires-Dist: shapely<3.0,>=2.0.0
15
- Requires-Dist: transformers<5.0,>=4.48.0
15
+ Requires-Dist: transformers<=4.47,>=4.42.4
16
16
  Requires-Dist: doclayout_yolo>=0.0.3
17
+ Requires-Dist: pix2tex<=0.2.0,>=0.1.4
18
+ Requires-Dist: accelerate<2.0,>=1.6.0
17
19
  Dynamic: author
18
20
  Dynamic: author-email
19
21
  Dynamic: description
@@ -78,3 +80,5 @@ The code of `doc_page_extractor/onnxocr` in this repo comes from [OnnxOCR](https
78
80
  - [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
79
81
  - [OnnxOCR](https://github.com/jingsongliujing/OnnxOCR)
80
82
  - [layoutreader](https://github.com/ppaanngggg/layoutreader)
83
+ - [StructEqTable](https://github.com/Alpha-Innovator/StructEqTable-Deploy)
84
+ - [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)
@@ -52,4 +52,6 @@ The code of `doc_page_extractor/onnxocr` in this repo comes from [OnnxOCR](https
52
52
 
53
53
  - [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
54
54
  - [OnnxOCR](https://github.com/jingsongliujing/OnnxOCR)
55
- - [layoutreader](https://github.com/ppaanngggg/layoutreader)
55
+ - [layoutreader](https://github.com/ppaanngggg/layoutreader)
56
+ - [StructEqTable](https://github.com/Alpha-Innovator/StructEqTable-Deploy)
57
+ - [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)
@@ -0,0 +1,15 @@
1
+ from .extractor import DocExtractor
2
+ from .clipper import clip, clip_from_image
3
+ from .plot import plot
4
+ from .rectangle import Point, Rectangle
5
+ from .types import (
6
+ ExtractedResult,
7
+ OCRFragment,
8
+ LayoutClass,
9
+ TableLayoutParsedFormat,
10
+ Layout,
11
+ BaseLayout,
12
+ PlainLayout,
13
+ FormulaLayout,
14
+ TableLayout,
15
+ )
@@ -0,0 +1,212 @@
1
+ import os
2
+
3
+ from typing import Literal, Generator
4
+ from pathlib import Path
5
+ from PIL.Image import Image
6
+ from doclayout_yolo import YOLOv10
7
+
8
+ from .ocr import OCR
9
+ from .ocr_corrector import correct_fragments
10
+ from .raw_optimizer import RawOptimizer
11
+ from .rectangle import intersection_area, Rectangle
12
+ from .downloader import download
13
+ from .table import Table
14
+ from .latex import LaTeX
15
+ from .layout_order import LayoutOrder
16
+ from .overlap import merge_fragments_as_line, remove_overlap_layouts
17
+ from .clipper import clip_from_image
18
+ from .types import (
19
+ ExtractedResult,
20
+ OCRFragment,
21
+ TableLayoutParsedFormat,
22
+ Layout,
23
+ LayoutClass,
24
+ PlainLayout,
25
+ TableLayout,
26
+ FormulaLayout,
27
+ )
28
+
29
+
30
+ class DocExtractor:
31
+ def __init__(
32
+ self,
33
+ model_dir_path: str,
34
+ device: Literal["cpu", "cuda"] = "cpu",
35
+ ocr_for_each_layouts: bool = True,
36
+ extract_formula: bool = True,
37
+ extract_table_format: TableLayoutParsedFormat | None = None,
38
+ ):
39
+ self._model_dir_path: str = model_dir_path
40
+ self._device: Literal["cpu", "cuda"] = device
41
+ self._ocr_for_each_layouts: bool = ocr_for_each_layouts
42
+ self._extract_formula: bool = extract_formula
43
+ self._extract_table_format: TableLayoutParsedFormat | None = extract_table_format
44
+ self._yolo: YOLOv10 | None = None
45
+ self._ocr: OCR = OCR(
46
+ device=device,
47
+ model_dir_path=os.path.join(model_dir_path, "onnx_ocr"),
48
+ )
49
+ self._table: Table = Table(
50
+ device=device,
51
+ model_path=os.path.join(model_dir_path, "struct_eqtable"),
52
+ )
53
+ self._latex: LaTeX = LaTeX(
54
+ model_path=os.path.join(model_dir_path, "latex"),
55
+ )
56
+ self._layout_order: LayoutOrder = LayoutOrder(
57
+ model_path=os.path.join(model_dir_path, "layoutreader"),
58
+ )
59
+
60
+ def extract(
61
+ self,
62
+ image: Image,
63
+ adjust_points: bool = False,
64
+ ) -> ExtractedResult:
65
+
66
+ raw_optimizer = RawOptimizer(image, adjust_points)
67
+ fragments = list(self._ocr.search_fragments(raw_optimizer.image_np))
68
+ raw_optimizer.receive_raw_fragments(fragments)
69
+ layouts = list(self._yolo_extract_layouts(raw_optimizer.image))
70
+ layouts = self._layouts_matched_by_fragments(fragments, layouts)
71
+ layouts = remove_overlap_layouts(layouts)
72
+
73
+ if self._ocr_for_each_layouts:
74
+ self._correct_fragments_by_ocr_layouts(raw_optimizer.image, layouts)
75
+
76
+ layouts = self._layout_order.sort(layouts, raw_optimizer.image.size)
77
+ layouts = [layout for layout in layouts if self._should_keep_layout(layout)]
78
+
79
+ self._parse_table_and_formula_layouts(layouts, raw_optimizer)
80
+
81
+ for layout in layouts:
82
+ layout.fragments = merge_fragments_as_line(layout.fragments)
83
+
84
+ raw_optimizer.receive_raw_layouts(layouts)
85
+
86
+ return ExtractedResult(
87
+ rotation=raw_optimizer.rotation,
88
+ layouts=layouts,
89
+ extracted_image=image,
90
+ adjusted_image=raw_optimizer.adjusted_image,
91
+ )
92
+
93
+ def _yolo_extract_layouts(self, source: Image) -> Generator[Layout, None, None]:
94
+ # about source parameter to see:
95
+ # https://github.com/opendatalab/DocLayout-YOLO/blob/7c4be36bc61f11b67cf4a44ee47f3c41e9800a91/doclayout_yolo/data/build.py#L157-L175
96
+ det_res = self._get_yolo().predict(
97
+ source=source,
98
+ imgsz=1024,
99
+ conf=0.2,
100
+ device=self._device # Device to use (e.g., "cuda" or "cpu")
101
+ )
102
+ boxes = det_res[0].__dict__["boxes"]
103
+
104
+ for cls_id, rect in zip(boxes.cls, boxes.xyxy):
105
+ cls_id = cls_id.item()
106
+ cls=LayoutClass(round(cls_id))
107
+
108
+ x1, y1, x2, y2 = rect
109
+ x1 = x1.item()
110
+ y1 = y1.item()
111
+ x2 = x2.item()
112
+ y2 = y2.item()
113
+ rect = Rectangle(
114
+ lt=(x1, y1),
115
+ rt=(x2, y1),
116
+ lb=(x1, y2),
117
+ rb=(x2, y2),
118
+ )
119
+ if cls == LayoutClass.TABLE:
120
+ yield TableLayout(cls=cls, rect=rect, fragments=[], parsed=None)
121
+ elif cls == LayoutClass.ISOLATE_FORMULA:
122
+ yield FormulaLayout(cls=cls, rect=rect, fragments=[], latex=None)
123
+ else:
124
+ yield PlainLayout(cls=cls, rect=rect, fragments=[])
125
+
126
+ def _layouts_matched_by_fragments(self, fragments: list[OCRFragment], layouts: list[Layout]):
127
+ layouts_group = self._split_layouts_by_group(layouts)
128
+ for fragment in fragments:
129
+ for sub_layouts in layouts_group:
130
+ layout = self._find_matched_layout(fragment, sub_layouts)
131
+ if layout is not None:
132
+ layout.fragments.append(fragment)
133
+ break
134
+ return layouts
135
+
136
+ def _correct_fragments_by_ocr_layouts(self, source: Image, layouts: list[Layout]):
137
+ for layout in layouts:
138
+ correct_fragments(self._ocr, source, layout)
139
+
140
+ def _parse_table_and_formula_layouts(self, layouts: list[Layout], raw_optimizer: RawOptimizer):
141
+ for layout in layouts:
142
+ if isinstance(layout, FormulaLayout) and self._extract_formula:
143
+ image = clip_from_image(raw_optimizer.image, layout.rect)
144
+ layout.latex = self._latex.extract(image)
145
+ elif isinstance(layout, TableLayout) and self._extract_table_format is not None:
146
+ image = clip_from_image(raw_optimizer.image, layout.rect)
147
+ parsed = self._table.predict(image, self._extract_table_format)
148
+ if parsed is not None:
149
+ layout.parsed = (parsed, self._extract_table_format)
150
+
151
+ def _split_layouts_by_group(self, layouts: list[Layout]):
152
+ texts_layouts: list[Layout] = []
153
+ abandon_layouts: list[Layout] = []
154
+
155
+ for layout in layouts:
156
+ cls = layout.cls
157
+ if cls == LayoutClass.TITLE or \
158
+ cls == LayoutClass.PLAIN_TEXT or \
159
+ cls == LayoutClass.FIGURE_CAPTION or \
160
+ cls == LayoutClass.TABLE_CAPTION or \
161
+ cls == LayoutClass.TABLE_FOOTNOTE or \
162
+ cls == LayoutClass.FORMULA_CAPTION:
163
+ texts_layouts.append(layout)
164
+ elif cls == LayoutClass.ABANDON:
165
+ abandon_layouts.append(layout)
166
+
167
+ return texts_layouts, abandon_layouts
168
+
169
+ def _find_matched_layout(self, fragment: OCRFragment, layouts: list[Layout]) -> Layout | None:
170
+ fragment_area = fragment.rect.area
171
+ primary_layouts: list[(Layout, float)] = []
172
+
173
+ if fragment_area == 0.0:
174
+ return None
175
+
176
+ for layout in layouts:
177
+ area = intersection_area(fragment.rect, layout.rect)
178
+ if area / fragment_area > 0.85:
179
+ primary_layouts.append((layout, layout.rect.area))
180
+
181
+ min_area: float = float("inf")
182
+ min_layout: Layout | None = None
183
+
184
+ for layout, area in primary_layouts:
185
+ if area < min_area:
186
+ min_area = area
187
+ min_layout = layout
188
+
189
+ return min_layout
190
+
191
+ def _get_yolo(self) -> YOLOv10:
192
+ if self._yolo is None:
193
+ base_path = os.path.join(self._model_dir_path, "yolo")
194
+ os.makedirs(base_path, exist_ok=True)
195
+ yolo_model_url = "https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0/resolve/main/models/Layout/YOLO/doclayout_yolo_ft.pt"
196
+ yolo_model_name = "doclayout_yolo_ft.pt"
197
+ yolo_model_path = Path(os.path.join(base_path, yolo_model_name))
198
+ if not yolo_model_path.exists():
199
+ download(yolo_model_url, yolo_model_path)
200
+ self._yolo = YOLOv10(str(yolo_model_path))
201
+ return self._yolo
202
+
203
+ def _should_keep_layout(self, layout: Layout) -> bool:
204
+ if len(layout.fragments) > 0:
205
+ return True
206
+ cls = layout.cls
207
+ return (
208
+ cls == LayoutClass.FIGURE or
209
+ cls == LayoutClass.TABLE or
210
+ cls == LayoutClass.ISOLATE_FORMULA
211
+ )
212
+
@@ -0,0 +1,57 @@
1
+ import os
2
+ import torch
3
+ import requests
4
+
5
+ from munch import Munch
6
+ from pix2tex.cli import LatexOCR
7
+ from PIL.Image import Image
8
+ from .utils import expand_image
9
+
10
+
11
+ class LaTeX:
12
+ def __init__(self, model_path: str):
13
+ self._model_path: str = model_path
14
+ self._model: LatexOCR | None = None
15
+
16
+ def extract(self, image: Image) -> str | None:
17
+ image = expand_image(image, 0.1) # 添加边缘提高识别准确率
18
+ model = self._get_model()
19
+ with torch.no_grad():
20
+ return model(image)
21
+
22
+ def _get_model(self) -> LatexOCR:
23
+ if self._model is None:
24
+ if not os.path.exists(self._model_path):
25
+ self._download_model()
26
+
27
+ self._model = LatexOCR(Munch({
28
+ "config": os.path.join("settings", "config.yaml"),
29
+ "checkpoint": os.path.join(self._model_path, "weights.pth"),
30
+ "no_cuda": True,
31
+ "no_resize": False,
32
+ }))
33
+ return self._model
34
+
35
+ # from https://github.com/lukas-blecher/LaTeX-OCR/blob/5c1ac929bd19a7ecf86d5fb8d94771c8969fcb80/pix2tex/model/checkpoints/get_latest_checkpoint.py#L37-L45
36
+ def _download_model(self):
37
+ os.makedirs(self._model_path, exist_ok=True)
38
+ tag = "v0.0.1"
39
+ files: list[tuple[str, str]] = (
40
+ ("weights.pth", f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/weights.pth"),
41
+ ("image_resizer.pth", f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/image_resizer.pth")
42
+ )
43
+ for file_name, url in files:
44
+ file_path = os.path.join(self._model_path, file_name)
45
+ try:
46
+ with open(file_path, "wb") as file:
47
+ response = requests.get(url, stream=True, timeout=15)
48
+ response.raise_for_status()
49
+ for chunk in response.iter_content(chunk_size=8192):
50
+ if chunk: # 过滤掉保持连接的新块
51
+ file.write(chunk)
52
+ file.flush()
53
+
54
+ except BaseException as e:
55
+ if os.path.exists(file_path):
56
+ os.remove(file_path)
57
+ raise e
@@ -0,0 +1,240 @@
1
+ import os
2
+ import torch
3
+
4
+ from typing import Generator
5
+ from dataclasses import dataclass
6
+ from transformers import LayoutLMv3ForTokenClassification
7
+
8
+ from .types import Layout, LayoutClass
9
+ from .layoutreader import prepare_inputs, boxes2inputs, parse_logits
10
+ from .utils import ensure_dir
11
+
12
+
13
+ @dataclass
14
+ class _BBox:
15
+ layout_index: int
16
+ fragment_index: int
17
+ virtual: bool
18
+ order: int
19
+ value: tuple[float, float, float, float]
20
+
21
+ class LayoutOrder:
22
+ def __init__(self, model_path: str):
23
+ self._model_path: str = model_path
24
+ self._model: LayoutLMv3ForTokenClassification | None = None
25
+
26
+ def _get_model(self) -> LayoutLMv3ForTokenClassification:
27
+ if self._model is None:
28
+ model_path = ensure_dir(self._model_path)
29
+ self._model = LayoutLMv3ForTokenClassification.from_pretrained(
30
+ pretrained_model_name_or_path="hantian/layoutreader",
31
+ cache_dir=model_path,
32
+ local_files_only=os.path.exists(os.path.join(model_path, "models--hantian--layoutreader")),
33
+ )
34
+ return self._model
35
+
36
+ def sort(self, layouts: list[Layout], size: tuple[int, int]) -> list[Layout]:
37
+ width, height = size
38
+ if width == 0 or height == 0:
39
+ return layouts
40
+
41
+ bbox_list = self._order_and_get_bbox_list(
42
+ layouts=layouts,
43
+ width=width,
44
+ height=height,
45
+ )
46
+ if bbox_list is None:
47
+ return layouts
48
+
49
+ return self._sort_layouts_and_fragments(layouts, bbox_list)
50
+
51
+ def _order_and_get_bbox_list(
52
+ self,
53
+ layouts: list[Layout],
54
+ width: int,
55
+ height: int,
56
+ ) -> list[_BBox] | None:
57
+
58
+ line_height = self._line_height(layouts)
59
+ bbox_list: list[_BBox] = []
60
+
61
+ for i, layout in enumerate(layouts):
62
+ if layout.cls == LayoutClass.PLAIN_TEXT and \
63
+ len(layout.fragments) > 0:
64
+ for j, fragment in enumerate(layout.fragments):
65
+ bbox_list.append(_BBox(
66
+ layout_index=i,
67
+ fragment_index=j,
68
+ virtual=False,
69
+ order=0,
70
+ value=fragment.rect.wrapper,
71
+ ))
72
+ else:
73
+ bbox_list.extend(
74
+ self._generate_virtual_lines(
75
+ layout=layout,
76
+ layout_index=i,
77
+ line_height=line_height,
78
+ width=width,
79
+ height=height,
80
+ ),
81
+ )
82
+
83
+ if len(bbox_list) > 200:
84
+ # https://github.com/opendatalab/MinerU/blob/980f5c8cd70f22f8c0c9b7b40eaff6f4804e6524/magic_pdf/pdf_parse_union_core_v2.py#L522
85
+ return None
86
+
87
+ layoutreader_size = 1000.0
88
+ x_scale = layoutreader_size / float(width)
89
+ y_scale = layoutreader_size / float(height)
90
+
91
+ for bbox in bbox_list:
92
+ x0, y0, x1, y1 = self._squeeze(bbox.value, width, height)
93
+ x0 = round(x0 * x_scale)
94
+ y0 = round(y0 * y_scale)
95
+ x1 = round(x1 * x_scale)
96
+ y1 = round(y1 * y_scale)
97
+ bbox.value = (x0, y0, x1, y1)
98
+
99
+ bbox_list.sort(key=lambda b: b.value) # 必须排序,乱序传入 layoutreader 会令它无法识别正确顺序
100
+ model = self._get_model()
101
+
102
+ with torch.no_grad():
103
+ inputs = boxes2inputs([list(bbox.value) for bbox in bbox_list])
104
+ inputs = prepare_inputs(inputs, model)
105
+ logits = model(**inputs).logits.cpu().squeeze(0)
106
+ orders = parse_logits(logits, len(bbox_list))
107
+
108
+ sorted_bbox_list = [bbox_list[i] for i in orders]
109
+ for i, bbox in enumerate(sorted_bbox_list):
110
+ bbox.order = i
111
+
112
+ return sorted_bbox_list
113
+
114
+ def _sort_layouts_and_fragments(self, layouts: list[Layout], bbox_list: list[_BBox]):
115
+ layout_bbox_list: list[list[_BBox]] = [[] for _ in range(len(layouts))]
116
+ for bbox in bbox_list:
117
+ layout_bbox_list[bbox.layout_index].append(bbox)
118
+
119
+ layouts_with_median_order: list[tuple[Layout, float]] = []
120
+ for layout_index, bbox_list in enumerate(layout_bbox_list):
121
+ layout = layouts[layout_index]
122
+ orders = [b.order for b in bbox_list] # virtual bbox 保证了 orders 不可能为空
123
+ median_order = self._median(orders)
124
+ layouts_with_median_order.append((layout, median_order))
125
+
126
+ for layout, bbox_list in zip(layouts, layout_bbox_list):
127
+ for bbox in bbox_list:
128
+ if not bbox.virtual:
129
+ layout.fragments[bbox.fragment_index].order = bbox.order
130
+ if all(not bbox.virtual for bbox in bbox_list):
131
+ layout.fragments.sort(key=lambda f: f.order)
132
+
133
+ layouts_with_median_order.sort(key=lambda x: x[1])
134
+ layouts = [layout for layout, _ in layouts_with_median_order]
135
+ next_fragment_order: int = 0
136
+
137
+ for layout in layouts:
138
+ for fragment in layout.fragments:
139
+ fragment.order = next_fragment_order
140
+ next_fragment_order += 1
141
+
142
+ return layouts
143
+
144
+ def _line_height(self, layouts: list[Layout]) -> float:
145
+ line_height: float = 0.0
146
+ count: int = 0
147
+ for layout in layouts:
148
+ for fragment in layout.fragments:
149
+ _, height = fragment.rect.size
150
+ line_height += height
151
+ count += 1
152
+ if count == 0:
153
+ return 10.0
154
+ return line_height / float(count)
155
+
156
+ def _generate_virtual_lines(
157
+ self,
158
+ layout: Layout,
159
+ layout_index: int,
160
+ line_height: float,
161
+ width: int,
162
+ height: int,
163
+ ) -> Generator[_BBox, None, None]:
164
+
165
+ # https://github.com/opendatalab/MinerU/blob/980f5c8cd70f22f8c0c9b7b40eaff6f4804e6524/magic_pdf/pdf_parse_union_core_v2.py#L451-L490
166
+ x0, y0, x1, y1 = layout.rect.wrapper
167
+ layout_height = y1 - y0
168
+ layout_weight = x1 - x0
169
+ lines = int(layout_height / line_height)
170
+
171
+ if layout_height <= line_height * 2:
172
+ yield _BBox(
173
+ layout_index=layout_index,
174
+ fragment_index=0,
175
+ virtual=True,
176
+ order=0,
177
+ value=(x0, y0, x1, y1),
178
+ )
179
+ return
180
+
181
+ elif layout_height <= height * 0.25 or \
182
+ width * 0.5 <= layout_weight or \
183
+ width * 0.25 < layout_weight:
184
+ if layout_weight > width * 0.4:
185
+ lines = 3
186
+ elif layout_weight <= width * 0.25:
187
+ if layout_height / layout_weight > 1.2: # 细长的不分
188
+ yield _BBox(
189
+ layout_index=layout_index,
190
+ fragment_index=0,
191
+ virtual=True,
192
+ order=0,
193
+ value=(x0, y0, x1, y1),
194
+ )
195
+ return
196
+ else: # 不细长的还是分成两行
197
+ lines = 2
198
+
199
+ lines = max(1, lines)
200
+ line_height = (y1 - y0) / lines
201
+ current_y = y0
202
+
203
+ for i in range(lines):
204
+ yield _BBox(
205
+ layout_index=layout_index,
206
+ fragment_index=i,
207
+ virtual=True,
208
+ order=0,
209
+ value=(x0, current_y, x1, current_y + line_height),
210
+ )
211
+ current_y += line_height
212
+
213
+ def _median(self, numbers: list[int]) -> float:
214
+ sorted_numbers = sorted(numbers)
215
+ n = len(sorted_numbers)
216
+
217
+ # 判断是奇数还是偶数个元素
218
+ if n % 2 == 1:
219
+ # 奇数情况,直接取中间的数
220
+ return float(sorted_numbers[n // 2])
221
+ else:
222
+ # 偶数情况,取中间两个数的平均值
223
+ mid1 = sorted_numbers[n // 2 - 1]
224
+ mid2 = sorted_numbers[n // 2]
225
+ return float((mid1 + mid2) / 2)
226
+
227
+ def _squeeze(self, bbox: _BBox, width: int, height: int) -> _BBox:
228
+ x0, y0, x1, y1 = bbox
229
+ x0 = self._squeeze_value(x0, width)
230
+ x1 = self._squeeze_value(x1, width)
231
+ y0 = self._squeeze_value(y0, height)
232
+ y1 = self._squeeze_value(y1, height)
233
+ return x0, y0, x1, y1
234
+
235
+ def _squeeze_value(self, position: float, size: int) -> float:
236
+ if position < 0:
237
+ position = 0.0
238
+ if position > size:
239
+ position = float(size)
240
+ return position
@@ -58,7 +58,6 @@ class OCR:
58
58
  self._text_system: TextSystem | None = None
59
59
 
60
60
  def search_fragments(self, image: np.ndarray) -> Generator[OCRFragment, None, None]:
61
- index: int = 0
62
61
  for box, res in self._ocr(image):
63
62
  text, rank = res
64
63
  if is_space_text(text):
@@ -74,12 +73,11 @@ class OCR:
74
73
  continue
75
74
 
76
75
  yield OCRFragment(
77
- order=index,
76
+ order=0,
78
77
  text=text,
79
78
  rank=rank,
80
79
  rect=rect,
81
80
  )
82
- index += 1
83
81
 
84
82
  def _ocr(self, image: np.ndarray) -> Generator[tuple[list[list[float]], tuple[str, float]], None, None]:
85
83
  text_system = self._get_text_system()
@@ -60,7 +60,7 @@ class _OverlapMatrixContext:
60
60
  rate >= _INCLUDES_MIN_RATE:
61
61
  yield i
62
62
 
63
- def regroup_lines(origin_fragments: list[OCRFragment]) -> list[OCRFragment]:
63
+ def merge_fragments_as_line(origin_fragments: list[OCRFragment]) -> list[OCRFragment]:
64
64
  fragments: list[OCRFragment] = []
65
65
  for group in _split_fragments_into_groups(origin_fragments):
66
66
  if len(group) == 1:
@@ -0,0 +1,49 @@
1
+ from .pix2s import Pix2Struct, Pix2StructTensorRT
2
+ from .internvl import InternVL, InternVL_LMDeploy
3
+
4
+ from transformers import AutoConfig
5
+
6
+
7
+ __ALL_MODELS__ = {
8
+ 'Pix2Struct': Pix2Struct,
9
+ 'Pix2StructTensorRT': Pix2StructTensorRT,
10
+ 'InternVL': InternVL,
11
+ 'InternVL_LMDeploy': InternVL_LMDeploy,
12
+ }
13
+
14
+
15
+ def get_model_name(model_path):
16
+ model_config = AutoConfig.from_pretrained(
17
+ model_path,
18
+ trust_remote_code=True,
19
+ )
20
+
21
+ if 'Pix2Struct' in model_config.architectures[0]:
22
+ model_name = 'Pix2Struct'
23
+ elif 'InternVL' in model_config.architectures[0]:
24
+ model_name = 'InternVL'
25
+ else:
26
+ raise ValueError(f"Unsupported model type: {model_config.architectures[0]}")
27
+
28
+ return model_name
29
+
30
+
31
+ def build_model(
32
+ model_ckpt='U4R/StructTable-InternVL2-1B',
33
+ cache_dir=None,
34
+ local_files_only=None,
35
+ **kwargs,
36
+ ):
37
+ model_name = get_model_name(model_ckpt)
38
+ if model_name == 'InternVL' and kwargs.get('lmdeploy', False):
39
+ model_name = 'InternVL_LMDeploy'
40
+ elif model_name == 'Pix2Struct' and kwargs.get('tensorrt_path', None):
41
+ model_name = 'Pix2StructTensorRT'
42
+
43
+ model = __ALL_MODELS__[model_name](
44
+ model_ckpt,
45
+ cache_dir=cache_dir,
46
+ local_files_only=local_files_only,
47
+ **kwargs
48
+ )
49
+ return model
@@ -0,0 +1,2 @@
1
+ from .internvl import InternVL
2
+ from .internvl_lmdeploy import InternVL_LMDeploy