magic-pdf 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/data/batch_build_dataset.py +156 -0
- magic_pdf/data/dataset.py +44 -24
- magic_pdf/data/utils.py +108 -9
- magic_pdf/dict2md/ocr_mkcontent.py +4 -3
- magic_pdf/libs/pdf_image_tools.py +11 -6
- magic_pdf/libs/performance_stats.py +12 -1
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/batch_analyze.py +175 -201
- magic_pdf/model/doc_analyze_by_custom_model.py +137 -92
- magic_pdf/model/pdf_extract_kit.py +5 -38
- magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
- magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
- magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
- magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
- magic_pdf/model/sub_modules/model_init.py +50 -37
- magic_pdf/model/sub_modules/model_utils.py +17 -11
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
- magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +10 -18
- magic_pdf/pdf_parse_union_core_v2.py +112 -74
- magic_pdf/post_proc/para_split_v3.py +16 -13
- magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
- magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
- magic_pdf/resources/model_config/model_configs.yaml +1 -1
- magic_pdf/tools/cli.py +30 -12
- magic_pdf/tools/common.py +90 -12
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/METADATA +51 -41
- magic_pdf-1.3.0.dist-info/RECORD +202 -0
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
- magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
- magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
- magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
- magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
- magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
- magic_pdf-1.2.1.dist-info/RECORD +0 -147
- /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
- /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
- /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/WHEEL +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,62 @@
|
|
1
|
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
__all__ = ["build_backbone"]
|
16
|
+
|
17
|
+
|
18
|
+
def build_backbone(config, model_type):
|
19
|
+
if model_type == "det":
|
20
|
+
from .det_mobilenet_v3 import MobileNetV3
|
21
|
+
from .rec_hgnet import PPHGNet_small
|
22
|
+
from .rec_lcnetv3 import PPLCNetV3
|
23
|
+
|
24
|
+
support_dict = [
|
25
|
+
"MobileNetV3",
|
26
|
+
"ResNet",
|
27
|
+
"ResNet_vd",
|
28
|
+
"ResNet_SAST",
|
29
|
+
"PPLCNetV3",
|
30
|
+
"PPHGNet_small",
|
31
|
+
]
|
32
|
+
elif model_type == "rec" or model_type == "cls":
|
33
|
+
from .rec_hgnet import PPHGNet_small
|
34
|
+
from .rec_lcnetv3 import PPLCNetV3
|
35
|
+
from .rec_mobilenet_v3 import MobileNetV3
|
36
|
+
from .rec_svtrnet import SVTRNet
|
37
|
+
from .rec_mv1_enhance import MobileNetV1Enhance
|
38
|
+
|
39
|
+
support_dict = [
|
40
|
+
"MobileNetV1Enhance",
|
41
|
+
"MobileNetV3",
|
42
|
+
"ResNet",
|
43
|
+
"ResNetFPN",
|
44
|
+
"MTB",
|
45
|
+
"ResNet31",
|
46
|
+
"SVTRNet",
|
47
|
+
"ViTSTR",
|
48
|
+
"DenseNet",
|
49
|
+
"PPLCNetV3",
|
50
|
+
"PPHGNet_small",
|
51
|
+
]
|
52
|
+
else:
|
53
|
+
raise NotImplementedError
|
54
|
+
|
55
|
+
module_name = config.pop("name")
|
56
|
+
assert module_name in support_dict, Exception(
|
57
|
+
"when model typs is {}, backbone only support {}".format(
|
58
|
+
model_type, support_dict
|
59
|
+
)
|
60
|
+
)
|
61
|
+
module_class = eval(module_name)(**config)
|
62
|
+
return module_class
|
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py
ADDED
@@ -0,0 +1,269 @@
|
|
1
|
+
from torch import nn
|
2
|
+
|
3
|
+
from ..common import Activation
|
4
|
+
|
5
|
+
|
6
|
+
def make_divisible(v, divisor=8, min_value=None):
|
7
|
+
if min_value is None:
|
8
|
+
min_value = divisor
|
9
|
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
10
|
+
if new_v < 0.9 * v:
|
11
|
+
new_v += divisor
|
12
|
+
return new_v
|
13
|
+
|
14
|
+
|
15
|
+
class ConvBNLayer(nn.Module):
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
in_channels,
|
19
|
+
out_channels,
|
20
|
+
kernel_size,
|
21
|
+
stride,
|
22
|
+
padding,
|
23
|
+
groups=1,
|
24
|
+
if_act=True,
|
25
|
+
act=None,
|
26
|
+
name=None,
|
27
|
+
):
|
28
|
+
super(ConvBNLayer, self).__init__()
|
29
|
+
self.if_act = if_act
|
30
|
+
self.conv = nn.Conv2d(
|
31
|
+
in_channels=in_channels,
|
32
|
+
out_channels=out_channels,
|
33
|
+
kernel_size=kernel_size,
|
34
|
+
stride=stride,
|
35
|
+
padding=padding,
|
36
|
+
groups=groups,
|
37
|
+
bias=False,
|
38
|
+
)
|
39
|
+
|
40
|
+
self.bn = nn.BatchNorm2d(
|
41
|
+
out_channels,
|
42
|
+
)
|
43
|
+
if self.if_act:
|
44
|
+
self.act = Activation(act_type=act, inplace=True)
|
45
|
+
|
46
|
+
def forward(self, x):
|
47
|
+
x = self.conv(x)
|
48
|
+
x = self.bn(x)
|
49
|
+
if self.if_act:
|
50
|
+
x = self.act(x)
|
51
|
+
return x
|
52
|
+
|
53
|
+
|
54
|
+
class SEModule(nn.Module):
|
55
|
+
def __init__(self, in_channels, reduction=4, name=""):
|
56
|
+
super(SEModule, self).__init__()
|
57
|
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
58
|
+
self.conv1 = nn.Conv2d(
|
59
|
+
in_channels=in_channels,
|
60
|
+
out_channels=in_channels // reduction,
|
61
|
+
kernel_size=1,
|
62
|
+
stride=1,
|
63
|
+
padding=0,
|
64
|
+
bias=True,
|
65
|
+
)
|
66
|
+
self.relu1 = Activation(act_type="relu", inplace=True)
|
67
|
+
self.conv2 = nn.Conv2d(
|
68
|
+
in_channels=in_channels // reduction,
|
69
|
+
out_channels=in_channels,
|
70
|
+
kernel_size=1,
|
71
|
+
stride=1,
|
72
|
+
padding=0,
|
73
|
+
bias=True,
|
74
|
+
)
|
75
|
+
self.hard_sigmoid = Activation(act_type="hard_sigmoid", inplace=True)
|
76
|
+
|
77
|
+
def forward(self, inputs):
|
78
|
+
outputs = self.avg_pool(inputs)
|
79
|
+
outputs = self.conv1(outputs)
|
80
|
+
outputs = self.relu1(outputs)
|
81
|
+
outputs = self.conv2(outputs)
|
82
|
+
outputs = self.hard_sigmoid(outputs)
|
83
|
+
outputs = inputs * outputs
|
84
|
+
return outputs
|
85
|
+
|
86
|
+
|
87
|
+
class ResidualUnit(nn.Module):
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
in_channels,
|
91
|
+
mid_channels,
|
92
|
+
out_channels,
|
93
|
+
kernel_size,
|
94
|
+
stride,
|
95
|
+
use_se,
|
96
|
+
act=None,
|
97
|
+
name="",
|
98
|
+
):
|
99
|
+
super(ResidualUnit, self).__init__()
|
100
|
+
self.if_shortcut = stride == 1 and in_channels == out_channels
|
101
|
+
self.if_se = use_se
|
102
|
+
|
103
|
+
self.expand_conv = ConvBNLayer(
|
104
|
+
in_channels=in_channels,
|
105
|
+
out_channels=mid_channels,
|
106
|
+
kernel_size=1,
|
107
|
+
stride=1,
|
108
|
+
padding=0,
|
109
|
+
if_act=True,
|
110
|
+
act=act,
|
111
|
+
name=name + "_expand",
|
112
|
+
)
|
113
|
+
self.bottleneck_conv = ConvBNLayer(
|
114
|
+
in_channels=mid_channels,
|
115
|
+
out_channels=mid_channels,
|
116
|
+
kernel_size=kernel_size,
|
117
|
+
stride=stride,
|
118
|
+
padding=int((kernel_size - 1) // 2),
|
119
|
+
groups=mid_channels,
|
120
|
+
if_act=True,
|
121
|
+
act=act,
|
122
|
+
name=name + "_depthwise",
|
123
|
+
)
|
124
|
+
if self.if_se:
|
125
|
+
self.mid_se = SEModule(mid_channels, name=name + "_se")
|
126
|
+
self.linear_conv = ConvBNLayer(
|
127
|
+
in_channels=mid_channels,
|
128
|
+
out_channels=out_channels,
|
129
|
+
kernel_size=1,
|
130
|
+
stride=1,
|
131
|
+
padding=0,
|
132
|
+
if_act=False,
|
133
|
+
act=None,
|
134
|
+
name=name + "_linear",
|
135
|
+
)
|
136
|
+
|
137
|
+
def forward(self, inputs):
|
138
|
+
x = self.expand_conv(inputs)
|
139
|
+
x = self.bottleneck_conv(x)
|
140
|
+
if self.if_se:
|
141
|
+
x = self.mid_se(x)
|
142
|
+
x = self.linear_conv(x)
|
143
|
+
if self.if_shortcut:
|
144
|
+
x = inputs + x
|
145
|
+
return x
|
146
|
+
|
147
|
+
|
148
|
+
class MobileNetV3(nn.Module):
|
149
|
+
def __init__(
|
150
|
+
self, in_channels=3, model_name="large", scale=0.5, disable_se=False, **kwargs
|
151
|
+
):
|
152
|
+
"""
|
153
|
+
the MobilenetV3 backbone network for detection module.
|
154
|
+
Args:
|
155
|
+
params(dict): the super parameters for build network
|
156
|
+
"""
|
157
|
+
super(MobileNetV3, self).__init__()
|
158
|
+
|
159
|
+
self.disable_se = disable_se
|
160
|
+
|
161
|
+
if model_name == "large":
|
162
|
+
cfg = [
|
163
|
+
# k, exp, c, se, nl, s,
|
164
|
+
[3, 16, 16, False, "relu", 1],
|
165
|
+
[3, 64, 24, False, "relu", 2],
|
166
|
+
[3, 72, 24, False, "relu", 1],
|
167
|
+
[5, 72, 40, True, "relu", 2],
|
168
|
+
[5, 120, 40, True, "relu", 1],
|
169
|
+
[5, 120, 40, True, "relu", 1],
|
170
|
+
[3, 240, 80, False, "hard_swish", 2],
|
171
|
+
[3, 200, 80, False, "hard_swish", 1],
|
172
|
+
[3, 184, 80, False, "hard_swish", 1],
|
173
|
+
[3, 184, 80, False, "hard_swish", 1],
|
174
|
+
[3, 480, 112, True, "hard_swish", 1],
|
175
|
+
[3, 672, 112, True, "hard_swish", 1],
|
176
|
+
[5, 672, 160, True, "hard_swish", 2],
|
177
|
+
[5, 960, 160, True, "hard_swish", 1],
|
178
|
+
[5, 960, 160, True, "hard_swish", 1],
|
179
|
+
]
|
180
|
+
cls_ch_squeeze = 960
|
181
|
+
elif model_name == "small":
|
182
|
+
cfg = [
|
183
|
+
# k, exp, c, se, nl, s,
|
184
|
+
[3, 16, 16, True, "relu", 2],
|
185
|
+
[3, 72, 24, False, "relu", 2],
|
186
|
+
[3, 88, 24, False, "relu", 1],
|
187
|
+
[5, 96, 40, True, "hard_swish", 2],
|
188
|
+
[5, 240, 40, True, "hard_swish", 1],
|
189
|
+
[5, 240, 40, True, "hard_swish", 1],
|
190
|
+
[5, 120, 48, True, "hard_swish", 1],
|
191
|
+
[5, 144, 48, True, "hard_swish", 1],
|
192
|
+
[5, 288, 96, True, "hard_swish", 2],
|
193
|
+
[5, 576, 96, True, "hard_swish", 1],
|
194
|
+
[5, 576, 96, True, "hard_swish", 1],
|
195
|
+
]
|
196
|
+
cls_ch_squeeze = 576
|
197
|
+
else:
|
198
|
+
raise NotImplementedError(
|
199
|
+
"mode[" + model_name + "_model] is not implemented!"
|
200
|
+
)
|
201
|
+
|
202
|
+
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
|
203
|
+
assert (
|
204
|
+
scale in supported_scale
|
205
|
+
), "supported scale are {} but input scale is {}".format(supported_scale, scale)
|
206
|
+
inplanes = 16
|
207
|
+
# conv1
|
208
|
+
self.conv = ConvBNLayer(
|
209
|
+
in_channels=in_channels,
|
210
|
+
out_channels=make_divisible(inplanes * scale),
|
211
|
+
kernel_size=3,
|
212
|
+
stride=2,
|
213
|
+
padding=1,
|
214
|
+
groups=1,
|
215
|
+
if_act=True,
|
216
|
+
act="hard_swish",
|
217
|
+
name="conv1",
|
218
|
+
)
|
219
|
+
|
220
|
+
self.stages = nn.ModuleList()
|
221
|
+
self.out_channels = []
|
222
|
+
block_list = []
|
223
|
+
i = 0
|
224
|
+
inplanes = make_divisible(inplanes * scale)
|
225
|
+
for k, exp, c, se, nl, s in cfg:
|
226
|
+
se = se and not self.disable_se
|
227
|
+
if s == 2 and i > 2:
|
228
|
+
self.out_channels.append(inplanes)
|
229
|
+
self.stages.append(nn.Sequential(*block_list))
|
230
|
+
block_list = []
|
231
|
+
block_list.append(
|
232
|
+
ResidualUnit(
|
233
|
+
in_channels=inplanes,
|
234
|
+
mid_channels=make_divisible(scale * exp),
|
235
|
+
out_channels=make_divisible(scale * c),
|
236
|
+
kernel_size=k,
|
237
|
+
stride=s,
|
238
|
+
use_se=se,
|
239
|
+
act=nl,
|
240
|
+
name="conv" + str(i + 2),
|
241
|
+
)
|
242
|
+
)
|
243
|
+
inplanes = make_divisible(scale * c)
|
244
|
+
i += 1
|
245
|
+
block_list.append(
|
246
|
+
ConvBNLayer(
|
247
|
+
in_channels=inplanes,
|
248
|
+
out_channels=make_divisible(scale * cls_ch_squeeze),
|
249
|
+
kernel_size=1,
|
250
|
+
stride=1,
|
251
|
+
padding=0,
|
252
|
+
groups=1,
|
253
|
+
if_act=True,
|
254
|
+
act="hard_swish",
|
255
|
+
name="conv_last",
|
256
|
+
)
|
257
|
+
)
|
258
|
+
self.stages.append(nn.Sequential(*block_list))
|
259
|
+
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
|
260
|
+
# for i, stage in enumerate(self.stages):
|
261
|
+
# self.add_sublayer(sublayer=stage, name="stage{}".format(i))
|
262
|
+
|
263
|
+
def forward(self, x):
|
264
|
+
x = self.conv(x)
|
265
|
+
out_list = []
|
266
|
+
for stage in self.stages:
|
267
|
+
x = stage(x)
|
268
|
+
out_list.append(x)
|
269
|
+
return out_list
|
@@ -0,0 +1,290 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
from torch import nn
|
4
|
+
|
5
|
+
|
6
|
+
class ConvBNAct(nn.Module):
|
7
|
+
def __init__(
|
8
|
+
self, in_channels, out_channels, kernel_size, stride, groups=1, use_act=True
|
9
|
+
):
|
10
|
+
super().__init__()
|
11
|
+
self.use_act = use_act
|
12
|
+
self.conv = nn.Conv2d(
|
13
|
+
in_channels,
|
14
|
+
out_channels,
|
15
|
+
kernel_size,
|
16
|
+
stride,
|
17
|
+
padding=(kernel_size - 1) // 2,
|
18
|
+
groups=groups,
|
19
|
+
bias=False,
|
20
|
+
)
|
21
|
+
self.bn = nn.BatchNorm2d(out_channels)
|
22
|
+
if self.use_act:
|
23
|
+
self.act = nn.ReLU()
|
24
|
+
|
25
|
+
def forward(self, x):
|
26
|
+
x = self.conv(x)
|
27
|
+
x = self.bn(x)
|
28
|
+
if self.use_act:
|
29
|
+
x = self.act(x)
|
30
|
+
return x
|
31
|
+
|
32
|
+
|
33
|
+
class ESEModule(nn.Module):
|
34
|
+
def __init__(self, channels):
|
35
|
+
super().__init__()
|
36
|
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
37
|
+
self.conv = nn.Conv2d(
|
38
|
+
in_channels=channels,
|
39
|
+
out_channels=channels,
|
40
|
+
kernel_size=1,
|
41
|
+
stride=1,
|
42
|
+
padding=0,
|
43
|
+
)
|
44
|
+
self.sigmoid = nn.Sigmoid()
|
45
|
+
|
46
|
+
def forward(self, x):
|
47
|
+
identity = x
|
48
|
+
x = self.avg_pool(x)
|
49
|
+
x = self.conv(x)
|
50
|
+
x = self.sigmoid(x)
|
51
|
+
return x * identity
|
52
|
+
|
53
|
+
|
54
|
+
class HG_Block(nn.Module):
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
in_channels,
|
58
|
+
mid_channels,
|
59
|
+
out_channels,
|
60
|
+
layer_num,
|
61
|
+
identity=False,
|
62
|
+
):
|
63
|
+
super().__init__()
|
64
|
+
self.identity = identity
|
65
|
+
|
66
|
+
self.layers = nn.ModuleList()
|
67
|
+
self.layers.append(
|
68
|
+
ConvBNAct(
|
69
|
+
in_channels=in_channels,
|
70
|
+
out_channels=mid_channels,
|
71
|
+
kernel_size=3,
|
72
|
+
stride=1,
|
73
|
+
)
|
74
|
+
)
|
75
|
+
for _ in range(layer_num - 1):
|
76
|
+
self.layers.append(
|
77
|
+
ConvBNAct(
|
78
|
+
in_channels=mid_channels,
|
79
|
+
out_channels=mid_channels,
|
80
|
+
kernel_size=3,
|
81
|
+
stride=1,
|
82
|
+
)
|
83
|
+
)
|
84
|
+
|
85
|
+
# feature aggregation
|
86
|
+
total_channels = in_channels + layer_num * mid_channels
|
87
|
+
self.aggregation_conv = ConvBNAct(
|
88
|
+
in_channels=total_channels,
|
89
|
+
out_channels=out_channels,
|
90
|
+
kernel_size=1,
|
91
|
+
stride=1,
|
92
|
+
)
|
93
|
+
self.att = ESEModule(out_channels)
|
94
|
+
|
95
|
+
def forward(self, x):
|
96
|
+
identity = x
|
97
|
+
output = []
|
98
|
+
output.append(x)
|
99
|
+
for layer in self.layers:
|
100
|
+
x = layer(x)
|
101
|
+
output.append(x)
|
102
|
+
x = torch.cat(output, dim=1)
|
103
|
+
x = self.aggregation_conv(x)
|
104
|
+
x = self.att(x)
|
105
|
+
if self.identity:
|
106
|
+
x += identity
|
107
|
+
return x
|
108
|
+
|
109
|
+
|
110
|
+
class HG_Stage(nn.Module):
|
111
|
+
def __init__(
|
112
|
+
self,
|
113
|
+
in_channels,
|
114
|
+
mid_channels,
|
115
|
+
out_channels,
|
116
|
+
block_num,
|
117
|
+
layer_num,
|
118
|
+
downsample=True,
|
119
|
+
stride=[2, 1],
|
120
|
+
):
|
121
|
+
super().__init__()
|
122
|
+
self.downsample = downsample
|
123
|
+
if downsample:
|
124
|
+
self.downsample = ConvBNAct(
|
125
|
+
in_channels=in_channels,
|
126
|
+
out_channels=in_channels,
|
127
|
+
kernel_size=3,
|
128
|
+
stride=stride,
|
129
|
+
groups=in_channels,
|
130
|
+
use_act=False,
|
131
|
+
)
|
132
|
+
|
133
|
+
blocks_list = []
|
134
|
+
blocks_list.append(
|
135
|
+
HG_Block(in_channels, mid_channels, out_channels, layer_num, identity=False)
|
136
|
+
)
|
137
|
+
for _ in range(block_num - 1):
|
138
|
+
blocks_list.append(
|
139
|
+
HG_Block(
|
140
|
+
out_channels, mid_channels, out_channels, layer_num, identity=True
|
141
|
+
)
|
142
|
+
)
|
143
|
+
self.blocks = nn.Sequential(*blocks_list)
|
144
|
+
|
145
|
+
def forward(self, x):
|
146
|
+
if self.downsample:
|
147
|
+
x = self.downsample(x)
|
148
|
+
x = self.blocks(x)
|
149
|
+
return x
|
150
|
+
|
151
|
+
|
152
|
+
class PPHGNet(nn.Module):
|
153
|
+
"""
|
154
|
+
PPHGNet
|
155
|
+
Args:
|
156
|
+
stem_channels: list. Stem channel list of PPHGNet.
|
157
|
+
stage_config: dict. The configuration of each stage of PPHGNet. such as the number of channels, stride, etc.
|
158
|
+
layer_num: int. Number of layers of HG_Block.
|
159
|
+
use_last_conv: boolean. Whether to use a 1x1 convolutional layer before the classification layer.
|
160
|
+
class_expand: int=2048. Number of channels for the last 1x1 convolutional layer.
|
161
|
+
dropout_prob: float. Parameters of dropout, 0.0 means dropout is not used.
|
162
|
+
class_num: int=1000. The number of classes.
|
163
|
+
Returns:
|
164
|
+
model: nn.Layer. Specific PPHGNet model depends on args.
|
165
|
+
"""
|
166
|
+
|
167
|
+
def __init__(
|
168
|
+
self,
|
169
|
+
stem_channels,
|
170
|
+
stage_config,
|
171
|
+
layer_num,
|
172
|
+
in_channels=3,
|
173
|
+
det=False,
|
174
|
+
out_indices=None,
|
175
|
+
):
|
176
|
+
super().__init__()
|
177
|
+
self.det = det
|
178
|
+
self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
|
179
|
+
|
180
|
+
# stem
|
181
|
+
stem_channels.insert(0, in_channels)
|
182
|
+
self.stem = nn.Sequential(
|
183
|
+
*[
|
184
|
+
ConvBNAct(
|
185
|
+
in_channels=stem_channels[i],
|
186
|
+
out_channels=stem_channels[i + 1],
|
187
|
+
kernel_size=3,
|
188
|
+
stride=2 if i == 0 else 1,
|
189
|
+
)
|
190
|
+
for i in range(len(stem_channels) - 1)
|
191
|
+
]
|
192
|
+
)
|
193
|
+
|
194
|
+
if self.det:
|
195
|
+
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
196
|
+
# stages
|
197
|
+
self.stages = nn.ModuleList()
|
198
|
+
self.out_channels = []
|
199
|
+
for block_id, k in enumerate(stage_config):
|
200
|
+
(
|
201
|
+
in_channels,
|
202
|
+
mid_channels,
|
203
|
+
out_channels,
|
204
|
+
block_num,
|
205
|
+
downsample,
|
206
|
+
stride,
|
207
|
+
) = stage_config[k]
|
208
|
+
self.stages.append(
|
209
|
+
HG_Stage(
|
210
|
+
in_channels,
|
211
|
+
mid_channels,
|
212
|
+
out_channels,
|
213
|
+
block_num,
|
214
|
+
layer_num,
|
215
|
+
downsample,
|
216
|
+
stride,
|
217
|
+
)
|
218
|
+
)
|
219
|
+
if block_id in self.out_indices:
|
220
|
+
self.out_channels.append(out_channels)
|
221
|
+
|
222
|
+
if not self.det:
|
223
|
+
self.out_channels = stage_config["stage4"][2]
|
224
|
+
|
225
|
+
self._init_weights()
|
226
|
+
|
227
|
+
def _init_weights(self):
|
228
|
+
for m in self.modules():
|
229
|
+
if isinstance(m, nn.Conv2d):
|
230
|
+
nn.init.kaiming_normal_(m.weight)
|
231
|
+
elif isinstance(m, nn.BatchNorm2d):
|
232
|
+
nn.init.ones_(m.weight)
|
233
|
+
nn.init.zeros_(m.bias)
|
234
|
+
elif isinstance(m, nn.Linear):
|
235
|
+
nn.init.zeros_(m.bias)
|
236
|
+
|
237
|
+
def forward(self, x):
|
238
|
+
x = self.stem(x)
|
239
|
+
if self.det:
|
240
|
+
x = self.pool(x)
|
241
|
+
|
242
|
+
out = []
|
243
|
+
for i, stage in enumerate(self.stages):
|
244
|
+
x = stage(x)
|
245
|
+
if self.det and i in self.out_indices:
|
246
|
+
out.append(x)
|
247
|
+
if self.det:
|
248
|
+
return out
|
249
|
+
|
250
|
+
if self.training:
|
251
|
+
x = F.adaptive_avg_pool2d(x, [1, 40])
|
252
|
+
else:
|
253
|
+
x = F.avg_pool2d(x, [3, 2])
|
254
|
+
return x
|
255
|
+
|
256
|
+
|
257
|
+
def PPHGNet_small(pretrained=False, use_ssld=False, det=False, **kwargs):
|
258
|
+
"""
|
259
|
+
PPHGNet_small
|
260
|
+
Args:
|
261
|
+
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
262
|
+
If str, means the path of the pretrained model.
|
263
|
+
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
264
|
+
Returns:
|
265
|
+
model: nn.Layer. Specific `PPHGNet_small` model depends on args.
|
266
|
+
"""
|
267
|
+
stage_config_det = {
|
268
|
+
# in_channels, mid_channels, out_channels, blocks, downsample
|
269
|
+
"stage1": [128, 128, 256, 1, False, 2],
|
270
|
+
"stage2": [256, 160, 512, 1, True, 2],
|
271
|
+
"stage3": [512, 192, 768, 2, True, 2],
|
272
|
+
"stage4": [768, 224, 1024, 1, True, 2],
|
273
|
+
}
|
274
|
+
|
275
|
+
stage_config_rec = {
|
276
|
+
# in_channels, mid_channels, out_channels, blocks, downsample
|
277
|
+
"stage1": [128, 128, 256, 1, True, [2, 1]],
|
278
|
+
"stage2": [256, 160, 512, 1, True, [1, 2]],
|
279
|
+
"stage3": [512, 192, 768, 2, True, [2, 1]],
|
280
|
+
"stage4": [768, 224, 1024, 1, True, [2, 1]],
|
281
|
+
}
|
282
|
+
|
283
|
+
model = PPHGNet(
|
284
|
+
stem_channels=[64, 64, 128],
|
285
|
+
stage_config=stage_config_det if det else stage_config_rec,
|
286
|
+
layer_num=6,
|
287
|
+
det=det,
|
288
|
+
**kwargs
|
289
|
+
)
|
290
|
+
return model
|