openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,119 @@
1
+ import re
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ class BaseRecLabelDecode(object):
8
+ """Convert between text-label and text-index."""
9
+
10
+ def __init__(self, character_dict_path=None, use_space_char=False):
11
+ self.beg_str = 'sos'
12
+ self.end_str = 'eos'
13
+ self.reverse = False
14
+ self.character_str = []
15
+
16
+ if character_dict_path is None:
17
+ self.character_str = '0123456789abcdefghijklmnopqrstuvwxyz'
18
+ dict_character = list(self.character_str)
19
+ else:
20
+ with open(character_dict_path, 'rb') as fin:
21
+ lines = fin.readlines()
22
+ for line in lines:
23
+ line = line.decode('utf-8').strip('\n').strip('\r\n')
24
+ self.character_str.append(line)
25
+ if use_space_char:
26
+ self.character_str.append(' ')
27
+ dict_character = list(self.character_str)
28
+ if 'arabic' in character_dict_path:
29
+ self.reverse = True
30
+
31
+ dict_character = self.add_special_char(dict_character)
32
+ self.dict = {}
33
+ for i, char in enumerate(dict_character):
34
+ self.dict[char] = i
35
+ self.character = dict_character
36
+
37
+ def pred_reverse(self, pred):
38
+ pred_re = []
39
+ c_current = ''
40
+ for c in pred:
41
+ if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
42
+ if c_current != '':
43
+ pred_re.append(c_current)
44
+ pred_re.append(c)
45
+ c_current = ''
46
+ else:
47
+ c_current += c
48
+ if c_current != '':
49
+ pred_re.append(c_current)
50
+
51
+ return ''.join(pred_re[::-1])
52
+
53
+ def add_special_char(self, dict_character):
54
+ return dict_character
55
+
56
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
57
+ """convert text-index into text-label."""
58
+ result_list = []
59
+ ignored_tokens = self.get_ignored_tokens()
60
+ batch_size = len(text_index)
61
+ for batch_idx in range(batch_size):
62
+ selection = np.ones(len(text_index[batch_idx]), dtype=bool)
63
+ if is_remove_duplicate:
64
+ selection[1:] = text_index[batch_idx][1:] != text_index[
65
+ batch_idx][:-1]
66
+ for ignored_token in ignored_tokens:
67
+ selection &= text_index[batch_idx] != ignored_token
68
+
69
+ char_list = [
70
+ self.character[text_id]
71
+ for text_id in text_index[batch_idx][selection]
72
+ ]
73
+ if text_prob is not None:
74
+ conf_list = text_prob[batch_idx][selection]
75
+ else:
76
+ conf_list = [1] * len(selection)
77
+ if len(conf_list) == 0:
78
+ conf_list = [0]
79
+
80
+ text = ''.join(char_list)
81
+
82
+ if self.reverse: # for arabic rec
83
+ text = self.pred_reverse(text)
84
+
85
+ result_list.append((text, np.mean(conf_list).tolist()))
86
+ return result_list
87
+
88
+ def get_ignored_tokens(self):
89
+ return [0] # for ctc blank
90
+
91
+ def get_character_num(self):
92
+ return len(self.character)
93
+
94
+
95
+ class CTCLabelDecode(BaseRecLabelDecode):
96
+ """Convert between text-label and text-index."""
97
+
98
+ def __init__(self,
99
+ character_dict_path=None,
100
+ use_space_char=False,
101
+ **kwargs):
102
+ super(CTCLabelDecode, self).__init__(character_dict_path,
103
+ use_space_char)
104
+
105
+ def __call__(self, preds, batch=None, *args, **kwargs):
106
+ # preds = preds['res']
107
+ if isinstance(preds, torch.Tensor):
108
+ preds = preds.detach().cpu().numpy()
109
+ preds_idx = preds.argmax(axis=2)
110
+ preds_prob = preds.max(axis=2)
111
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
112
+ if batch is None:
113
+ return text
114
+ label = self.decode(batch[1].cpu().numpy())
115
+ return text, label
116
+
117
+ def add_special_char(self, dict_character):
118
+ dict_character = ['blank'] + dict_character
119
+ return dict_character
@@ -0,0 +1,100 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .nrtr_postprocess import NRTRLabelDecode
5
+
6
+
7
+ class IGTRLabelDecode(NRTRLabelDecode):
8
+ """Convert between text-label and text-index."""
9
+
10
+ def __init__(self,
11
+ character_dict_path=None,
12
+ use_space_char=False,
13
+ **kwargs):
14
+ super(IGTRLabelDecode, self).__init__(character_dict_path,
15
+ use_space_char)
16
+
17
+ def __call__(self, preds, batch=None, *args, **kwargs):
18
+
19
+ if isinstance(preds, list):
20
+ if isinstance(preds[0], dict):
21
+ preds = preds[-1].detach().cpu().numpy()
22
+ if isinstance(preds, torch.Tensor):
23
+ preds = preds.detach().cpu().numpy()
24
+ elif isinstance(preds, dict):
25
+ preds = preds['align'][-1].detach().cpu().numpy()
26
+ else:
27
+ preds = preds
28
+ preds_idx = preds.argmax(axis=2)
29
+ preds_prob = preds.max(axis=2)
30
+ text = self.decode(preds_idx,
31
+ preds_prob,
32
+ is_remove_duplicate=False)
33
+ else:
34
+ preds_idx = preds[0].detach().cpu().numpy()
35
+ preds_prob = preds[1].detach().cpu().numpy()
36
+ text = self.decode(preds_idx,
37
+ preds_prob,
38
+ is_remove_duplicate=False)
39
+ else:
40
+ if isinstance(preds, torch.Tensor):
41
+ preds = preds.detach().cpu().numpy()
42
+ elif isinstance(preds, dict):
43
+ preds = preds['align'][-1].detach().cpu().numpy()
44
+ else:
45
+ preds = preds
46
+ preds_idx = preds.argmax(axis=2)
47
+ preds_idx_top5 = preds.argsort(axis=2)[:, :, -5:]
48
+ preds_prob = preds.max(axis=2)
49
+ text = self.decode(preds_idx,
50
+ preds_prob,
51
+ is_remove_duplicate=False,
52
+ idx_top5=preds_idx_top5)
53
+ if batch is None:
54
+ return text
55
+ label = batch[1]
56
+ label = self.decode(label.detach().cpu().numpy())
57
+ return text, label
58
+
59
+ def add_special_char(self, dict_character):
60
+ dict_character = ['</s>'] + dict_character + ['<s>', '<pad>']
61
+ return dict_character
62
+
63
+ def decode(self,
64
+ text_index,
65
+ text_prob=None,
66
+ is_remove_duplicate=False,
67
+ idx_top5=None):
68
+ """convert text-index into text-label."""
69
+ result_list = []
70
+ batch_size = len(text_index)
71
+ for batch_idx in range(batch_size):
72
+ char_list = []
73
+ char_list_top5 = []
74
+ conf_list = []
75
+ for idx in range(len(text_index[batch_idx])):
76
+ char_idx_top5 = []
77
+ try:
78
+ char_idx = self.character[int(text_index[batch_idx][idx])]
79
+ if idx_top5 is not None:
80
+ for top5_i in idx_top5[batch_idx][idx]:
81
+ char_idx_top5.append(self.character[top5_i])
82
+ except:
83
+ continue
84
+ if char_idx == '</s>': # end
85
+ break
86
+ if char_idx == '<s>' or char_idx == '<pad>':
87
+ continue
88
+ char_list.append(char_idx)
89
+ char_list_top5.append(char_idx_top5)
90
+ if text_prob is not None:
91
+ conf_list.append(text_prob[batch_idx][idx])
92
+ else:
93
+ conf_list.append(1)
94
+ text = ''.join(char_list)
95
+ if idx_top5 is not None:
96
+ result_list.append(
97
+ (text, [np.mean(conf_list).tolist(), char_list_top5]))
98
+ else:
99
+ result_list.append((text, np.mean(conf_list).tolist()))
100
+ return result_list
@@ -0,0 +1,59 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from openrec.postprocess.ctc_postprocess import BaseRecLabelDecode
5
+
6
+
7
+ class LISTERLabelDecode(BaseRecLabelDecode):
8
+ """Convert between text-label and text-index."""
9
+
10
+ def __init__(self,
11
+ character_dict_path=None,
12
+ use_space_char=True,
13
+ **kwargs):
14
+ super(LISTERLabelDecode, self).__init__(character_dict_path,
15
+ use_space_char)
16
+
17
+ def __call__(self, preds, batch=None, *args, **kwargs):
18
+
19
+ preds = preds[1]['logits']
20
+ if isinstance(preds, torch.Tensor):
21
+ preds = preds.detach().cpu().numpy()
22
+ preds_idx = preds.argmax(axis=2)
23
+ # preds_idx_top5 = preds.argsort(axis=2)[:, :, -5:]
24
+ preds_prob = preds.max(axis=2)
25
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
26
+ if batch is None:
27
+ return text
28
+ label = batch[1]
29
+ label = self.decode(label.detach().cpu().numpy())
30
+ return text, label
31
+
32
+ def add_special_char(self, dict_character):
33
+ dict_character = ['</s>'] + dict_character + ['<pad>']
34
+ return dict_character
35
+
36
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
37
+ """convert text-index into text-label."""
38
+ result_list = []
39
+ batch_size = len(text_index)
40
+ for batch_idx in range(batch_size):
41
+ char_list = []
42
+ conf_list = []
43
+ for idx in range(len(text_index[batch_idx])):
44
+ try:
45
+ char_idx = self.character[int(text_index[batch_idx][idx])]
46
+ except:
47
+ continue
48
+ if char_idx == '</s>': # end
49
+ break
50
+ if char_idx == '<s>' or char_idx == '<pad>':
51
+ continue
52
+ char_list.append(char_idx)
53
+ if text_prob is not None:
54
+ conf_list.append(text_prob[batch_idx][idx])
55
+ else:
56
+ conf_list.append(1)
57
+ text = ''.join(char_list)
58
+ result_list.append((text, np.mean(conf_list).tolist()))
59
+ return result_list
@@ -0,0 +1,143 @@
1
+ from .ctc_postprocess import BaseRecLabelDecode
2
+
3
+
4
+ class MPGLabelDecode(BaseRecLabelDecode):
5
+ """Convert between text-label and text-index."""
6
+ SPACE = '[s]'
7
+ GO = '[GO]'
8
+ list_token = [GO, SPACE]
9
+
10
+ def __init__(self,
11
+ character_dict_path=None,
12
+ use_space_char=False,
13
+ only_char=False,
14
+ **kwargs):
15
+ super(MPGLabelDecode, self).__init__(character_dict_path,
16
+ use_space_char)
17
+ self.only_char = only_char
18
+ self.EOS = '[s]'
19
+ self.PAD = '[GO]'
20
+ if not only_char:
21
+ # transformers==4.2.1
22
+ from transformers import BertTokenizer, GPT2Tokenizer
23
+ self.bpe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
24
+ self.wp_tokenizer = BertTokenizer.from_pretrained(
25
+ 'bert-base-uncased')
26
+
27
+ def __call__(self, preds, batch=None, *args, **kwargs):
28
+
29
+ if isinstance(preds, list):
30
+ char_preds = preds[0].detach().cpu().numpy()
31
+ else:
32
+ char_preds = preds.detach().cpu().numpy()
33
+
34
+ preds_idx = char_preds.argmax(axis=2)
35
+ preds_prob = char_preds.max(axis=2)
36
+ char_text = self.char_decode(preds_idx[:, 1:], preds_prob[:, 1:])
37
+ if batch is None:
38
+ return char_text
39
+ label = batch[1]
40
+ label = self.char_decode(label[:, 1:].detach().cpu().numpy())
41
+ if self.only_char:
42
+ return char_text, label
43
+ else:
44
+ bpe_preds = preds[1].detach().cpu().numpy()
45
+ wp_preds = preds[2]
46
+
47
+ bpe_preds_idx = bpe_preds.argmax(axis=2)
48
+ bpe_preds_prob = bpe_preds.max(axis=2)
49
+ bpe_text = self.bpe_decode(bpe_preds_idx[:, 1:],
50
+ bpe_preds_prob[:, 1:])
51
+
52
+ wp_preds = wp_preds.detach() #.cpu().numpy()
53
+ wp_preds_prob, wp_preds_idx = wp_preds.max(-1)
54
+ wp_text = self.wp_decode(wp_preds_idx[:, 1:], wp_preds_prob[:, 1:])
55
+
56
+ final_text = self.final_decode(char_text, bpe_text, wp_text)
57
+ return char_text, bpe_text, wp_text, final_text, label
58
+
59
+ def add_special_char(self, dict_character):
60
+ dict_character = self.list_token + dict_character
61
+ return dict_character
62
+
63
+ def final_decode(self, char_text, bpe_text, wp_text):
64
+ result_list = []
65
+ for (char_pred,
66
+ char_pred_conf), (bpe_pred,
67
+ bpe_pred_conf), (wp_pred, wp_pred_conf) in zip(
68
+ char_text, bpe_text, wp_text):
69
+ final_text = char_pred
70
+ final_prob = char_pred_conf
71
+ if bpe_pred_conf > final_prob:
72
+ final_text = bpe_pred
73
+ final_prob = bpe_pred_conf
74
+ if wp_pred_conf > final_prob:
75
+ final_text = wp_pred
76
+ final_prob = wp_pred_conf
77
+ result_list.append((final_text, final_prob))
78
+ return result_list
79
+
80
+ def char_decode(self, text_index, text_prob=None):
81
+ """ convert text-index into text-label. """
82
+ result_list = []
83
+ batch_size = len(text_index)
84
+ for batch_idx in range(batch_size):
85
+ char_list = []
86
+ conf_list = 1.0
87
+ for idx in range(len(text_index[batch_idx])):
88
+ try:
89
+ char_idx = self.character[int(text_index[batch_idx][idx])]
90
+ except:
91
+ continue
92
+ if text_prob is not None:
93
+ conf_list *= text_prob[batch_idx][idx]
94
+
95
+ if char_idx == self.EOS: # end
96
+ break
97
+ if char_idx == self.PAD:
98
+ continue
99
+ char_list.append(char_idx)
100
+
101
+ text = ''.join(char_list)
102
+ result_list.append((text, conf_list))
103
+ return result_list
104
+
105
+ def bpe_decode(self, text_index, text_prob):
106
+ """ convert text-index into text-label. """
107
+ result_list = []
108
+ for text, probs in zip(text_index, text_prob):
109
+ text_decoded = []
110
+ conf_list = 1.0
111
+ for bpeindx, prob in zip(text, probs):
112
+ tokenstr = self.bpe_tokenizer.decode([bpeindx])
113
+ if tokenstr == '#':
114
+ break
115
+ text_decoded.append(tokenstr)
116
+ conf_list *= prob
117
+ text = ''.join(text_decoded)
118
+ result_list.append((text, conf_list))
119
+ return result_list
120
+
121
+ def wp_decode(self, text_index, text_prob=None):
122
+ """ convert text-index into text-label. """
123
+ result_list = []
124
+ for batch_idx, text in enumerate(text_index):
125
+ wp_pred = self.wp_tokenizer.decode(text)
126
+ wp_pred_EOS = wp_pred.find('[SEP]')
127
+ wp_pred = wp_pred[:wp_pred_EOS]
128
+ if text_prob is not None:
129
+ try:
130
+ # print(text.cpu().tolist())
131
+ wp_pred_EOS_index = text.cpu().tolist().index(102) + 1
132
+ except:
133
+ wp_pred_EOS_index = -1
134
+ wp_pred_max_prob = text_prob[batch_idx][:wp_pred_EOS_index]
135
+ try:
136
+ wp_confidence_score = wp_pred_max_prob.cumprod(
137
+ dim=0)[-1].cpu().numpy().sum()
138
+ except:
139
+ wp_confidence_score = 0.0
140
+ else:
141
+ wp_confidence_score = 1.0
142
+ result_list.append((wp_pred, wp_confidence_score))
143
+ return result_list
@@ -0,0 +1,75 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .ctc_postprocess import BaseRecLabelDecode
5
+
6
+
7
+ class NRTRLabelDecode(BaseRecLabelDecode):
8
+ """Convert between text-label and text-index."""
9
+
10
+ def __init__(self,
11
+ character_dict_path=None,
12
+ use_space_char=True,
13
+ **kwargs):
14
+ super(NRTRLabelDecode, self).__init__(character_dict_path,
15
+ use_space_char)
16
+
17
+ def __call__(self, preds, batch=None, *args, **kwargs):
18
+ preds = preds['res']
19
+ if len(preds) == 2:
20
+ preds_id = preds[0]
21
+ preds_prob = preds[1]
22
+ if isinstance(preds_id, torch.Tensor):
23
+ preds_id = preds_id.detach().cpu().numpy()
24
+ if isinstance(preds_prob, torch.Tensor):
25
+ preds_prob = preds_prob.detach().cpu().numpy()
26
+ if preds_id[0][0] == 2:
27
+ preds_idx = preds_id[:, 1:]
28
+ preds_prob = preds_prob[:, 1:]
29
+ else:
30
+ preds_idx = preds_id
31
+ text = self.decode(preds_idx,
32
+ preds_prob,
33
+ is_remove_duplicate=False)
34
+ if batch is None:
35
+ return text
36
+ label = self.decode(batch[1][:, 1:].cpu().numpy())
37
+ else:
38
+ if isinstance(preds, torch.Tensor):
39
+ preds = preds.detach().cpu().numpy()
40
+ preds_idx = preds.argmax(axis=2)
41
+ preds_prob = preds.max(axis=2)
42
+ text = self.decode(preds_idx,
43
+ preds_prob,
44
+ is_remove_duplicate=False)
45
+ if batch is None:
46
+ return text
47
+ label = self.decode(batch[1][:, 1:].cpu().numpy())
48
+ return text, label
49
+
50
+ def add_special_char(self, dict_character):
51
+ dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
52
+ return dict_character
53
+
54
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
55
+ """convert text-index into text-label."""
56
+ result_list = []
57
+ batch_size = len(text_index)
58
+ for batch_idx in range(batch_size):
59
+ char_list = []
60
+ conf_list = []
61
+ for idx in range(len(text_index[batch_idx])):
62
+ try:
63
+ char_idx = self.character[int(text_index[batch_idx][idx])]
64
+ except:
65
+ continue
66
+ if char_idx == '</s>': # end
67
+ break
68
+ char_list.append(char_idx)
69
+ if text_prob is not None:
70
+ conf_list.append(text_prob[batch_idx][idx])
71
+ else:
72
+ conf_list.append(1)
73
+ text = ''.join(char_list)
74
+ result_list.append((text, np.mean(conf_list).tolist()))
75
+ return result_list
@@ -0,0 +1,73 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .ctc_postprocess import BaseRecLabelDecode
5
+
6
+
7
+ class SMTRLabelDecode(BaseRecLabelDecode):
8
+ """Convert between text-label and text-index."""
9
+
10
+ BOS = '<s>'
11
+ EOS = '</s>'
12
+ IN_F = '<INF>' # ignore
13
+ IN_B = '<INB>' # ignore
14
+ PAD = '<pad>'
15
+
16
+ def __init__(self,
17
+ character_dict_path=None,
18
+ use_space_char=True,
19
+ next_mode=True,
20
+ **kwargs):
21
+ super(SMTRLabelDecode, self).__init__(character_dict_path,
22
+ use_space_char)
23
+ self.next_mode = next_mode
24
+
25
+ def __call__(self, preds, batch=None, *args, **kwargs):
26
+ if isinstance(preds, list):
27
+ preds = preds[-1]
28
+ if isinstance(preds, torch.Tensor):
29
+ preds = preds.detach().cpu().numpy()
30
+ preds_idx = preds.argmax(axis=2)
31
+ preds_prob = preds.max(axis=2)
32
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
33
+ if batch is None:
34
+ return text
35
+ label = batch[1]
36
+ label = self.decode(label[:, 1:].detach().cpu().numpy())
37
+ return text, label
38
+
39
+ def add_special_char(self, dict_character):
40
+ dict_character = [self.EOS] + dict_character + [
41
+ self.BOS, self.IN_F, self.IN_B, self.PAD
42
+ ]
43
+ self.num_character = len(dict_character)
44
+ return dict_character
45
+
46
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
47
+ """convert text-index into text-label."""
48
+ result_list = []
49
+ batch_size = len(text_index)
50
+ for batch_idx in range(batch_size):
51
+ char_list = []
52
+ conf_list = []
53
+ for idx in range(len(text_index[batch_idx])):
54
+ try:
55
+ char_idx = self.character[int(text_index[batch_idx][idx])]
56
+ except:
57
+ continue
58
+ if char_idx == '</s>': # end
59
+ break
60
+ if char_idx == '<s>' or char_idx == '<pad>':
61
+ continue
62
+ char_list.append(char_idx)
63
+
64
+ if text_prob is not None:
65
+ conf_list.append(text_prob[batch_idx][idx])
66
+ else:
67
+ conf_list.append(1)
68
+ if self.next_mode or text_prob is None:
69
+ text = ''.join(char_list)
70
+ else:
71
+ text = ''.join(char_list[::-1])
72
+ result_list.append((text, np.mean(conf_list).tolist()))
73
+ return result_list
@@ -0,0 +1,80 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .ctc_postprocess import BaseRecLabelDecode
5
+
6
+
7
+ class SRNLabelDecode(BaseRecLabelDecode):
8
+ """Convert between text-label and text-index."""
9
+
10
+ def __init__(self,
11
+ character_dict_path=None,
12
+ use_space_char=False,
13
+ **kwargs):
14
+ super(SRNLabelDecode, self).__init__(character_dict_path,
15
+ use_space_char)
16
+ self.max_len = 25
17
+
18
+ def add_special_char(self, dict_character):
19
+ dict_character = dict_character + ['<BOS>', '<EOS>']
20
+ self.start_idx = len(dict_character) - 2
21
+ self.end_idx = len(dict_character) - 1
22
+ return dict_character
23
+
24
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
25
+ """convert text-index into text-label."""
26
+ result_list = []
27
+ ignored_tokens = self.get_ignored_tokens()
28
+ # [B,25]
29
+ batch_size = len(text_index)
30
+ for batch_idx in range(batch_size):
31
+ char_list = []
32
+ conf_list = []
33
+ for idx in range(len(text_index[batch_idx])):
34
+ # print(f"text_index[{batch_idx}][{idx}]:{text_index[batch_idx][idx]}")
35
+ if text_index[batch_idx][idx] in ignored_tokens:
36
+ continue
37
+ if int(text_index[batch_idx][idx]) == int(self.end_idx):
38
+ if text_prob is None and idx == 0:
39
+ continue
40
+ else:
41
+ break
42
+ if is_remove_duplicate:
43
+ # only for predict
44
+ if idx > 0 and text_index[batch_idx][
45
+ idx - 1] == text_index[batch_idx][idx]:
46
+ continue
47
+ char_list.append(self.character[int(
48
+ text_index[batch_idx][idx])])
49
+ if text_prob is not None:
50
+ conf_list.append(text_prob[batch_idx][idx])
51
+ else:
52
+ conf_list.append(1)
53
+ text = ''.join(char_list)
54
+ result_list.append((text, np.mean(conf_list).tolist()))
55
+ return result_list
56
+
57
+ def __call__(self, preds, batch=None, *args, **kwargs):
58
+
59
+ if isinstance(preds, torch.Tensor):
60
+ preds = preds.reshape([-1, self.max_len, preds.shape[-1]])
61
+ preds = preds.detach().cpu().numpy()
62
+ else:
63
+ preds = preds[-1]
64
+ preds = preds.reshape([-1, self.max_len,
65
+ preds.shape[-1]]).detach().cpu().numpy()
66
+
67
+ preds_idx = preds.argmax(axis=2)
68
+ preds_prob = preds.max(axis=2)
69
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
70
+
71
+ if batch is None:
72
+ return text
73
+
74
+ label = batch[1].cpu().numpy()
75
+ # print(f"label.shape:{label.shape}")
76
+ label = self.decode(label, is_remove_duplicate=False)
77
+ return text, label
78
+
79
+ def get_ignored_tokens(self):
80
+ return [self.start_idx, self.end_idx]