magic-pdf 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/data/batch_build_dataset.py +156 -0
- magic_pdf/data/dataset.py +44 -24
- magic_pdf/data/utils.py +108 -9
- magic_pdf/dict2md/ocr_mkcontent.py +4 -3
- magic_pdf/libs/pdf_image_tools.py +11 -6
- magic_pdf/libs/performance_stats.py +12 -1
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/batch_analyze.py +175 -201
- magic_pdf/model/doc_analyze_by_custom_model.py +137 -92
- magic_pdf/model/pdf_extract_kit.py +5 -38
- magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
- magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
- magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
- magic_pdf/model/sub_modules/model_init.py +50 -37
- magic_pdf/model/sub_modules/model_utils.py +17 -11
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +10 -18
- magic_pdf/pdf_parse_union_core_v2.py +112 -74
- magic_pdf/post_proc/para_split_v3.py +16 -13
- magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
- magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
- magic_pdf/resources/model_config/model_configs.yaml +1 -1
- magic_pdf/tools/cli.py +30 -12
- magic_pdf/tools/common.py +90 -12
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/METADATA +51 -41
- magic_pdf-1.3.0.dist-info/RECORD +202 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
- magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
- magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
- magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
- magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
- magic_pdf-1.2.1.dist-info/RECORD +0 -147
- /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
- /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
- /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/WHEEL +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,6 @@
|
|
1
|
-
import argparse
|
2
|
-
import os
|
3
|
-
import re
|
4
|
-
|
5
1
|
import torch
|
6
|
-
import unimernet.tasks as tasks
|
7
|
-
from PIL import Image
|
8
2
|
from torch.utils.data import DataLoader, Dataset
|
9
|
-
from
|
10
|
-
from unimernet.common.config import Config
|
11
|
-
from unimernet.processors import load_processor
|
3
|
+
from tqdm import tqdm
|
12
4
|
|
13
5
|
|
14
6
|
class MathDataset(Dataset):
|
@@ -20,55 +12,24 @@ class MathDataset(Dataset):
|
|
20
12
|
return len(self.image_paths)
|
21
13
|
|
22
14
|
def __getitem__(self, idx):
|
23
|
-
|
24
|
-
if isinstance(self.image_paths[idx], str):
|
25
|
-
raw_image = Image.open(self.image_paths[idx])
|
26
|
-
else:
|
27
|
-
raw_image = self.image_paths[idx]
|
15
|
+
raw_image = self.image_paths[idx]
|
28
16
|
if self.transform:
|
29
17
|
image = self.transform(raw_image)
|
30
18
|
return image
|
31
19
|
|
32
20
|
|
33
|
-
def latex_rm_whitespace(s: str):
|
34
|
-
"""Remove unnecessary whitespace from LaTeX code."""
|
35
|
-
text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
|
36
|
-
letter = "[a-zA-Z]"
|
37
|
-
noletter = "[\W_^\d]"
|
38
|
-
names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
|
39
|
-
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
|
40
|
-
news = s
|
41
|
-
while True:
|
42
|
-
s = news
|
43
|
-
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
|
44
|
-
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
|
45
|
-
news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
|
46
|
-
if news == s:
|
47
|
-
break
|
48
|
-
return s
|
49
|
-
|
50
|
-
|
51
21
|
class UnimernetModel(object):
|
52
22
|
def __init__(self, weight_dir, cfg_path, _device_="cpu"):
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
task = tasks.setup_task(cfg)
|
59
|
-
self.model = task.build_model(cfg)
|
23
|
+
from .unimernet_hf import UnimernetModel
|
24
|
+
if _device_.startswith("mps"):
|
25
|
+
self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
|
26
|
+
else:
|
27
|
+
self.model = UnimernetModel.from_pretrained(weight_dir)
|
60
28
|
self.device = _device_
|
61
29
|
self.model.to(_device_)
|
30
|
+
if not _device_.startswith("cpu"):
|
31
|
+
self.model = self.model.to(dtype=torch.float16)
|
62
32
|
self.model.eval()
|
63
|
-
vis_processor = load_processor(
|
64
|
-
"formula_image_eval",
|
65
|
-
cfg.config.datasets.formula_rec_eval.vis_processor.eval,
|
66
|
-
)
|
67
|
-
self.mfr_transform = transforms.Compose(
|
68
|
-
[
|
69
|
-
vis_processor,
|
70
|
-
]
|
71
|
-
)
|
72
33
|
|
73
34
|
def predict(self, mfd_res, image):
|
74
35
|
formula_list = []
|
@@ -84,62 +45,22 @@ class UnimernetModel(object):
|
|
84
45
|
"latex": "",
|
85
46
|
}
|
86
47
|
formula_list.append(new_item)
|
87
|
-
|
88
|
-
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
48
|
+
bbox_img = image[ymin:ymax, xmin:xmax]
|
89
49
|
mf_image_list.append(bbox_img)
|
90
50
|
|
91
|
-
dataset = MathDataset(mf_image_list, transform=self.
|
51
|
+
dataset = MathDataset(mf_image_list, transform=self.model.transform)
|
92
52
|
dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
|
93
53
|
mfr_res = []
|
94
54
|
for mf_img in dataloader:
|
55
|
+
mf_img = mf_img.to(dtype=self.model.dtype)
|
95
56
|
mf_img = mf_img.to(self.device)
|
96
57
|
with torch.no_grad():
|
97
58
|
output = self.model.generate({"image": mf_img})
|
98
|
-
mfr_res.extend(output["
|
59
|
+
mfr_res.extend(output["fixed_str"])
|
99
60
|
for res, latex in zip(formula_list, mfr_res):
|
100
|
-
res["latex"] =
|
61
|
+
res["latex"] = latex
|
101
62
|
return formula_list
|
102
63
|
|
103
|
-
# def batch_predict(
|
104
|
-
# self, images_mfd_res: list, images: list, batch_size: int = 64
|
105
|
-
# ) -> list:
|
106
|
-
# images_formula_list = []
|
107
|
-
# mf_image_list = []
|
108
|
-
# backfill_list = []
|
109
|
-
# for image_index in range(len(images_mfd_res)):
|
110
|
-
# mfd_res = images_mfd_res[image_index]
|
111
|
-
# pil_img = Image.fromarray(images[image_index])
|
112
|
-
# formula_list = []
|
113
|
-
#
|
114
|
-
# for xyxy, conf, cla in zip(
|
115
|
-
# mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
|
116
|
-
# ):
|
117
|
-
# xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
|
118
|
-
# new_item = {
|
119
|
-
# "category_id": 13 + int(cla.item()),
|
120
|
-
# "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
|
121
|
-
# "score": round(float(conf.item()), 2),
|
122
|
-
# "latex": "",
|
123
|
-
# }
|
124
|
-
# formula_list.append(new_item)
|
125
|
-
# bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
126
|
-
# mf_image_list.append(bbox_img)
|
127
|
-
#
|
128
|
-
# images_formula_list.append(formula_list)
|
129
|
-
# backfill_list += formula_list
|
130
|
-
#
|
131
|
-
# dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
|
132
|
-
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
|
133
|
-
# mfr_res = []
|
134
|
-
# for mf_img in dataloader:
|
135
|
-
# mf_img = mf_img.to(self.device)
|
136
|
-
# with torch.no_grad():
|
137
|
-
# output = self.model.generate({"image": mf_img})
|
138
|
-
# mfr_res.extend(output["pred_str"])
|
139
|
-
# for res, latex in zip(backfill_list, mfr_res):
|
140
|
-
# res["latex"] = latex_rm_whitespace(latex)
|
141
|
-
# return images_formula_list
|
142
|
-
|
143
64
|
def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
|
144
65
|
images_formula_list = []
|
145
66
|
mf_image_list = []
|
@@ -149,7 +70,7 @@ class UnimernetModel(object):
|
|
149
70
|
# Collect images with their original indices
|
150
71
|
for image_index in range(len(images_mfd_res)):
|
151
72
|
mfd_res = images_mfd_res[image_index]
|
152
|
-
|
73
|
+
np_array_image = images[image_index]
|
153
74
|
formula_list = []
|
154
75
|
|
155
76
|
for idx, (xyxy, conf, cla) in enumerate(zip(
|
@@ -163,7 +84,7 @@ class UnimernetModel(object):
|
|
163
84
|
"latex": "",
|
164
85
|
}
|
165
86
|
formula_list.append(new_item)
|
166
|
-
bbox_img =
|
87
|
+
bbox_img = np_array_image[ymin:ymax, xmin:xmax]
|
167
88
|
area = (xmax - xmin) * (ymax - ymin)
|
168
89
|
|
169
90
|
curr_idx = len(mf_image_list)
|
@@ -182,22 +103,30 @@ class UnimernetModel(object):
|
|
182
103
|
index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
|
183
104
|
|
184
105
|
# Create dataset with sorted images
|
185
|
-
dataset = MathDataset(sorted_images, transform=self.
|
106
|
+
dataset = MathDataset(sorted_images, transform=self.model.transform)
|
186
107
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
|
187
108
|
|
188
109
|
# Process batches and store results
|
189
110
|
mfr_res = []
|
190
|
-
for mf_img in dataloader:
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
111
|
+
# for mf_img in dataloader:
|
112
|
+
|
113
|
+
with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
|
114
|
+
for index, mf_img in enumerate(dataloader):
|
115
|
+
mf_img = mf_img.to(dtype=self.model.dtype)
|
116
|
+
mf_img = mf_img.to(self.device)
|
117
|
+
with torch.no_grad():
|
118
|
+
output = self.model.generate({"image": mf_img})
|
119
|
+
mfr_res.extend(output["fixed_str"])
|
120
|
+
|
121
|
+
# 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
|
122
|
+
current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
|
123
|
+
pbar.update(current_batch_size)
|
195
124
|
|
196
125
|
# Restore original order
|
197
126
|
unsorted_results = [""] * len(mfr_res)
|
198
127
|
for new_idx, latex in enumerate(mfr_res):
|
199
128
|
original_idx = index_mapping[new_idx]
|
200
|
-
unsorted_results[original_idx] =
|
129
|
+
unsorted_results[original_idx] = latex
|
201
130
|
|
202
131
|
# Fill results back
|
203
132
|
for res, latex in zip(backfill_list, unsorted_results):
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from .unimer_swin import UnimerSwinConfig, UnimerSwinModel, UnimerSwinImageProcessor
|
2
|
+
from .unimer_mbart import UnimerMBartConfig, UnimerMBartModel, UnimerMBartForCausalLM
|
3
|
+
from .modeling_unimernet import UnimernetModel
|
4
|
+
|
5
|
+
__all__ = [
|
6
|
+
"UnimerSwinConfig",
|
7
|
+
"UnimerSwinModel",
|
8
|
+
"UnimerSwinImageProcessor",
|
9
|
+
"UnimerMBartConfig",
|
10
|
+
"UnimerMBartModel",
|
11
|
+
"UnimerMBartForCausalLM",
|
12
|
+
"UnimernetModel",
|
13
|
+
]
|
@@ -0,0 +1,189 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
import warnings
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from ftfy import fix_text
|
8
|
+
|
9
|
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
10
|
+
from transformers import VisionEncoderDecoderConfig, VisionEncoderDecoderModel
|
11
|
+
from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import logger as base_model_logger
|
12
|
+
|
13
|
+
from .unimer_swin import UnimerSwinConfig, UnimerSwinModel, UnimerSwinImageProcessor
|
14
|
+
from .unimer_mbart import UnimerMBartConfig, UnimerMBartForCausalLM
|
15
|
+
|
16
|
+
AutoConfig.register(UnimerSwinConfig.model_type, UnimerSwinConfig)
|
17
|
+
AutoConfig.register(UnimerMBartConfig.model_type, UnimerMBartConfig)
|
18
|
+
AutoModel.register(UnimerSwinConfig, UnimerSwinModel)
|
19
|
+
AutoModelForCausalLM.register(UnimerMBartConfig, UnimerMBartForCausalLM)
|
20
|
+
|
21
|
+
|
22
|
+
# TODO: rewrite tokenizer
|
23
|
+
class TokenizerWrapper:
|
24
|
+
def __init__(self, tokenizer):
|
25
|
+
self.tokenizer = tokenizer
|
26
|
+
self.pad_token_id = self.tokenizer.pad_token_id
|
27
|
+
self.bos_token_id = self.tokenizer.bos_token_id
|
28
|
+
self.eos_token_id = self.tokenizer.eos_token_id
|
29
|
+
|
30
|
+
def __len__(self):
|
31
|
+
return len(self.tokenizer)
|
32
|
+
|
33
|
+
def tokenize(self, text, **kwargs):
|
34
|
+
return self.tokenizer(
|
35
|
+
text,
|
36
|
+
return_token_type_ids=False,
|
37
|
+
return_tensors="pt",
|
38
|
+
padding="longest",
|
39
|
+
truncation=True,
|
40
|
+
**kwargs,
|
41
|
+
)
|
42
|
+
|
43
|
+
def token2str(self, tokens) -> list:
|
44
|
+
generated_text = self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
|
45
|
+
generated_text = [fix_text(text) for text in generated_text]
|
46
|
+
return generated_text
|
47
|
+
|
48
|
+
def detokenize(self, tokens):
|
49
|
+
toks = [self.tokenizer.convert_ids_to_tokens(tok) for tok in tokens]
|
50
|
+
for b in range(len(toks)):
|
51
|
+
for i in reversed(range(len(toks[b]))):
|
52
|
+
if toks[b][i] is None:
|
53
|
+
toks[b][i] = ''
|
54
|
+
toks[b][i] = toks[b][i].replace('Ġ', ' ').strip()
|
55
|
+
if toks[b][i] in ([self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.pad_token]):
|
56
|
+
del toks[b][i]
|
57
|
+
return toks
|
58
|
+
|
59
|
+
|
60
|
+
def latex_rm_whitespace(s: str):
|
61
|
+
"""Remove unnecessary whitespace from LaTeX code.
|
62
|
+
"""
|
63
|
+
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
|
64
|
+
letter = r'[a-zA-Z]'
|
65
|
+
noletter = r'[\W_^\d]'
|
66
|
+
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
|
67
|
+
s = re.sub(text_reg, lambda _: str(names.pop(0)), s)
|
68
|
+
news = s
|
69
|
+
while True:
|
70
|
+
s = news
|
71
|
+
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
|
72
|
+
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
|
73
|
+
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
|
74
|
+
if news == s:
|
75
|
+
break
|
76
|
+
return s
|
77
|
+
|
78
|
+
|
79
|
+
class UnimernetModel(VisionEncoderDecoderModel):
|
80
|
+
def __init__(
|
81
|
+
self,
|
82
|
+
config: Optional[PretrainedConfig] = None,
|
83
|
+
encoder: Optional[PreTrainedModel] = None,
|
84
|
+
decoder: Optional[PreTrainedModel] = None,
|
85
|
+
):
|
86
|
+
# VisionEncoderDecoderModel's checking log has bug, disable for temp.
|
87
|
+
base_model_logger.disabled = True
|
88
|
+
try:
|
89
|
+
super().__init__(config, encoder, decoder)
|
90
|
+
finally:
|
91
|
+
base_model_logger.disabled = False
|
92
|
+
|
93
|
+
if not config or not hasattr(config, "_name_or_path"):
|
94
|
+
raise RuntimeError("config._name_or_path is required by UnimernetModel.")
|
95
|
+
|
96
|
+
model_path = config._name_or_path
|
97
|
+
self.transform = UnimerSwinImageProcessor()
|
98
|
+
self.tokenizer = TokenizerWrapper(AutoTokenizer.from_pretrained(model_path))
|
99
|
+
self._post_check()
|
100
|
+
|
101
|
+
def _post_check(self):
|
102
|
+
tokenizer = self.tokenizer
|
103
|
+
|
104
|
+
if tokenizer.tokenizer.model_max_length != self.config.decoder.max_position_embeddings:
|
105
|
+
warnings.warn(
|
106
|
+
f"decoder.max_position_embeddings={self.config.decoder.max_position_embeddings}," +
|
107
|
+
f" but tokenizer.model_max_length={tokenizer.tokenizer.model_max_length}, will set" +
|
108
|
+
f" tokenizer.model_max_length to {self.config.decoder.max_position_embeddings}.")
|
109
|
+
tokenizer.tokenizer.model_max_length = self.config.decoder.max_position_embeddings
|
110
|
+
|
111
|
+
assert self.config.decoder.vocab_size == len(tokenizer)
|
112
|
+
assert self.config.decoder_start_token_id == tokenizer.bos_token_id
|
113
|
+
assert self.config.pad_token_id == tokenizer.pad_token_id
|
114
|
+
|
115
|
+
@classmethod
|
116
|
+
def from_checkpoint(cls, model_path: str, model_filename: str = "pytorch_model.pth", state_dict_strip_prefix="model.model."):
|
117
|
+
config = VisionEncoderDecoderConfig.from_pretrained(model_path)
|
118
|
+
config._name_or_path = model_path
|
119
|
+
config.encoder = UnimerSwinConfig(**vars(config.encoder))
|
120
|
+
config.decoder = UnimerMBartConfig(**vars(config.decoder))
|
121
|
+
|
122
|
+
encoder = UnimerSwinModel(config.encoder)
|
123
|
+
decoder = UnimerMBartForCausalLM(config.decoder)
|
124
|
+
model = cls(config, encoder, decoder)
|
125
|
+
|
126
|
+
# load model weights
|
127
|
+
model_file_path = os.path.join(model_path, model_filename)
|
128
|
+
checkpoint = torch.load(model_file_path, map_location="cpu", weights_only=True)
|
129
|
+
state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint
|
130
|
+
if not state_dict:
|
131
|
+
raise RuntimeError("state_dict is empty.")
|
132
|
+
if state_dict_strip_prefix:
|
133
|
+
state_dict = {
|
134
|
+
k[len(state_dict_strip_prefix):] if k.startswith(state_dict_strip_prefix) else k: v
|
135
|
+
for k, v in state_dict.items()
|
136
|
+
}
|
137
|
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
138
|
+
if len(unexpected_keys) > 0:
|
139
|
+
warnings.warn("Unexpected key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in unexpected_keys)))
|
140
|
+
if len(missing_keys) > 0:
|
141
|
+
raise RuntimeError("Missing key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in missing_keys)))
|
142
|
+
return model
|
143
|
+
|
144
|
+
def forward_bak(self, samples):
|
145
|
+
pixel_values, text = samples["image"], samples["text_input"]
|
146
|
+
|
147
|
+
text_inputs = self.tokenizer.tokenize(text).to(pixel_values.device)
|
148
|
+
decoder_input_ids, decoder_attention_mask = text_inputs["input_ids"], text_inputs["attention_mask"]
|
149
|
+
|
150
|
+
num_channels = pixel_values.shape[1]
|
151
|
+
if num_channels == 1:
|
152
|
+
pixel_values = pixel_values.repeat(1, 3, 1, 1)
|
153
|
+
|
154
|
+
labels = decoder_input_ids * 1
|
155
|
+
labels = labels.masked_fill(labels == self.tokenizer.pad_token_id, -100)
|
156
|
+
|
157
|
+
loss = self.model(
|
158
|
+
pixel_values=pixel_values,
|
159
|
+
decoder_input_ids=decoder_input_ids[:, :-1],
|
160
|
+
decoder_attention_mask=decoder_attention_mask[:, :-1],
|
161
|
+
labels=labels[:, 1:],
|
162
|
+
).loss
|
163
|
+
return {"loss": loss}
|
164
|
+
|
165
|
+
def generate(self, samples, do_sample: bool = False, temperature: float = 0.2, top_p: float = 0.95):
|
166
|
+
pixel_values = samples["image"]
|
167
|
+
num_channels = pixel_values.shape[1]
|
168
|
+
if num_channels == 1:
|
169
|
+
pixel_values = pixel_values.repeat(1, 3, 1, 1)
|
170
|
+
|
171
|
+
kwargs = {}
|
172
|
+
if do_sample:
|
173
|
+
kwargs["temperature"] = temperature
|
174
|
+
kwargs["top_p"] = top_p
|
175
|
+
|
176
|
+
outputs = super().generate(
|
177
|
+
pixel_values=pixel_values,
|
178
|
+
max_new_tokens=self.tokenizer.tokenizer.model_max_length, # required
|
179
|
+
decoder_start_token_id=self.tokenizer.tokenizer.bos_token_id,
|
180
|
+
do_sample=do_sample,
|
181
|
+
**kwargs,
|
182
|
+
)
|
183
|
+
|
184
|
+
outputs = outputs[:, 1:].cpu().numpy()
|
185
|
+
pred_tokens = self.tokenizer.detokenize(outputs)
|
186
|
+
pred_str = self.tokenizer.token2str(outputs)
|
187
|
+
fixed_str = [latex_rm_whitespace(s) for s in pred_str]
|
188
|
+
return {"pred_ids": outputs, "pred_tokens": pred_tokens, "pred_str": pred_str, "fixed_str": fixed_str}
|
189
|
+
|
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py
ADDED
@@ -0,0 +1,163 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
"""UnimerMBART model configuration"""
|
16
|
+
|
17
|
+
from transformers.configuration_utils import PretrainedConfig
|
18
|
+
from transformers.utils import logging
|
19
|
+
|
20
|
+
|
21
|
+
logger = logging.get_logger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class UnimerMBartConfig(PretrainedConfig):
|
25
|
+
r"""
|
26
|
+
This is the configuration class to store the configuration of a [`MBartModel`]. It is used to instantiate an MBART
|
27
|
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
28
|
+
defaults will yield a similar configuration to that of the MBART
|
29
|
+
[facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25) architecture.
|
30
|
+
|
31
|
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
32
|
+
documentation from [`PretrainedConfig`] for more information.
|
33
|
+
|
34
|
+
|
35
|
+
Args:
|
36
|
+
vocab_size (`int`, *optional*, defaults to 50265):
|
37
|
+
Vocabulary size of the MBART model. Defines the number of different tokens that can be represented by the
|
38
|
+
`inputs_ids` passed when calling [`MBartModel`] or [`TFMBartModel`].
|
39
|
+
d_model (`int`, *optional*, defaults to 1024):
|
40
|
+
Dimensionality of the layers and the pooler layer.
|
41
|
+
qk_squeeze (`int`, *optional*, defaults to 2):
|
42
|
+
Squeeze ratio for query/key's output dimension. See the [UniMERNet paper](https://arxiv.org/abs/2404.15254).
|
43
|
+
Squeeze Attention maps the query and key to a lower-dimensional space without excessive loss of information,
|
44
|
+
thereby accelerating the computation of attention.
|
45
|
+
encoder_layers (`int`, *optional*, defaults to 12):
|
46
|
+
Number of encoder layers.
|
47
|
+
decoder_layers (`int`, *optional*, defaults to 12):
|
48
|
+
Number of decoder layers.
|
49
|
+
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
50
|
+
Number of attention heads for each attention layer in the Transformer encoder.
|
51
|
+
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
52
|
+
Number of attention heads for each attention layer in the Transformer decoder.
|
53
|
+
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
54
|
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
55
|
+
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
56
|
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
57
|
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
58
|
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
59
|
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
60
|
+
dropout (`float`, *optional*, defaults to 0.1):
|
61
|
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
62
|
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
63
|
+
The dropout ratio for the attention probabilities.
|
64
|
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
65
|
+
The dropout ratio for activations inside the fully connected layer.
|
66
|
+
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
67
|
+
The dropout ratio for classifier.
|
68
|
+
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
69
|
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
70
|
+
just in case (e.g., 512 or 1024 or 2048).
|
71
|
+
init_std (`float`, *optional*, defaults to 0.02):
|
72
|
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
73
|
+
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
74
|
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
75
|
+
for more details.
|
76
|
+
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
77
|
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
78
|
+
for more details.
|
79
|
+
scale_embedding (`bool`, *optional*, defaults to `False`):
|
80
|
+
Scale embeddings by diving by sqrt(d_model).
|
81
|
+
use_cache (`bool`, *optional*, defaults to `True`):
|
82
|
+
Whether or not the model should return the last key/values attentions (not used by all models)
|
83
|
+
forced_eos_token_id (`int`, *optional*, defaults to 2):
|
84
|
+
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
|
85
|
+
`eos_token_id`.
|
86
|
+
|
87
|
+
Example:
|
88
|
+
|
89
|
+
```python
|
90
|
+
>>> from transformers import MBartConfig, MBartModel
|
91
|
+
|
92
|
+
>>> # Initializing a MBART facebook/mbart-large-cc25 style configuration
|
93
|
+
>>> configuration = MBartConfig()
|
94
|
+
|
95
|
+
>>> # Initializing a model (with random weights) from the facebook/mbart-large-cc25 style configuration
|
96
|
+
>>> model = MBartModel(configuration)
|
97
|
+
|
98
|
+
>>> # Accessing the model configuration
|
99
|
+
>>> configuration = model.config
|
100
|
+
```"""
|
101
|
+
|
102
|
+
model_type = "unimer-mbart"
|
103
|
+
keys_to_ignore_at_inference = ["past_key_values"]
|
104
|
+
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
105
|
+
|
106
|
+
def __init__(
|
107
|
+
self,
|
108
|
+
vocab_size=50265,
|
109
|
+
max_position_embeddings=1024,
|
110
|
+
encoder_layers=12,
|
111
|
+
encoder_ffn_dim=4096,
|
112
|
+
encoder_attention_heads=16,
|
113
|
+
decoder_layers=12,
|
114
|
+
decoder_ffn_dim=4096,
|
115
|
+
decoder_attention_heads=16,
|
116
|
+
encoder_layerdrop=0.0,
|
117
|
+
decoder_layerdrop=0.0,
|
118
|
+
use_cache=True,
|
119
|
+
is_encoder_decoder=True,
|
120
|
+
activation_function="gelu",
|
121
|
+
d_model=1024,
|
122
|
+
qk_squeeze=2,
|
123
|
+
dropout=0.1,
|
124
|
+
attention_dropout=0.0,
|
125
|
+
activation_dropout=0.0,
|
126
|
+
init_std=0.02,
|
127
|
+
classifier_dropout=0.0,
|
128
|
+
scale_embedding=False,
|
129
|
+
pad_token_id=1,
|
130
|
+
bos_token_id=0,
|
131
|
+
eos_token_id=2,
|
132
|
+
forced_eos_token_id=2,
|
133
|
+
**kwargs,
|
134
|
+
):
|
135
|
+
self.vocab_size = vocab_size
|
136
|
+
self.max_position_embeddings = max_position_embeddings
|
137
|
+
self.d_model = d_model
|
138
|
+
self.qk_squeeze = qk_squeeze
|
139
|
+
self.encoder_ffn_dim = encoder_ffn_dim
|
140
|
+
self.encoder_layers = encoder_layers
|
141
|
+
self.encoder_attention_heads = encoder_attention_heads
|
142
|
+
self.decoder_ffn_dim = decoder_ffn_dim
|
143
|
+
self.decoder_layers = decoder_layers
|
144
|
+
self.decoder_attention_heads = decoder_attention_heads
|
145
|
+
self.dropout = dropout
|
146
|
+
self.attention_dropout = attention_dropout
|
147
|
+
self.activation_dropout = activation_dropout
|
148
|
+
self.activation_function = activation_function
|
149
|
+
self.init_std = init_std
|
150
|
+
self.encoder_layerdrop = encoder_layerdrop
|
151
|
+
self.decoder_layerdrop = decoder_layerdrop
|
152
|
+
self.classifier_dropout = classifier_dropout
|
153
|
+
self.use_cache = use_cache
|
154
|
+
self.num_hidden_layers = encoder_layers
|
155
|
+
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
156
|
+
super().__init__(
|
157
|
+
pad_token_id=pad_token_id,
|
158
|
+
bos_token_id=bos_token_id,
|
159
|
+
eos_token_id=eos_token_id,
|
160
|
+
is_encoder_decoder=is_encoder_decoder,
|
161
|
+
forced_eos_token_id=forced_eos_token_id,
|
162
|
+
**kwargs,
|
163
|
+
)
|