openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,749 @@
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class RobustScannerDecoder(nn.Module):
9
+
10
+ def __init__(
11
+ self,
12
+ out_channels, # 90 + unknown + start + padding
13
+ in_channels,
14
+ enc_outchannles=128,
15
+ hybrid_dec_rnn_layers=2,
16
+ hybrid_dec_dropout=0,
17
+ position_dec_rnn_layers=2,
18
+ max_len=25,
19
+ mask=True,
20
+ encode_value=False,
21
+ **kwargs):
22
+ super(RobustScannerDecoder, self).__init__()
23
+
24
+ start_idx = out_channels - 2
25
+ padding_idx = out_channels - 1
26
+ end_idx = 0
27
+ # encoder module
28
+ self.encoder = ChannelReductionEncoder(in_channels=in_channels,
29
+ out_channels=enc_outchannles)
30
+ self.max_text_length = max_len + 1
31
+ self.mask = mask
32
+ # decoder module
33
+ self.decoder = Decoder(
34
+ num_classes=out_channels,
35
+ dim_input=in_channels,
36
+ dim_model=enc_outchannles,
37
+ hybrid_decoder_rnn_layers=hybrid_dec_rnn_layers,
38
+ hybrid_decoder_dropout=hybrid_dec_dropout,
39
+ position_decoder_rnn_layers=position_dec_rnn_layers,
40
+ max_len=max_len + 1,
41
+ start_idx=start_idx,
42
+ mask=mask,
43
+ padding_idx=padding_idx,
44
+ end_idx=end_idx,
45
+ encode_value=encode_value)
46
+
47
+ def forward(self, inputs, data=None):
48
+ '''
49
+ data: [label, valid_ratio, 'length']
50
+ '''
51
+ out_enc = self.encoder(inputs)
52
+ bs = out_enc.shape[0]
53
+ valid_ratios = None
54
+ word_positions = torch.arange(0,
55
+ self.max_text_length,
56
+ device=inputs.device).unsqueeze(0).tile(
57
+ [bs, 1])
58
+
59
+ if self.mask:
60
+ valid_ratios = data[-1]
61
+
62
+ if self.training:
63
+ max_len = data[1].max()
64
+ label = data[0][:, :1 + max_len] # label
65
+ final_out = self.decoder(inputs, out_enc, label, valid_ratios,
66
+ word_positions[:, :1 + max_len])
67
+ if not self.training:
68
+ final_out = self.decoder(inputs,
69
+ out_enc,
70
+ label=None,
71
+ valid_ratios=valid_ratios,
72
+ word_positions=word_positions,
73
+ train_mode=False)
74
+ return final_out
75
+
76
+
77
+ class BaseDecoder(nn.Module):
78
+
79
+ def __init__(self, **kwargs):
80
+ super().__init__()
81
+
82
+ def forward_train(self, feat, out_enc, targets, img_metas):
83
+ raise NotImplementedError
84
+
85
+ def forward_test(self, feat, out_enc, img_metas):
86
+ raise NotImplementedError
87
+
88
+ def forward(self,
89
+ feat,
90
+ out_enc,
91
+ label=None,
92
+ valid_ratios=None,
93
+ word_positions=None,
94
+ train_mode=True):
95
+ self.train_mode = train_mode
96
+
97
+ if train_mode:
98
+ return self.forward_train(feat, out_enc, label, valid_ratios,
99
+ word_positions)
100
+ return self.forward_test(feat, out_enc, valid_ratios, word_positions)
101
+
102
+
103
+ class ChannelReductionEncoder(nn.Module):
104
+ """Change the channel number with a one by one convoluational layer.
105
+
106
+ Args:
107
+ in_channels (int): Number of input channels.
108
+ out_channels (int): Number of output channels.
109
+ """
110
+
111
+ def __init__(self, in_channels, out_channels, **kwargs):
112
+ super(ChannelReductionEncoder, self).__init__()
113
+
114
+ weight = torch.nn.Parameter(
115
+ torch.nn.init.xavier_normal_(torch.empty(out_channels, in_channels,
116
+ 1, 1),
117
+ gain=1.0))
118
+ self.layer = nn.Conv2d(in_channels,
119
+ out_channels,
120
+ kernel_size=1,
121
+ stride=1,
122
+ padding=0)
123
+
124
+ use_xavier_normal = 1
125
+ if use_xavier_normal:
126
+ self.layer.weight = weight
127
+
128
+ def forward(self, feat):
129
+ """
130
+ Args:
131
+ feat (Tensor): Image features with the shape of
132
+ :math:`(N, C_{in}, H, W)`.
133
+
134
+ Returns:
135
+ Tensor: A tensor of shape :math:`(N, C_{out}, H, W)`.
136
+ """
137
+ return self.layer(feat)
138
+
139
+
140
+ def masked_fill(x, mask, value):
141
+ y = torch.full(x.shape, value, x.dtype)
142
+ return torch.where(mask, y, x)
143
+
144
+
145
+ class DotProductAttentionLayer(nn.Module):
146
+
147
+ def __init__(self, dim_model=None):
148
+ super().__init__()
149
+
150
+ self.scale = dim_model**-0.5 if dim_model is not None else 1.
151
+
152
+ def forward(self, query, key, value, mask=None):
153
+
154
+ query = query.permute(0, 2, 1)
155
+ logits = query @ key * self.scale
156
+
157
+ if mask is not None:
158
+ n, seq_len = mask.size()
159
+ mask = mask.view(n, 1, seq_len)
160
+ logits = logits.masked_fill(mask, float('-inf'))
161
+
162
+ weights = F.softmax(logits, dim=2)
163
+ value = value.transpose(1, 2)
164
+ glimpse = weights @ value
165
+ glimpse = glimpse.permute(0, 2, 1).contiguous()
166
+ return glimpse
167
+
168
+
169
+ class SequenceAttentionDecoder(BaseDecoder):
170
+ """Sequence attention decoder for RobustScanner.
171
+
172
+ RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
173
+ Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
174
+
175
+ Args:
176
+ num_classes (int): Number of output classes :math:`C`.
177
+ rnn_layers (int): Number of RNN layers.
178
+ dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
179
+ dim_model (int): Dimension :math:`D_m` of the model. Should also be the
180
+ same as encoder output vector ``out_enc``.
181
+ max_seq_len (int): Maximum output sequence length :math:`T`.
182
+ start_idx (int): The index of `<SOS>`.
183
+ mask (bool): Whether to mask input features according to
184
+ ``img_meta['valid_ratio']``.
185
+ padding_idx (int): The index of `<PAD>`.
186
+ dropout (float): Dropout rate.
187
+ return_feature (bool): Return feature or logits as the result.
188
+ encode_value (bool): Whether to use the output of encoder ``out_enc``
189
+ as `value` of attention layer. If False, the original feature
190
+ ``feat`` will be used.
191
+
192
+ Warning:
193
+ This decoder will not predict the final class which is assumed to be
194
+ `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
195
+ is also ignored by loss as specified in
196
+ :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
197
+ """
198
+
199
+ def __init__(self,
200
+ num_classes=None,
201
+ rnn_layers=2,
202
+ dim_input=512,
203
+ dim_model=128,
204
+ max_seq_len=40,
205
+ start_idx=0,
206
+ mask=True,
207
+ padding_idx=None,
208
+ dropout=0,
209
+ return_feature=False,
210
+ encode_value=False):
211
+ super().__init__()
212
+
213
+ self.num_classes = num_classes
214
+ self.dim_input = dim_input
215
+ self.dim_model = dim_model
216
+ self.return_feature = return_feature
217
+ self.encode_value = encode_value
218
+ self.max_seq_len = max_seq_len
219
+ self.start_idx = start_idx
220
+ self.mask = mask
221
+
222
+ self.embedding = nn.Embedding(self.num_classes,
223
+ self.dim_model,
224
+ padding_idx=padding_idx)
225
+
226
+ self.sequence_layer = nn.LSTM(input_size=dim_model,
227
+ hidden_size=dim_model,
228
+ num_layers=rnn_layers,
229
+ batch_first=True,
230
+ dropout=dropout)
231
+
232
+ self.attention_layer = DotProductAttentionLayer()
233
+
234
+ self.prediction = None
235
+ if not self.return_feature:
236
+ pred_num_classes = num_classes - 1
237
+ self.prediction = nn.Linear(
238
+ dim_model if encode_value else dim_input, pred_num_classes)
239
+
240
+ def forward_train(self, feat, out_enc, targets, valid_ratios):
241
+ """
242
+ Args:
243
+ feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
244
+ out_enc (Tensor): Encoder output of shape
245
+ :math:`(N, D_m, H, W)`.
246
+ targets (Tensor): a tensor of shape :math:`(N, T)`. Each element is the index of a
247
+ character.
248
+ valid_ratios (Tensor): valid length ratio of img.
249
+ Returns:
250
+ Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
251
+ ``return_feature=False``. Otherwise it would be the hidden feature
252
+ before the prediction projection layer, whose shape is
253
+ :math:`(N, T, D_m)`.
254
+ """
255
+
256
+ tgt_embedding = self.embedding(targets)
257
+
258
+ n, c_enc, h, w = out_enc.shape
259
+ assert c_enc == self.dim_model
260
+ _, c_feat, _, _ = feat.shape
261
+ assert c_feat == self.dim_input
262
+ _, len_q, c_q = tgt_embedding.shape
263
+ assert c_q == self.dim_model
264
+ assert len_q <= self.max_seq_len
265
+
266
+ query, _ = self.sequence_layer(tgt_embedding)
267
+
268
+ query = query.permute(0, 2, 1).contiguous()
269
+
270
+ key = out_enc.view(n, c_enc, h * w)
271
+
272
+ if self.encode_value:
273
+ value = key
274
+ else:
275
+ value = feat.view(n, c_feat, h * w)
276
+
277
+ mask = None
278
+ if valid_ratios is not None:
279
+ mask = query.new_zeros((n, h, w))
280
+ for i, valid_ratio in enumerate(valid_ratios):
281
+ valid_width = min(w, math.ceil(w * valid_ratio))
282
+ mask[i, :, valid_width:] = 1
283
+ mask = mask.bool()
284
+ mask = mask.view(n, h * w)
285
+
286
+ attn_out = self.attention_layer(query, key, value, mask)
287
+ attn_out = attn_out.permute(0, 2, 1).contiguous()
288
+
289
+ if self.return_feature:
290
+ return attn_out
291
+
292
+ out = self.prediction(attn_out)
293
+
294
+ return out
295
+
296
+ def forward_test(self, feat, out_enc, valid_ratios):
297
+ """
298
+ Args:
299
+ feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
300
+ out_enc (Tensor): Encoder output of shape
301
+ :math:`(N, D_m, H, W)`.
302
+ valid_ratios (Tensor): valid length ratio of img.
303
+
304
+ Returns:
305
+ Tensor: The output logit sequence tensor of shape
306
+ :math:`(N, T, C-1)`.
307
+ """
308
+ batch_size = feat.shape[0]
309
+
310
+ decode_sequence = (torch.ones((batch_size, self.max_seq_len),
311
+ dtype=torch.int64,
312
+ device=feat.device) * self.start_idx)
313
+
314
+ outputs = []
315
+ for i in range(self.max_seq_len):
316
+ step_out = self.forward_test_step(feat, out_enc, decode_sequence,
317
+ i, valid_ratios)
318
+ outputs.append(step_out)
319
+ max_idx = torch.argmax(step_out, dim=1, keepdim=False)
320
+ if i < self.max_seq_len - 1:
321
+ decode_sequence[:, i + 1] = max_idx
322
+
323
+ outputs = torch.stack(outputs, 1)
324
+
325
+ return outputs
326
+
327
+ def forward_test_step(self, feat, out_enc, decode_sequence, current_step,
328
+ valid_ratios):
329
+ """
330
+ Args:
331
+ feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
332
+ out_enc (Tensor): Encoder output of shape
333
+ :math:`(N, D_m, H, W)`.
334
+ decode_sequence (Tensor): Shape :math:`(N, T)`. The tensor that
335
+ stores history decoding result.
336
+ current_step (int): Current decoding step.
337
+ valid_ratios (Tensor): valid length ratio of img
338
+
339
+ Returns:
340
+ Tensor: Shape :math:`(N, C-1)`. The logit tensor of predicted
341
+ tokens at current time step.
342
+ """
343
+
344
+ embed = self.embedding(decode_sequence)
345
+
346
+ n, c_enc, h, w = out_enc.shape
347
+ assert c_enc == self.dim_model
348
+ _, c_feat, _, _ = feat.shape
349
+ assert c_feat == self.dim_input
350
+ _, _, c_q = embed.shape
351
+ assert c_q == self.dim_model
352
+
353
+ query, _ = self.sequence_layer(embed)
354
+ query = query.transpose(1, 2)
355
+ key = torch.reshape(out_enc, (n, c_enc, h * w))
356
+ if self.encode_value:
357
+ value = key
358
+ else:
359
+ value = torch.reshape(feat, (n, c_feat, h * w))
360
+
361
+ mask = None
362
+ if valid_ratios is not None:
363
+ mask = query.new_zeros((n, h, w))
364
+ for i, valid_ratio in enumerate(valid_ratios):
365
+ valid_width = min(w, math.ceil(w * valid_ratio))
366
+ mask[i, :, valid_width:] = 1
367
+ mask = mask.bool()
368
+ mask = mask.view(n, h * w)
369
+
370
+ # [n, c, l]
371
+ attn_out = self.attention_layer(query, key, value, mask)
372
+ out = attn_out[:, :, current_step]
373
+
374
+ if self.return_feature:
375
+ return out
376
+
377
+ out = self.prediction(out)
378
+ out = F.softmax(out, dim=-1)
379
+
380
+ return out
381
+
382
+
383
+ class PositionAwareLayer(nn.Module):
384
+
385
+ def __init__(self, dim_model, rnn_layers=2):
386
+ super().__init__()
387
+
388
+ self.dim_model = dim_model
389
+
390
+ self.rnn = nn.LSTM(input_size=dim_model,
391
+ hidden_size=dim_model,
392
+ num_layers=rnn_layers,
393
+ batch_first=True)
394
+
395
+ self.mixer = nn.Sequential(
396
+ nn.Conv2d(dim_model, dim_model, kernel_size=3, stride=1,
397
+ padding=1), nn.ReLU(True),
398
+ nn.Conv2d(dim_model, dim_model, kernel_size=3, stride=1,
399
+ padding=1))
400
+
401
+ def forward(self, img_feature):
402
+ n, c, h, w = img_feature.shape
403
+ rnn_input = img_feature.permute(0, 2, 3, 1).contiguous()
404
+ rnn_input = rnn_input.view(n * h, w, c)
405
+ rnn_output, _ = self.rnn(rnn_input)
406
+ rnn_output = rnn_output.view(n, h, w, c)
407
+ rnn_output = rnn_output.permute(0, 3, 1, 2).contiguous()
408
+
409
+ out = self.mixer(rnn_output)
410
+ return out
411
+
412
+
413
+ class PositionAttentionDecoder(BaseDecoder):
414
+ """Position attention decoder for RobustScanner.
415
+
416
+ RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
417
+ Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
418
+
419
+ Args:
420
+ num_classes (int): Number of output classes :math:`C`.
421
+ rnn_layers (int): Number of RNN layers.
422
+ dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
423
+ dim_model (int): Dimension :math:`D_m` of the model. Should also be the
424
+ same as encoder output vector ``out_enc``.
425
+ max_seq_len (int): Maximum output sequence length :math:`T`.
426
+ mask (bool): Whether to mask input features according to
427
+ ``img_meta['valid_ratio']``.
428
+ return_feature (bool): Return feature or logits as the result.
429
+ encode_value (bool): Whether to use the output of encoder ``out_enc``
430
+ as `value` of attention layer. If False, the original feature
431
+ ``feat`` will be used.
432
+
433
+ Warning:
434
+ This decoder will not predict the final class which is assumed to be
435
+ `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
436
+ is also ignored by loss
437
+ """
438
+
439
+ def __init__(self,
440
+ num_classes=None,
441
+ rnn_layers=2,
442
+ dim_input=512,
443
+ dim_model=128,
444
+ max_seq_len=40,
445
+ mask=True,
446
+ return_feature=False,
447
+ encode_value=False):
448
+ super().__init__()
449
+
450
+ self.num_classes = num_classes
451
+ self.dim_input = dim_input
452
+ self.dim_model = dim_model
453
+ self.max_seq_len = max_seq_len
454
+ self.return_feature = return_feature
455
+ self.encode_value = encode_value
456
+ self.mask = mask
457
+
458
+ self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model)
459
+
460
+ self.position_aware_module = PositionAwareLayer(
461
+ self.dim_model, rnn_layers)
462
+
463
+ self.attention_layer = DotProductAttentionLayer()
464
+
465
+ self.prediction = None
466
+ if not self.return_feature:
467
+ pred_num_classes = num_classes - 1
468
+ self.prediction = nn.Linear(
469
+ dim_model if encode_value else dim_input, pred_num_classes)
470
+
471
+ def _get_position_index(self, length, batch_size):
472
+ position_index_list = []
473
+ for i in range(batch_size):
474
+ position_index = torch.range(0, length, step=1, dtype='int64')
475
+ position_index_list.append(position_index)
476
+ batch_position_index = torch.stack(position_index_list, dim=0)
477
+ return batch_position_index
478
+
479
+ def forward_train(self, feat, out_enc, targets, valid_ratios,
480
+ position_index):
481
+ """
482
+ Args:
483
+ feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
484
+ out_enc (Tensor): Encoder output of shape
485
+ :math:`(N, D_m, H, W)`.
486
+ targets (dict): A dict with the key ``padded_targets``, a
487
+ tensor of shape :math:`(N, T)`. Each element is the index of a
488
+ character.
489
+ valid_ratios (Tensor): valid length ratio of img.
490
+ position_index (Tensor): The position of each word.
491
+
492
+ Returns:
493
+ Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
494
+ ``return_feature=False``. Otherwise it will be the hidden feature
495
+ before the prediction projection layer, whose shape is
496
+ :math:`(N, T, D_m)`.
497
+ """
498
+ n, c_enc, h, w = out_enc.shape
499
+ assert c_enc == self.dim_model
500
+ _, c_feat, _, _ = feat.shape
501
+ assert c_feat == self.dim_input
502
+ _, len_q = targets.shape
503
+ assert len_q <= self.max_seq_len
504
+
505
+ position_out_enc = self.position_aware_module(out_enc)
506
+
507
+ query = self.embedding(position_index)
508
+ query = query.permute(0, 2, 1).contiguous()
509
+ key = position_out_enc.view(n, c_enc, h * w)
510
+ if self.encode_value:
511
+ value = out_enc.view(n, c_enc, h * w)
512
+ else:
513
+ value = feat.view(n, c_feat, h * w)
514
+
515
+ mask = None
516
+ if valid_ratios is not None:
517
+ mask = query.new_zeros((n, h, w))
518
+ for i, valid_ratio in enumerate(valid_ratios):
519
+ valid_width = min(w, math.ceil(w * valid_ratio))
520
+ mask[i, :, valid_width:] = 1
521
+ mask = mask.bool()
522
+ mask = mask.view(n, h * w)
523
+
524
+ attn_out = self.attention_layer(query, key, value, mask)
525
+ attn_out = attn_out.permute(0, 2, 1).contiguous()
526
+
527
+ if self.return_feature:
528
+ return attn_out
529
+
530
+ return self.prediction(attn_out)
531
+
532
+ def forward_test(self, feat, out_enc, valid_ratios, position_index):
533
+ """
534
+ Args:
535
+ feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
536
+ out_enc (Tensor): Encoder output of shape
537
+ :math:`(N, D_m, H, W)`.
538
+ valid_ratios (Tensor): valid length ratio of img
539
+ position_index (Tensor): The position of each word.
540
+
541
+ Returns:
542
+ Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
543
+ ``return_feature=False``. Otherwise it would be the hidden feature
544
+ before the prediction projection layer, whose shape is
545
+ :math:`(N, T, D_m)`.
546
+ """
547
+ n, c_enc, h, w = out_enc.shape
548
+ assert c_enc == self.dim_model
549
+ _, c_feat, _, _ = feat.shape
550
+ assert c_feat == self.dim_input
551
+
552
+ position_out_enc = self.position_aware_module(out_enc)
553
+
554
+ query = self.embedding(position_index)
555
+ query = query.permute(0, 2, 1).contiguous()
556
+ key = position_out_enc.view(n, c_enc, h * w)
557
+ if self.encode_value:
558
+ value = torch.reshape(out_enc, (n, c_enc, h * w))
559
+ else:
560
+ value = torch.reshape(feat, (n, c_feat, h * w))
561
+
562
+ mask = None
563
+ if valid_ratios is not None:
564
+ mask = query.new_zeros((n, h, w))
565
+ for i, valid_ratio in enumerate(valid_ratios):
566
+ valid_width = min(w, math.ceil(w * valid_ratio))
567
+ mask[i, :, valid_width:] = 1
568
+ mask = mask.bool()
569
+ mask = mask.view(n, h * w)
570
+
571
+ attn_out = self.attention_layer(query, key, value, mask)
572
+ attn_out = attn_out.transpose(1, 2) # [n, len_q, dim_v]
573
+
574
+ if self.return_feature:
575
+ return attn_out
576
+
577
+ return self.prediction(attn_out)
578
+
579
+
580
+ class RobustScannerFusionLayer(nn.Module):
581
+
582
+ def __init__(self, dim_model, dim=-1):
583
+ super(RobustScannerFusionLayer, self).__init__()
584
+
585
+ self.dim_model = dim_model
586
+ self.dim = dim
587
+ self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2)
588
+
589
+ def forward(self, x0, x1):
590
+ assert x0.shape == x1.shape
591
+ fusion_input = torch.concat((x0, x1), self.dim)
592
+ output = self.linear_layer(fusion_input)
593
+ output = F.glu(output, self.dim)
594
+
595
+ return output
596
+
597
+
598
+ class Decoder(BaseDecoder):
599
+ """Decoder for RobustScanner.
600
+
601
+ RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
602
+ Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
603
+
604
+ Args:
605
+ num_classes (int): Number of output classes :math:`C`.
606
+ dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
607
+ dim_model (int): Dimension :math:`D_m` of the model. Should also be the
608
+ same as encoder output vector ``out_enc``.
609
+ max_seq_len (int): Maximum output sequence length :math:`T`.
610
+ start_idx (int): The index of `<SOS>`.
611
+ mask (bool): Whether to mask input features according to
612
+ ``img_meta['valid_ratio']``.
613
+ padding_idx (int): The index of `<PAD>`.
614
+ encode_value (bool): Whether to use the output of encoder ``out_enc``
615
+ as `value` of attention layer. If False, the original feature
616
+ ``feat`` will be used.
617
+
618
+ Warning:
619
+ This decoder will not predict the final class which is assumed to be
620
+ `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
621
+ is also ignored by loss as specified in
622
+ :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
623
+ """
624
+
625
+ def __init__(self,
626
+ num_classes=None,
627
+ dim_input=512,
628
+ dim_model=128,
629
+ hybrid_decoder_rnn_layers=2,
630
+ hybrid_decoder_dropout=0,
631
+ position_decoder_rnn_layers=2,
632
+ max_len=40,
633
+ start_idx=0,
634
+ mask=True,
635
+ padding_idx=None,
636
+ end_idx=0,
637
+ encode_value=False):
638
+ super().__init__()
639
+ self.num_classes = num_classes
640
+ self.dim_input = dim_input
641
+ self.dim_model = dim_model
642
+ self.max_seq_len = max_len
643
+ self.encode_value = encode_value
644
+ self.start_idx = start_idx
645
+ self.padding_idx = padding_idx
646
+ self.end_idx = end_idx
647
+ self.mask = mask
648
+
649
+ # init hybrid decoder
650
+ self.hybrid_decoder = SequenceAttentionDecoder(
651
+ num_classes=num_classes,
652
+ rnn_layers=hybrid_decoder_rnn_layers,
653
+ dim_input=dim_input,
654
+ dim_model=dim_model,
655
+ max_seq_len=max_len,
656
+ start_idx=start_idx,
657
+ mask=mask,
658
+ padding_idx=padding_idx,
659
+ dropout=hybrid_decoder_dropout,
660
+ encode_value=encode_value,
661
+ return_feature=True)
662
+
663
+ # init position decoder
664
+ self.position_decoder = PositionAttentionDecoder(
665
+ num_classes=num_classes,
666
+ rnn_layers=position_decoder_rnn_layers,
667
+ dim_input=dim_input,
668
+ dim_model=dim_model,
669
+ max_seq_len=max_len,
670
+ mask=mask,
671
+ encode_value=encode_value,
672
+ return_feature=True)
673
+
674
+ self.fusion_module = RobustScannerFusionLayer(
675
+ self.dim_model if encode_value else dim_input)
676
+
677
+ pred_num_classes = num_classes
678
+ self.prediction = nn.Linear(dim_model if encode_value else dim_input,
679
+ pred_num_classes)
680
+
681
+ def forward_train(self, feat, out_enc, target, valid_ratios,
682
+ word_positions):
683
+ """
684
+ Args:
685
+ feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
686
+ out_enc (Tensor): Encoder output of shape
687
+ :math:`(N, D_m, H, W)`.
688
+ target (dict): A dict with the key ``padded_targets``, a
689
+ tensor of shape :math:`(N, T)`. Each element is the index of a
690
+ character.
691
+ valid_ratios (Tensor):
692
+ word_positions (Tensor): The position of each word.
693
+
694
+ Returns:
695
+ Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`.
696
+ """
697
+
698
+ hybrid_glimpse = self.hybrid_decoder.forward_train(
699
+ feat, out_enc, target, valid_ratios)
700
+ position_glimpse = self.position_decoder.forward_train(
701
+ feat, out_enc, target, valid_ratios, word_positions)
702
+
703
+ fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse)
704
+
705
+ out = self.prediction(fusion_out)
706
+
707
+ return out
708
+
709
+ def forward_test(self, feat, out_enc, valid_ratios, word_positions):
710
+ """
711
+ Args:
712
+ feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
713
+ out_enc (Tensor): Encoder output of shape
714
+ :math:`(N, D_m, H, W)`.
715
+ valid_ratios (Tensor):
716
+ word_positions (Tensor): The position of each word.
717
+ Returns:
718
+ Tensor: The output logit sequence tensor of shape
719
+ :math:`(N, T, C-1)`.
720
+ """
721
+ seq_len = self.max_seq_len
722
+ batch_size = feat.shape[0]
723
+
724
+ decode_sequence = (torch.ones(
725
+ (batch_size, seq_len), dtype=torch.int64, device=feat.device) *
726
+ self.start_idx)
727
+
728
+ position_glimpse = self.position_decoder.forward_test(
729
+ feat, out_enc, valid_ratios, word_positions)
730
+
731
+ outputs = []
732
+ for i in range(seq_len):
733
+ hybrid_glimpse_step = self.hybrid_decoder.forward_test_step(
734
+ feat, out_enc, decode_sequence, i, valid_ratios)
735
+
736
+ fusion_out = self.fusion_module(hybrid_glimpse_step,
737
+ position_glimpse[:, i, :])
738
+
739
+ char_out = self.prediction(fusion_out)
740
+ char_out = F.softmax(char_out, -1)
741
+ outputs.append(char_out)
742
+ max_idx = torch.argmax(char_out, dim=1, keepdim=False)
743
+ if i < seq_len - 1:
744
+ decode_sequence[:, i + 1] = max_idx
745
+ if (decode_sequence == self.end_idx).any(dim=-1).all():
746
+ break
747
+ outputs = torch.stack(outputs, 1)
748
+
749
+ return outputs