openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,238 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class GELU(nn.Module):
6
+
7
+ def __init__(self, inplace=True):
8
+ super(GELU, self).__init__()
9
+ self.inplace = inplace
10
+
11
+ def forward(self, x):
12
+ return torch.nn.functional.gelu(x)
13
+
14
+
15
+ class Swish(nn.Module):
16
+
17
+ def __init__(self, inplace=True):
18
+ super(Swish, self).__init__()
19
+ self.inplace = inplace
20
+
21
+ def forward(self, x):
22
+ if self.inplace:
23
+ x.mul_(torch.sigmoid(x))
24
+ return x
25
+ else:
26
+ return x * torch.sigmoid(x)
27
+
28
+
29
+ class Activation(nn.Module):
30
+
31
+ def __init__(self, act_type, inplace=True):
32
+ super(Activation, self).__init__()
33
+ act_type = act_type.lower()
34
+ if act_type == 'relu':
35
+ self.act = nn.ReLU(inplace=inplace)
36
+ elif act_type == 'relu6':
37
+ self.act = nn.ReLU6(inplace=inplace)
38
+ elif act_type == 'sigmoid':
39
+ self.act = nn.Sigmoid()
40
+ elif act_type == 'hard_sigmoid':
41
+ self.act = nn.Hardsigmoid(inplace)
42
+ elif act_type == 'hard_swish':
43
+ self.act = nn.Hardswish(inplace=inplace)
44
+ elif act_type == 'leakyrelu':
45
+ self.act = nn.LeakyReLU(inplace=inplace)
46
+ elif act_type == 'gelu':
47
+ self.act = GELU(inplace=inplace)
48
+ elif act_type == 'swish':
49
+ self.act = Swish(inplace=inplace)
50
+ else:
51
+ raise NotImplementedError
52
+
53
+ def forward(self, inputs):
54
+ return self.act(inputs)
55
+
56
+
57
+ def drop_path(x,
58
+ drop_prob: float = 0.0,
59
+ training: bool = False,
60
+ scale_by_keep: bool = True):
61
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
62
+ residual blocks).
63
+
64
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
65
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
66
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
67
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
68
+ 'survival rate' as the argument.
69
+ """
70
+ if drop_prob == 0.0 or not training:
71
+ return x
72
+ keep_prob = 1 - drop_prob
73
+ shape = (x.shape[0], ) + (1, ) * (
74
+ x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
75
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
76
+ if keep_prob > 0.0 and scale_by_keep:
77
+ random_tensor.div_(keep_prob)
78
+ return x * random_tensor
79
+
80
+
81
+ class DropPath(nn.Module):
82
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
83
+ residual blocks)."""
84
+
85
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
86
+ super(DropPath, self).__init__()
87
+ self.drop_prob = drop_prob
88
+ self.scale_by_keep = scale_by_keep
89
+
90
+ def forward(self, x):
91
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
92
+
93
+ def extra_repr(self):
94
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
95
+
96
+
97
+ class Identity(nn.Module):
98
+
99
+ def __init__(self):
100
+ super(Identity, self).__init__()
101
+
102
+ def forward(self, input):
103
+ return input
104
+
105
+
106
+ class Mlp(nn.Module):
107
+
108
+ def __init__(
109
+ self,
110
+ in_features,
111
+ hidden_features=None,
112
+ out_features=None,
113
+ act_layer=nn.GELU,
114
+ drop=0.0,
115
+ ):
116
+ super().__init__()
117
+ out_features = out_features or in_features
118
+ hidden_features = hidden_features or in_features
119
+ self.fc1 = nn.Linear(in_features, hidden_features)
120
+ self.act = act_layer()
121
+ self.fc2 = nn.Linear(hidden_features, out_features)
122
+ self.drop = nn.Dropout(drop)
123
+
124
+ def forward(self, x):
125
+ x = self.fc1(x)
126
+ x = self.act(x)
127
+ x = self.drop(x)
128
+ x = self.fc2(x)
129
+ x = self.drop(x)
130
+ return x
131
+
132
+
133
+ class Attention(nn.Module):
134
+
135
+ def __init__(self,
136
+ dim,
137
+ num_heads=8,
138
+ qkv_bias=False,
139
+ qk_scale=None,
140
+ attn_drop=0.0,
141
+ proj_drop=0.0):
142
+ super().__init__()
143
+ self.num_heads = num_heads
144
+ head_dim = dim // num_heads
145
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
146
+ self.scale = qk_scale or head_dim**-0.5
147
+
148
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
149
+ self.attn_drop = nn.Dropout(attn_drop)
150
+ self.proj = nn.Linear(dim, dim)
151
+ self.proj_drop = nn.Dropout(proj_drop)
152
+
153
+ def forward(self, x):
154
+ B, N, C = x.shape
155
+ qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads,
156
+ C // self.num_heads).permute(2, 0, 3, 1, 4))
157
+ q, k, v = qkv[0], qkv[1], qkv[
158
+ 2] # make torchscript happy (cannot use tensor as tuple)
159
+
160
+ attn = (q @ k.transpose(-2, -1)) * self.scale
161
+ attn = attn.softmax(dim=-1)
162
+ attn = self.attn_drop(attn)
163
+
164
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
165
+ x = self.proj(x)
166
+ x = self.proj_drop(x)
167
+ return x
168
+
169
+
170
+ class Block(nn.Module):
171
+
172
+ def __init__(
173
+ self,
174
+ dim,
175
+ num_heads,
176
+ mlp_ratio=4.0,
177
+ qkv_bias=False,
178
+ qk_scale=None,
179
+ drop=0.0,
180
+ attn_drop=0.0,
181
+ drop_path=0.0,
182
+ act_layer=nn.GELU,
183
+ norm_layer=nn.LayerNorm,
184
+ ):
185
+ super().__init__()
186
+ self.norm1 = norm_layer(dim)
187
+ self.attn = Attention(
188
+ dim,
189
+ num_heads=num_heads,
190
+ qkv_bias=qkv_bias,
191
+ qk_scale=qk_scale,
192
+ attn_drop=attn_drop,
193
+ proj_drop=drop,
194
+ )
195
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
196
+ self.drop_path = DropPath(
197
+ drop_path) if drop_path > 0.0 else nn.Identity()
198
+ self.norm2 = norm_layer(dim)
199
+ mlp_hidden_dim = int(dim * mlp_ratio)
200
+ self.mlp = Mlp(in_features=dim,
201
+ hidden_features=mlp_hidden_dim,
202
+ act_layer=act_layer,
203
+ drop=drop)
204
+
205
+ def forward(self, x):
206
+ x = x + self.drop_path(self.attn(self.norm1(x)))
207
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
208
+ return x
209
+
210
+
211
+ class PatchEmbed(nn.Module):
212
+ """Image to Patch Embedding."""
213
+
214
+ def __init__(self,
215
+ img_size=[32, 128],
216
+ patch_size=[4, 4],
217
+ in_chans=3,
218
+ embed_dim=768):
219
+ super().__init__()
220
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] //
221
+ patch_size[0])
222
+ self.img_size = img_size
223
+ self.patch_size = patch_size
224
+ self.num_patches = num_patches
225
+
226
+ self.proj = nn.Conv2d(in_chans,
227
+ embed_dim,
228
+ kernel_size=patch_size,
229
+ stride=patch_size)
230
+
231
+ def forward(self, x):
232
+ B, C, H, W = x.shape
233
+ # FIXME look at relaxing size constraints
234
+ assert (
235
+ H == self.img_size[0] and W == self.img_size[1]
236
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
237
+ x = self.proj(x).flatten(2).transpose(1, 2)
238
+ return x
@@ -0,0 +1,109 @@
1
+ import torch.nn as nn
2
+
3
+ __all__ = ['build_decoder']
4
+
5
+
6
+ def build_decoder(config):
7
+ # rec decoder
8
+ from .abinet_decoder import ABINetDecoder
9
+ from .aster_decoder import ASTERDecoder
10
+ from .cdistnet_decoder import CDistNetDecoder
11
+ from .cppd_decoder import CPPDDecoder
12
+ from .rctc_decoder import RCTCDecoder
13
+ from .ctc_decoder import CTCDecoder
14
+ from .dan_decoder import DANDecoder
15
+ from .igtr_decoder import IGTRDecoder
16
+ from .lister_decoder import LISTERDecoder
17
+ from .lpv_decoder import LPVDecoder
18
+ from .mgp_decoder import MGPDecoder
19
+ from .nrtr_decoder import NRTRDecoder
20
+ from .parseq_decoder import PARSeqDecoder
21
+ from .robustscanner_decoder import RobustScannerDecoder
22
+ from .sar_decoder import SARDecoder
23
+ from .smtr_decoder import SMTRDecoder
24
+ from .smtr_decoder_nattn import SMTRDecoderNumAttn
25
+ from .srn_decoder import SRNDecoder
26
+ from .visionlan_decoder import VisionLANDecoder
27
+ from .matrn_decoder import MATRNDecoder
28
+ from .cam_decoder import CAMDecoder
29
+ from .ote_decoder import OTEDecoder
30
+ from .bus_decoder import BUSDecoder
31
+
32
+ support_dict = [
33
+ 'CTCDecoder', 'NRTRDecoder', 'CPPDDecoder', 'ABINetDecoder',
34
+ 'CDistNetDecoder', 'VisionLANDecoder', 'PARSeqDecoder', 'IGTRDecoder',
35
+ 'SMTRDecoder', 'LPVDecoder', 'SARDecoder', 'RobustScannerDecoder',
36
+ 'SRNDecoder', 'ASTERDecoder', 'RCTCDecoder', 'LISTERDecoder',
37
+ 'GTCDecoder', 'SMTRDecoderNumAttn', 'MATRNDecoder', 'MGPDecoder',
38
+ 'DANDecoder', 'CAMDecoder', 'OTEDecoder', 'BUSDecoder'
39
+ ]
40
+
41
+ module_name = config.pop('name')
42
+ assert module_name in support_dict, Exception(
43
+ 'decoder only support {}'.format(support_dict))
44
+ module_class = eval(module_name)(**config)
45
+ return module_class
46
+
47
+
48
+ class GTCDecoder(nn.Module):
49
+
50
+ def __init__(self,
51
+ in_channels,
52
+ gtc_decoder,
53
+ ctc_decoder,
54
+ detach=True,
55
+ infer_gtc=False,
56
+ out_channels=0,
57
+ **kwargs):
58
+ super(GTCDecoder, self).__init__()
59
+ self.detach = detach
60
+ self.infer_gtc = infer_gtc
61
+ if infer_gtc:
62
+ gtc_decoder['out_channels'] = out_channels[0]
63
+ ctc_decoder['out_channels'] = out_channels[1]
64
+ gtc_decoder['in_channels'] = in_channels
65
+ ctc_decoder['in_channels'] = in_channels
66
+ self.gtc_decoder = build_decoder(gtc_decoder)
67
+ else:
68
+ ctc_decoder['in_channels'] = in_channels
69
+ ctc_decoder['out_channels'] = out_channels
70
+ self.ctc_decoder = build_decoder(ctc_decoder)
71
+
72
+ def forward(self, x, data=None):
73
+ ctc_pred = self.ctc_decoder(x.detach() if self.detach else x,
74
+ data=data)
75
+ if self.training or self.infer_gtc:
76
+ gtc_pred = self.gtc_decoder(x.flatten(2).transpose(1, 2),
77
+ data=data)
78
+ return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred}
79
+ else:
80
+ return ctc_pred
81
+
82
+
83
+ class GTCDecoderTwo(nn.Module):
84
+
85
+ def __init__(self,
86
+ in_channels,
87
+ gtc_decoder,
88
+ ctc_decoder,
89
+ infer_gtc=False,
90
+ out_channels=0,
91
+ **kwargs):
92
+ super(GTCDecoderTwo, self).__init__()
93
+ self.infer_gtc = infer_gtc
94
+ gtc_decoder['out_channels'] = out_channels[0]
95
+ ctc_decoder['out_channels'] = out_channels[1]
96
+ gtc_decoder['in_channels'] = in_channels
97
+ ctc_decoder['in_channels'] = in_channels
98
+ self.gtc_decoder = build_decoder(gtc_decoder)
99
+ self.ctc_decoder = build_decoder(ctc_decoder)
100
+
101
+ def forward(self, x, data=None):
102
+ x_ctc, x_gtc = x
103
+ ctc_pred = self.ctc_decoder(x_ctc, data=data)
104
+ if self.training or self.infer_gtc:
105
+ gtc_pred = self.gtc_decoder(x_gtc.flatten(2).transpose(1, 2),
106
+ data=data)
107
+ return {'gtc_pred': gtc_pred, 'ctc_pred': ctc_pred}
108
+ else:
109
+ return ctc_pred
@@ -0,0 +1,283 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from openrec.modeling.decoders.nrtr_decoder import PositionalEncoding, TransformerBlock
6
+
7
+
8
+ class BCNLanguage(nn.Module):
9
+
10
+ def __init__(
11
+ self,
12
+ d_model=512,
13
+ nhead=8,
14
+ num_layers=4,
15
+ dim_feedforward=2048,
16
+ dropout=0.0,
17
+ max_length=25,
18
+ detach=True,
19
+ num_classes=37,
20
+ ):
21
+ super().__init__()
22
+ self.d_model = d_model
23
+ self.detach = detach
24
+ self.max_length = max_length + 1
25
+
26
+ self.proj = nn.Linear(num_classes, d_model, False)
27
+ self.token_encoder = PositionalEncoding(dropout=0.1,
28
+ dim=d_model,
29
+ max_len=self.max_length)
30
+ self.pos_encoder = PositionalEncoding(dropout=0,
31
+ dim=d_model,
32
+ max_len=self.max_length)
33
+ self.decoder = nn.ModuleList([
34
+ TransformerBlock(
35
+ d_model=d_model,
36
+ nhead=nhead,
37
+ dim_feedforward=dim_feedforward,
38
+ attention_dropout_rate=dropout,
39
+ residual_dropout_rate=dropout,
40
+ with_self_attn=False,
41
+ with_cross_attn=True,
42
+ ) for i in range(num_layers)
43
+ ])
44
+
45
+ self.cls = nn.Linear(d_model, num_classes)
46
+
47
+ def forward(self, tokens, lengths):
48
+ """
49
+ Args:
50
+ tokens: (N, T, C) where T is length, N is batch size and C is classes number
51
+ lengths: (N,)
52
+ """
53
+ if self.detach:
54
+ tokens = tokens.detach()
55
+ embed = self.proj(tokens) # (N, T, E)
56
+ embed = self.token_encoder(embed) # (N, T, E)
57
+ mask = _get_mask(lengths, self.max_length) # (N, 1, T, T)
58
+ zeros = embed.new_zeros(*embed.shape)
59
+ qeury = self.pos_encoder(zeros)
60
+ for decoder_layer in self.decoder:
61
+ qeury = decoder_layer(qeury, embed, cross_mask=mask)
62
+ output = qeury # (N, T, E)
63
+
64
+ logits = self.cls(output) # (N, T, C)
65
+ return output, logits
66
+
67
+
68
+ def encoder_layer(in_c, out_c, k=3, s=2, p=1):
69
+ return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
70
+ nn.BatchNorm2d(out_c), nn.ReLU(True))
71
+
72
+
73
+ class DecoderUpsample(nn.Module):
74
+
75
+ def __init__(self, in_c, out_c, k=3, s=1, p=1, mode='nearest') -> None:
76
+ super().__init__()
77
+ self.align_corners = None if mode == 'nearest' else True
78
+ self.mode = mode
79
+ # nn.Upsample(size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners),
80
+ self.w = nn.Sequential(
81
+ nn.Conv2d(in_c, out_c, k, s, p),
82
+ nn.BatchNorm2d(out_c),
83
+ nn.ReLU(True),
84
+ )
85
+
86
+ def forward(self, x, size):
87
+ x = F.interpolate(x,
88
+ size=size,
89
+ mode=self.mode,
90
+ align_corners=self.align_corners)
91
+ return self.w(x)
92
+
93
+
94
+ class PositionAttention(nn.Module):
95
+
96
+ def __init__(self,
97
+ max_length,
98
+ in_channels=512,
99
+ num_channels=64,
100
+ mode='nearest',
101
+ **kwargs):
102
+ super().__init__()
103
+ self.max_length = max_length
104
+ self.k_encoder = nn.Sequential(
105
+ encoder_layer(in_channels, num_channels, s=(1, 2)),
106
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
107
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
108
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
109
+ )
110
+ self.k_decoder = nn.ModuleList([
111
+ DecoderUpsample(num_channels, num_channels, mode=mode),
112
+ DecoderUpsample(num_channels, num_channels, mode=mode),
113
+ DecoderUpsample(num_channels, num_channels, mode=mode),
114
+ DecoderUpsample(num_channels, in_channels, mode=mode),
115
+ ])
116
+
117
+ self.pos_encoder = PositionalEncoding(dropout=0,
118
+ dim=in_channels,
119
+ max_len=max_length)
120
+ self.project = nn.Linear(in_channels, in_channels)
121
+
122
+ def forward(self, x, query=None):
123
+ N, E, H, W = x.size()
124
+ k, v = x, x # (N, E, H, W)
125
+
126
+ # calculate key vector
127
+ features = []
128
+ size_decoder = []
129
+ for i in range(0, len(self.k_encoder)):
130
+ size_decoder.append(k.shape[2:])
131
+ k = self.k_encoder[i](k)
132
+ features.append(k)
133
+ for i in range(0, len(self.k_decoder) - 1):
134
+ k = self.k_decoder[i](k, size=size_decoder[-(i + 1)])
135
+ k = k + features[len(self.k_decoder) - 2 - i]
136
+ k = self.k_decoder[-1](k, size=size_decoder[0]) # (N, E, H, W)
137
+ # calculate query vector
138
+ # TODO q=f(q,k)
139
+ zeros = x.new_zeros(
140
+ (N, self.max_length, E)) if query is None else query # (N, T, E)
141
+ q = self.pos_encoder(zeros) # (N, T, E)
142
+ q = self.project(q) # (N, T, E)
143
+
144
+ # calculate attention
145
+ attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
146
+ attn_scores = attn_scores / (E**0.5)
147
+ attn_scores = F.softmax(attn_scores, dim=-1)
148
+
149
+ # (N, E, H, W) -> (N, H, W, E) -> (N, (H*W), E)
150
+ v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
151
+ attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
152
+ return attn_vecs, attn_scores.view(N, -1, H, W)
153
+
154
+
155
+ class ABINetDecoder(nn.Module):
156
+
157
+ def __init__(self,
158
+ in_channels,
159
+ out_channels,
160
+ nhead=8,
161
+ num_layers=3,
162
+ dim_feedforward=2048,
163
+ dropout=0.1,
164
+ max_length=25,
165
+ iter_size=3,
166
+ **kwargs):
167
+ super().__init__()
168
+ self.max_length = max_length + 1
169
+ d_model = in_channels
170
+ self.pos_encoder = PositionalEncoding(dropout=0.1, dim=d_model)
171
+ self.encoder = nn.ModuleList([
172
+ TransformerBlock(
173
+ d_model=d_model,
174
+ nhead=nhead,
175
+ dim_feedforward=dim_feedforward,
176
+ attention_dropout_rate=dropout,
177
+ residual_dropout_rate=dropout,
178
+ with_self_attn=True,
179
+ with_cross_attn=False,
180
+ ) for _ in range(num_layers)
181
+ ])
182
+ self.decoder = PositionAttention(
183
+ max_length=self.max_length, # additional stop token
184
+ in_channels=d_model,
185
+ num_channels=d_model // 8,
186
+ mode='nearest',
187
+ )
188
+ self.out_channels = out_channels
189
+ self.cls = nn.Linear(d_model, self.out_channels)
190
+ self.iter_size = iter_size
191
+ if iter_size > 0:
192
+ self.language = BCNLanguage(
193
+ d_model=d_model,
194
+ nhead=nhead,
195
+ num_layers=4,
196
+ dim_feedforward=dim_feedforward,
197
+ dropout=dropout,
198
+ max_length=max_length,
199
+ num_classes=self.out_channels,
200
+ )
201
+ # alignment
202
+ self.w_att_align = nn.Linear(2 * d_model, d_model)
203
+ self.cls_align = nn.Linear(d_model, self.out_channels)
204
+
205
+ def forward(self, x, data=None):
206
+ # bs, c, h, w
207
+ x = x.permute([0, 2, 3, 1]) # bs, h, w, c
208
+ _, H, W, C = x.shape
209
+ # assert H % 8 == 0 and W % 16 == 0, 'The height and width should be multiples of 8 and 16.'
210
+ feature = x.flatten(1, 2) # bs, h*w, c
211
+ feature = self.pos_encoder(feature) # bs, h*w, c
212
+ for encoder_layer in self.encoder:
213
+ feature = encoder_layer(feature)
214
+ # bs, h*w, c
215
+ feature = feature.reshape([-1, H, W, C]).permute(0, 3, 1,
216
+ 2) # bs, c, h, w
217
+ v_feature, _ = self.decoder(feature) # (bs[N], T, E)
218
+ vis_logits = self.cls(v_feature) # (bs[N], T, E)
219
+ align_lengths = _get_length(vis_logits)
220
+ align_logits = vis_logits
221
+ all_l_res, all_a_res = [], []
222
+ for _ in range(self.iter_size):
223
+ tokens = F.softmax(align_logits, dim=-1)
224
+ lengths = torch.clamp(
225
+ align_lengths, 2,
226
+ self.max_length) # TODO: move to language model
227
+ l_feature, l_logits = self.language(tokens, lengths)
228
+
229
+ # alignment
230
+ all_l_res.append(l_logits)
231
+ fuse = torch.cat((l_feature, v_feature), -1)
232
+ f_att = torch.sigmoid(self.w_att_align(fuse))
233
+ output = f_att * v_feature + (1 - f_att) * l_feature
234
+ align_logits = self.cls_align(output)
235
+
236
+ align_lengths = _get_length(align_logits)
237
+ all_a_res.append(align_logits)
238
+ if self.training:
239
+ return {
240
+ 'align': all_a_res,
241
+ 'lang': all_l_res,
242
+ 'vision': vis_logits
243
+ }
244
+ else:
245
+ return F.softmax(align_logits, -1)
246
+
247
+
248
+ def _get_length(logit):
249
+ """Greed decoder to obtain length from logit."""
250
+ out = logit.argmax(dim=-1) == 0
251
+ non_zero_mask = out.int() != 0
252
+ mask_max_values, mask_max_indices = torch.max(non_zero_mask.int(), dim=-1)
253
+ mask_max_indices[mask_max_values == 0] = -1
254
+ out = mask_max_indices + 1
255
+ return out
256
+
257
+
258
+ def _get_mask(length, max_length):
259
+ """Generate a square mask for the sequence.
260
+
261
+ The masked positions are filled with float('-inf'). Unmasked positions are
262
+ filled with float(0.0).
263
+ """
264
+ length = length.unsqueeze(-1)
265
+ N = length.size(0)
266
+ grid = torch.arange(0, max_length, device=length.device).unsqueeze(0)
267
+ zero_mask = torch.zeros([N, max_length],
268
+ dtype=torch.float32,
269
+ device=length.device)
270
+ inf_mask = torch.full([N, max_length],
271
+ float('-inf'),
272
+ dtype=torch.float32,
273
+ device=length.device)
274
+ diag_mask = torch.diag(
275
+ torch.full([max_length],
276
+ float('-inf'),
277
+ dtype=torch.float32,
278
+ device=length.device),
279
+ diagonal=0,
280
+ )
281
+ mask = torch.where(grid >= length, inf_mask, zero_mask)
282
+ mask = mask.unsqueeze(1) + diag_mask
283
+ return mask.unsqueeze(1)