doc-page-extractor 0.1.1__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc_page_extractor/__init__.py +5 -14
- doc_page_extractor/check_env.py +40 -0
- doc_page_extractor/extractor.py +87 -212
- doc_page_extractor/model.py +97 -0
- doc_page_extractor/parser.py +51 -0
- doc_page_extractor/plot.py +52 -79
- doc_page_extractor/redacter.py +111 -0
- doc_page_extractor-1.0.2.dist-info/METADATA +120 -0
- doc_page_extractor-1.0.2.dist-info/RECORD +11 -0
- {doc_page_extractor-0.1.1.dist-info → doc_page_extractor-1.0.2.dist-info}/WHEEL +1 -2
- doc_page_extractor-1.0.2.dist-info/licenses/LICENSE +21 -0
- doc_page_extractor/clipper.py +0 -119
- doc_page_extractor/downloader.py +0 -16
- doc_page_extractor/latex.py +0 -57
- doc_page_extractor/layout_order.py +0 -240
- doc_page_extractor/layoutreader.py +0 -126
- doc_page_extractor/ocr.py +0 -175
- doc_page_extractor/ocr_corrector.py +0 -126
- doc_page_extractor/onnxocr/__init__.py +0 -1
- doc_page_extractor/onnxocr/cls_postprocess.py +0 -26
- doc_page_extractor/onnxocr/db_postprocess.py +0 -246
- doc_page_extractor/onnxocr/imaug.py +0 -32
- doc_page_extractor/onnxocr/operators.py +0 -187
- doc_page_extractor/onnxocr/predict_base.py +0 -52
- doc_page_extractor/onnxocr/predict_cls.py +0 -89
- doc_page_extractor/onnxocr/predict_det.py +0 -120
- doc_page_extractor/onnxocr/predict_rec.py +0 -321
- doc_page_extractor/onnxocr/predict_system.py +0 -97
- doc_page_extractor/onnxocr/rec_postprocess.py +0 -896
- doc_page_extractor/onnxocr/utils.py +0 -71
- doc_page_extractor/overlap.py +0 -167
- doc_page_extractor/raw_optimizer.py +0 -104
- doc_page_extractor/rectangle.py +0 -72
- doc_page_extractor/rotation.py +0 -158
- doc_page_extractor/struct_eqtable/__init__.py +0 -49
- doc_page_extractor/struct_eqtable/internvl/__init__.py +0 -2
- doc_page_extractor/struct_eqtable/internvl/conversation.py +0 -394
- doc_page_extractor/struct_eqtable/internvl/internvl.py +0 -198
- doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +0 -81
- doc_page_extractor/struct_eqtable/pix2s/__init__.py +0 -3
- doc_page_extractor/struct_eqtable/pix2s/pix2s.py +0 -76
- doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +0 -1047
- doc_page_extractor/table.py +0 -71
- doc_page_extractor/types.py +0 -67
- doc_page_extractor/utils.py +0 -32
- doc_page_extractor-0.1.1.dist-info/METADATA +0 -84
- doc_page_extractor-0.1.1.dist-info/RECORD +0 -44
- doc_page_extractor-0.1.1.dist-info/licenses/LICENSE +0 -661
- doc_page_extractor-0.1.1.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -0
- tests/test_history_bus.py +0 -55
doc_page_extractor/__init__.py
CHANGED
|
@@ -1,15 +1,6 @@
|
|
|
1
|
-
from .extractor import
|
|
2
|
-
from .
|
|
1
|
+
from .extractor import Layout, PageExtractor
|
|
2
|
+
from .model import DeepSeekOCRSize
|
|
3
3
|
from .plot import plot
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
OCRFragment,
|
|
8
|
-
LayoutClass,
|
|
9
|
-
TableLayoutParsedFormat,
|
|
10
|
-
Layout,
|
|
11
|
-
BaseLayout,
|
|
12
|
-
PlainLayout,
|
|
13
|
-
FormulaLayout,
|
|
14
|
-
TableLayout,
|
|
15
|
-
)
|
|
4
|
+
|
|
5
|
+
__version__ = "1.0.0"
|
|
6
|
+
__all__ = ["DeepSeekOCRSize", "Layout", "PageExtractor", "plot"]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
_env_checked = False
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def check_env() -> None:
|
|
9
|
+
global _env_checked
|
|
10
|
+
if _env_checked:
|
|
11
|
+
return
|
|
12
|
+
_env_checked = True
|
|
13
|
+
|
|
14
|
+
if torch.cuda.is_available():
|
|
15
|
+
return
|
|
16
|
+
warnings.warn(
|
|
17
|
+
"""
|
|
18
|
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
19
|
+
CUDA is not available!
|
|
20
|
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
21
|
+
|
|
22
|
+
This package requires CUDA to run, but torch.cuda.is_available() returned False.
|
|
23
|
+
|
|
24
|
+
Possible causes:
|
|
25
|
+
1. You installed CPU-only PyTorch. Reinstall with CUDA support:
|
|
26
|
+
pip uninstall torch torchvision
|
|
27
|
+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
|
|
28
|
+
|
|
29
|
+
2. Your NVIDIA GPU driver is outdated. Update it from:
|
|
30
|
+
https://www.nvidia.com/download/index.aspx
|
|
31
|
+
|
|
32
|
+
3. You don't have a CUDA-compatible GPU.
|
|
33
|
+
|
|
34
|
+
To verify your setup, run: nvidia-smi
|
|
35
|
+
|
|
36
|
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
37
|
+
""".strip(),
|
|
38
|
+
RuntimeWarning,
|
|
39
|
+
stacklevel=2,
|
|
40
|
+
)
|
doc_page_extractor/extractor.py
CHANGED
|
@@ -1,213 +1,88 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
from
|
|
1
|
+
import tempfile
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from os import PathLike
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
from .
|
|
10
|
-
from .
|
|
11
|
-
from .
|
|
12
|
-
from .
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
extracted_image=image,
|
|
90
|
-
adjusted_image=raw_optimizer.adjusted_image,
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
def _yolo_extract_layouts(self, source: Image) -> Generator[Layout, None, None]:
|
|
94
|
-
# about source parameter to see:
|
|
95
|
-
# https://github.com/opendatalab/DocLayout-YOLO/blob/7c4be36bc61f11b67cf4a44ee47f3c41e9800a91/doclayout_yolo/data/build.py#L157-L175
|
|
96
|
-
det_res = self._get_yolo().predict(
|
|
97
|
-
source=source,
|
|
98
|
-
imgsz=1024,
|
|
99
|
-
conf=0.2,
|
|
100
|
-
device=self._device # Device to use (e.g., "cuda" or "cpu")
|
|
101
|
-
)
|
|
102
|
-
boxes = det_res[0].__dict__["boxes"]
|
|
103
|
-
|
|
104
|
-
for cls_id, rect in zip(boxes.cls, boxes.xyxy):
|
|
105
|
-
cls_id = cls_id.item()
|
|
106
|
-
cls=LayoutClass(round(cls_id))
|
|
107
|
-
|
|
108
|
-
x1, y1, x2, y2 = rect
|
|
109
|
-
x1 = x1.item()
|
|
110
|
-
y1 = y1.item()
|
|
111
|
-
x2 = x2.item()
|
|
112
|
-
y2 = y2.item()
|
|
113
|
-
rect = Rectangle(
|
|
114
|
-
lt=(x1, y1),
|
|
115
|
-
rt=(x2, y1),
|
|
116
|
-
lb=(x1, y2),
|
|
117
|
-
rb=(x2, y2),
|
|
118
|
-
)
|
|
119
|
-
if rect.is_valid:
|
|
120
|
-
if cls == LayoutClass.TABLE:
|
|
121
|
-
yield TableLayout(cls=cls, rect=rect, fragments=[], parsed=None)
|
|
122
|
-
elif cls == LayoutClass.ISOLATE_FORMULA:
|
|
123
|
-
yield FormulaLayout(cls=cls, rect=rect, fragments=[], latex=None)
|
|
124
|
-
else:
|
|
125
|
-
yield PlainLayout(cls=cls, rect=rect, fragments=[])
|
|
126
|
-
|
|
127
|
-
def _layouts_matched_by_fragments(self, fragments: list[OCRFragment], layouts: list[Layout]):
|
|
128
|
-
layouts_group = self._split_layouts_by_group(layouts)
|
|
129
|
-
for fragment in fragments:
|
|
130
|
-
for sub_layouts in layouts_group:
|
|
131
|
-
layout = self._find_matched_layout(fragment, sub_layouts)
|
|
132
|
-
if layout is not None:
|
|
133
|
-
layout.fragments.append(fragment)
|
|
134
|
-
break
|
|
135
|
-
return layouts
|
|
136
|
-
|
|
137
|
-
def _correct_fragments_by_ocr_layouts(self, source: Image, layouts: list[Layout]):
|
|
138
|
-
for layout in layouts:
|
|
139
|
-
correct_fragments(self._ocr, source, layout)
|
|
140
|
-
|
|
141
|
-
def _parse_table_and_formula_layouts(self, layouts: list[Layout], raw_optimizer: RawOptimizer):
|
|
142
|
-
for layout in layouts:
|
|
143
|
-
if isinstance(layout, FormulaLayout) and self._extract_formula:
|
|
144
|
-
image = clip_from_image(raw_optimizer.image, layout.rect)
|
|
145
|
-
layout.latex = self._latex.extract(image)
|
|
146
|
-
elif isinstance(layout, TableLayout) and self._extract_table_format is not None:
|
|
147
|
-
image = clip_from_image(raw_optimizer.image, layout.rect)
|
|
148
|
-
parsed = self._table.predict(image, self._extract_table_format)
|
|
149
|
-
if parsed is not None:
|
|
150
|
-
layout.parsed = (parsed, self._extract_table_format)
|
|
151
|
-
|
|
152
|
-
def _split_layouts_by_group(self, layouts: list[Layout]):
|
|
153
|
-
texts_layouts: list[Layout] = []
|
|
154
|
-
abandon_layouts: list[Layout] = []
|
|
155
|
-
|
|
156
|
-
for layout in layouts:
|
|
157
|
-
cls = layout.cls
|
|
158
|
-
if cls == LayoutClass.TITLE or \
|
|
159
|
-
cls == LayoutClass.PLAIN_TEXT or \
|
|
160
|
-
cls == LayoutClass.FIGURE_CAPTION or \
|
|
161
|
-
cls == LayoutClass.TABLE_CAPTION or \
|
|
162
|
-
cls == LayoutClass.TABLE_FOOTNOTE or \
|
|
163
|
-
cls == LayoutClass.FORMULA_CAPTION:
|
|
164
|
-
texts_layouts.append(layout)
|
|
165
|
-
elif cls == LayoutClass.ABANDON:
|
|
166
|
-
abandon_layouts.append(layout)
|
|
167
|
-
|
|
168
|
-
return texts_layouts, abandon_layouts
|
|
169
|
-
|
|
170
|
-
def _find_matched_layout(self, fragment: OCRFragment, layouts: list[Layout]) -> Layout | None:
|
|
171
|
-
fragment_area = fragment.rect.area
|
|
172
|
-
primary_layouts: list[(Layout, float)] = []
|
|
173
|
-
|
|
174
|
-
if fragment_area == 0.0:
|
|
175
|
-
return None
|
|
176
|
-
|
|
177
|
-
for layout in layouts:
|
|
178
|
-
area = intersection_area(fragment.rect, layout.rect)
|
|
179
|
-
if area / fragment_area > 0.85:
|
|
180
|
-
primary_layouts.append((layout, layout.rect.area))
|
|
181
|
-
|
|
182
|
-
min_area: float = float("inf")
|
|
183
|
-
min_layout: Layout | None = None
|
|
184
|
-
|
|
185
|
-
for layout, area in primary_layouts:
|
|
186
|
-
if area < min_area:
|
|
187
|
-
min_area = area
|
|
188
|
-
min_layout = layout
|
|
189
|
-
|
|
190
|
-
return min_layout
|
|
191
|
-
|
|
192
|
-
def _get_yolo(self) -> YOLOv10:
|
|
193
|
-
if self._yolo is None:
|
|
194
|
-
base_path = os.path.join(self._model_dir_path, "yolo")
|
|
195
|
-
os.makedirs(base_path, exist_ok=True)
|
|
196
|
-
yolo_model_url = "https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0/resolve/main/models/Layout/YOLO/doclayout_yolo_ft.pt"
|
|
197
|
-
yolo_model_name = "doclayout_yolo_ft.pt"
|
|
198
|
-
yolo_model_path = Path(os.path.join(base_path, yolo_model_name))
|
|
199
|
-
if not yolo_model_path.exists():
|
|
200
|
-
download(yolo_model_url, yolo_model_path)
|
|
201
|
-
self._yolo = YOLOv10(str(yolo_model_path))
|
|
202
|
-
return self._yolo
|
|
203
|
-
|
|
204
|
-
def _should_keep_layout(self, layout: Layout) -> bool:
|
|
205
|
-
if len(layout.fragments) > 0:
|
|
206
|
-
return True
|
|
207
|
-
cls = layout.cls
|
|
208
|
-
return (
|
|
209
|
-
cls == LayoutClass.FIGURE or
|
|
210
|
-
cls == LayoutClass.TABLE or
|
|
211
|
-
cls == LayoutClass.ISOLATE_FORMULA
|
|
212
|
-
)
|
|
213
|
-
|
|
5
|
+
from typing import Generator, cast
|
|
6
|
+
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from .check_env import check_env
|
|
10
|
+
from .model import DeepSeekOCRModel, DeepSeekOCRSize
|
|
11
|
+
from .parser import ParsedItemKind, parse_ocr_response
|
|
12
|
+
from .redacter import background_color, redact
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class Layout:
|
|
17
|
+
ref: str
|
|
18
|
+
det: tuple[int, int, int, int]
|
|
19
|
+
text: str | None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PageExtractor:
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
model_path: PathLike | None = None,
|
|
26
|
+
local_only: bool = False,
|
|
27
|
+
) -> None:
|
|
28
|
+
self._model: DeepSeekOCRModel = DeepSeekOCRModel(
|
|
29
|
+
model_path=Path(model_path) if model_path else None,
|
|
30
|
+
local_only=local_only,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def download_models(self) -> None:
|
|
34
|
+
self._model.download()
|
|
35
|
+
|
|
36
|
+
def load_models(self) -> None:
|
|
37
|
+
self._model.load()
|
|
38
|
+
|
|
39
|
+
def extract(
|
|
40
|
+
self, image: Image.Image, size: DeepSeekOCRSize, stages: int = 1
|
|
41
|
+
) -> Generator[tuple[Image.Image, list[Layout]], None, None]:
|
|
42
|
+
check_env()
|
|
43
|
+
assert stages >= 1, "stages must be at least 1"
|
|
44
|
+
with tempfile.TemporaryDirectory() as temp_path:
|
|
45
|
+
fill_color: tuple[int, int, int] | None = None
|
|
46
|
+
for i in range(stages):
|
|
47
|
+
response = self._model.generate(
|
|
48
|
+
image=image,
|
|
49
|
+
prompt="<image>\n<|grounding|>Convert the document to markdown.",
|
|
50
|
+
temp_path=temp_path,
|
|
51
|
+
size=size,
|
|
52
|
+
)
|
|
53
|
+
layouts: list[Layout] = []
|
|
54
|
+
for ref, det, text in self._parse_response(image, response):
|
|
55
|
+
layouts.append(Layout(ref, det, text))
|
|
56
|
+
yield image, layouts
|
|
57
|
+
if i < stages - 1:
|
|
58
|
+
if fill_color is None:
|
|
59
|
+
fill_color = background_color(image)
|
|
60
|
+
image = redact(
|
|
61
|
+
image=image.copy(),
|
|
62
|
+
fill_color=fill_color,
|
|
63
|
+
rectangles=(layout.det for layout in layouts),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def _parse_response(
|
|
67
|
+
self, image: Image.Image, response: str
|
|
68
|
+
) -> Generator[tuple[str, tuple[int, int, int, int], str | None], None, None]:
|
|
69
|
+
width, height = image.size
|
|
70
|
+
det: tuple[int, int, int, int] | None = None
|
|
71
|
+
ref: str | None = None
|
|
72
|
+
|
|
73
|
+
for kind, content in parse_ocr_response(response, width, height):
|
|
74
|
+
if kind == ParsedItemKind.TEXT:
|
|
75
|
+
if det is not None and ref is not None:
|
|
76
|
+
yield ref, det, cast(str, content)
|
|
77
|
+
det = None
|
|
78
|
+
ref = None
|
|
79
|
+
if det is not None and ref is not None:
|
|
80
|
+
yield ref, det, None
|
|
81
|
+
det = None
|
|
82
|
+
ref = None
|
|
83
|
+
elif kind == ParsedItemKind.DET:
|
|
84
|
+
det = cast(tuple[int, int, int, int], content)
|
|
85
|
+
elif kind == ParsedItemKind.REF:
|
|
86
|
+
ref = cast(str, content)
|
|
87
|
+
if det is not None and ref is not None:
|
|
88
|
+
yield ref, det, None
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from huggingface_hub import snapshot_download
|
|
8
|
+
from PIL import Image
|
|
9
|
+
from transformers import AutoModel, AutoTokenizer
|
|
10
|
+
|
|
11
|
+
DeepSeekOCRSize = Literal["tiny", "small", "base", "large", "gundam"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class _SizeConfig:
|
|
16
|
+
base_size: int
|
|
17
|
+
image_size: int
|
|
18
|
+
crop_mode: bool
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
_SIZE_CONFIGS: dict[DeepSeekOCRSize, _SizeConfig] = {
|
|
22
|
+
"tiny": _SizeConfig(base_size=512, image_size=512, crop_mode=False),
|
|
23
|
+
"small": _SizeConfig(base_size=640, image_size=640, crop_mode=False),
|
|
24
|
+
"base": _SizeConfig(base_size=1024, image_size=1024, crop_mode=False),
|
|
25
|
+
"large": _SizeConfig(base_size=1280, image_size=1280, crop_mode=False),
|
|
26
|
+
"gundam": _SizeConfig(base_size=1024, image_size=640, crop_mode=True),
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
_ATTN_IMPLEMENTATION: str
|
|
30
|
+
try:
|
|
31
|
+
import flash_attn # type: ignore # pylint: disable=unused-import
|
|
32
|
+
|
|
33
|
+
_ATTN_IMPLEMENTATION = "flash_attention_2"
|
|
34
|
+
except ImportError:
|
|
35
|
+
_ATTN_IMPLEMENTATION = "eager"
|
|
36
|
+
|
|
37
|
+
_Models = tuple[Any, Any]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class DeepSeekOCRModel:
|
|
41
|
+
def __init__(self, model_path: Path | None, local_only: bool) -> None:
|
|
42
|
+
self._model_name = "deepseek-ai/DeepSeek-OCR"
|
|
43
|
+
self._cache_dir = str(model_path) if model_path else None
|
|
44
|
+
self._local_only = local_only
|
|
45
|
+
self._models: _Models | None = None
|
|
46
|
+
|
|
47
|
+
def download(self) -> None:
|
|
48
|
+
snapshot_download(
|
|
49
|
+
repo_id=self._model_name,
|
|
50
|
+
repo_type="model",
|
|
51
|
+
cache_dir=self._cache_dir,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def load(self) -> None:
|
|
55
|
+
self._ensure_models()
|
|
56
|
+
|
|
57
|
+
def _ensure_models(self) -> _Models:
|
|
58
|
+
if self._models is None:
|
|
59
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
60
|
+
self._model_name,
|
|
61
|
+
trust_remote_code=True,
|
|
62
|
+
cache_dir=self._cache_dir,
|
|
63
|
+
local_files_only=self._local_only,
|
|
64
|
+
)
|
|
65
|
+
model = AutoModel.from_pretrained(
|
|
66
|
+
pretrained_model_name_or_path=self._model_name,
|
|
67
|
+
_attn_implementation=_ATTN_IMPLEMENTATION,
|
|
68
|
+
trust_remote_code=True,
|
|
69
|
+
use_safetensors=True,
|
|
70
|
+
cache_dir=self._cache_dir,
|
|
71
|
+
local_files_only=self._local_only,
|
|
72
|
+
)
|
|
73
|
+
model = model.cuda().to(torch.bfloat16)
|
|
74
|
+
self._models = (tokenizer, model)
|
|
75
|
+
|
|
76
|
+
return self._models
|
|
77
|
+
|
|
78
|
+
def generate(
|
|
79
|
+
self, image: Image.Image, prompt: str, temp_path: str, size: DeepSeekOCRSize
|
|
80
|
+
) -> str:
|
|
81
|
+
tokenizer, model = self._ensure_models()
|
|
82
|
+
config = _SIZE_CONFIGS[size]
|
|
83
|
+
temp_image_path = os.path.join(temp_path, "temp_image.png")
|
|
84
|
+
image.save(temp_image_path)
|
|
85
|
+
text_result = model.infer(
|
|
86
|
+
tokenizer,
|
|
87
|
+
prompt=prompt,
|
|
88
|
+
image_file=temp_image_path,
|
|
89
|
+
output_path=temp_path,
|
|
90
|
+
base_size=config.base_size,
|
|
91
|
+
image_size=config.image_size,
|
|
92
|
+
crop_mode=config.crop_mode,
|
|
93
|
+
save_results=True,
|
|
94
|
+
test_compress=True,
|
|
95
|
+
eval_mode=True,
|
|
96
|
+
)
|
|
97
|
+
return text_result
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from enum import Enum, auto
|
|
3
|
+
from typing import Generator
|
|
4
|
+
|
|
5
|
+
_TAG_PATTERN = re.compile(r"<\|(det|ref)\|>(.+?)<\|/\1\|>")
|
|
6
|
+
_DET_COORDS_PATTERN = re.compile(r"\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ParsedItemKind(Enum):
|
|
10
|
+
DET = auto()
|
|
11
|
+
REF = auto()
|
|
12
|
+
TEXT = auto()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
ParsedItem = (
|
|
16
|
+
tuple[ParsedItemKind.DET, tuple[int, int, int, int]]
|
|
17
|
+
| tuple[ParsedItemKind.REF, str]
|
|
18
|
+
| tuple[ParsedItemKind.TEXT, str]
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def parse_ocr_response(
|
|
23
|
+
response: str, width: int, height: int
|
|
24
|
+
) -> Generator[ParsedItem, None, None]:
|
|
25
|
+
last_end: int = 0
|
|
26
|
+
for matched in _TAG_PATTERN.finditer(response):
|
|
27
|
+
if matched.start() > last_end:
|
|
28
|
+
plain_text = response[last_end : matched.start()]
|
|
29
|
+
if plain_text:
|
|
30
|
+
yield ParsedItemKind.TEXT, plain_text
|
|
31
|
+
tag_type = matched.group(1)
|
|
32
|
+
content = matched.group(2)
|
|
33
|
+
if tag_type == "det":
|
|
34
|
+
coords_match = _DET_COORDS_PATTERN.search(content)
|
|
35
|
+
if coords_match:
|
|
36
|
+
x1_norm, y1_norm, x2_norm, y2_norm = [
|
|
37
|
+
int(c) for c in coords_match.groups()
|
|
38
|
+
]
|
|
39
|
+
x1 = round(x1_norm / 1000 * width)
|
|
40
|
+
y1 = round(y1_norm / 1000 * height)
|
|
41
|
+
x2 = round(x2_norm / 1000 * width)
|
|
42
|
+
y2 = round(y2_norm / 1000 * height)
|
|
43
|
+
yield ParsedItemKind.DET, (x1, y1, x2, y2)
|
|
44
|
+
elif tag_type == "ref":
|
|
45
|
+
yield ParsedItemKind.REF, content
|
|
46
|
+
last_end = matched.end()
|
|
47
|
+
|
|
48
|
+
if last_end < len(response):
|
|
49
|
+
plain_text = response[last_end:]
|
|
50
|
+
if plain_text:
|
|
51
|
+
yield ParsedItemKind.TEXT, plain_text
|
doc_page_extractor/plot.py
CHANGED
|
@@ -1,91 +1,64 @@
|
|
|
1
|
-
from typing import Iterable
|
|
1
|
+
from typing import Iterable, cast
|
|
2
|
+
|
|
2
3
|
from PIL import ImageDraw
|
|
3
|
-
from PIL.ImageFont import load_default, FreeTypeFont
|
|
4
4
|
from PIL.Image import Image
|
|
5
|
-
from .
|
|
6
|
-
|
|
5
|
+
from PIL.ImageFont import FreeTypeFont, load_default
|
|
6
|
+
|
|
7
|
+
from .extractor import Layout
|
|
7
8
|
|
|
8
|
-
_FRAGMENT_COLOR = (0x49, 0xCF, 0xCB)
|
|
9
|
+
_FRAGMENT_COLOR = (0x49, 0xCF, 0xCB) # Light Green
|
|
9
10
|
_Color = tuple[int, int, int]
|
|
10
11
|
|
|
11
|
-
def plot(image: Image, layouts: Iterable[Layout]) -> None:
|
|
12
|
-
layout_font = load_default(size=35)
|
|
13
|
-
fragment_font = load_default(size=25)
|
|
14
|
-
draw = ImageDraw.Draw(image, mode="RGBA")
|
|
15
12
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
text = str(object=number)
|
|
20
|
-
width = len(text) * font.size
|
|
21
|
-
offset = round(font.size * 0.15)
|
|
13
|
+
def plot(image: Image, layouts: Iterable[Layout]) -> Image:
|
|
14
|
+
layout_font = cast(FreeTypeFont, load_default(size=35))
|
|
15
|
+
draw = ImageDraw.Draw(image, mode="RGBA")
|
|
22
16
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
17
|
+
def _draw_text(
|
|
18
|
+
position: tuple[int, int],
|
|
19
|
+
text: str,
|
|
20
|
+
font: FreeTypeFont,
|
|
21
|
+
bold: bool,
|
|
22
|
+
color: _Color,
|
|
23
|
+
) -> None:
|
|
24
|
+
nonlocal draw
|
|
25
|
+
x, y = position
|
|
26
|
+
bbox = font.getbbox(text)
|
|
27
|
+
text_width = bbox[2] - bbox[0]
|
|
28
|
+
offset = round(font.size * 0.15)
|
|
30
29
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
30
|
+
for dx, dy in _generate_delta(bold):
|
|
31
|
+
draw.text(
|
|
32
|
+
xy=(x + dx - text_width - offset, y + dy),
|
|
33
|
+
text=text,
|
|
34
|
+
font=font,
|
|
35
|
+
fill=color,
|
|
36
|
+
)
|
|
37
37
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
_draw_number(
|
|
46
|
-
position=fragment.rect.lt,
|
|
47
|
-
number=fragment.order + 1,
|
|
48
|
-
font=fragment_font,
|
|
49
|
-
bold=False,
|
|
50
|
-
color=_FRAGMENT_COLOR,
|
|
51
|
-
)
|
|
38
|
+
for layout in layouts:
|
|
39
|
+
x1, y1, x2, y2 = layout.det
|
|
40
|
+
draw.polygon(
|
|
41
|
+
xy=[(x1, y1), (x2, y1), (x2, y2), (x1, y2)],
|
|
42
|
+
outline=_FRAGMENT_COLOR,
|
|
43
|
+
width=5,
|
|
44
|
+
)
|
|
52
45
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
46
|
+
for i, layout in enumerate(layouts):
|
|
47
|
+
x1, y1, _, _ = layout.det
|
|
48
|
+
_draw_text(
|
|
49
|
+
position=(x1, y1),
|
|
50
|
+
text=f"{i + 1}. {layout.ref.strip()}",
|
|
51
|
+
font=layout_font,
|
|
52
|
+
bold=True,
|
|
53
|
+
color=_FRAGMENT_COLOR,
|
|
54
|
+
)
|
|
55
|
+
return image
|
|
61
56
|
|
|
62
|
-
def _generate_delta(bold: bool):
|
|
63
|
-
if bold:
|
|
64
|
-
for dx in range(-1, 2):
|
|
65
|
-
for dy in range(-1, 2):
|
|
66
|
-
yield dx, dy
|
|
67
|
-
else:
|
|
68
|
-
yield 0, 0
|
|
69
57
|
|
|
70
|
-
def
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
return (0xC0, 0xBB, 0xA9) # Gray
|
|
78
|
-
elif cls == LayoutClass.FIGURE:
|
|
79
|
-
return (0x5B, 0x91, 0x3C) # Dark Green
|
|
80
|
-
elif cls == LayoutClass.FIGURE_CAPTION:
|
|
81
|
-
return (0x77, 0xB3, 0x54) # Green
|
|
82
|
-
elif cls == LayoutClass.TABLE:
|
|
83
|
-
return (0x44, 0x17, 0x52) # Dark Purple
|
|
84
|
-
elif cls == LayoutClass.TABLE_CAPTION:
|
|
85
|
-
return (0x81, 0x75, 0xA0) # Purple
|
|
86
|
-
elif cls == LayoutClass.TABLE_FOOTNOTE:
|
|
87
|
-
return (0xEF, 0xB6, 0xC9) # Pink Purple
|
|
88
|
-
elif cls == LayoutClass.ISOLATE_FORMULA:
|
|
89
|
-
return (0xFA, 0x38, 0x27) # Red
|
|
90
|
-
elif cls == LayoutClass.FORMULA_CAPTION:
|
|
91
|
-
return (0xFF, 0x9D, 0x24) # Orange
|
|
58
|
+
def _generate_delta(bold: bool):
|
|
59
|
+
if bold:
|
|
60
|
+
for dx in range(-1, 2):
|
|
61
|
+
for dy in range(-1, 2):
|
|
62
|
+
yield dx, dy
|
|
63
|
+
else:
|
|
64
|
+
yield 0, 0
|