mineru 2.5.3__py3-none-any.whl → 2.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mineru/backend/pipeline/model_init.py +25 -3
- mineru/backend/pipeline/model_json_to_middle_json.py +2 -2
- mineru/backend/pipeline/model_list.py +0 -1
- mineru/backend/utils.py +24 -0
- mineru/backend/vlm/model_output_to_middle_json.py +2 -2
- mineru/backend/vlm/{custom_logits_processors.py → utils.py} +36 -2
- mineru/backend/vlm/vlm_analyze.py +43 -50
- mineru/backend/vlm/vlm_magic_model.py +155 -1
- mineru/cli/common.py +26 -23
- mineru/cli/fast_api.py +2 -8
- mineru/cli/gradio_app.py +104 -13
- mineru/cli/models_download.py +1 -0
- mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py +152 -0
- mineru/model/mfr/pp_formulanet_plus_m/processors.py +657 -0
- mineru/model/mfr/unimernet/unimernet_hf/modeling_unimernet.py +1 -326
- mineru/model/mfr/utils.py +338 -0
- mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py +103 -16
- mineru/model/table/rec/unet_table/main.py +1 -1
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/imaug/operators.py +5 -5
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/__init__.py +2 -1
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_lcnetv3.py +7 -7
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_pphgnetv2.py +2 -2
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/__init__.py +2 -0
- mineru/model/utils/pytorchocr/modeling/heads/rec_ppformulanet_head.py +1383 -0
- mineru/model/utils/pytorchocr/modeling/heads/rec_unimernet_head.py +2631 -0
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/rec_postprocess.py +25 -28
- mineru/model/utils/pytorchocr/utils/__init__.py +0 -0
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/arch_config.yaml +130 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_arabic_dict.txt +747 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_cyrillic_dict.txt +850 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_devanagari_dict.txt +568 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_ta_dict.txt +513 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_te_dict.txt +540 -0
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/models_config.yml +15 -15
- mineru/model/utils/pytorchocr/utils/resources/pp_formulanet_arch_config.yaml +24 -0
- mineru/model/utils/tools/infer/__init__.py +1 -0
- mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_det.py +6 -3
- mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_rec.py +16 -25
- mineru/model/vlm_vllm_model/server.py +4 -1
- mineru/resources/header.html +2 -2
- mineru/utils/enum_class.py +1 -0
- mineru/utils/guess_suffix_or_lang.py +9 -1
- mineru/utils/llm_aided.py +4 -2
- mineru/utils/ocr_utils.py +16 -0
- mineru/utils/table_merge.py +102 -13
- mineru/version.py +1 -1
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/METADATA +33 -6
- mineru-2.6.0.dist-info/RECORD +195 -0
- mineru-2.5.3.dist-info/RECORD +0 -181
- /mineru/model/{ocr/paddleocr2pytorch/pytorchocr → mfr/pp_formulanet_plus_m}/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch/tools/infer → utils}/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch/pytorchocr/modeling → utils/pytorchocr}/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/base_ocr_v20.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/imaug/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch/pytorchocr/utils → utils/pytorchocr/modeling}/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/architectures/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/architectures/base_model.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/det_mobilenet_v3.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_donut_swin.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_hgnet.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_mv1_enhance.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_svtrnet.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/common.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/cls_head.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/det_db_head.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/rec_ctc_head.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/rec_multi_head.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/db_fpn.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/intracl.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/rnn.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/cls_postprocess.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/db_postprocess.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/arabic_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/cyrillic_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/devanagari_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/en_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/japan_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ka_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/korean_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/latin_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv4_doc_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_el_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_en_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_eslav_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_korean_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_latin_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_th_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ta_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/te_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_cls.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_system.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/pytorchocr_utility.py +0 -0
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/WHEEL +0 -0
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/entry_points.txt +0 -0
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/licenses/LICENSE.md +0 -0
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -12,27 +12,114 @@ from loguru import logger
|
|
|
12
12
|
from mineru.utils.config_reader import get_device
|
|
13
13
|
from mineru.utils.enum_class import ModelPath
|
|
14
14
|
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
|
15
|
-
from
|
|
16
|
-
from .tools.infer.predict_system import TextSystem
|
|
17
|
-
from .tools.infer import pytorchocr_utility as utility
|
|
15
|
+
from mineru.utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
|
|
16
|
+
from mineru.model.utils.tools.infer.predict_system import TextSystem
|
|
17
|
+
from mineru.model.utils.tools.infer import pytorchocr_utility as utility
|
|
18
18
|
import argparse
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
latin_lang = [
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
22
|
+
"af",
|
|
23
|
+
"az",
|
|
24
|
+
"bs",
|
|
25
|
+
"cs",
|
|
26
|
+
"cy",
|
|
27
|
+
"da",
|
|
28
|
+
"de",
|
|
29
|
+
"es",
|
|
30
|
+
"et",
|
|
31
|
+
"fr",
|
|
32
|
+
"ga",
|
|
33
|
+
"hr",
|
|
34
|
+
"hu",
|
|
35
|
+
"id",
|
|
36
|
+
"is",
|
|
37
|
+
"it",
|
|
38
|
+
"ku",
|
|
39
|
+
"la",
|
|
40
|
+
"lt",
|
|
41
|
+
"lv",
|
|
42
|
+
"mi",
|
|
43
|
+
"ms",
|
|
44
|
+
"mt",
|
|
45
|
+
"nl",
|
|
46
|
+
"no",
|
|
47
|
+
"oc",
|
|
48
|
+
"pi",
|
|
49
|
+
"pl",
|
|
50
|
+
"pt",
|
|
51
|
+
"ro",
|
|
52
|
+
"rs_latin",
|
|
53
|
+
"sk",
|
|
54
|
+
"sl",
|
|
55
|
+
"sq",
|
|
56
|
+
"sv",
|
|
57
|
+
"sw",
|
|
58
|
+
"tl",
|
|
59
|
+
"tr",
|
|
60
|
+
"uz",
|
|
61
|
+
"vi",
|
|
62
|
+
"french",
|
|
63
|
+
"german",
|
|
64
|
+
"fi",
|
|
65
|
+
"eu",
|
|
66
|
+
"gl",
|
|
67
|
+
"lb",
|
|
68
|
+
"rm",
|
|
69
|
+
"ca",
|
|
70
|
+
"qu",
|
|
26
71
|
]
|
|
27
|
-
arabic_lang = [
|
|
72
|
+
arabic_lang = ["ar", "fa", "ug", "ur", "ps", "ku", "sd", "bal"]
|
|
28
73
|
cyrillic_lang = [
|
|
29
|
-
|
|
30
|
-
|
|
74
|
+
"ru",
|
|
75
|
+
"rs_cyrillic",
|
|
76
|
+
"be",
|
|
77
|
+
"bg",
|
|
78
|
+
"uk",
|
|
79
|
+
"mn",
|
|
80
|
+
"abq",
|
|
81
|
+
"ady",
|
|
82
|
+
"kbd",
|
|
83
|
+
"ava",
|
|
84
|
+
"dar",
|
|
85
|
+
"inh",
|
|
86
|
+
"che",
|
|
87
|
+
"lbe",
|
|
88
|
+
"lez",
|
|
89
|
+
"tab",
|
|
90
|
+
"kk",
|
|
91
|
+
"ky",
|
|
92
|
+
"tg",
|
|
93
|
+
"mk",
|
|
94
|
+
"tt",
|
|
95
|
+
"cv",
|
|
96
|
+
"ba",
|
|
97
|
+
"mhr",
|
|
98
|
+
"mo",
|
|
99
|
+
"udm",
|
|
100
|
+
"kv",
|
|
101
|
+
"os",
|
|
102
|
+
"bua",
|
|
103
|
+
"xal",
|
|
104
|
+
"tyv",
|
|
105
|
+
"sah",
|
|
106
|
+
"kaa",
|
|
31
107
|
]
|
|
32
108
|
east_slavic_lang = ["ru", "be", "uk"]
|
|
33
109
|
devanagari_lang = [
|
|
34
|
-
|
|
35
|
-
|
|
110
|
+
"hi",
|
|
111
|
+
"mr",
|
|
112
|
+
"ne",
|
|
113
|
+
"bh",
|
|
114
|
+
"mai",
|
|
115
|
+
"ang",
|
|
116
|
+
"bho",
|
|
117
|
+
"mah",
|
|
118
|
+
"sck",
|
|
119
|
+
"new",
|
|
120
|
+
"gom",
|
|
121
|
+
"sa",
|
|
122
|
+
"bgc",
|
|
36
123
|
]
|
|
37
124
|
|
|
38
125
|
|
|
@@ -47,7 +134,7 @@ def get_model_params(lang, config):
|
|
|
47
134
|
raise Exception (f'Language {lang} not supported')
|
|
48
135
|
|
|
49
136
|
|
|
50
|
-
root_dir = Path(__file__).resolve().parent
|
|
137
|
+
root_dir = os.path.join(Path(__file__).resolve().parent.parent.parent, 'utils')
|
|
51
138
|
|
|
52
139
|
|
|
53
140
|
class PytorchPaddleOCR(TextSystem):
|
|
@@ -65,14 +152,14 @@ class PytorchPaddleOCR(TextSystem):
|
|
|
65
152
|
|
|
66
153
|
if self.lang in latin_lang:
|
|
67
154
|
self.lang = 'latin'
|
|
155
|
+
elif self.lang in east_slavic_lang:
|
|
156
|
+
self.lang = 'east_slavic'
|
|
68
157
|
elif self.lang in arabic_lang:
|
|
69
158
|
self.lang = 'arabic'
|
|
70
159
|
elif self.lang in cyrillic_lang:
|
|
71
160
|
self.lang = 'cyrillic'
|
|
72
161
|
elif self.lang in devanagari_lang:
|
|
73
162
|
self.lang = 'devanagari'
|
|
74
|
-
elif self.lang in east_slavic_lang:
|
|
75
|
-
self.lang = 'east_slavic'
|
|
76
163
|
else:
|
|
77
164
|
pass
|
|
78
165
|
|
|
@@ -89,7 +176,7 @@ class PytorchPaddleOCR(TextSystem):
|
|
|
89
176
|
kwargs['det_model_path'] = det_model_path
|
|
90
177
|
kwargs['rec_model_path'] = rec_model_path
|
|
91
178
|
kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
|
|
92
|
-
kwargs['rec_batch_num'] =
|
|
179
|
+
kwargs['rec_batch_num'] = 6
|
|
93
180
|
|
|
94
181
|
kwargs['device'] = device
|
|
95
182
|
|
|
@@ -184,7 +184,7 @@ class WiredTableRecognition:
|
|
|
184
184
|
continue
|
|
185
185
|
# 从img中截取对应的区域
|
|
186
186
|
x1, y1, x2, y2 = int(box[0][0])+1, int(box[0][1])+1, int(box[2][0])-1, int(box[2][1])-1
|
|
187
|
-
if x1 >= x2 or y1 >= y2:
|
|
187
|
+
if x1 >= x2 or y1 >= y2 or x1 < 0 or y1 < 0:
|
|
188
188
|
# logger.warning(f"Invalid box coordinates: {x1, y1, x2, y2}")
|
|
189
189
|
continue
|
|
190
190
|
# 判断长宽比
|
|
@@ -23,6 +23,7 @@ import sys
|
|
|
23
23
|
import six
|
|
24
24
|
import cv2
|
|
25
25
|
import numpy as np
|
|
26
|
+
from PIL import Image
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class DecodeImage(object):
|
|
@@ -104,16 +105,15 @@ class NormalizeImage(object):
|
|
|
104
105
|
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
|
|
105
106
|
self.mean = np.array(mean).reshape(shape).astype('float32')
|
|
106
107
|
self.std = np.array(std).reshape(shape).astype('float32')
|
|
108
|
+
self.scale = self.scale / self.std
|
|
109
|
+
self.mean = self.mean / self.std
|
|
110
|
+
|
|
107
111
|
|
|
108
112
|
def __call__(self, data):
|
|
109
113
|
img = data['image']
|
|
110
|
-
from PIL import Image
|
|
111
114
|
if isinstance(img, Image.Image):
|
|
112
115
|
img = np.array(img)
|
|
113
|
-
|
|
114
|
-
np.ndarray), "invalid input 'img' in NormalizeImage"
|
|
115
|
-
data['image'] = (
|
|
116
|
-
img.astype('float32') * self.scale - self.mean) / self.std
|
|
116
|
+
data['image'] = img.astype('float32') * self.scale - self.mean
|
|
117
117
|
return data
|
|
118
118
|
|
|
119
119
|
|
|
@@ -37,7 +37,7 @@ def build_backbone(config, model_type):
|
|
|
37
37
|
from .rec_mobilenet_v3 import MobileNetV3
|
|
38
38
|
from .rec_svtrnet import SVTRNet
|
|
39
39
|
from .rec_mv1_enhance import MobileNetV1Enhance
|
|
40
|
-
from .rec_pphgnetv2 import PPHGNetV2_B4
|
|
40
|
+
from .rec_pphgnetv2 import PPHGNetV2_B4, PPHGNetV2_B6_Formula
|
|
41
41
|
support_dict = [
|
|
42
42
|
"MobileNetV1Enhance",
|
|
43
43
|
"MobileNetV3",
|
|
@@ -51,6 +51,7 @@ def build_backbone(config, model_type):
|
|
|
51
51
|
"PPLCNetV3",
|
|
52
52
|
"PPHGNet_small",
|
|
53
53
|
"PPHGNetV2_B4",
|
|
54
|
+
"PPHGNetV2_B6_Formula"
|
|
54
55
|
]
|
|
55
56
|
else:
|
|
56
57
|
raise NotImplementedError
|
|
@@ -245,18 +245,18 @@ class LearnableRepLayer(nn.Module):
|
|
|
245
245
|
return 0, 0
|
|
246
246
|
elif isinstance(branch, ConvBNLayer):
|
|
247
247
|
kernel = branch.conv.weight
|
|
248
|
-
running_mean = branch.bn.
|
|
249
|
-
running_var = branch.bn.
|
|
248
|
+
running_mean = branch.bn.running_mean
|
|
249
|
+
running_var = branch.bn.running_var
|
|
250
250
|
gamma = branch.bn.weight
|
|
251
251
|
beta = branch.bn.bias
|
|
252
|
-
eps = branch.bn.
|
|
252
|
+
eps = branch.bn.eps
|
|
253
253
|
else:
|
|
254
254
|
assert isinstance(branch, nn.BatchNorm2d)
|
|
255
255
|
if not hasattr(self, "id_tensor"):
|
|
256
256
|
input_dim = self.in_channels // self.groups
|
|
257
257
|
kernel_value = torch.zeros(
|
|
258
258
|
(self.in_channels, input_dim, self.kernel_size, self.kernel_size),
|
|
259
|
-
dtype=branch.weight.dtype,
|
|
259
|
+
dtype=branch.weight.dtype, device=branch.weight.device,
|
|
260
260
|
)
|
|
261
261
|
for i in range(self.in_channels):
|
|
262
262
|
kernel_value[
|
|
@@ -264,11 +264,11 @@ class LearnableRepLayer(nn.Module):
|
|
|
264
264
|
] = 1
|
|
265
265
|
self.id_tensor = kernel_value
|
|
266
266
|
kernel = self.id_tensor
|
|
267
|
-
running_mean = branch.
|
|
268
|
-
running_var = branch.
|
|
267
|
+
running_mean = branch.running_mean
|
|
268
|
+
running_var = branch.running_var
|
|
269
269
|
gamma = branch.weight
|
|
270
270
|
beta = branch.bias
|
|
271
|
-
eps = branch.
|
|
271
|
+
eps = branch.eps
|
|
272
272
|
std = (running_var + eps).sqrt()
|
|
273
273
|
t = (gamma / std).reshape((-1, 1, 1, 1))
|
|
274
274
|
return kernel * t, beta - running_mean * gamma / std
|
|
@@ -1626,8 +1626,8 @@ class PPHGNetV2_B6_Formula(nn.Module):
|
|
|
1626
1626
|
pixel_values = torch.repeat_interleave(pixel_values, repeats=3, dim=1)
|
|
1627
1627
|
pphgnet_b6_output = self.pphgnet_b6(pixel_values)
|
|
1628
1628
|
b, c, h, w = pphgnet_b6_output.shape
|
|
1629
|
-
pphgnet_b6_output = pphgnet_b6_output.reshape([b, c, h * w]).
|
|
1630
|
-
|
|
1629
|
+
pphgnet_b6_output = pphgnet_b6_output.reshape([b, c, h * w]).permute(
|
|
1630
|
+
0, 2, 1
|
|
1631
1631
|
)
|
|
1632
1632
|
pphgnet_b6_output = DonutSwinModelOutput(
|
|
1633
1633
|
last_hidden_state=pphgnet_b6_output,
|
|
@@ -22,6 +22,7 @@ def build_head(config, **kwargs):
|
|
|
22
22
|
# rec head
|
|
23
23
|
from .rec_ctc_head import CTCHead
|
|
24
24
|
from .rec_multi_head import MultiHead
|
|
25
|
+
from .rec_ppformulanet_head import PPFormulaNet_Head
|
|
25
26
|
|
|
26
27
|
# cls head
|
|
27
28
|
from .cls_head import ClsHead
|
|
@@ -32,6 +33,7 @@ def build_head(config, **kwargs):
|
|
|
32
33
|
"ClsHead",
|
|
33
34
|
"MultiHead",
|
|
34
35
|
"PFHeadLocal",
|
|
36
|
+
"PPFormulaNet_Head",
|
|
35
37
|
]
|
|
36
38
|
|
|
37
39
|
module_name = config.pop("name")
|