magic-pdf 1.2.2__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. magic_pdf/data/batch_build_dataset.py +156 -0
  2. magic_pdf/data/dataset.py +56 -25
  3. magic_pdf/data/utils.py +108 -9
  4. magic_pdf/dict2md/ocr_mkcontent.py +4 -3
  5. magic_pdf/libs/pdf_image_tools.py +11 -6
  6. magic_pdf/libs/performance_stats.py +12 -1
  7. magic_pdf/libs/version.py +1 -1
  8. magic_pdf/model/batch_analyze.py +175 -201
  9. magic_pdf/model/doc_analyze_by_custom_model.py +142 -92
  10. magic_pdf/model/pdf_extract_kit.py +5 -38
  11. magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
  12. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
  13. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
  14. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
  15. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
  16. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
  17. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
  18. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
  19. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
  20. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
  21. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
  22. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
  23. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
  24. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
  25. magic_pdf/model/sub_modules/model_init.py +50 -37
  26. magic_pdf/model/sub_modules/model_utils.py +18 -12
  27. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
  29. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
  30. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
  31. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
  32. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
  33. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
  34. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
  35. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
  36. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
  37. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
  38. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
  39. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
  40. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
  41. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
  42. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
  43. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
  44. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
  45. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
  46. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
  47. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
  48. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
  49. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
  50. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
  51. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
  52. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
  53. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
  54. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
  55. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
  56. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
  57. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
  58. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
  59. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
  60. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
  61. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
  62. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
  63. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
  64. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
  65. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
  66. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
  67. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
  68. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
  69. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
  70. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
  71. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
  72. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
  73. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
  74. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
  75. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
  76. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
  77. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
  78. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
  79. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +15 -19
  80. magic_pdf/pdf_parse_union_core_v2.py +112 -74
  81. magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
  82. magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
  83. magic_pdf/resources/model_config/model_configs.yaml +1 -1
  84. magic_pdf/resources/slanet_plus/slanet-plus.onnx +0 -0
  85. magic_pdf/tools/cli.py +30 -12
  86. magic_pdf/tools/common.py +90 -12
  87. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/METADATA +92 -59
  88. magic_pdf-1.3.1.dist-info/RECORD +203 -0
  89. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/WHEEL +1 -1
  90. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
  91. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
  92. magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
  93. magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
  94. magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
  95. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
  96. magic_pdf-1.2.2.dist-info/RECORD +0 -147
  97. /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
  98. /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
  99. /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
  100. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/LICENSE.md +0 -0
  101. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/entry_points.txt +0 -0
  102. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,228 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from ..backbones.rec_svtrnet import Block, ConvBNLayer
