openocr-python 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +11 -0
- openocr/configs/det/dbnet/repvit_db.yml +173 -0
- openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
- openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
- openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
- openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
- openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
- openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
- openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
- openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
- openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
- openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
- openocr/configs/rec/busnet/vit_busnet.yml +104 -0
- openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
- openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
- openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
- openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
- openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
- openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
- openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
- openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
- openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
- openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
- openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
- openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
- openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
- openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
- openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
- openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
- openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
- openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
- openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
- openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
- openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
- openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
- openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
- openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
- openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
- openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
- openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
- openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
- openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
- openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
- openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
- openocr/configs/rec/nrtr/nrtr.yml +107 -0
- openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
- openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
- openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
- openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
- openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
- openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
- openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
- openocr/configs/rec/parseq/vit_parseq.yml +100 -0
- openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
- openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
- openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
- openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
- openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
- openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
- openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
- openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
- openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
- openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
- openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
- openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
- openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
- openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
- openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
- openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
- openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
- openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
- openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
- openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
- openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
- openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
- openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
- openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
- openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
- openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
- openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
- openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
- openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
- openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
- openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
- openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
- openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
- openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
- openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
- openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
- openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
- openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
- openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
- openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
- openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
- openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
- openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
- openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
- openocr/demo_gradio.py +128 -0
- openocr/opendet/modeling/__init__.py +11 -0
- openocr/opendet/modeling/backbones/__init__.py +14 -0
- openocr/opendet/modeling/backbones/repvit.py +340 -0
- openocr/opendet/modeling/base_detector.py +69 -0
- openocr/opendet/modeling/heads/__init__.py +14 -0
- openocr/opendet/modeling/heads/db_head.py +73 -0
- openocr/opendet/modeling/necks/__init__.py +14 -0
- openocr/opendet/modeling/necks/db_fpn.py +609 -0
- openocr/opendet/postprocess/__init__.py +18 -0
- openocr/opendet/postprocess/db_postprocess.py +273 -0
- openocr/opendet/preprocess/__init__.py +154 -0
- openocr/opendet/preprocess/crop_resize.py +121 -0
- openocr/opendet/preprocess/db_resize_for_test.py +135 -0
- openocr/openrec/losses/__init__.py +62 -0
- openocr/openrec/losses/abinet_loss.py +42 -0
- openocr/openrec/losses/ar_loss.py +23 -0
- openocr/openrec/losses/cam_loss.py +48 -0
- openocr/openrec/losses/cdistnet_loss.py +34 -0
- openocr/openrec/losses/ce_loss.py +68 -0
- openocr/openrec/losses/cppd_loss.py +77 -0
- openocr/openrec/losses/ctc_loss.py +33 -0
- openocr/openrec/losses/igtr_loss.py +12 -0
- openocr/openrec/losses/lister_loss.py +14 -0
- openocr/openrec/losses/lpv_loss.py +30 -0
- openocr/openrec/losses/mgp_loss.py +34 -0
- openocr/openrec/losses/parseq_loss.py +12 -0
- openocr/openrec/losses/robustscanner_loss.py +20 -0
- openocr/openrec/losses/seed_loss.py +46 -0
- openocr/openrec/losses/smtr_loss.py +12 -0
- openocr/openrec/losses/srn_loss.py +40 -0
- openocr/openrec/losses/visionlan_loss.py +58 -0
- openocr/openrec/metrics/__init__.py +19 -0
- openocr/openrec/metrics/rec_metric.py +270 -0
- openocr/openrec/metrics/rec_metric_gtc.py +58 -0
- openocr/openrec/metrics/rec_metric_long.py +142 -0
- openocr/openrec/metrics/rec_metric_mgp.py +93 -0
- openocr/openrec/modeling/__init__.py +11 -0
- openocr/openrec/modeling/base_recognizer.py +69 -0
- openocr/openrec/modeling/common.py +238 -0
- openocr/openrec/modeling/decoders/__init__.py +109 -0
- openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
- openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
- openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
- openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
- openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
- openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
- openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
- openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
- openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
- openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
- openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
- openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
- openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
- openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
- openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
- openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
- openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
- openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
- openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
- openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
- openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
- openocr/openrec/modeling/encoders/__init__.py +39 -0
- openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
- openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
- openocr/openrec/modeling/encoders/convnextv2.py +213 -0
- openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
- openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
- openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
- openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
- openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
- openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
- openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
- openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
- openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
- openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
- openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
- openocr/openrec/modeling/encoders/repvit.py +338 -0
- openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
- openocr/openrec/modeling/encoders/svtrnet.py +574 -0
- openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
- openocr/openrec/modeling/encoders/svtrv2.py +470 -0
- openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
- openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
- openocr/openrec/modeling/encoders/vit.py +120 -0
- openocr/openrec/modeling/transforms/__init__.py +15 -0
- openocr/openrec/modeling/transforms/aster_tps.py +262 -0
- openocr/openrec/modeling/transforms/moran.py +136 -0
- openocr/openrec/modeling/transforms/tps.py +246 -0
- openocr/openrec/optimizer/__init__.py +73 -0
- openocr/openrec/optimizer/lr.py +227 -0
- openocr/openrec/postprocess/__init__.py +72 -0
- openocr/openrec/postprocess/abinet_postprocess.py +37 -0
- openocr/openrec/postprocess/ar_postprocess.py +63 -0
- openocr/openrec/postprocess/ce_postprocess.py +43 -0
- openocr/openrec/postprocess/char_postprocess.py +108 -0
- openocr/openrec/postprocess/cppd_postprocess.py +42 -0
- openocr/openrec/postprocess/ctc_postprocess.py +119 -0
- openocr/openrec/postprocess/igtr_postprocess.py +100 -0
- openocr/openrec/postprocess/lister_postprocess.py +59 -0
- openocr/openrec/postprocess/mgp_postprocess.py +143 -0
- openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
- openocr/openrec/postprocess/smtr_postprocess.py +73 -0
- openocr/openrec/postprocess/srn_postprocess.py +80 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
- openocr/openrec/preprocess/__init__.py +173 -0
- openocr/openrec/preprocess/abinet_aug.py +473 -0
- openocr/openrec/preprocess/abinet_label_encode.py +36 -0
- openocr/openrec/preprocess/ar_label_encode.py +36 -0
- openocr/openrec/preprocess/auto_augment.py +1012 -0
- openocr/openrec/preprocess/cam_label_encode.py +141 -0
- openocr/openrec/preprocess/ce_label_encode.py +116 -0
- openocr/openrec/preprocess/char_label_encode.py +36 -0
- openocr/openrec/preprocess/cppd_label_encode.py +173 -0
- openocr/openrec/preprocess/ctc_label_encode.py +124 -0
- openocr/openrec/preprocess/ep_label_encode.py +38 -0
- openocr/openrec/preprocess/igtr_label_encode.py +360 -0
- openocr/openrec/preprocess/mgp_label_encode.py +95 -0
- openocr/openrec/preprocess/parseq_aug.py +150 -0
- openocr/openrec/preprocess/rec_aug.py +211 -0
- openocr/openrec/preprocess/resize.py +534 -0
- openocr/openrec/preprocess/smtr_label_encode.py +125 -0
- openocr/openrec/preprocess/srn_label_encode.py +37 -0
- openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
- openocr/tools/create_lmdb_dataset.py +118 -0
- openocr/tools/data/__init__.py +94 -0
- openocr/tools/data/collate_fn.py +100 -0
- openocr/tools/data/lmdb_dataset.py +142 -0
- openocr/tools/data/lmdb_dataset_test.py +166 -0
- openocr/tools/data/multi_scale_sampler.py +177 -0
- openocr/tools/data/ratio_dataset.py +217 -0
- openocr/tools/data/ratio_dataset_test.py +273 -0
- openocr/tools/data/ratio_dataset_tvresize.py +213 -0
- openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
- openocr/tools/data/ratio_sampler.py +190 -0
- openocr/tools/data/simple_dataset.py +263 -0
- openocr/tools/data/strlmdb_dataset.py +143 -0
- openocr/tools/engine/__init__.py +5 -0
- openocr/tools/engine/config.py +158 -0
- openocr/tools/engine/trainer.py +621 -0
- openocr/tools/eval_rec.py +41 -0
- openocr/tools/eval_rec_all_ch.py +184 -0
- openocr/tools/eval_rec_all_en.py +206 -0
- openocr/tools/eval_rec_all_long.py +119 -0
- openocr/tools/eval_rec_all_long_simple.py +122 -0
- openocr/tools/export_rec.py +118 -0
- openocr/tools/infer/onnx_engine.py +65 -0
- openocr/tools/infer/predict_rec.py +140 -0
- openocr/tools/infer/utility.py +234 -0
- openocr/tools/infer_det.py +449 -0
- openocr/tools/infer_e2e.py +462 -0
- openocr/tools/infer_e2e_parallel.py +184 -0
- openocr/tools/infer_rec.py +371 -0
- openocr/tools/train_rec.py +37 -0
- openocr/tools/utility.py +45 -0
- openocr/tools/utils/EN_symbol_dict.txt +94 -0
- openocr/tools/utils/__init__.py +0 -0
- openocr/tools/utils/ckpt.py +87 -0
- openocr/tools/utils/dict/ar_dict.txt +117 -0
- openocr/tools/utils/dict/arabic_dict.txt +161 -0
- openocr/tools/utils/dict/be_dict.txt +145 -0
- openocr/tools/utils/dict/bg_dict.txt +140 -0
- openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
- openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
- openocr/tools/utils/dict/devanagari_dict.txt +167 -0
- openocr/tools/utils/dict/en_dict.txt +63 -0
- openocr/tools/utils/dict/fa_dict.txt +136 -0
- openocr/tools/utils/dict/french_dict.txt +136 -0
- openocr/tools/utils/dict/german_dict.txt +143 -0
- openocr/tools/utils/dict/hi_dict.txt +162 -0
- openocr/tools/utils/dict/it_dict.txt +118 -0
- openocr/tools/utils/dict/japan_dict.txt +4399 -0
- openocr/tools/utils/dict/ka_dict.txt +153 -0
- openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
- openocr/tools/utils/dict/korean_dict.txt +3688 -0
- openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
- openocr/tools/utils/dict/latin_dict.txt +185 -0
- openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
- openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
- openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
- openocr/tools/utils/dict/mr_dict.txt +153 -0
- openocr/tools/utils/dict/ne_dict.txt +153 -0
- openocr/tools/utils/dict/oc_dict.txt +96 -0
- openocr/tools/utils/dict/pu_dict.txt +130 -0
- openocr/tools/utils/dict/rs_dict.txt +91 -0
- openocr/tools/utils/dict/rsc_dict.txt +134 -0
- openocr/tools/utils/dict/ru_dict.txt +125 -0
- openocr/tools/utils/dict/spin_dict.txt +68 -0
- openocr/tools/utils/dict/ta_dict.txt +128 -0
- openocr/tools/utils/dict/table_dict.txt +277 -0
- openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
- openocr/tools/utils/dict/table_structure_dict.txt +28 -0
- openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
- openocr/tools/utils/dict/te_dict.txt +151 -0
- openocr/tools/utils/dict/ug_dict.txt +114 -0
- openocr/tools/utils/dict/uk_dict.txt +142 -0
- openocr/tools/utils/dict/ur_dict.txt +137 -0
- openocr/tools/utils/dict/xi_dict.txt +110 -0
- openocr/tools/utils/dict90.txt +90 -0
- openocr/tools/utils/e2e_metric/Deteval.py +802 -0
- openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
- openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
- openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
- openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
- openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
- openocr/tools/utils/e2e_utils/visual.py +152 -0
- openocr/tools/utils/en_dict.txt +95 -0
- openocr/tools/utils/gen_label.py +68 -0
- openocr/tools/utils/ic15_dict.txt +36 -0
- openocr/tools/utils/logging.py +56 -0
- openocr/tools/utils/poly_nms.py +132 -0
- openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
- openocr/tools/utils/stats.py +58 -0
- openocr/tools/utils/utility.py +165 -0
- openocr/tools/utils/visual.py +117 -0
- openocr_python-0.0.2.dist-info/LICENCE +201 -0
- openocr_python-0.0.2.dist-info/METADATA +98 -0
- openocr_python-0.0.2.dist-info/RECORD +323 -0
- openocr_python-0.0.2.dist-info/WHEEL +5 -0
- openocr_python-0.0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,749 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RobustScannerDecoder(nn.Module):
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
out_channels, # 90 + unknown + start + padding
|
|
13
|
+
in_channels,
|
|
14
|
+
enc_outchannles=128,
|
|
15
|
+
hybrid_dec_rnn_layers=2,
|
|
16
|
+
hybrid_dec_dropout=0,
|
|
17
|
+
position_dec_rnn_layers=2,
|
|
18
|
+
max_len=25,
|
|
19
|
+
mask=True,
|
|
20
|
+
encode_value=False,
|
|
21
|
+
**kwargs):
|
|
22
|
+
super(RobustScannerDecoder, self).__init__()
|
|
23
|
+
|
|
24
|
+
start_idx = out_channels - 2
|
|
25
|
+
padding_idx = out_channels - 1
|
|
26
|
+
end_idx = 0
|
|
27
|
+
# encoder module
|
|
28
|
+
self.encoder = ChannelReductionEncoder(in_channels=in_channels,
|
|
29
|
+
out_channels=enc_outchannles)
|
|
30
|
+
self.max_text_length = max_len + 1
|
|
31
|
+
self.mask = mask
|
|
32
|
+
# decoder module
|
|
33
|
+
self.decoder = Decoder(
|
|
34
|
+
num_classes=out_channels,
|
|
35
|
+
dim_input=in_channels,
|
|
36
|
+
dim_model=enc_outchannles,
|
|
37
|
+
hybrid_decoder_rnn_layers=hybrid_dec_rnn_layers,
|
|
38
|
+
hybrid_decoder_dropout=hybrid_dec_dropout,
|
|
39
|
+
position_decoder_rnn_layers=position_dec_rnn_layers,
|
|
40
|
+
max_len=max_len + 1,
|
|
41
|
+
start_idx=start_idx,
|
|
42
|
+
mask=mask,
|
|
43
|
+
padding_idx=padding_idx,
|
|
44
|
+
end_idx=end_idx,
|
|
45
|
+
encode_value=encode_value)
|
|
46
|
+
|
|
47
|
+
def forward(self, inputs, data=None):
|
|
48
|
+
'''
|
|
49
|
+
data: [label, valid_ratio, 'length']
|
|
50
|
+
'''
|
|
51
|
+
out_enc = self.encoder(inputs)
|
|
52
|
+
bs = out_enc.shape[0]
|
|
53
|
+
valid_ratios = None
|
|
54
|
+
word_positions = torch.arange(0,
|
|
55
|
+
self.max_text_length,
|
|
56
|
+
device=inputs.device).unsqueeze(0).tile(
|
|
57
|
+
[bs, 1])
|
|
58
|
+
|
|
59
|
+
if self.mask:
|
|
60
|
+
valid_ratios = data[-1]
|
|
61
|
+
|
|
62
|
+
if self.training:
|
|
63
|
+
max_len = data[1].max()
|
|
64
|
+
label = data[0][:, :1 + max_len] # label
|
|
65
|
+
final_out = self.decoder(inputs, out_enc, label, valid_ratios,
|
|
66
|
+
word_positions[:, :1 + max_len])
|
|
67
|
+
if not self.training:
|
|
68
|
+
final_out = self.decoder(inputs,
|
|
69
|
+
out_enc,
|
|
70
|
+
label=None,
|
|
71
|
+
valid_ratios=valid_ratios,
|
|
72
|
+
word_positions=word_positions,
|
|
73
|
+
train_mode=False)
|
|
74
|
+
return final_out
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class BaseDecoder(nn.Module):
|
|
78
|
+
|
|
79
|
+
def __init__(self, **kwargs):
|
|
80
|
+
super().__init__()
|
|
81
|
+
|
|
82
|
+
def forward_train(self, feat, out_enc, targets, img_metas):
|
|
83
|
+
raise NotImplementedError
|
|
84
|
+
|
|
85
|
+
def forward_test(self, feat, out_enc, img_metas):
|
|
86
|
+
raise NotImplementedError
|
|
87
|
+
|
|
88
|
+
def forward(self,
|
|
89
|
+
feat,
|
|
90
|
+
out_enc,
|
|
91
|
+
label=None,
|
|
92
|
+
valid_ratios=None,
|
|
93
|
+
word_positions=None,
|
|
94
|
+
train_mode=True):
|
|
95
|
+
self.train_mode = train_mode
|
|
96
|
+
|
|
97
|
+
if train_mode:
|
|
98
|
+
return self.forward_train(feat, out_enc, label, valid_ratios,
|
|
99
|
+
word_positions)
|
|
100
|
+
return self.forward_test(feat, out_enc, valid_ratios, word_positions)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class ChannelReductionEncoder(nn.Module):
|
|
104
|
+
"""Change the channel number with a one by one convoluational layer.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
in_channels (int): Number of input channels.
|
|
108
|
+
out_channels (int): Number of output channels.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def __init__(self, in_channels, out_channels, **kwargs):
|
|
112
|
+
super(ChannelReductionEncoder, self).__init__()
|
|
113
|
+
|
|
114
|
+
weight = torch.nn.Parameter(
|
|
115
|
+
torch.nn.init.xavier_normal_(torch.empty(out_channels, in_channels,
|
|
116
|
+
1, 1),
|
|
117
|
+
gain=1.0))
|
|
118
|
+
self.layer = nn.Conv2d(in_channels,
|
|
119
|
+
out_channels,
|
|
120
|
+
kernel_size=1,
|
|
121
|
+
stride=1,
|
|
122
|
+
padding=0)
|
|
123
|
+
|
|
124
|
+
use_xavier_normal = 1
|
|
125
|
+
if use_xavier_normal:
|
|
126
|
+
self.layer.weight = weight
|
|
127
|
+
|
|
128
|
+
def forward(self, feat):
|
|
129
|
+
"""
|
|
130
|
+
Args:
|
|
131
|
+
feat (Tensor): Image features with the shape of
|
|
132
|
+
:math:`(N, C_{in}, H, W)`.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Tensor: A tensor of shape :math:`(N, C_{out}, H, W)`.
|
|
136
|
+
"""
|
|
137
|
+
return self.layer(feat)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def masked_fill(x, mask, value):
|
|
141
|
+
y = torch.full(x.shape, value, x.dtype)
|
|
142
|
+
return torch.where(mask, y, x)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class DotProductAttentionLayer(nn.Module):
|
|
146
|
+
|
|
147
|
+
def __init__(self, dim_model=None):
|
|
148
|
+
super().__init__()
|
|
149
|
+
|
|
150
|
+
self.scale = dim_model**-0.5 if dim_model is not None else 1.
|
|
151
|
+
|
|
152
|
+
def forward(self, query, key, value, mask=None):
|
|
153
|
+
|
|
154
|
+
query = query.permute(0, 2, 1)
|
|
155
|
+
logits = query @ key * self.scale
|
|
156
|
+
|
|
157
|
+
if mask is not None:
|
|
158
|
+
n, seq_len = mask.size()
|
|
159
|
+
mask = mask.view(n, 1, seq_len)
|
|
160
|
+
logits = logits.masked_fill(mask, float('-inf'))
|
|
161
|
+
|
|
162
|
+
weights = F.softmax(logits, dim=2)
|
|
163
|
+
value = value.transpose(1, 2)
|
|
164
|
+
glimpse = weights @ value
|
|
165
|
+
glimpse = glimpse.permute(0, 2, 1).contiguous()
|
|
166
|
+
return glimpse
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class SequenceAttentionDecoder(BaseDecoder):
|
|
170
|
+
"""Sequence attention decoder for RobustScanner.
|
|
171
|
+
|
|
172
|
+
RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
|
|
173
|
+
Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
num_classes (int): Number of output classes :math:`C`.
|
|
177
|
+
rnn_layers (int): Number of RNN layers.
|
|
178
|
+
dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
|
|
179
|
+
dim_model (int): Dimension :math:`D_m` of the model. Should also be the
|
|
180
|
+
same as encoder output vector ``out_enc``.
|
|
181
|
+
max_seq_len (int): Maximum output sequence length :math:`T`.
|
|
182
|
+
start_idx (int): The index of `<SOS>`.
|
|
183
|
+
mask (bool): Whether to mask input features according to
|
|
184
|
+
``img_meta['valid_ratio']``.
|
|
185
|
+
padding_idx (int): The index of `<PAD>`.
|
|
186
|
+
dropout (float): Dropout rate.
|
|
187
|
+
return_feature (bool): Return feature or logits as the result.
|
|
188
|
+
encode_value (bool): Whether to use the output of encoder ``out_enc``
|
|
189
|
+
as `value` of attention layer. If False, the original feature
|
|
190
|
+
``feat`` will be used.
|
|
191
|
+
|
|
192
|
+
Warning:
|
|
193
|
+
This decoder will not predict the final class which is assumed to be
|
|
194
|
+
`<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
|
|
195
|
+
is also ignored by loss as specified in
|
|
196
|
+
:obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
def __init__(self,
|
|
200
|
+
num_classes=None,
|
|
201
|
+
rnn_layers=2,
|
|
202
|
+
dim_input=512,
|
|
203
|
+
dim_model=128,
|
|
204
|
+
max_seq_len=40,
|
|
205
|
+
start_idx=0,
|
|
206
|
+
mask=True,
|
|
207
|
+
padding_idx=None,
|
|
208
|
+
dropout=0,
|
|
209
|
+
return_feature=False,
|
|
210
|
+
encode_value=False):
|
|
211
|
+
super().__init__()
|
|
212
|
+
|
|
213
|
+
self.num_classes = num_classes
|
|
214
|
+
self.dim_input = dim_input
|
|
215
|
+
self.dim_model = dim_model
|
|
216
|
+
self.return_feature = return_feature
|
|
217
|
+
self.encode_value = encode_value
|
|
218
|
+
self.max_seq_len = max_seq_len
|
|
219
|
+
self.start_idx = start_idx
|
|
220
|
+
self.mask = mask
|
|
221
|
+
|
|
222
|
+
self.embedding = nn.Embedding(self.num_classes,
|
|
223
|
+
self.dim_model,
|
|
224
|
+
padding_idx=padding_idx)
|
|
225
|
+
|
|
226
|
+
self.sequence_layer = nn.LSTM(input_size=dim_model,
|
|
227
|
+
hidden_size=dim_model,
|
|
228
|
+
num_layers=rnn_layers,
|
|
229
|
+
batch_first=True,
|
|
230
|
+
dropout=dropout)
|
|
231
|
+
|
|
232
|
+
self.attention_layer = DotProductAttentionLayer()
|
|
233
|
+
|
|
234
|
+
self.prediction = None
|
|
235
|
+
if not self.return_feature:
|
|
236
|
+
pred_num_classes = num_classes - 1
|
|
237
|
+
self.prediction = nn.Linear(
|
|
238
|
+
dim_model if encode_value else dim_input, pred_num_classes)
|
|
239
|
+
|
|
240
|
+
def forward_train(self, feat, out_enc, targets, valid_ratios):
|
|
241
|
+
"""
|
|
242
|
+
Args:
|
|
243
|
+
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
|
244
|
+
out_enc (Tensor): Encoder output of shape
|
|
245
|
+
:math:`(N, D_m, H, W)`.
|
|
246
|
+
targets (Tensor): a tensor of shape :math:`(N, T)`. Each element is the index of a
|
|
247
|
+
character.
|
|
248
|
+
valid_ratios (Tensor): valid length ratio of img.
|
|
249
|
+
Returns:
|
|
250
|
+
Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
|
|
251
|
+
``return_feature=False``. Otherwise it would be the hidden feature
|
|
252
|
+
before the prediction projection layer, whose shape is
|
|
253
|
+
:math:`(N, T, D_m)`.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
tgt_embedding = self.embedding(targets)
|
|
257
|
+
|
|
258
|
+
n, c_enc, h, w = out_enc.shape
|
|
259
|
+
assert c_enc == self.dim_model
|
|
260
|
+
_, c_feat, _, _ = feat.shape
|
|
261
|
+
assert c_feat == self.dim_input
|
|
262
|
+
_, len_q, c_q = tgt_embedding.shape
|
|
263
|
+
assert c_q == self.dim_model
|
|
264
|
+
assert len_q <= self.max_seq_len
|
|
265
|
+
|
|
266
|
+
query, _ = self.sequence_layer(tgt_embedding)
|
|
267
|
+
|
|
268
|
+
query = query.permute(0, 2, 1).contiguous()
|
|
269
|
+
|
|
270
|
+
key = out_enc.view(n, c_enc, h * w)
|
|
271
|
+
|
|
272
|
+
if self.encode_value:
|
|
273
|
+
value = key
|
|
274
|
+
else:
|
|
275
|
+
value = feat.view(n, c_feat, h * w)
|
|
276
|
+
|
|
277
|
+
mask = None
|
|
278
|
+
if valid_ratios is not None:
|
|
279
|
+
mask = query.new_zeros((n, h, w))
|
|
280
|
+
for i, valid_ratio in enumerate(valid_ratios):
|
|
281
|
+
valid_width = min(w, math.ceil(w * valid_ratio))
|
|
282
|
+
mask[i, :, valid_width:] = 1
|
|
283
|
+
mask = mask.bool()
|
|
284
|
+
mask = mask.view(n, h * w)
|
|
285
|
+
|
|
286
|
+
attn_out = self.attention_layer(query, key, value, mask)
|
|
287
|
+
attn_out = attn_out.permute(0, 2, 1).contiguous()
|
|
288
|
+
|
|
289
|
+
if self.return_feature:
|
|
290
|
+
return attn_out
|
|
291
|
+
|
|
292
|
+
out = self.prediction(attn_out)
|
|
293
|
+
|
|
294
|
+
return out
|
|
295
|
+
|
|
296
|
+
def forward_test(self, feat, out_enc, valid_ratios):
|
|
297
|
+
"""
|
|
298
|
+
Args:
|
|
299
|
+
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
|
300
|
+
out_enc (Tensor): Encoder output of shape
|
|
301
|
+
:math:`(N, D_m, H, W)`.
|
|
302
|
+
valid_ratios (Tensor): valid length ratio of img.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
Tensor: The output logit sequence tensor of shape
|
|
306
|
+
:math:`(N, T, C-1)`.
|
|
307
|
+
"""
|
|
308
|
+
batch_size = feat.shape[0]
|
|
309
|
+
|
|
310
|
+
decode_sequence = (torch.ones((batch_size, self.max_seq_len),
|
|
311
|
+
dtype=torch.int64,
|
|
312
|
+
device=feat.device) * self.start_idx)
|
|
313
|
+
|
|
314
|
+
outputs = []
|
|
315
|
+
for i in range(self.max_seq_len):
|
|
316
|
+
step_out = self.forward_test_step(feat, out_enc, decode_sequence,
|
|
317
|
+
i, valid_ratios)
|
|
318
|
+
outputs.append(step_out)
|
|
319
|
+
max_idx = torch.argmax(step_out, dim=1, keepdim=False)
|
|
320
|
+
if i < self.max_seq_len - 1:
|
|
321
|
+
decode_sequence[:, i + 1] = max_idx
|
|
322
|
+
|
|
323
|
+
outputs = torch.stack(outputs, 1)
|
|
324
|
+
|
|
325
|
+
return outputs
|
|
326
|
+
|
|
327
|
+
def forward_test_step(self, feat, out_enc, decode_sequence, current_step,
|
|
328
|
+
valid_ratios):
|
|
329
|
+
"""
|
|
330
|
+
Args:
|
|
331
|
+
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
|
332
|
+
out_enc (Tensor): Encoder output of shape
|
|
333
|
+
:math:`(N, D_m, H, W)`.
|
|
334
|
+
decode_sequence (Tensor): Shape :math:`(N, T)`. The tensor that
|
|
335
|
+
stores history decoding result.
|
|
336
|
+
current_step (int): Current decoding step.
|
|
337
|
+
valid_ratios (Tensor): valid length ratio of img
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
Tensor: Shape :math:`(N, C-1)`. The logit tensor of predicted
|
|
341
|
+
tokens at current time step.
|
|
342
|
+
"""
|
|
343
|
+
|
|
344
|
+
embed = self.embedding(decode_sequence)
|
|
345
|
+
|
|
346
|
+
n, c_enc, h, w = out_enc.shape
|
|
347
|
+
assert c_enc == self.dim_model
|
|
348
|
+
_, c_feat, _, _ = feat.shape
|
|
349
|
+
assert c_feat == self.dim_input
|
|
350
|
+
_, _, c_q = embed.shape
|
|
351
|
+
assert c_q == self.dim_model
|
|
352
|
+
|
|
353
|
+
query, _ = self.sequence_layer(embed)
|
|
354
|
+
query = query.transpose(1, 2)
|
|
355
|
+
key = torch.reshape(out_enc, (n, c_enc, h * w))
|
|
356
|
+
if self.encode_value:
|
|
357
|
+
value = key
|
|
358
|
+
else:
|
|
359
|
+
value = torch.reshape(feat, (n, c_feat, h * w))
|
|
360
|
+
|
|
361
|
+
mask = None
|
|
362
|
+
if valid_ratios is not None:
|
|
363
|
+
mask = query.new_zeros((n, h, w))
|
|
364
|
+
for i, valid_ratio in enumerate(valid_ratios):
|
|
365
|
+
valid_width = min(w, math.ceil(w * valid_ratio))
|
|
366
|
+
mask[i, :, valid_width:] = 1
|
|
367
|
+
mask = mask.bool()
|
|
368
|
+
mask = mask.view(n, h * w)
|
|
369
|
+
|
|
370
|
+
# [n, c, l]
|
|
371
|
+
attn_out = self.attention_layer(query, key, value, mask)
|
|
372
|
+
out = attn_out[:, :, current_step]
|
|
373
|
+
|
|
374
|
+
if self.return_feature:
|
|
375
|
+
return out
|
|
376
|
+
|
|
377
|
+
out = self.prediction(out)
|
|
378
|
+
out = F.softmax(out, dim=-1)
|
|
379
|
+
|
|
380
|
+
return out
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
class PositionAwareLayer(nn.Module):
|
|
384
|
+
|
|
385
|
+
def __init__(self, dim_model, rnn_layers=2):
|
|
386
|
+
super().__init__()
|
|
387
|
+
|
|
388
|
+
self.dim_model = dim_model
|
|
389
|
+
|
|
390
|
+
self.rnn = nn.LSTM(input_size=dim_model,
|
|
391
|
+
hidden_size=dim_model,
|
|
392
|
+
num_layers=rnn_layers,
|
|
393
|
+
batch_first=True)
|
|
394
|
+
|
|
395
|
+
self.mixer = nn.Sequential(
|
|
396
|
+
nn.Conv2d(dim_model, dim_model, kernel_size=3, stride=1,
|
|
397
|
+
padding=1), nn.ReLU(True),
|
|
398
|
+
nn.Conv2d(dim_model, dim_model, kernel_size=3, stride=1,
|
|
399
|
+
padding=1))
|
|
400
|
+
|
|
401
|
+
def forward(self, img_feature):
|
|
402
|
+
n, c, h, w = img_feature.shape
|
|
403
|
+
rnn_input = img_feature.permute(0, 2, 3, 1).contiguous()
|
|
404
|
+
rnn_input = rnn_input.view(n * h, w, c)
|
|
405
|
+
rnn_output, _ = self.rnn(rnn_input)
|
|
406
|
+
rnn_output = rnn_output.view(n, h, w, c)
|
|
407
|
+
rnn_output = rnn_output.permute(0, 3, 1, 2).contiguous()
|
|
408
|
+
|
|
409
|
+
out = self.mixer(rnn_output)
|
|
410
|
+
return out
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
class PositionAttentionDecoder(BaseDecoder):
|
|
414
|
+
"""Position attention decoder for RobustScanner.
|
|
415
|
+
|
|
416
|
+
RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
|
|
417
|
+
Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
num_classes (int): Number of output classes :math:`C`.
|
|
421
|
+
rnn_layers (int): Number of RNN layers.
|
|
422
|
+
dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
|
|
423
|
+
dim_model (int): Dimension :math:`D_m` of the model. Should also be the
|
|
424
|
+
same as encoder output vector ``out_enc``.
|
|
425
|
+
max_seq_len (int): Maximum output sequence length :math:`T`.
|
|
426
|
+
mask (bool): Whether to mask input features according to
|
|
427
|
+
``img_meta['valid_ratio']``.
|
|
428
|
+
return_feature (bool): Return feature or logits as the result.
|
|
429
|
+
encode_value (bool): Whether to use the output of encoder ``out_enc``
|
|
430
|
+
as `value` of attention layer. If False, the original feature
|
|
431
|
+
``feat`` will be used.
|
|
432
|
+
|
|
433
|
+
Warning:
|
|
434
|
+
This decoder will not predict the final class which is assumed to be
|
|
435
|
+
`<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
|
|
436
|
+
is also ignored by loss
|
|
437
|
+
"""
|
|
438
|
+
|
|
439
|
+
def __init__(self,
|
|
440
|
+
num_classes=None,
|
|
441
|
+
rnn_layers=2,
|
|
442
|
+
dim_input=512,
|
|
443
|
+
dim_model=128,
|
|
444
|
+
max_seq_len=40,
|
|
445
|
+
mask=True,
|
|
446
|
+
return_feature=False,
|
|
447
|
+
encode_value=False):
|
|
448
|
+
super().__init__()
|
|
449
|
+
|
|
450
|
+
self.num_classes = num_classes
|
|
451
|
+
self.dim_input = dim_input
|
|
452
|
+
self.dim_model = dim_model
|
|
453
|
+
self.max_seq_len = max_seq_len
|
|
454
|
+
self.return_feature = return_feature
|
|
455
|
+
self.encode_value = encode_value
|
|
456
|
+
self.mask = mask
|
|
457
|
+
|
|
458
|
+
self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model)
|
|
459
|
+
|
|
460
|
+
self.position_aware_module = PositionAwareLayer(
|
|
461
|
+
self.dim_model, rnn_layers)
|
|
462
|
+
|
|
463
|
+
self.attention_layer = DotProductAttentionLayer()
|
|
464
|
+
|
|
465
|
+
self.prediction = None
|
|
466
|
+
if not self.return_feature:
|
|
467
|
+
pred_num_classes = num_classes - 1
|
|
468
|
+
self.prediction = nn.Linear(
|
|
469
|
+
dim_model if encode_value else dim_input, pred_num_classes)
|
|
470
|
+
|
|
471
|
+
def _get_position_index(self, length, batch_size):
|
|
472
|
+
position_index_list = []
|
|
473
|
+
for i in range(batch_size):
|
|
474
|
+
position_index = torch.range(0, length, step=1, dtype='int64')
|
|
475
|
+
position_index_list.append(position_index)
|
|
476
|
+
batch_position_index = torch.stack(position_index_list, dim=0)
|
|
477
|
+
return batch_position_index
|
|
478
|
+
|
|
479
|
+
def forward_train(self, feat, out_enc, targets, valid_ratios,
|
|
480
|
+
position_index):
|
|
481
|
+
"""
|
|
482
|
+
Args:
|
|
483
|
+
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
|
484
|
+
out_enc (Tensor): Encoder output of shape
|
|
485
|
+
:math:`(N, D_m, H, W)`.
|
|
486
|
+
targets (dict): A dict with the key ``padded_targets``, a
|
|
487
|
+
tensor of shape :math:`(N, T)`. Each element is the index of a
|
|
488
|
+
character.
|
|
489
|
+
valid_ratios (Tensor): valid length ratio of img.
|
|
490
|
+
position_index (Tensor): The position of each word.
|
|
491
|
+
|
|
492
|
+
Returns:
|
|
493
|
+
Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
|
|
494
|
+
``return_feature=False``. Otherwise it will be the hidden feature
|
|
495
|
+
before the prediction projection layer, whose shape is
|
|
496
|
+
:math:`(N, T, D_m)`.
|
|
497
|
+
"""
|
|
498
|
+
n, c_enc, h, w = out_enc.shape
|
|
499
|
+
assert c_enc == self.dim_model
|
|
500
|
+
_, c_feat, _, _ = feat.shape
|
|
501
|
+
assert c_feat == self.dim_input
|
|
502
|
+
_, len_q = targets.shape
|
|
503
|
+
assert len_q <= self.max_seq_len
|
|
504
|
+
|
|
505
|
+
position_out_enc = self.position_aware_module(out_enc)
|
|
506
|
+
|
|
507
|
+
query = self.embedding(position_index)
|
|
508
|
+
query = query.permute(0, 2, 1).contiguous()
|
|
509
|
+
key = position_out_enc.view(n, c_enc, h * w)
|
|
510
|
+
if self.encode_value:
|
|
511
|
+
value = out_enc.view(n, c_enc, h * w)
|
|
512
|
+
else:
|
|
513
|
+
value = feat.view(n, c_feat, h * w)
|
|
514
|
+
|
|
515
|
+
mask = None
|
|
516
|
+
if valid_ratios is not None:
|
|
517
|
+
mask = query.new_zeros((n, h, w))
|
|
518
|
+
for i, valid_ratio in enumerate(valid_ratios):
|
|
519
|
+
valid_width = min(w, math.ceil(w * valid_ratio))
|
|
520
|
+
mask[i, :, valid_width:] = 1
|
|
521
|
+
mask = mask.bool()
|
|
522
|
+
mask = mask.view(n, h * w)
|
|
523
|
+
|
|
524
|
+
attn_out = self.attention_layer(query, key, value, mask)
|
|
525
|
+
attn_out = attn_out.permute(0, 2, 1).contiguous()
|
|
526
|
+
|
|
527
|
+
if self.return_feature:
|
|
528
|
+
return attn_out
|
|
529
|
+
|
|
530
|
+
return self.prediction(attn_out)
|
|
531
|
+
|
|
532
|
+
def forward_test(self, feat, out_enc, valid_ratios, position_index):
|
|
533
|
+
"""
|
|
534
|
+
Args:
|
|
535
|
+
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
|
536
|
+
out_enc (Tensor): Encoder output of shape
|
|
537
|
+
:math:`(N, D_m, H, W)`.
|
|
538
|
+
valid_ratios (Tensor): valid length ratio of img
|
|
539
|
+
position_index (Tensor): The position of each word.
|
|
540
|
+
|
|
541
|
+
Returns:
|
|
542
|
+
Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
|
|
543
|
+
``return_feature=False``. Otherwise it would be the hidden feature
|
|
544
|
+
before the prediction projection layer, whose shape is
|
|
545
|
+
:math:`(N, T, D_m)`.
|
|
546
|
+
"""
|
|
547
|
+
n, c_enc, h, w = out_enc.shape
|
|
548
|
+
assert c_enc == self.dim_model
|
|
549
|
+
_, c_feat, _, _ = feat.shape
|
|
550
|
+
assert c_feat == self.dim_input
|
|
551
|
+
|
|
552
|
+
position_out_enc = self.position_aware_module(out_enc)
|
|
553
|
+
|
|
554
|
+
query = self.embedding(position_index)
|
|
555
|
+
query = query.permute(0, 2, 1).contiguous()
|
|
556
|
+
key = position_out_enc.view(n, c_enc, h * w)
|
|
557
|
+
if self.encode_value:
|
|
558
|
+
value = torch.reshape(out_enc, (n, c_enc, h * w))
|
|
559
|
+
else:
|
|
560
|
+
value = torch.reshape(feat, (n, c_feat, h * w))
|
|
561
|
+
|
|
562
|
+
mask = None
|
|
563
|
+
if valid_ratios is not None:
|
|
564
|
+
mask = query.new_zeros((n, h, w))
|
|
565
|
+
for i, valid_ratio in enumerate(valid_ratios):
|
|
566
|
+
valid_width = min(w, math.ceil(w * valid_ratio))
|
|
567
|
+
mask[i, :, valid_width:] = 1
|
|
568
|
+
mask = mask.bool()
|
|
569
|
+
mask = mask.view(n, h * w)
|
|
570
|
+
|
|
571
|
+
attn_out = self.attention_layer(query, key, value, mask)
|
|
572
|
+
attn_out = attn_out.transpose(1, 2) # [n, len_q, dim_v]
|
|
573
|
+
|
|
574
|
+
if self.return_feature:
|
|
575
|
+
return attn_out
|
|
576
|
+
|
|
577
|
+
return self.prediction(attn_out)
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
class RobustScannerFusionLayer(nn.Module):
|
|
581
|
+
|
|
582
|
+
def __init__(self, dim_model, dim=-1):
|
|
583
|
+
super(RobustScannerFusionLayer, self).__init__()
|
|
584
|
+
|
|
585
|
+
self.dim_model = dim_model
|
|
586
|
+
self.dim = dim
|
|
587
|
+
self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2)
|
|
588
|
+
|
|
589
|
+
def forward(self, x0, x1):
|
|
590
|
+
assert x0.shape == x1.shape
|
|
591
|
+
fusion_input = torch.concat((x0, x1), self.dim)
|
|
592
|
+
output = self.linear_layer(fusion_input)
|
|
593
|
+
output = F.glu(output, self.dim)
|
|
594
|
+
|
|
595
|
+
return output
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
class Decoder(BaseDecoder):
|
|
599
|
+
"""Decoder for RobustScanner.
|
|
600
|
+
|
|
601
|
+
RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
|
|
602
|
+
Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
|
|
603
|
+
|
|
604
|
+
Args:
|
|
605
|
+
num_classes (int): Number of output classes :math:`C`.
|
|
606
|
+
dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
|
|
607
|
+
dim_model (int): Dimension :math:`D_m` of the model. Should also be the
|
|
608
|
+
same as encoder output vector ``out_enc``.
|
|
609
|
+
max_seq_len (int): Maximum output sequence length :math:`T`.
|
|
610
|
+
start_idx (int): The index of `<SOS>`.
|
|
611
|
+
mask (bool): Whether to mask input features according to
|
|
612
|
+
``img_meta['valid_ratio']``.
|
|
613
|
+
padding_idx (int): The index of `<PAD>`.
|
|
614
|
+
encode_value (bool): Whether to use the output of encoder ``out_enc``
|
|
615
|
+
as `value` of attention layer. If False, the original feature
|
|
616
|
+
``feat`` will be used.
|
|
617
|
+
|
|
618
|
+
Warning:
|
|
619
|
+
This decoder will not predict the final class which is assumed to be
|
|
620
|
+
`<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
|
|
621
|
+
is also ignored by loss as specified in
|
|
622
|
+
:obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
|
|
623
|
+
"""
|
|
624
|
+
|
|
625
|
+
def __init__(self,
|
|
626
|
+
num_classes=None,
|
|
627
|
+
dim_input=512,
|
|
628
|
+
dim_model=128,
|
|
629
|
+
hybrid_decoder_rnn_layers=2,
|
|
630
|
+
hybrid_decoder_dropout=0,
|
|
631
|
+
position_decoder_rnn_layers=2,
|
|
632
|
+
max_len=40,
|
|
633
|
+
start_idx=0,
|
|
634
|
+
mask=True,
|
|
635
|
+
padding_idx=None,
|
|
636
|
+
end_idx=0,
|
|
637
|
+
encode_value=False):
|
|
638
|
+
super().__init__()
|
|
639
|
+
self.num_classes = num_classes
|
|
640
|
+
self.dim_input = dim_input
|
|
641
|
+
self.dim_model = dim_model
|
|
642
|
+
self.max_seq_len = max_len
|
|
643
|
+
self.encode_value = encode_value
|
|
644
|
+
self.start_idx = start_idx
|
|
645
|
+
self.padding_idx = padding_idx
|
|
646
|
+
self.end_idx = end_idx
|
|
647
|
+
self.mask = mask
|
|
648
|
+
|
|
649
|
+
# init hybrid decoder
|
|
650
|
+
self.hybrid_decoder = SequenceAttentionDecoder(
|
|
651
|
+
num_classes=num_classes,
|
|
652
|
+
rnn_layers=hybrid_decoder_rnn_layers,
|
|
653
|
+
dim_input=dim_input,
|
|
654
|
+
dim_model=dim_model,
|
|
655
|
+
max_seq_len=max_len,
|
|
656
|
+
start_idx=start_idx,
|
|
657
|
+
mask=mask,
|
|
658
|
+
padding_idx=padding_idx,
|
|
659
|
+
dropout=hybrid_decoder_dropout,
|
|
660
|
+
encode_value=encode_value,
|
|
661
|
+
return_feature=True)
|
|
662
|
+
|
|
663
|
+
# init position decoder
|
|
664
|
+
self.position_decoder = PositionAttentionDecoder(
|
|
665
|
+
num_classes=num_classes,
|
|
666
|
+
rnn_layers=position_decoder_rnn_layers,
|
|
667
|
+
dim_input=dim_input,
|
|
668
|
+
dim_model=dim_model,
|
|
669
|
+
max_seq_len=max_len,
|
|
670
|
+
mask=mask,
|
|
671
|
+
encode_value=encode_value,
|
|
672
|
+
return_feature=True)
|
|
673
|
+
|
|
674
|
+
self.fusion_module = RobustScannerFusionLayer(
|
|
675
|
+
self.dim_model if encode_value else dim_input)
|
|
676
|
+
|
|
677
|
+
pred_num_classes = num_classes
|
|
678
|
+
self.prediction = nn.Linear(dim_model if encode_value else dim_input,
|
|
679
|
+
pred_num_classes)
|
|
680
|
+
|
|
681
|
+
def forward_train(self, feat, out_enc, target, valid_ratios,
|
|
682
|
+
word_positions):
|
|
683
|
+
"""
|
|
684
|
+
Args:
|
|
685
|
+
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
|
686
|
+
out_enc (Tensor): Encoder output of shape
|
|
687
|
+
:math:`(N, D_m, H, W)`.
|
|
688
|
+
target (dict): A dict with the key ``padded_targets``, a
|
|
689
|
+
tensor of shape :math:`(N, T)`. Each element is the index of a
|
|
690
|
+
character.
|
|
691
|
+
valid_ratios (Tensor):
|
|
692
|
+
word_positions (Tensor): The position of each word.
|
|
693
|
+
|
|
694
|
+
Returns:
|
|
695
|
+
Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`.
|
|
696
|
+
"""
|
|
697
|
+
|
|
698
|
+
hybrid_glimpse = self.hybrid_decoder.forward_train(
|
|
699
|
+
feat, out_enc, target, valid_ratios)
|
|
700
|
+
position_glimpse = self.position_decoder.forward_train(
|
|
701
|
+
feat, out_enc, target, valid_ratios, word_positions)
|
|
702
|
+
|
|
703
|
+
fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse)
|
|
704
|
+
|
|
705
|
+
out = self.prediction(fusion_out)
|
|
706
|
+
|
|
707
|
+
return out
|
|
708
|
+
|
|
709
|
+
def forward_test(self, feat, out_enc, valid_ratios, word_positions):
|
|
710
|
+
"""
|
|
711
|
+
Args:
|
|
712
|
+
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
|
713
|
+
out_enc (Tensor): Encoder output of shape
|
|
714
|
+
:math:`(N, D_m, H, W)`.
|
|
715
|
+
valid_ratios (Tensor):
|
|
716
|
+
word_positions (Tensor): The position of each word.
|
|
717
|
+
Returns:
|
|
718
|
+
Tensor: The output logit sequence tensor of shape
|
|
719
|
+
:math:`(N, T, C-1)`.
|
|
720
|
+
"""
|
|
721
|
+
seq_len = self.max_seq_len
|
|
722
|
+
batch_size = feat.shape[0]
|
|
723
|
+
|
|
724
|
+
decode_sequence = (torch.ones(
|
|
725
|
+
(batch_size, seq_len), dtype=torch.int64, device=feat.device) *
|
|
726
|
+
self.start_idx)
|
|
727
|
+
|
|
728
|
+
position_glimpse = self.position_decoder.forward_test(
|
|
729
|
+
feat, out_enc, valid_ratios, word_positions)
|
|
730
|
+
|
|
731
|
+
outputs = []
|
|
732
|
+
for i in range(seq_len):
|
|
733
|
+
hybrid_glimpse_step = self.hybrid_decoder.forward_test_step(
|
|
734
|
+
feat, out_enc, decode_sequence, i, valid_ratios)
|
|
735
|
+
|
|
736
|
+
fusion_out = self.fusion_module(hybrid_glimpse_step,
|
|
737
|
+
position_glimpse[:, i, :])
|
|
738
|
+
|
|
739
|
+
char_out = self.prediction(fusion_out)
|
|
740
|
+
char_out = F.softmax(char_out, -1)
|
|
741
|
+
outputs.append(char_out)
|
|
742
|
+
max_idx = torch.argmax(char_out, dim=1, keepdim=False)
|
|
743
|
+
if i < seq_len - 1:
|
|
744
|
+
decode_sequence[:, i + 1] = max_idx
|
|
745
|
+
if (decode_sequence == self.end_idx).any(dim=-1).all():
|
|
746
|
+
break
|
|
747
|
+
outputs = torch.stack(outputs, 1)
|
|
748
|
+
|
|
749
|
+
return outputs
|