openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,170 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from torch.nn import init
5
+
6
+
7
+ class Embedding(nn.Module):
8
+
9
+ def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
10
+ super(Embedding, self).__init__()
11
+ self.in_timestep = in_timestep
12
+ self.in_planes = in_planes
13
+ self.embed_dim = embed_dim
14
+ self.mid_dim = mid_dim
15
+ self.eEmbed = nn.Linear(
16
+ in_timestep * in_planes,
17
+ self.embed_dim) # Embed encoder output to a word-embedding like
18
+
19
+ def forward(self, x):
20
+ x = x.flatten(1)
21
+ x = self.eEmbed(x)
22
+ return x
23
+
24
+
25
+ class Attn_Rnn_Block(nn.Module):
26
+
27
+ def __init__(self, featdim, hiddendim, embedding_dim, out_channels,
28
+ attndim):
29
+ super(Attn_Rnn_Block, self).__init__()
30
+
31
+ self.attndim = attndim
32
+ self.embedding_dim = embedding_dim
33
+ self.feat_embed = nn.Linear(featdim, attndim)
34
+ self.hidden_embed = nn.Linear(hiddendim, attndim)
35
+ self.attnfeat_embed = nn.Linear(attndim, 1)
36
+ self.gru = nn.GRU(input_size=featdim + self.embedding_dim,
37
+ hidden_size=hiddendim,
38
+ batch_first=True)
39
+ self.fc = nn.Linear(hiddendim, out_channels)
40
+ self.init_weights()
41
+
42
+ def init_weights(self):
43
+ init.normal_(self.hidden_embed.weight, std=0.01)
44
+ init.constant_(self.hidden_embed.bias, 0)
45
+ init.normal_(self.attnfeat_embed.weight, std=0.01)
46
+ init.constant_(self.attnfeat_embed.bias, 0)
47
+
48
+ def _attn(self, feat, h_state):
49
+ b, t, _ = feat.shape
50
+ feat = self.feat_embed(feat)
51
+ h_state = self.hidden_embed(h_state.squeeze(0)).unsqueeze(1)
52
+ h_state = h_state.expand(b, t, self.attndim)
53
+ sumTanh = torch.tanh(feat + h_state)
54
+ attn_w = self.attnfeat_embed(sumTanh).squeeze(-1)
55
+ attn_w = F.softmax(attn_w, dim=1).unsqueeze(1)
56
+ # [B,1,25]
57
+ return attn_w
58
+
59
+ def forward(self, feat, h_state, label_input):
60
+
61
+ attn_w = self._attn(feat, h_state)
62
+
63
+ attn_feat = attn_w @ feat
64
+ attn_feat = attn_feat.squeeze(1)
65
+
66
+ output, h_state = self.gru(
67
+ torch.cat([label_input, attn_feat], 1).unsqueeze(1), h_state)
68
+ pred = self.fc(output)
69
+
70
+ return pred, h_state
71
+
72
+
73
+ class ASTERDecoder(nn.Module):
74
+
75
+ def __init__(self,
76
+ in_channels,
77
+ out_channels,
78
+ embedding_dim=256,
79
+ hiddendim=256,
80
+ attndim=256,
81
+ max_len=25,
82
+ seed=False,
83
+ time_step=32,
84
+ **kwargs):
85
+ super(ASTERDecoder, self).__init__()
86
+ self.num_classes = out_channels
87
+ self.bos = out_channels - 2
88
+ self.eos = 0
89
+ self.padding_idx = out_channels - 1
90
+ self.seed = seed
91
+ if seed:
92
+ self.embeder = Embedding(
93
+ in_timestep=time_step,
94
+ in_planes=in_channels,
95
+ )
96
+ self.word_embedding = nn.Embedding(self.num_classes,
97
+ embedding_dim,
98
+ padding_idx=self.padding_idx)
99
+
100
+ self.attndim = attndim
101
+ self.hiddendim = hiddendim
102
+ self.max_seq_len = max_len + 1
103
+
104
+ self.featdim = in_channels
105
+
106
+ self.attn_rnn_block = Attn_Rnn_Block(
107
+ featdim=self.featdim,
108
+ hiddendim=hiddendim,
109
+ embedding_dim=embedding_dim,
110
+ out_channels=out_channels - 2,
111
+ attndim=attndim,
112
+ )
113
+ self.embed_fc = nn.Linear(300, self.hiddendim)
114
+
115
+ def get_initial_state(self, embed, tile_times=1):
116
+ assert embed.shape[1] == 300
117
+ state = self.embed_fc(embed) # N * sDim
118
+ if tile_times != 1:
119
+ state = state.unsqueeze(1)
120
+ trans_state = state.transpose(0, 1)
121
+ state = trans_state.tile([tile_times, 1, 1])
122
+ trans_state = state.transpose(0, 1)
123
+ state = trans_state.reshape(-1, self.hiddendim)
124
+ state = state.unsqueeze(0) # 1 * N * sDim
125
+ return state
126
+
127
+ def forward(self, feat, data=None):
128
+ # b,25,512
129
+ b = feat.size(0)
130
+ if self.seed:
131
+ embedding_vectors = self.embeder(feat)
132
+ h_state = self.get_initial_state(embedding_vectors)
133
+ else:
134
+ h_state = torch.zeros(1, b, self.hiddendim).to(feat.device)
135
+ outputs = []
136
+ if self.training:
137
+ label = data[0]
138
+ label_embedding = self.word_embedding(label) # [B,25,256]
139
+ tokens = label_embedding[:, 0, :]
140
+ max_len = data[1].max() + 1
141
+ else:
142
+ tokens = torch.full([b, 1],
143
+ self.bos,
144
+ device=feat.device,
145
+ dtype=torch.long)
146
+ tokens = self.word_embedding(tokens.squeeze(1))
147
+ max_len = self.max_seq_len
148
+ pred, h_state = self.attn_rnn_block(feat, h_state, tokens)
149
+ outputs.append(pred)
150
+
151
+ dec_seq = torch.full((feat.shape[0], max_len),
152
+ self.padding_idx,
153
+ dtype=torch.int64,
154
+ device=feat.get_device())
155
+ dec_seq[:, :1] = torch.argmax(pred, dim=-1)
156
+ for i in range(1, max_len):
157
+ if not self.training:
158
+ max_idx = torch.argmax(pred, dim=-1).squeeze(1)
159
+ tokens = self.word_embedding(max_idx)
160
+ dec_seq[:, i] = max_idx
161
+ if (dec_seq == self.eos).any(dim=-1).all():
162
+ break
163
+ else:
164
+ tokens = label_embedding[:, i, :]
165
+ pred, h_state = self.attn_rnn_block(feat, h_state, tokens)
166
+ outputs.append(pred)
167
+ preds = torch.cat(outputs, 1)
168
+ if self.seed and self.training:
169
+ return [embedding_vectors, preds]
170
+ return preds if self.training else F.softmax(preds, -1)
@@ -0,0 +1,133 @@
1
+ """This code is refer from:
2
+ https://github.com/jjwei66/BUSNet
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .nrtr_decoder import PositionalEncoding, TransformerBlock
10
+ from .abinet_decoder import _get_mask, _get_length
11
+
12
+
13
+ class BUSDecoder(nn.Module):
14
+
15
+ def __init__(self,
16
+ in_channels,
17
+ out_channels,
18
+ nhead=8,
19
+ num_layers=4,
20
+ dim_feedforward=2048,
21
+ dropout=0.1,
22
+ max_length=25,
23
+ ignore_index=100,
24
+ pretraining=False,
25
+ detach=True):
26
+ super().__init__()
27
+ d_model = in_channels
28
+ self.ignore_index = ignore_index
29
+ self.pretraining = pretraining
30
+ self.d_model = d_model
31
+ self.detach = detach
32
+ self.max_length = max_length + 1 # additional stop token
33
+ self.out_channels = out_channels
34
+ # --------------------------------------------------------------------------
35
+ # decoder specifics
36
+ self.proj = nn.Linear(out_channels, d_model, False)
37
+ self.token_encoder = PositionalEncoding(dropout=0.1,
38
+ dim=d_model,
39
+ max_len=self.max_length)
40
+ self.pos_encoder = PositionalEncoding(dropout=0.1,
41
+ dim=d_model,
42
+ max_len=self.max_length)
43
+
44
+ self.decoder = nn.ModuleList([
45
+ TransformerBlock(
46
+ d_model=d_model,
47
+ nhead=nhead,
48
+ dim_feedforward=dim_feedforward,
49
+ attention_dropout_rate=dropout,
50
+ residual_dropout_rate=dropout,
51
+ with_self_attn=False,
52
+ with_cross_attn=True,
53
+ ) for i in range(num_layers)
54
+ ])
55
+
56
+ v_mask = torch.empty((1, 1, d_model))
57
+ l_mask = torch.empty((1, 1, d_model))
58
+ self.v_mask = nn.Parameter(v_mask)
59
+ self.l_mask = nn.Parameter(l_mask)
60
+ torch.nn.init.uniform_(self.v_mask, -0.001, 0.001)
61
+ torch.nn.init.uniform_(self.l_mask, -0.001, 0.001)
62
+
63
+ v_embeding = torch.empty((1, 1, d_model))
64
+ l_embeding = torch.empty((1, 1, d_model))
65
+ self.v_embeding = nn.Parameter(v_embeding)
66
+ self.l_embeding = nn.Parameter(l_embeding)
67
+ torch.nn.init.uniform_(self.v_embeding, -0.001, 0.001)
68
+ torch.nn.init.uniform_(self.l_embeding, -0.001, 0.001)
69
+ self.cls = nn.Linear(d_model, out_channels)
70
+
71
+ def forward_decoder(self, q, x, mask=None):
72
+ for decoder_layer in self.decoder:
73
+ q = decoder_layer(q, x, cross_mask=mask)
74
+ output = q # (N, T, E)
75
+ logits = self.cls(output) # (N, T, C)
76
+ return logits
77
+
78
+ def forward(self, img_feat, data=None):
79
+ """
80
+ Args:
81
+ tokens: (N, T, C) where T is length, N is batch size and C is classes number
82
+ lengths: (N,)
83
+ """
84
+ img_feat = img_feat + self.v_embeding
85
+ B, L, C = img_feat.shape
86
+
87
+ # --------------------------------------------------------------------------
88
+ # decoder procedure
89
+ T = self.max_length
90
+ zeros = img_feat.new_zeros((B, T, C))
91
+ zeros_len = img_feat.new_zeros(B)
92
+ query = self.pos_encoder(zeros)
93
+
94
+ # 1. vision decode
95
+ v_embed = torch.cat((img_feat, self.l_mask.repeat(B, T, 1)),
96
+ dim=1) # v
97
+ padding_mask = _get_mask(
98
+ self.max_length + zeros_len,
99
+ self.max_length) # 对tokens长度以外的padding # B, maxlen maxlen
100
+ v_mask = torch.zeros((1, 1, self.max_length, L),
101
+ device=img_feat.device).tile([B, 1, 1,
102
+ 1]) # maxlen L
103
+ mask = torch.cat((v_mask, padding_mask), 3)
104
+ v_logits = self.forward_decoder(query, v_embed, mask=mask)
105
+
106
+ # 2. language decode
107
+ if self.training and self.pretraining:
108
+ tgt = torch.where(data[0] == self.ignore_index, 0, data[0])
109
+ tokens = F.one_hot(tgt, num_classes=self.out_channels)
110
+ tokens = tokens.float()
111
+ lengths = data[-1]
112
+ else:
113
+ tokens = torch.softmax(v_logits, dim=-1)
114
+ lengths = _get_length(v_logits)
115
+ tokens = tokens.detach()
116
+ token_embed = self.proj(tokens) # (N, T, E)
117
+ token_embed = self.token_encoder(token_embed) # (T, N, E)
118
+ token_embed = token_embed + self.l_embeding
119
+
120
+ padding_mask = _get_mask(lengths,
121
+ self.max_length) # 对tokens长度以外的padding
122
+ mask = torch.cat((v_mask, padding_mask), 3)
123
+ l_embed = torch.cat((self.v_mask.repeat(B, L, 1), token_embed), dim=1)
124
+ l_logits = self.forward_decoder(query, l_embed, mask=mask)
125
+
126
+ # 3. vision language decode
127
+ vl_embed = torch.cat((img_feat, token_embed), dim=1)
128
+ vl_logits = self.forward_decoder(query, vl_embed, mask=mask)
129
+
130
+ if self.training:
131
+ return {'align': [vl_logits], 'lang': l_logits, 'vision': v_logits}
132
+ else:
133
+ return F.softmax(vl_logits, -1)
@@ -0,0 +1,43 @@
1
+ import torch.nn as nn
2
+
3
+ from .nrtr_decoder import NRTRDecoder
4
+
5
+
6
+ class CAMDecoder(nn.Module):
7
+
8
+ def __init__(
9
+ self,
10
+ in_channels,
11
+ out_channels,
12
+ nhead=None,
13
+ num_encoder_layers=6,
14
+ beam_size=0,
15
+ num_decoder_layers=6,
16
+ max_len=25,
17
+ attention_dropout_rate=0.0,
18
+ residual_dropout_rate=0.1,
19
+ scale_embedding=True,
20
+ ):
21
+ super().__init__()
22
+
23
+ self.decoder = NRTRDecoder(
24
+ in_channels=in_channels,
25
+ out_channels=out_channels,
26
+ nhead=nhead,
27
+ num_encoder_layers=num_encoder_layers,
28
+ beam_size=beam_size,
29
+ num_decoder_layers=num_decoder_layers,
30
+ max_len=max_len,
31
+ attention_dropout_rate=attention_dropout_rate,
32
+ residual_dropout_rate=residual_dropout_rate,
33
+ scale_embedding=scale_embedding,
34
+ )
35
+
36
+ def forward(self, x, data=None):
37
+ dec_in = x['refined_feat']
38
+ dec_output = self.decoder(dec_in, data=data)
39
+ x['rec_output'] = dec_output
40
+ if self.training:
41
+ return x
42
+ else:
43
+ return dec_output
@@ -0,0 +1,334 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from openrec.modeling.decoders.nrtr_decoder import Embeddings, PositionalEncoding, TransformerBlock # , Beam
6
+ from openrec.modeling.decoders.visionlan_decoder import Transformer_Encoder
7
+
8
+
9
+ def generate_square_subsequent_mask(sz):
10
+ r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
11
+ Unmasked positions are filled with float(0.0).
12
+ """
13
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
14
+ mask = (mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
15
+ mask == 1, float(0.0)))
16
+ return mask
17
+
18
+
19
+ class SEM_Pre(nn.Module):
20
+
21
+ def __init__(
22
+ self,
23
+ d_model=512,
24
+ dst_vocab_size=40,
25
+ residual_dropout_rate=0.1,
26
+ ):
27
+ super(SEM_Pre, self).__init__()
28
+ self.embedding = Embeddings(d_model=d_model, vocab=dst_vocab_size)
29
+
30
+ self.positional_encoding = PositionalEncoding(
31
+ dropout=residual_dropout_rate,
32
+ dim=d_model,
33
+ )
34
+
35
+ def forward(self, tgt):
36
+ tgt = self.embedding(tgt)
37
+ tgt = self.positional_encoding(tgt)
38
+ tgt_mask = generate_square_subsequent_mask(tgt.shape[1]).to(tgt.device)
39
+ return tgt, tgt_mask
40
+
41
+
42
+ class POS_Pre(nn.Module):
43
+
44
+ def __init__(
45
+ self,
46
+ d_model=512,
47
+ ):
48
+ super(POS_Pre, self).__init__()
49
+ self.pos_encoding = PositionalEncoding(
50
+ dropout=0.1,
51
+ dim=d_model,
52
+ )
53
+ self.linear1 = nn.Linear(d_model, d_model)
54
+ self.linear2 = nn.Linear(d_model, d_model)
55
+
56
+ self.norm2 = nn.LayerNorm(d_model)
57
+
58
+ def forward(self, tgt):
59
+ pos = tgt.new_zeros(*tgt.shape)
60
+ pos = self.pos_encoding(pos)
61
+
62
+ pos2 = self.linear2(F.relu(self.linear1(pos)))
63
+ pos = self.norm2(pos + pos2)
64
+ return pos
65
+
66
+
67
+ class DSF(nn.Module):
68
+
69
+ def __init__(self, d_model, fusion_num):
70
+ super(DSF, self).__init__()
71
+ self.w_att = nn.Linear(fusion_num * d_model, d_model)
72
+
73
+ def forward(self, l_feature, v_feature):
74
+ """
75
+ Args:
76
+ l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
77
+ v_feature: (N, T, E) shape the same as l_feature
78
+ l_lengths: (N,)
79
+ v_lengths: (N,)
80
+ """
81
+ f = torch.cat((l_feature, v_feature), dim=2)
82
+ f_att = torch.sigmoid(self.w_att(f))
83
+ output = f_att * v_feature + (1 - f_att) * l_feature
84
+
85
+ return output
86
+
87
+
88
+ class MDCDP(nn.Module):
89
+ r"""
90
+ Multi-Domain CharacterDistance Perception
91
+ """
92
+
93
+ def __init__(self, d_model, n_head, d_inner, num_layers):
94
+ super(MDCDP, self).__init__()
95
+
96
+ self.num_layers = num_layers
97
+
98
+ # step 1 SAE
99
+ self.layers_pos = nn.ModuleList([
100
+ TransformerBlock(d_model, n_head, d_inner)
101
+ for _ in range(num_layers)
102
+ ])
103
+
104
+ # step 2 CBI:
105
+ self.layers2 = nn.ModuleList([
106
+ TransformerBlock(
107
+ d_model,
108
+ n_head,
109
+ d_inner,
110
+ with_self_attn=False,
111
+ with_cross_attn=True,
112
+ ) for _ in range(num_layers)
113
+ ])
114
+ self.layers3 = nn.ModuleList([
115
+ TransformerBlock(
116
+ d_model,
117
+ n_head,
118
+ d_inner,
119
+ with_self_attn=False,
120
+ with_cross_attn=True,
121
+ ) for _ in range(num_layers)
122
+ ])
123
+
124
+ # step 3 :DSF
125
+ self.dynamic_shared_fusion = DSF(d_model, 2)
126
+
127
+ def forward(
128
+ self,
129
+ sem,
130
+ vis,
131
+ pos,
132
+ tgt_mask=None,
133
+ memory_mask=None,
134
+ ):
135
+
136
+ for i in range(self.num_layers):
137
+ # ----------step 1 -----------: SAE: Self-Attention Enhancement
138
+ pos = self.layers_pos[i](pos, self_mask=tgt_mask)
139
+
140
+ # ----------step 2 -----------: CBI: Cross-Branch Interaction
141
+
142
+ # CBI-V
143
+ pos_vis = self.layers2[i](
144
+ pos,
145
+ vis,
146
+ cross_mask=memory_mask,
147
+ )
148
+
149
+ # CBI-S
150
+ pos_sem = self.layers3[i](
151
+ pos,
152
+ sem,
153
+ cross_mask=tgt_mask,
154
+ )
155
+
156
+ # ----------step 3 -----------: DSF: Dynamic Shared Fusion
157
+ pos = self.dynamic_shared_fusion(pos_vis, pos_sem)
158
+
159
+ output = pos
160
+ return output
161
+
162
+
163
+ class ConvBnRelu(nn.Module):
164
+ # adapt padding for kernel_size change
165
+ def __init__(
166
+ self,
167
+ in_channels,
168
+ out_channels,
169
+ kernel_size,
170
+ conv=nn.Conv2d,
171
+ stride=2,
172
+ inplace=True,
173
+ ):
174
+ super().__init__()
175
+ p_size = [int(k // 2) for k in kernel_size]
176
+ # p_size = int(kernel_size//2)
177
+ self.conv = conv(
178
+ in_channels,
179
+ out_channels,
180
+ kernel_size=kernel_size,
181
+ stride=stride,
182
+ padding=p_size,
183
+ )
184
+ self.bn = nn.BatchNorm2d(out_channels)
185
+ self.relu = nn.ReLU(inplace=inplace)
186
+
187
+ def forward(self, x):
188
+ x = self.conv(x)
189
+ x = self.bn(x)
190
+ x = self.relu(x)
191
+ return x
192
+
193
+
194
+ class CDistNetDecoder(nn.Module):
195
+
196
+ def __init__(self,
197
+ in_channels,
198
+ out_channels,
199
+ n_head=None,
200
+ num_encoder_blocks=3,
201
+ num_decoder_blocks=3,
202
+ beam_size=0,
203
+ max_len=25,
204
+ residual_dropout_rate=0.1,
205
+ add_conv=False,
206
+ **kwargs):
207
+ super(CDistNetDecoder, self).__init__()
208
+ dst_vocab_size = out_channels
209
+ self.ignore_index = dst_vocab_size - 1
210
+ self.bos = dst_vocab_size - 2
211
+ self.eos = 0
212
+ self.beam_size = beam_size
213
+ self.max_len = max_len
214
+ self.add_conv = add_conv
215
+ d_model = in_channels
216
+ dim_feedforward = d_model * 4
217
+ n_head = n_head if n_head is not None else d_model // 32
218
+
219
+ if add_conv:
220
+ self.convbnrelu = ConvBnRelu(
221
+ in_channels=in_channels,
222
+ out_channels=in_channels,
223
+ kernel_size=(1, 3),
224
+ stride=(1, 2),
225
+ )
226
+ if num_encoder_blocks > 0:
227
+ self.positional_encoding = PositionalEncoding(
228
+ dropout=0.1,
229
+ dim=d_model,
230
+ )
231
+ self.trans_encoder = Transformer_Encoder(
232
+ n_layers=num_encoder_blocks,
233
+ n_head=n_head,
234
+ d_model=d_model,
235
+ d_inner=dim_feedforward,
236
+ )
237
+ else:
238
+ self.trans_encoder = None
239
+ self.semantic_branch = SEM_Pre(
240
+ d_model=d_model,
241
+ dst_vocab_size=dst_vocab_size,
242
+ residual_dropout_rate=residual_dropout_rate,
243
+ )
244
+ self.positional_branch = POS_Pre(d_model=d_model)
245
+
246
+ self.mdcdp = MDCDP(d_model, n_head, dim_feedforward // 2,
247
+ num_decoder_blocks)
248
+ self._reset_parameters()
249
+
250
+ self.tgt_word_prj = nn.Linear(
251
+ d_model, dst_vocab_size - 2,
252
+ bias=False) # We don't predict <bos> nor <pad>
253
+ self.tgt_word_prj.weight.data.normal_(mean=0.0, std=d_model**-0.5)
254
+
255
+ def forward(self, x, data=None):
256
+ if self.add_conv:
257
+ x = self.convbnrelu(x)
258
+ # x = rearrange(x, "b c h w -> b (w h) c")
259
+ x = x.flatten(2).transpose(1, 2)
260
+ if self.trans_encoder is not None:
261
+ x = self.positional_encoding(x)
262
+ vis_feat = self.trans_encoder(x, src_mask=None)
263
+ else:
264
+ vis_feat = x
265
+ if self.training:
266
+ max_len = data[1].max()
267
+ tgt = data[0][:, :1 + max_len]
268
+ res = self.forward_train(vis_feat, tgt)
269
+ else:
270
+ if self.beam_size > 0:
271
+ res = self.forward_beam(vis_feat)
272
+ else:
273
+ res = self.forward_test(vis_feat)
274
+ return res
275
+
276
+ def forward_train(self, vis_feat, tgt):
277
+ sem_feat, sem_mask = self.semantic_branch(tgt)
278
+ pos_feat = self.positional_branch(sem_feat)
279
+ output = self.mdcdp(
280
+ sem_feat,
281
+ vis_feat,
282
+ pos_feat,
283
+ tgt_mask=sem_mask,
284
+ memory_mask=None,
285
+ )
286
+
287
+ logit = self.tgt_word_prj(output)
288
+ return logit
289
+
290
+ def forward_test(self, vis_feat):
291
+ bs = vis_feat.size(0)
292
+
293
+ dec_seq = torch.full(
294
+ (bs, self.max_len + 1),
295
+ self.ignore_index,
296
+ dtype=torch.int64,
297
+ device=vis_feat.device,
298
+ )
299
+ dec_seq[:, 0] = self.bos
300
+ logits = []
301
+ for len_dec_seq in range(0, self.max_len):
302
+ sem_feat, sem_mask = self.semantic_branch(dec_seq[:, :len_dec_seq +
303
+ 1])
304
+ pos_feat = self.positional_branch(sem_feat)
305
+ output = self.mdcdp(
306
+ sem_feat,
307
+ vis_feat,
308
+ pos_feat,
309
+ tgt_mask=sem_mask,
310
+ memory_mask=None,
311
+ )
312
+
313
+ dec_output = output[:, -1:, :]
314
+
315
+ word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1)
316
+ logits.append(word_prob)
317
+ if len_dec_seq < self.max_len:
318
+ # greedy decode. add the next token index to the target input
319
+ dec_seq[:, len_dec_seq + 1] = word_prob.squeeze(1).argmax(-1)
320
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
321
+ if (dec_seq == self.eos).any(dim=-1).all():
322
+ break
323
+ logits = torch.cat(logits, dim=1)
324
+ return logits
325
+
326
+ def forward_beam(self, x):
327
+ """Translation work in one batch."""
328
+ # to do
329
+
330
+ def _reset_parameters(self):
331
+ r"""Initiate parameters in the transformer model."""
332
+ for p in self.parameters():
333
+ if p.dim() > 1:
334
+ nn.init.xavier_uniform_(p)