openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,815 @@
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from torch.nn.init import ones_, trunc_normal_, zeros_
6
+
7
+ from openrec.modeling.common import DropPath, Identity, Mlp
8
+ from openrec.modeling.decoders.nrtr_decoder import Embeddings
9
+
10
+
11
+ class CrossAttention(nn.Module):
12
+
13
+ def __init__(
14
+ self,
15
+ dim,
16
+ num_heads=8,
17
+ qkv_bias=False,
18
+ qk_scale=None,
19
+ attn_drop=0.0,
20
+ proj_drop=0.0,
21
+ ):
22
+ super().__init__()
23
+ self.num_heads = num_heads
24
+ head_dim = dim // num_heads
25
+ self.scale = qk_scale or head_dim**-0.5
26
+
27
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
28
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
29
+ self.attn_drop = nn.Dropout(attn_drop)
30
+ self.proj = nn.Linear(dim, dim)
31
+ self.proj_drop = nn.Dropout(proj_drop)
32
+
33
+ def forward(self, q, kv, key_mask=None):
34
+ N, C = kv.shape[1:]
35
+ QN = q.shape[1]
36
+ q = self.q(q).reshape([-1, QN, self.num_heads,
37
+ C // self.num_heads]).transpose(1, 2)
38
+ q = q * self.scale
39
+ k, v = self.kv(kv).reshape(
40
+ [-1, N, 2, self.num_heads,
41
+ C // self.num_heads]).permute(2, 0, 3, 1, 4)
42
+
43
+ attn = q.matmul(k.transpose(2, 3))
44
+
45
+ if key_mask is not None:
46
+ attn = attn + key_mask.unsqueeze(1)
47
+
48
+ attn = F.softmax(attn, -1)
49
+ if not self.training:
50
+ self.attn_map = attn
51
+ attn = self.attn_drop(attn)
52
+
53
+ x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C))
54
+ x = self.proj(x)
55
+ x = self.proj_drop(x)
56
+ return x
57
+
58
+
59
+ class EdgeDecoderLayer(nn.Module):
60
+
61
+ def __init__(
62
+ self,
63
+ dim,
64
+ num_heads,
65
+ mlp_ratio=4.0,
66
+ qkv_bias=False,
67
+ qk_scale=None,
68
+ drop=0.0,
69
+ attn_drop=0.0,
70
+ drop_path=[0.0, 0.0],
71
+ act_layer=nn.GELU,
72
+ norm_layer='nn.LayerNorm',
73
+ epsilon=1e-6,
74
+ ):
75
+ super().__init__()
76
+
77
+ self.head_dim = dim // num_heads
78
+ self.scale = qk_scale or self.head_dim**-0.5
79
+
80
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
81
+ self.drop_path1 = DropPath(
82
+ drop_path[0]) if drop_path[0] > 0.0 else Identity()
83
+ self.norm1 = eval(norm_layer)(dim, eps=epsilon)
84
+ self.norm2 = eval(norm_layer)(dim, eps=epsilon)
85
+
86
+ # self.c = nn.Linear(dim, dim*2)
87
+ self.p = nn.Linear(dim, dim)
88
+ self.cv = nn.Linear(dim, dim)
89
+ self.pv = nn.Linear(dim, dim)
90
+
91
+ self.dim = dim
92
+ self.num_heads = num_heads
93
+ self.p_proj = nn.Linear(dim, dim)
94
+ mlp_hidden_dim = int(dim * mlp_ratio)
95
+ self.mlp_ratio = mlp_ratio
96
+ self.mlp = Mlp(
97
+ in_features=dim,
98
+ hidden_features=mlp_hidden_dim,
99
+ act_layer=act_layer,
100
+ drop=drop,
101
+ )
102
+
103
+ def forward(self, p, cv, pv):
104
+ pN = p.shape[1]
105
+ vN = cv.shape[1]
106
+ p_shortcut = p
107
+
108
+ p1 = self.p(p).reshape(
109
+ [-1, pN, self.num_heads,
110
+ self.dim // self.num_heads]).transpose(1, 2)
111
+ cv1 = self.cv(cv).reshape(
112
+ [-1, vN, self.num_heads,
113
+ self.dim // self.num_heads]).transpose(1, 2)
114
+ pv1 = self.pv(pv).reshape(
115
+ [-1, vN, self.num_heads,
116
+ self.dim // self.num_heads]).transpose(1, 2)
117
+
118
+ edge = F.softmax(p1.matmul(pv1.transpose(2, 3)), -1) # B h N N
119
+
120
+ p_c = (edge @ cv1).transpose(1, 2).reshape((-1, pN, self.dim))
121
+
122
+ x1 = self.norm1(p_shortcut + self.drop_path1(self.p_proj(p_c)))
123
+
124
+ x = self.norm2(x1 + self.drop_path1(self.mlp(x1)))
125
+ return x
126
+
127
+
128
+ class DecoderLayer(nn.Module):
129
+
130
+ def __init__(
131
+ self,
132
+ dim,
133
+ num_heads,
134
+ mlp_ratio=4.0,
135
+ qkv_bias=False,
136
+ qk_scale=None,
137
+ drop=0.0,
138
+ attn_drop=0.0,
139
+ drop_path=0.0,
140
+ act_layer=nn.GELU,
141
+ norm_layer='nn.LayerNorm',
142
+ epsilon=1e-6,
143
+ ):
144
+ super().__init__()
145
+ self.norm1 = eval(norm_layer)(dim, eps=epsilon)
146
+ self.normkv = eval(norm_layer)(dim, eps=epsilon)
147
+
148
+ self.mixer = CrossAttention(
149
+ dim,
150
+ num_heads=num_heads,
151
+ qkv_bias=qkv_bias,
152
+ qk_scale=qk_scale,
153
+ attn_drop=attn_drop,
154
+ proj_drop=drop,
155
+ )
156
+
157
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
158
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
159
+
160
+ self.norm2 = eval(norm_layer)(dim, eps=epsilon)
161
+
162
+ mlp_hidden_dim = int(dim * mlp_ratio)
163
+ self.mlp_ratio = mlp_ratio
164
+ self.mlp = Mlp(
165
+ in_features=dim,
166
+ hidden_features=mlp_hidden_dim,
167
+ act_layer=act_layer,
168
+ drop=drop,
169
+ )
170
+
171
+ def forward(self, q, kv, key_mask=None):
172
+ x1 = q + self.drop_path(
173
+ self.mixer(self.norm1(q), self.normkv(kv), key_mask))
174
+ x = x1 + self.drop_path(self.mlp(self.norm2(x1)))
175
+ return x
176
+
177
+
178
+ class CMFFLayer(nn.Module):
179
+
180
+ def __init__(
181
+ self,
182
+ dim,
183
+ num_heads,
184
+ mlp_ratio=4.0,
185
+ qkv_bias=False,
186
+ qk_scale=None,
187
+ drop=0.0,
188
+ attn_drop=0.0,
189
+ drop_path=0.0,
190
+ act_layer=nn.GELU,
191
+ epsilon=1e-6,
192
+ ):
193
+ super().__init__()
194
+ self.normq1 = nn.LayerNorm(dim, eps=epsilon)
195
+ self.normkv1 = nn.LayerNorm(dim, eps=epsilon)
196
+ self.images_to_question_cross_attn = CrossAttention(
197
+ dim,
198
+ num_heads=num_heads,
199
+ qkv_bias=qkv_bias,
200
+ qk_scale=qk_scale,
201
+ attn_drop=attn_drop,
202
+ proj_drop=drop,
203
+ )
204
+ self.normq2 = nn.LayerNorm(dim, eps=epsilon)
205
+ self.normkv2 = nn.LayerNorm(dim, eps=epsilon)
206
+ self.question_to_images_cross_attn = CrossAttention(
207
+ dim,
208
+ num_heads=num_heads,
209
+ qkv_bias=qkv_bias,
210
+ qk_scale=qk_scale,
211
+ attn_drop=attn_drop,
212
+ proj_drop=drop,
213
+ )
214
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
215
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
216
+ self.normmlp = nn.LayerNorm(dim, eps=epsilon)
217
+ mlp_hidden_dim = int(dim * mlp_ratio)
218
+ self.mlp = Mlp(
219
+ in_features=dim,
220
+ hidden_features=mlp_hidden_dim,
221
+ act_layer=act_layer,
222
+ drop=drop,
223
+ )
224
+
225
+ def forward(self, question_f, prompt_f, visual_f, mask=None):
226
+
227
+ query_add = torch.concat([question_f, prompt_f, visual_f], 1)
228
+
229
+ query_add = query_add + self.drop_path(
230
+ self.images_to_question_cross_attn(self.normq1(query_add),
231
+ self.normkv1(prompt_f), mask))
232
+ query_add = query_add + self.drop_path(
233
+ self.question_to_images_cross_attn(
234
+ self.normq2(query_add),
235
+ self.normkv2(query_add[:, -visual_f.shape[1]:, :])))
236
+ query_updated = query_add + self.drop_path(
237
+ self.mlp(self.normmlp(query_add)))
238
+
239
+ question_f_updated = query_updated[:, :question_f.shape[1], :]
240
+ prompt_f_updated = query_updated[:, question_f.
241
+ shape[1]:-visual_f.shape[1], :]
242
+ visual_f_updated = query_updated[:, -visual_f.shape[1]:, :]
243
+
244
+ return question_f_updated, prompt_f_updated, visual_f_updated
245
+
246
+
247
+ class IGTRDecoder(nn.Module):
248
+
249
+ def __init__(self,
250
+ in_channels,
251
+ dim,
252
+ out_channels,
253
+ num_layer=2,
254
+ drop_path_rate=0.1,
255
+ max_len=25,
256
+ vis_seq=50,
257
+ ch=False,
258
+ ar=False,
259
+ refine_iter=0,
260
+ quesall=True,
261
+ next_pred=False,
262
+ ds=False,
263
+ pos2d=False,
264
+ check_search=False,
265
+ max_size=[8, 32],
266
+ **kwargs):
267
+ super(IGTRDecoder, self).__init__()
268
+
269
+ self.out_channels = out_channels
270
+ self.dim = dim
271
+ self.max_len = max_len + 3 # max_len + eos + bos
272
+ self.ch = ch
273
+ self.char_embed = Embeddings(d_model=dim,
274
+ vocab=self.out_channels,
275
+ scale_embedding=True)
276
+ self.ignore_index = out_channels - 1
277
+ self.ar = ar
278
+ self.refine_iter = refine_iter
279
+ self.bos = self.out_channels - 2
280
+ self.eos = 0
281
+ self.next_pred = next_pred
282
+ self.quesall = quesall
283
+ self.check_search = check_search
284
+ dpr = np.linspace(0, drop_path_rate, num_layer + 2)
285
+
286
+ self.cmff_decoder = nn.ModuleList([
287
+ CMFFLayer(dim=dim,
288
+ num_heads=dim // 32,
289
+ mlp_ratio=4.0,
290
+ qkv_bias=True,
291
+ drop_path=dpr[i]) for i in range(num_layer)
292
+ ])
293
+
294
+ self.answer_to_question_layer = DecoderLayer(dim=dim,
295
+ num_heads=dim // 32,
296
+ mlp_ratio=4.0,
297
+ qkv_bias=True,
298
+ drop_path=dpr[-2])
299
+ self.answer_to_image_layer = DecoderLayer(dim=dim,
300
+ num_heads=dim // 32,
301
+ mlp_ratio=4.0,
302
+ qkv_bias=True,
303
+ drop_path=dpr[-1])
304
+
305
+ self.char_pos_embed = nn.Parameter(torch.zeros([self.max_len, dim],
306
+ dtype=torch.float32),
307
+ requires_grad=True)
308
+ self.appear_num_embed = nn.Parameter(torch.zeros([self.max_len, dim],
309
+ dtype=torch.float32),
310
+ requires_grad=True)
311
+ self.ds = ds
312
+ self.pos2d = pos2d
313
+ if not ds:
314
+ self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim],
315
+ dtype=torch.float32),
316
+ requires_grad=True)
317
+ trunc_normal_(self.vis_pos_embed, std=0.02)
318
+ elif pos2d:
319
+ pos_embed = torch.zeros([1, max_size[0] * max_size[1], dim],
320
+ dtype=torch.float32)
321
+ trunc_normal_(pos_embed, mean=0, std=0.02)
322
+ self.vis_pos_embed = nn.Parameter(
323
+ pos_embed.transpose(1, 2).reshape(1, dim, max_size[0],
324
+ max_size[1]),
325
+ requires_grad=True,
326
+ )
327
+ self.prompt_pos_embed = nn.Parameter(torch.zeros([1, 6, dim],
328
+ dtype=torch.float32),
329
+ requires_grad=True)
330
+
331
+ self.answer_query = nn.Parameter(torch.zeros([1, 1, dim],
332
+ dtype=torch.float32),
333
+ requires_grad=True)
334
+ self.norm_pred = nn.LayerNorm(dim, eps=1e-6)
335
+ self.ques1_head = nn.Linear(dim, self.out_channels - 2)
336
+ self.ques2_head = nn.Linear(dim, self.max_len, bias=False)
337
+ self.ques3_head = nn.Linear(dim, self.max_len - 1)
338
+ self.ques4_head = nn.Linear(dim, self.max_len - 1)
339
+ trunc_normal_(self.char_pos_embed, std=0.02)
340
+ trunc_normal_(self.appear_num_embed, std=0.02)
341
+ trunc_normal_(self.answer_query, std=0.02)
342
+ trunc_normal_(self.prompt_pos_embed, std=0.02)
343
+ self.apply(self._init_weights)
344
+
345
+ def _init_weights(self, m):
346
+ if isinstance(m, nn.Linear):
347
+ trunc_normal_(m.weight, std=0.02)
348
+ if isinstance(m, nn.Linear) and m.bias is not None:
349
+ zeros_(m.bias)
350
+ elif isinstance(m, nn.LayerNorm):
351
+ zeros_(m.bias)
352
+ ones_(m.weight)
353
+
354
+ @torch.jit.ignore
355
+ def no_weight_decay(self):
356
+ return {
357
+ 'char_pos_embed', 'vis_pos_embed', 'appear_num_embed',
358
+ 'answer_query', 'char_embed'
359
+ }
360
+
361
+ def question_encoder(self, targets, train_i):
362
+ (
363
+ prompt_pos_idx,
364
+ prompt_char_idx,
365
+ ques_pos_idx,
366
+ ques1_answer,
367
+ ques2_char_idx,
368
+ ques2_answer,
369
+ ques4_char_num,
370
+ ques_len,
371
+ ques2_len,
372
+ prompt_len,
373
+ ) = targets
374
+ max_ques_len = torch.max(ques_len)
375
+ max_ques2_len = torch.max(ques2_len)
376
+ max_prompt_len = torch.max(prompt_len)
377
+ if self.next_pred and (train_i == 2 or train_i == 3):
378
+ prompt_pos = self.prompt_pos_embed
379
+ prompt_char_idx = prompt_char_idx[:, :max_prompt_len]
380
+ else:
381
+ prompt_pos = F.embedding(
382
+ prompt_pos_idx[:, :max_prompt_len], self.char_pos_embed
383
+ ) # bs lp [ 0, 4, 3, 12, 12, 12, 12, 12, 12, 12, 12]
384
+ prompt_char_idx = prompt_char_idx[:, :max_prompt_len]
385
+ prompt_char = self.char_embed(prompt_char_idx) # bs lp
386
+
387
+ prompt = prompt_pos + prompt_char
388
+ mask_1234 = torch.where(prompt_char_idx == self.ignore_index,
389
+ float('-inf'), 0)
390
+
391
+ ques1 = F.embedding(ques_pos_idx[:, :max_ques_len],
392
+ self.char_pos_embed) # bs lq1 dim
393
+ ques1_answer = ques1_answer[:, :max_ques_len]
394
+ if self.quesall or train_i == 0:
395
+ ques2_char = self.char_embed(ques2_char_idx[:, :max_ques2_len, 1])
396
+ ques2 = ques2_char + F.embedding(ques2_char_idx[:, :max_ques2_len,
397
+ 0],
398
+ self.char_pos_embed) # bs lq2 dim
399
+ ques2_answer = ques2_answer[:, :max_ques2_len]
400
+ ques2_head = F.embedding(ques2_char_idx[:, :max_ques2_len, 0],
401
+ self.ques2_head.weight)
402
+ ques4_char = self.char_embed(ques1_answer)
403
+ ques4_ap_num = F.embedding(ques4_char_num[:, :max_ques_len],
404
+ self.appear_num_embed)
405
+ ques4 = ques4_char + ques4_ap_num
406
+ ques4_answer = ques_pos_idx[:, :max_ques_len]
407
+
408
+ return (
409
+ prompt,
410
+ ques1,
411
+ ques2,
412
+ ques2_head,
413
+ ques4,
414
+ ques1_answer,
415
+ ques2_answer,
416
+ ques4_answer,
417
+ mask_1234.unsqueeze(1),
418
+ )
419
+ else:
420
+ return prompt, ques1, ques1_answer, mask_1234.unsqueeze(1)
421
+
422
+ def forward(self, x, data=None):
423
+ if self.training:
424
+ return self.forward_train(x, data)
425
+ else:
426
+ return self.forward_test(x)
427
+
428
+ def forward_test(self, x):
429
+ if not self.ds:
430
+ visual_f = x + self.vis_pos_embed
431
+ elif self.pos2d:
432
+ x = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
433
+ visual_f = x.flatten(2).transpose(1, 2)
434
+ else:
435
+ visual_f = x
436
+ bs = x.shape[0]
437
+ prompt_bos = self.char_embed(
438
+ torch.full(
439
+ [bs, 1], self.bos, dtype=torch.long,
440
+ device=x.get_device())) + self.char_pos_embed[:1, :].unsqueeze(
441
+ 0) # BOS prompt
442
+ ques_all = torch.tile(self.char_pos_embed.unsqueeze(0), (bs, 1, 1))
443
+ if not self.ar:
444
+ if self.check_search:
445
+ tgt_in = torch.full((bs, self.max_len),
446
+ self.ignore_index,
447
+ dtype=torch.long,
448
+ device=x.get_device())
449
+ tgt_in[:, 0] = self.bos
450
+ logits = []
451
+ for j in range(1, self.max_len):
452
+ visual_f_check = visual_f
453
+ ques_check_i = ques_all[:, j:j + 1, :] + self.char_embed(
454
+ torch.arange(self.out_channels - 2,
455
+ device=x.get_device())).unsqueeze(0)
456
+ prompt_check = ques_all[:, :j] + self.char_embed(
457
+ tgt_in[:, :j])
458
+ # prompt_check = prompt_bos
459
+ mask = torch.where(
460
+ (tgt_in[:, :j] == self.eos).int().cumsum(-1) > 0,
461
+ float('-inf'), 0)
462
+ for layer in self.cmff_decoder:
463
+ ques_check_i, prompt_check, visual_f_check = layer(
464
+ ques_check_i, prompt_check, visual_f_check,
465
+ mask.unsqueeze(1))
466
+ answer_query_i = self.answer_to_question_layer(
467
+ ques_check_i, prompt_check, mask.unsqueeze(1))
468
+ answer_pred_i = self.norm_pred(
469
+ self.answer_to_image_layer(
470
+ answer_query_i, visual_f_check)) # B, 26, 37
471
+ # the next token probability is in the output's ith token position
472
+ fc_2 = self.ques2_head.weight[j:j + 1].unsqueeze(0)
473
+ fc_2 = fc_2.tile([bs, 1, 1])
474
+ p_i = fc_2 @ answer_pred_i.transpose(1, 2)
475
+ # p_i = p_i[:, 0, :]
476
+ logits.append(p_i)
477
+ if j < self.max_len - 1:
478
+ # greedy decode. add the next token index to the target input
479
+ tgt_in[:, j] = p_i.squeeze().argmax(-1)
480
+
481
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
482
+ if (tgt_in == self.eos).any(dim=-1).all():
483
+ break
484
+ logits = torch.cat(logits, dim=1)
485
+ else:
486
+ ques_pd = ques_all[:, 1:, :]
487
+ prompt_pd = prompt_bos
488
+ visual_f_pd = visual_f
489
+ for layer in self.cmff_decoder:
490
+ ques_pd, prompt_pd, visual_f_pd = layer(
491
+ ques_pd, prompt_pd, visual_f_pd)
492
+ answer_query_pd = self.answer_to_question_layer(
493
+ ques_pd, prompt_pd)
494
+ answer_feats_pd = self.norm_pred(
495
+ self.answer_to_image_layer(answer_query_pd,
496
+ visual_f_pd)) # B, 26, 37
497
+ logits = self.ques1_head(answer_feats_pd)
498
+ elif self.next_pred:
499
+ ques_pd_1 = ques_all[:, 1:2, :]
500
+ prompt_pd = prompt_bos
501
+ visual_f_pd = visual_f
502
+ for layer in self.cmff_decoder:
503
+ ques_pd_1, prompt_pd, visual_f_pd = layer(
504
+ ques_pd_1, prompt_pd, visual_f_pd)
505
+ answer_query_pd = self.answer_to_question_layer(
506
+ ques_pd_1, prompt_pd)
507
+ answer_feats_pd = self.norm_pred(
508
+ self.answer_to_image_layer(answer_query_pd,
509
+ visual_f_pd)) # B, 26, 37
510
+ logits_pd_1 = self.ques1_head(answer_feats_pd)
511
+
512
+ ques_next = self.char_pos_embed[-2:-1, :].unsqueeze(0).tile(
513
+ [bs, 1, 1])
514
+ prompt_next_bos = (self.char_embed(
515
+ torch.full(
516
+ [bs, 1], self.bos, dtype=torch.long,
517
+ device=x.get_device())) + self.prompt_pos_embed[:, :1, :])
518
+ pred_prob, pred_id = F.softmax(logits_pd_1, -1).max(-1)
519
+ pred_prob_list = [pred_prob]
520
+ pred_id_list = [pred_id]
521
+ for j in range(1, 70):
522
+ prompt_next_1 = self.char_embed(
523
+ pred_id) + self.prompt_pos_embed[:,
524
+ -1 * pred_id.shape[1]:, :]
525
+ prompt_next = torch.concat([prompt_next_bos, prompt_next_1], 1)
526
+ ques_next_i = ques_next
527
+ visual_f_i = visual_f
528
+ for layer in self.cmff_decoder:
529
+ ques_next_i, prompt_next, visual_f_pd = layer(
530
+ ques_next_i, prompt_next, visual_f_i)
531
+ answer_query_next_i = self.answer_to_question_layer(
532
+ ques_next_i, prompt_next)
533
+ answer_feats_next_i = self.norm_pred(
534
+ self.answer_to_image_layer(answer_query_next_i,
535
+ visual_f_i)) # B, 26, 37
536
+ logits_next_i = self.ques1_head(answer_feats_next_i)
537
+ # pred_id = logits_next_i.argmax(-1)
538
+ pred_prob_i, pred_id_i = F.softmax(logits_next_i, -1).max(-1)
539
+ pred_prob_list.append(pred_prob_i)
540
+ pred_id_list.append(pred_id_i)
541
+ if (torch.concat(pred_id_list,
542
+ 1) == self.eos).any(dim=-1).all():
543
+ break
544
+ if pred_id.shape[1] >= 5:
545
+ pred_id = torch.concat([pred_id[:, 1:], pred_id_i], 1)
546
+ else:
547
+ pred_id = torch.concat([pred_id, pred_id_i], 1)
548
+ return [
549
+ torch.concat(pred_id_list, 1),
550
+ torch.concat(pred_prob_list, 1)
551
+ ]
552
+
553
+ else:
554
+ tgt_in = torch.full((bs, self.max_len),
555
+ self.ignore_index,
556
+ dtype=torch.long,
557
+ device=x.get_device())
558
+ tgt_in[:, 0] = self.bos
559
+ logits = []
560
+ for j in range(1, self.max_len):
561
+ visual_f_ar = visual_f
562
+ ques_i = ques_all[:, j:j + 1, :]
563
+ prompt_ar = ques_all[:, :j] + self.char_embed(tgt_in[:, :j])
564
+ mask = torch.where(
565
+ (tgt_in[:, :j] == self.eos).int().cumsum(-1) > 0,
566
+ float('-inf'), 0)
567
+ for layer in self.cmff_decoder:
568
+ ques_i, prompt_ar, visual_f_ar = layer(
569
+ ques_i, prompt_ar, visual_f_ar, mask.unsqueeze(1))
570
+ answer_query_i = self.answer_to_question_layer(
571
+ ques_i, prompt_ar, mask.unsqueeze(1))
572
+ answer_pred_i = self.norm_pred(
573
+ self.answer_to_image_layer(answer_query_i,
574
+ visual_f_ar)) # B, 26, 37
575
+ # the next token probability is in the output's ith token position
576
+ p_i = self.ques1_head(answer_pred_i)
577
+ logits.append(p_i)
578
+ if j < self.max_len - 1:
579
+ # greedy decode. add the next token index to the target input
580
+ tgt_in[:, j] = p_i.squeeze().argmax(-1)
581
+
582
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
583
+ if (tgt_in == self.eos).any(dim=-1).all():
584
+ break
585
+ logits = torch.cat(logits, dim=1)
586
+
587
+ if self.refine_iter > 0:
588
+ pred_probs, pred_idxs = F.softmax(logits, -1).max(-1)
589
+ for i in range(self.refine_iter):
590
+
591
+ mask_check = (pred_idxs == self.eos).int().cumsum(-1) <= 1
592
+
593
+ ques_check_all = self.char_embed(
594
+ pred_idxs) + ques_all[:, 1:pred_idxs.shape[1] + 1, :]
595
+ prompt_check = prompt_bos
596
+ visual_f_check = visual_f
597
+ ques_check = ques_check_all
598
+ for layer in self.cmff_decoder:
599
+ ques_check, prompt_check, visual_f_check = layer(
600
+ ques_check, prompt_check, visual_f_check)
601
+ answer_query_check = self.answer_to_question_layer(
602
+ ques_check, prompt_check)
603
+ answer_pred_check = self.norm_pred(
604
+ self.answer_to_image_layer(answer_query_check,
605
+ visual_f_check)) # B, 26, 37
606
+ ques2_head = self.ques2_head.weight[1:pred_idxs.shape[1] +
607
+ 1, :]
608
+ ques2_head = torch.tile(ques2_head.unsqueeze(0), [bs, 1, 1])
609
+ answer2_pred = answer_pred_check.matmul(
610
+ ques2_head.transpose(1, 2))
611
+ diag_mask = torch.eye(answer2_pred.shape[1],
612
+ device=x.get_device()).unsqueeze(0).tile(
613
+ [bs, 1, 1])
614
+ answer2_pred = F.sigmoid(
615
+ (answer2_pred * diag_mask).sum(-1)) * mask_check
616
+
617
+ check_result = answer2_pred < 0.9 # pred_probs < 0.99
618
+
619
+ prompt_refine = torch.concat([prompt_bos, ques_check_all], 1)
620
+ mask_refine = torch.where(
621
+ check_result, float('-inf'), 0) + torch.where(
622
+ (pred_idxs == self.eos).int().cumsum(-1) < 1, 0,
623
+ float('-inf'))
624
+ mask_refine = torch.concat(
625
+ [torch.zeros([bs, 1], device=x.get_device()), mask_refine],
626
+ 1).unsqueeze(1)
627
+ ques_refine = ques_all[:, 1:pred_idxs.shape[1] + 1, :]
628
+ visual_f_refine = visual_f
629
+ for layer in self.cmff_decoder:
630
+ ques_refine, prompt_refine, visual_f_refine = layer(
631
+ ques_refine, prompt_refine, visual_f_refine,
632
+ mask_refine)
633
+ answer_query_refine = self.answer_to_question_layer(
634
+ ques_refine, prompt_refine, mask_refine)
635
+ answer_pred_refine = self.norm_pred(
636
+ self.answer_to_image_layer(answer_query_refine,
637
+ visual_f_refine)) # B, 26, 37
638
+ answer_refine = self.ques1_head(answer_pred_refine)
639
+ refine_probs, refine_idxs = F.softmax(answer_refine,
640
+ -1).max(-1)
641
+ pred_idxs_refine = torch.where(check_result, refine_idxs,
642
+ pred_idxs)
643
+ pred_idxs = torch.where(mask_check, pred_idxs_refine,
644
+ pred_idxs)
645
+ pred_probs_refine = torch.where(check_result, refine_probs,
646
+ pred_probs)
647
+ pred_probs = torch.where(mask_check, pred_probs_refine,
648
+ pred_probs)
649
+
650
+ return [pred_idxs, pred_probs]
651
+
652
+ return F.softmax(logits, -1)
653
+
654
+ def forward_train(self, x, targets=None):
655
+
656
+ bs = x.shape[0]
657
+ answer_token = torch.tile(self.answer_query, (bs, 1, 1))
658
+ if self.ch:
659
+ ques3 = self.char_embed(targets[7][:, :,
660
+ 0]) + answer_token # bs nc dim
661
+ ques3_answer = targets[7][:, :, 1]
662
+ else:
663
+ ques3 = self.char_embed(
664
+ torch.arange(self.out_channels - 2, device=x.get_device())
665
+ ).unsqueeze(0) + answer_token # bs nc dim
666
+ ques3_answer = targets[7]
667
+ loss1_list = []
668
+ loss2_list = []
669
+ loss3_list = []
670
+ loss4_list = []
671
+ sampler1_num = 0
672
+ sampler2_num = 0
673
+ sampler3_num = 0
674
+ sampler4_num = 0
675
+ if not self.ds:
676
+ visual_f = x + self.vis_pos_embed
677
+ elif self.pos2d:
678
+ x = x + self.vis_pos_embed[:, :, :x.shape[2], :x.shape[3]]
679
+ visual_f = x.flatten(2).transpose(1, 2)
680
+ else:
681
+ visual_f = x
682
+ train_i = 0
683
+ for target_ in zip(
684
+ targets[1].transpose(0, 1),
685
+ targets[2].transpose(0, 1),
686
+ targets[3].transpose(0, 1),
687
+ targets[4].transpose(0, 1),
688
+ targets[5].transpose(0, 1),
689
+ targets[6].transpose(0, 1),
690
+ targets[8].transpose(0, 1),
691
+ targets[9].transpose(0, 1),
692
+ targets[10].transpose(0, 1),
693
+ targets[11].transpose(0, 1),
694
+ ):
695
+ # target_ = [prompt_pos_idx, prompt_char_idx, ques_pos_idx, ques1_answer, \
696
+ # ques2_char_idx, ques2_answer, ques4_char_num, ques_len, prompt_len]
697
+ visual_f_1234 = visual_f
698
+ if self.quesall or train_i == 0:
699
+ (
700
+ prompt,
701
+ ques1,
702
+ ques2,
703
+ ques2_head,
704
+ ques4,
705
+ ques1_answer,
706
+ ques2_answer,
707
+ ques4_answer,
708
+ mask_1234,
709
+ ) = self.question_encoder(target_, train_i)
710
+ prompt_1234 = prompt
711
+ ques_1234 = torch.concat([ques1, ques2, ques3, ques4], 1)
712
+ for layer in self.cmff_decoder:
713
+ ques_1234, prompt_1234, visual_f_1234 = layer(
714
+ ques_1234, prompt_1234, visual_f_1234, mask_1234)
715
+ answer_query_1234 = self.answer_to_question_layer(
716
+ ques_1234, prompt_1234, mask_1234)
717
+ answer_feats_1234 = self.norm_pred(
718
+ self.answer_to_image_layer(answer_query_1234,
719
+ visual_f_1234)) # B, 26, 37
720
+
721
+ answer_feats_1 = answer_feats_1234[:, :ques1.shape[1], :]
722
+ answer_feats_2 = answer_feats_1234[:, ques1.shape[1]:(
723
+ ques1.shape[1] + ques2.shape[1]), :]
724
+ answer_feats_3 = answer_feats_1234[:, (
725
+ ques1.shape[1] + ques2.shape[1]):-ques4.shape[1], :]
726
+ answer_feats_4 = answer_feats_1234[:, -ques4.shape[1]:, :]
727
+
728
+ answer1_pred = self.ques1_head(answer_feats_1)
729
+ if train_i == 0:
730
+ logits = answer1_pred
731
+
732
+ n = (ques1_answer != self.ignore_index).sum().item()
733
+ loss1 = n * F.cross_entropy(
734
+ answer1_pred.flatten(0, 1),
735
+ ques1_answer.flatten(0, 1),
736
+ ignore_index=self.ignore_index,
737
+ reduction='mean',
738
+ )
739
+ sampler1_num += n
740
+ loss1_list.append(loss1)
741
+
742
+ answer2_pred = answer_feats_2.matmul(ques2_head.transpose(
743
+ 1, 2))
744
+ diag_mask = torch.eye(answer2_pred.shape[1],
745
+ device=x.get_device()).unsqueeze(0).tile(
746
+ [bs, 1, 1])
747
+ answer2_pred = (answer2_pred * diag_mask).sum(-1)
748
+
749
+ ques2_answer = ques2_answer.flatten(0, 1)
750
+ non_pad_mask = torch.not_equal(ques2_answer, self.ignore_index)
751
+ n = non_pad_mask.sum().item()
752
+ ques2_answer = torch.where(ques2_answer == self.ignore_index,
753
+ 0, ques2_answer)
754
+ loss2_none = F.binary_cross_entropy_with_logits(
755
+ answer2_pred.flatten(0, 1), ques2_answer, reduction='none')
756
+ loss2 = n * loss2_none.masked_select(non_pad_mask).mean()
757
+ sampler2_num += n
758
+ loss2_list.append(loss2)
759
+
760
+ answer3_pred = self.ques3_head(answer_feats_3)
761
+ n = (ques3_answer != self.ignore_index).sum().item()
762
+ loss3 = n * F.cross_entropy(answer3_pred.flatten(0, 1),
763
+ ques3_answer.flatten(0, 1),
764
+ reduction='mean')
765
+ sampler3_num += n
766
+ loss3_list.append(loss3)
767
+
768
+ answer4_pred = self.ques4_head(answer_feats_4)
769
+ n = (ques4_answer != self.max_len - 1).sum().item()
770
+ loss4 = n * F.cross_entropy(
771
+ answer4_pred.flatten(0, 1),
772
+ ques4_answer.flatten(0, 1),
773
+ ignore_index=self.max_len - 1,
774
+ reduction='mean',
775
+ )
776
+ sampler4_num += n
777
+ loss4_list.append(loss4)
778
+ else:
779
+ prompt, ques1, ques1_answer, mask_1234 = self.question_encoder(
780
+ target_, train_i)
781
+ prompt_1234 = prompt
782
+ for layer in self.cmff_decoder:
783
+ ques1, prompt_1234, visual_f_1234 = layer(
784
+ ques1, prompt_1234, visual_f_1234, mask_1234)
785
+ answer_query_1 = self.answer_to_question_layer(
786
+ ques1, prompt_1234, mask_1234)
787
+ answer_feats_1 = self.norm_pred(
788
+ self.answer_to_image_layer(answer_query_1,
789
+ visual_f_1234)) # B, 26, 37
790
+ answer1_pred = self.ques1_head(answer_feats_1)
791
+ n = (ques1_answer != self.ignore_index).sum().item()
792
+ loss1 = n * F.cross_entropy(
793
+ answer1_pred.flatten(0, 1),
794
+ ques1_answer.flatten(0, 1),
795
+ ignore_index=self.ignore_index,
796
+ reduction='mean',
797
+ )
798
+ sampler1_num += n
799
+ loss1_list.append(loss1)
800
+ train_i += 1
801
+
802
+ loss_list = [
803
+ sum(loss1_list) / sampler1_num,
804
+ sum(loss2_list) / sampler2_num,
805
+ sum(loss3_list) / sampler3_num,
806
+ sum(loss4_list) / sampler4_num,
807
+ ]
808
+ loss = {
809
+ 'loss': sum(loss_list),
810
+ 'loss1': loss_list[0],
811
+ 'loss2': loss_list[1],
812
+ 'loss3': loss_list[2],
813
+ 'loss4': loss_list[3],
814
+ }
815
+ return [loss, logits]