magic-pdf 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. magic_pdf/data/batch_build_dataset.py +156 -0
  2. magic_pdf/data/dataset.py +44 -24
  3. magic_pdf/data/utils.py +108 -9
  4. magic_pdf/dict2md/ocr_mkcontent.py +4 -3
  5. magic_pdf/libs/pdf_image_tools.py +11 -6
  6. magic_pdf/libs/performance_stats.py +12 -1
  7. magic_pdf/libs/version.py +1 -1
  8. magic_pdf/model/batch_analyze.py +175 -201
  9. magic_pdf/model/doc_analyze_by_custom_model.py +137 -92
  10. magic_pdf/model/pdf_extract_kit.py +5 -38
  11. magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
  12. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
  13. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
  14. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
  15. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
  16. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
  17. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
  18. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
  19. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
  20. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
  21. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
  22. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
  23. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
  24. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
  25. magic_pdf/model/sub_modules/model_init.py +50 -37
  26. magic_pdf/model/sub_modules/model_utils.py +17 -11
  27. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
  29. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
  30. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
  31. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
  32. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
  33. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
  34. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
  35. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
  36. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
  37. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
  38. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
  39. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
  40. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
  41. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
  42. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
  43. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
  44. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
  45. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
  46. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
  47. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
  48. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
  49. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
  50. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
  51. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
  52. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
  53. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
  54. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
  55. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
  56. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
  57. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
  58. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
  59. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
  60. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
  61. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
  62. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
  63. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
  64. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
  65. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
  66. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
  67. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
  68. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
  69. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
  70. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
  71. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
  72. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
  73. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
  74. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
  75. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
  76. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
  77. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
  78. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
  79. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +10 -18
  80. magic_pdf/pdf_parse_union_core_v2.py +112 -74
  81. magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
  82. magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
  83. magic_pdf/resources/model_config/model_configs.yaml +1 -1
  84. magic_pdf/tools/cli.py +30 -12
  85. magic_pdf/tools/common.py +90 -12
  86. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/METADATA +50 -40
  87. magic_pdf-1.3.0.dist-info/RECORD +202 -0
  88. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
  89. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
  90. magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
  91. magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
  92. magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
  93. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
  94. magic_pdf-1.2.2.dist-info/RECORD +0 -147
  95. /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
  96. /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
  97. /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
  98. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/LICENSE.md +0 -0
  99. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/WHEEL +0 -0
  100. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/entry_points.txt +0 -0
  101. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,456 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ from ..backbones.det_mobilenet_v3 import SEModule
