magic-pdf 0.5.12__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/cli/magicpdf.py +23 -8
- magic_pdf/libs/config_reader.py +10 -0
- magic_pdf/libs/language.py +3 -3
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/__init__.py +1 -0
- magic_pdf/model/doc_analyze_by_custom_model.py +38 -15
- magic_pdf/model/model_list.py +1 -0
- magic_pdf/model/pdf_extract_kit.py +196 -0
- magic_pdf/model/pek_sub_modules/__init__.py +0 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py +0 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py +179 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py +671 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py +476 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py +7 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py +2 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py +171 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py +124 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py +136 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py +284 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py +213 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py +7 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +24 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +60 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +1282 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +32 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +34 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py +150 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py +163 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py +1236 -0
- magic_pdf/model/pek_sub_modules/post_process.py +36 -0
- magic_pdf/model/pek_sub_modules/self_modify.py +260 -0
- magic_pdf/model/pp_structure_v2.py +7 -0
- magic_pdf/pipe/AbsPipe.py +8 -14
- magic_pdf/pipe/OCRPipe.py +12 -8
- magic_pdf/pipe/TXTPipe.py +12 -8
- magic_pdf/pipe/UNIPipe.py +9 -7
- magic_pdf/resources/model_config/UniMERNet/demo.yaml +46 -0
- magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +351 -0
- magic_pdf/resources/model_config/model_configs.yaml +9 -0
- {magic_pdf-0.5.12.dist-info → magic_pdf-0.6.0.dist-info}/METADATA +68 -34
- {magic_pdf-0.5.12.dist-info → magic_pdf-0.6.0.dist-info}/RECORD +45 -19
- magic_pdf/model/360_layout_analysis.py +0 -8
- {magic_pdf-0.5.12.dist-info → magic_pdf-0.6.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.5.12.dist-info → magic_pdf-0.6.0.dist-info}/WHEEL +0 -0
- {magic_pdf-0.5.12.dist-info → magic_pdf-0.6.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.5.12.dist-info → magic_pdf-0.6.0.dist-info}/top_level.txt +0 -0
magic_pdf/cli/magicpdf.py
CHANGED
@@ -28,18 +28,20 @@ from loguru import logger
|
|
28
28
|
from pathlib import Path
|
29
29
|
from magic_pdf.libs.version import __version__
|
30
30
|
|
31
|
-
from magic_pdf.libs.MakeContentConfig import DropMode
|
31
|
+
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
|
32
32
|
from magic_pdf.libs.draw_bbox import draw_layout_bbox, draw_span_bbox
|
33
33
|
from magic_pdf.pipe.UNIPipe import UNIPipe
|
34
34
|
from magic_pdf.pipe.OCRPipe import OCRPipe
|
35
35
|
from magic_pdf.pipe.TXTPipe import TXTPipe
|
36
|
-
from magic_pdf.libs.config_reader import get_s3_config
|
37
36
|
from magic_pdf.libs.path_utils import (
|
38
37
|
parse_s3path,
|
39
38
|
parse_s3_range_params,
|
40
39
|
remove_non_official_s3_args,
|
41
40
|
)
|
42
|
-
from magic_pdf.libs.config_reader import
|
41
|
+
from magic_pdf.libs.config_reader import (
|
42
|
+
get_local_dir,
|
43
|
+
get_s3_config,
|
44
|
+
)
|
43
45
|
from magic_pdf.rw.S3ReaderWriter import S3ReaderWriter
|
44
46
|
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
|
45
47
|
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
|
@@ -81,10 +83,12 @@ def do_parse(
|
|
81
83
|
f_dump_model_json=True,
|
82
84
|
f_dump_orig_pdf=True,
|
83
85
|
f_dump_content_list=True,
|
86
|
+
f_make_md_mode=MakeMode.MM_MD,
|
84
87
|
):
|
85
88
|
orig_model_list = copy.deepcopy(model_list)
|
86
89
|
|
87
90
|
local_image_dir, local_md_dir = prepare_env(pdf_file_name, parse_method)
|
91
|
+
logger.info(f"local output dir is {local_md_dir}")
|
88
92
|
image_writer, md_writer = DiskReaderWriter(local_image_dir), DiskReaderWriter(local_md_dir)
|
89
93
|
image_dir = str(os.path.basename(local_image_dir))
|
90
94
|
|
@@ -105,6 +109,7 @@ def do_parse(
|
|
105
109
|
if len(model_list) == 0:
|
106
110
|
if model_config.__use_inside_model__:
|
107
111
|
pipe.pipe_analyze()
|
112
|
+
orig_model_list = copy.deepcopy(pipe.model_list)
|
108
113
|
else:
|
109
114
|
logger.error("need model list input")
|
110
115
|
exit(1)
|
@@ -116,7 +121,7 @@ def do_parse(
|
|
116
121
|
if f_draw_span_bbox:
|
117
122
|
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir)
|
118
123
|
|
119
|
-
md_content = pipe.pipe_mk_markdown(image_dir, drop_mode=DropMode.NONE)
|
124
|
+
md_content = pipe.pipe_mk_markdown(image_dir, drop_mode=DropMode.NONE, md_make_mode=f_make_md_mode)
|
120
125
|
if f_dump_md:
|
121
126
|
"""写markdown"""
|
122
127
|
md_writer.write(
|
@@ -175,8 +180,10 @@ def cli():
|
|
175
180
|
default="auto",
|
176
181
|
)
|
177
182
|
@click.option("--inside_model", type=click.BOOL, default=False, help="使用内置模型测试")
|
178
|
-
|
183
|
+
@click.option("--model_mode", type=click.STRING, default="full", help="内置模型选择。lite: 快速解析,精度较低,full: 高精度解析,速度较慢")
|
184
|
+
def json_command(json, method, inside_model, model_mode):
|
179
185
|
model_config.__use_inside_model__ = inside_model
|
186
|
+
model_config.__model_mode__ = model_mode
|
180
187
|
|
181
188
|
if not json.startswith("s3://"):
|
182
189
|
logger.error("usage: magic-pdf json-command --json s3://some_bucket/some_path")
|
@@ -226,8 +233,10 @@ def json_command(json, method, inside_model):
|
|
226
233
|
default="auto",
|
227
234
|
)
|
228
235
|
@click.option("--inside_model", type=click.BOOL, default=False, help="使用内置模型测试")
|
229
|
-
|
236
|
+
@click.option("--model_mode", type=click.STRING, default="full", help="内置模型选择。lite: 快速解析,精度较低,full: 高精度解析,速度较慢")
|
237
|
+
def local_json_command(local_json, method, inside_model, model_mode):
|
230
238
|
model_config.__use_inside_model__ = inside_model
|
239
|
+
model_config.__model_mode__ = model_mode
|
231
240
|
|
232
241
|
def read_s3_path(s3path):
|
233
242
|
bucket, key = parse_s3path(s3path)
|
@@ -278,8 +287,10 @@ def local_json_command(local_json, method, inside_model):
|
|
278
287
|
default="auto",
|
279
288
|
)
|
280
289
|
@click.option("--inside_model", type=click.BOOL, default=False, help="使用内置模型测试")
|
281
|
-
|
290
|
+
@click.option("--model_mode", type=click.STRING, default="full", help="内置模型选择。lite: 快速解析,精度较低,full: 高精度解析,速度较慢")
|
291
|
+
def pdf_command(pdf, model, method, inside_model, model_mode):
|
282
292
|
model_config.__use_inside_model__ = inside_model
|
293
|
+
model_config.__model_mode__ = model_mode
|
283
294
|
|
284
295
|
def read_fn(path):
|
285
296
|
disk_rw = DiskReaderWriter(os.path.dirname(path))
|
@@ -290,7 +301,11 @@ def pdf_command(pdf, model, method, inside_model):
|
|
290
301
|
def get_model_json(model_path):
|
291
302
|
# 这里处理pdf和模型相关的逻辑
|
292
303
|
if model_path is None:
|
293
|
-
|
304
|
+
file_name_without_extension, extension = os.path.splitext(pdf)
|
305
|
+
if extension == ".pdf":
|
306
|
+
model_path = file_name_without_extension + ".json"
|
307
|
+
else:
|
308
|
+
raise Exception("pdf_path input error")
|
294
309
|
if not os.path.exists(model_path):
|
295
310
|
logger.warning(
|
296
311
|
f"not found json {model_path} existed"
|
magic_pdf/libs/config_reader.py
CHANGED
@@ -59,5 +59,15 @@ def get_local_dir():
|
|
59
59
|
return config.get("temp-output-dir", "/tmp")
|
60
60
|
|
61
61
|
|
62
|
+
def get_local_models_dir():
|
63
|
+
config = read_config()
|
64
|
+
return config.get("models-dir", "/tmp/models")
|
65
|
+
|
66
|
+
|
67
|
+
def get_device():
|
68
|
+
config = read_config()
|
69
|
+
return config.get("device-mode", "cpu")
|
70
|
+
|
71
|
+
|
62
72
|
if __name__ == "__main__":
|
63
73
|
ak, sk, endpoint = get_s3_config("llm-raw")
|
magic_pdf/libs/language.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1
1
|
import unicodedata
|
2
|
-
from fast_langdetect import
|
2
|
+
from fast_langdetect import detect_language
|
3
3
|
|
4
4
|
|
5
5
|
def detect_lang(text: str) -> str:
|
6
6
|
if len(text) == 0:
|
7
7
|
return ""
|
8
8
|
try:
|
9
|
-
lang_upper =
|
9
|
+
lang_upper = detect_language(text)
|
10
10
|
except:
|
11
11
|
html_no_ctrl_chars = ''.join([l for l in text if unicodedata.category(l)[0] not in ['C', ]])
|
12
|
-
lang_upper =
|
12
|
+
lang_upper = detect_language(html_no_ctrl_chars)
|
13
13
|
try:
|
14
14
|
lang = lang_upper.lower()
|
15
15
|
except:
|
magic_pdf/libs/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.
|
1
|
+
__version__ = "0.6.0"
|
magic_pdf/model/__init__.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1
|
+
import time
|
2
|
+
|
1
3
|
import fitz
|
2
4
|
import numpy as np
|
3
5
|
from loguru import logger
|
6
|
+
|
7
|
+
from magic_pdf.libs.config_reader import get_local_models_dir, get_device
|
4
8
|
from magic_pdf.model.model_list import MODEL
|
5
9
|
import magic_pdf.model as model_config
|
6
10
|
|
@@ -21,10 +25,11 @@ def remove_duplicates_dicts(lst):
|
|
21
25
|
|
22
26
|
def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
|
23
27
|
try:
|
24
|
-
import cv2
|
25
28
|
from PIL import Image
|
26
29
|
except ImportError:
|
27
|
-
logger.error("
|
30
|
+
logger.error("Pillow not installed, please install by pip.")
|
31
|
+
exit(1)
|
32
|
+
|
28
33
|
images = []
|
29
34
|
with fitz.open("pdf", pdf_bytes) as doc:
|
30
35
|
for index in range(0, doc.page_count):
|
@@ -32,32 +37,49 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
|
|
32
37
|
mat = fitz.Matrix(dpi / 72, dpi / 72)
|
33
38
|
pm = page.get_pixmap(matrix=mat, alpha=False)
|
34
39
|
|
35
|
-
# if width or height >
|
36
|
-
|
37
|
-
|
40
|
+
# if width or height > 3000 pixels, don't enlarge the image
|
41
|
+
if pm.width > 3000 or pm.height > 3000:
|
42
|
+
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
|
38
43
|
|
39
|
-
img = Image.frombytes("RGB",
|
40
|
-
img =
|
44
|
+
img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
|
45
|
+
img = np.array(img)
|
41
46
|
img_dict = {"img": img, "width": pm.width, "height": pm.height}
|
42
47
|
images.append(img_dict)
|
43
48
|
return images
|
44
49
|
|
45
50
|
|
46
|
-
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False
|
51
|
+
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False):
|
52
|
+
model = None
|
53
|
+
|
54
|
+
if model_config.__model_mode__ == "lite":
|
55
|
+
model = MODEL.Paddle
|
56
|
+
elif model_config.__model_mode__ == "full":
|
57
|
+
model = MODEL.PEK
|
47
58
|
|
48
59
|
if model_config.__use_inside_model__:
|
49
|
-
|
60
|
+
model_init_start = time.time()
|
61
|
+
if model == MODEL.Paddle:
|
62
|
+
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
|
63
|
+
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
|
64
|
+
elif model == MODEL.PEK:
|
65
|
+
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
|
66
|
+
# 从配置文件读取model-dir和device
|
67
|
+
local_models_dir = get_local_models_dir()
|
68
|
+
device = get_device()
|
69
|
+
custom_model = CustomPEKModel(ocr=ocr, show_log=show_log, models_dir=local_models_dir, device=device)
|
70
|
+
else:
|
71
|
+
logger.error("Not allow model_name!")
|
72
|
+
exit(1)
|
73
|
+
model_init_cost = time.time() - model_init_start
|
74
|
+
logger.info(f"model init cost: {model_init_cost}")
|
50
75
|
else:
|
51
76
|
logger.error("use_inside_model is False, not allow to use inside model")
|
52
77
|
exit(1)
|
53
78
|
|
54
79
|
images = load_images_from_pdf(pdf_bytes)
|
55
|
-
|
56
|
-
if model == MODEL.Paddle:
|
57
|
-
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
|
58
|
-
else:
|
59
|
-
pass
|
80
|
+
|
60
81
|
model_json = []
|
82
|
+
doc_analyze_start = time.time()
|
61
83
|
for index, img_dict in enumerate(images):
|
62
84
|
img = img_dict["img"]
|
63
85
|
page_width = img_dict["width"]
|
@@ -65,7 +87,8 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, mod
|
|
65
87
|
result = custom_model(img)
|
66
88
|
page_info = {"page_no": index, "height": page_height, "width": page_width}
|
67
89
|
page_dict = {"layout_dets": result, "page_info": page_info}
|
68
|
-
|
69
90
|
model_json.append(page_dict)
|
91
|
+
doc_analyze_cost = time.time() - doc_analyze_start
|
92
|
+
logger.info(f"doc analyze cost: {doc_analyze_cost}")
|
70
93
|
|
71
94
|
return model_json
|
magic_pdf/model/model_list.py
CHANGED
@@ -0,0 +1,196 @@
|
|
1
|
+
import os
|
2
|
+
import cv2
|
3
|
+
import yaml
|
4
|
+
import time
|
5
|
+
import argparse
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
from loguru import logger
|
9
|
+
|
10
|
+
from paddleocr import draw_ocr
|
11
|
+
from PIL import Image
|
12
|
+
from torchvision import transforms
|
13
|
+
from torch.utils.data import Dataset, DataLoader
|
14
|
+
from ultralytics import YOLO
|
15
|
+
from unimernet.common.config import Config
|
16
|
+
import unimernet.tasks as tasks
|
17
|
+
from unimernet.processors import load_processor
|
18
|
+
|
19
|
+
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
|
20
|
+
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
|
21
|
+
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
|
22
|
+
|
23
|
+
|
24
|
+
def mfd_model_init(weight):
|
25
|
+
mfd_model = YOLO(weight)
|
26
|
+
return mfd_model
|
27
|
+
|
28
|
+
|
29
|
+
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
|
30
|
+
args = argparse.Namespace(cfg_path=cfg_path, options=None)
|
31
|
+
cfg = Config(args)
|
32
|
+
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
|
33
|
+
cfg.config.model.model_config.model_name = weight_dir
|
34
|
+
cfg.config.model.tokenizer_config.path = weight_dir
|
35
|
+
task = tasks.setup_task(cfg)
|
36
|
+
model = task.build_model(cfg)
|
37
|
+
model = model.to(_device_)
|
38
|
+
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
|
39
|
+
return model, vis_processor
|
40
|
+
|
41
|
+
|
42
|
+
def layout_model_init(weight, config_file, device):
|
43
|
+
model = Layoutlmv3_Predictor(weight, config_file, device)
|
44
|
+
return model
|
45
|
+
|
46
|
+
|
47
|
+
class MathDataset(Dataset):
|
48
|
+
def __init__(self, image_paths, transform=None):
|
49
|
+
self.image_paths = image_paths
|
50
|
+
self.transform = transform
|
51
|
+
|
52
|
+
def __len__(self):
|
53
|
+
return len(self.image_paths)
|
54
|
+
|
55
|
+
def __getitem__(self, idx):
|
56
|
+
# if not pil image, then convert to pil image
|
57
|
+
if isinstance(self.image_paths[idx], str):
|
58
|
+
raw_image = Image.open(self.image_paths[idx])
|
59
|
+
else:
|
60
|
+
raw_image = self.image_paths[idx]
|
61
|
+
if self.transform:
|
62
|
+
image = self.transform(raw_image)
|
63
|
+
return image
|
64
|
+
|
65
|
+
|
66
|
+
class CustomPEKModel:
|
67
|
+
|
68
|
+
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
|
69
|
+
"""
|
70
|
+
======== model init ========
|
71
|
+
"""
|
72
|
+
# 获取当前文件(即 pdf_extract_kit.py)的绝对路径
|
73
|
+
current_file_path = os.path.abspath(__file__)
|
74
|
+
# 获取当前文件所在的目录(model)
|
75
|
+
current_dir = os.path.dirname(current_file_path)
|
76
|
+
# 上一级目录(magic_pdf)
|
77
|
+
root_dir = os.path.dirname(current_dir)
|
78
|
+
# model_config目录
|
79
|
+
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
|
80
|
+
# 构建 model_configs.yaml 文件的完整路径
|
81
|
+
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
|
82
|
+
with open(config_path, "r") as f:
|
83
|
+
self.configs = yaml.load(f, Loader=yaml.FullLoader)
|
84
|
+
# 初始化解析配置
|
85
|
+
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
|
86
|
+
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
|
87
|
+
self.apply_ocr = ocr
|
88
|
+
logger.info(
|
89
|
+
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
|
90
|
+
self.apply_layout, self.apply_formula, self.apply_ocr
|
91
|
+
)
|
92
|
+
)
|
93
|
+
assert self.apply_layout, "DocAnalysis must contain layout model."
|
94
|
+
# 初始化解析方案
|
95
|
+
self.device = kwargs.get("device", self.configs["config"]["device"])
|
96
|
+
logger.info("using device: {}".format(self.device))
|
97
|
+
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
|
98
|
+
|
99
|
+
# 初始化公式识别
|
100
|
+
if self.apply_formula:
|
101
|
+
# 初始化公式检测模型
|
102
|
+
self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
|
103
|
+
|
104
|
+
# 初始化公式解析模型
|
105
|
+
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
|
106
|
+
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
|
107
|
+
self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
|
108
|
+
self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
|
109
|
+
|
110
|
+
# 初始化layout模型
|
111
|
+
self.layout_model = Layoutlmv3_Predictor(
|
112
|
+
str(os.path.join(models_dir, self.configs['weights']['layout'])),
|
113
|
+
str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
|
114
|
+
device=self.device
|
115
|
+
)
|
116
|
+
# 初始化ocr
|
117
|
+
if self.apply_ocr:
|
118
|
+
self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
|
119
|
+
|
120
|
+
logger.info('DocAnalysis init done!')
|
121
|
+
|
122
|
+
def __call__(self, image):
|
123
|
+
|
124
|
+
latex_filling_list = []
|
125
|
+
mf_image_list = []
|
126
|
+
|
127
|
+
# layout检测
|
128
|
+
layout_start = time.time()
|
129
|
+
layout_res = self.layout_model(image, ignore_catids=[])
|
130
|
+
layout_cost = round(time.time() - layout_start, 2)
|
131
|
+
logger.info(f"layout detection cost: {layout_cost}")
|
132
|
+
|
133
|
+
# 公式检测
|
134
|
+
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
|
135
|
+
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
|
136
|
+
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
137
|
+
new_item = {
|
138
|
+
'category_id': 13 + int(cla.item()),
|
139
|
+
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
140
|
+
'score': round(float(conf.item()), 2),
|
141
|
+
'latex': '',
|
142
|
+
}
|
143
|
+
layout_res.append(new_item)
|
144
|
+
latex_filling_list.append(new_item)
|
145
|
+
bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
|
146
|
+
mf_image_list.append(bbox_img)
|
147
|
+
|
148
|
+
# 公式识别
|
149
|
+
mfr_start = time.time()
|
150
|
+
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
151
|
+
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
|
152
|
+
mfr_res = []
|
153
|
+
for mf_img in dataloader:
|
154
|
+
mf_img = mf_img.to(self.device)
|
155
|
+
output = self.mfr_model.generate({'image': mf_img})
|
156
|
+
mfr_res.extend(output['pred_str'])
|
157
|
+
for res, latex in zip(latex_filling_list, mfr_res):
|
158
|
+
res['latex'] = latex_rm_whitespace(latex)
|
159
|
+
mfr_cost = round(time.time() - mfr_start, 2)
|
160
|
+
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
|
161
|
+
|
162
|
+
# ocr识别
|
163
|
+
if self.apply_ocr:
|
164
|
+
ocr_start = time.time()
|
165
|
+
pil_img = Image.fromarray(image)
|
166
|
+
single_page_mfdetrec_res = []
|
167
|
+
for res in layout_res:
|
168
|
+
if int(res['category_id']) in [13, 14]:
|
169
|
+
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
|
170
|
+
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
|
171
|
+
single_page_mfdetrec_res.append({
|
172
|
+
"bbox": [xmin, ymin, xmax, ymax],
|
173
|
+
})
|
174
|
+
for res in layout_res:
|
175
|
+
if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
|
176
|
+
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
|
177
|
+
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
|
178
|
+
crop_box = (xmin, ymin, xmax, ymax)
|
179
|
+
cropped_img = Image.new('RGB', pil_img.size, 'white')
|
180
|
+
cropped_img.paste(pil_img.crop(crop_box), crop_box)
|
181
|
+
cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
|
182
|
+
ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
|
183
|
+
if ocr_res:
|
184
|
+
for box_ocr_res in ocr_res:
|
185
|
+
p1, p2, p3, p4 = box_ocr_res[0]
|
186
|
+
text, score = box_ocr_res[1]
|
187
|
+
layout_res.append({
|
188
|
+
'category_id': 15,
|
189
|
+
'poly': p1 + p2 + p3 + p4,
|
190
|
+
'score': round(score, 2),
|
191
|
+
'text': text,
|
192
|
+
})
|
193
|
+
ocr_cost = round(time.time() - ocr_start, 2)
|
194
|
+
logger.info(f"ocr cost: {ocr_cost}")
|
195
|
+
|
196
|
+
return layout_res
|
File without changes
|
File without changes
|
@@ -0,0 +1,179 @@
|
|
1
|
+
# --------------------------------------------------------------------------------
|
2
|
+
# VIT: Multi-Path Vision Transformer for Dense Prediction
|
3
|
+
# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
|
4
|
+
# All Rights Reserved.
|
5
|
+
# Written by Youngwan Lee
|
6
|
+
# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
|
7
|
+
# LICENSE file in the root directory of this source tree.
|
8
|
+
# --------------------------------------------------------------------------------
|
9
|
+
# References:
|
10
|
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
11
|
+
# CoaT: https://github.com/mlpc-ucsd/CoaT
|
12
|
+
# --------------------------------------------------------------------------------
|
13
|
+
|
14
|
+
|
15
|
+
import torch
|
16
|
+
|
17
|
+
from detectron2.layers import (
|
18
|
+
ShapeSpec,
|
19
|
+
)
|
20
|
+
from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
|
21
|
+
from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
|
22
|
+
|
23
|
+
from .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16
|
24
|
+
from .deit import deit_base_patch16, mae_base_patch16
|
25
|
+
from .layoutlmft.models.layoutlmv3 import LayoutLMv3Model
|
26
|
+
from transformers import AutoConfig
|
27
|
+
|
28
|
+
__all__ = [
|
29
|
+
"build_vit_fpn_backbone",
|
30
|
+
]
|
31
|
+
|
32
|
+
|
33
|
+
class VIT_Backbone(Backbone):
|
34
|
+
"""
|
35
|
+
Implement VIT backbone.
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs,
|
39
|
+
config_path=None, image_only=False, cfg=None):
|
40
|
+
super().__init__()
|
41
|
+
self._out_features = out_features
|
42
|
+
if 'base' in name:
|
43
|
+
self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
|
44
|
+
self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
|
45
|
+
else:
|
46
|
+
self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
|
47
|
+
self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
|
48
|
+
|
49
|
+
if name == 'beit_base_patch16':
|
50
|
+
model_func = beit_base_patch16
|
51
|
+
elif name == 'dit_base_patch16':
|
52
|
+
model_func = dit_base_patch16
|
53
|
+
elif name == "deit_base_patch16":
|
54
|
+
model_func = deit_base_patch16
|
55
|
+
elif name == "mae_base_patch16":
|
56
|
+
model_func = mae_base_patch16
|
57
|
+
elif name == "dit_large_patch16":
|
58
|
+
model_func = dit_large_patch16
|
59
|
+
elif name == "beit_large_patch16":
|
60
|
+
model_func = beit_large_patch16
|
61
|
+
|
62
|
+
if 'beit' in name or 'dit' in name:
|
63
|
+
if pos_type == "abs":
|
64
|
+
self.backbone = model_func(img_size=img_size,
|
65
|
+
out_features=out_features,
|
66
|
+
drop_path_rate=drop_path,
|
67
|
+
use_abs_pos_emb=True,
|
68
|
+
**model_kwargs)
|
69
|
+
elif pos_type == "shared_rel":
|
70
|
+
self.backbone = model_func(img_size=img_size,
|
71
|
+
out_features=out_features,
|
72
|
+
drop_path_rate=drop_path,
|
73
|
+
use_shared_rel_pos_bias=True,
|
74
|
+
**model_kwargs)
|
75
|
+
elif pos_type == "rel":
|
76
|
+
self.backbone = model_func(img_size=img_size,
|
77
|
+
out_features=out_features,
|
78
|
+
drop_path_rate=drop_path,
|
79
|
+
use_rel_pos_bias=True,
|
80
|
+
**model_kwargs)
|
81
|
+
else:
|
82
|
+
raise ValueError()
|
83
|
+
elif "layoutlmv3" in name:
|
84
|
+
config = AutoConfig.from_pretrained(config_path)
|
85
|
+
# disable relative bias as DiT
|
86
|
+
config.has_spatial_attention_bias = False
|
87
|
+
config.has_relative_attention_bias = False
|
88
|
+
self.backbone = LayoutLMv3Model(config, detection=True,
|
89
|
+
out_features=out_features, image_only=image_only)
|
90
|
+
else:
|
91
|
+
self.backbone = model_func(img_size=img_size,
|
92
|
+
out_features=out_features,
|
93
|
+
drop_path_rate=drop_path,
|
94
|
+
**model_kwargs)
|
95
|
+
self.name = name
|
96
|
+
|
97
|
+
def forward(self, x):
|
98
|
+
"""
|
99
|
+
Args:
|
100
|
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
dict[str->Tensor]: names and the corresponding features
|
104
|
+
"""
|
105
|
+
if "layoutlmv3" in self.name:
|
106
|
+
return self.backbone.forward(
|
107
|
+
input_ids=x["input_ids"] if "input_ids" in x else None,
|
108
|
+
bbox=x["bbox"] if "bbox" in x else None,
|
109
|
+
images=x["images"] if "images" in x else None,
|
110
|
+
attention_mask=x["attention_mask"] if "attention_mask" in x else None,
|
111
|
+
# output_hidden_states=True,
|
112
|
+
)
|
113
|
+
assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
114
|
+
return self.backbone.forward_features(x)
|
115
|
+
|
116
|
+
def output_shape(self):
|
117
|
+
return {
|
118
|
+
name: ShapeSpec(
|
119
|
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
120
|
+
)
|
121
|
+
for name in self._out_features
|
122
|
+
}
|
123
|
+
|
124
|
+
|
125
|
+
def build_VIT_backbone(cfg):
|
126
|
+
"""
|
127
|
+
Create a VIT instance from config.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
cfg: a detectron2 CfgNode
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
A VIT backbone instance.
|
134
|
+
"""
|
135
|
+
# fmt: off
|
136
|
+
name = cfg.MODEL.VIT.NAME
|
137
|
+
out_features = cfg.MODEL.VIT.OUT_FEATURES
|
138
|
+
drop_path = cfg.MODEL.VIT.DROP_PATH
|
139
|
+
img_size = cfg.MODEL.VIT.IMG_SIZE
|
140
|
+
pos_type = cfg.MODEL.VIT.POS_TYPE
|
141
|
+
|
142
|
+
model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
|
143
|
+
|
144
|
+
if 'layoutlmv3' in name:
|
145
|
+
if cfg.MODEL.CONFIG_PATH != '':
|
146
|
+
config_path = cfg.MODEL.CONFIG_PATH
|
147
|
+
else:
|
148
|
+
config_path = cfg.MODEL.WEIGHTS.replace('pytorch_model.bin', '') # layoutlmv3 pre-trained models
|
149
|
+
config_path = config_path.replace('model_final.pth', '') # detection fine-tuned models
|
150
|
+
else:
|
151
|
+
config_path = None
|
152
|
+
|
153
|
+
return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs,
|
154
|
+
config_path=config_path, image_only=cfg.MODEL.IMAGE_ONLY, cfg=cfg)
|
155
|
+
|
156
|
+
|
157
|
+
@BACKBONE_REGISTRY.register()
|
158
|
+
def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
|
159
|
+
"""
|
160
|
+
Create a VIT w/ FPN backbone.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
cfg: a detectron2 CfgNode
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
|
167
|
+
"""
|
168
|
+
bottom_up = build_VIT_backbone(cfg)
|
169
|
+
in_features = cfg.MODEL.FPN.IN_FEATURES
|
170
|
+
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
|
171
|
+
backbone = FPN(
|
172
|
+
bottom_up=bottom_up,
|
173
|
+
in_features=in_features,
|
174
|
+
out_channels=out_channels,
|
175
|
+
norm=cfg.MODEL.FPN.NORM,
|
176
|
+
top_block=LastLevelMaxPool(),
|
177
|
+
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
|
178
|
+
)
|
179
|
+
return backbone
|