openocr-python 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +11 -0
- openocr/configs/det/dbnet/repvit_db.yml +173 -0
- openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
- openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
- openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
- openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
- openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
- openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
- openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
- openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
- openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
- openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
- openocr/configs/rec/busnet/vit_busnet.yml +104 -0
- openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
- openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
- openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
- openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
- openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
- openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
- openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
- openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
- openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
- openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
- openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
- openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
- openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
- openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
- openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
- openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
- openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
- openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
- openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
- openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
- openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
- openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
- openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
- openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
- openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
- openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
- openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
- openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
- openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
- openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
- openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
- openocr/configs/rec/nrtr/nrtr.yml +107 -0
- openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
- openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
- openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
- openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
- openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
- openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
- openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
- openocr/configs/rec/parseq/vit_parseq.yml +100 -0
- openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
- openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
- openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
- openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
- openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
- openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
- openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
- openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
- openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
- openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
- openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
- openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
- openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
- openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
- openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
- openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
- openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
- openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
- openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
- openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
- openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
- openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
- openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
- openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
- openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
- openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
- openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
- openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
- openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
- openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
- openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
- openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
- openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
- openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
- openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
- openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
- openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
- openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
- openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
- openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
- openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
- openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
- openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
- openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
- openocr/demo_gradio.py +128 -0
- openocr/opendet/modeling/__init__.py +11 -0
- openocr/opendet/modeling/backbones/__init__.py +14 -0
- openocr/opendet/modeling/backbones/repvit.py +340 -0
- openocr/opendet/modeling/base_detector.py +69 -0
- openocr/opendet/modeling/heads/__init__.py +14 -0
- openocr/opendet/modeling/heads/db_head.py +73 -0
- openocr/opendet/modeling/necks/__init__.py +14 -0
- openocr/opendet/modeling/necks/db_fpn.py +609 -0
- openocr/opendet/postprocess/__init__.py +18 -0
- openocr/opendet/postprocess/db_postprocess.py +273 -0
- openocr/opendet/preprocess/__init__.py +154 -0
- openocr/opendet/preprocess/crop_resize.py +121 -0
- openocr/opendet/preprocess/db_resize_for_test.py +135 -0
- openocr/openrec/losses/__init__.py +62 -0
- openocr/openrec/losses/abinet_loss.py +42 -0
- openocr/openrec/losses/ar_loss.py +23 -0
- openocr/openrec/losses/cam_loss.py +48 -0
- openocr/openrec/losses/cdistnet_loss.py +34 -0
- openocr/openrec/losses/ce_loss.py +68 -0
- openocr/openrec/losses/cppd_loss.py +77 -0
- openocr/openrec/losses/ctc_loss.py +33 -0
- openocr/openrec/losses/igtr_loss.py +12 -0
- openocr/openrec/losses/lister_loss.py +14 -0
- openocr/openrec/losses/lpv_loss.py +30 -0
- openocr/openrec/losses/mgp_loss.py +34 -0
- openocr/openrec/losses/parseq_loss.py +12 -0
- openocr/openrec/losses/robustscanner_loss.py +20 -0
- openocr/openrec/losses/seed_loss.py +46 -0
- openocr/openrec/losses/smtr_loss.py +12 -0
- openocr/openrec/losses/srn_loss.py +40 -0
- openocr/openrec/losses/visionlan_loss.py +58 -0
- openocr/openrec/metrics/__init__.py +19 -0
- openocr/openrec/metrics/rec_metric.py +270 -0
- openocr/openrec/metrics/rec_metric_gtc.py +58 -0
- openocr/openrec/metrics/rec_metric_long.py +142 -0
- openocr/openrec/metrics/rec_metric_mgp.py +93 -0
- openocr/openrec/modeling/__init__.py +11 -0
- openocr/openrec/modeling/base_recognizer.py +69 -0
- openocr/openrec/modeling/common.py +238 -0
- openocr/openrec/modeling/decoders/__init__.py +109 -0
- openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
- openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
- openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
- openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
- openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
- openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
- openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
- openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
- openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
- openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
- openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
- openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
- openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
- openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
- openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
- openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
- openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
- openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
- openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
- openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
- openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
- openocr/openrec/modeling/encoders/__init__.py +39 -0
- openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
- openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
- openocr/openrec/modeling/encoders/convnextv2.py +213 -0
- openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
- openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
- openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
- openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
- openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
- openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
- openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
- openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
- openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
- openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
- openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
- openocr/openrec/modeling/encoders/repvit.py +338 -0
- openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
- openocr/openrec/modeling/encoders/svtrnet.py +574 -0
- openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
- openocr/openrec/modeling/encoders/svtrv2.py +470 -0
- openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
- openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
- openocr/openrec/modeling/encoders/vit.py +120 -0
- openocr/openrec/modeling/transforms/__init__.py +15 -0
- openocr/openrec/modeling/transforms/aster_tps.py +262 -0
- openocr/openrec/modeling/transforms/moran.py +136 -0
- openocr/openrec/modeling/transforms/tps.py +246 -0
- openocr/openrec/optimizer/__init__.py +73 -0
- openocr/openrec/optimizer/lr.py +227 -0
- openocr/openrec/postprocess/__init__.py +72 -0
- openocr/openrec/postprocess/abinet_postprocess.py +37 -0
- openocr/openrec/postprocess/ar_postprocess.py +63 -0
- openocr/openrec/postprocess/ce_postprocess.py +43 -0
- openocr/openrec/postprocess/char_postprocess.py +108 -0
- openocr/openrec/postprocess/cppd_postprocess.py +42 -0
- openocr/openrec/postprocess/ctc_postprocess.py +119 -0
- openocr/openrec/postprocess/igtr_postprocess.py +100 -0
- openocr/openrec/postprocess/lister_postprocess.py +59 -0
- openocr/openrec/postprocess/mgp_postprocess.py +143 -0
- openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
- openocr/openrec/postprocess/smtr_postprocess.py +73 -0
- openocr/openrec/postprocess/srn_postprocess.py +80 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
- openocr/openrec/preprocess/__init__.py +173 -0
- openocr/openrec/preprocess/abinet_aug.py +473 -0
- openocr/openrec/preprocess/abinet_label_encode.py +36 -0
- openocr/openrec/preprocess/ar_label_encode.py +36 -0
- openocr/openrec/preprocess/auto_augment.py +1012 -0
- openocr/openrec/preprocess/cam_label_encode.py +141 -0
- openocr/openrec/preprocess/ce_label_encode.py +116 -0
- openocr/openrec/preprocess/char_label_encode.py +36 -0
- openocr/openrec/preprocess/cppd_label_encode.py +173 -0
- openocr/openrec/preprocess/ctc_label_encode.py +124 -0
- openocr/openrec/preprocess/ep_label_encode.py +38 -0
- openocr/openrec/preprocess/igtr_label_encode.py +360 -0
- openocr/openrec/preprocess/mgp_label_encode.py +95 -0
- openocr/openrec/preprocess/parseq_aug.py +150 -0
- openocr/openrec/preprocess/rec_aug.py +211 -0
- openocr/openrec/preprocess/resize.py +534 -0
- openocr/openrec/preprocess/smtr_label_encode.py +125 -0
- openocr/openrec/preprocess/srn_label_encode.py +37 -0
- openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
- openocr/tools/create_lmdb_dataset.py +118 -0
- openocr/tools/data/__init__.py +94 -0
- openocr/tools/data/collate_fn.py +100 -0
- openocr/tools/data/lmdb_dataset.py +142 -0
- openocr/tools/data/lmdb_dataset_test.py +166 -0
- openocr/tools/data/multi_scale_sampler.py +177 -0
- openocr/tools/data/ratio_dataset.py +217 -0
- openocr/tools/data/ratio_dataset_test.py +273 -0
- openocr/tools/data/ratio_dataset_tvresize.py +213 -0
- openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
- openocr/tools/data/ratio_sampler.py +190 -0
- openocr/tools/data/simple_dataset.py +263 -0
- openocr/tools/data/strlmdb_dataset.py +143 -0
- openocr/tools/engine/__init__.py +5 -0
- openocr/tools/engine/config.py +158 -0
- openocr/tools/engine/trainer.py +621 -0
- openocr/tools/eval_rec.py +41 -0
- openocr/tools/eval_rec_all_ch.py +184 -0
- openocr/tools/eval_rec_all_en.py +206 -0
- openocr/tools/eval_rec_all_long.py +119 -0
- openocr/tools/eval_rec_all_long_simple.py +122 -0
- openocr/tools/export_rec.py +118 -0
- openocr/tools/infer/onnx_engine.py +65 -0
- openocr/tools/infer/predict_rec.py +140 -0
- openocr/tools/infer/utility.py +234 -0
- openocr/tools/infer_det.py +449 -0
- openocr/tools/infer_e2e.py +462 -0
- openocr/tools/infer_e2e_parallel.py +184 -0
- openocr/tools/infer_rec.py +371 -0
- openocr/tools/train_rec.py +37 -0
- openocr/tools/utility.py +45 -0
- openocr/tools/utils/EN_symbol_dict.txt +94 -0
- openocr/tools/utils/__init__.py +0 -0
- openocr/tools/utils/ckpt.py +87 -0
- openocr/tools/utils/dict/ar_dict.txt +117 -0
- openocr/tools/utils/dict/arabic_dict.txt +161 -0
- openocr/tools/utils/dict/be_dict.txt +145 -0
- openocr/tools/utils/dict/bg_dict.txt +140 -0
- openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
- openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
- openocr/tools/utils/dict/devanagari_dict.txt +167 -0
- openocr/tools/utils/dict/en_dict.txt +63 -0
- openocr/tools/utils/dict/fa_dict.txt +136 -0
- openocr/tools/utils/dict/french_dict.txt +136 -0
- openocr/tools/utils/dict/german_dict.txt +143 -0
- openocr/tools/utils/dict/hi_dict.txt +162 -0
- openocr/tools/utils/dict/it_dict.txt +118 -0
- openocr/tools/utils/dict/japan_dict.txt +4399 -0
- openocr/tools/utils/dict/ka_dict.txt +153 -0
- openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
- openocr/tools/utils/dict/korean_dict.txt +3688 -0
- openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
- openocr/tools/utils/dict/latin_dict.txt +185 -0
- openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
- openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
- openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
- openocr/tools/utils/dict/mr_dict.txt +153 -0
- openocr/tools/utils/dict/ne_dict.txt +153 -0
- openocr/tools/utils/dict/oc_dict.txt +96 -0
- openocr/tools/utils/dict/pu_dict.txt +130 -0
- openocr/tools/utils/dict/rs_dict.txt +91 -0
- openocr/tools/utils/dict/rsc_dict.txt +134 -0
- openocr/tools/utils/dict/ru_dict.txt +125 -0
- openocr/tools/utils/dict/spin_dict.txt +68 -0
- openocr/tools/utils/dict/ta_dict.txt +128 -0
- openocr/tools/utils/dict/table_dict.txt +277 -0
- openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
- openocr/tools/utils/dict/table_structure_dict.txt +28 -0
- openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
- openocr/tools/utils/dict/te_dict.txt +151 -0
- openocr/tools/utils/dict/ug_dict.txt +114 -0
- openocr/tools/utils/dict/uk_dict.txt +142 -0
- openocr/tools/utils/dict/ur_dict.txt +137 -0
- openocr/tools/utils/dict/xi_dict.txt +110 -0
- openocr/tools/utils/dict90.txt +90 -0
- openocr/tools/utils/e2e_metric/Deteval.py +802 -0
- openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
- openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
- openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
- openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
- openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
- openocr/tools/utils/e2e_utils/visual.py +152 -0
- openocr/tools/utils/en_dict.txt +95 -0
- openocr/tools/utils/gen_label.py +68 -0
- openocr/tools/utils/ic15_dict.txt +36 -0
- openocr/tools/utils/logging.py +56 -0
- openocr/tools/utils/poly_nms.py +132 -0
- openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
- openocr/tools/utils/stats.py +58 -0
- openocr/tools/utils/utility.py +165 -0
- openocr/tools/utils/visual.py +117 -0
- openocr_python-0.0.2.dist-info/LICENCE +201 -0
- openocr_python-0.0.2.dist-info/METADATA +98 -0
- openocr_python-0.0.2.dist-info/RECORD +323 -0
- openocr_python-0.0.2.dist-info/WHEEL +5 -0
- openocr_python-0.0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from random import sample
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from .ctc_label_encode import BaseRecLabelEncode
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class VisionLANLabelEncode(BaseRecLabelEncode):
|
|
9
|
+
"""Convert between text-label and text-index."""
|
|
10
|
+
|
|
11
|
+
def __init__(self,
|
|
12
|
+
max_text_length,
|
|
13
|
+
character_dict_path=None,
|
|
14
|
+
use_space_char=False,
|
|
15
|
+
**kwargs):
|
|
16
|
+
super(VisionLANLabelEncode,
|
|
17
|
+
self).__init__(max_text_length, character_dict_path,
|
|
18
|
+
use_space_char)
|
|
19
|
+
self.dict = {}
|
|
20
|
+
for i, char in enumerate(self.character):
|
|
21
|
+
self.dict[char] = i
|
|
22
|
+
|
|
23
|
+
def __call__(self, data):
|
|
24
|
+
text = data['label'] # original string
|
|
25
|
+
# generate occluded text
|
|
26
|
+
len_str = len(text)
|
|
27
|
+
if len_str <= 0:
|
|
28
|
+
return None
|
|
29
|
+
change_num = 1
|
|
30
|
+
order = list(range(len_str))
|
|
31
|
+
change_id = sample(order, change_num)[0]
|
|
32
|
+
label_sub = text[change_id]
|
|
33
|
+
if change_id == (len_str - 1):
|
|
34
|
+
label_res = text[:change_id]
|
|
35
|
+
elif change_id == 0:
|
|
36
|
+
label_res = text[1:]
|
|
37
|
+
else:
|
|
38
|
+
label_res = text[:change_id] + text[change_id + 1:]
|
|
39
|
+
|
|
40
|
+
data['label_res'] = label_res # remaining string
|
|
41
|
+
data['label_sub'] = label_sub # occluded character
|
|
42
|
+
data['label_id'] = change_id # character index
|
|
43
|
+
# encode label
|
|
44
|
+
text = self.encode(text)
|
|
45
|
+
if text is None:
|
|
46
|
+
return None
|
|
47
|
+
text = [i + 1 for i in text]
|
|
48
|
+
data['length'] = np.array(len(text))
|
|
49
|
+
text = text + [0] * (self.max_text_len + 1 - len(text))
|
|
50
|
+
data['label'] = np.array(text)
|
|
51
|
+
label_res = self.encode(label_res)
|
|
52
|
+
label_sub = self.encode(label_sub)
|
|
53
|
+
if label_res is None:
|
|
54
|
+
label_res = []
|
|
55
|
+
else:
|
|
56
|
+
label_res = [i + 1 for i in label_res]
|
|
57
|
+
if label_sub is None:
|
|
58
|
+
label_sub = []
|
|
59
|
+
else:
|
|
60
|
+
label_sub = [i + 1 for i in label_sub]
|
|
61
|
+
data['length_res'] = np.array(len(label_res))
|
|
62
|
+
data['length_sub'] = np.array(len(label_sub))
|
|
63
|
+
label_res = label_res + [0] * (self.max_text_len - len(label_res))
|
|
64
|
+
label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
|
|
65
|
+
data['label_res'] = np.array(label_res)
|
|
66
|
+
data['label_sub'] = np.array(label_sub)
|
|
67
|
+
return data
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import lmdb
|
|
3
|
+
import cv2
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
import numpy as np
|
|
6
|
+
import io
|
|
7
|
+
from PIL import Image
|
|
8
|
+
""" a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_datalist(data_dir, data_path, max_len):
|
|
12
|
+
"""
|
|
13
|
+
获取训练和验证的数据list
|
|
14
|
+
:param data_dir: 数据集根目录
|
|
15
|
+
:param data_path: 训练的dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’
|
|
16
|
+
:return:
|
|
17
|
+
"""
|
|
18
|
+
train_data = []
|
|
19
|
+
if isinstance(data_path, list):
|
|
20
|
+
for p in data_path:
|
|
21
|
+
train_data.extend(get_datalist(data_dir, p, max_len))
|
|
22
|
+
else:
|
|
23
|
+
with open(data_path, 'r', encoding='utf-8') as f:
|
|
24
|
+
for line in tqdm(f.readlines(),
|
|
25
|
+
desc=f'load data from {data_path}'):
|
|
26
|
+
line = (line.strip('\n').replace('.jpg ', '.jpg\t').replace(
|
|
27
|
+
'.png ', '.png\t').split('\t'))
|
|
28
|
+
if len(line) > 1:
|
|
29
|
+
img_path = os.path.join(data_dir, line[0].strip(' '))
|
|
30
|
+
label = line[1]
|
|
31
|
+
if len(label) > max_len:
|
|
32
|
+
continue
|
|
33
|
+
if os.path.exists(
|
|
34
|
+
img_path) and os.path.getsize(img_path) > 0:
|
|
35
|
+
train_data.append([str(img_path), label])
|
|
36
|
+
return train_data
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def checkImageIsValid(imageBin):
|
|
40
|
+
if imageBin is None:
|
|
41
|
+
return False
|
|
42
|
+
imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
|
|
43
|
+
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
|
|
44
|
+
imgH, imgW = img.shape[0], img.shape[1]
|
|
45
|
+
if imgH * imgW == 0:
|
|
46
|
+
return False
|
|
47
|
+
return True
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def writeCache(env, cache):
|
|
51
|
+
with env.begin(write=True) as txn:
|
|
52
|
+
for k, v in cache.items():
|
|
53
|
+
txn.put(k, v)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def createDataset(data_list, outputPath, checkValid=True):
|
|
57
|
+
"""
|
|
58
|
+
Create LMDB dataset for training and evaluation.
|
|
59
|
+
ARGS:
|
|
60
|
+
inputPath : input folder path where starts imagePath
|
|
61
|
+
outputPath : LMDB output path
|
|
62
|
+
gtFile : list of image path and label
|
|
63
|
+
checkValid : if true, check the validity of every image
|
|
64
|
+
"""
|
|
65
|
+
os.makedirs(outputPath, exist_ok=True)
|
|
66
|
+
env = lmdb.open(outputPath, map_size=1099511627776)
|
|
67
|
+
cache = {}
|
|
68
|
+
cnt = 1
|
|
69
|
+
for imagePath, label in tqdm(data_list,
|
|
70
|
+
desc=f'make dataset, save to {outputPath}'):
|
|
71
|
+
with open(imagePath, 'rb') as f:
|
|
72
|
+
imageBin = f.read()
|
|
73
|
+
buf = io.BytesIO(imageBin)
|
|
74
|
+
w, h = Image.open(buf).size
|
|
75
|
+
if checkValid:
|
|
76
|
+
try:
|
|
77
|
+
if not checkImageIsValid(imageBin):
|
|
78
|
+
print('%s is not a valid image' % imagePath)
|
|
79
|
+
continue
|
|
80
|
+
except:
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
imageKey = 'image-%09d'.encode() % cnt
|
|
84
|
+
labelKey = 'label-%09d'.encode() % cnt
|
|
85
|
+
whKey = 'wh-%09d'.encode() % cnt
|
|
86
|
+
cache[imageKey] = imageBin
|
|
87
|
+
cache[labelKey] = label.encode()
|
|
88
|
+
cache[whKey] = (str(w) + '_' + str(h)).encode()
|
|
89
|
+
|
|
90
|
+
if cnt % 1000 == 0:
|
|
91
|
+
writeCache(env, cache)
|
|
92
|
+
cache = {}
|
|
93
|
+
cnt += 1
|
|
94
|
+
nSamples = cnt - 1
|
|
95
|
+
cache['num-samples'.encode()] = str(nSamples).encode()
|
|
96
|
+
writeCache(env, cache)
|
|
97
|
+
print('Created dataset with %d samples' % nSamples)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
if __name__ == '__main__':
|
|
101
|
+
data_dir = './Union14M-L/'
|
|
102
|
+
label_file_list = [
|
|
103
|
+
'./Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_challenging.jsonl.txt',
|
|
104
|
+
'./Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_easy.jsonl.txt',
|
|
105
|
+
'./Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_hard.jsonl.txt',
|
|
106
|
+
'./Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_medium.jsonl.txt',
|
|
107
|
+
'./Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_normal.jsonl.txt'
|
|
108
|
+
]
|
|
109
|
+
save_path_root = './Union14M-L-LMDB-Filtered/'
|
|
110
|
+
|
|
111
|
+
for data_list in label_file_list:
|
|
112
|
+
save_path = save_path_root + data_list.split('/')[-1].split(
|
|
113
|
+
'.')[0] + '/'
|
|
114
|
+
os.makedirs(save_path, exist_ok=True)
|
|
115
|
+
print(save_path)
|
|
116
|
+
train_data_list = get_datalist(data_dir, data_list, 800)
|
|
117
|
+
|
|
118
|
+
createDataset(train_data_list, save_path)
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
|
5
|
+
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
|
6
|
+
|
|
7
|
+
import copy
|
|
8
|
+
|
|
9
|
+
from torch.utils.data import DataLoader, DistributedSampler
|
|
10
|
+
|
|
11
|
+
from tools.data.lmdb_dataset import LMDBDataSet
|
|
12
|
+
from tools.data.lmdb_dataset_test import LMDBDataSetTest
|
|
13
|
+
from tools.data.multi_scale_sampler import MultiScaleSampler
|
|
14
|
+
from tools.data.ratio_dataset import RatioDataSet
|
|
15
|
+
from tools.data.ratio_dataset_test import RatioDataSetTest
|
|
16
|
+
from tools.data.ratio_dataset_tvresize_test import RatioDataSetTVResizeTest
|
|
17
|
+
from tools.data.ratio_dataset_tvresize import RatioDataSetTVResize
|
|
18
|
+
from tools.data.ratio_sampler import RatioSampler
|
|
19
|
+
from tools.data.simple_dataset import MultiScaleDataSet, SimpleDataSet
|
|
20
|
+
from tools.data.strlmdb_dataset import STRLMDBDataSet
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
'build_dataloader',
|
|
24
|
+
'transform',
|
|
25
|
+
'create_operators',
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def build_dataloader(config, mode, logger, seed=None, epoch=3):
|
|
30
|
+
config = copy.deepcopy(config)
|
|
31
|
+
|
|
32
|
+
support_dict = [
|
|
33
|
+
'SimpleDataSet', 'LMDBDataSet', 'MultiScaleDataSet', 'STRLMDBDataSet',
|
|
34
|
+
'LMDBDataSetTest', 'RatioDataSet', 'RatioDataSetTest',
|
|
35
|
+
'RatioDataSetTVResize', 'RatioDataSetTVResizeTest'
|
|
36
|
+
]
|
|
37
|
+
module_name = config[mode]['dataset']['name']
|
|
38
|
+
assert module_name in support_dict, Exception(
|
|
39
|
+
'DataSet only support {}/{}'.format(support_dict, module_name))
|
|
40
|
+
assert mode in ['Train', 'Eval',
|
|
41
|
+
'Test'], 'Mode should be Train, Eval or Test.'
|
|
42
|
+
|
|
43
|
+
dataset = eval(module_name)(config, mode, logger, seed, epoch=epoch)
|
|
44
|
+
loader_config = config[mode]['loader']
|
|
45
|
+
batch_size = loader_config['batch_size_per_card']
|
|
46
|
+
drop_last = loader_config['drop_last']
|
|
47
|
+
shuffle = loader_config['shuffle']
|
|
48
|
+
num_workers = loader_config['num_workers']
|
|
49
|
+
if 'pin_memory' in loader_config.keys():
|
|
50
|
+
pin_memory = loader_config['use_shared_memory']
|
|
51
|
+
else:
|
|
52
|
+
pin_memory = False
|
|
53
|
+
|
|
54
|
+
sampler = None
|
|
55
|
+
batch_sampler = None
|
|
56
|
+
if 'sampler' in config[mode]:
|
|
57
|
+
config_sampler = config[mode]['sampler']
|
|
58
|
+
sampler_name = config_sampler.pop('name')
|
|
59
|
+
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
|
|
60
|
+
elif config['Global']['distributed'] and mode == 'Train':
|
|
61
|
+
sampler = DistributedSampler(dataset=dataset, shuffle=shuffle)
|
|
62
|
+
|
|
63
|
+
if 'collate_fn' in loader_config:
|
|
64
|
+
from . import collate_fn
|
|
65
|
+
|
|
66
|
+
collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
|
|
67
|
+
else:
|
|
68
|
+
collate_fn = None
|
|
69
|
+
if batch_sampler is None:
|
|
70
|
+
data_loader = DataLoader(
|
|
71
|
+
dataset=dataset,
|
|
72
|
+
sampler=sampler,
|
|
73
|
+
num_workers=num_workers,
|
|
74
|
+
pin_memory=pin_memory,
|
|
75
|
+
collate_fn=collate_fn,
|
|
76
|
+
batch_size=batch_size,
|
|
77
|
+
drop_last=drop_last,
|
|
78
|
+
)
|
|
79
|
+
else:
|
|
80
|
+
data_loader = DataLoader(
|
|
81
|
+
dataset=dataset,
|
|
82
|
+
batch_sampler=batch_sampler,
|
|
83
|
+
num_workers=num_workers,
|
|
84
|
+
pin_memory=pin_memory,
|
|
85
|
+
collate_fn=collate_fn,
|
|
86
|
+
)
|
|
87
|
+
if len(data_loader) == 0:
|
|
88
|
+
logger.error(
|
|
89
|
+
f'No Images in {mode.lower()} dataloader, please ensure\n'
|
|
90
|
+
'\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n'
|
|
91
|
+
'\t2. The annotation file and path in the configuration file are provided normally.\n'
|
|
92
|
+
'\t3. The BatchSize is large than images.')
|
|
93
|
+
sys.exit()
|
|
94
|
+
return data_loader
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import numbers
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DictCollator(object):
|
|
9
|
+
"""data batch."""
|
|
10
|
+
|
|
11
|
+
def __call__(self, batch):
|
|
12
|
+
data_dict = defaultdict(list)
|
|
13
|
+
to_tensor_keys = []
|
|
14
|
+
for sample in batch:
|
|
15
|
+
for k, v in sample.items():
|
|
16
|
+
if isinstance(v, (np.ndarray, torch.Tensor, numbers.Number)):
|
|
17
|
+
if k not in to_tensor_keys:
|
|
18
|
+
to_tensor_keys.append(k)
|
|
19
|
+
data_dict[k].append(v)
|
|
20
|
+
for k in to_tensor_keys:
|
|
21
|
+
data_dict[k] = torch.from_numpy(data_dict[k])
|
|
22
|
+
return data_dict
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ListCollator(object):
|
|
26
|
+
"""data batch."""
|
|
27
|
+
|
|
28
|
+
def __call__(self, batch):
|
|
29
|
+
data_dict = defaultdict(list)
|
|
30
|
+
to_tensor_idxs = []
|
|
31
|
+
for sample in batch:
|
|
32
|
+
for idx, v in enumerate(sample):
|
|
33
|
+
if isinstance(v, (np.ndarray, torch.Tensor, numbers.Number)):
|
|
34
|
+
if idx not in to_tensor_idxs:
|
|
35
|
+
to_tensor_idxs.append(idx)
|
|
36
|
+
data_dict[idx].append(v)
|
|
37
|
+
for idx in to_tensor_idxs:
|
|
38
|
+
data_dict[idx] = torch.from_numpy(data_dict[idx])
|
|
39
|
+
return list(data_dict.values())
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SSLRotateCollate(object):
|
|
43
|
+
"""
|
|
44
|
+
bach: [
|
|
45
|
+
[(4*3xH*W), (4,)]
|
|
46
|
+
[(4*3xH*W), (4,)]
|
|
47
|
+
...
|
|
48
|
+
]
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __call__(self, batch):
|
|
52
|
+
output = [np.concatenate(d, axis=0) for d in zip(*batch)]
|
|
53
|
+
return output
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class DyMaskCollator(object):
|
|
57
|
+
"""
|
|
58
|
+
batch: [
|
|
59
|
+
image [batch_size, channel, maxHinbatch, maxWinbatch]
|
|
60
|
+
image_mask [batch_size, channel, maxHinbatch, maxWinbatch]
|
|
61
|
+
label [batch_size, maxLabelLen]
|
|
62
|
+
label_mask [batch_size, maxLabelLen]
|
|
63
|
+
...
|
|
64
|
+
]
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __call__(self, batch):
|
|
68
|
+
max_width, max_height, max_length = 0, 0, 0
|
|
69
|
+
bs, channel = len(batch), batch[0][0].shape[0]
|
|
70
|
+
proper_items = []
|
|
71
|
+
for item in batch:
|
|
72
|
+
if item[0].shape[1] * max_width > 1600 * 320 or item[0].shape[
|
|
73
|
+
2] * max_height > 1600 * 320:
|
|
74
|
+
continue
|
|
75
|
+
max_height = item[0].shape[
|
|
76
|
+
1] if item[0].shape[1] > max_height else max_height
|
|
77
|
+
max_width = item[0].shape[
|
|
78
|
+
2] if item[0].shape[2] > max_width else max_width
|
|
79
|
+
max_length = len(
|
|
80
|
+
item[1]) if len(item[1]) > max_length else max_length
|
|
81
|
+
proper_items.append(item)
|
|
82
|
+
|
|
83
|
+
images, image_masks = np.zeros(
|
|
84
|
+
(len(proper_items), channel, max_height, max_width),
|
|
85
|
+
dtype='float32'), np.zeros(
|
|
86
|
+
(len(proper_items), 1, max_height, max_width), dtype='float32')
|
|
87
|
+
labels, label_masks = np.zeros((len(proper_items), max_length),
|
|
88
|
+
dtype='int64'), np.zeros(
|
|
89
|
+
(len(proper_items), max_length),
|
|
90
|
+
dtype='int64')
|
|
91
|
+
|
|
92
|
+
for i in range(len(proper_items)):
|
|
93
|
+
_, h, w = proper_items[i][0].shape
|
|
94
|
+
images[i][:, :h, :w] = proper_items[i][0]
|
|
95
|
+
image_masks[i][:, :h, :w] = 1
|
|
96
|
+
l = len(proper_items[i][1])
|
|
97
|
+
labels[i][:l] = proper_items[i][1]
|
|
98
|
+
label_masks[i][:l] = 1
|
|
99
|
+
|
|
100
|
+
return images, image_masks, labels, label_masks
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import cv2
|
|
4
|
+
import lmdb
|
|
5
|
+
import numpy as np
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
|
|
8
|
+
from openrec.preprocess import create_operators, transform
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LMDBDataSet(Dataset):
|
|
12
|
+
|
|
13
|
+
def __init__(self, config, mode, logger, seed=None, epoch=1):
|
|
14
|
+
super(LMDBDataSet, self).__init__()
|
|
15
|
+
|
|
16
|
+
global_config = config['Global']
|
|
17
|
+
dataset_config = config[mode]['dataset']
|
|
18
|
+
loader_config = config[mode]['loader']
|
|
19
|
+
loader_config['batch_size_per_card']
|
|
20
|
+
data_dir = dataset_config['data_dir']
|
|
21
|
+
self.do_shuffle = loader_config['shuffle']
|
|
22
|
+
|
|
23
|
+
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
|
|
24
|
+
logger.info(f'Initialize indexs of datasets: {data_dir}')
|
|
25
|
+
self.data_idx_order_list = self.dataset_traversal()
|
|
26
|
+
if self.do_shuffle:
|
|
27
|
+
np.random.shuffle(self.data_idx_order_list)
|
|
28
|
+
self.ops = create_operators(dataset_config['transforms'],
|
|
29
|
+
global_config)
|
|
30
|
+
self.ext_op_transform_idx = dataset_config.get('ext_op_transform_idx',
|
|
31
|
+
1)
|
|
32
|
+
|
|
33
|
+
ratio_list = dataset_config.get('ratio_list', [1.0])
|
|
34
|
+
self.need_reset = True in [x < 1 for x in ratio_list]
|
|
35
|
+
|
|
36
|
+
def load_hierarchical_lmdb_dataset(self, data_dir):
|
|
37
|
+
lmdb_sets = {}
|
|
38
|
+
dataset_idx = 0
|
|
39
|
+
for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
|
|
40
|
+
if not dirnames:
|
|
41
|
+
env = lmdb.open(
|
|
42
|
+
dirpath,
|
|
43
|
+
max_readers=32,
|
|
44
|
+
readonly=True,
|
|
45
|
+
lock=False,
|
|
46
|
+
readahead=False,
|
|
47
|
+
meminit=False,
|
|
48
|
+
)
|
|
49
|
+
txn = env.begin(write=False)
|
|
50
|
+
num_samples = int(txn.get('num-samples'.encode()))
|
|
51
|
+
lmdb_sets[dataset_idx] = {
|
|
52
|
+
'dirpath': dirpath,
|
|
53
|
+
'env': env,
|
|
54
|
+
'txn': txn,
|
|
55
|
+
'num_samples': num_samples,
|
|
56
|
+
}
|
|
57
|
+
dataset_idx += 1
|
|
58
|
+
return lmdb_sets
|
|
59
|
+
|
|
60
|
+
def dataset_traversal(self):
|
|
61
|
+
lmdb_num = len(self.lmdb_sets)
|
|
62
|
+
total_sample_num = 0
|
|
63
|
+
for lno in range(lmdb_num):
|
|
64
|
+
total_sample_num += self.lmdb_sets[lno]['num_samples']
|
|
65
|
+
data_idx_order_list = np.zeros((total_sample_num, 2))
|
|
66
|
+
beg_idx = 0
|
|
67
|
+
for lno in range(lmdb_num):
|
|
68
|
+
tmp_sample_num = self.lmdb_sets[lno]['num_samples']
|
|
69
|
+
end_idx = beg_idx + tmp_sample_num
|
|
70
|
+
data_idx_order_list[beg_idx:end_idx, 0] = lno
|
|
71
|
+
data_idx_order_list[beg_idx:end_idx,
|
|
72
|
+
1] = list(range(tmp_sample_num))
|
|
73
|
+
data_idx_order_list[beg_idx:end_idx, 1] += 1
|
|
74
|
+
beg_idx = beg_idx + tmp_sample_num
|
|
75
|
+
return data_idx_order_list
|
|
76
|
+
|
|
77
|
+
def get_img_data(self, value):
|
|
78
|
+
"""get_img_data."""
|
|
79
|
+
if not value:
|
|
80
|
+
return None
|
|
81
|
+
imgdata = np.frombuffer(value, dtype='uint8')
|
|
82
|
+
if imgdata is None:
|
|
83
|
+
return None
|
|
84
|
+
imgori = cv2.imdecode(imgdata, 1)
|
|
85
|
+
if imgori is None:
|
|
86
|
+
return None
|
|
87
|
+
return imgori
|
|
88
|
+
|
|
89
|
+
def get_ext_data(self):
|
|
90
|
+
ext_data_num = 0
|
|
91
|
+
for op in self.ops:
|
|
92
|
+
if hasattr(op, 'ext_data_num'):
|
|
93
|
+
ext_data_num = getattr(op, 'ext_data_num')
|
|
94
|
+
break
|
|
95
|
+
load_data_ops = self.ops[:self.ext_op_transform_idx]
|
|
96
|
+
ext_data = []
|
|
97
|
+
|
|
98
|
+
while len(ext_data) < ext_data_num:
|
|
99
|
+
lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(
|
|
100
|
+
len(self))]
|
|
101
|
+
lmdb_idx = int(lmdb_idx)
|
|
102
|
+
file_idx = int(file_idx)
|
|
103
|
+
sample_info = self.get_lmdb_sample_info(
|
|
104
|
+
self.lmdb_sets[lmdb_idx]['txn'], file_idx)
|
|
105
|
+
if sample_info is None:
|
|
106
|
+
continue
|
|
107
|
+
img, label = sample_info
|
|
108
|
+
data = {'image': img, 'label': label}
|
|
109
|
+
data = transform(data, load_data_ops)
|
|
110
|
+
if data is None:
|
|
111
|
+
continue
|
|
112
|
+
ext_data.append(data)
|
|
113
|
+
return ext_data
|
|
114
|
+
|
|
115
|
+
def get_lmdb_sample_info(self, txn, index):
|
|
116
|
+
label_key = 'label-%09d'.encode() % index
|
|
117
|
+
label = txn.get(label_key)
|
|
118
|
+
if label is None:
|
|
119
|
+
return None
|
|
120
|
+
label = label.decode('utf-8')
|
|
121
|
+
img_key = 'image-%09d'.encode() % index
|
|
122
|
+
imgbuf = txn.get(img_key)
|
|
123
|
+
return imgbuf, label
|
|
124
|
+
|
|
125
|
+
def __getitem__(self, idx):
|
|
126
|
+
lmdb_idx, file_idx = self.data_idx_order_list[idx]
|
|
127
|
+
lmdb_idx = int(lmdb_idx)
|
|
128
|
+
file_idx = int(file_idx)
|
|
129
|
+
sample_info = self.get_lmdb_sample_info(
|
|
130
|
+
self.lmdb_sets[lmdb_idx]['txn'], file_idx)
|
|
131
|
+
if sample_info is None:
|
|
132
|
+
return self.__getitem__(np.random.randint(self.__len__()))
|
|
133
|
+
img, label = sample_info
|
|
134
|
+
data = {'image': img, 'label': label}
|
|
135
|
+
data['ext_data'] = self.get_ext_data()
|
|
136
|
+
outs = transform(data, self.ops)
|
|
137
|
+
if outs is None:
|
|
138
|
+
return self.__getitem__(np.random.randint(self.__len__()))
|
|
139
|
+
return outs
|
|
140
|
+
|
|
141
|
+
def __len__(self):
|
|
142
|
+
return self.data_idx_order_list.shape[0]
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import re
|
|
3
|
+
import unicodedata
|
|
4
|
+
|
|
5
|
+
import lmdb
|
|
6
|
+
from PIL import Image
|
|
7
|
+
from torch.utils.data import Dataset
|
|
8
|
+
|
|
9
|
+
from openrec.preprocess import create_operators, transform
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CharsetAdapter:
|
|
13
|
+
"""Transforms labels according to the target charset."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, target_charset) -> None:
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.lowercase_only = target_charset == target_charset.lower()
|
|
18
|
+
self.uppercase_only = target_charset == target_charset.upper()
|
|
19
|
+
self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
|
|
20
|
+
|
|
21
|
+
def __call__(self, label):
|
|
22
|
+
if self.lowercase_only:
|
|
23
|
+
label = label.lower()
|
|
24
|
+
elif self.uppercase_only:
|
|
25
|
+
label = label.upper()
|
|
26
|
+
# Remove unsupported characters
|
|
27
|
+
label = self.unsupported.sub('', label)
|
|
28
|
+
return label
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class LMDBDataSetTest(Dataset):
|
|
32
|
+
"""Dataset interface to an LMDB database.
|
|
33
|
+
|
|
34
|
+
It supports both labelled and unlabelled datasets. For unlabelled datasets,
|
|
35
|
+
the image index itself is returned as the label. Unicode characters are
|
|
36
|
+
normalized by default. Case-sensitivity is inferred from the charset.
|
|
37
|
+
Labels are transformed according to the charset.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self,
|
|
41
|
+
config,
|
|
42
|
+
mode,
|
|
43
|
+
logger,
|
|
44
|
+
seed=None,
|
|
45
|
+
epoch=1,
|
|
46
|
+
gpu_i=0,
|
|
47
|
+
max_label_len: int = 25,
|
|
48
|
+
min_image_dim: int = 0,
|
|
49
|
+
remove_whitespace: bool = True,
|
|
50
|
+
normalize_unicode: bool = True,
|
|
51
|
+
unlabelled: bool = False,
|
|
52
|
+
transform=None):
|
|
53
|
+
dataset_config = config[mode]['dataset']
|
|
54
|
+
global_config = config['Global']
|
|
55
|
+
max_label_len = global_config['max_text_length']
|
|
56
|
+
self.root = dataset_config['data_dir']
|
|
57
|
+
self._env = None
|
|
58
|
+
self.unlabelled = unlabelled
|
|
59
|
+
self.transform = transform
|
|
60
|
+
self.labels = []
|
|
61
|
+
self.filtered_index_list = []
|
|
62
|
+
self.min_image_dim = min_image_dim
|
|
63
|
+
self.filter_label = dataset_config.get('filter_label',
|
|
64
|
+
True) #'data_dir']filter_label
|
|
65
|
+
character_dict_path = global_config.get('character_dict_path', None)
|
|
66
|
+
use_space_char = global_config.get('use_space_char', False)
|
|
67
|
+
if character_dict_path is None:
|
|
68
|
+
char_test = '0123456789abcdefghijklmnopqrstuvwxyz'
|
|
69
|
+
else:
|
|
70
|
+
char_test = ''
|
|
71
|
+
with open(character_dict_path, 'rb') as fin:
|
|
72
|
+
lines = fin.readlines()
|
|
73
|
+
for line in lines:
|
|
74
|
+
line = line.decode('utf-8').strip('\n').strip('\r\n')
|
|
75
|
+
char_test += line
|
|
76
|
+
if use_space_char:
|
|
77
|
+
char_test += ' '
|
|
78
|
+
self.ops = create_operators(dataset_config['transforms'],
|
|
79
|
+
global_config)
|
|
80
|
+
self.num_samples = self._preprocess_labels(char_test,
|
|
81
|
+
remove_whitespace,
|
|
82
|
+
normalize_unicode,
|
|
83
|
+
max_label_len,
|
|
84
|
+
min_image_dim)
|
|
85
|
+
|
|
86
|
+
def __del__(self):
|
|
87
|
+
if self._env is not None:
|
|
88
|
+
self._env.close()
|
|
89
|
+
self._env = None
|
|
90
|
+
|
|
91
|
+
def _create_env(self):
|
|
92
|
+
return lmdb.open(self.root,
|
|
93
|
+
max_readers=1,
|
|
94
|
+
readonly=True,
|
|
95
|
+
create=False,
|
|
96
|
+
readahead=False,
|
|
97
|
+
meminit=False,
|
|
98
|
+
lock=False)
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def env(self):
|
|
102
|
+
if self._env is None:
|
|
103
|
+
self._env = self._create_env()
|
|
104
|
+
return self._env
|
|
105
|
+
|
|
106
|
+
def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode,
|
|
107
|
+
max_label_len, min_image_dim):
|
|
108
|
+
charset_adapter = CharsetAdapter(charset)
|
|
109
|
+
with self._create_env() as env, env.begin() as txn:
|
|
110
|
+
num_samples = int(txn.get('num-samples'.encode()))
|
|
111
|
+
if self.unlabelled:
|
|
112
|
+
return num_samples
|
|
113
|
+
for index in range(num_samples):
|
|
114
|
+
index += 1 # lmdb starts with 1
|
|
115
|
+
label_key = f'label-{index:09d}'.encode()
|
|
116
|
+
label = txn.get(label_key).decode()
|
|
117
|
+
# Normally, whitespace is removed from the labels.
|
|
118
|
+
if remove_whitespace:
|
|
119
|
+
label = ''.join(label.split())
|
|
120
|
+
# Normalize unicode composites (if any) and convert to compatible ASCII characters
|
|
121
|
+
if self.filter_label:
|
|
122
|
+
# if normalize_unicode:
|
|
123
|
+
label = unicodedata.normalize('NFKD', label).encode(
|
|
124
|
+
'ascii', 'ignore').decode()
|
|
125
|
+
# Filter by length before removing unsupported characters. The original label might be too long.
|
|
126
|
+
if len(label) > max_label_len:
|
|
127
|
+
continue
|
|
128
|
+
|
|
129
|
+
if self.filter_label:
|
|
130
|
+
label = charset_adapter(label)
|
|
131
|
+
# We filter out samples which don't contain any supported characters
|
|
132
|
+
if not label:
|
|
133
|
+
continue
|
|
134
|
+
# Filter images that are too small.
|
|
135
|
+
if min_image_dim > 0:
|
|
136
|
+
img_key = f'image-{index:09d}'.encode()
|
|
137
|
+
img = txn.get(img_key)
|
|
138
|
+
data = {'image': img, 'label': label}
|
|
139
|
+
outs = transform(data, self.ops)
|
|
140
|
+
if outs is None:
|
|
141
|
+
continue
|
|
142
|
+
buf = io.BytesIO(img)
|
|
143
|
+
w, h = Image.open(buf).size
|
|
144
|
+
if w < self.min_image_dim or h < self.min_image_dim:
|
|
145
|
+
continue
|
|
146
|
+
self.labels.append(label)
|
|
147
|
+
self.filtered_index_list.append(index)
|
|
148
|
+
return len(self.labels)
|
|
149
|
+
|
|
150
|
+
def __len__(self):
|
|
151
|
+
return self.num_samples
|
|
152
|
+
|
|
153
|
+
def __getitem__(self, index):
|
|
154
|
+
if self.unlabelled:
|
|
155
|
+
label = index
|
|
156
|
+
else:
|
|
157
|
+
label = self.labels[index]
|
|
158
|
+
index = self.filtered_index_list[index]
|
|
159
|
+
|
|
160
|
+
img_key = f'image-{index:09d}'.encode()
|
|
161
|
+
with self.env.begin() as txn:
|
|
162
|
+
img = txn.get(img_key)
|
|
163
|
+
data = {'image': img, 'label': label}
|
|
164
|
+
outs = transform(data, self.ops)
|
|
165
|
+
|
|
166
|
+
return outs
|