openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,439 @@
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from openrec.modeling.common import Mlp
9
+
10
+
11
+ class NRTRDecoder(nn.Module):
12
+ """A transformer model. User is able to modify the attributes as needed.
13
+ The architechture is based on the paper "Attention Is All You Need". Ashish
14
+ Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N
15
+ Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you
16
+ need. In Advances in Neural Information Processing Systems, pages
17
+ 6000-6010.
18
+
19
+ Args:
20
+ d_model: the number of expected features in the encoder/decoder inputs (default=512).
21
+ nhead: the number of heads in the multiheadattention models (default=8).
22
+ num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
23
+ num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
24
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
25
+ dropout: the dropout value (default=0.1).
26
+ custom_encoder: custom encoder (default=None).
27
+ custom_decoder: custom decoder (default=None).
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ in_channels,
33
+ out_channels,
34
+ nhead=None,
35
+ num_encoder_layers=6,
36
+ beam_size=0,
37
+ num_decoder_layers=6,
38
+ max_len=25,
39
+ attention_dropout_rate=0.0,
40
+ residual_dropout_rate=0.1,
41
+ scale_embedding=True,
42
+ ):
43
+ super(NRTRDecoder, self).__init__()
44
+ self.out_channels = out_channels
45
+ self.ignore_index = out_channels - 1
46
+ self.bos = out_channels - 2
47
+ self.eos = 0
48
+ self.max_len = max_len
49
+ d_model = in_channels
50
+ dim_feedforward = d_model * 4
51
+ nhead = nhead if nhead is not None else d_model // 32
52
+ self.embedding = Embeddings(
53
+ d_model=d_model,
54
+ vocab=self.out_channels,
55
+ padding_idx=0,
56
+ scale_embedding=scale_embedding,
57
+ )
58
+ self.positional_encoding = PositionalEncoding(
59
+ dropout=residual_dropout_rate, dim=d_model)
60
+
61
+ if num_encoder_layers > 0:
62
+ self.encoder = nn.ModuleList([
63
+ TransformerBlock(
64
+ d_model,
65
+ nhead,
66
+ dim_feedforward,
67
+ attention_dropout_rate,
68
+ residual_dropout_rate,
69
+ with_self_attn=True,
70
+ with_cross_attn=False,
71
+ ) for i in range(num_encoder_layers)
72
+ ])
73
+ else:
74
+ self.encoder = None
75
+
76
+ self.decoder = nn.ModuleList([
77
+ TransformerBlock(
78
+ d_model,
79
+ nhead,
80
+ dim_feedforward,
81
+ attention_dropout_rate,
82
+ residual_dropout_rate,
83
+ with_self_attn=True,
84
+ with_cross_attn=True,
85
+ ) for i in range(num_decoder_layers)
86
+ ])
87
+
88
+ self.beam_size = beam_size
89
+ self.d_model = d_model
90
+ self.nhead = nhead
91
+ self.tgt_word_prj = nn.Linear(d_model,
92
+ self.out_channels - 2,
93
+ bias=False)
94
+ w0 = np.random.normal(0.0, d_model**-0.5,
95
+ (d_model, self.out_channels - 2)).astype(
96
+ np.float32)
97
+ self.tgt_word_prj.weight.data = torch.from_numpy(w0.transpose())
98
+ self.apply(self._init_weights)
99
+
100
+ def _init_weights(self, m):
101
+ if isinstance(m, nn.Linear):
102
+ nn.init.xavier_normal_(m.weight)
103
+ if m.bias is not None:
104
+ nn.init.zeros_(m.bias)
105
+
106
+ def forward_train(self, src, tgt):
107
+ tgt = tgt[:, :-1]
108
+
109
+ tgt = self.embedding(tgt)
110
+ tgt = self.positional_encoding(tgt)
111
+ tgt_mask = self.generate_square_subsequent_mask(
112
+ tgt.shape[1], device=src.get_device())
113
+
114
+ if self.encoder is not None:
115
+ src = self.positional_encoding(src)
116
+ for encoder_layer in self.encoder:
117
+ src = encoder_layer(src)
118
+ memory = src # B N C
119
+ else:
120
+ memory = src # B N C
121
+ for decoder_layer in self.decoder:
122
+ tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
123
+ output = tgt
124
+ logit = self.tgt_word_prj(output)
125
+ return logit
126
+
127
+ def forward(self, src, data=None):
128
+ """Take in and process masked source/target sequences.
129
+ Args:
130
+ src: the sequence to the encoder (required).
131
+ tgt: the sequence to the decoder (required).
132
+ Shape:
133
+ - src: :math:`(B, sN, C)`.
134
+ - tgt: :math:`(B, tN, C)`.
135
+ Examples:
136
+ >>> output = transformer_model(src, tgt)
137
+ """
138
+
139
+ if self.training:
140
+ max_len = data[1].max()
141
+ tgt = data[0][:, :2 + max_len]
142
+ res = self.forward_train(src, tgt)
143
+ else:
144
+ res = self.forward_test(src)
145
+ return res
146
+
147
+ def forward_test(self, src):
148
+ bs = src.shape[0]
149
+ if self.encoder is not None:
150
+ src = self.positional_encoding(src)
151
+ for encoder_layer in self.encoder:
152
+ src = encoder_layer(src)
153
+ memory = src # B N C
154
+ else:
155
+ memory = src
156
+ dec_seq = torch.full((bs, self.max_len + 1),
157
+ self.ignore_index,
158
+ dtype=torch.int64,
159
+ device=src.get_device())
160
+ dec_seq[:, 0] = self.bos
161
+ logits = []
162
+ self.attn_maps = []
163
+ for len_dec_seq in range(0, self.max_len):
164
+ dec_seq_embed = self.embedding(
165
+ dec_seq[:, :len_dec_seq + 1]) # N dim 26+10 # </s> 012 a
166
+ dec_seq_embed = self.positional_encoding(dec_seq_embed)
167
+ tgt_mask = self.generate_square_subsequent_mask(
168
+ dec_seq_embed.shape[1], src.get_device())
169
+ tgt = dec_seq_embed # bs, 3, dim #bos, a, b, c, ... eos
170
+ for decoder_layer in self.decoder:
171
+ tgt = decoder_layer(tgt, memory, self_mask=tgt_mask)
172
+ self.attn_maps.append(
173
+ self.decoder[-1].cross_attn.attn_map[0][:, -1:, :])
174
+ dec_output = tgt
175
+ dec_output = dec_output[:, -1:, :]
176
+
177
+ word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1)
178
+ logits.append(word_prob)
179
+ if len_dec_seq < self.max_len:
180
+ # greedy decode. add the next token index to the target input
181
+ dec_seq[:, len_dec_seq + 1] = word_prob.squeeze().argmax(-1)
182
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
183
+ if (dec_seq == self.eos).any(dim=-1).all():
184
+ break
185
+ logits = torch.cat(logits, dim=1)
186
+ return logits
187
+
188
+ def generate_square_subsequent_mask(self, sz, device):
189
+ """Generate a square mask for the sequence.
190
+
191
+ The masked positions are filled with float('-inf'). Unmasked positions
192
+ are filled with float(0.0).
193
+ """
194
+ mask = torch.zeros([sz, sz], dtype=torch.float32)
195
+ mask_inf = torch.triu(
196
+ torch.full((sz, sz), dtype=torch.float32, fill_value=-torch.inf),
197
+ diagonal=1,
198
+ )
199
+ mask = mask + mask_inf
200
+ return mask.unsqueeze(0).unsqueeze(0).to(device)
201
+
202
+
203
+ class MultiheadAttention(nn.Module):
204
+
205
+ def __init__(self, embed_dim, num_heads, dropout=0.0, self_attn=False):
206
+ super(MultiheadAttention, self).__init__()
207
+ self.embed_dim = embed_dim
208
+ self.num_heads = num_heads
209
+ self.head_dim = embed_dim // num_heads
210
+ assert (self.head_dim * num_heads == self.embed_dim
211
+ ), 'embed_dim must be divisible by num_heads'
212
+ self.scale = self.head_dim**-0.5
213
+ self.self_attn = self_attn
214
+ if self_attn:
215
+ self.qkv = nn.Linear(embed_dim, embed_dim * 3)
216
+ else:
217
+ self.q = nn.Linear(embed_dim, embed_dim)
218
+ self.kv = nn.Linear(embed_dim, embed_dim * 2)
219
+ self.attn_drop = nn.Dropout(dropout)
220
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
221
+
222
+ def forward(self, query, key=None, attn_mask=None):
223
+ B, qN = query.shape[:2]
224
+
225
+ if self.self_attn:
226
+ qkv = self.qkv(query)
227
+ qkv = qkv.reshape(B, qN, 3, self.num_heads,
228
+ self.head_dim).permute(2, 0, 3, 1, 4)
229
+ q, k, v = qkv.unbind(0)
230
+ else:
231
+ kN = key.shape[1]
232
+ q = self.q(query)
233
+ q = q.reshape(B, qN, self.num_heads, self.head_dim).transpose(1, 2)
234
+ kv = self.kv(key)
235
+ kv = kv.reshape(B, kN, 2, self.num_heads,
236
+ self.head_dim).permute(2, 0, 3, 1, 4)
237
+ k, v = kv.unbind(0)
238
+
239
+ attn = (q.matmul(k.transpose(2, 3))) * self.scale
240
+ if attn_mask is not None:
241
+ attn += attn_mask
242
+
243
+ attn = F.softmax(attn, dim=-1)
244
+ if not self.training:
245
+ self.attn_map = attn
246
+ attn = self.attn_drop(attn)
247
+
248
+ x = (attn.matmul(v)).transpose(1, 2)
249
+ x = x.reshape(B, qN, self.embed_dim)
250
+ x = self.out_proj(x)
251
+
252
+ return x
253
+
254
+
255
+ class TransformerBlock(nn.Module):
256
+
257
+ def __init__(
258
+ self,
259
+ d_model,
260
+ nhead,
261
+ dim_feedforward=2048,
262
+ attention_dropout_rate=0.0,
263
+ residual_dropout_rate=0.1,
264
+ with_self_attn=True,
265
+ with_cross_attn=False,
266
+ epsilon=1e-5,
267
+ ):
268
+ super(TransformerBlock, self).__init__()
269
+ self.with_self_attn = with_self_attn
270
+ if with_self_attn:
271
+ self.self_attn = MultiheadAttention(d_model,
272
+ nhead,
273
+ dropout=attention_dropout_rate,
274
+ self_attn=with_self_attn)
275
+ self.norm1 = nn.LayerNorm(d_model, eps=epsilon)
276
+ self.dropout1 = nn.Dropout(residual_dropout_rate)
277
+ self.with_cross_attn = with_cross_attn
278
+ if with_cross_attn:
279
+ self.cross_attn = MultiheadAttention(
280
+ d_model, nhead, dropout=attention_dropout_rate
281
+ ) # for self_attn of encoder or cross_attn of decoder
282
+ self.norm2 = nn.LayerNorm(d_model, eps=epsilon)
283
+ self.dropout2 = nn.Dropout(residual_dropout_rate)
284
+
285
+ self.mlp = Mlp(
286
+ in_features=d_model,
287
+ hidden_features=dim_feedforward,
288
+ act_layer=nn.ReLU,
289
+ drop=residual_dropout_rate,
290
+ )
291
+
292
+ self.norm3 = nn.LayerNorm(d_model, eps=epsilon)
293
+
294
+ self.dropout3 = nn.Dropout(residual_dropout_rate)
295
+
296
+ def forward(self, tgt, memory=None, self_mask=None, cross_mask=None):
297
+ if self.with_self_attn:
298
+ tgt1 = self.self_attn(tgt, attn_mask=self_mask)
299
+ tgt = self.norm1(tgt + self.dropout1(tgt1))
300
+
301
+ if self.with_cross_attn:
302
+ tgt2 = self.cross_attn(tgt, key=memory, attn_mask=cross_mask)
303
+ tgt = self.norm2(tgt + self.dropout2(tgt2))
304
+ tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
305
+ return tgt
306
+
307
+
308
+ class PositionalEncoding(nn.Module):
309
+ """Inject some information about the relative or absolute position of the
310
+ tokens in the sequence. The positional encodings have the same dimension as
311
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
312
+ functions of different frequencies.
313
+
314
+ .. math::
315
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
316
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
317
+ \text{where pos is the word position and i is the embed idx)
318
+ Args:
319
+ d_model: the embed dim (required).
320
+ dropout: the dropout value (default=0.1).
321
+ max_len: the max. length of the incoming sequence (default=5000).
322
+ Examples:
323
+ >>> pos_encoder = PositionalEncoding(d_model)
324
+ """
325
+
326
+ def __init__(self, dropout, dim, max_len=5000):
327
+ super(PositionalEncoding, self).__init__()
328
+ self.dropout = nn.Dropout(p=dropout)
329
+
330
+ pe = torch.zeros([max_len, dim])
331
+ position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
332
+ div_term = torch.exp(
333
+ torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
334
+ pe[:, 0::2] = torch.sin(position * div_term)
335
+ pe[:, 1::2] = torch.cos(position * div_term)
336
+ pe = torch.unsqueeze(pe, 0)
337
+ # pe = torch.permute(pe, [1, 0, 2])
338
+ self.register_buffer('pe', pe)
339
+
340
+ def forward(self, x):
341
+ """Inputs of forward function
342
+ Args:
343
+ x: the sequence fed to the positional encoder model (required).
344
+ Shape:
345
+ x: [sequence length, batch size, embed dim]
346
+ output: [sequence length, batch size, embed dim]
347
+ Examples:
348
+ >>> output = pos_encoder(x)
349
+ """
350
+ # x = x.permute([1, 0, 2])
351
+ # x = x + self.pe[:x.shape[0], :]
352
+ x = x + self.pe[:, :x.shape[1], :]
353
+ return self.dropout(x) # .permute([1, 0, 2])
354
+
355
+
356
+ class PositionalEncoding_2d(nn.Module):
357
+ """Inject some information about the relative or absolute position of the
358
+ tokens in the sequence. The positional encodings have the same dimension as
359
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
360
+ functions of different frequencies.
361
+
362
+ .. math::
363
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
364
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
365
+ \text{where pos is the word position and i is the embed idx)
366
+ Args:
367
+ d_model: the embed dim (required).
368
+ dropout: the dropout value (default=0.1).
369
+ max_len: the max. length of the incoming sequence (default=5000).
370
+ Examples:
371
+ >>> pos_encoder = PositionalEncoding(d_model)
372
+ """
373
+
374
+ def __init__(self, dropout, dim, max_len=5000):
375
+ super(PositionalEncoding_2d, self).__init__()
376
+ self.dropout = nn.Dropout(p=dropout)
377
+
378
+ pe = torch.zeros([max_len, dim])
379
+ position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
380
+ div_term = torch.exp(
381
+ torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
382
+ pe[:, 0::2] = torch.sin(position * div_term)
383
+ pe[:, 1::2] = torch.cos(position * div_term)
384
+ pe = torch.permute(torch.unsqueeze(pe, 0), [1, 0, 2])
385
+ self.register_buffer('pe', pe)
386
+
387
+ self.avg_pool_1 = nn.AdaptiveAvgPool2d((1, 1))
388
+ self.linear1 = nn.Linear(dim, dim)
389
+ self.linear1.weight.data.fill_(1.0)
390
+ self.avg_pool_2 = nn.AdaptiveAvgPool2d((1, 1))
391
+ self.linear2 = nn.Linear(dim, dim)
392
+ self.linear2.weight.data.fill_(1.0)
393
+
394
+ def forward(self, x):
395
+ """Inputs of forward function
396
+ Args:
397
+ x: the sequence fed to the positional encoder model (required).
398
+ Shape:
399
+ x: [sequence length, batch size, embed dim]
400
+ output: [sequence length, batch size, embed dim]
401
+ Examples:
402
+ >>> output = pos_encoder(x)
403
+ """
404
+ w_pe = self.pe[:x.shape[-1], :]
405
+ w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
406
+ w_pe = w_pe * w1
407
+ w_pe = torch.permute(w_pe, [1, 2, 0])
408
+ w_pe = torch.unsqueeze(w_pe, 2)
409
+
410
+ h_pe = self.pe[:x.shape[-2], :]
411
+ w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
412
+ h_pe = h_pe * w2
413
+ h_pe = torch.permute(h_pe, [1, 2, 0])
414
+ h_pe = torch.unsqueeze(h_pe, 3)
415
+
416
+ x = x + w_pe + h_pe
417
+ x = torch.permute(
418
+ torch.reshape(x,
419
+ [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
420
+ [2, 0, 1],
421
+ )
422
+
423
+ return self.dropout(x)
424
+
425
+
426
+ class Embeddings(nn.Module):
427
+
428
+ def __init__(self, d_model, vocab, padding_idx=None, scale_embedding=True):
429
+ super(Embeddings, self).__init__()
430
+ self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
431
+ self.embedding.weight.data.normal_(mean=0.0, std=d_model**-0.5)
432
+ self.d_model = d_model
433
+ self.scale_embedding = scale_embedding
434
+
435
+ def forward(self, x):
436
+ if self.scale_embedding:
437
+ x = self.embedding(x)
438
+ return x * math.sqrt(self.d_model)
439
+ return self.embedding(x)
@@ -0,0 +1,205 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from torch.nn.init import ones_, trunc_normal_, zeros_
5
+
6
+ from .nrtr_decoder import TransformerBlock, Embeddings
7
+
8
+
9
+ class CPA(nn.Module):
10
+
11
+ def __init__(self, dim, max_len=25):
12
+ super(CPA, self).__init__()
13
+
14
+ self.fc1 = nn.Linear(dim, dim)
15
+ self.fc2 = nn.Linear(dim, dim)
16
+ self.fc3 = nn.Linear(dim, dim)
17
+ self.pos_embed = nn.Parameter(torch.zeros([1, max_len + 1, dim],
18
+ dtype=torch.float32),
19
+ requires_grad=True)
20
+ trunc_normal_(self.pos_embed, std=0.02)
21
+
22
+ def forward(self, feat):
23
+ # feat: B, L, Dim
24
+ feat = feat.mean(1).unsqueeze(1) # B, 1, Dim
25
+ x = self.fc1(feat) + self.pos_embed # B max_len dim
26
+ x = F.softmax(self.fc2(F.tanh(x)), -1) # B max_len dim
27
+ x = self.fc3(feat * x) # B max_len dim
28
+ return x
29
+
30
+
31
+ class ARDecoder(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ out_channels,
37
+ nhead=None,
38
+ num_decoder_layers=6,
39
+ max_len=25,
40
+ attention_dropout_rate=0.0,
41
+ residual_dropout_rate=0.1,
42
+ scale_embedding=True,
43
+ ):
44
+ super(ARDecoder, self).__init__()
45
+ self.out_channels = out_channels
46
+ self.ignore_index = out_channels - 1
47
+ self.bos = out_channels - 2
48
+ self.eos = 0
49
+ self.max_len = max_len
50
+ d_model = in_channels
51
+ dim_feedforward = d_model * 4
52
+ nhead = nhead if nhead is not None else d_model // 32
53
+ self.embedding = Embeddings(
54
+ d_model=d_model,
55
+ vocab=self.out_channels,
56
+ padding_idx=0,
57
+ scale_embedding=scale_embedding,
58
+ )
59
+ self.pos_embed = nn.Parameter(torch.zeros([1, max_len + 1, d_model],
60
+ dtype=torch.float32),
61
+ requires_grad=True)
62
+ trunc_normal_(self.pos_embed, std=0.02)
63
+ self.decoder = nn.ModuleList([
64
+ TransformerBlock(
65
+ d_model,
66
+ nhead,
67
+ dim_feedforward,
68
+ attention_dropout_rate,
69
+ residual_dropout_rate,
70
+ with_self_attn=True,
71
+ with_cross_attn=False,
72
+ ) for i in range(num_decoder_layers)
73
+ ])
74
+
75
+ self.tgt_word_prj = nn.Linear(d_model,
76
+ self.out_channels - 2,
77
+ bias=False)
78
+ self.apply(self._init_weights)
79
+
80
+ def _init_weights(self, m):
81
+ if isinstance(m, nn.Linear):
82
+ nn.init.xavier_normal_(m.weight)
83
+ if m.bias is not None:
84
+ nn.init.zeros_(m.bias)
85
+
86
+ def forward_train(self, src, tgt):
87
+ tgt = tgt[:, :-1]
88
+
89
+ tgt = self.embedding(
90
+ tgt) + src[:, :tgt.shape[1]] + self.pos_embed[:, :tgt.shape[1]]
91
+ tgt_mask = self.generate_square_subsequent_mask(
92
+ tgt.shape[1], device=src.get_device())
93
+
94
+ for decoder_layer in self.decoder:
95
+ tgt = decoder_layer(tgt, self_mask=tgt_mask)
96
+ output = tgt
97
+ logit = self.tgt_word_prj(output)
98
+ return logit
99
+
100
+ def forward(self, src, data=None):
101
+
102
+ if self.training:
103
+ max_len = data[1].max()
104
+ tgt = data[0][:, :2 + max_len]
105
+ res = self.forward_train(src, tgt)
106
+ else:
107
+ res = self.forward_test(src)
108
+ return res
109
+
110
+ def forward_test(self, src):
111
+ bs = src.shape[0]
112
+ src = src + self.pos_embed
113
+ dec_seq = torch.full((bs, self.max_len + 1),
114
+ self.ignore_index,
115
+ dtype=torch.int64,
116
+ device=src.get_device())
117
+ dec_seq[:, 0] = self.bos
118
+ logits = []
119
+ for len_dec_seq in range(0, self.max_len):
120
+ dec_seq_embed = self.embedding(
121
+ dec_seq[:, :len_dec_seq + 1]) # N dim 26+10 # </s> 012 a
122
+ dec_seq_embed = dec_seq_embed + src[:, :len_dec_seq + 1]
123
+ tgt_mask = self.generate_square_subsequent_mask(
124
+ dec_seq_embed.shape[1], src.get_device())
125
+ tgt = dec_seq_embed # bs, 3, dim #bos, a, b, c, ... eos
126
+ for decoder_layer in self.decoder:
127
+ tgt = decoder_layer(tgt, self_mask=tgt_mask)
128
+ dec_output = tgt
129
+ dec_output = dec_output[:, -1:, :]
130
+ word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1)
131
+ logits.append(word_prob)
132
+ if len_dec_seq < self.max_len:
133
+ # greedy decode. add the next token index to the target input
134
+ dec_seq[:, len_dec_seq + 1] = word_prob.squeeze(1).argmax(-1)
135
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
136
+ if (dec_seq == self.eos).any(dim=-1).all():
137
+ break
138
+ logits = torch.cat(logits, dim=1)
139
+ return logits
140
+
141
+ def generate_square_subsequent_mask(self, sz, device):
142
+ """Generate a square mask for the sequence.
143
+
144
+ The masked positions are filled with float('-inf'). Unmasked positions
145
+ are filled with float(0.0).
146
+ """
147
+ mask = torch.zeros([sz, sz], dtype=torch.float32)
148
+ mask_inf = torch.triu(
149
+ torch.full((sz, sz), dtype=torch.float32, fill_value=-torch.inf),
150
+ diagonal=1,
151
+ )
152
+ mask = mask + mask_inf
153
+ return mask.unsqueeze(0).unsqueeze(0).to(device)
154
+
155
+
156
+ class OTEDecoder(nn.Module):
157
+
158
+ def __init__(self,
159
+ in_channels,
160
+ out_channels,
161
+ max_len=25,
162
+ num_heads=None,
163
+ ar=False,
164
+ num_decoder_layers=1,
165
+ **kwargs):
166
+ super(OTEDecoder, self).__init__()
167
+
168
+ self.out_channels = out_channels - 2 # none + 26 + 10
169
+ dim = in_channels
170
+ self.dim = dim
171
+ self.max_len = max_len + 1 # max_len + eos
172
+
173
+ self.cpa = CPA(dim=dim, max_len=max_len)
174
+ self.ar = ar
175
+ if ar:
176
+ self.ar_decoder = ARDecoder(in_channels=dim,
177
+ out_channels=out_channels,
178
+ nhead=num_heads,
179
+ num_decoder_layers=num_decoder_layers,
180
+ max_len=max_len)
181
+ else:
182
+ self.fc = nn.Linear(dim, self.out_channels)
183
+ self.apply(self._init_weights)
184
+
185
+ def _init_weights(self, m):
186
+ if isinstance(m, nn.Linear):
187
+ trunc_normal_(m.weight, std=0.02)
188
+ if isinstance(m, nn.Linear) and m.bias is not None:
189
+ zeros_(m.bias)
190
+ elif isinstance(m, nn.LayerNorm):
191
+ zeros_(m.bias)
192
+ ones_(m.weight)
193
+
194
+ @torch.jit.ignore
195
+ def no_weight_decay(self):
196
+ return {'pos_embed'}
197
+
198
+ def forward(self, x, data=None):
199
+ x = self.cpa(x)
200
+ if self.ar:
201
+ return self.ar_decoder(x, data=data)
202
+ logits = self.fc(x) # B, 26, 37
203
+ if self.training:
204
+ logits = logits[:, :data[1].max() + 1]
205
+ return logits