openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,504 @@
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from itertools import permutations
18
+ from typing import Any, Optional
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch import Tensor
25
+ from torch.nn.modules import transformer
26
+
27
+
28
+ class DecoderLayer(nn.Module):
29
+ """A Transformer decoder layer supporting two-stream attention (XLNet) This
30
+ implements a pre-LN decoder, as opposed to the post-LN default in
31
+ PyTorch."""
32
+
33
+ def __init__(
34
+ self,
35
+ d_model,
36
+ nhead,
37
+ dim_feedforward=2048,
38
+ dropout=0.1,
39
+ activation='gelu',
40
+ layer_norm_eps=1e-5,
41
+ ):
42
+ super().__init__()
43
+ self.self_attn = nn.MultiheadAttention(d_model,
44
+ nhead,
45
+ dropout=dropout,
46
+ batch_first=True)
47
+ self.cross_attn = nn.MultiheadAttention(d_model,
48
+ nhead,
49
+ dropout=dropout,
50
+ batch_first=True)
51
+ # Implementation of Feedforward model
52
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
53
+ self.dropout = nn.Dropout(dropout)
54
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
55
+
56
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
57
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
58
+ self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps)
59
+ self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps)
60
+ self.dropout1 = nn.Dropout(dropout)
61
+ self.dropout2 = nn.Dropout(dropout)
62
+ self.dropout3 = nn.Dropout(dropout)
63
+
64
+ self.activation = transformer._get_activation_fn(activation)
65
+
66
+ def __setstate__(self, state):
67
+ if 'activation' not in state:
68
+ state['activation'] = F.gelu
69
+ super().__setstate__(state)
70
+
71
+ def forward_stream(
72
+ self,
73
+ tgt: Tensor,
74
+ tgt_norm: Tensor,
75
+ tgt_kv: Tensor,
76
+ memory: Tensor,
77
+ tgt_mask: Optional[Tensor],
78
+ tgt_key_padding_mask: Optional[Tensor],
79
+ ):
80
+ """Forward pass for a single stream (i.e. content or query) tgt_norm is
81
+ just a LayerNorm'd tgt.
82
+
83
+ Added as a separate parameter for efficiency. Both tgt_kv and memory
84
+ are expected to be LayerNorm'd too. memory is LayerNorm'd by ViT.
85
+ """
86
+ tgt2, sa_weights = self.self_attn(
87
+ tgt_norm,
88
+ tgt_kv,
89
+ tgt_kv,
90
+ attn_mask=tgt_mask,
91
+ key_padding_mask=tgt_key_padding_mask)
92
+ tgt = tgt + self.dropout1(tgt2)
93
+
94
+ tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
95
+ self.attn_map = ca_weights
96
+ tgt = tgt + self.dropout2(tgt2)
97
+
98
+ tgt2 = self.linear2(
99
+ self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
100
+ tgt = tgt + self.dropout3(tgt2)
101
+ return tgt, sa_weights, ca_weights
102
+
103
+ def forward(
104
+ self,
105
+ query,
106
+ content,
107
+ memory,
108
+ query_mask: Optional[Tensor] = None,
109
+ content_mask: Optional[Tensor] = None,
110
+ content_key_padding_mask: Optional[Tensor] = None,
111
+ update_content: bool = True,
112
+ ):
113
+ query_norm = self.norm_q(query)
114
+ content_norm = self.norm_c(content)
115
+ query = self.forward_stream(query, query_norm, content_norm, memory,
116
+ query_mask, content_key_padding_mask)[0]
117
+ if update_content:
118
+ content = self.forward_stream(content, content_norm, content_norm,
119
+ memory, content_mask,
120
+ content_key_padding_mask)[0]
121
+ return query, content
122
+
123
+
124
+ class Decoder(nn.Module):
125
+ __constants__ = ['norm']
126
+
127
+ def __init__(self, decoder_layer, num_layers, norm):
128
+ super().__init__()
129
+ self.layers = transformer._get_clones(decoder_layer, num_layers)
130
+ self.num_layers = num_layers
131
+ self.norm = norm
132
+
133
+ def forward(
134
+ self,
135
+ query,
136
+ content,
137
+ memory,
138
+ query_mask: Optional[Tensor] = None,
139
+ content_mask: Optional[Tensor] = None,
140
+ content_key_padding_mask: Optional[Tensor] = None,
141
+ ):
142
+ for i, mod in enumerate(self.layers):
143
+ last = i == len(self.layers) - 1
144
+ query, content = mod(
145
+ query,
146
+ content,
147
+ memory,
148
+ query_mask,
149
+ content_mask,
150
+ content_key_padding_mask,
151
+ update_content=not last,
152
+ )
153
+ query = self.norm(query)
154
+ return query
155
+
156
+
157
+ class TokenEmbedding(nn.Module):
158
+
159
+ def __init__(self, charset_size: int, embed_dim: int):
160
+ super().__init__()
161
+ self.embedding = nn.Embedding(charset_size, embed_dim)
162
+ self.embed_dim = embed_dim
163
+
164
+ def forward(self, tokens: torch.Tensor):
165
+ return math.sqrt(self.embed_dim) * self.embedding(tokens)
166
+
167
+
168
+ class PARSeqDecoder(nn.Module):
169
+
170
+ def __init__(self,
171
+ in_channels,
172
+ out_channels,
173
+ max_label_length=25,
174
+ embed_dim=384,
175
+ dec_num_heads=12,
176
+ dec_mlp_ratio=4,
177
+ dec_depth=1,
178
+ perm_num=6,
179
+ perm_forward=True,
180
+ perm_mirrored=True,
181
+ decode_ar=True,
182
+ refine_iters=1,
183
+ dropout=0.1,
184
+ **kwargs: Any) -> None:
185
+ super().__init__()
186
+ self.pad_id = out_channels - 1
187
+ self.eos_id = 0
188
+ self.bos_id = out_channels - 2
189
+ self.max_label_length = max_label_length
190
+ self.decode_ar = decode_ar
191
+ self.refine_iters = refine_iters
192
+
193
+ decoder_layer = DecoderLayer(embed_dim, dec_num_heads,
194
+ embed_dim * dec_mlp_ratio, dropout)
195
+ self.decoder = Decoder(decoder_layer,
196
+ num_layers=dec_depth,
197
+ norm=nn.LayerNorm(embed_dim))
198
+
199
+ # Perm/attn mask stuff
200
+ self.rng = np.random.default_rng()
201
+ self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
202
+ self.perm_forward = perm_forward
203
+ self.perm_mirrored = perm_mirrored
204
+
205
+ # We don't predict <bos> nor <pad>
206
+ self.head = nn.Linear(embed_dim, out_channels - 2)
207
+ self.text_embed = TokenEmbedding(out_channels, embed_dim)
208
+
209
+ # +1 for <eos>
210
+ self.pos_queries = nn.Parameter(
211
+ torch.Tensor(1, max_label_length + 1, embed_dim))
212
+ self.dropout = nn.Dropout(p=dropout)
213
+ # Encoder has its own init.
214
+ self.apply(self._init_weights)
215
+ nn.init.trunc_normal_(self.pos_queries, std=0.02)
216
+
217
+ def _init_weights(self, module: nn.Module):
218
+ """Initialize the weights using the typical initialization schemes used
219
+ in SOTA models."""
220
+
221
+ if isinstance(module, nn.Linear):
222
+ nn.init.trunc_normal_(module.weight, std=0.02)
223
+ if module.bias is not None:
224
+ nn.init.zeros_(module.bias)
225
+ elif isinstance(module, nn.Embedding):
226
+ nn.init.trunc_normal_(module.weight, std=0.02)
227
+ if module.padding_idx is not None:
228
+ module.weight.data[module.padding_idx].zero_()
229
+ elif isinstance(module, nn.Conv2d):
230
+ nn.init.kaiming_normal_(module.weight,
231
+ mode='fan_out',
232
+ nonlinearity='relu')
233
+ if module.bias is not None:
234
+ nn.init.zeros_(module.bias)
235
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
236
+ nn.init.ones_(module.weight)
237
+ nn.init.zeros_(module.bias)
238
+
239
+ @torch.jit.ignore
240
+ def no_weight_decay(self):
241
+ param_names = {'text_embed.embedding.weight', 'pos_queries'}
242
+ return param_names
243
+
244
+ def decode(
245
+ self,
246
+ tgt: torch.Tensor,
247
+ memory: torch.Tensor,
248
+ tgt_mask: Optional[Tensor] = None,
249
+ tgt_padding_mask: Optional[Tensor] = None,
250
+ tgt_query: Optional[Tensor] = None,
251
+ tgt_query_mask: Optional[Tensor] = None,
252
+ pos_query: torch.Tensor = None,
253
+ ):
254
+ N, L = tgt.shape
255
+ # <bos> stands for the null context. We only supply position information for characters after <bos>.
256
+ null_ctx = self.text_embed(tgt[:, :1])
257
+
258
+ if tgt_query is None:
259
+ tgt_query = pos_query[:, :L]
260
+ tgt_emb = pos_query[:, :L - 1] + self.text_embed(tgt[:, 1:])
261
+ tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
262
+
263
+ tgt_query = self.dropout(tgt_query)
264
+ return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask,
265
+ tgt_mask, tgt_padding_mask)
266
+
267
+ def forward(self, x, data=None, pos_query=None):
268
+ if self.training:
269
+ return self.training_step([x, pos_query, data[0]])
270
+ else:
271
+ return self.forward_test(x, pos_query)
272
+
273
+ def forward_test(self,
274
+ memory: Tensor,
275
+ pos_query: Tensor = None,
276
+ max_length: Optional[int] = None) -> Tensor:
277
+ _device = memory.get_device()
278
+ testing = max_length is None
279
+ max_length = (self.max_label_length if max_length is None else min(
280
+ max_length, self.max_label_length))
281
+ bs = memory.shape[0]
282
+ # +1 for <eos> at end of sequence.
283
+ num_steps = max_length + 1
284
+ # memory = self.encode(images)
285
+
286
+ # Query positions up to `num_steps`
287
+ if pos_query is None:
288
+ pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
289
+ else:
290
+ pos_queries = pos_query
291
+
292
+ # Special case for the forward permutation. Faster than using `generate_attn_masks()`
293
+ tgt_mask = query_mask = torch.triu(
294
+ torch.full((num_steps, num_steps), float('-inf'), device=_device),
295
+ 1)
296
+ self.attn_maps = []
297
+ if self.decode_ar:
298
+ tgt_in = torch.full((bs, num_steps),
299
+ self.pad_id,
300
+ dtype=torch.long,
301
+ device=_device)
302
+ tgt_in[:, 0] = self.bos_id
303
+
304
+ logits = []
305
+ for i in range(num_steps):
306
+ j = i + 1 # next token index
307
+ # Efficient decoding:
308
+ # Input the context up to the ith token. We use only one query (at position = i) at a time.
309
+ # This works because of the lookahead masking effect of the canonical (forward) AR context.
310
+ # Past tokens have no access to future tokens, hence are fixed once computed.
311
+ tgt_out = self.decode(
312
+ tgt_in[:, :j],
313
+ memory,
314
+ tgt_mask[:j, :j],
315
+ tgt_query=pos_queries[:, i:j],
316
+ tgt_query_mask=query_mask[i:j, :j],
317
+ pos_query=pos_queries,
318
+ )
319
+ self.attn_maps.append(self.decoder.layers[-1].attn_map)
320
+ # the next token probability is in the output's ith token position
321
+ p_i = self.head(tgt_out)
322
+ logits.append(p_i)
323
+ if j < num_steps:
324
+ # greedy decode. add the next token index to the target input
325
+ tgt_in[:, j] = p_i.squeeze().argmax(-1)
326
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
327
+ if testing and (tgt_in == self.eos_id).any(dim=-1).all():
328
+ break
329
+
330
+ logits = torch.cat(logits, dim=1)
331
+ else:
332
+ # No prior context, so input is just <bos>. We query all positions.
333
+ tgt_in = torch.full((bs, 1),
334
+ self.bos_id,
335
+ dtype=torch.long,
336
+ device=_device)
337
+ tgt_out = self.decode(tgt_in,
338
+ memory,
339
+ tgt_query=pos_queries,
340
+ pos_query=pos_queries)
341
+ logits = self.head(tgt_out)
342
+
343
+ if self.refine_iters:
344
+ # For iterative refinement, we always use a 'cloze' mask.
345
+ # We can derive it from the AR forward mask by unmasking the token context to the right.
346
+ query_mask[torch.triu(
347
+ torch.ones(num_steps,
348
+ num_steps,
349
+ dtype=torch.bool,
350
+ device=_device), 2)] = 0
351
+ bos = torch.full((bs, 1),
352
+ self.bos_id,
353
+ dtype=torch.long,
354
+ device=_device)
355
+ for i in range(self.refine_iters):
356
+ # Prior context is the previous output.
357
+ tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
358
+ tgt_padding_mask = (tgt_in == self.eos_id).int().cumsum(
359
+ -1) > 0 # mask tokens beyond the first EOS token.
360
+ tgt_out = self.decode(
361
+ tgt_in,
362
+ memory,
363
+ tgt_mask,
364
+ tgt_padding_mask,
365
+ tgt_query=pos_queries,
366
+ tgt_query_mask=query_mask[:, :tgt_in.shape[1]],
367
+ pos_query=pos_queries,
368
+ )
369
+ logits = self.head(tgt_out)
370
+
371
+ return F.softmax(logits, -1)
372
+
373
+ def gen_tgt_perms(self, tgt, _device):
374
+ """Generate shared permutations for the whole batch.
375
+
376
+ This works because the same attention mask can be used for the shorter
377
+ sequences because of the padding mask.
378
+ """
379
+ # We don't permute the position of BOS, we permute EOS separately
380
+ max_num_chars = tgt.shape[1] - 2
381
+ # Special handling for 1-character sequences
382
+ if max_num_chars == 1:
383
+ return torch.arange(3, device=_device).unsqueeze(0)
384
+ perms = [torch.arange(max_num_chars, device=_device)
385
+ ] if self.perm_forward else []
386
+ # Additional permutations if needed
387
+ max_perms = math.factorial(max_num_chars)
388
+ if self.perm_mirrored:
389
+ max_perms //= 2
390
+ num_gen_perms = min(self.max_gen_perms, max_perms)
391
+ # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions
392
+ # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars.
393
+ if max_num_chars < 5:
394
+ # Pool of permutations to sample from. We only need the first half (if complementary option is selected)
395
+ # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
396
+ if max_num_chars == 4 and self.perm_mirrored:
397
+ selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
398
+ else:
399
+ selector = list(range(max_perms))
400
+ perm_pool = torch.as_tensor(list(
401
+ permutations(range(max_num_chars), max_num_chars)),
402
+ device=_device)[selector]
403
+ # If the forward permutation is always selected, no need to add it to the pool for sampling
404
+ if self.perm_forward:
405
+ perm_pool = perm_pool[1:]
406
+ perms = torch.stack(perms)
407
+ if len(perm_pool):
408
+ i = self.rng.choice(len(perm_pool),
409
+ size=num_gen_perms - len(perms),
410
+ replace=False)
411
+ perms = torch.cat([perms, perm_pool[i]])
412
+ else:
413
+ perms.extend([
414
+ torch.randperm(max_num_chars, device=_device)
415
+ for _ in range(num_gen_perms - len(perms))
416
+ ])
417
+ perms = torch.stack(perms)
418
+ if self.perm_mirrored:
419
+ # Add complementary pairs
420
+ comp = perms.flip(-1)
421
+ # Stack in such a way that the pairs are next to each other.
422
+ perms = torch.stack([perms, comp
423
+ ]).transpose(0, 1).reshape(-1, max_num_chars)
424
+ # NOTE:
425
+ # The only meaningful way of permuting the EOS position is by moving it one character position at a time.
426
+ # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS
427
+ # positions will always be much less than the number of permutations (unless a low perm_num is set).
428
+ # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly
429
+ # distribute it across the chosen number of permutations.
430
+ # Add position indices of BOS and EOS
431
+ bos_idx = perms.new_zeros((len(perms), 1))
432
+ eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1)
433
+ perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1)
434
+ # Special handling for the reverse direction. This does two things:
435
+ # 1. Reverse context for the characters
436
+ # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode)
437
+ if len(perms) > 1:
438
+ perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1,
439
+ device=_device)
440
+ return perms
441
+
442
+ def generate_attn_masks(self, perm, _device):
443
+ """Generate attention masks given a sequence permutation (includes pos.
444
+ for bos and eos tokens)
445
+
446
+ :param perm: the permutation sequence. i = 0 is always the BOS
447
+ :return: lookahead attention masks
448
+ """
449
+ sz = perm.shape[0]
450
+ mask = torch.zeros((sz, sz), device=_device)
451
+ for i in range(sz):
452
+ query_idx = perm[i]
453
+ masked_keys = perm[i + 1:]
454
+ mask[query_idx, masked_keys] = float('-inf')
455
+ content_mask = mask[:-1, :-1].clone()
456
+ mask[torch.eye(sz, dtype=torch.bool,
457
+ device=_device)] = float('-inf') # mask "self"
458
+ query_mask = mask[1:, :-1]
459
+ return content_mask, query_mask
460
+
461
+ def training_step(self, batch):
462
+ memory, pos_query, tgt = batch
463
+ bs = memory.shape[0]
464
+ if pos_query is None:
465
+ pos_query = self.pos_queries.expand(bs, -1, -1)
466
+
467
+ # Prepare the target sequences (input and output)
468
+ tgt_perms = self.gen_tgt_perms(tgt, memory.get_device())
469
+ tgt_in = tgt[:, :-1]
470
+ tgt_out = tgt[:, 1:]
471
+ # The [EOS] token is not depended upon by any other token in any permutation ordering
472
+ tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
473
+
474
+ loss = 0
475
+ loss_numel = 0
476
+ n = (tgt_out != self.pad_id).sum().item()
477
+ for i, perm in enumerate(tgt_perms):
478
+ tgt_mask, query_mask = self.generate_attn_masks(
479
+ perm, memory.get_device())
480
+ out = self.decode(
481
+ tgt_in,
482
+ memory,
483
+ tgt_mask,
484
+ tgt_padding_mask,
485
+ tgt_query_mask=query_mask,
486
+ pos_query=pos_query,
487
+ )
488
+ logits = self.head(out)
489
+ if i == 0:
490
+ final_out = logits
491
+ loss += n * F.cross_entropy(logits.flatten(end_dim=1),
492
+ tgt_out.flatten(),
493
+ ignore_index=self.pad_id)
494
+ loss_numel += n
495
+ # After the second iteration (i.e. done with canonical and reverse orderings),
496
+ # remove the [EOS] tokens for the succeeding perms
497
+ if i == 1:
498
+ tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id,
499
+ tgt_out)
500
+ n = (tgt_out != self.pad_id).sum().item()
501
+ loss /= loss_numel
502
+
503
+ # self.log('loss', loss)
504
+ return [loss, final_out]
@@ -0,0 +1,70 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.init import trunc_normal_
5
+
6
+ from openrec.modeling.common import Block
7
+
8
+
9
+ class RCTCDecoder(nn.Module):
10
+
11
+ def __init__(self,
12
+ in_channels,
13
+ out_channels=6625,
14
+ return_feats=False,
15
+ **kwargs):
16
+ super(RCTCDecoder, self).__init__()
17
+ self.char_token = nn.Parameter(
18
+ torch.zeros([1, 1, in_channels], dtype=torch.float32),
19
+ requires_grad=True,
20
+ )
21
+ trunc_normal_(self.char_token, mean=0, std=0.02)
22
+ self.fc = nn.Linear(
23
+ in_channels,
24
+ out_channels,
25
+ bias=True,
26
+ )
27
+ self.fc_kv = nn.Linear(
28
+ in_channels,
29
+ 2 * in_channels,
30
+ bias=True,
31
+ )
32
+ self.w_atten_block = Block(dim=in_channels,
33
+ num_heads=in_channels // 32,
34
+ mlp_ratio=4.0,
35
+ qkv_bias=False)
36
+ self.out_channels = out_channels
37
+ self.return_feats = return_feats
38
+
39
+ def forward(self, x, data=None):
40
+
41
+ B, C, H, W = x.shape
42
+ x = self.w_atten_block(x.permute(0, 2, 3,
43
+ 1).reshape(-1, W, C)).reshape(
44
+ B, H, W, C).permute(0, 3, 1, 2)
45
+ # B, D, 8, 32
46
+ x_kv = self.fc_kv(x.flatten(2).transpose(1, 2)).reshape(
47
+ B, H * W, 2, C).permute(2, 0, 3, 1) # 2, b, c, hw
48
+ x_k, x_v = x_kv.unbind(0) # b, c, hw
49
+ char_token = self.char_token.tile([B, 1, 1])
50
+ attn_ctc2d = char_token @ x_k # b, 1, hw
51
+ attn_ctc2d = attn_ctc2d.reshape([-1, 1, H, W])
52
+ attn_ctc2d = F.softmax(attn_ctc2d, 2) # b, 1, h, w
53
+ attn_ctc2d = attn_ctc2d.permute(0, 3, 1, 2) # b, w, 1, h
54
+ x_v = x_v.reshape(B, C, H, W)
55
+ # B, W, H, C
56
+ feats = attn_ctc2d @ x_v.permute(0, 3, 2, 1) # b, w, 1, c
57
+ feats = feats.squeeze(2) # b, w, c
58
+
59
+ predicts = self.fc(feats)
60
+
61
+ if self.return_feats:
62
+ result = (feats, predicts)
63
+ else:
64
+ result = predicts
65
+
66
+ if not self.training:
67
+ predicts = F.softmax(predicts, dim=2)
68
+ result = predicts
69
+
70
+ return result