magic-pdf 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. magic_pdf/data/batch_build_dataset.py +156 -0
  2. magic_pdf/data/dataset.py +56 -25
  3. magic_pdf/data/utils.py +108 -9
  4. magic_pdf/dict2md/ocr_mkcontent.py +4 -3
  5. magic_pdf/libs/pdf_image_tools.py +11 -6
  6. magic_pdf/libs/performance_stats.py +12 -1
  7. magic_pdf/libs/version.py +1 -1
  8. magic_pdf/model/batch_analyze.py +175 -201
  9. magic_pdf/model/doc_analyze_by_custom_model.py +142 -92
  10. magic_pdf/model/pdf_extract_kit.py +5 -38
  11. magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
  12. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
  13. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
  14. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
  15. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
  16. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
  17. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
  18. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
  19. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
  20. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
  21. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
  22. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
  23. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
  24. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
  25. magic_pdf/model/sub_modules/model_init.py +50 -37
  26. magic_pdf/model/sub_modules/model_utils.py +18 -12
  27. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
  29. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
  30. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
  31. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
  32. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
  33. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
  34. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
  35. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
  36. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
  37. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
  38. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
  39. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
  40. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
  41. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
  42. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
  43. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
  44. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
  45. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
  46. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
  47. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
  48. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
  49. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
  50. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
  51. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
  52. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
  53. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
  54. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
  55. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
  56. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
  57. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
  58. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
  59. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
  60. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
  61. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
  62. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
  63. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
  64. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
  65. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
  66. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
  67. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
  68. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
  69. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
  70. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
  71. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
  72. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
  73. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
  74. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
  75. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
  76. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
  77. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
  78. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
  79. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +15 -19
  80. magic_pdf/pdf_parse_union_core_v2.py +112 -74
  81. magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
  82. magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
  83. magic_pdf/resources/model_config/model_configs.yaml +1 -1
  84. magic_pdf/resources/slanet_plus/slanet-plus.onnx +0 -0
  85. magic_pdf/tools/cli.py +30 -12
  86. magic_pdf/tools/common.py +90 -12
  87. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/METADATA +92 -59
  88. magic_pdf-1.3.1.dist-info/RECORD +203 -0
  89. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/WHEEL +1 -1
  90. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
  91. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
  92. magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
  93. magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
  94. magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
  95. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
  96. magic_pdf-1.2.2.dist-info/RECORD +0 -147
  97. /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
  98. /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
  99. /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
  100. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/LICENSE.md +0 -0
  101. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/entry_points.txt +0 -0
  102. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,638 @@
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+
5
+ from ..common import Activation
6
+
7
+
8
+ def drop_path(x, drop_prob=0.0, training=False):
9
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
10
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
11
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
12
+ """
13
+ if drop_prob == 0.0 or not training:
14
+ return x
15
+ keep_prob = torch.as_tensor(1 - drop_prob)
16
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
17
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
18
+ random_tensor = torch.floor(random_tensor) # binarize
19
+ output = x.divide(keep_prob) * random_tensor
20
+ return output
21
+
22
+
23
+ class ConvBNLayer(nn.Module):
24
+ def __init__(
25
+ self,
26
+ in_channels,
27
+ out_channels,
28
+ kernel_size=3,
29
+ stride=1,
30
+ padding=0,
31
+ bias_attr=False,
32
+ groups=1,
33
+ act="gelu",
34
+ ):
35
+ super().__init__()
36
+ self.conv = nn.Conv2d(
37
+ in_channels=in_channels,
38
+ out_channels=out_channels,
39
+ kernel_size=kernel_size,
40
+ stride=stride,
41
+ padding=padding,
42
+ groups=groups,
43
+ bias=bias_attr,
44
+ )
45
+ self.norm = nn.BatchNorm2d(out_channels)
46
+ self.act = Activation(act_type=act, inplace=True)
47
+
48
+ def forward(self, inputs):
49
+ out = self.conv(inputs)
50
+ out = self.norm(out)
51
+ out = self.act(out)
52
+ return out
53
+
54
+
55
+ class DropPath(nn.Module):
56
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
57
+
58
+ def __init__(self, drop_prob=None):
59
+ super(DropPath, self).__init__()
60
+ self.drop_prob = drop_prob
61
+
62
+ def forward(self, x):
63
+ return drop_path(x, self.drop_prob, self.training)
64
+
65
+
66
+ class Identity(nn.Module):
67
+ def __init__(self):
68
+ super(Identity, self).__init__()
69
+
70
+ def forward(self, input):
71
+ return input
72
+
73
+
74
+ class Mlp(nn.Module):
75
+ def __init__(
76
+ self,
77
+ in_features,
78
+ hidden_features=None,
79
+ out_features=None,
80
+ act_layer="gelu",
81
+ drop=0.0,
82
+ ):
83
+ super().__init__()
84
+ out_features = out_features or in_features
85
+ hidden_features = hidden_features or in_features
86
+ self.fc1 = nn.Linear(in_features, hidden_features)
87
+ self.act = Activation(act_type=act_layer, inplace=True)
88
+ self.fc2 = nn.Linear(hidden_features, out_features)
89
+ self.drop = nn.Dropout(drop)
90
+
91
+ def forward(self, x):
92
+ x = self.fc1(x)
93
+ x = self.act(x)
94
+ x = self.drop(x)
95
+ x = self.fc2(x)
96
+ x = self.drop(x)
97
+ return x
98
+
99
+
100
+ class ConvMixer(nn.Module):
101
+ def __init__(
102
+ self,
103
+ dim,
104
+ num_heads=8,
105
+ HW=[8, 25],
106
+ local_k=[3, 3],
107
+ ):
108
+ super().__init__()
109
+ self.HW = HW
110
+ self.dim = dim
111
+ self.local_mixer = nn.Conv2d(
112
+ dim,
113
+ dim,
114
+ local_k,
115
+ 1,
116
+ [local_k[0] // 2, local_k[1] // 2],
117
+ groups=num_heads,
118
+ )
119
+
120
+ def forward(self, x):
121
+ h = self.HW[0]
122
+ w = self.HW[1]
123
+ x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
124
+ x = self.local_mixer(x)
125
+ x = x.flatten(2).permute(0, 2, 1)
126
+ return x
127
+
128
+
129
+ class Attention(nn.Module):
130
+ def __init__(
131
+ self,
132
+ dim,
133
+ num_heads=8,
134
+ mixer="Global",
135
+ HW=[8, 25],
136
+ local_k=[7, 11],
137
+ qkv_bias=False,
138
+ qk_scale=None,
139
+ attn_drop=0.0,
140
+ proj_drop=0.0,
141
+ ):
142
+ super().__init__()
143
+ self.num_heads = num_heads
144
+ head_dim = dim // num_heads
145
+ self.scale = qk_scale or head_dim**-0.5
146
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
147
+ self.attn_drop = nn.Dropout(attn_drop)
148
+ self.proj = nn.Linear(dim, dim)
149
+ self.proj_drop = nn.Dropout(proj_drop)
150
+ self.HW = HW
151
+ if HW is not None:
152
+ H = HW[0]
153
+ W = HW[1]
154
+ self.N = H * W
155
+ self.C = dim
156
+ if mixer == "Local" and HW is not None:
157
+ hk = local_k[0]
158
+ wk = local_k[1]
159
+ mask = torch.ones(H * W, H + hk - 1, W + wk - 1, dtype=torch.float32)
160
+ for h in range(0, H):
161
+ for w in range(0, W):
162
+ mask[h * W + w, h : h + hk, w : w + wk] = 0.0
163
+ mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(
164
+ 1
165
+ )
166
+ mask_inf = torch.full(
167
+ [H * W, H * W], fill_value=float("-Inf"), dtype=torch.float32
168
+ )
169
+ mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
170
+ self.mask = mask.unsqueeze(0).unsqueeze(1)
171
+ # self.mask = mask[None, None, :]
172
+ self.mixer = mixer
173
+
174
+ def forward(self, x):
175
+ if self.HW is not None:
176
+ N = self.N
177
+ C = self.C
178
+ else:
179
+ _, N, C = x.shape
180
+ qkv = self.qkv(x)
181
+ qkv = qkv.reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute(
182
+ 2, 0, 3, 1, 4
183
+ )
184
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
185
+
186
+ attn = q.matmul(k.permute(0, 1, 3, 2))
187
+ if self.mixer == "Local":
188
+ attn += self.mask
189
+ attn = nn.functional.softmax(attn, dim=-1)
190
+ attn = self.attn_drop(attn)
191
+
192
+ x = (attn.matmul(v)).permute(0, 2, 1, 3).reshape((-1, N, C))
193
+ x = self.proj(x)
194
+ x = self.proj_drop(x)
195
+ return x
196
+
197
+
198
+ class Block(nn.Module):
199
+ def __init__(
200
+ self,
201
+ dim,
202
+ num_heads,
203
+ mixer="Global",
204
+ local_mixer=[7, 11],
205
+ HW=None,
206
+ mlp_ratio=4.0,
207
+ qkv_bias=False,
208
+ qk_scale=None,
209
+ drop=0.0,
210
+ attn_drop=0.0,
211
+ drop_path=0.0,
212
+ act_layer="gelu",
213
+ norm_layer="nn.LayerNorm",
214
+ epsilon=1e-6,
215
+ prenorm=True,
216
+ ):
217
+ super().__init__()
218
+ if isinstance(norm_layer, str):
219
+ self.norm1 = eval(norm_layer)(dim, eps=epsilon)
220
+ else:
221
+ self.norm1 = norm_layer(dim)
222
+ if mixer == "Global" or mixer == "Local":
223
+ self.mixer = Attention(
224
+ dim,
225
+ num_heads=num_heads,
226
+ mixer=mixer,
227
+ HW=HW,
228
+ local_k=local_mixer,
229
+ qkv_bias=qkv_bias,
230
+ qk_scale=qk_scale,
231
+ attn_drop=attn_drop,
232
+ proj_drop=drop,
233
+ )
234
+ elif mixer == "Conv":
235
+ self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
236
+ else:
237
+ raise TypeError("The mixer must be one of [Global, Local, Conv]")
238
+
239
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
240
+ if isinstance(norm_layer, str):
241
+ self.norm2 = eval(norm_layer)(dim, eps=epsilon)
242
+ else:
243
+ self.norm2 = norm_layer(dim)
244
+ mlp_hidden_dim = int(dim * mlp_ratio)
245
+ self.mlp_ratio = mlp_ratio
246
+ self.mlp = Mlp(
247
+ in_features=dim,
248
+ hidden_features=mlp_hidden_dim,
249
+ act_layer=act_layer,
250
+ drop=drop,
251
+ )
252
+ self.prenorm = prenorm
253
+
254
+ def forward(self, x):
255
+ if self.prenorm:
256
+ x = self.norm1(x + self.drop_path(self.mixer(x)))
257
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
258
+ else:
259
+ x = x + self.drop_path(self.mixer(self.norm1(x)))
260
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
261
+ return x
262
+
263
+
264
+ class PatchEmbed(nn.Module):
265
+ """Image to Patch Embedding"""
266
+
267
+ def __init__(
268
+ self,
269
+ img_size=[32, 100],
270
+ in_channels=3,
271
+ embed_dim=768,
272
+ sub_num=2,
273
+ patch_size=[4, 4],
274
+ mode="pope",
275
+ ):
276
+ super().__init__()
277
+ num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))
278
+ self.img_size = img_size
279
+ self.num_patches = num_patches
280
+ self.embed_dim = embed_dim
281
+ self.norm = None
282
+ if mode == "pope":
283
+ if sub_num == 2:
284
+ self.proj = nn.Sequential(
285
+ ConvBNLayer(
286
+ in_channels=in_channels,
287
+ out_channels=embed_dim // 2,
288
+ kernel_size=3,
289
+ stride=2,
290
+ padding=1,
291
+ act="gelu",
292
+ bias_attr=True,
293
+ ),
294
+ ConvBNLayer(
295
+ in_channels=embed_dim // 2,
296
+ out_channels=embed_dim,
297
+ kernel_size=3,
298
+ stride=2,
299
+ padding=1,
300
+ act="gelu",
301
+ bias_attr=True,
302
+ ),
303
+ )
304
+ if sub_num == 3:
305
+ self.proj = nn.Sequential(
306
+ ConvBNLayer(
307
+ in_channels=in_channels,
308
+ out_channels=embed_dim // 4,
309
+ kernel_size=3,
310
+ stride=2,
311
+ padding=1,
312
+ act="gelu",
313
+ bias_attr=True,
314
+ ),
315
+ ConvBNLayer(
316
+ in_channels=embed_dim // 4,
317
+ out_channels=embed_dim // 2,
318
+ kernel_size=3,
319
+ stride=2,
320
+ padding=1,
321
+ act="gelu",
322
+ bias_attr=True,
323
+ ),
324
+ ConvBNLayer(
325
+ in_channels=embed_dim // 2,
326
+ out_channels=embed_dim,
327
+ kernel_size=3,
328
+ stride=2,
329
+ padding=1,
330
+ act="gelu",
331
+ bias_attr=True,
332
+ ),
333
+ )
334
+ elif mode == "linear":
335
+ self.proj = nn.Conv2d(
336
+ 1, embed_dim, kernel_size=patch_size, stride=patch_size
337
+ )
338
+ self.num_patches = (
339
+ img_size[0] // patch_size[0] * img_size[1] // patch_size[1]
340
+ )
341
+
342
+ def forward(self, x):
343
+ B, C, H, W = x.shape
344
+ assert (
345
+ H == self.img_size[0] and W == self.img_size[1]
346
+ ), "Input image size ({}*{}) doesn't match model ({}*{}).".format(
347
+ H, W, self.img_size[0], self.img_size[1]
348
+ )
349
+ x = self.proj(x).flatten(2).permute(0, 2, 1)
350
+ return x
351
+
352
+
353
+ class SubSample(nn.Module):
354
+ def __init__(
355
+ self,
356
+ in_channels,
357
+ out_channels,
358
+ types="Pool",
359
+ stride=[2, 1],
360
+ sub_norm="nn.LayerNorm",
361
+ act=None,
362
+ ):
363
+ super().__init__()
364
+ self.types = types
365
+ if types == "Pool":
366
+ self.avgpool = nn.AvgPool2d(
367
+ kernel_size=[3, 5], stride=stride, padding=[1, 2]
368
+ )
369
+ self.maxpool = nn.MaxPool2d(
370
+ kernel_size=[3, 5], stride=stride, padding=[1, 2]
371
+ )
372
+ self.proj = nn.Linear(in_channels, out_channels)
373
+ else:
374
+ self.conv = nn.Conv2d(
375
+ in_channels,
376
+ out_channels,
377
+ kernel_size=3,
378
+ stride=stride,
379
+ padding=1,
380
+ )
381
+ self.norm = eval(sub_norm)(out_channels)
382
+ if act is not None:
383
+ self.act = act()
384
+ else:
385
+ self.act = None
386
+
387
+ def forward(self, x):
388
+ if self.types == "Pool":
389
+ x1 = self.avgpool(x)
390
+ x2 = self.maxpool(x)
391
+ x = (x1 + x2) * 0.5
392
+ out = self.proj(x.flatten(2).permute(0, 2, 1))
393
+ else:
394
+ x = self.conv(x)
395
+ out = x.flatten(2).permute(0, 2, 1)
396
+ out = self.norm(out)
397
+ if self.act is not None:
398
+ out = self.act(out)
399
+
400
+ return out
401
+
402
+
403
+ class SVTRNet(nn.Module):
404
+ def __init__(
405
+ self,
406
+ img_size=[32, 100],
407
+ in_channels=3,
408
+ embed_dim=[64, 128, 256],
409
+ depth=[3, 6, 3],
410
+ num_heads=[2, 4, 8],
411
+ mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv
412
+ local_mixer=[[7, 11], [7, 11], [7, 11]],
413
+ patch_merging="Conv", # Conv, Pool, None
414
+ mlp_ratio=4,
415
+ qkv_bias=True,
416
+ qk_scale=None,
417
+ drop_rate=0.0,
418
+ last_drop=0.0,
419
+ attn_drop_rate=0.0,
420
+ drop_path_rate=0.1,
421
+ norm_layer="nn.LayerNorm",
422
+ sub_norm="nn.LayerNorm",
423
+ epsilon=1e-6,
424
+ out_channels=192,
425
+ out_char_num=25,
426
+ block_unit="Block",
427
+ act="gelu",
428
+ last_stage=True,
429
+ sub_num=2,
430
+ prenorm=True,
431
+ use_lenhead=False,
432
+ **kwargs
433
+ ):
434
+ super().__init__()
435
+ self.img_size = img_size
436
+ self.embed_dim = embed_dim
437
+ self.out_channels = out_channels
438
+ self.prenorm = prenorm
439
+ patch_merging = (
440
+ None
441
+ if patch_merging != "Conv" and patch_merging != "Pool"
442
+ else patch_merging
443
+ )
444
+ self.patch_embed = PatchEmbed(
445
+ img_size=img_size,
446
+ in_channels=in_channels,
447
+ embed_dim=embed_dim[0],
448
+ sub_num=sub_num,
449
+ )
450
+ num_patches = self.patch_embed.num_patches
451
+ self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
452
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
453
+ self.pos_drop = nn.Dropout(p=drop_rate)
454
+ Block_unit = eval(block_unit)
455
+
456
+ dpr = np.linspace(0, drop_path_rate, sum(depth))
457
+ self.blocks1 = nn.ModuleList(
458
+ [
459
+ Block_unit(
460
+ dim=embed_dim[0],
461
+ num_heads=num_heads[0],
462
+ mixer=mixer[0 : depth[0]][i],
463
+ HW=self.HW,
464
+ local_mixer=local_mixer[0],
465
+ mlp_ratio=mlp_ratio,
466
+ qkv_bias=qkv_bias,
467
+ qk_scale=qk_scale,
468
+ drop=drop_rate,
469
+ act_layer=act,
470
+ attn_drop=attn_drop_rate,
471
+ drop_path=dpr[0 : depth[0]][i],
472
+ norm_layer=norm_layer,
473
+ epsilon=epsilon,
474
+ prenorm=prenorm,
475
+ )
476
+ for i in range(depth[0])
477
+ ]
478
+ )
479
+ if patch_merging is not None:
480
+ self.sub_sample1 = SubSample(
481
+ embed_dim[0],
482
+ embed_dim[1],
483
+ sub_norm=sub_norm,
484
+ stride=[2, 1],
485
+ types=patch_merging,
486
+ )
487
+ HW = [self.HW[0] // 2, self.HW[1]]
488
+ else:
489
+ HW = self.HW
490
+ self.patch_merging = patch_merging
491
+ self.blocks2 = nn.ModuleList(
492
+ [
493
+ Block_unit(
494
+ dim=embed_dim[1],
495
+ num_heads=num_heads[1],
496
+ mixer=mixer[depth[0] : depth[0] + depth[1]][i],
497
+ HW=HW,
498
+ local_mixer=local_mixer[1],
499
+ mlp_ratio=mlp_ratio,
500
+ qkv_bias=qkv_bias,
501
+ qk_scale=qk_scale,
502
+ drop=drop_rate,
503
+ act_layer=act,
504
+ attn_drop=attn_drop_rate,
505
+ drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
506
+ norm_layer=norm_layer,
507
+ epsilon=epsilon,
508
+ prenorm=prenorm,
509
+ )
510
+ for i in range(depth[1])
511
+ ]
512
+ )
513
+ if patch_merging is not None:
514
+ self.sub_sample2 = SubSample(
515
+ embed_dim[1],
516
+ embed_dim[2],
517
+ sub_norm=sub_norm,
518
+ stride=[2, 1],
519
+ types=patch_merging,
520
+ )
521
+ HW = [self.HW[0] // 4, self.HW[1]]
522
+ else:
523
+ HW = self.HW
524
+ self.blocks3 = nn.ModuleList(
525
+ [
526
+ Block_unit(
527
+ dim=embed_dim[2],
528
+ num_heads=num_heads[2],
529
+ mixer=mixer[depth[0] + depth[1] :][i],
530
+ HW=HW,
531
+ local_mixer=local_mixer[2],
532
+ mlp_ratio=mlp_ratio,
533
+ qkv_bias=qkv_bias,
534
+ qk_scale=qk_scale,
535
+ drop=drop_rate,
536
+ act_layer=act,
537
+ attn_drop=attn_drop_rate,
538
+ drop_path=dpr[depth[0] + depth[1] :][i],
539
+ norm_layer=norm_layer,
540
+ epsilon=epsilon,
541
+ prenorm=prenorm,
542
+ )
543
+ for i in range(depth[2])
544
+ ]
545
+ )
546
+ self.last_stage = last_stage
547
+ if last_stage:
548
+ self.avg_pool = nn.AdaptiveAvgPool2d([1, out_char_num])
549
+ self.last_conv = nn.Conv2d(
550
+ in_channels=embed_dim[2],
551
+ out_channels=self.out_channels,
552
+ kernel_size=1,
553
+ stride=1,
554
+ padding=0,
555
+ bias=False,
556
+ )
557
+ self.hardswish = Activation("hard_swish", inplace=True) # nn.Hardswish()
558
+ # self.dropout = nn.Dropout(p=last_drop, mode="downscale_in_infer")
559
+ self.dropout = nn.Dropout(p=last_drop)
560
+ if not prenorm:
561
+ self.norm = eval(norm_layer)(embed_dim[-1], eps=epsilon)
562
+ self.use_lenhead = use_lenhead
563
+ if use_lenhead:
564
+ self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
565
+ self.hardswish_len = Activation(
566
+ "hard_swish", inplace=True
567
+ ) # nn.Hardswish()
568
+ self.dropout_len = nn.Dropout(p=last_drop)
569
+
570
+ torch.nn.init.xavier_normal_(self.pos_embed)
571
+ self.apply(self._init_weights)
572
+
573
+ def _init_weights(self, m):
574
+ # weight initialization
575
+ if isinstance(m, nn.Conv2d):
576
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
577
+ if m.bias is not None:
578
+ nn.init.zeros_(m.bias)
579
+ elif isinstance(m, nn.BatchNorm2d):
580
+ nn.init.ones_(m.weight)
581
+ nn.init.zeros_(m.bias)
582
+ elif isinstance(m, nn.Linear):
583
+ nn.init.normal_(m.weight, 0, 0.01)
584
+ if m.bias is not None:
585
+ nn.init.zeros_(m.bias)
586
+ elif isinstance(m, nn.ConvTranspose2d):
587
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
588
+ if m.bias is not None:
589
+ nn.init.zeros_(m.bias)
590
+ elif isinstance(m, nn.LayerNorm):
591
+ nn.init.ones_(m.weight)
592
+ nn.init.zeros_(m.bias)
593
+
594
+ def forward_features(self, x):
595
+ x = self.patch_embed(x)
596
+ x = x + self.pos_embed
597
+ x = self.pos_drop(x)
598
+ for blk in self.blocks1:
599
+ x = blk(x)
600
+ if self.patch_merging is not None:
601
+ x = self.sub_sample1(
602
+ x.permute(0, 2, 1).reshape(
603
+ [-1, self.embed_dim[0], self.HW[0], self.HW[1]]
604
+ )
605
+ )
606
+ for blk in self.blocks2:
607
+ x = blk(x)
608
+ if self.patch_merging is not None:
609
+ x = self.sub_sample2(
610
+ x.permute(0, 2, 1).reshape(
611
+ [-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]
612
+ )
613
+ )
614
+ for blk in self.blocks3:
615
+ x = blk(x)
616
+ if not self.prenorm:
617
+ x = self.norm(x)
618
+ return x
619
+
620
+ def forward(self, x):
621
+ x = self.forward_features(x)
622
+ if self.use_lenhead:
623
+ len_x = self.len_conv(x.mean(1))
624
+ len_x = self.dropout_len(self.hardswish_len(len_x))
625
+ if self.last_stage:
626
+ if self.patch_merging is not None:
627
+ h = self.HW[0] // 4
628
+ else:
629
+ h = self.HW[0]
630
+ x = self.avg_pool(
631
+ x.permute(0, 2, 1).reshape([-1, self.embed_dim[2], h, self.HW[1]])
632
+ )
633
+ x = self.last_conv(x)
634
+ x = self.hardswish(x)
635
+ x = self.dropout(x)
636
+ if self.use_lenhead:
637
+ return x, len_x
638
+ return x
@@ -0,0 +1,76 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+
6
+ class Hswish(nn.Module):
7
+ def __init__(self, inplace=True):
8
+ super(Hswish, self).__init__()
9
+ self.inplace = inplace
10
+
11
+ def forward(self, x):
12
+ return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
13
+
14
+
15
+ # out = max(0, min(1, slop*x+offset))
16
+ # paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
17
+ class Hsigmoid(nn.Module):
18
+ def __init__(self, inplace=True):
19
+ super(Hsigmoid, self).__init__()
20
+ self.inplace = inplace
21
+
22
+ def forward(self, x):
23
+ # torch: F.relu6(x + 3., inplace=self.inplace) / 6.
24
+ # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
25
+ return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0
26
+
27
+
28
+ class GELU(nn.Module):
29
+ def __init__(self, inplace=True):
30
+ super(GELU, self).__init__()
31
+ self.inplace = inplace
32
+
33
+ def forward(self, x):
34
+ return torch.nn.functional.gelu(x)
35
+
36
+
37
+ class Swish(nn.Module):
38
+ def __init__(self, inplace=True):
39
+ super(Swish, self).__init__()
40
+ self.inplace = inplace
41
+
42
+ def forward(self, x):
43
+ if self.inplace:
44
+ x.mul_(torch.sigmoid(x))
45
+ return x
46
+ else:
47
+ return x * torch.sigmoid(x)
48
+
49
+
50
+ class Activation(nn.Module):
51
+ def __init__(self, act_type, inplace=True):
52
+ super(Activation, self).__init__()
53
+ act_type = act_type.lower()
54
+ if act_type == "relu":
55
+ self.act = nn.ReLU(inplace=inplace)
56
+ elif act_type == "relu6":
57
+ self.act = nn.ReLU6(inplace=inplace)
58
+ elif act_type == "sigmoid":
59
+ raise NotImplementedError
60
+ elif act_type == "hard_sigmoid":
61
+ self.act = Hsigmoid(
62
+ inplace
63
+ ) # nn.Hardsigmoid(inplace=inplace)#Hsigmoid(inplace)#
64
+ elif act_type == "hard_swish" or act_type == "hswish":
65
+ self.act = Hswish(inplace=inplace)
66
+ elif act_type == "leakyrelu":
67
+ self.act = nn.LeakyReLU(inplace=inplace)
68
+ elif act_type == "gelu":
69
+ self.act = GELU(inplace=inplace)
70
+ elif act_type == "swish":
71
+ self.act = Swish(inplace=inplace)
72
+ else:
73
+ raise NotImplementedError
74
+
75
+ def forward(self, inputs):
76
+ return self.act(inputs)