openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,802 @@
1
+ import json
2
+ import numpy as np
3
+ import scipy.io as io
4
+
5
+ from tools.utils.utility import check_install
6
+
7
+ from tools.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
8
+
9
+
10
+ def get_socre_A(gt_dir, pred_dict):
11
+ allInputs = 1
12
+
13
+ def input_reading_mod(pred_dict):
14
+ """This helper reads input from txt files"""
15
+ det = []
16
+ n = len(pred_dict)
17
+ for i in range(n):
18
+ points = pred_dict[i]["points"]
19
+ text = pred_dict[i]["texts"]
20
+ point = ",".join(map(
21
+ str,
22
+ points.reshape(-1, ), ))
23
+ det.append([point, text])
24
+ return det
25
+
26
+ def gt_reading_mod(gt_dict):
27
+ """This helper reads groundtruths from mat files"""
28
+ gt = []
29
+ n = len(gt_dict)
30
+ for i in range(n):
31
+ points = gt_dict[i]["points"].tolist()
32
+ h = len(points)
33
+ text = gt_dict[i]["text"]
34
+ xx = [
35
+ np.array(
36
+ ["x:"], dtype="<U2"),
37
+ 0,
38
+ np.array(
39
+ ["y:"], dtype="<U2"),
40
+ 0,
41
+ np.array(
42
+ ["#"], dtype="<U1"),
43
+ np.array(
44
+ ["#"], dtype="<U1"),
45
+ ]
46
+ t_x, t_y = [], []
47
+ for j in range(h):
48
+ t_x.append(points[j][0])
49
+ t_y.append(points[j][1])
50
+ xx[1] = np.array([t_x], dtype="int16")
51
+ xx[3] = np.array([t_y], dtype="int16")
52
+ if text != "":
53
+ xx[4] = np.array([text], dtype="U{}".format(len(text)))
54
+ xx[5] = np.array(["c"], dtype="<U1")
55
+ gt.append(xx)
56
+ return gt
57
+
58
+ def detection_filtering(detections, groundtruths, threshold=0.5):
59
+ for gt_id, gt in enumerate(groundtruths):
60
+ if (gt[5] == "#") and (gt[1].shape[1] > 1):
61
+ gt_x = list(map(int, np.squeeze(gt[1])))
62
+ gt_y = list(map(int, np.squeeze(gt[3])))
63
+ for det_id, detection in enumerate(detections):
64
+ detection_orig = detection
65
+ detection = [float(x) for x in detection[0].split(",")]
66
+ detection = list(map(int, detection))
67
+ det_x = detection[0::2]
68
+ det_y = detection[1::2]
69
+ det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
70
+ if det_gt_iou > threshold:
71
+ detections[det_id] = []
72
+
73
+ detections[:] = [item for item in detections if item != []]
74
+ return detections
75
+
76
+ def sigma_calculation(det_x, det_y, gt_x, gt_y):
77
+ """
78
+ sigma = inter_area / gt_area
79
+ """
80
+ return np.round(
81
+ (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)),
82
+ 2)
83
+
84
+ def tau_calculation(det_x, det_y, gt_x, gt_y):
85
+ if area(det_x, det_y) == 0.0:
86
+ return 0
87
+ return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
88
+ area(det_x, det_y)), 2)
89
+
90
+ ##############################Initialization###################################
91
+ # global_sigma = []
92
+ # global_tau = []
93
+ # global_pred_str = []
94
+ # global_gt_str = []
95
+ ###############################################################################
96
+
97
+ for input_id in range(allInputs):
98
+ if ((input_id != ".DS_Store") and (input_id != "Pascal_result.txt") and
99
+ (input_id != "Pascal_result_curved.txt") and
100
+ (input_id != "Pascal_result_non_curved.txt") and
101
+ (input_id != "Deteval_result.txt") and
102
+ (input_id != "Deteval_result_curved.txt") and
103
+ (input_id != "Deteval_result_non_curved.txt")):
104
+ detections = input_reading_mod(pred_dict)
105
+ groundtruths = gt_reading_mod(gt_dir)
106
+ detections = detection_filtering(
107
+ detections,
108
+ groundtruths) # filters detections overlapping with DC area
109
+ dc_id = []
110
+ for i in range(len(groundtruths)):
111
+ if groundtruths[i][5] == "#":
112
+ dc_id.append(i)
113
+ cnt = 0
114
+ for a in dc_id:
115
+ num = a - cnt
116
+ del groundtruths[num]
117
+ cnt += 1
118
+
119
+ local_sigma_table = np.zeros((len(groundtruths), len(detections)))
120
+ local_tau_table = np.zeros((len(groundtruths), len(detections)))
121
+ local_pred_str = {}
122
+ local_gt_str = {}
123
+
124
+ for gt_id, gt in enumerate(groundtruths):
125
+ if len(detections) > 0:
126
+ for det_id, detection in enumerate(detections):
127
+ detection_orig = detection
128
+ detection = [float(x) for x in detection[0].split(",")]
129
+ detection = list(map(int, detection))
130
+ pred_seq_str = detection_orig[1].strip()
131
+ det_x = detection[0::2]
132
+ det_y = detection[1::2]
133
+ gt_x = list(map(int, np.squeeze(gt[1])))
134
+ gt_y = list(map(int, np.squeeze(gt[3])))
135
+ gt_seq_str = str(gt[4].tolist()[0])
136
+
137
+ local_sigma_table[gt_id, det_id] = sigma_calculation(
138
+ det_x, det_y, gt_x, gt_y)
139
+ local_tau_table[gt_id, det_id] = tau_calculation(
140
+ det_x, det_y, gt_x, gt_y)
141
+ local_pred_str[det_id] = pred_seq_str
142
+ local_gt_str[gt_id] = gt_seq_str
143
+
144
+ global_sigma = local_sigma_table
145
+ global_tau = local_tau_table
146
+ global_pred_str = local_pred_str
147
+ global_gt_str = local_gt_str
148
+
149
+ single_data = {}
150
+ single_data["sigma"] = global_sigma
151
+ single_data["global_tau"] = global_tau
152
+ single_data["global_pred_str"] = global_pred_str
153
+ single_data["global_gt_str"] = global_gt_str
154
+ return single_data
155
+
156
+
157
+ def get_socre_B(gt_dir, img_id, pred_dict):
158
+ allInputs = 1
159
+
160
+ def input_reading_mod(pred_dict):
161
+ """This helper reads input from txt files"""
162
+ det = []
163
+ n = len(pred_dict)
164
+ for i in range(n):
165
+ points = pred_dict[i]["points"]
166
+ text = pred_dict[i]["texts"]
167
+ point = ",".join(map(
168
+ str,
169
+ points.reshape(-1, ), ))
170
+ det.append([point, text])
171
+ return det
172
+
173
+ def gt_reading_mod(gt_dir, gt_id):
174
+ gt = io.loadmat("%s/poly_gt_img%s.mat" % (gt_dir, gt_id))
175
+ gt = gt["polygt"]
176
+ return gt
177
+
178
+ def detection_filtering(detections, groundtruths, threshold=0.5):
179
+ for gt_id, gt in enumerate(groundtruths):
180
+ if (gt[5] == "#") and (gt[1].shape[1] > 1):
181
+ gt_x = list(map(int, np.squeeze(gt[1])))
182
+ gt_y = list(map(int, np.squeeze(gt[3])))
183
+ for det_id, detection in enumerate(detections):
184
+ detection_orig = detection
185
+ detection = [float(x) for x in detection[0].split(",")]
186
+ detection = list(map(int, detection))
187
+ det_x = detection[0::2]
188
+ det_y = detection[1::2]
189
+ det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
190
+ if det_gt_iou > threshold:
191
+ detections[det_id] = []
192
+
193
+ detections[:] = [item for item in detections if item != []]
194
+ return detections
195
+
196
+ def sigma_calculation(det_x, det_y, gt_x, gt_y):
197
+ """
198
+ sigma = inter_area / gt_area
199
+ """
200
+ return np.round(
201
+ (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)),
202
+ 2)
203
+
204
+ def tau_calculation(det_x, det_y, gt_x, gt_y):
205
+ if area(det_x, det_y) == 0.0:
206
+ return 0
207
+ return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
208
+ area(det_x, det_y)), 2)
209
+
210
+ ##############################Initialization###################################
211
+ # global_sigma = []
212
+ # global_tau = []
213
+ # global_pred_str = []
214
+ # global_gt_str = []
215
+ ###############################################################################
216
+
217
+ for input_id in range(allInputs):
218
+ if ((input_id != ".DS_Store") and (input_id != "Pascal_result.txt") and
219
+ (input_id != "Pascal_result_curved.txt") and
220
+ (input_id != "Pascal_result_non_curved.txt") and
221
+ (input_id != "Deteval_result.txt") and
222
+ (input_id != "Deteval_result_curved.txt") and
223
+ (input_id != "Deteval_result_non_curved.txt")):
224
+ detections = input_reading_mod(pred_dict)
225
+ groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
226
+ detections = detection_filtering(
227
+ detections,
228
+ groundtruths) # filters detections overlapping with DC area
229
+ dc_id = []
230
+ for i in range(len(groundtruths)):
231
+ if groundtruths[i][5] == "#":
232
+ dc_id.append(i)
233
+ cnt = 0
234
+ for a in dc_id:
235
+ num = a - cnt
236
+ del groundtruths[num]
237
+ cnt += 1
238
+
239
+ local_sigma_table = np.zeros((len(groundtruths), len(detections)))
240
+ local_tau_table = np.zeros((len(groundtruths), len(detections)))
241
+ local_pred_str = {}
242
+ local_gt_str = {}
243
+
244
+ for gt_id, gt in enumerate(groundtruths):
245
+ if len(detections) > 0:
246
+ for det_id, detection in enumerate(detections):
247
+ detection_orig = detection
248
+ detection = [float(x) for x in detection[0].split(",")]
249
+ detection = list(map(int, detection))
250
+ pred_seq_str = detection_orig[1].strip()
251
+ det_x = detection[0::2]
252
+ det_y = detection[1::2]
253
+ gt_x = list(map(int, np.squeeze(gt[1])))
254
+ gt_y = list(map(int, np.squeeze(gt[3])))
255
+ gt_seq_str = str(gt[4].tolist()[0])
256
+
257
+ local_sigma_table[gt_id, det_id] = sigma_calculation(
258
+ det_x, det_y, gt_x, gt_y)
259
+ local_tau_table[gt_id, det_id] = tau_calculation(
260
+ det_x, det_y, gt_x, gt_y)
261
+ local_pred_str[det_id] = pred_seq_str
262
+ local_gt_str[gt_id] = gt_seq_str
263
+
264
+ global_sigma = local_sigma_table
265
+ global_tau = local_tau_table
266
+ global_pred_str = local_pred_str
267
+ global_gt_str = local_gt_str
268
+
269
+ single_data = {}
270
+ single_data["sigma"] = global_sigma
271
+ single_data["global_tau"] = global_tau
272
+ single_data["global_pred_str"] = global_pred_str
273
+ single_data["global_gt_str"] = global_gt_str
274
+ return single_data
275
+
276
+
277
+ def get_score_C(gt_label, text, pred_bboxes):
278
+ """
279
+ get score for CentripetalText (CT) prediction.
280
+ """
281
+ check_install("Polygon", "Polygon3")
282
+ import Polygon as plg
283
+
284
+ def gt_reading_mod(gt_label, text):
285
+ """This helper reads groundtruths from mat files"""
286
+ groundtruths = []
287
+ nbox = len(gt_label)
288
+ for i in range(nbox):
289
+ label = {"transcription": text[i][0], "points": gt_label[i].numpy()}
290
+ groundtruths.append(label)
291
+
292
+ return groundtruths
293
+
294
+ def get_union(pD, pG):
295
+ areaA = pD.area()
296
+ areaB = pG.area()
297
+ return areaA + areaB - get_intersection(pD, pG)
298
+
299
+ def get_intersection(pD, pG):
300
+ pInt = pD & pG
301
+ if len(pInt) == 0:
302
+ return 0
303
+ return pInt.area()
304
+
305
+ def detection_filtering(detections, groundtruths, threshold=0.5):
306
+ for gt in groundtruths:
307
+ point_num = gt["points"].shape[1] // 2
308
+ if gt["transcription"] == "###" and (point_num > 1):
309
+ gt_p = np.array(gt["points"]).reshape(point_num,
310
+ 2).astype("int32")
311
+ gt_p = plg.Polygon(gt_p)
312
+
313
+ for det_id, detection in enumerate(detections):
314
+ det_y = detection[0::2]
315
+ det_x = detection[1::2]
316
+
317
+ det_p = np.concatenate((np.array(det_x), np.array(det_y)))
318
+ det_p = det_p.reshape(2, -1).transpose()
319
+ det_p = plg.Polygon(det_p)
320
+
321
+ try:
322
+ det_gt_iou = get_intersection(det_p,
323
+ gt_p) / det_p.area()
324
+ except:
325
+ print(det_x, det_y, gt_p)
326
+ if det_gt_iou > threshold:
327
+ detections[det_id] = []
328
+
329
+ detections[:] = [item for item in detections if item != []]
330
+ return detections
331
+
332
+ def sigma_calculation(det_p, gt_p):
333
+ """
334
+ sigma = inter_area / gt_area
335
+ """
336
+ if gt_p.area() == 0.0:
337
+ return 0
338
+ return get_intersection(det_p, gt_p) / gt_p.area()
339
+
340
+ def tau_calculation(det_p, gt_p):
341
+ """
342
+ tau = inter_area / det_area
343
+ """
344
+ if det_p.area() == 0.0:
345
+ return 0
346
+ return get_intersection(det_p, gt_p) / det_p.area()
347
+
348
+ detections = []
349
+
350
+ for item in pred_bboxes:
351
+ detections.append(item[:, ::-1].reshape(-1))
352
+
353
+ groundtruths = gt_reading_mod(gt_label, text)
354
+
355
+ detections = detection_filtering(
356
+ detections, groundtruths) # filters detections overlapping with DC area
357
+
358
+ for idx in range(len(groundtruths) - 1, -1, -1):
359
+ # NOTE: source code use 'orin' to indicate '#', here we use 'anno',
360
+ # which may cause slight drop in fscore, about 0.12
361
+ if groundtruths[idx]["transcription"] == "###":
362
+ groundtruths.pop(idx)
363
+
364
+ local_sigma_table = np.zeros((len(groundtruths), len(detections)))
365
+ local_tau_table = np.zeros((len(groundtruths), len(detections)))
366
+
367
+ for gt_id, gt in enumerate(groundtruths):
368
+ if len(detections) > 0:
369
+ for det_id, detection in enumerate(detections):
370
+ point_num = gt["points"].shape[1] // 2
371
+
372
+ gt_p = np.array(gt["points"]).reshape(point_num,
373
+ 2).astype("int32")
374
+ gt_p = plg.Polygon(gt_p)
375
+
376
+ det_y = detection[0::2]
377
+ det_x = detection[1::2]
378
+
379
+ det_p = np.concatenate((np.array(det_x), np.array(det_y)))
380
+
381
+ det_p = det_p.reshape(2, -1).transpose()
382
+ det_p = plg.Polygon(det_p)
383
+
384
+ local_sigma_table[gt_id, det_id] = sigma_calculation(det_p,
385
+ gt_p)
386
+ local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p)
387
+
388
+ data = {}
389
+ data["sigma"] = local_sigma_table
390
+ data["global_tau"] = local_tau_table
391
+ data["global_pred_str"] = ""
392
+ data["global_gt_str"] = ""
393
+ return data
394
+
395
+
396
+ def combine_results(all_data, rec_flag=True):
397
+ tr = 0.7
398
+ tp = 0.6
399
+ fsc_k = 0.8
400
+ k = 2
401
+ global_sigma = []
402
+ global_tau = []
403
+ global_pred_str = []
404
+ global_gt_str = []
405
+
406
+ for data in all_data:
407
+ global_sigma.append(data["sigma"])
408
+ global_tau.append(data["global_tau"])
409
+ global_pred_str.append(data["global_pred_str"])
410
+ global_gt_str.append(data["global_gt_str"])
411
+
412
+ global_accumulative_recall = 0
413
+ global_accumulative_precision = 0
414
+ total_num_gt = 0
415
+ total_num_det = 0
416
+ hit_str_count = 0
417
+ hit_count = 0
418
+
419
+ def one_to_one(
420
+ local_sigma_table,
421
+ local_tau_table,
422
+ local_accumulative_recall,
423
+ local_accumulative_precision,
424
+ global_accumulative_recall,
425
+ global_accumulative_precision,
426
+ gt_flag,
427
+ det_flag,
428
+ idy,
429
+ rec_flag, ):
430
+ hit_str_num = 0
431
+ for gt_id in range(num_gt):
432
+ gt_matching_qualified_sigma_candidates = np.where(
433
+ local_sigma_table[gt_id, :] > tr)
434
+ gt_matching_num_qualified_sigma_candidates = (
435
+ gt_matching_qualified_sigma_candidates[0].shape[0])
436
+ gt_matching_qualified_tau_candidates = np.where(
437
+ local_tau_table[gt_id, :] > tp)
438
+ gt_matching_num_qualified_tau_candidates = (
439
+ gt_matching_qualified_tau_candidates[0].shape[0])
440
+
441
+ det_matching_qualified_sigma_candidates = np.where(
442
+ local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
443
+ > tr)
444
+ det_matching_num_qualified_sigma_candidates = (
445
+ det_matching_qualified_sigma_candidates[0].shape[0])
446
+ det_matching_qualified_tau_candidates = np.where(
447
+ local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
448
+ tp)
449
+ det_matching_num_qualified_tau_candidates = (
450
+ det_matching_qualified_tau_candidates[0].shape[0])
451
+
452
+ if ((gt_matching_num_qualified_sigma_candidates == 1) and
453
+ (gt_matching_num_qualified_tau_candidates == 1) and
454
+ (det_matching_num_qualified_sigma_candidates == 1) and
455
+ (det_matching_num_qualified_tau_candidates == 1)):
456
+ global_accumulative_recall = global_accumulative_recall + 1.0
457
+ global_accumulative_precision = global_accumulative_precision + 1.0
458
+ local_accumulative_recall = local_accumulative_recall + 1.0
459
+ local_accumulative_precision = local_accumulative_precision + 1.0
460
+
461
+ gt_flag[0, gt_id] = 1
462
+ matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
463
+ # recg start
464
+ if rec_flag:
465
+ gt_str_cur = global_gt_str[idy][gt_id]
466
+ pred_str_cur = global_pred_str[idy][matched_det_id[0]
467
+ .tolist()[0]]
468
+ if pred_str_cur == gt_str_cur:
469
+ hit_str_num += 1
470
+ else:
471
+ if pred_str_cur.lower() == gt_str_cur.lower():
472
+ hit_str_num += 1
473
+ # recg end
474
+ det_flag[0, matched_det_id] = 1
475
+ return (
476
+ local_accumulative_recall,
477
+ local_accumulative_precision,
478
+ global_accumulative_recall,
479
+ global_accumulative_precision,
480
+ gt_flag,
481
+ det_flag,
482
+ hit_str_num, )
483
+
484
+ def one_to_many(
485
+ local_sigma_table,
486
+ local_tau_table,
487
+ local_accumulative_recall,
488
+ local_accumulative_precision,
489
+ global_accumulative_recall,
490
+ global_accumulative_precision,
491
+ gt_flag,
492
+ det_flag,
493
+ idy,
494
+ rec_flag, ):
495
+ hit_str_num = 0
496
+ for gt_id in range(num_gt):
497
+ # skip the following if the groundtruth was matched
498
+ if gt_flag[0, gt_id] > 0:
499
+ continue
500
+
501
+ non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
502
+ num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
503
+
504
+ if num_non_zero_in_sigma >= k:
505
+ ####search for all detections that overlaps with this groundtruth
506
+ qualified_tau_candidates = np.where((local_tau_table[
507
+ gt_id, :] >= tp) & (det_flag[0, :] == 0))
508
+ num_qualified_tau_candidates = qualified_tau_candidates[
509
+ 0].shape[0]
510
+
511
+ if num_qualified_tau_candidates == 1:
512
+ if (local_tau_table[gt_id, qualified_tau_candidates] >= tp
513
+ ) and (
514
+ local_sigma_table[gt_id, qualified_tau_candidates]
515
+ >= tr):
516
+ # became an one-to-one case
517
+ global_accumulative_recall = global_accumulative_recall + 1.0
518
+ global_accumulative_precision = (
519
+ global_accumulative_precision + 1.0)
520
+ local_accumulative_recall = local_accumulative_recall + 1.0
521
+ local_accumulative_precision = (
522
+ local_accumulative_precision + 1.0)
523
+
524
+ gt_flag[0, gt_id] = 1
525
+ det_flag[0, qualified_tau_candidates] = 1
526
+ # recg start
527
+ if rec_flag:
528
+ gt_str_cur = global_gt_str[idy][gt_id]
529
+ pred_str_cur = global_pred_str[idy][
530
+ qualified_tau_candidates[0].tolist()[0]]
531
+ if pred_str_cur == gt_str_cur:
532
+ hit_str_num += 1
533
+ else:
534
+ if pred_str_cur.lower() == gt_str_cur.lower():
535
+ hit_str_num += 1
536
+ # recg end
537
+ elif np.sum(local_sigma_table[gt_id,
538
+ qualified_tau_candidates]) >= tr:
539
+ gt_flag[0, gt_id] = 1
540
+ det_flag[0, qualified_tau_candidates] = 1
541
+ # recg start
542
+ if rec_flag:
543
+ gt_str_cur = global_gt_str[idy][gt_id]
544
+ pred_str_cur = global_pred_str[idy][
545
+ qualified_tau_candidates[0].tolist()[0]]
546
+ if pred_str_cur == gt_str_cur:
547
+ hit_str_num += 1
548
+ else:
549
+ if pred_str_cur.lower() == gt_str_cur.lower():
550
+ hit_str_num += 1
551
+ # recg end
552
+
553
+ global_accumulative_recall = global_accumulative_recall + fsc_k
554
+ global_accumulative_precision = (
555
+ global_accumulative_precision +
556
+ num_qualified_tau_candidates * fsc_k)
557
+
558
+ local_accumulative_recall = local_accumulative_recall + fsc_k
559
+ local_accumulative_precision = (
560
+ local_accumulative_precision +
561
+ num_qualified_tau_candidates * fsc_k)
562
+
563
+ return (
564
+ local_accumulative_recall,
565
+ local_accumulative_precision,
566
+ global_accumulative_recall,
567
+ global_accumulative_precision,
568
+ gt_flag,
569
+ det_flag,
570
+ hit_str_num, )
571
+
572
+ def many_to_one(
573
+ local_sigma_table,
574
+ local_tau_table,
575
+ local_accumulative_recall,
576
+ local_accumulative_precision,
577
+ global_accumulative_recall,
578
+ global_accumulative_precision,
579
+ gt_flag,
580
+ det_flag,
581
+ idy,
582
+ rec_flag, ):
583
+ hit_str_num = 0
584
+ for det_id in range(num_det):
585
+ # skip the following if the detection was matched
586
+ if det_flag[0, det_id] > 0:
587
+ continue
588
+
589
+ non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
590
+ num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
591
+
592
+ if num_non_zero_in_tau >= k:
593
+ ####search for all detections that overlaps with this groundtruth
594
+ qualified_sigma_candidates = np.where((
595
+ local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
596
+ num_qualified_sigma_candidates = qualified_sigma_candidates[
597
+ 0].shape[0]
598
+
599
+ if num_qualified_sigma_candidates == 1:
600
+ if (
601
+ local_tau_table[qualified_sigma_candidates, det_id]
602
+ >= tp
603
+ ) and (local_sigma_table[qualified_sigma_candidates, det_id]
604
+ >= tr):
605
+ # became an one-to-one case
606
+ global_accumulative_recall = global_accumulative_recall + 1.0
607
+ global_accumulative_precision = (
608
+ global_accumulative_precision + 1.0)
609
+ local_accumulative_recall = local_accumulative_recall + 1.0
610
+ local_accumulative_precision = (
611
+ local_accumulative_precision + 1.0)
612
+
613
+ gt_flag[0, qualified_sigma_candidates] = 1
614
+ det_flag[0, det_id] = 1
615
+ # recg start
616
+ if rec_flag:
617
+ pred_str_cur = global_pred_str[idy][det_id]
618
+ gt_len = len(qualified_sigma_candidates[0])
619
+ for idx in range(gt_len):
620
+ ele_gt_id = qualified_sigma_candidates[
621
+ 0].tolist()[idx]
622
+ if ele_gt_id not in global_gt_str[idy]:
623
+ continue
624
+ gt_str_cur = global_gt_str[idy][ele_gt_id]
625
+ if pred_str_cur == gt_str_cur:
626
+ hit_str_num += 1
627
+ break
628
+ else:
629
+ if pred_str_cur.lower() == gt_str_cur.lower(
630
+ ):
631
+ hit_str_num += 1
632
+ break
633
+ # recg end
634
+ elif np.sum(local_tau_table[qualified_sigma_candidates,
635
+ det_id]) >= tp:
636
+ det_flag[0, det_id] = 1
637
+ gt_flag[0, qualified_sigma_candidates] = 1
638
+ # recg start
639
+ if rec_flag:
640
+ pred_str_cur = global_pred_str[idy][det_id]
641
+ gt_len = len(qualified_sigma_candidates[0])
642
+ for idx in range(gt_len):
643
+ ele_gt_id = qualified_sigma_candidates[0].tolist()[
644
+ idx]
645
+ if ele_gt_id not in global_gt_str[idy]:
646
+ continue
647
+ gt_str_cur = global_gt_str[idy][ele_gt_id]
648
+ if pred_str_cur == gt_str_cur:
649
+ hit_str_num += 1
650
+ break
651
+ else:
652
+ if pred_str_cur.lower() == gt_str_cur.lower():
653
+ hit_str_num += 1
654
+ break
655
+ # recg end
656
+
657
+ global_accumulative_recall = (
658
+ global_accumulative_recall +
659
+ num_qualified_sigma_candidates * fsc_k)
660
+ global_accumulative_precision = (
661
+ global_accumulative_precision + fsc_k)
662
+
663
+ local_accumulative_recall = (
664
+ local_accumulative_recall +
665
+ num_qualified_sigma_candidates * fsc_k)
666
+ local_accumulative_precision = local_accumulative_precision + fsc_k
667
+ return (
668
+ local_accumulative_recall,
669
+ local_accumulative_precision,
670
+ global_accumulative_recall,
671
+ global_accumulative_precision,
672
+ gt_flag,
673
+ det_flag,
674
+ hit_str_num, )
675
+
676
+ for idx in range(len(global_sigma)):
677
+ local_sigma_table = np.array(global_sigma[idx])
678
+ local_tau_table = global_tau[idx]
679
+
680
+ num_gt = local_sigma_table.shape[0]
681
+ num_det = local_sigma_table.shape[1]
682
+
683
+ total_num_gt = total_num_gt + num_gt
684
+ total_num_det = total_num_det + num_det
685
+
686
+ local_accumulative_recall = 0
687
+ local_accumulative_precision = 0
688
+ gt_flag = np.zeros((1, num_gt))
689
+ det_flag = np.zeros((1, num_det))
690
+
691
+ #######first check for one-to-one case##########
692
+ (
693
+ local_accumulative_recall,
694
+ local_accumulative_precision,
695
+ global_accumulative_recall,
696
+ global_accumulative_precision,
697
+ gt_flag,
698
+ det_flag,
699
+ hit_str_num, ) = one_to_one(
700
+ local_sigma_table,
701
+ local_tau_table,
702
+ local_accumulative_recall,
703
+ local_accumulative_precision,
704
+ global_accumulative_recall,
705
+ global_accumulative_precision,
706
+ gt_flag,
707
+ det_flag,
708
+ idx,
709
+ rec_flag, )
710
+
711
+ hit_str_count += hit_str_num
712
+ #######then check for one-to-many case##########
713
+ (
714
+ local_accumulative_recall,
715
+ local_accumulative_precision,
716
+ global_accumulative_recall,
717
+ global_accumulative_precision,
718
+ gt_flag,
719
+ det_flag,
720
+ hit_str_num, ) = one_to_many(
721
+ local_sigma_table,
722
+ local_tau_table,
723
+ local_accumulative_recall,
724
+ local_accumulative_precision,
725
+ global_accumulative_recall,
726
+ global_accumulative_precision,
727
+ gt_flag,
728
+ det_flag,
729
+ idx,
730
+ rec_flag, )
731
+ hit_str_count += hit_str_num
732
+ #######then check for many-to-one case##########
733
+ (
734
+ local_accumulative_recall,
735
+ local_accumulative_precision,
736
+ global_accumulative_recall,
737
+ global_accumulative_precision,
738
+ gt_flag,
739
+ det_flag,
740
+ hit_str_num, ) = many_to_one(
741
+ local_sigma_table,
742
+ local_tau_table,
743
+ local_accumulative_recall,
744
+ local_accumulative_precision,
745
+ global_accumulative_recall,
746
+ global_accumulative_precision,
747
+ gt_flag,
748
+ det_flag,
749
+ idx,
750
+ rec_flag, )
751
+ hit_str_count += hit_str_num
752
+
753
+ try:
754
+ recall = global_accumulative_recall / total_num_gt
755
+ except ZeroDivisionError:
756
+ recall = 0
757
+
758
+ try:
759
+ precision = global_accumulative_precision / total_num_det
760
+ except ZeroDivisionError:
761
+ precision = 0
762
+
763
+ try:
764
+ f_score = 2 * precision * recall / (precision + recall)
765
+ except ZeroDivisionError:
766
+ f_score = 0
767
+
768
+ try:
769
+ seqerr = 1 - float(hit_str_count) / global_accumulative_recall
770
+ except ZeroDivisionError:
771
+ seqerr = 1
772
+
773
+ try:
774
+ recall_e2e = float(hit_str_count) / total_num_gt
775
+ except ZeroDivisionError:
776
+ recall_e2e = 0
777
+
778
+ try:
779
+ precision_e2e = float(hit_str_count) / total_num_det
780
+ except ZeroDivisionError:
781
+ precision_e2e = 0
782
+
783
+ try:
784
+ f_score_e2e = 2 * precision_e2e * recall_e2e / (
785
+ precision_e2e + recall_e2e)
786
+ except ZeroDivisionError:
787
+ f_score_e2e = 0
788
+
789
+ final = {
790
+ "total_num_gt": total_num_gt,
791
+ "total_num_det": total_num_det,
792
+ "global_accumulative_recall": global_accumulative_recall,
793
+ "hit_str_count": hit_str_count,
794
+ "recall": recall,
795
+ "precision": precision,
796
+ "f_score": f_score,
797
+ "seqerr": seqerr,
798
+ "recall_e2e": recall_e2e,
799
+ "precision_e2e": precision_e2e,
800
+ "f_score_e2e": f_score_e2e,
801
+ }
802
+ return final