openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,535 @@
1
+ """This code is refer from:
2
+ https://github.com/AlibabaResearch/AdvancedLiterateMachinery/blob/main/OCR/LISTER
3
+ """
4
+
5
+ # Copyright (2023) Alibaba Group and its affiliates
6
+ # --------------------------------------------------------
7
+ # To decode arbitrary-length text images.
8
+ # --------------------------------------------------------
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.nn.init import trunc_normal_
15
+
16
+ from openrec.modeling.encoders.focalsvtr import FocalNetBlock
17
+
18
+
19
+ class LocalSelfAttention(nn.Module):
20
+
21
+ def __init__(self,
22
+ feat_dim,
23
+ nhead,
24
+ window_size: int,
25
+ add_pos_bias=False,
26
+ qkv_drop=0.0,
27
+ proj_drop=0.0,
28
+ mlm=False):
29
+ super().__init__()
30
+ assert feat_dim % nhead == 0
31
+ self.q_fc = nn.Linear(feat_dim, feat_dim)
32
+ self.kv_fc = nn.Linear(feat_dim, feat_dim * 2)
33
+
34
+ self.nhead = nhead
35
+ self.head_dim = feat_dim // nhead
36
+ self.window_size = window_size
37
+ if add_pos_bias:
38
+ self.kv_pos_bias = nn.Parameter(torch.zeros(window_size, feat_dim))
39
+ trunc_normal_(self.kv_pos_bias, std=.02)
40
+ else:
41
+ self.kv_pos_bias = None
42
+ self.qkv_dropout = nn.Dropout(qkv_drop)
43
+
44
+ self.proj = nn.Linear(feat_dim, feat_dim)
45
+ self.proj_dropout = nn.Dropout(proj_drop)
46
+ self.mlm = mlm
47
+ if mlm:
48
+ print('Use mlm.')
49
+
50
+ def _gen_t_index(self, real_len, device):
51
+ idx = torch.stack([
52
+ torch.arange(real_len, dtype=torch.long, device=device) + st
53
+ for st in range(self.window_size)
54
+ ]).t() # [T, w]
55
+ return idx
56
+
57
+ def _apply_attn_mask(self, attn_score):
58
+ attn_score[:, :, :, :, self.window_size // 2] = float('-inf')
59
+ return attn_score
60
+
61
+ def forward(self, x, mask):
62
+ """
63
+ Args:
64
+ x: [b, T, C]
65
+ mask: [b, T]
66
+ """
67
+ b, T, C = x.size()
68
+ # mask with 0
69
+ x = x * mask.unsqueeze(-1)
70
+
71
+ q = self.q_fc(self.qkv_dropout(x)) # [b, T, C]
72
+ pad_l = pad_r = self.window_size // 2
73
+ x_pad = F.pad(x, (0, 0, pad_l, pad_r)) # [b, T+w, C]
74
+ # organize the window-based kv
75
+ b_idx = torch.arange(b, dtype=torch.long,
76
+ device=x.device).contiguous().view(b, 1, 1)
77
+ t_idx = self._gen_t_index(T, x.device).unsqueeze(0)
78
+ x_pad = x_pad[b_idx, t_idx] # [b, T, w, C]
79
+ if self.kv_pos_bias is not None:
80
+ x_pad = self.qkv_dropout(
81
+ x_pad + self.kv_pos_bias.unsqueeze(0).unsqueeze(1))
82
+ else:
83
+ x_pad = self.qkv_dropout(x_pad)
84
+ kv = self.kv_fc(x_pad) # [b, T, w, 2*C]
85
+ k, v = kv.chunk(2, -1) # both are [b, T, w, C]
86
+ # multi-head splitting
87
+ q = q.contiguous().view(b, T, self.nhead, -1) # [b, T, h, C/h]
88
+ k = k.contiguous().view(b, T, self.window_size, self.nhead,
89
+ -1).transpose(2, 3) # [b, T, h, w, C/h]
90
+ v = v.contiguous().view(b, T, self.window_size, self.nhead,
91
+ -1).transpose(2, 3)
92
+ # calculate attention scores
93
+ # the scaling of qk refers to: https://kexue.fm/archives/8823
94
+ alpha = q.unsqueeze(3).matmul(
95
+ k.transpose(-1, -2) / self.head_dim *
96
+ math.log(self.window_size)) # [b, T, h, 1, w]
97
+ if self.mlm:
98
+ alpha = self._apply_attn_mask(alpha)
99
+ alpha = alpha.softmax(-1)
100
+ output = alpha.matmul(v).squeeze(-2).contiguous().view(b, T,
101
+ -1) # [b, T, C]
102
+ output = self.proj_dropout(self.proj(output))
103
+ output = output * mask.unsqueeze(-1)
104
+ return output
105
+
106
+
107
+ class LocalAttentionBlock(nn.Module):
108
+
109
+ def __init__(self,
110
+ feat_dim,
111
+ nhead,
112
+ window_size,
113
+ add_pos_bias: bool,
114
+ drop=0.0,
115
+ proj_drop=0.0,
116
+ init_values=1e-6,
117
+ mlm=False):
118
+ super().__init__()
119
+ self.norm1 = nn.LayerNorm(feat_dim)
120
+ self.sa = LocalSelfAttention(feat_dim,
121
+ nhead,
122
+ window_size,
123
+ add_pos_bias,
124
+ drop,
125
+ proj_drop,
126
+ mlm=mlm)
127
+ self.norm2 = nn.LayerNorm(feat_dim)
128
+ self.mlp = nn.Sequential(
129
+ nn.Linear(feat_dim, feat_dim * 4),
130
+ nn.GELU(),
131
+ nn.Dropout(drop),
132
+ nn.Linear(feat_dim * 4, feat_dim),
133
+ nn.Dropout(drop),
134
+ )
135
+ if init_values > 0:
136
+ self.gamma_1 = nn.Parameter(init_values * torch.ones(feat_dim),
137
+ requires_grad=True)
138
+ self.gamma_2 = nn.Parameter(init_values * torch.ones(feat_dim),
139
+ requires_grad=True)
140
+ else:
141
+ self.gamma_1, self.gamma_2 = 1.0, 1.0
142
+
143
+ def forward(self, x, mask):
144
+ x = x + self.gamma_1 * self.sa(self.norm1(x), mask)
145
+ x = x + self.gamma_2 * self.mlp(self.norm2(x))
146
+ x = x * mask.unsqueeze(-1)
147
+ return x
148
+
149
+
150
+ class LocalAttentionModule(nn.Module):
151
+
152
+ def __init__(self,
153
+ feat_dim,
154
+ nhead,
155
+ window_size,
156
+ num_layers,
157
+ drop_rate=0.0,
158
+ proj_drop_rate=0.0,
159
+ detach_grad=False,
160
+ mlm=False):
161
+ super().__init__()
162
+ self.attn_blocks = nn.ModuleList([
163
+ LocalAttentionBlock(
164
+ feat_dim,
165
+ nhead,
166
+ window_size,
167
+ add_pos_bias=(i == 0),
168
+ drop=drop_rate,
169
+ proj_drop=proj_drop_rate,
170
+ mlm=mlm,
171
+ ) for i in range(num_layers)
172
+ ])
173
+
174
+ self.detach_grad = detach_grad
175
+
176
+ def forward(self, x, mask):
177
+ if self.detach_grad:
178
+ x = x.detach()
179
+ for blk in self.attn_blocks:
180
+ x = blk(x, mask)
181
+ return x
182
+
183
+
184
+ def softmax_m1(x: torch.Tensor, dim: int):
185
+ # for x >= 0
186
+ fx = x.exp() - 1
187
+ fx = fx / fx.sum(dim, keepdim=True)
188
+ return fx
189
+
190
+
191
+ class BilinearLayer(nn.Module):
192
+
193
+ def __init__(self, in1, in2, out, bias=True):
194
+ super(BilinearLayer, self).__init__()
195
+ self.weight = nn.Parameter(torch.randn(out, in1, in2))
196
+ if bias:
197
+ self.bias = nn.Parameter(torch.zeros(out))
198
+ else:
199
+ self.bias = None
200
+ torch.nn.init.xavier_normal_(self.weight, 0.1)
201
+
202
+ def forward(self, x1, x2):
203
+ '''
204
+ input:
205
+ x1: [b, T1, in1]
206
+ x2: [b, T2, in2]
207
+ output:
208
+ y: [b, T1, T2, out]
209
+ '''
210
+ y = torch.einsum('bim,omn->bino', x1, self.weight) # [b, T1, in2, out]
211
+ y = torch.einsum('bino,bjn->bijo', y, x2) # [b, T1, T2, out]
212
+ if self.bias is not None:
213
+ y = y + self.bias.contiguous().view(1, 1, 1, -1)
214
+ return y
215
+
216
+
217
+ class NeighborDecoder(nn.Module):
218
+ """Find neighbors for each character In this version, each iteration shares
219
+ the same decoder with the local vision decoder."""
220
+
221
+ def __init__(self,
222
+ num_classes,
223
+ feat_dim,
224
+ max_len=1000,
225
+ detach_grad=False,
226
+ **kwargs):
227
+ super().__init__()
228
+ self.eos_emb = nn.Parameter(torch.ones(feat_dim))
229
+ trunc_normal_(self.eos_emb, std=.02)
230
+ self.q_fc = nn.Linear(feat_dim, feat_dim, bias=True)
231
+ self.k_fc = nn.Linear(feat_dim, feat_dim)
232
+
233
+ self.neighbor_navigator = BilinearLayer(feat_dim, feat_dim, 1)
234
+
235
+ self.vis_cls = nn.Linear(feat_dim, num_classes)
236
+
237
+ self.p_threshold = 0.6
238
+ self.max_len = max_len or 1000 # to avoid endless loop
239
+ self.max_ch = max_len or 1000
240
+
241
+ self.detach_grad = detach_grad
242
+ self.attn_scaling = kwargs['attn_scaling']
243
+
244
+ def align_chars(self, start_map, nb_map, max_ch=None):
245
+ if self.training:
246
+ assert max_ch is not None
247
+ max_ch = max_ch or self.max_ch # required during training to be efficient
248
+ b, N = nb_map.shape[:2]
249
+
250
+ char_map = start_map # [b, N]
251
+ all_finished = torch.zeros(b, dtype=torch.long, device=nb_map.device)
252
+ char_maps = []
253
+ char_masks = []
254
+ for i in range(max_ch):
255
+ char_maps.append(char_map)
256
+ char_mask = (all_finished == 0).float()
257
+ char_masks.append(char_mask)
258
+ if i == max_ch - 1:
259
+ break
260
+ all_finished = all_finished + (char_map[:, -1] >
261
+ self.p_threshold).long()
262
+ if not self.training:
263
+ # check if end
264
+ if (all_finished > 0).sum().item() == b:
265
+ break
266
+ if self.training:
267
+ char_map = char_map.unsqueeze(1).matmul(nb_map).squeeze(1)
268
+ else:
269
+ # char_map_dt = (char_map.detach() * 50).softmax(-1)
270
+ k = min(1 + i * 2, 16)
271
+ char_map_dt = softmax_m1(char_map.detach() * k, dim=-1)
272
+ char_map = char_map_dt.unsqueeze(1).matmul(nb_map).squeeze(1)
273
+
274
+ char_maps = torch.stack(char_maps, dim=1) # [b, L, N], L = n_char + 1
275
+ char_masks = torch.stack(char_masks, dim=1) # [b, L], 0 denotes masked
276
+ return char_maps, char_masks
277
+
278
+ def forward(self, x: torch.FloatTensor, max_char: int = None):
279
+ b, c, h, w = x.size()
280
+ x = x.flatten(2).transpose(1, 2) # [b, N, c], N = h x w
281
+ g = x.mean(1) # global representation, [b, c]
282
+
283
+ # append eos emb to x
284
+ x_ext = torch.cat(
285
+ [x, self.eos_emb.unsqueeze(0).expand(b, -1).unsqueeze(1)],
286
+ dim=1) # [b, N+1, c]
287
+
288
+ # locate the first character feature
289
+ q_start = self.q_fc(g) # [b, c]
290
+ k_feat = self.k_fc(x_ext) # [b, N+1, c]
291
+ start_map = k_feat.matmul(q_start.unsqueeze(-1)).squeeze(
292
+ -1) # [b, N+1]
293
+ # scaling, referring to: https://kexue.fm/archives/8823
294
+ if self.attn_scaling:
295
+ start_map = start_map / (c**0.5)
296
+ start_map = start_map.softmax(1)
297
+
298
+ # Neighbor discovering
299
+ q_feat = self.q_fc(x)
300
+ nb_map = self.neighbor_navigator(q_feat,
301
+ k_feat).squeeze(-1) # [b, N, N+1]
302
+ if self.attn_scaling:
303
+ nb_map = nb_map / (c**0.5)
304
+ nb_map = nb_map.softmax(2)
305
+ last_neighbor = torch.zeros(h * w + 1, device=x.device)
306
+ last_neighbor[-1] = 1.0
307
+ nb_map = torch.cat(
308
+ [
309
+ nb_map,
310
+ last_neighbor.contiguous().view(1, 1, -1).expand(b, -1, -1)
311
+ ],
312
+ dim=1) # to complete the neighbor matrix, (N+1) x (N+1)
313
+
314
+ # string (feature) decoding
315
+ char_maps, char_masks = self.align_chars(start_map, nb_map, max_char)
316
+ char_feats = char_maps.matmul(x_ext) # [b, L, c]
317
+ char_feats = char_feats * char_masks.unsqueeze(-1)
318
+ logits = self.vis_cls(char_feats) # [b, L, nC]
319
+
320
+ results = dict(
321
+ logits=logits,
322
+ char_feats=char_feats,
323
+ char_maps=char_maps,
324
+ char_masks=char_masks,
325
+ h=h,
326
+ nb_map=nb_map,
327
+ )
328
+ return results
329
+
330
+
331
+ class FeatureMapEnhancer(nn.Module):
332
+ """ Merge the global and local features
333
+ """
334
+
335
+ def __init__(self,
336
+ feat_dim,
337
+ num_layers=1,
338
+ focal_level=3,
339
+ max_kh=1,
340
+ layerscale_value=1e-6,
341
+ drop_rate=0.0):
342
+ super().__init__()
343
+ self.norm1 = nn.LayerNorm(feat_dim)
344
+ self.merge_layer = nn.ModuleList([
345
+ FocalNetBlock(
346
+ dim=feat_dim,
347
+ mlp_ratio=4,
348
+ drop=drop_rate,
349
+ focal_level=focal_level,
350
+ max_kh=max_kh,
351
+ focal_window=3,
352
+ use_layerscale=True,
353
+ layerscale_value=layerscale_value,
354
+ ) for i in range(num_layers)
355
+ ])
356
+ # self.scale = 1. / (feat_dim ** 0.5)
357
+ self.norm2 = nn.LayerNorm(feat_dim)
358
+ self.dropout = nn.Dropout(drop_rate)
359
+
360
+ def forward(self, feat_map, feat_char, char_attn_map):
361
+ """
362
+ feat_map: [b, N, C]
363
+ feat_char: [b, T, C], T include the EOS token
364
+ char_attn_map: [b, T, N], N exclude the EOS token
365
+ vis_mask: [b, N]
366
+ h: height of the feature map
367
+
368
+ return: [b, C, h, w]
369
+ """
370
+ b, C, h, w = feat_map.size()
371
+ feat_map = feat_map.flatten(2).transpose(1, 2)
372
+ # 1. restore the char feats into the visual map
373
+ # char_feat_map = char_attn_map.transpose(1, 2).matmul(feat_char * self.scale) # [b, N, C]
374
+ char_feat_map = char_attn_map.transpose(1, 2).matmul(
375
+ feat_char) # [b, N, C]
376
+ char_feat_map = self.norm1(char_feat_map)
377
+ feat_map = feat_map + char_feat_map
378
+
379
+ # 2. merge
380
+ # vis_mask = vis_mask.contiguous().view(b, h, -1) # [b, h, w]
381
+ for blk in self.merge_layer:
382
+ blk.H, blk.W = h, w
383
+ feat_map = blk(feat_map)
384
+ feat_map = self.dropout(self.norm2(feat_map))
385
+ feat_map = feat_map.transpose(1, 2).reshape(b, C, h, w) # [b, C, h, w]
386
+ # feat_map = feat_map * vis_mask.unsqueeze(1)
387
+ return feat_map
388
+
389
+
390
+ class LISTERDecoder(nn.Module):
391
+
392
+ def __init__(self,
393
+ in_channels,
394
+ out_channels,
395
+ max_len=25,
396
+ use_fem=True,
397
+ detach_grad=False,
398
+ nhead=8,
399
+ window_size=11,
400
+ iters=2,
401
+ num_sa_layers=1,
402
+ num_mg_layers=1,
403
+ coef=[1.0, 0.01, 0.001],
404
+ **kwargs):
405
+ super().__init__()
406
+ num_classes = out_channels - 1
407
+ self.ignore_index = num_classes
408
+ self.max_len = max_len
409
+ self.use_fem = use_fem
410
+ self.detach_grad = detach_grad
411
+ self.iters = max(1, iters) if use_fem else 0
412
+ feat_dim = in_channels
413
+ self.decoder = NeighborDecoder(num_classes,
414
+ feat_dim,
415
+ max_len=max_len,
416
+ detach_grad=detach_grad,
417
+ **kwargs)
418
+ if iters > 0 and use_fem:
419
+ self.cntx_module = LocalAttentionModule(feat_dim,
420
+ nhead,
421
+ window_size,
422
+ num_sa_layers,
423
+ drop_rate=0.1,
424
+ proj_drop_rate=0.1,
425
+ detach_grad=detach_grad,
426
+ mlm=kwargs.get(
427
+ 'mlm', False))
428
+ self.merge_layer = FeatureMapEnhancer(feat_dim,
429
+ num_layers=num_mg_layers)
430
+ self.celoss_fn = nn.CrossEntropyLoss(reduction='mean',
431
+ ignore_index=self.ignore_index)
432
+ self.coef = coef # for loss of rec, eos and ent respectively
433
+ # self.coef=(1.0, 0.0, 0.0)
434
+ self.apply(self._init_weights)
435
+
436
+ def _init_weights(self, m):
437
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
438
+ trunc_normal_(m.weight, std=.02)
439
+ try:
440
+ nn.init.constant_(m.bias, 0)
441
+ except:
442
+ pass
443
+
444
+ def forward(self, x, data=None):
445
+ if data is not None:
446
+ labels, label_lens = data
447
+ label_lens = label_lens + 1
448
+ max_char = label_lens.max()
449
+ else:
450
+ max_char = self.max_len
451
+
452
+ res_vis = self.decoder(x, max_char=max_char)
453
+ res_list = [res_vis]
454
+ if self.use_fem:
455
+ for it in range(self.iters):
456
+ char_feat_cntx = self.cntx_module(res_list[-1]['char_feats'],
457
+ res_list[-1]['char_masks'])
458
+ # import ipdb;ipdb.set_trace()
459
+ char_maps = res_list[-1]['char_maps']
460
+ if self.detach_grad:
461
+ char_maps = char_maps.detach()
462
+ feat_map = self.merge_layer(
463
+ x,
464
+ char_feat_cntx,
465
+ char_maps[:, :, :-1],
466
+ )
467
+ res_i = self.decoder(feat_map, max_char)
468
+ res_list.append(res_i)
469
+ if self.training:
470
+ loss_dict = self.get_loss(res_list[0], labels, label_lens)
471
+ for it in range(self.iters):
472
+ loss_dict_i = self.get_loss(res_list[it + 1], labels,
473
+ label_lens)
474
+ for k, v in loss_dict_i.items():
475
+ loss_dict[k] += v
476
+ else:
477
+ loss_dict = None
478
+ return [loss_dict, res_list[-1]]
479
+
480
+ def calc_rec_loss(self, logits, targets):
481
+ """
482
+ Args:
483
+ logits: [minibatch, C, T], not passed to the softmax func.
484
+ targets, torch.cuda.LongTensor [minibatch, T]
485
+ target_lens: [minibatch]
486
+ mask: [minibatch, T]
487
+ """
488
+ losses = self.celoss_fn(logits, targets)
489
+ return losses
490
+
491
+ def calc_eos_loc_loss(self, char_maps, target_lens, eps=1e-10):
492
+ max_tok = char_maps.shape[2]
493
+ eos_idx = (target_lens - 1).contiguous().view(-1, 1, 1).expand(
494
+ -1, 1, max_tok)
495
+ eos_maps = torch.gather(char_maps, dim=1,
496
+ index=eos_idx).squeeze(1) # (b, max_tok)
497
+ loss = (eos_maps[:, -1] + eps).log().neg()
498
+ return loss.mean()
499
+
500
+ def calc_entropy(self, p: torch.Tensor, mask: torch.Tensor, eps=1e-10):
501
+ """
502
+ Args:
503
+ p: probability distribution over the last dimension, of size (..., L, C)
504
+ mask: (..., L)
505
+ """
506
+ p_nlog = (p + eps).log().neg()
507
+ ent = p * p_nlog
508
+ ent = ent.sum(-1) / math.log(p.size(-1) + 1)
509
+ ent = (ent * mask).sum(-1) / (mask.sum(-1) + eps) # (...)
510
+ ent = ent.mean()
511
+ return ent
512
+
513
+ def get_loss(self, model_output, labels, label_lens):
514
+ labels = labels[:, :label_lens.max()]
515
+ batch_size, max_len = labels.size()
516
+ seq_range = torch.arange(
517
+ 0, max_len, device=labels.device).long().unsqueeze(0).expand(
518
+ batch_size, max_len)
519
+ seq_len = label_lens.unsqueeze(1).expand_as(seq_range)
520
+ mask = (seq_range < seq_len).float() # [batch_size, max_len]
521
+
522
+ l_rec = self.calc_rec_loss(model_output['logits'].transpose(1, 2),
523
+ labels)
524
+ l_eos = self.calc_eos_loc_loss(model_output['char_maps'], label_lens)
525
+ l_ent = self.calc_entropy(model_output['char_maps'], mask)
526
+
527
+ loss = l_rec * self.coef[0] + l_eos * self.coef[1] + l_ent * self.coef[
528
+ 2]
529
+ loss_dict = dict(
530
+ loss=loss,
531
+ l_rec=l_rec,
532
+ l_eos=l_eos,
533
+ l_ent=l_ent,
534
+ )
535
+ return loss_dict
@@ -0,0 +1,119 @@
1
+ import copy
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from .abinet_decoder import PositionAttention
7
+ from .nrtr_decoder import PositionalEncoding, TransformerBlock
8
+
9
+
10
+ class Trans(nn.Module):
11
+
12
+ def __init__(self, dim, nhead, dim_feedforward, dropout, num_layers):
13
+ super().__init__()
14
+ self.d_model = dim
15
+ self.nhead = nhead
16
+
17
+ self.pos_encoder = PositionalEncoding(dropout=0.0,
18
+ dim=self.d_model,
19
+ max_len=512)
20
+
21
+ self.transformer = nn.ModuleList([
22
+ TransformerBlock(
23
+ dim,
24
+ nhead,
25
+ dim_feedforward,
26
+ attention_dropout_rate=dropout,
27
+ residual_dropout_rate=dropout,
28
+ with_self_attn=True,
29
+ with_cross_attn=False,
30
+ ) for i in range(num_layers)
31
+ ])
32
+
33
+ def forward(self, feature, attn_map=None, use_mask=False):
34
+ n, c, h, w = feature.shape
35
+ feature = feature.flatten(2).transpose(1, 2)
36
+
37
+ if use_mask:
38
+ _, t, h, w = attn_map.shape
39
+ location_mask = (attn_map.view(n, t, -1).transpose(1, 2) >
40
+ 0.05).type(torch.float) # n,hw,t
41
+ location_mask = location_mask.bmm(location_mask.transpose(
42
+ 1, 2)) # n, hw, hw
43
+ location_mask = location_mask.new_zeros(
44
+ (h * w, h * w)).masked_fill(location_mask > 0, float('-inf'))
45
+ location_mask = location_mask.unsqueeze(1) # n, 1, hw, hw
46
+ else:
47
+ location_mask = None
48
+
49
+ feature = self.pos_encoder(feature)
50
+ for layer in self.transformer:
51
+ feature = layer(feature, self_mask=location_mask)
52
+ feature = feature.transpose(1, 2).view(n, c, h, w)
53
+ return feature, location_mask
54
+
55
+
56
+ def _get_clones(module, N):
57
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
58
+
59
+
60
+ class LPVDecoder(nn.Module):
61
+
62
+ def __init__(self,
63
+ in_channels,
64
+ out_channels,
65
+ num_layer=3,
66
+ max_len=25,
67
+ use_mask=False,
68
+ dim_feedforward=1024,
69
+ nhead=8,
70
+ dropout=0.1,
71
+ trans_layer=2):
72
+ super().__init__()
73
+ self.use_mask = use_mask
74
+ self.max_len = max_len
75
+ attn_layer = PositionAttention(max_length=max_len + 1,
76
+ mode='nearest',
77
+ in_channels=in_channels,
78
+ num_channels=in_channels // 8)
79
+ trans_layer = Trans(dim=in_channels,
80
+ nhead=nhead,
81
+ dim_feedforward=dim_feedforward,
82
+ dropout=dropout,
83
+ num_layers=trans_layer)
84
+ cls_layer = nn.Linear(in_channels, out_channels - 2)
85
+
86
+ self.attention = _get_clones(attn_layer, num_layer)
87
+ self.trans = _get_clones(trans_layer, num_layer - 1)
88
+ self.cls = _get_clones(cls_layer, num_layer)
89
+
90
+ def forward(self, x, data=None):
91
+ if data is not None:
92
+ max_len = data[1].max()
93
+ else:
94
+ max_len = self.max_len
95
+ features = x # (N, E, H, W)
96
+
97
+ attn_vecs, attn_scores_map = self.attention[0](features)
98
+ attn_vecs = attn_vecs[:, :max_len + 1, :]
99
+ if not self.training:
100
+ for i in range(1, len(self.attention)):
101
+ features, mask = self.trans[i - 1](features,
102
+ attn_scores_map,
103
+ use_mask=self.use_mask)
104
+ attn_vecs, attn_scores_map = self.attention[i](
105
+ features, attn_vecs) # (N, T, E), (N, T, H, W)
106
+ return F.softmax(self.cls[-1](attn_vecs), -1)
107
+ else:
108
+ logits = []
109
+ logit = self.cls[0](attn_vecs) # (N, T, C)
110
+ logits.append(logit)
111
+ for i in range(1, len(self.attention)):
112
+ features, mask = self.trans[i - 1](features,
113
+ attn_scores_map,
114
+ use_mask=self.use_mask)
115
+ attn_vecs, attn_scores_map = self.attention[i](
116
+ features, attn_vecs) # (N, T, E), (N, T, H, W)
117
+ logit = self.cls[i](attn_vecs) # (N, T, C)
118
+ logits.append(logit)
119
+ return logits