magic-pdf 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. magic_pdf/data/batch_build_dataset.py +156 -0
  2. magic_pdf/data/dataset.py +44 -24
  3. magic_pdf/data/utils.py +108 -9
  4. magic_pdf/dict2md/ocr_mkcontent.py +4 -3
  5. magic_pdf/libs/pdf_image_tools.py +11 -6
  6. magic_pdf/libs/performance_stats.py +12 -1
  7. magic_pdf/libs/version.py +1 -1
  8. magic_pdf/model/batch_analyze.py +175 -201
  9. magic_pdf/model/doc_analyze_by_custom_model.py +137 -92
  10. magic_pdf/model/pdf_extract_kit.py +5 -38
  11. magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
  12. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
  13. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
  14. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
  15. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
  16. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
  17. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
  18. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
  19. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
  20. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
  21. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
  22. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
  23. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
  24. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
  25. magic_pdf/model/sub_modules/model_init.py +50 -37
  26. magic_pdf/model/sub_modules/model_utils.py +17 -11
  27. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
  29. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
  30. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
  31. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
  32. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
  33. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
  34. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
  35. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
  36. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
  37. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
  38. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
  39. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
  40. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
  41. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
  42. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
  43. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
  44. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
  45. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
  46. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
  47. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
  48. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
  49. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
  50. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
  51. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
  52. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
  53. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
  54. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
  55. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
  56. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
  57. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
  58. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
  59. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
  60. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
  61. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
  62. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
  63. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
  64. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
  65. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
  66. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
  67. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
  68. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
  69. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
  70. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
  71. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
  72. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
  73. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
  74. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
  75. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
  76. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
  77. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
  78. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
  79. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +10 -18
  80. magic_pdf/pdf_parse_union_core_v2.py +112 -74
  81. magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
  82. magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
  83. magic_pdf/resources/model_config/model_configs.yaml +1 -1
  84. magic_pdf/tools/cli.py +30 -12
  85. magic_pdf/tools/common.py +90 -12
  86. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/METADATA +50 -40
  87. magic_pdf-1.3.0.dist-info/RECORD +202 -0
  88. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
  89. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
  90. magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
  91. magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
  92. magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
  93. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
  94. magic_pdf-1.2.2.dist-info/RECORD +0 -147
  95. /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
  96. /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
  97. /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
  98. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/LICENSE.md +0 -0
  99. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/WHEEL +0 -0
  100. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/entry_points.txt +0 -0
  101. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,62 @@
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ __all__ = ["build_backbone"]
16
+
17
+
18
+ def build_backbone(config, model_type):
19
+ if model_type == "det":
20
+ from .det_mobilenet_v3 import MobileNetV3
21
+ from .rec_hgnet import PPHGNet_small
22
+ from .rec_lcnetv3 import PPLCNetV3
23
+
24
+ support_dict = [
25
+ "MobileNetV3",
26
+ "ResNet",
27
+ "ResNet_vd",
28
+ "ResNet_SAST",
29
+ "PPLCNetV3",
30
+ "PPHGNet_small",
31
+ ]
32
+ elif model_type == "rec" or model_type == "cls":
33
+ from .rec_hgnet import PPHGNet_small
34
+ from .rec_lcnetv3 import PPLCNetV3
35
+ from .rec_mobilenet_v3 import MobileNetV3
36
+ from .rec_svtrnet import SVTRNet
37
+ from .rec_mv1_enhance import MobileNetV1Enhance
38
+
39
+ support_dict = [
40
+ "MobileNetV1Enhance",
41
+ "MobileNetV3",
42
+ "ResNet",
43
+ "ResNetFPN",
44
+ "MTB",
45
+ "ResNet31",
46
+ "SVTRNet",
47
+ "ViTSTR",
48
+ "DenseNet",
49
+ "PPLCNetV3",
50
+ "PPHGNet_small",
51
+ ]
52
+ else:
53
+ raise NotImplementedError
54
+
55
+ module_name = config.pop("name")
56
+ assert module_name in support_dict, Exception(
57
+ "when model typs is {}, backbone only support {}".format(
58
+ model_type, support_dict
59
+ )
60
+ )
61
+ module_class = eval(module_name)(**config)
62
+ return module_class
@@ -0,0 +1,269 @@
1
+ from torch import nn
2
+
3
+ from ..common import Activation
4
+
5
+
6
+ def make_divisible(v, divisor=8, min_value=None):
7
+ if min_value is None:
8
+ min_value = divisor
9
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
10
+ if new_v < 0.9 * v:
11
+ new_v += divisor
12
+ return new_v
13
+
14
+
15
+ class ConvBNLayer(nn.Module):
16
+ def __init__(
17
+ self,
18
+ in_channels,
19
+ out_channels,
20
+ kernel_size,
21
+ stride,
22
+ padding,
23
+ groups=1,
24
+ if_act=True,
25
+ act=None,
26
+ name=None,
27
+ ):
28
+ super(ConvBNLayer, self).__init__()
29
+ self.if_act = if_act
30
+ self.conv = nn.Conv2d(
31
+ in_channels=in_channels,
32
+ out_channels=out_channels,
33
+ kernel_size=kernel_size,
34
+ stride=stride,
35
+ padding=padding,
36
+ groups=groups,
37
+ bias=False,
38
+ )
39
+
40
+ self.bn = nn.BatchNorm2d(
41
+ out_channels,
42
+ )
43
+ if self.if_act:
44
+ self.act = Activation(act_type=act, inplace=True)
45
+
46
+ def forward(self, x):
47
+ x = self.conv(x)
48
+ x = self.bn(x)
49
+ if self.if_act:
50
+ x = self.act(x)
51
+ return x
52
+
53
+
54
+ class SEModule(nn.Module):
55
+ def __init__(self, in_channels, reduction=4, name=""):
56
+ super(SEModule, self).__init__()
57
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
58
+ self.conv1 = nn.Conv2d(
59
+ in_channels=in_channels,
60
+ out_channels=in_channels // reduction,
61
+ kernel_size=1,
62
+ stride=1,
63
+ padding=0,
64
+ bias=True,
65
+ )
66
+ self.relu1 = Activation(act_type="relu", inplace=True)
67
+ self.conv2 = nn.Conv2d(
68
+ in_channels=in_channels // reduction,
69
+ out_channels=in_channels,
70
+ kernel_size=1,
71
+ stride=1,
72
+ padding=0,
73
+ bias=True,
74
+ )
75
+ self.hard_sigmoid = Activation(act_type="hard_sigmoid", inplace=True)
76
+
77
+ def forward(self, inputs):
78
+ outputs = self.avg_pool(inputs)
79
+ outputs = self.conv1(outputs)
80
+ outputs = self.relu1(outputs)
81
+ outputs = self.conv2(outputs)
82
+ outputs = self.hard_sigmoid(outputs)
83
+ outputs = inputs * outputs
84
+ return outputs
85
+
86
+
87
+ class ResidualUnit(nn.Module):
88
+ def __init__(
89
+ self,
90
+ in_channels,
91
+ mid_channels,
92
+ out_channels,
93
+ kernel_size,
94
+ stride,
95
+ use_se,
96
+ act=None,
97
+ name="",
98
+ ):
99
+ super(ResidualUnit, self).__init__()
100
+ self.if_shortcut = stride == 1 and in_channels == out_channels
101
+ self.if_se = use_se
102
+
103
+ self.expand_conv = ConvBNLayer(
104
+ in_channels=in_channels,
105
+ out_channels=mid_channels,
106
+ kernel_size=1,
107
+ stride=1,
108
+ padding=0,
109
+ if_act=True,
110
+ act=act,
111
+ name=name + "_expand",
112
+ )
113
+ self.bottleneck_conv = ConvBNLayer(
114
+ in_channels=mid_channels,
115
+ out_channels=mid_channels,
116
+ kernel_size=kernel_size,
117
+ stride=stride,
118
+ padding=int((kernel_size - 1) // 2),
119
+ groups=mid_channels,
120
+ if_act=True,
121
+ act=act,
122
+ name=name + "_depthwise",
123
+ )
124
+ if self.if_se:
125
+ self.mid_se = SEModule(mid_channels, name=name + "_se")
126
+ self.linear_conv = ConvBNLayer(
127
+ in_channels=mid_channels,
128
+ out_channels=out_channels,
129
+ kernel_size=1,
130
+ stride=1,
131
+ padding=0,
132
+ if_act=False,
133
+ act=None,
134
+ name=name + "_linear",
135
+ )
136
+
137
+ def forward(self, inputs):
138
+ x = self.expand_conv(inputs)
139
+ x = self.bottleneck_conv(x)
140
+ if self.if_se:
141
+ x = self.mid_se(x)
142
+ x = self.linear_conv(x)
143
+ if self.if_shortcut:
144
+ x = inputs + x
145
+ return x
146
+
147
+
148
+ class MobileNetV3(nn.Module):
149
+ def __init__(
150
+ self, in_channels=3, model_name="large", scale=0.5, disable_se=False, **kwargs
151
+ ):
152
+ """
153
+ the MobilenetV3 backbone network for detection module.
154
+ Args:
155
+ params(dict): the super parameters for build network
156
+ """
157
+ super(MobileNetV3, self).__init__()
158
+
159
+ self.disable_se = disable_se
160
+
161
+ if model_name == "large":
162
+ cfg = [
163
+ # k, exp, c, se, nl, s,
164
+ [3, 16, 16, False, "relu", 1],
165
+ [3, 64, 24, False, "relu", 2],
166
+ [3, 72, 24, False, "relu", 1],
167
+ [5, 72, 40, True, "relu", 2],
168
+ [5, 120, 40, True, "relu", 1],
169
+ [5, 120, 40, True, "relu", 1],
170
+ [3, 240, 80, False, "hard_swish", 2],
171
+ [3, 200, 80, False, "hard_swish", 1],
172
+ [3, 184, 80, False, "hard_swish", 1],
173
+ [3, 184, 80, False, "hard_swish", 1],
174
+ [3, 480, 112, True, "hard_swish", 1],
175
+ [3, 672, 112, True, "hard_swish", 1],
176
+ [5, 672, 160, True, "hard_swish", 2],
177
+ [5, 960, 160, True, "hard_swish", 1],
178
+ [5, 960, 160, True, "hard_swish", 1],
179
+ ]
180
+ cls_ch_squeeze = 960
181
+ elif model_name == "small":
182
+ cfg = [
183
+ # k, exp, c, se, nl, s,
184
+ [3, 16, 16, True, "relu", 2],
185
+ [3, 72, 24, False, "relu", 2],
186
+ [3, 88, 24, False, "relu", 1],
187
+ [5, 96, 40, True, "hard_swish", 2],
188
+ [5, 240, 40, True, "hard_swish", 1],
189
+ [5, 240, 40, True, "hard_swish", 1],
190
+ [5, 120, 48, True, "hard_swish", 1],
191
+ [5, 144, 48, True, "hard_swish", 1],
192
+ [5, 288, 96, True, "hard_swish", 2],
193
+ [5, 576, 96, True, "hard_swish", 1],
194
+ [5, 576, 96, True, "hard_swish", 1],
195
+ ]
196
+ cls_ch_squeeze = 576
197
+ else:
198
+ raise NotImplementedError(
199
+ "mode[" + model_name + "_model] is not implemented!"
200
+ )
201
+
202
+ supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
203
+ assert (
204
+ scale in supported_scale
205
+ ), "supported scale are {} but input scale is {}".format(supported_scale, scale)
206
+ inplanes = 16
207
+ # conv1
208
+ self.conv = ConvBNLayer(
209
+ in_channels=in_channels,
210
+ out_channels=make_divisible(inplanes * scale),
211
+ kernel_size=3,
212
+ stride=2,
213
+ padding=1,
214
+ groups=1,
215
+ if_act=True,
216
+ act="hard_swish",
217
+ name="conv1",
218
+ )
219
+
220
+ self.stages = nn.ModuleList()
221
+ self.out_channels = []
222
+ block_list = []
223
+ i = 0
224
+ inplanes = make_divisible(inplanes * scale)
225
+ for k, exp, c, se, nl, s in cfg:
226
+ se = se and not self.disable_se
227
+ if s == 2 and i > 2:
228
+ self.out_channels.append(inplanes)
229
+ self.stages.append(nn.Sequential(*block_list))
230
+ block_list = []
231
+ block_list.append(
232
+ ResidualUnit(
233
+ in_channels=inplanes,
234
+ mid_channels=make_divisible(scale * exp),
235
+ out_channels=make_divisible(scale * c),
236
+ kernel_size=k,
237
+ stride=s,
238
+ use_se=se,
239
+ act=nl,
240
+ name="conv" + str(i + 2),
241
+ )
242
+ )
243
+ inplanes = make_divisible(scale * c)
244
+ i += 1
245
+ block_list.append(
246
+ ConvBNLayer(
247
+ in_channels=inplanes,
248
+ out_channels=make_divisible(scale * cls_ch_squeeze),
249
+ kernel_size=1,
250
+ stride=1,
251
+ padding=0,
252
+ groups=1,
253
+ if_act=True,
254
+ act="hard_swish",
255
+ name="conv_last",
256
+ )
257
+ )
258
+ self.stages.append(nn.Sequential(*block_list))
259
+ self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
260
+ # for i, stage in enumerate(self.stages):
261
+ # self.add_sublayer(sublayer=stage, name="stage{}".format(i))
262
+
263
+ def forward(self, x):
264
+ x = self.conv(x)
265
+ out_list = []
266
+ for stage in self.stages:
267
+ x = stage(x)
268
+ out_list.append(x)
269
+ return out_list
@@ -0,0 +1,290 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+
6
+ class ConvBNAct(nn.Module):
7
+ def __init__(
8
+ self, in_channels, out_channels, kernel_size, stride, groups=1, use_act=True
9
+ ):
10
+ super().__init__()
11
+ self.use_act = use_act
12
+ self.conv = nn.Conv2d(
13
+ in_channels,
14
+ out_channels,
15
+ kernel_size,
16
+ stride,
17
+ padding=(kernel_size - 1) // 2,
18
+ groups=groups,
19
+ bias=False,
20
+ )
21
+ self.bn = nn.BatchNorm2d(out_channels)
22
+ if self.use_act:
23
+ self.act = nn.ReLU()
24
+
25
+ def forward(self, x):
26
+ x = self.conv(x)
27
+ x = self.bn(x)
28
+ if self.use_act:
29
+ x = self.act(x)
30
+ return x
31
+
32
+
33
+ class ESEModule(nn.Module):
34
+ def __init__(self, channels):
35
+ super().__init__()
36
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
37
+ self.conv = nn.Conv2d(
38
+ in_channels=channels,
39
+ out_channels=channels,
40
+ kernel_size=1,
41
+ stride=1,
42
+ padding=0,
43
+ )
44
+ self.sigmoid = nn.Sigmoid()
45
+
46
+ def forward(self, x):
47
+ identity = x
48
+ x = self.avg_pool(x)
49
+ x = self.conv(x)
50
+ x = self.sigmoid(x)
51
+ return x * identity
52
+
53
+
54
+ class HG_Block(nn.Module):
55
+ def __init__(
56
+ self,
57
+ in_channels,
58
+ mid_channels,
59
+ out_channels,
60
+ layer_num,
61
+ identity=False,
62
+ ):
63
+ super().__init__()
64
+ self.identity = identity
65
+
66
+ self.layers = nn.ModuleList()
67
+ self.layers.append(
68
+ ConvBNAct(
69
+ in_channels=in_channels,
70
+ out_channels=mid_channels,
71
+ kernel_size=3,
72
+ stride=1,
73
+ )
74
+ )
75
+ for _ in range(layer_num - 1):
76
+ self.layers.append(
77
+ ConvBNAct(
78
+ in_channels=mid_channels,
79
+ out_channels=mid_channels,
80
+ kernel_size=3,
81
+ stride=1,
82
+ )
83
+ )
84
+
85
+ # feature aggregation
86
+ total_channels = in_channels + layer_num * mid_channels
87
+ self.aggregation_conv = ConvBNAct(
88
+ in_channels=total_channels,
89
+ out_channels=out_channels,
90
+ kernel_size=1,
91
+ stride=1,
92
+ )
93
+ self.att = ESEModule(out_channels)
94
+
95
+ def forward(self, x):
96
+ identity = x
97
+ output = []
98
+ output.append(x)
99
+ for layer in self.layers:
100
+ x = layer(x)
101
+ output.append(x)
102
+ x = torch.cat(output, dim=1)
103
+ x = self.aggregation_conv(x)
104
+ x = self.att(x)
105
+ if self.identity:
106
+ x += identity
107
+ return x
108
+
109
+
110
+ class HG_Stage(nn.Module):
111
+ def __init__(
112
+ self,
113
+ in_channels,
114
+ mid_channels,
115
+ out_channels,
116
+ block_num,
117
+ layer_num,
118
+ downsample=True,
119
+ stride=[2, 1],
120
+ ):
121
+ super().__init__()
122
+ self.downsample = downsample
123
+ if downsample:
124
+ self.downsample = ConvBNAct(
125
+ in_channels=in_channels,
126
+ out_channels=in_channels,
127
+ kernel_size=3,
128
+ stride=stride,
129
+ groups=in_channels,
130
+ use_act=False,
131
+ )
132
+
133
+ blocks_list = []
134
+ blocks_list.append(
135
+ HG_Block(in_channels, mid_channels, out_channels, layer_num, identity=False)
136
+ )
137
+ for _ in range(block_num - 1):
138
+ blocks_list.append(
139
+ HG_Block(
140
+ out_channels, mid_channels, out_channels, layer_num, identity=True
141
+ )
142
+ )
143
+ self.blocks = nn.Sequential(*blocks_list)
144
+
145
+ def forward(self, x):
146
+ if self.downsample:
147
+ x = self.downsample(x)
148
+ x = self.blocks(x)
149
+ return x
150
+
151
+
152
+ class PPHGNet(nn.Module):
153
+ """
154
+ PPHGNet
155
+ Args:
156
+ stem_channels: list. Stem channel list of PPHGNet.
157
+ stage_config: dict. The configuration of each stage of PPHGNet. such as the number of channels, stride, etc.
158
+ layer_num: int. Number of layers of HG_Block.
159
+ use_last_conv: boolean. Whether to use a 1x1 convolutional layer before the classification layer.
160
+ class_expand: int=2048. Number of channels for the last 1x1 convolutional layer.
161
+ dropout_prob: float. Parameters of dropout, 0.0 means dropout is not used.
162
+ class_num: int=1000. The number of classes.
163
+ Returns:
164
+ model: nn.Layer. Specific PPHGNet model depends on args.
165
+ """
166
+
167
+ def __init__(
168
+ self,
169
+ stem_channels,
170
+ stage_config,
171
+ layer_num,
172
+ in_channels=3,
173
+ det=False,
174
+ out_indices=None,
175
+ ):
176
+ super().__init__()
177
+ self.det = det
178
+ self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
179
+
180
+ # stem
181
+ stem_channels.insert(0, in_channels)
182
+ self.stem = nn.Sequential(
183
+ *[
184
+ ConvBNAct(
185
+ in_channels=stem_channels[i],
186
+ out_channels=stem_channels[i + 1],
187
+ kernel_size=3,
188
+ stride=2 if i == 0 else 1,
189
+ )
190
+ for i in range(len(stem_channels) - 1)
191
+ ]
192
+ )
193
+
194
+ if self.det:
195
+ self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
196
+ # stages
197
+ self.stages = nn.ModuleList()
198
+ self.out_channels = []
199
+ for block_id, k in enumerate(stage_config):
200
+ (
201
+ in_channels,
202
+ mid_channels,
203
+ out_channels,
204
+ block_num,
205
+ downsample,
206
+ stride,
207
+ ) = stage_config[k]
208
+ self.stages.append(
209
+ HG_Stage(
210
+ in_channels,
211
+ mid_channels,
212
+ out_channels,
213
+ block_num,
214
+ layer_num,
215
+ downsample,
216
+ stride,
217
+ )
218
+ )
219
+ if block_id in self.out_indices:
220
+ self.out_channels.append(out_channels)
221
+
222
+ if not self.det:
223
+ self.out_channels = stage_config["stage4"][2]
224
+
225
+ self._init_weights()
226
+
227
+ def _init_weights(self):
228
+ for m in self.modules():
229
+ if isinstance(m, nn.Conv2d):
230
+ nn.init.kaiming_normal_(m.weight)
231
+ elif isinstance(m, nn.BatchNorm2d):
232
+ nn.init.ones_(m.weight)
233
+ nn.init.zeros_(m.bias)
234
+ elif isinstance(m, nn.Linear):
235
+ nn.init.zeros_(m.bias)
236
+
237
+ def forward(self, x):
238
+ x = self.stem(x)
239
+ if self.det:
240
+ x = self.pool(x)
241
+
242
+ out = []
243
+ for i, stage in enumerate(self.stages):
244
+ x = stage(x)
245
+ if self.det and i in self.out_indices:
246
+ out.append(x)
247
+ if self.det:
248
+ return out
249
+
250
+ if self.training:
251
+ x = F.adaptive_avg_pool2d(x, [1, 40])
252
+ else:
253
+ x = F.avg_pool2d(x, [3, 2])
254
+ return x
255
+
256
+
257
+ def PPHGNet_small(pretrained=False, use_ssld=False, det=False, **kwargs):
258
+ """
259
+ PPHGNet_small
260
+ Args:
261
+ pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
262
+ If str, means the path of the pretrained model.
263
+ use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
264
+ Returns:
265
+ model: nn.Layer. Specific `PPHGNet_small` model depends on args.
266
+ """
267
+ stage_config_det = {
268
+ # in_channels, mid_channels, out_channels, blocks, downsample
269
+ "stage1": [128, 128, 256, 1, False, 2],
270
+ "stage2": [256, 160, 512, 1, True, 2],
271
+ "stage3": [512, 192, 768, 2, True, 2],
272
+ "stage4": [768, 224, 1024, 1, True, 2],
273
+ }
274
+
275
+ stage_config_rec = {
276
+ # in_channels, mid_channels, out_channels, blocks, downsample
277
+ "stage1": [128, 128, 256, 1, True, [2, 1]],
278
+ "stage2": [256, 160, 512, 1, True, [1, 2]],
279
+ "stage3": [512, 192, 768, 2, True, [2, 1]],
280
+ "stage4": [768, 224, 1024, 1, True, [2, 1]],
281
+ }
282
+
283
+ model = PPHGNet(
284
+ stem_channels=[64, 64, 128],
285
+ stage_config=stage_config_det if det else stage_config_rec,
286
+ layer_num=6,
287
+ det=det,
288
+ **kwargs
289
+ )
290
+ return model