6
+ from ..necks.intracl import IntraCLBlock
7
+
8
+
9
+ def hard_swish(x, inplace=True):
10
+ return x * F.relu6(x + 3.0, inplace=inplace) / 6.0
11
+
12
+
13
+ class DSConv(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_channels,
17
+ out_channels,
18
+ kernel_size,
19
+ padding,
20
+ stride=1,
21
+ groups=None,
22
+ if_act=True,
23
+ act="relu",
24
+ **kwargs
25
+ ):
26
+ super(DSConv, self).__init__()
27
+ if groups == None:
28
+ groups = in_channels
29
+ self.if_act = if_act
30
+ self.act = act
31
+ self.conv1 = nn.Conv2d(
32
+ in_channels=in_channels,
33
+ out_channels=in_channels,
34
+ kernel_size=kernel_size,
35
+ stride=stride,
36
+ padding=padding,
37
+ groups=groups,
38
+ bias=False,
39
+ )
40
+
41
+ self.bn1 = nn.BatchNorm2d(in_channels)
42
+
43
+ self.conv2 = nn.Conv2d(
44
+ in_channels=in_channels,
45
+ out_channels=int(in_channels * 4),
46
+ kernel_size=1,
47
+ stride=1,
48
+ bias=False,
49
+ )
50
+
51
+ self.bn2 = nn.BatchNorm2d(int(in_channels * 4))
52
+
53
+ self.conv3 = nn.Conv2d(
54
+ in_channels=int(in_channels * 4),
55
+ out_channels=out_channels,
56
+ kernel_size=1,
57
+ stride=1,
58
+ bias=False,
59
+ )
60
+ self._c = [in_channels, out_channels]
61
+ if in_channels != out_channels:
62
+ self.conv_end = nn.Conv2d(
63
+ in_channels=in_channels,
64
+ out_channels=out_channels,
65
+ kernel_size=1,
66
+ stride=1,
67
+ bias=False,
68
+ )
69
+
70
+ def forward(self, inputs):
71
+ x = self.conv1(inputs)
72
+ x = self.bn1(x)
73
+
74
+ x = self.conv2(x)
75
+ x = self.bn2(x)
76
+ if self.if_act:
77
+ if self.act == "relu":
78
+ x = F.relu(x)
79
+ elif self.act == "hardswish":
80
+ x = hard_swish(x)
81
+ else:
82
+ print(
83
+ "The activation function({}) is selected incorrectly.".format(
84
+ self.act
85
+ )
86
+ )
87
+ exit()
88
+
89
+ x = self.conv3(x)
90
+ if self._c[0] != self._c[1]:
91
+ x = x + self.conv_end(inputs)
92
+ return x
93
+
94
+
95
+ class DBFPN(nn.Module):
96
+ def __init__(self, in_channels, out_channels, use_asf=False, **kwargs):
97
+ super(DBFPN, self).__init__()
98
+ self.out_channels = out_channels
99
+ self.use_asf = use_asf
100
+
101
+ self.in2_conv = nn.Conv2d(
102
+ in_channels=in_channels[0],
103
+ out_channels=self.out_channels,
104
+ kernel_size=1,
105
+ bias=False,
106
+ )
107
+ self.in3_conv = nn.Conv2d(
108
+ in_channels=in_channels[1],
109
+ out_channels=self.out_channels,
110
+ kernel_size=1,
111
+ bias=False,
112
+ )
113
+ self.in4_conv = nn.Conv2d(
114
+ in_channels=in_channels[2],
115
+ out_channels=self.out_channels,
116
+ kernel_size=1,
117
+ bias=False,
118
+ )
119
+ self.in5_conv = nn.Conv2d(
120
+ in_channels=in_channels[3],
121
+ out_channels=self.out_channels,
122
+ kernel_size=1,
123
+ bias=False,
124
+ )
125
+ self.p5_conv = nn.Conv2d(
126
+ in_channels=self.out_channels,
127
+ out_channels=self.out_channels // 4,
128
+ kernel_size=3,
129
+ padding=1,
130
+ bias=False,
131
+ )
132
+ self.p4_conv = nn.Conv2d(
133
+ in_channels=self.out_channels,
134
+ out_channels=self.out_channels // 4,
135
+ kernel_size=3,
136
+ padding=1,
137
+ bias=False,
138
+ )
139
+ self.p3_conv = nn.Conv2d(
140
+ in_channels=self.out_channels,
141
+ out_channels=self.out_channels // 4,
142
+ kernel_size=3,
143
+ padding=1,
144
+ bias=False,
145
+ )
146
+ self.p2_conv = nn.Conv2d(
147
+ in_channels=self.out_channels,
148
+ out_channels=self.out_channels // 4,
149
+ kernel_size=3,
150
+ padding=1,
151
+ bias=False,
152
+ )
153
+
154
+ if self.use_asf is True:
155
+ self.asf = ASFBlock(self.out_channels, self.out_channels // 4)
156
+
157
+ def forward(self, x):
158
+ c2, c3, c4, c5 = x
159
+
160
+ in5 = self.in5_conv(c5)
161
+ in4 = self.in4_conv(c4)
162
+ in3 = self.in3_conv(c3)
163
+ in2 = self.in2_conv(c2)
164
+
165
+ out4 = in4 + F.interpolate(
166
+ in5,
167
+ scale_factor=2,
168
+ mode="nearest",
169
+ ) # align_mode=1) # 1/16
170
+ out3 = in3 + F.interpolate(
171
+ out4,
172
+ scale_factor=2,
173
+ mode="nearest",
174
+ ) # align_mode=1) # 1/8
175
+ out2 = in2 + F.interpolate(
176
+ out3,
177
+ scale_factor=2,
178
+ mode="nearest",
179
+ ) # align_mode=1) # 1/4
180
+
181
+ p5 = self.p5_conv(in5)
182
+ p4 = self.p4_conv(out4)
183
+ p3 = self.p3_conv(out3)
184
+ p2 = self.p2_conv(out2)
185
+ p5 = F.interpolate(
186
+ p5,
187
+ scale_factor=8,
188
+ mode="nearest",
189
+ ) # align_mode=1)
190
+ p4 = F.interpolate(
191
+ p4,
192
+ scale_factor=4,
193
+ mode="nearest",
194
+ ) # align_mode=1)
195
+ p3 = F.interpolate(
196
+ p3,
197
+ scale_factor=2,
198
+ mode="nearest",
199
+ ) # align_mode=1)
200
+
201
+ fuse = torch.cat([p5, p4, p3, p2], dim=1)
202
+
203
+ if self.use_asf is True:
204
+ fuse = self.asf(fuse, [p5, p4, p3, p2])
205
+
206
+ return fuse
207
+
208
+
209
+ class RSELayer(nn.Module):
210
+ def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):
211
+ super(RSELayer, self).__init__()
212
+ self.out_channels = out_channels
213
+ self.in_conv = nn.Conv2d(
214
+ in_channels=in_channels,
215
+ out_channels=self.out_channels,
216
+ kernel_size=kernel_size,
217
+ padding=int(kernel_size // 2),
218
+ bias=False,
219
+ )
220
+ self.se_block = SEModule(self.out_channels)
221
+ self.shortcut = shortcut
222
+
223
+ def forward(self, ins):
224
+ x = self.in_conv(ins)
225
+ if self.shortcut:
226
+ out = x + self.se_block(x)
227
+ else:
228
+ out = self.se_block(x)
229
+ return out
230
+
231
+
232
+ class RSEFPN(nn.Module):
233
+ def __init__(self, in_channels, out_channels, shortcut=True, **kwargs):
234
+ super(RSEFPN, self).__init__()
235
+ self.out_channels = out_channels
236
+ self.ins_conv = nn.ModuleList()
237
+ self.inp_conv = nn.ModuleList()
238
+ self.intracl = False
239
+ if "intracl" in kwargs.keys() and kwargs["intracl"] is True:
240
+ self.intracl = kwargs["intracl"]
241
+ self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
242
+ self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
243
+ self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
244
+ self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
245
+
246
+ for i in range(len(in_channels)):
247
+ self.ins_conv.append(
248
+ RSELayer(in_channels[i], out_channels, kernel_size=1, shortcut=shortcut)
249
+ )
250
+ self.inp_conv.append(
251
+ RSELayer(
252
+ out_channels, out_channels // 4, kernel_size=3, shortcut=shortcut
253
+ )
254
+ )
255
+
256
+ def forward(self, x):
257
+ c2, c3, c4, c5 = x
258
+
259
+ in5 = self.ins_conv[3](c5)
260
+ in4 = self.ins_conv[2](c4)
261
+ in3 = self.ins_conv[1](c3)
262
+ in2 = self.ins_conv[0](c2)
263
+
264
+ out4 = in4 + F.interpolate(in5, scale_factor=2, mode="nearest") # 1/16
265
+ out3 = in3 + F.interpolate(out4, scale_factor=2, mode="nearest") # 1/8
266
+ out2 = in2 + F.interpolate(out3, scale_factor=2, mode="nearest") # 1/4
267
+
268
+ p5 = self.inp_conv[3](in5)
269
+ p4 = self.inp_conv[2](out4)
270
+ p3 = self.inp_conv[1](out3)
271
+ p2 = self.inp_conv[0](out2)
272
+
273
+ if self.intracl is True:
274
+ p5 = self.incl4(p5)
275
+ p4 = self.incl3(p4)
276
+ p3 = self.incl2(p3)
277
+ p2 = self.incl1(p2)
278
+
279
+ p5 = F.interpolate(p5, scale_factor=8, mode="nearest")
280
+ p4 = F.interpolate(p4, scale_factor=4, mode="nearest")
281
+ p3 = F.interpolate(p3, scale_factor=2, mode="nearest")
282
+
283
+ fuse = torch.cat([p5, p4, p3, p2], dim=1)
284
+ return fuse
285
+
286
+
287
+ class LKPAN(nn.Module):
288
+ def __init__(self, in_channels, out_channels, mode="large", **kwargs):
289
+ super(LKPAN, self).__init__()
290
+ self.out_channels = out_channels
291
+
292
+ self.ins_conv = nn.ModuleList()
293
+ self.inp_conv = nn.ModuleList()
294
+ # pan head
295
+ self.pan_head_conv = nn.ModuleList()
296
+ self.pan_lat_conv = nn.ModuleList()
297
+
298
+ if mode.lower() == "lite":
299
+ p_layer = DSConv
300
+ elif mode.lower() == "large":
301
+ p_layer = nn.Conv2d
302
+ else:
303
+ raise ValueError(
304
+ "mode can only be one of ['lite', 'large'], but received {}".format(
305
+ mode
306
+ )
307
+ )
308
+
309
+ for i in range(len(in_channels)):
310
+ self.ins_conv.append(
311
+ nn.Conv2d(
312
+ in_channels=in_channels[i],
313
+ out_channels=self.out_channels,
314
+ kernel_size=1,
315
+ bias=False,
316
+ )
317
+ )
318
+
319
+ self.inp_conv.append(
320
+ p_layer(
321
+ in_channels=self.out_channels,
322
+ out_channels=self.out_channels // 4,
323
+ kernel_size=9,
324
+ padding=4,
325
+ bias=False,
326
+ )
327
+ )
328
+
329
+ if i > 0:
330
+ self.pan_head_conv.append(
331
+ nn.Conv2d(
332
+ in_channels=self.out_channels // 4,
333
+ out_channels=self.out_channels // 4,
334
+ kernel_size=3,
335
+ padding=1,
336
+ stride=2,
337
+ bias=False,
338
+ )
339
+ )
340
+ self.pan_lat_conv.append(
341
+ p_layer(
342
+ in_channels=self.out_channels // 4,
343
+ out_channels=self.out_channels // 4,
344
+ kernel_size=9,
345
+ padding=4,
346
+ bias=False,
347
+ )
348
+ )
349
+ self.intracl = False
350
+ if "intracl" in kwargs.keys() and kwargs["intracl"] is True:
351
+ self.intracl = kwargs["intracl"]
352
+ self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
353
+ self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
354
+ self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
355
+ self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
356
+
357
+ def forward(self, x):
358
+ c2, c3, c4, c5 = x
359
+
360
+ in5 = self.ins_conv[3](c5)
361
+ in4 = self.ins_conv[2](c4)
362
+ in3 = self.ins_conv[1](c3)
363
+ in2 = self.ins_conv[0](c2)
364
+
365
+ out4 = in4 + F.interpolate(in5, scale_factor=2, mode="nearest") # 1/16
366
+ out3 = in3 + F.interpolate(out4, scale_factor=2, mode="nearest") # 1/8
367
+ out2 = in2 + F.interpolate(out3, scale_factor=2, mode="nearest") # 1/4
368
+
369
+ f5 = self.inp_conv[3](in5)
370
+ f4 = self.inp_conv[2](out4)
371
+ f3 = self.inp_conv[1](out3)
372
+ f2 = self.inp_conv[0](out2)
373
+
374
+ pan3 = f3 + self.pan_head_conv[0](f2)
375
+ pan4 = f4 + self.pan_head_conv[1](pan3)
376
+ pan5 = f5 + self.pan_head_conv[2](pan4)
377
+
378
+ p2 = self.pan_lat_conv[0](f2)
379
+ p3 = self.pan_lat_conv[1](pan3)
380
+ p4 = self.pan_lat_conv[2](pan4)
381
+ p5 = self.pan_lat_conv[3](pan5)
382
+
383
+ if self.intracl is True:
384
+ p5 = self.incl4(p5)
385
+ p4 = self.incl3(p4)
386
+ p3 = self.incl2(p3)
387
+ p2 = self.incl1(p2)
388
+
389
+ p5 = F.interpolate(p5, scale_factor=8, mode="nearest")
390
+ p4 = F.interpolate(p4, scale_factor=4, mode="nearest")
391
+ p3 = F.interpolate(p3, scale_factor=2, mode="nearest")
392
+
393
+ fuse = torch.cat([p5, p4, p3, p2], dim=1)
394
+ return fuse
395
+
396
+
397
+ class ASFBlock(nn.Module):
398
+ """
399
+ This code is refered from:
400
+ https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py
401
+ """
402
+
403
+ def __init__(self, in_channels, inter_channels, out_features_num=4):
404
+ """
405
+ Adaptive Scale Fusion (ASF) block of DBNet++
406
+ Args:
407
+ in_channels: the number of channels in the input data
408
+ inter_channels: the number of middle channels
409
+ out_features_num: the number of fused stages
410
+ """
411
+ super(ASFBlock, self).__init__()
412
+ self.in_channels = in_channels
413
+ self.inter_channels = inter_channels
414
+ self.out_features_num = out_features_num
415
+ self.conv = nn.Conv2d(in_channels, inter_channels, 3, padding=1)
416
+
417
+ self.spatial_scale = nn.Sequential(
418
+ # Nx1xHxW
419
+ nn.Conv2d(
420
+ in_channels=1,
421
+ out_channels=1,
422
+ kernel_size=3,
423
+ bias=False,
424
+ padding=1,
425
+ ),
426
+ nn.ReLU(),
427
+ nn.Conv2d(
428
+ in_channels=1,
429
+ out_channels=1,
430
+ kernel_size=1,
431
+ bias=False,
432
+ ),
433
+ nn.Sigmoid(),
434
+ )
435
+
436
+ self.channel_scale = nn.Sequential(
437
+ nn.Conv2d(
438
+ in_channels=inter_channels,
439
+ out_channels=out_features_num,
440
+ kernel_size=1,
441
+ bias=False,
442
+ ),
443
+ nn.Sigmoid(),
444
+ )
445
+
446
+ def forward(self, fuse_features, features_list):
447
+ fuse_features = self.conv(fuse_features)
448
+ spatial_x = torch.mean(fuse_features, dim=1, keepdim=True)
449
+ attention_scores = self.spatial_scale(spatial_x) + fuse_features
450
+ attention_scores = self.channel_scale(attention_scores)
451
+ assert len(features_list) == self.out_features_num
452
+
453
+ out_list = []
454
+ for i in range(self.out_features_num):
455
+ out_list.append(attention_scores[:, i : i + 1] * features_list[i])
456
+ return torch.cat(out_list, dim=1)
@@ -0,0 +1,117 @@
1
+ from torch import nn
2
+
3
+
4
+ class IntraCLBlock(nn.Module):
5
+ def __init__(self, in_channels=96, reduce_factor=4):
6
+ super(IntraCLBlock, self).__init__()
7
+ self.channels = in_channels
8
+ self.rf = reduce_factor
9
+ self.conv1x1_reduce_channel = nn.Conv2d(
10
+ self.channels, self.channels // self.rf, kernel_size=1, stride=1, padding=0
11
+ )
12
+ self.conv1x1_return_channel = nn.Conv2d(
13
+ self.channels // self.rf, self.channels, kernel_size=1, stride=1, padding=0
14
+ )
15
+
16
+ self.v_layer_7x1 = nn.Conv2d(
17
+ self.channels // self.rf,
18
+ self.channels // self.rf,
19
+ kernel_size=(7, 1),
20
+ stride=(1, 1),
21
+ padding=(3, 0),
22
+ )
23
+ self.v_layer_5x1 = nn.Conv2d(
24
+ self.channels // self.rf,
25
+ self.channels // self.rf,
26
+ kernel_size=(5, 1),
27
+ stride=(1, 1),
28
+ padding=(2, 0),
29
+ )
30
+ self.v_layer_3x1 = nn.Conv2d(
31
+ self.channels // self.rf,
32
+ self.channels // self.rf,
33
+ kernel_size=(3, 1),
34
+ stride=(1, 1),
35
+ padding=(1, 0),
36
+ )
37
+
38
+ self.q_layer_1x7 = nn.Conv2d(
39
+ self.channels // self.rf,
40
+ self.channels // self.rf,
41
+ kernel_size=(1, 7),
42
+ stride=(1, 1),
43
+ padding=(0, 3),
44
+ )
45
+ self.q_layer_1x5 = nn.Conv2d(
46
+ self.channels // self.rf,
47
+ self.channels // self.rf,
48
+ kernel_size=(1, 5),
49
+ stride=(1, 1),
50
+ padding=(0, 2),
51
+ )
52
+ self.q_layer_1x3 = nn.Conv2d(
53
+ self.channels // self.rf,
54
+ self.channels // self.rf,
55
+ kernel_size=(1, 3),
56
+ stride=(1, 1),
57
+ padding=(0, 1),
58
+ )
59
+
60
+ # base
61
+ self.c_layer_7x7 = nn.Conv2d(
62
+ self.channels // self.rf,
63
+ self.channels // self.rf,
64
+ kernel_size=(7, 7),
65
+ stride=(1, 1),
66
+ padding=(3, 3),
67
+ )
68
+ self.c_layer_5x5 = nn.Conv2d(
69
+ self.channels // self.rf,
70
+ self.channels // self.rf,
71
+ kernel_size=(5, 5),
72
+ stride=(1, 1),
73
+ padding=(2, 2),
74
+ )
75
+ self.c_layer_3x3 = nn.Conv2d(
76
+ self.channels // self.rf,
77
+ self.channels // self.rf,
78
+ kernel_size=(3, 3),
79
+ stride=(1, 1),
80
+ padding=(1, 1),
81
+ )
82
+
83
+ self.bn = nn.BatchNorm2d(self.channels)
84
+ self.relu = nn.ReLU()
85
+
86
+ def forward(self, x):
87
+ x_new = self.conv1x1_reduce_channel(x)
88
+
89
+ x_7_c = self.c_layer_7x7(x_new)
90
+ x_7_v = self.v_layer_7x1(x_new)
91
+ x_7_q = self.q_layer_1x7(x_new)
92
+ x_7 = x_7_c + x_7_v + x_7_q
93
+
94
+ x_5_c = self.c_layer_5x5(x_7)
95
+ x_5_v = self.v_layer_5x1(x_7)
96
+ x_5_q = self.q_layer_1x5(x_7)
97
+ x_5 = x_5_c + x_5_v + x_5_q
98
+
99
+ x_3_c = self.c_layer_3x3(x_5)
100
+ x_3_v = self.v_layer_3x1(x_5)
101
+ x_3_q = self.q_layer_1x3(x_5)
102
+ x_3 = x_3_c + x_3_v + x_3_q
103
+
104
+ x_relation = self.conv1x1_return_channel(x_3)
105
+
106
+ x_relation = self.bn(x_relation)
107
+ x_relation = self.relu(x_relation)
108
+
109
+ return x + x_relation
110
+
111
+
112
+ def build_intraclblock_list(num_block):
113
+ IntraCLBlock_list = nn.ModuleList()
114
+ for i in range(num_block):
115
+ IntraCLBlock_list.append(IntraCLBlock())
116
+
117
+ return IntraCLBlock_list