5
+
6
+
7
+ class Im2Seq(nn.Module):
8
+ def __init__(self, in_channels, **kwargs):
9
+ super().__init__()
10
+ self.out_channels = in_channels
11
+
12
+ def forward(self, x):
13
+ B, C, H, W = x.shape
14
+ # assert H == 1
15
+ x = x.squeeze(dim=2)
16
+ # x = x.transpose([0, 2, 1]) # paddle (NTC)(batch, width, channels)
17
+ x = x.permute(0, 2, 1)
18
+ return x
19
+
20
+
21
+ class EncoderWithRNN_(nn.Module):
22
+ def __init__(self, in_channels, hidden_size):
23
+ super(EncoderWithRNN_, self).__init__()
24
+ self.out_channels = hidden_size * 2
25
+ self.rnn1 = nn.LSTM(
26
+ in_channels,
27
+ hidden_size,
28
+ bidirectional=False,
29
+ batch_first=True,
30
+ num_layers=2,
31
+ )
32
+ self.rnn2 = nn.LSTM(
33
+ in_channels,
34
+ hidden_size,
35
+ bidirectional=False,
36
+ batch_first=True,
37
+ num_layers=2,
38
+ )
39
+
40
+ def forward(self, x):
41
+ self.rnn1.flatten_parameters()
42
+ self.rnn2.flatten_parameters()
43
+ out1, h1 = self.rnn1(x)
44
+ out2, h2 = self.rnn2(torch.flip(x, [1]))
45
+ return torch.cat([out1, torch.flip(out2, [1])], 2)
46
+
47
+
48
+ class EncoderWithRNN(nn.Module):
49
+ def __init__(self, in_channels, hidden_size):
50
+ super(EncoderWithRNN, self).__init__()
51
+ self.out_channels = hidden_size * 2
52
+ self.lstm = nn.LSTM(
53
+ in_channels, hidden_size, num_layers=2, batch_first=True, bidirectional=True
54
+ ) # batch_first:=True
55
+
56
+ def forward(self, x):
57
+ x, _ = self.lstm(x)
58
+ return x
59
+
60
+
61
+ class EncoderWithFC(nn.Module):
62
+ def __init__(self, in_channels, hidden_size):
63
+ super(EncoderWithFC, self).__init__()
64
+ self.out_channels = hidden_size
65
+ self.fc = nn.Linear(
66
+ in_channels,
67
+ hidden_size,
68
+ bias=True,
69
+ )
70
+
71
+ def forward(self, x):
72
+ x = self.fc(x)
73
+ return x
74
+
75
+
76
+ class EncoderWithSVTR(nn.Module):
77
+ def __init__(
78
+ self,
79
+ in_channels,
80
+ dims=64, # XS
81
+ depth=2,
82
+ hidden_dims=120,
83
+ use_guide=False,
84
+ num_heads=8,
85
+ qkv_bias=True,
86
+ mlp_ratio=2.0,
87
+ drop_rate=0.1,
88
+ kernel_size=[3, 3],
89
+ attn_drop_rate=0.1,
90
+ drop_path=0.0,
91
+ qk_scale=None,
92
+ ):
93
+ super(EncoderWithSVTR, self).__init__()
94
+ self.depth = depth
95
+ self.use_guide = use_guide
96
+ self.conv1 = ConvBNLayer(
97
+ in_channels,
98
+ in_channels // 8,
99
+ kernel_size=kernel_size,
100
+ padding=[kernel_size[0] // 2, kernel_size[1] // 2],
101
+ act="swish",
102
+ )
103
+ self.conv2 = ConvBNLayer(
104
+ in_channels // 8, hidden_dims, kernel_size=1, act="swish"
105
+ )
106
+
107
+ self.svtr_block = nn.ModuleList(
108
+ [
109
+ Block(
110
+ dim=hidden_dims,
111
+ num_heads=num_heads,
112
+ mixer="Global",
113
+ HW=None,
114
+ mlp_ratio=mlp_ratio,
115
+ qkv_bias=qkv_bias,
116
+ qk_scale=qk_scale,
117
+ drop=drop_rate,
118
+ act_layer="swish",
119
+ attn_drop=attn_drop_rate,
120
+ drop_path=drop_path,
121
+ norm_layer="nn.LayerNorm",
122
+ epsilon=1e-05,
123
+ prenorm=False,
124
+ )
125
+ for i in range(depth)
126
+ ]
127
+ )
128
+ self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
129
+ self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish")
130
+ # last conv-nxn, the input is concat of input tensor and conv3 output tensor
131
+ self.conv4 = ConvBNLayer(
132
+ 2 * in_channels, in_channels // 8, padding=1, act="swish"
133
+ )
134
+
135
+ self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish")
136
+ self.out_channels = dims
137
+ self.apply(self._init_weights)
138
+
139
+ def _init_weights(self, m):
140
+ # weight initialization
141
+ if isinstance(m, nn.Conv2d):
142
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
143
+ if m.bias is not None:
144
+ nn.init.zeros_(m.bias)
145
+ elif isinstance(m, nn.BatchNorm2d):
146
+ nn.init.ones_(m.weight)
147
+ nn.init.zeros_(m.bias)
148
+ elif isinstance(m, nn.Linear):
149
+ nn.init.normal_(m.weight, 0, 0.01)
150
+ if m.bias is not None:
151
+ nn.init.zeros_(m.bias)
152
+ elif isinstance(m, nn.ConvTranspose2d):
153
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
154
+ if m.bias is not None:
155
+ nn.init.zeros_(m.bias)
156
+ elif isinstance(m, nn.LayerNorm):
157
+ nn.init.ones_(m.weight)
158
+ nn.init.zeros_(m.bias)
159
+
160
+ def forward(self, x):
161
+ # for use guide
162
+ if self.use_guide:
163
+ z = x.clone()
164
+ z.stop_gradient = True
165
+ else:
166
+ z = x
167
+ # for short cut
168
+ h = z
169
+ # reduce dim
170
+ z = self.conv1(z)
171
+ z = self.conv2(z)
172
+ # SVTR global block
173
+ B, C, H, W = z.shape
174
+ z = z.flatten(2).permute(0, 2, 1)
175
+
176
+ for blk in self.svtr_block:
177
+ z = blk(z)
178
+
179
+ z = self.norm(z)
180
+ # last stage
181
+ z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
182
+ z = self.conv3(z)
183
+ z = torch.cat((h, z), dim=1)
184
+ z = self.conv1x1(self.conv4(z))
185
+
186
+ return z
187
+
188
+
189
+ class SequenceEncoder(nn.Module):
190
+ def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
191
+ super(SequenceEncoder, self).__init__()
192
+ self.encoder_reshape = Im2Seq(in_channels)
193
+ self.out_channels = self.encoder_reshape.out_channels
194
+ self.encoder_type = encoder_type
195
+ if encoder_type == "reshape":
196
+ self.only_reshape = True
197
+ else:
198
+ support_encoder_dict = {
199
+ "reshape": Im2Seq,
200
+ "fc": EncoderWithFC,
201
+ "rnn": EncoderWithRNN,
202
+ "svtr": EncoderWithSVTR,
203
+ }
204
+ assert encoder_type in support_encoder_dict, "{} must in {}".format(
205
+ encoder_type, support_encoder_dict.keys()
206
+ )
207
+
208
+ if encoder_type == "svtr":
209
+ self.encoder = support_encoder_dict[encoder_type](
210
+ self.encoder_reshape.out_channels, **kwargs
211
+ )
212
+ else:
213
+ self.encoder = support_encoder_dict[encoder_type](
214
+ self.encoder_reshape.out_channels, hidden_size
215
+ )
216
+ self.out_channels = self.encoder.out_channels
217
+ self.only_reshape = False
218
+
219
+ def forward(self, x):
220
+ if self.encoder_type != "svtr":
221
+ x = self.encoder_reshape(x)
222
+ if not self.only_reshape:
223
+ x = self.encoder(x)
224
+ return x
225
+ else:
226
+ x = self.encoder(x)
227
+ x = self.encoder_reshape(x)
228
+ return x
@@ -0,0 +1,33 @@
1
+
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+ from __future__ import unicode_literals
6
+
7
+ import copy
8
+
9
+ __all__ = ['build_post_process']
10
+
11
+
12
+ def build_post_process(config, global_config=None):
13
+ from .db_postprocess import DBPostProcess
14
+ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, TableLabelDecode, \
15
+ NRTRLabelDecode, SARLabelDecode, ViTSTRLabelDecode, RFLLabelDecode
16
+ from .cls_postprocess import ClsPostProcess
17
+ from .rec_postprocess import CANLabelDecode
18
+
19
+ support_dict = [
20
+ 'DBPostProcess', 'CTCLabelDecode',
21
+ 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
22
+ 'TableLabelDecode', 'NRTRLabelDecode', 'SARLabelDecode',
23
+ 'ViTSTRLabelDecode','CANLabelDecode', 'RFLLabelDecode'
24
+ ]
25
+
26
+ config = copy.deepcopy(config)
27
+ module_name = config.pop('name')
28
+ if global_config is not None:
29
+ config.update(global_config)
30
+ assert module_name in support_dict, Exception(
31
+ 'post process only support {}, but got {}'.format(support_dict, module_name))
32
+ module_class = eval(module_name)(**config)
33
+ return module_class
@@ -0,0 +1,20 @@
1
+ import torch
2
+
3
+
4
+ class ClsPostProcess(object):
5
+ """ Convert between text-label and text-index """
6
+
7
+ def __init__(self, label_list, **kwargs):
8
+ super(ClsPostProcess, self).__init__()
9
+ self.label_list = label_list
10
+
11
+ def __call__(self, preds, label=None, *args, **kwargs):
12
+ if isinstance(preds, torch.Tensor):
13
+ preds = preds.cpu().numpy()
14
+ pred_idxs = preds.argmax(axis=1)
15
+ decode_out = [(self.label_list[idx], preds[i, idx])
16
+ for i, idx in enumerate(pred_idxs)]
17
+ if label is None:
18
+ return decode_out
19
+ label = [(self.label_list[idx], 1.0) for idx in label]
20
+ return decode_out, label
@@ -0,0 +1,179 @@
1
+ """
2
+ This code is refered from:
3
+ https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
4
+ """
5
+ from __future__ import absolute_import
6
+ from __future__ import division
7
+ from __future__ import print_function
8
+
9
+ import numpy as np
10
+ import cv2
11
+ import torch
12
+ from shapely.geometry import Polygon
13
+ import pyclipper
14
+
15
+
16
+ class DBPostProcess(object):
17
+ """
18
+ The post process for Differentiable Binarization (DB).
19
+ """
20
+
21
+ def __init__(self,
22
+ thresh=0.3,
23
+ box_thresh=0.7,
24
+ max_candidates=1000,
25
+ unclip_ratio=2.0,
26
+ use_dilation=False,
27
+ score_mode="fast",
28
+ **kwargs):
29
+ self.thresh = thresh
30
+ self.box_thresh = box_thresh
31
+ self.max_candidates = max_candidates
32
+ self.unclip_ratio = unclip_ratio
33
+ self.min_size = 3
34
+ self.score_mode = score_mode
35
+ assert score_mode in [
36
+ "slow", "fast"
37
+ ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
38
+
39
+ self.dilation_kernel = None if not use_dilation else np.array(
40
+ [[1, 1], [1, 1]])
41
+
42
+ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
43
+ '''
44
+ _bitmap: single map with shape (1, H, W),
45
+ whose values are binarized as {0, 1}
46
+ '''
47
+
48
+ bitmap = _bitmap
49
+ height, width = bitmap.shape
50
+
51
+ outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
52
+ cv2.CHAIN_APPROX_SIMPLE)
53
+ if len(outs) == 3:
54
+ img, contours, _ = outs[0], outs[1], outs[2]
55
+ elif len(outs) == 2:
56
+ contours, _ = outs[0], outs[1]
57
+
58
+ num_contours = min(len(contours), self.max_candidates)
59
+
60
+ boxes = []
61
+ scores = []
62
+ for index in range(num_contours):
63
+ contour = contours[index]
64
+ points, sside = self.get_mini_boxes(contour)
65
+ if sside < self.min_size:
66
+ continue
67
+ points = np.array(points)
68
+ if self.score_mode == "fast":
69
+ score = self.box_score_fast(pred, points.reshape(-1, 2))
70
+ else:
71
+ score = self.box_score_slow(pred, contour)
72
+ if self.box_thresh > score:
73
+ continue
74
+
75
+ box = self.unclip(points).reshape(-1, 1, 2)
76
+ box, sside = self.get_mini_boxes(box)
77
+ if sside < self.min_size + 2:
78
+ continue
79
+ box = np.array(box)
80
+
81
+ box[:, 0] = np.clip(
82
+ np.round(box[:, 0] / width * dest_width), 0, dest_width)
83
+ box[:, 1] = np.clip(
84
+ np.round(box[:, 1] / height * dest_height), 0, dest_height)
85
+ boxes.append(box.astype(np.int16))
86
+ scores.append(score)
87
+ return np.array(boxes, dtype=np.int16), scores
88
+
89
+ def unclip(self, box):
90
+ unclip_ratio = self.unclip_ratio
91
+ poly = Polygon(box)
92
+ distance = poly.area * unclip_ratio / poly.length
93
+ offset = pyclipper.PyclipperOffset()
94
+ offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
95
+ expanded = np.array(offset.Execute(distance))
96
+ return expanded
97
+
98
+ def get_mini_boxes(self, contour):
99
+ bounding_box = cv2.minAreaRect(contour)
100
+ points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
101
+
102
+ index_1, index_2, index_3, index_4 = 0, 1, 2, 3
103
+ if points[1][1] > points[0][1]:
104
+ index_1 = 0
105
+ index_4 = 1
106
+ else:
107
+ index_1 = 1
108
+ index_4 = 0
109
+ if points[3][1] > points[2][1]:
110
+ index_2 = 2
111
+ index_3 = 3
112
+ else:
113
+ index_2 = 3
114
+ index_3 = 2
115
+
116
+ box = [
117
+ points[index_1], points[index_2], points[index_3], points[index_4]
118
+ ]
119
+ return box, min(bounding_box[1])
120
+
121
+ def box_score_fast(self, bitmap, _box):
122
+ '''
123
+ box_score_fast: use bbox mean score as the mean score
124
+ '''
125
+ h, w = bitmap.shape[:2]
126
+ box = _box.copy()
127
+ xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int64), 0, w - 1)
128
+ xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int64), 0, w - 1)
129
+ ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int64), 0, h - 1)
130
+ ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int64), 0, h - 1)
131
+
132
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
133
+ box[:, 0] = box[:, 0] - xmin
134
+ box[:, 1] = box[:, 1] - ymin
135
+ cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
136
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
137
+
138
+ def box_score_slow(self, bitmap, contour):
139
+ '''
140
+ box_score_slow: use polyon mean score as the mean score
141
+ '''
142
+ h, w = bitmap.shape[:2]
143
+ contour = contour.copy()
144
+ contour = np.reshape(contour, (-1, 2))
145
+
146
+ xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
147
+ xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
148
+ ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
149
+ ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
150
+
151
+ mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
152
+
153
+ contour[:, 0] = contour[:, 0] - xmin
154
+ contour[:, 1] = contour[:, 1] - ymin
155
+
156
+ cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
157
+ return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
158
+
159
+ def __call__(self, outs_dict, shape_list):
160
+ pred = outs_dict['maps']
161
+ if isinstance(pred, torch.Tensor):
162
+ pred = pred.cpu().numpy()
163
+ pred = pred[:, 0, :, :]
164
+ segmentation = pred > self.thresh
165
+
166
+ boxes_batch = []
167
+ for batch_index in range(pred.shape[0]):
168
+ src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
169
+ if self.dilation_kernel is not None:
170
+ mask = cv2.dilate(
171
+ np.array(segmentation[batch_index]).astype(np.uint8),
172
+ self.dilation_kernel)
173
+ else:
174
+ mask = segmentation[batch_index]
175
+ boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
176
+ src_w, src_h)
177
+
178
+ boxes_batch.append({'points': boxes})
179
+ return boxes_batch