openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,73 @@
1
+ import copy
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ __all__ = ['build_optimizer']
7
+
8
+
9
+ def param_groups_weight_decay(model: nn.Module,
10
+ weight_decay=1e-5,
11
+ no_weight_decay_list=()):
12
+ no_weight_decay_list = set(no_weight_decay_list)
13
+ decay = []
14
+ no_decay = []
15
+ for name, param in model.named_parameters():
16
+ if not param.requires_grad:
17
+ continue
18
+
19
+ if param.ndim <= 1 or name.endswith(
20
+ '.bias') or any(nd in name for nd in no_weight_decay_list):
21
+ no_decay.append(param)
22
+ else:
23
+ decay.append(param)
24
+
25
+ return [
26
+ {
27
+ 'params': no_decay,
28
+ 'weight_decay': 0.0
29
+ },
30
+ {
31
+ 'params': decay,
32
+ 'weight_decay': weight_decay
33
+ },
34
+ ]
35
+
36
+
37
+ def build_optimizer(optim_config, lr_scheduler_config, epochs, step_each_epoch,
38
+ model):
39
+ from . import lr
40
+
41
+ config = copy.deepcopy(optim_config)
42
+
43
+ if isinstance(model, nn.Module):
44
+ # a model was passed in, extract parameters and add weight decays to appropriate layers
45
+ weight_decay = config.get('weight_decay', 0.0)
46
+ filter_bias_and_bn = (config.pop('filter_bias_and_bn')
47
+ if 'filter_bias_and_bn' in config else False)
48
+ if weight_decay > 0.0 and filter_bias_and_bn:
49
+ no_weight_decay = {}
50
+ if hasattr(model, 'no_weight_decay'):
51
+ no_weight_decay = model.no_weight_decay()
52
+ parameters = param_groups_weight_decay(model, weight_decay,
53
+ no_weight_decay)
54
+ config['weight_decay'] = 0.0
55
+ # print('debug adamw')
56
+ else:
57
+ parameters = model.parameters()
58
+ else:
59
+ # iterable of parameters or param groups passed in
60
+ parameters = model
61
+
62
+ optim = getattr(torch.optim, config.pop('name'))(params=parameters,
63
+ **config)
64
+
65
+ lr_config = copy.deepcopy(lr_scheduler_config)
66
+ lr_config.update({
67
+ 'epochs': epochs,
68
+ 'step_each_epoch': step_each_epoch,
69
+ 'lr': config['lr']
70
+ })
71
+ lr_scheduler = getattr(lr,
72
+ lr_config.pop('name'))(**lr_config)(optimizer=optim)
73
+ return optim, lr_scheduler
@@ -0,0 +1,227 @@
1
+ import math
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ from torch.optim import lr_scheduler
6
+
7
+
8
+ class StepLR(object):
9
+
10
+ def __init__(self,
11
+ step_each_epoch,
12
+ step_size,
13
+ warmup_epoch=0,
14
+ gamma=0.1,
15
+ last_epoch=-1,
16
+ **kwargs):
17
+ super(StepLR, self).__init__()
18
+ self.step_size = step_each_epoch * step_size
19
+ self.gamma = gamma
20
+ self.last_epoch = last_epoch
21
+ self.warmup_epoch = warmup_epoch
22
+
23
+ def __call__(self, optimizer):
24
+ return lr_scheduler.LambdaLR(optimizer, self.lambda_func,
25
+ self.last_epoch)
26
+
27
+ def lambda_func(self, current_step):
28
+ if current_step < self.warmup_epoch:
29
+ return float(current_step) / float(max(1, self.warmup_epoch))
30
+ return self.gamma**(current_step // self.step_size)
31
+
32
+
33
+ class MultiStepLR(object):
34
+
35
+ def __init__(self,
36
+ step_each_epoch,
37
+ milestones,
38
+ warmup_epoch=0,
39
+ gamma=0.1,
40
+ last_epoch=-1,
41
+ **kwargs):
42
+ super(MultiStepLR, self).__init__()
43
+ self.milestones = [step_each_epoch * e for e in milestones]
44
+ self.gamma = gamma
45
+ self.last_epoch = last_epoch
46
+ self.warmup_epoch = warmup_epoch
47
+
48
+ def __call__(self, optimizer):
49
+ return lr_scheduler.LambdaLR(optimizer, self.lambda_func,
50
+ self.last_epoch)
51
+
52
+ def lambda_func(self, current_step):
53
+ if current_step < self.warmup_epoch:
54
+ return float(current_step) / float(max(1, self.warmup_epoch))
55
+ return self.gamma**len(
56
+ [m for m in self.milestones if m <= current_step])
57
+
58
+
59
+ class ConstLR(object):
60
+
61
+ def __init__(self,
62
+ step_each_epoch,
63
+ warmup_epoch=0,
64
+ last_epoch=-1,
65
+ **kwargs):
66
+ super(ConstLR, self).__init__()
67
+ self.last_epoch = last_epoch
68
+ self.warmup_epoch = warmup_epoch * step_each_epoch
69
+
70
+ def __call__(self, optimizer):
71
+ return lr_scheduler.LambdaLR(optimizer, self.lambda_func,
72
+ self.last_epoch)
73
+
74
+ def lambda_func(self, current_step):
75
+ if current_step < self.warmup_epoch:
76
+ return float(current_step) / float(max(1.0, self.warmup_epoch))
77
+ return 1.0
78
+
79
+
80
+ class LinearLR(object):
81
+
82
+ def __init__(self,
83
+ epochs,
84
+ step_each_epoch,
85
+ warmup_epoch=0,
86
+ last_epoch=-1,
87
+ **kwargs):
88
+ super(LinearLR, self).__init__()
89
+ self.epochs = epochs * step_each_epoch
90
+ self.last_epoch = last_epoch
91
+ self.warmup_epoch = warmup_epoch * step_each_epoch
92
+
93
+ def __call__(self, optimizer):
94
+ return lr_scheduler.LambdaLR(optimizer, self.lambda_func,
95
+ self.last_epoch)
96
+
97
+ def lambda_func(self, current_step):
98
+ if current_step < self.warmup_epoch:
99
+ return float(current_step) / float(max(1, self.warmup_epoch))
100
+ return max(
101
+ 0.0,
102
+ float(self.epochs - current_step) /
103
+ float(max(1, self.epochs - self.warmup_epoch)),
104
+ )
105
+
106
+
107
+ class CosineAnnealingLR(object):
108
+
109
+ def __init__(self,
110
+ epochs,
111
+ step_each_epoch,
112
+ warmup_epoch=0,
113
+ last_epoch=-1,
114
+ **kwargs):
115
+ super(CosineAnnealingLR, self).__init__()
116
+ self.epochs = epochs * step_each_epoch
117
+ self.last_epoch = last_epoch
118
+ self.warmup_epoch = warmup_epoch * step_each_epoch
119
+
120
+ def __call__(self, optimizer):
121
+ return lr_scheduler.LambdaLR(optimizer, self.lambda_func,
122
+ self.last_epoch)
123
+
124
+ def lambda_func(self, current_step, num_cycles=0.5):
125
+ if current_step < self.warmup_epoch:
126
+ return float(current_step) / float(max(1, self.warmup_epoch))
127
+ progress = float(current_step - self.warmup_epoch) / float(
128
+ max(1, self.epochs - self.warmup_epoch))
129
+ return max(
130
+ 0.0, 0.5 *
131
+ (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
132
+
133
+
134
+ class OneCycleLR(object):
135
+
136
+ def __init__(self,
137
+ epochs,
138
+ step_each_epoch,
139
+ last_epoch=-1,
140
+ lr=0.00148,
141
+ warmup_epoch=1.0,
142
+ cycle_momentum=True,
143
+ **kwargs):
144
+ super(OneCycleLR, self).__init__()
145
+ self.epochs = epochs
146
+ self.last_epoch = last_epoch
147
+ self.step_each_epoch = step_each_epoch
148
+ self.lr = lr
149
+ self.pct_start = warmup_epoch / epochs
150
+ self.cycle_momentum = cycle_momentum
151
+
152
+ def __call__(self, optimizer):
153
+ return lr_scheduler.OneCycleLR(
154
+ optimizer,
155
+ max_lr=self.lr,
156
+ total_steps=self.epochs * self.step_each_epoch,
157
+ pct_start=self.pct_start,
158
+ cycle_momentum=self.cycle_momentum,
159
+ )
160
+
161
+
162
+ class PolynomialLR(object):
163
+
164
+ def __init__(self,
165
+ step_each_epoch,
166
+ epochs,
167
+ lr_end=1e-7,
168
+ power=1.0,
169
+ warmup_epoch=0,
170
+ last_epoch=-1,
171
+ **kwargs):
172
+ super(PolynomialLR, self).__init__()
173
+ self.lr_end = lr_end
174
+ self.power = power
175
+ self.epochs = epochs * step_each_epoch
176
+ self.warmup_epoch = warmup_epoch * step_each_epoch
177
+ self.last_epoch = last_epoch
178
+
179
+ def __call__(self, optimizer):
180
+ lr_lambda = partial(
181
+ self.lambda_func,
182
+ lr_init=optimizer.defaults['lr'],
183
+ )
184
+ return lr_scheduler.LambdaLR(optimizer, lr_lambda, self.last_epoch)
185
+
186
+ def lambda_func(self, current_step, lr_init):
187
+ if current_step < self.warmup_epoch:
188
+ return float(current_step) / float(max(1, self.warmup_epoch))
189
+ elif current_step > self.epochs:
190
+ return self.lr_end / lr_init # as LambdaLR multiplies by lr_init
191
+ else:
192
+ lr_range = lr_init - self.lr_end
193
+ decay_steps = self.epochs - self.warmup_epoch
194
+ pct_remaining = 1 - (current_step -
195
+ self.warmup_epoch) / decay_steps
196
+ decay = lr_range * pct_remaining**self.power + self.lr_end
197
+ return decay / lr_init # as LambdaLR multiplies by lr_init
198
+
199
+
200
+ class CdistNetLR(object):
201
+
202
+ def __init__(self,
203
+ step_each_epoch,
204
+ lr=0.0442,
205
+ n_warmup_steps=10000,
206
+ step2_epoch=7,
207
+ last_epoch=-1,
208
+ **kwargs):
209
+ super(CdistNetLR, self).__init__()
210
+ self.last_epoch = last_epoch
211
+ self.step2_epoch = step2_epoch * step_each_epoch
212
+ self.n_current_steps = 0
213
+ self.n_warmup_steps = n_warmup_steps
214
+ self.init_lr = lr
215
+ self.step2_lr = 0.00001
216
+
217
+ def __call__(self, optimizer):
218
+ return lr_scheduler.LambdaLR(optimizer, self.lambda_func,
219
+ self.last_epoch)
220
+
221
+ def lambda_func(self, current_step):
222
+ if current_step < self.step2_epoch:
223
+ return np.min([
224
+ np.power(current_step, -0.5),
225
+ np.power(self.n_warmup_steps, -1.5) * current_step,
226
+ ])
227
+ return self.step2_lr / self.init_lr
@@ -0,0 +1,72 @@
1
+ import copy
2
+
3
+ __all__ = ['build_post_process']
4
+
5
+ from .abinet_postprocess import ABINetLabelDecode
6
+ from .ar_postprocess import ARLabelDecode
7
+ from .ce_postprocess import CELabelDecode
8
+ from .char_postprocess import CharLabelDecode
9
+ from .cppd_postprocess import CPPDLabelDecode
10
+ from .ctc_postprocess import CTCLabelDecode
11
+ from .igtr_postprocess import IGTRLabelDecode
12
+ from .lister_postprocess import LISTERLabelDecode
13
+ from .mgp_postprocess import MPGLabelDecode
14
+ from .nrtr_postprocess import NRTRLabelDecode
15
+ from .smtr_postprocess import SMTRLabelDecode
16
+ from .srn_postprocess import SRNLabelDecode
17
+ from .visionlan_postprocess import VisionLANLabelDecode
18
+
19
+ support_dict = [
20
+ 'CTCLabelDecode', 'CharLabelDecode', 'CELabelDecode', 'CPPDLabelDecode',
21
+ 'NRTRLabelDecode', 'ABINetLabelDecode', 'ARLabelDecode', 'IGTRLabelDecode',
22
+ 'VisionLANLabelDecode', 'SMTRLabelDecode', 'SRNLabelDecode',
23
+ 'LISTERLabelDecode', 'GTCLabelDecode', 'MPGLabelDecode'
24
+ ]
25
+
26
+
27
+ def build_post_process(config, global_config=None):
28
+ config = copy.deepcopy(config)
29
+ module_name = config.pop('name')
30
+ if global_config is not None:
31
+ config.update(global_config)
32
+ assert module_name in support_dict, Exception(
33
+ 'post process only support {}'.format(support_dict))
34
+ module_class = eval(module_name)(**config)
35
+ return module_class
36
+
37
+
38
+ class GTCLabelDecode(object):
39
+ """Convert between text-label and text-index."""
40
+
41
+ def __init__(self,
42
+ gtc_label_decode=None,
43
+ character_dict_path=None,
44
+ use_space_char=True,
45
+ only_gtc=False,
46
+ with_ratio=False,
47
+ **kwargs):
48
+ gtc_label_decode['character_dict_path'] = character_dict_path
49
+ gtc_label_decode['use_space_char'] = use_space_char
50
+ self.gtc_label_decode = build_post_process(gtc_label_decode)
51
+ self.ctc_label_decode = CTCLabelDecode(
52
+ character_dict_path=character_dict_path,
53
+ use_space_char=use_space_char)
54
+ self.gtc_character = self.gtc_label_decode.character
55
+ self.ctc_character = self.ctc_label_decode.character
56
+ self.only_gtc = only_gtc
57
+ self.with_ratio = with_ratio
58
+
59
+ def get_character_num(self):
60
+ return [len(self.gtc_character), len(self.ctc_character)]
61
+
62
+ def __call__(self, preds, batch=None, *args, **kwargs):
63
+ if self.with_ratio:
64
+ batch = batch[:-1]
65
+ gtc = self.gtc_label_decode(preds['gtc_pred'],
66
+ batch[:-2] if batch is not None else None)
67
+ if self.only_gtc:
68
+ return gtc
69
+ ctc = self.ctc_label_decode(preds['ctc_pred'], [None] +
70
+ batch[-2:] if batch is not None else None)
71
+
72
+ return [gtc, ctc]
@@ -0,0 +1,37 @@
1
+ import torch
2
+
3
+ from .nrtr_postprocess import NRTRLabelDecode
4
+
5
+
6
+ class ABINetLabelDecode(NRTRLabelDecode):
7
+ """Convert between text-label and text-index."""
8
+
9
+ def __init__(self,
10
+ character_dict_path=None,
11
+ use_space_char=False,
12
+ **kwargs):
13
+ super(ABINetLabelDecode, self).__init__(character_dict_path,
14
+ use_space_char)
15
+
16
+ def __call__(self, preds, batch=None, *args, **kwargs):
17
+ if isinstance(preds, dict):
18
+ if len(preds['align']) > 0:
19
+ preds = preds['align'][-1].detach().cpu().numpy()
20
+ else:
21
+ preds = preds['vision'].detach().cpu().numpy()
22
+ elif isinstance(preds, torch.Tensor):
23
+ preds = preds.detach().cpu().numpy()
24
+ else:
25
+ preds = preds
26
+
27
+ preds_idx = preds.argmax(axis=2)
28
+ preds_prob = preds.max(axis=2)
29
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
30
+ if batch is None:
31
+ return text
32
+ label = self.decode(batch[1].cpu().numpy())
33
+ return text, label
34
+
35
+ def add_special_char(self, dict_character):
36
+ dict_character = ['</s>'] + dict_character
37
+ return dict_character
@@ -0,0 +1,63 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .ctc_postprocess import BaseRecLabelDecode
5
+
6
+
7
+ class ARLabelDecode(BaseRecLabelDecode):
8
+ """Convert between text-label and text-index."""
9
+
10
+ BOS = '<s>'
11
+ EOS = '</s>'
12
+ PAD = '<pad>'
13
+
14
+ def __init__(self,
15
+ character_dict_path=None,
16
+ use_space_char=True,
17
+ **kwargs):
18
+ super(ARLabelDecode, self).__init__(character_dict_path,
19
+ use_space_char)
20
+
21
+ def __call__(self, preds, batch=None, *args, **kwargs):
22
+
23
+ if isinstance(preds, list):
24
+ preds = preds[-1]
25
+ if isinstance(preds, torch.Tensor):
26
+ preds = preds.detach().cpu().numpy()
27
+ preds_idx = preds.argmax(axis=2)
28
+ preds_prob = preds.max(axis=2)
29
+ text = self.decode(preds_idx, preds_prob)
30
+ if batch is None:
31
+ return text
32
+ label = batch[1]
33
+ label = self.decode(label[:, 1:].detach().cpu().numpy())
34
+ return text, label
35
+
36
+ def add_special_char(self, dict_character):
37
+ dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
38
+ return dict_character
39
+
40
+ def decode(self, text_index, text_prob=None):
41
+ """convert text-index into text-label."""
42
+ result_list = []
43
+ batch_size = len(text_index)
44
+ for batch_idx in range(batch_size):
45
+ char_list = []
46
+ conf_list = []
47
+ for idx in range(len(text_index[batch_idx])):
48
+ try:
49
+ char_idx = self.character[int(text_index[batch_idx][idx])]
50
+ except:
51
+ continue
52
+ if char_idx == self.EOS: # end
53
+ break
54
+ if char_idx == self.BOS or char_idx == self.PAD:
55
+ continue
56
+ char_list.append(char_idx)
57
+ if text_prob is not None:
58
+ conf_list.append(text_prob[batch_idx][idx])
59
+ else:
60
+ conf_list.append(1)
61
+ text = ''.join(char_list)
62
+ result_list.append((text, np.mean(conf_list).tolist()))
63
+ return result_list
@@ -0,0 +1,43 @@
1
+ import torch
2
+
3
+ from .ctc_postprocess import BaseRecLabelDecode
4
+
5
+
6
+ class CELabelDecode(BaseRecLabelDecode):
7
+ """Convert between text-label and text-index."""
8
+
9
+ def __init__(self,
10
+ character_dict_path=None,
11
+ use_space_char=False,
12
+ **kwargs):
13
+ super(CELabelDecode, self).__init__(character_dict_path,
14
+ use_space_char)
15
+
16
+ def __call__(self, preds, label=None, *args, **kwargs):
17
+ if isinstance(preds, tuple) or isinstance(preds, list):
18
+ preds = preds[-1]
19
+ if isinstance(preds, torch.Tensor):
20
+ preds = preds.numpy()
21
+ preds_idx = preds.argmax(axis=-1)
22
+ preds_prob = preds.max(axis=-1)
23
+ text = self.decode(preds_idx, preds_prob)
24
+ if label is None:
25
+ return text
26
+ label = self.decode(label.flatten())
27
+ return text, label
28
+
29
+ def decode(self, text_index, text_prob=None):
30
+ """convert text-index into text-label."""
31
+ result_list = []
32
+ batch_size = len(text_index)
33
+ for batch_idx in range(batch_size):
34
+ text = self.character[text_index[batch_idx]]
35
+ if text_prob is not None:
36
+ conf_list = text_prob[batch_idx]
37
+ else:
38
+ conf_list = 1.0
39
+ result_list.append((text, conf_list))
40
+ return result_list
41
+
42
+ def add_special_char(self, dict_character):
43
+ return dict_character
@@ -0,0 +1,108 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .ctc_postprocess import BaseRecLabelDecode
5
+
6
+
7
+ class CharLabelDecode(BaseRecLabelDecode):
8
+ """Convert between text-label and text-index."""
9
+
10
+ def __init__(self,
11
+ character_dict_path=None,
12
+ use_space_char=True,
13
+ **kwargs):
14
+ super(CharLabelDecode, self).__init__(character_dict_path,
15
+ use_space_char)
16
+
17
+ def __call__(self, preds, label=None, *args, **kwargs):
18
+ if len(preds) >= 4:
19
+ preds_id = preds[0]
20
+ preds_prob = preds[1]
21
+ char_preds = preds[2]
22
+ if isinstance(preds_id, torch.Tensor):
23
+ preds_id = preds_id.numpy()
24
+ if isinstance(preds_prob, torch.Tensor):
25
+ preds_prob = preds_prob.numpy()
26
+ if preds_id[0][0] == 2:
27
+ preds_idx = preds_id[:, 1:]
28
+ preds_prob = preds_prob[:, 1:]
29
+ # char_preds = char_preds[:, 1:]
30
+ else:
31
+ preds_idx = preds_id
32
+ char_preds = char_preds.numpy()
33
+ char_preds_idx = char_preds.argmax(-1) + 4
34
+ char_preds_prob = char_preds.max(-1)
35
+ text, text_box = self.decode(preds_idx, preds_prob, char_preds_idx,
36
+ char_preds_prob)
37
+ else:
38
+ preds_logit = preds[0].numpy()
39
+ char_preds = preds[1].numpy()
40
+ # if isinstance(preds, torch.Tensor):
41
+ # preds = preds.numpy()
42
+ preds_idx = preds_logit.argmax(axis=2)
43
+ preds_prob = preds_logit.max(axis=2)
44
+ char_preds_idx = char_preds.argmax(-1) + 4
45
+ char_preds_prob = char_preds.max(-1)
46
+ text, text_box = self.decode(preds_idx, preds_prob, char_preds_idx,
47
+ char_preds_prob)
48
+
49
+ if label is None:
50
+ return text, text_box
51
+ label = self.decode(label[:, 1:])
52
+ return text, text_box, label
53
+
54
+ def add_special_char(self, dict_character):
55
+ dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
56
+ return dict_character
57
+
58
+ def decode(
59
+ self,
60
+ text_index,
61
+ text_prob=None,
62
+ char_text_index=None,
63
+ char_text_prob=None,
64
+ is_remove_duplicate=False,
65
+ ):
66
+ """convert text-index into text-label."""
67
+ result_list = []
68
+ box_result_list = []
69
+ batch_size = len(text_index)
70
+ for batch_idx in range(batch_size):
71
+ char_list = []
72
+ conf_list = []
73
+ char_box_list = []
74
+ conf_box_list = []
75
+ for idx in range(len(text_index[batch_idx])):
76
+ try:
77
+ char_idx = self.character[int(text_index[batch_idx][idx])]
78
+ if char_text_index is not None:
79
+ char_box_idx = self.character[int(
80
+ char_text_index[batch_idx][idx])]
81
+ except:
82
+ continue
83
+ if char_idx == '</s>': # end
84
+ break
85
+ char_list.append(char_idx)
86
+
87
+ if char_text_index is not None:
88
+ char_box_list.append(char_box_idx)
89
+
90
+ if text_prob is not None:
91
+ conf_list.append(text_prob[batch_idx][idx])
92
+ else:
93
+ conf_list.append(1)
94
+
95
+ if char_text_prob is not None:
96
+ conf_box_list.append(char_text_prob[batch_idx][idx])
97
+ else:
98
+ conf_box_list.append(1)
99
+ text = ''.join(char_list)
100
+ result_list.append((text, np.mean(conf_list).tolist()))
101
+
102
+ if char_text_index is not None:
103
+ text_box = ''.join(char_box_list)
104
+ box_result_list.append(
105
+ (text_box, np.mean(conf_box_list).tolist()))
106
+ if char_text_index is not None:
107
+ return result_list, box_result_list
108
+ return result_list
@@ -0,0 +1,42 @@
1
+ import torch
2
+
3
+ from .nrtr_postprocess import NRTRLabelDecode
4
+
5
+
6
+ class CPPDLabelDecode(NRTRLabelDecode):
7
+ """Convert between text-label and text-index."""
8
+
9
+ def __init__(self,
10
+ character_dict_path=None,
11
+ use_space_char=False,
12
+ **kwargs):
13
+ super(CPPDLabelDecode, self).__init__(character_dict_path,
14
+ use_space_char)
15
+
16
+ def __call__(self, preds, batch=None, *args, **kwargs):
17
+
18
+ if isinstance(preds, tuple):
19
+ if isinstance(preds[-1], dict):
20
+ preds = preds[-1]['align'][-1].detach().cpu().numpy()
21
+ else:
22
+ preds = preds[-1].detach().cpu().numpy()
23
+ if isinstance(preds, list):
24
+ preds = preds[-1].detach().cpu().numpy()
25
+ if isinstance(preds, torch.Tensor):
26
+ preds = preds.detach().cpu().numpy()
27
+ elif isinstance(preds, dict):
28
+ preds = preds['align'][-1].detach().cpu().numpy()
29
+ else:
30
+ preds = preds
31
+ preds_idx = preds.argmax(axis=2)
32
+ preds_prob = preds.max(axis=2)
33
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
34
+ if batch is None:
35
+ return text
36
+ label = batch[1]
37
+ label = self.decode(label.detach().cpu().numpy())
38
+ return text, label
39
+
40
+ def add_special_char(self, dict_character):
41
+ dict_character = ['</s>'] + dict_character
42
+ return dict_character