magic-pdf 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. magic_pdf/data/batch_build_dataset.py +156 -0
  2. magic_pdf/data/dataset.py +44 -24
  3. magic_pdf/data/utils.py +108 -9
  4. magic_pdf/dict2md/ocr_mkcontent.py +4 -3
  5. magic_pdf/libs/pdf_image_tools.py +11 -6
  6. magic_pdf/libs/performance_stats.py +12 -1
  7. magic_pdf/libs/version.py +1 -1
  8. magic_pdf/model/batch_analyze.py +175 -201
  9. magic_pdf/model/doc_analyze_by_custom_model.py +137 -92
  10. magic_pdf/model/pdf_extract_kit.py +5 -38
  11. magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
  12. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
  13. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
  14. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
  15. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
  16. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
  17. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
  18. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
  19. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
  20. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
  21. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
  22. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
  23. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
  24. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
  25. magic_pdf/model/sub_modules/model_init.py +50 -37
  26. magic_pdf/model/sub_modules/model_utils.py +17 -11
  27. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
  29. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
  30. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
  31. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
  32. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
  33. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
  34. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
  35. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
  36. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
  37. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
  38. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
  39. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
  40. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
  41. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
  42. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
  43. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
  44. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
  45. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
  46. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
  47. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
  48. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
  49. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
  50. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
  51. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
  52. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
  53. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
  54. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
  55. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
  56. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
  57. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
  58. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
  59. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
  60. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
  61. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
  62. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
  63. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
  64. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
  65. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
  66. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
  67. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
  68. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
  69. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
  70. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
  71. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
  72. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
  73. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
  74. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
  75. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
  76. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
  77. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
  78. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
  79. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +10 -18
  80. magic_pdf/pdf_parse_union_core_v2.py +112 -74
  81. magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
  82. magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
  83. magic_pdf/resources/model_config/model_configs.yaml +1 -1
  84. magic_pdf/tools/cli.py +30 -12
  85. magic_pdf/tools/common.py +90 -12
  86. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/METADATA +50 -40
  87. magic_pdf-1.3.0.dist-info/RECORD +202 -0
  88. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
  89. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
  90. magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
  91. magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
  92. magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
  93. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
  94. magic_pdf-1.2.2.dist-info/RECORD +0 -147
  95. /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
  96. /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
  97. /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
  98. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/LICENSE.md +0 -0
  99. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/WHEEL +0 -0
  100. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/entry_points.txt +0 -0
  101. {magic_pdf-1.2.2.dist-info → magic_pdf-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,128 @@
1
+ t
2
+ a
3
+ _
4
+ i
5
+ m
6
+ g
7
+ /
8
+ 3
9
+ I
10
+ L
11
+ S
12
+ V
13
+ R
14
+ C
15
+ 2
16
+ 0
17
+ 1
18
+ v
19
+ l
20
+ 9
21
+ 7
22
+ 8
23
+ .
24
+ j
25
+ p
26
+
27
+
28
+
29
+
30
+ ி
31
+
32
+
33
+
34
+
35
+
36
+
37
+ 6
38
+
39
+
40
+
41
+ 5
42
+
43
+
44
+
45
+
46
+
47
+ 4
48
+
49
+
50
+
51
+
52
+
53
+
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+ s
77
+ c
78
+ e
79
+ n
80
+ w
81
+ F
82
+ T
83
+ O
84
+ P
85
+ K
86
+ A
87
+ N
88
+ G
89
+ Y
90
+ E
91
+ M
92
+ H
93
+ U
94
+ B
95
+ o
96
+ b
97
+ D
98
+ d
99
+ r
100
+ W
101
+ u
102
+ y
103
+ f
104
+ X
105
+ k
106
+ q
107
+ h
108
+ J
109
+ z
110
+ Z
111
+ Q
112
+ x
113
+ -
114
+ '
115
+ $
116
+ ,
117
+ %
118
+ @
119
+ é
120
+ !
121
+ #
122
+ +
123
+ É
124
+ &
125
+ :
126
+ (
127
+ ?
128
+
@@ -0,0 +1,151 @@
1
+ t
2
+ e
3
+ _
4
+ i
5
+ m
6
+ g
7
+ /
8
+ 5
9
+ I
10
+ L
11
+ S
12
+ V
13
+ R
14
+ C
15
+ 2
16
+ 0
17
+ 1
18
+ v
19
+ a
20
+ l
21
+ 3
22
+ 4
23
+ 8
24
+ 9
25
+ .
26
+ j
27
+ p
28
+
29
+
30
+
31
+
32
+
33
+ ి
34
+
35
+
36
+
37
+
38
+
39
+ 7
40
+ 6
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+
54
+ '
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+ [
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+ ;
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+
83
+ -
84
+
85
+
86
+
87
+
88
+
89
+
90
+ ,
91
+
92
+
93
+ |
94
+ ?
95
+ :
96
+
97
+ "
98
+ (
99
+
100
+ !
101
+ +
102
+ )
103
+ *
104
+ =
105
+ &
106
+
107
+
108
+ ]
109
+ £
110
+ $
111
+ s
112
+ c
113
+ n
114
+ w
115
+ k
116
+ J
117
+ G
118
+ u
119
+ d
120
+ r
121
+ E
122
+ o
123
+ h
124
+ y
125
+ b
126
+ f
127
+ B
128
+ M
129
+ O
130
+ T
131
+ N
132
+ D
133
+ P
134
+ A
135
+ F
136
+ x
137
+ W
138
+ Y
139
+ U
140
+ H
141
+ K
142
+ X
143
+ z
144
+ Z
145
+ Q
146
+ q
147
+ É
148
+ %
149
+ #
150
+ @
151
+ é
@@ -0,0 +1,49 @@
1
+ lang:
2
+ ch:
3
+ det: ch_PP-OCRv3_det_infer.pth
4
+ rec: ch_PP-OCRv4_rec_infer.pth
5
+ dict: ppocr_keys_v1.txt
6
+ en:
7
+ det: en_PP-OCRv3_det_infer.pth
8
+ rec: en_PP-OCRv4_rec_infer.pth
9
+ dict: en_dict.txt
10
+ korean:
11
+ det: Multilingual_PP-OCRv3_det_infer.pth
12
+ rec: korean_PP-OCRv3_rec_infer.pth
13
+ dict: korean_dict.txt
14
+ japan:
15
+ det: Multilingual_PP-OCRv3_det_infer.pth
16
+ rec: japan_PP-OCRv3_rec_infer.pth
17
+ dict: japan_dict.txt
18
+ chinese_cht:
19
+ det: Multilingual_PP-OCRv3_det_infer.pth
20
+ rec: chinese_cht_PP-OCRv3_rec_infer.pth
21
+ dict: chinese_cht_dict.txt
22
+ ta:
23
+ det: Multilingual_PP-OCRv3_det_infer.pth
24
+ rec: ta_PP-OCRv3_rec_infer.pth
25
+ dict: ta_dict.txt
26
+ te:
27
+ det: Multilingual_PP-OCRv3_det_infer.pth
28
+ rec: te_PP-OCRv3_rec_infer.pth
29
+ dict: te_dict.txt
30
+ ka:
31
+ det: Multilingual_PP-OCRv3_det_infer.pth
32
+ rec: ka_PP-OCRv3_rec_infer.pth
33
+ dict: ka_dict.txt
34
+ latin:
35
+ det: en_PP-OCRv3_det_infer.pth
36
+ rec: latin_PP-OCRv3_rec_infer.pth
37
+ dict: latin_dict.txt
38
+ arabic:
39
+ det: Multilingual_PP-OCRv3_det_infer.pth
40
+ rec: arabic_PP-OCRv3_rec_infer.pth
41
+ dict: arabic_dict.txt
42
+ cyrillic:
43
+ det: Multilingual_PP-OCRv3_det_infer.pth
44
+ rec: cyrillic_PP-OCRv3_rec_infer.pth
45
+ dict: cyrillic_dict.txt
46
+ devanagari:
47
+ det: Multilingual_PP-OCRv3_det_infer.pth
48
+ rec: devanagari_PP-OCRv3_rec_infer.pth
49
+ dict: devanagari_dict.txt
@@ -0,0 +1 @@
1
+ # Copyright (c) Opendatalab. All rights reserved.
@@ -0,0 +1 @@
1
+ # Copyright (c) Opendatalab. All rights reserved.
@@ -0,0 +1,106 @@
1
+ import cv2
2
+ import copy
3
+ import numpy as np
4
+ import math
5
+ import time
6
+ import torch
7
+ from ...pytorchocr.base_ocr_v20 import BaseOCRV20
8
+ from . import pytorchocr_utility as utility
9
+ from ...pytorchocr.postprocess import build_post_process
10
+
11
+
12
+ class TextClassifier(BaseOCRV20):
13
+ def __init__(self, args, **kwargs):
14
+ self.device = args.device
15
+ self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
16
+ self.cls_batch_num = args.cls_batch_num
17
+ self.cls_thresh = args.cls_thresh
18
+ postprocess_params = {
19
+ 'name': 'ClsPostProcess',
20
+ "label_list": args.label_list,
21
+ }
22
+ self.postprocess_op = build_post_process(postprocess_params)
23
+
24
+ self.weights_path = args.cls_model_path
25
+ self.yaml_path = args.cls_yaml_path
26
+ network_config = utility.get_arch_config(self.weights_path)
27
+ super(TextClassifier, self).__init__(network_config, **kwargs)
28
+
29
+ self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
30
+
31
+ self.limited_max_width = args.limited_max_width
32
+ self.limited_min_width = args.limited_min_width
33
+
34
+ self.load_pytorch_weights(self.weights_path)
35
+ self.net.eval()
36
+ self.net.to(self.device)
37
+
38
+ def resize_norm_img(self, img):
39
+ imgC, imgH, imgW = self.cls_image_shape
40
+ h = img.shape[0]
41
+ w = img.shape[1]
42
+ ratio = w / float(h)
43
+ imgW = max(min(imgW, self.limited_max_width), self.limited_min_width)
44
+ ratio_imgH = math.ceil(imgH * ratio)
45
+ ratio_imgH = max(ratio_imgH, self.limited_min_width)
46
+ if ratio_imgH > imgW:
47
+ resized_w = imgW
48
+ else:
49
+ resized_w = int(math.ceil(imgH * ratio))
50
+ resized_image = cv2.resize(img, (resized_w, imgH))
51
+ resized_image = resized_image.astype('float32')
52
+ if self.cls_image_shape[0] == 1:
53
+ resized_image = resized_image / 255
54
+ resized_image = resized_image[np.newaxis, :]
55
+ else:
56
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
57
+ resized_image -= 0.5
58
+ resized_image /= 0.5
59
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
60
+ padding_im[:, :, 0:resized_w] = resized_image
61
+ return padding_im
62
+
63
+ def __call__(self, img_list):
64
+ img_list = copy.deepcopy(img_list)
65
+ img_num = len(img_list)
66
+ # Calculate the aspect ratio of all text bars
67
+ width_list = []
68
+ for img in img_list:
69
+ width_list.append(img.shape[1] / float(img.shape[0]))
70
+ # Sorting can speed up the cls process
71
+ indices = np.argsort(np.array(width_list))
72
+
73
+ cls_res = [['', 0.0]] * img_num
74
+ batch_num = self.cls_batch_num
75
+ elapse = 0
76
+ for beg_img_no in range(0, img_num, batch_num):
77
+ end_img_no = min(img_num, beg_img_no + batch_num)
78
+ norm_img_batch = []
79
+ max_wh_ratio = 0
80
+ for ino in range(beg_img_no, end_img_no):
81
+ h, w = img_list[indices[ino]].shape[0:2]
82
+ wh_ratio = w * 1.0 / h
83
+ max_wh_ratio = max(max_wh_ratio, wh_ratio)
84
+ for ino in range(beg_img_no, end_img_no):
85
+ norm_img = self.resize_norm_img(img_list[indices[ino]])
86
+ norm_img = norm_img[np.newaxis, :]
87
+ norm_img_batch.append(norm_img)
88
+ norm_img_batch = np.concatenate(norm_img_batch)
89
+ norm_img_batch = norm_img_batch.copy()
90
+ starttime = time.time()
91
+
92
+ with torch.no_grad():
93
+ inp = torch.from_numpy(norm_img_batch)
94
+ inp = inp.to(self.device)
95
+ prob_out = self.net(inp)
96
+ prob_out = prob_out.cpu().numpy()
97
+
98
+ cls_result = self.postprocess_op(prob_out)
99
+ elapse += time.time() - starttime
100
+ for rno in range(len(cls_result)):
101
+ label, score = cls_result[rno]
102
+ cls_res[indices[beg_img_no + rno]] = [label, score]
103
+ if '180' in label and score > self.cls_thresh:
104
+ img_list[indices[beg_img_no + rno]] = cv2.rotate(
105
+ img_list[indices[beg_img_no + rno]], 1)
106
+ return img_list, cls_res, elapse
@@ -0,0 +1,217 @@
1
+ import sys
2
+
3
+ import numpy as np
4
+ import time
5
+ import torch
6
+ from ...pytorchocr.base_ocr_v20 import BaseOCRV20
7
+ from . import pytorchocr_utility as utility
8
+ from ...pytorchocr.data import create_operators, transform
9
+ from ...pytorchocr.postprocess import build_post_process
10
+
11
+
12
+ class TextDetector(BaseOCRV20):
13
+ def __init__(self, args, **kwargs):
14
+ self.args = args
15
+ self.det_algorithm = args.det_algorithm
16
+ self.device = args.device
17
+ pre_process_list = [{
18
+ 'DetResizeForTest': {
19
+ 'limit_side_len': args.det_limit_side_len,
20
+ 'limit_type': args.det_limit_type,
21
+ }
22
+ }, {
23
+ 'NormalizeImage': {
24
+ 'std': [0.229, 0.224, 0.225],
25
+ 'mean': [0.485, 0.456, 0.406],
26
+ 'scale': '1./255.',
27
+ 'order': 'hwc'
28
+ }
29
+ }, {
30
+ 'ToCHWImage': None
31
+ }, {
32
+ 'KeepKeys': {
33
+ 'keep_keys': ['image', 'shape']
34
+ }
35
+ }]
36
+ postprocess_params = {}
37
+ if self.det_algorithm == "DB":
38
+ postprocess_params['name'] = 'DBPostProcess'
39
+ postprocess_params["thresh"] = args.det_db_thresh
40
+ postprocess_params["box_thresh"] = args.det_db_box_thresh
41
+ postprocess_params["max_candidates"] = 1000
42
+ postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
43
+ postprocess_params["use_dilation"] = args.use_dilation
44
+ postprocess_params["score_mode"] = args.det_db_score_mode
45
+ elif self.det_algorithm == "DB++":
46
+ postprocess_params['name'] = 'DBPostProcess'
47
+ postprocess_params["thresh"] = args.det_db_thresh
48
+ postprocess_params["box_thresh"] = args.det_db_box_thresh
49
+ postprocess_params["max_candidates"] = 1000
50
+ postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
51
+ postprocess_params["use_dilation"] = args.use_dilation
52
+ postprocess_params["score_mode"] = args.det_db_score_mode
53
+ pre_process_list[1] = {
54
+ 'NormalizeImage': {
55
+ 'std': [1.0, 1.0, 1.0],
56
+ 'mean':
57
+ [0.48109378172549, 0.45752457890196, 0.40787054090196],
58
+ 'scale': '1./255.',
59
+ 'order': 'hwc'
60
+ }
61
+ }
62
+ elif self.det_algorithm == "EAST":
63
+ postprocess_params['name'] = 'EASTPostProcess'
64
+ postprocess_params["score_thresh"] = args.det_east_score_thresh
65
+ postprocess_params["cover_thresh"] = args.det_east_cover_thresh
66
+ postprocess_params["nms_thresh"] = args.det_east_nms_thresh
67
+ elif self.det_algorithm == "SAST":
68
+ pre_process_list[0] = {
69
+ 'DetResizeForTest': {
70
+ 'resize_long': args.det_limit_side_len
71
+ }
72
+ }
73
+ postprocess_params['name'] = 'SASTPostProcess'
74
+ postprocess_params["score_thresh"] = args.det_sast_score_thresh
75
+ postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
76
+ self.det_sast_polygon = args.det_sast_polygon
77
+ if self.det_sast_polygon:
78
+ postprocess_params["sample_pts_num"] = 6
79
+ postprocess_params["expand_scale"] = 1.2
80
+ postprocess_params["shrink_ratio_of_width"] = 0.2
81
+ else:
82
+ postprocess_params["sample_pts_num"] = 2
83
+ postprocess_params["expand_scale"] = 1.0
84
+ postprocess_params["shrink_ratio_of_width"] = 0.3
85
+ elif self.det_algorithm == "PSE":
86
+ postprocess_params['name'] = 'PSEPostProcess'
87
+ postprocess_params["thresh"] = args.det_pse_thresh
88
+ postprocess_params["box_thresh"] = args.det_pse_box_thresh
89
+ postprocess_params["min_area"] = args.det_pse_min_area
90
+ postprocess_params["box_type"] = args.det_pse_box_type
91
+ postprocess_params["scale"] = args.det_pse_scale
92
+ self.det_pse_box_type = args.det_pse_box_type
93
+ elif self.det_algorithm == "FCE":
94
+ pre_process_list[0] = {
95
+ 'DetResizeForTest': {
96
+ 'rescale_img': [1080, 736]
97
+ }
98
+ }
99
+ postprocess_params['name'] = 'FCEPostProcess'
100
+ postprocess_params["scales"] = args.scales
101
+ postprocess_params["alpha"] = args.alpha
102
+ postprocess_params["beta"] = args.beta
103
+ postprocess_params["fourier_degree"] = args.fourier_degree
104
+ postprocess_params["box_type"] = args.det_fce_box_type
105
+ else:
106
+ print("unknown det_algorithm:{}".format(self.det_algorithm))
107
+ sys.exit(0)
108
+
109
+ self.preprocess_op = create_operators(pre_process_list)
110
+ self.postprocess_op = build_post_process(postprocess_params)
111
+
112
+ self.weights_path = args.det_model_path
113
+ self.yaml_path = args.det_yaml_path
114
+ network_config = utility.get_arch_config(self.weights_path)
115
+ super(TextDetector, self).__init__(network_config, **kwargs)
116
+ self.load_pytorch_weights(self.weights_path)
117
+ self.net.eval()
118
+ self.net.to(self.device)
119
+
120
+ def order_points_clockwise(self, pts):
121
+ """
122
+ reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
123
+ # sort the points based on their x-coordinates
124
+ """
125
+ xSorted = pts[np.argsort(pts[:, 0]), :]
126
+
127
+ # grab the left-most and right-most points from the sorted
128
+ # x-roodinate points
129
+ leftMost = xSorted[:2, :]
130
+ rightMost = xSorted[2:, :]
131
+
132
+ # now, sort the left-most coordinates according to their
133
+ # y-coordinates so we can grab the top-left and bottom-left
134
+ # points, respectively
135
+ leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
136
+ (tl, bl) = leftMost
137
+
138
+ rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
139
+ (tr, br) = rightMost
140
+
141
+ rect = np.array([tl, tr, br, bl], dtype="float32")
142
+ return rect
143
+
144
+ def clip_det_res(self, points, img_height, img_width):
145
+ for pno in range(points.shape[0]):
146
+ points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
147
+ points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
148
+ return points
149
+
150
+ def filter_tag_det_res(self, dt_boxes, image_shape):
151
+ img_height, img_width = image_shape[0:2]
152
+ dt_boxes_new = []
153
+ for box in dt_boxes:
154
+ box = self.order_points_clockwise(box)
155
+ box = self.clip_det_res(box, img_height, img_width)
156
+ rect_width = int(np.linalg.norm(box[0] - box[1]))
157
+ rect_height = int(np.linalg.norm(box[0] - box[3]))
158
+ if rect_width <= 3 or rect_height <= 3:
159
+ continue
160
+ dt_boxes_new.append(box)
161
+ dt_boxes = np.array(dt_boxes_new)
162
+ return dt_boxes
163
+
164
+ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
165
+ img_height, img_width = image_shape[0:2]
166
+ dt_boxes_new = []
167
+ for box in dt_boxes:
168
+ box = self.clip_det_res(box, img_height, img_width)
169
+ dt_boxes_new.append(box)
170
+ dt_boxes = np.array(dt_boxes_new)
171
+ return dt_boxes
172
+
173
+ def __call__(self, img):
174
+ ori_im = img.copy()
175
+ data = {'image': img}
176
+ data = transform(data, self.preprocess_op)
177
+ img, shape_list = data
178
+ if img is None:
179
+ return None, 0
180
+ img = np.expand_dims(img, axis=0)
181
+ shape_list = np.expand_dims(shape_list, axis=0)
182
+ img = img.copy()
183
+ starttime = time.time()
184
+
185
+ with torch.no_grad():
186
+ inp = torch.from_numpy(img)
187
+ inp = inp.to(self.device)
188
+ outputs = self.net(inp)
189
+
190
+ preds = {}
191
+ if self.det_algorithm == "EAST":
192
+ preds['f_geo'] = outputs['f_geo'].cpu().numpy()
193
+ preds['f_score'] = outputs['f_score'].cpu().numpy()
194
+ elif self.det_algorithm == 'SAST':
195
+ preds['f_border'] = outputs['f_border'].cpu().numpy()
196
+ preds['f_score'] = outputs['f_score'].cpu().numpy()
197
+ preds['f_tco'] = outputs['f_tco'].cpu().numpy()
198
+ preds['f_tvo'] = outputs['f_tvo'].cpu().numpy()
199
+ elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
200
+ preds['maps'] = outputs['maps'].cpu().numpy()
201
+ elif self.det_algorithm == 'FCE':
202
+ for i, (k, output) in enumerate(outputs.items()):
203
+ preds['level_{}'.format(i)] = output
204
+ else:
205
+ raise NotImplementedError
206
+
207
+ post_result = self.postprocess_op(preds, shape_list)
208
+ dt_boxes = post_result[0]['points']
209
+ if (self.det_algorithm == "SAST" and
210
+ self.det_sast_polygon) or (self.det_algorithm in ["PSE", "FCE"] and
211
+ self.postprocess_op.box_type == 'poly'):
212
+ dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
213
+ else:
214
+ dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
215
+
216
+ elapse = time.time() - starttime
217
+ return dt_boxes, elapse