openocr-python 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +11 -0
- openocr/configs/det/dbnet/repvit_db.yml +173 -0
- openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
- openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
- openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
- openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
- openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
- openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
- openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
- openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
- openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
- openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
- openocr/configs/rec/busnet/vit_busnet.yml +104 -0
- openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
- openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
- openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
- openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
- openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
- openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
- openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
- openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
- openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
- openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
- openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
- openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
- openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
- openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
- openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
- openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
- openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
- openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
- openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
- openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
- openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
- openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
- openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
- openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
- openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
- openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
- openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
- openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
- openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
- openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
- openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
- openocr/configs/rec/nrtr/nrtr.yml +107 -0
- openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
- openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
- openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
- openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
- openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
- openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
- openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
- openocr/configs/rec/parseq/vit_parseq.yml +100 -0
- openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
- openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
- openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
- openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
- openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
- openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
- openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
- openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
- openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
- openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
- openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
- openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
- openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
- openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
- openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
- openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
- openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
- openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
- openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
- openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
- openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
- openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
- openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
- openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
- openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
- openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
- openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
- openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
- openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
- openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
- openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
- openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
- openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
- openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
- openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
- openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
- openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
- openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
- openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
- openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
- openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
- openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
- openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
- openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
- openocr/demo_gradio.py +128 -0
- openocr/opendet/modeling/__init__.py +11 -0
- openocr/opendet/modeling/backbones/__init__.py +14 -0
- openocr/opendet/modeling/backbones/repvit.py +340 -0
- openocr/opendet/modeling/base_detector.py +69 -0
- openocr/opendet/modeling/heads/__init__.py +14 -0
- openocr/opendet/modeling/heads/db_head.py +73 -0
- openocr/opendet/modeling/necks/__init__.py +14 -0
- openocr/opendet/modeling/necks/db_fpn.py +609 -0
- openocr/opendet/postprocess/__init__.py +18 -0
- openocr/opendet/postprocess/db_postprocess.py +273 -0
- openocr/opendet/preprocess/__init__.py +154 -0
- openocr/opendet/preprocess/crop_resize.py +121 -0
- openocr/opendet/preprocess/db_resize_for_test.py +135 -0
- openocr/openrec/losses/__init__.py +62 -0
- openocr/openrec/losses/abinet_loss.py +42 -0
- openocr/openrec/losses/ar_loss.py +23 -0
- openocr/openrec/losses/cam_loss.py +48 -0
- openocr/openrec/losses/cdistnet_loss.py +34 -0
- openocr/openrec/losses/ce_loss.py +68 -0
- openocr/openrec/losses/cppd_loss.py +77 -0
- openocr/openrec/losses/ctc_loss.py +33 -0
- openocr/openrec/losses/igtr_loss.py +12 -0
- openocr/openrec/losses/lister_loss.py +14 -0
- openocr/openrec/losses/lpv_loss.py +30 -0
- openocr/openrec/losses/mgp_loss.py +34 -0
- openocr/openrec/losses/parseq_loss.py +12 -0
- openocr/openrec/losses/robustscanner_loss.py +20 -0
- openocr/openrec/losses/seed_loss.py +46 -0
- openocr/openrec/losses/smtr_loss.py +12 -0
- openocr/openrec/losses/srn_loss.py +40 -0
- openocr/openrec/losses/visionlan_loss.py +58 -0
- openocr/openrec/metrics/__init__.py +19 -0
- openocr/openrec/metrics/rec_metric.py +270 -0
- openocr/openrec/metrics/rec_metric_gtc.py +58 -0
- openocr/openrec/metrics/rec_metric_long.py +142 -0
- openocr/openrec/metrics/rec_metric_mgp.py +93 -0
- openocr/openrec/modeling/__init__.py +11 -0
- openocr/openrec/modeling/base_recognizer.py +69 -0
- openocr/openrec/modeling/common.py +238 -0
- openocr/openrec/modeling/decoders/__init__.py +109 -0
- openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
- openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
- openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
- openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
- openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
- openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
- openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
- openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
- openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
- openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
- openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
- openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
- openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
- openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
- openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
- openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
- openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
- openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
- openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
- openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
- openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
- openocr/openrec/modeling/encoders/__init__.py +39 -0
- openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
- openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
- openocr/openrec/modeling/encoders/convnextv2.py +213 -0
- openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
- openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
- openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
- openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
- openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
- openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
- openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
- openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
- openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
- openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
- openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
- openocr/openrec/modeling/encoders/repvit.py +338 -0
- openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
- openocr/openrec/modeling/encoders/svtrnet.py +574 -0
- openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
- openocr/openrec/modeling/encoders/svtrv2.py +470 -0
- openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
- openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
- openocr/openrec/modeling/encoders/vit.py +120 -0
- openocr/openrec/modeling/transforms/__init__.py +15 -0
- openocr/openrec/modeling/transforms/aster_tps.py +262 -0
- openocr/openrec/modeling/transforms/moran.py +136 -0
- openocr/openrec/modeling/transforms/tps.py +246 -0
- openocr/openrec/optimizer/__init__.py +73 -0
- openocr/openrec/optimizer/lr.py +227 -0
- openocr/openrec/postprocess/__init__.py +72 -0
- openocr/openrec/postprocess/abinet_postprocess.py +37 -0
- openocr/openrec/postprocess/ar_postprocess.py +63 -0
- openocr/openrec/postprocess/ce_postprocess.py +43 -0
- openocr/openrec/postprocess/char_postprocess.py +108 -0
- openocr/openrec/postprocess/cppd_postprocess.py +42 -0
- openocr/openrec/postprocess/ctc_postprocess.py +119 -0
- openocr/openrec/postprocess/igtr_postprocess.py +100 -0
- openocr/openrec/postprocess/lister_postprocess.py +59 -0
- openocr/openrec/postprocess/mgp_postprocess.py +143 -0
- openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
- openocr/openrec/postprocess/smtr_postprocess.py +73 -0
- openocr/openrec/postprocess/srn_postprocess.py +80 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
- openocr/openrec/preprocess/__init__.py +173 -0
- openocr/openrec/preprocess/abinet_aug.py +473 -0
- openocr/openrec/preprocess/abinet_label_encode.py +36 -0
- openocr/openrec/preprocess/ar_label_encode.py +36 -0
- openocr/openrec/preprocess/auto_augment.py +1012 -0
- openocr/openrec/preprocess/cam_label_encode.py +141 -0
- openocr/openrec/preprocess/ce_label_encode.py +116 -0
- openocr/openrec/preprocess/char_label_encode.py +36 -0
- openocr/openrec/preprocess/cppd_label_encode.py +173 -0
- openocr/openrec/preprocess/ctc_label_encode.py +124 -0
- openocr/openrec/preprocess/ep_label_encode.py +38 -0
- openocr/openrec/preprocess/igtr_label_encode.py +360 -0
- openocr/openrec/preprocess/mgp_label_encode.py +95 -0
- openocr/openrec/preprocess/parseq_aug.py +150 -0
- openocr/openrec/preprocess/rec_aug.py +211 -0
- openocr/openrec/preprocess/resize.py +534 -0
- openocr/openrec/preprocess/smtr_label_encode.py +125 -0
- openocr/openrec/preprocess/srn_label_encode.py +37 -0
- openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
- openocr/tools/create_lmdb_dataset.py +118 -0
- openocr/tools/data/__init__.py +94 -0
- openocr/tools/data/collate_fn.py +100 -0
- openocr/tools/data/lmdb_dataset.py +142 -0
- openocr/tools/data/lmdb_dataset_test.py +166 -0
- openocr/tools/data/multi_scale_sampler.py +177 -0
- openocr/tools/data/ratio_dataset.py +217 -0
- openocr/tools/data/ratio_dataset_test.py +273 -0
- openocr/tools/data/ratio_dataset_tvresize.py +213 -0
- openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
- openocr/tools/data/ratio_sampler.py +190 -0
- openocr/tools/data/simple_dataset.py +263 -0
- openocr/tools/data/strlmdb_dataset.py +143 -0
- openocr/tools/engine/__init__.py +5 -0
- openocr/tools/engine/config.py +158 -0
- openocr/tools/engine/trainer.py +621 -0
- openocr/tools/eval_rec.py +41 -0
- openocr/tools/eval_rec_all_ch.py +184 -0
- openocr/tools/eval_rec_all_en.py +206 -0
- openocr/tools/eval_rec_all_long.py +119 -0
- openocr/tools/eval_rec_all_long_simple.py +122 -0
- openocr/tools/export_rec.py +118 -0
- openocr/tools/infer/onnx_engine.py +65 -0
- openocr/tools/infer/predict_rec.py +140 -0
- openocr/tools/infer/utility.py +234 -0
- openocr/tools/infer_det.py +449 -0
- openocr/tools/infer_e2e.py +462 -0
- openocr/tools/infer_e2e_parallel.py +184 -0
- openocr/tools/infer_rec.py +371 -0
- openocr/tools/train_rec.py +37 -0
- openocr/tools/utility.py +45 -0
- openocr/tools/utils/EN_symbol_dict.txt +94 -0
- openocr/tools/utils/__init__.py +0 -0
- openocr/tools/utils/ckpt.py +87 -0
- openocr/tools/utils/dict/ar_dict.txt +117 -0
- openocr/tools/utils/dict/arabic_dict.txt +161 -0
- openocr/tools/utils/dict/be_dict.txt +145 -0
- openocr/tools/utils/dict/bg_dict.txt +140 -0
- openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
- openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
- openocr/tools/utils/dict/devanagari_dict.txt +167 -0
- openocr/tools/utils/dict/en_dict.txt +63 -0
- openocr/tools/utils/dict/fa_dict.txt +136 -0
- openocr/tools/utils/dict/french_dict.txt +136 -0
- openocr/tools/utils/dict/german_dict.txt +143 -0
- openocr/tools/utils/dict/hi_dict.txt +162 -0
- openocr/tools/utils/dict/it_dict.txt +118 -0
- openocr/tools/utils/dict/japan_dict.txt +4399 -0
- openocr/tools/utils/dict/ka_dict.txt +153 -0
- openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
- openocr/tools/utils/dict/korean_dict.txt +3688 -0
- openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
- openocr/tools/utils/dict/latin_dict.txt +185 -0
- openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
- openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
- openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
- openocr/tools/utils/dict/mr_dict.txt +153 -0
- openocr/tools/utils/dict/ne_dict.txt +153 -0
- openocr/tools/utils/dict/oc_dict.txt +96 -0
- openocr/tools/utils/dict/pu_dict.txt +130 -0
- openocr/tools/utils/dict/rs_dict.txt +91 -0
- openocr/tools/utils/dict/rsc_dict.txt +134 -0
- openocr/tools/utils/dict/ru_dict.txt +125 -0
- openocr/tools/utils/dict/spin_dict.txt +68 -0
- openocr/tools/utils/dict/ta_dict.txt +128 -0
- openocr/tools/utils/dict/table_dict.txt +277 -0
- openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
- openocr/tools/utils/dict/table_structure_dict.txt +28 -0
- openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
- openocr/tools/utils/dict/te_dict.txt +151 -0
- openocr/tools/utils/dict/ug_dict.txt +114 -0
- openocr/tools/utils/dict/uk_dict.txt +142 -0
- openocr/tools/utils/dict/ur_dict.txt +137 -0
- openocr/tools/utils/dict/xi_dict.txt +110 -0
- openocr/tools/utils/dict90.txt +90 -0
- openocr/tools/utils/e2e_metric/Deteval.py +802 -0
- openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
- openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
- openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
- openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
- openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
- openocr/tools/utils/e2e_utils/visual.py +152 -0
- openocr/tools/utils/en_dict.txt +95 -0
- openocr/tools/utils/gen_label.py +68 -0
- openocr/tools/utils/ic15_dict.txt +36 -0
- openocr/tools/utils/logging.py +56 -0
- openocr/tools/utils/poly_nms.py +132 -0
- openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
- openocr/tools/utils/stats.py +58 -0
- openocr/tools/utils/utility.py +165 -0
- openocr/tools/utils/visual.py +117 -0
- openocr_python-0.0.2.dist-info/LICENCE +201 -0
- openocr_python-0.0.2.dist-info/METADATA +98 -0
- openocr_python-0.0.2.dist-info/RECORD +323 -0
- openocr_python-0.0.2.dist-info/WHEEL +5 -0
- openocr_python-0.0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,449 @@
|
|
|
1
|
+
from __future__ import absolute_import
|
|
2
|
+
from __future__ import division
|
|
3
|
+
from __future__ import print_function
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import time
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
|
|
12
|
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
|
13
|
+
sys.path.append(__dir__)
|
|
14
|
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
|
15
|
+
|
|
16
|
+
os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
|
|
17
|
+
|
|
18
|
+
import cv2
|
|
19
|
+
import json
|
|
20
|
+
import torch
|
|
21
|
+
from tools.engine import Config
|
|
22
|
+
from tools.utility import ArgsParser
|
|
23
|
+
from tools.utils.ckpt import load_ckpt
|
|
24
|
+
from tools.utils.logging import get_logger
|
|
25
|
+
from tools.utils.utility import get_image_file_list
|
|
26
|
+
|
|
27
|
+
logger = get_logger()
|
|
28
|
+
|
|
29
|
+
root_dir = Path(__file__).resolve().parent
|
|
30
|
+
DEFAULT_CFG_PATH_DET = str(root_dir / '../configs/det/dbnet/repvit_db.yml')
|
|
31
|
+
|
|
32
|
+
MODEL_NAME_DET = './openocr_det_repvit_ch.pth' # 模型文件名称
|
|
33
|
+
DOWNLOAD_URL_DET = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth' # 模型文件 URL
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def check_and_download_model(model_name: str, url: str):
|
|
37
|
+
"""
|
|
38
|
+
检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
model_name (str): 模型文件的名称,例如 "model.pt"
|
|
42
|
+
url (str): 模型文件的下载地址
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
str: 模型文件的完整路径
|
|
46
|
+
"""
|
|
47
|
+
if os.path.exists(model_name):
|
|
48
|
+
return model_name
|
|
49
|
+
|
|
50
|
+
# 固定缓存路径为用户主目录下的 ".cache/openocr"
|
|
51
|
+
cache_dir = Path.home() / '.cache' / 'openocr'
|
|
52
|
+
model_path = cache_dir / model_name
|
|
53
|
+
|
|
54
|
+
# 如果模型文件已存在,直接返回路径
|
|
55
|
+
if model_path.exists():
|
|
56
|
+
logger.info(f'Model already exists at: {model_path}')
|
|
57
|
+
return str(model_path)
|
|
58
|
+
|
|
59
|
+
# 如果文件不存在,下载模型
|
|
60
|
+
logger.info(f'Model not found. Downloading from {url}...')
|
|
61
|
+
|
|
62
|
+
# 创建缓存目录(如果不存在)
|
|
63
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
# 下载文件
|
|
67
|
+
import urllib.request
|
|
68
|
+
with urllib.request.urlopen(url) as response, open(model_path,
|
|
69
|
+
'wb') as out_file:
|
|
70
|
+
out_file.write(response.read())
|
|
71
|
+
logger.info(f'Model downloaded and saved at: {model_path}')
|
|
72
|
+
return str(model_path)
|
|
73
|
+
|
|
74
|
+
except Exception as e:
|
|
75
|
+
logger.info(f'Error downloading the model: {e}')
|
|
76
|
+
raise
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def replace_batchnorm(net):
|
|
80
|
+
for child_name, child in net.named_children():
|
|
81
|
+
if hasattr(child, 'fuse'):
|
|
82
|
+
fused = child.fuse()
|
|
83
|
+
setattr(net, child_name, fused)
|
|
84
|
+
replace_batchnorm(fused)
|
|
85
|
+
elif isinstance(child, torch.nn.BatchNorm2d):
|
|
86
|
+
setattr(net, child_name, torch.nn.Identity())
|
|
87
|
+
else:
|
|
88
|
+
replace_batchnorm(child)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def padding_image(img, size=(640, 640)):
|
|
92
|
+
"""
|
|
93
|
+
Padding an image using OpenCV:
|
|
94
|
+
- If the image is smaller than the target size, pad it to 640x640.
|
|
95
|
+
- If the image is larger than the target size, split it into multiple 640x640 images and record positions.
|
|
96
|
+
|
|
97
|
+
:param image_path: Path to the input image.
|
|
98
|
+
:param output_dir: Directory to save the output images.
|
|
99
|
+
:param size: The target size for padding or splitting (default 640x640).
|
|
100
|
+
:return: List of tuples containing the coordinates of the top-left corner of each cropped 640x640 image.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
img_height, img_width = img.shape[:2]
|
|
104
|
+
target_width, target_height = size
|
|
105
|
+
|
|
106
|
+
# If image is smaller than target size, pad the image to 640x640
|
|
107
|
+
|
|
108
|
+
# Calculate padding amounts (top, bottom, left, right)
|
|
109
|
+
pad_top = 0
|
|
110
|
+
pad_bottom = target_height - img_height
|
|
111
|
+
pad_left = 0
|
|
112
|
+
pad_right = target_width - img_width
|
|
113
|
+
|
|
114
|
+
# Pad the image (white padding, border type: constant)
|
|
115
|
+
padded_img = cv2.copyMakeBorder(img,
|
|
116
|
+
pad_top,
|
|
117
|
+
pad_bottom,
|
|
118
|
+
pad_left,
|
|
119
|
+
pad_right,
|
|
120
|
+
cv2.BORDER_CONSTANT,
|
|
121
|
+
value=[0, 0, 0])
|
|
122
|
+
|
|
123
|
+
# Return the padded area positions (top-left and bottom-right coordinates of the original image)
|
|
124
|
+
return padded_img
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def resize_image(img, size=(640, 640), over_lap=64):
|
|
128
|
+
"""
|
|
129
|
+
Resize an image using OpenCV:
|
|
130
|
+
- If the image is smaller than the target size, pad it to 640x640.
|
|
131
|
+
- If the image is larger than the target size, split it into multiple 640x640 images and record positions.
|
|
132
|
+
|
|
133
|
+
:param image_path: Path to the input image.
|
|
134
|
+
:param output_dir: Directory to save the output images.
|
|
135
|
+
:param size: The target size for padding or splitting (default 640x640).
|
|
136
|
+
:return: List of tuples containing the coordinates of the top-left corner of each cropped 640x640 image.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
img_height, img_width = img.shape[:2]
|
|
140
|
+
target_width, target_height = size
|
|
141
|
+
|
|
142
|
+
# If image is smaller than target size, pad the image to 640x640
|
|
143
|
+
if img_width <= target_width and img_height <= target_height:
|
|
144
|
+
# Calculate padding amounts (top, bottom, left, right)
|
|
145
|
+
if img_width == target_width and img_height == target_height:
|
|
146
|
+
return [img], [[0, 0, img_width, img_height]]
|
|
147
|
+
padded_img = padding_image(img, size)
|
|
148
|
+
|
|
149
|
+
# Return the padded area positions (top-left and bottom-right coordinates of the original image)
|
|
150
|
+
return [padded_img], [[0, 0, img_width, img_height]]
|
|
151
|
+
|
|
152
|
+
img_height, img_width = img.shape[:2]
|
|
153
|
+
# If image is larger than or equal to target size, crop it into 640x640 tiles
|
|
154
|
+
crop_positions = []
|
|
155
|
+
count = 0
|
|
156
|
+
cropped_img_list = []
|
|
157
|
+
for top in range(0, img_height - over_lap, target_height - over_lap):
|
|
158
|
+
for left in range(0, img_width - over_lap, target_width - over_lap):
|
|
159
|
+
# Calculate the bottom and right boundaries for the crop
|
|
160
|
+
right = min(left + target_width, img_width)
|
|
161
|
+
bottom = min(top + target_height, img_height)
|
|
162
|
+
if right >= img_width:
|
|
163
|
+
right = img_width
|
|
164
|
+
left = max(0, right - target_width)
|
|
165
|
+
if bottom >= img_height:
|
|
166
|
+
bottom = img_height
|
|
167
|
+
top = max(0, bottom - target_height)
|
|
168
|
+
# Crop the image
|
|
169
|
+
cropped_img = img[top:bottom, left:right]
|
|
170
|
+
if bottom - top < target_height or right - left < target_width:
|
|
171
|
+
cropped_img = padding_image(cropped_img, size)
|
|
172
|
+
count += 1
|
|
173
|
+
cropped_img_list.append(cropped_img)
|
|
174
|
+
|
|
175
|
+
# Record the position of the cropped image
|
|
176
|
+
crop_positions.append([left, top, right, bottom])
|
|
177
|
+
|
|
178
|
+
return cropped_img_list, crop_positions
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def restore_preds(preds, crop_positions, original_size):
|
|
182
|
+
|
|
183
|
+
restored_pred = torch.zeros((1, 1, original_size[0], original_size[1]),
|
|
184
|
+
dtype=preds.dtype,
|
|
185
|
+
device=preds.device)
|
|
186
|
+
count = 0
|
|
187
|
+
for cropped_pred, (left, top, right, bottom) in zip(preds, crop_positions):
|
|
188
|
+
|
|
189
|
+
crop_height = bottom - top
|
|
190
|
+
crop_width = right - left
|
|
191
|
+
|
|
192
|
+
corp_vis_img = cropped_pred[:, :crop_height, :crop_width]
|
|
193
|
+
mask = corp_vis_img > 0.3
|
|
194
|
+
count += 1
|
|
195
|
+
restored_pred[:, :, top:top + crop_height, left:left +
|
|
196
|
+
crop_width] += mask[:, :crop_height, :crop_width].to(
|
|
197
|
+
preds.dtype)
|
|
198
|
+
|
|
199
|
+
return restored_pred
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def draw_det_res(dt_boxes, img, img_name, save_path):
|
|
203
|
+
src_im = img
|
|
204
|
+
for box in dt_boxes:
|
|
205
|
+
box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
|
|
206
|
+
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
|
|
207
|
+
if not os.path.exists(save_path):
|
|
208
|
+
os.makedirs(save_path)
|
|
209
|
+
save_path = os.path.join(save_path, os.path.basename(img_name))
|
|
210
|
+
cv2.imwrite(save_path, src_im)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def set_device(device, numId=0):
|
|
214
|
+
if device == 'gpu' and torch.cuda.is_available():
|
|
215
|
+
device = torch.device(f'cuda:{numId}')
|
|
216
|
+
else:
|
|
217
|
+
device = torch.device('cpu')
|
|
218
|
+
return device
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class OpenDetector(object):
|
|
222
|
+
|
|
223
|
+
def __init__(self, config=None, numId=0):
|
|
224
|
+
"""
|
|
225
|
+
初始化函数。
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
config (dict, optional): 配置文件,默认为None。如果为None,则使用默认配置文件。
|
|
229
|
+
numId (int, optional): 设备编号,默认为0。
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
None
|
|
233
|
+
|
|
234
|
+
Raises:
|
|
235
|
+
无
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
if config is None:
|
|
239
|
+
config = Config(DEFAULT_CFG_PATH_DET).cfg
|
|
240
|
+
config['Global']['pretrained_model'] = check_and_download_model(
|
|
241
|
+
MODEL_NAME_DET, DOWNLOAD_URL_DET)
|
|
242
|
+
|
|
243
|
+
from opendet.modeling import build_model as build_det_model
|
|
244
|
+
from opendet.postprocess import build_post_process
|
|
245
|
+
from opendet.preprocess import create_operators, transform
|
|
246
|
+
self.transform = transform
|
|
247
|
+
global_config = config['Global']
|
|
248
|
+
|
|
249
|
+
# build model
|
|
250
|
+
self.model = build_det_model(config['Architecture'])
|
|
251
|
+
self.model.eval()
|
|
252
|
+
load_ckpt(self.model, config)
|
|
253
|
+
replace_batchnorm(self.model.backbone)
|
|
254
|
+
self.device = set_device(config['Global']['device'], numId=numId)
|
|
255
|
+
self.model.to(device=self.device)
|
|
256
|
+
|
|
257
|
+
# create data ops
|
|
258
|
+
transforms = []
|
|
259
|
+
for op in config['Eval']['dataset']['transforms']:
|
|
260
|
+
op_name = list(op)[0]
|
|
261
|
+
if 'Label' in op_name:
|
|
262
|
+
continue
|
|
263
|
+
elif op_name == 'KeepKeys':
|
|
264
|
+
op[op_name]['keep_keys'] = ['image', 'shape']
|
|
265
|
+
transforms.append(op)
|
|
266
|
+
|
|
267
|
+
self.ops = create_operators(transforms, global_config)
|
|
268
|
+
|
|
269
|
+
# build post process
|
|
270
|
+
self.post_process_class = build_post_process(config['PostProcess'],
|
|
271
|
+
global_config)
|
|
272
|
+
|
|
273
|
+
def crop_infer(
|
|
274
|
+
self,
|
|
275
|
+
img_path=None,
|
|
276
|
+
img_numpy_list=None,
|
|
277
|
+
img_numpy=None,
|
|
278
|
+
):
|
|
279
|
+
if img_numpy is not None:
|
|
280
|
+
img_numpy_list = [img_numpy]
|
|
281
|
+
num_img = 1
|
|
282
|
+
elif img_path is not None:
|
|
283
|
+
num_img = len(img_path)
|
|
284
|
+
elif img_numpy_list is not None:
|
|
285
|
+
num_img = len(img_numpy_list)
|
|
286
|
+
else:
|
|
287
|
+
raise Exception('No input image path or numpy array.')
|
|
288
|
+
results = []
|
|
289
|
+
for img_idx in range(num_img):
|
|
290
|
+
if img_numpy_list is not None:
|
|
291
|
+
img = img_numpy_list[img_idx]
|
|
292
|
+
data = {'image': img}
|
|
293
|
+
elif img_path is not None:
|
|
294
|
+
with open(img_path[img_idx], 'rb') as f:
|
|
295
|
+
img = f.read()
|
|
296
|
+
data = {'image': img}
|
|
297
|
+
data = self.transform(data, self.ops[:1])
|
|
298
|
+
src_img_ori = data['image']
|
|
299
|
+
img_height, img_width = src_img_ori.shape[:2]
|
|
300
|
+
|
|
301
|
+
target_size = 640
|
|
302
|
+
over_lap = 64
|
|
303
|
+
if img_height > img_width:
|
|
304
|
+
r_h = target_size * 2 - over_lap
|
|
305
|
+
r_w = img_width * (target_size * 2 - over_lap) // img_height
|
|
306
|
+
else:
|
|
307
|
+
r_w = target_size * 2 - over_lap
|
|
308
|
+
r_h = img_height * (target_size * 2 - over_lap) // img_width
|
|
309
|
+
src_img = cv2.resize(src_img_ori, (r_w, r_h))
|
|
310
|
+
shape_list_ori = np.array([[
|
|
311
|
+
img_height, img_width,
|
|
312
|
+
float(r_h) / img_height,
|
|
313
|
+
float(r_w) / img_width
|
|
314
|
+
]])
|
|
315
|
+
img_height, img_width = src_img.shape[:2]
|
|
316
|
+
cropped_img_list, crop_positions = resize_image(src_img,
|
|
317
|
+
size=(target_size,
|
|
318
|
+
target_size),
|
|
319
|
+
over_lap=over_lap)
|
|
320
|
+
|
|
321
|
+
image_list = []
|
|
322
|
+
shape_list = []
|
|
323
|
+
for img in cropped_img_list:
|
|
324
|
+
batch_i = self.transform({'image': img}, self.ops[-3:-1])
|
|
325
|
+
image_list.append(batch_i['image'])
|
|
326
|
+
shape_list.append([640, 640, 1, 1])
|
|
327
|
+
images = np.array(image_list)
|
|
328
|
+
shape_list = np.array(shape_list)
|
|
329
|
+
images = torch.from_numpy(images).to(device=self.device)
|
|
330
|
+
|
|
331
|
+
t_start = time.time()
|
|
332
|
+
preds = self.model(images)
|
|
333
|
+
torch.cuda.synchronize()
|
|
334
|
+
t_cost = time.time() - t_start
|
|
335
|
+
|
|
336
|
+
preds['maps'] = restore_preds(preds['maps'], crop_positions,
|
|
337
|
+
(img_height, img_width))
|
|
338
|
+
post_result = self.post_process_class(preds, shape_list_ori)
|
|
339
|
+
info = {'boxes': post_result[0]['points'], 'elapse': t_cost}
|
|
340
|
+
results.append(info)
|
|
341
|
+
return results
|
|
342
|
+
|
|
343
|
+
def __call__(self,
|
|
344
|
+
img_path=None,
|
|
345
|
+
img_numpy_list=None,
|
|
346
|
+
img_numpy=None,
|
|
347
|
+
return_mask=False):
|
|
348
|
+
"""
|
|
349
|
+
对输入图像进行处理,并返回处理结果。
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
img_path (str, optional): 图像文件路径。默认为 None。
|
|
353
|
+
img_numpy_list (list, optional): 图像数据列表,每个元素为 numpy 数组。默认为 None。
|
|
354
|
+
img_numpy (numpy.ndarray, optional): 图像数据,numpy 数组格式。默认为 None。
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
list: 包含处理结果的列表。每个元素为一个字典,包含 'boxes' 和 'elapse' 两个键。
|
|
358
|
+
'boxes' 的值为检测到的目标框点集,'elapse' 的值为处理时间。
|
|
359
|
+
|
|
360
|
+
Raises:
|
|
361
|
+
Exception: 若没有提供图像路径或 numpy 数组,则抛出异常。
|
|
362
|
+
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
if img_numpy is not None:
|
|
366
|
+
img_numpy_list = [img_numpy]
|
|
367
|
+
num_img = 1
|
|
368
|
+
elif img_path is not None:
|
|
369
|
+
img_path = get_image_file_list(img_path)
|
|
370
|
+
num_img = len(img_path)
|
|
371
|
+
elif img_numpy_list is not None:
|
|
372
|
+
num_img = len(img_numpy_list)
|
|
373
|
+
else:
|
|
374
|
+
raise Exception('No input image path or numpy array.')
|
|
375
|
+
results = []
|
|
376
|
+
for img_idx in range(num_img):
|
|
377
|
+
if img_numpy_list is not None:
|
|
378
|
+
img = img_numpy_list[img_idx]
|
|
379
|
+
data = {'image': img}
|
|
380
|
+
elif img_path is not None:
|
|
381
|
+
with open(img_path[img_idx], 'rb') as f:
|
|
382
|
+
img = f.read()
|
|
383
|
+
data = {'image': img}
|
|
384
|
+
data = self.transform(data, self.ops[:1])
|
|
385
|
+
batch = self.transform(data, self.ops[1:])
|
|
386
|
+
|
|
387
|
+
images = np.expand_dims(batch[0], axis=0)
|
|
388
|
+
shape_list = np.expand_dims(batch[1], axis=0)
|
|
389
|
+
images = torch.from_numpy(images).to(device=self.device)
|
|
390
|
+
with torch.no_grad():
|
|
391
|
+
t_start = time.time()
|
|
392
|
+
preds = self.model(images)
|
|
393
|
+
t_cost = time.time() - t_start
|
|
394
|
+
post_result = self.post_process_class(preds, shape_list)
|
|
395
|
+
|
|
396
|
+
info = {'boxes': post_result[0]['points'], 'elapse': t_cost}
|
|
397
|
+
if return_mask:
|
|
398
|
+
if isinstance(preds['maps'], torch.Tensor):
|
|
399
|
+
mask = preds['maps'].detach().cpu().numpy()
|
|
400
|
+
else:
|
|
401
|
+
mask = preds['maps']
|
|
402
|
+
info['mask'] = mask
|
|
403
|
+
results.append(info)
|
|
404
|
+
return results
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
@torch.no_grad()
|
|
408
|
+
def main(cfg):
|
|
409
|
+
is_visualize = cfg['Global'].get('is_visualize', False)
|
|
410
|
+
model = OpenDetector(cfg)
|
|
411
|
+
|
|
412
|
+
save_res_path = cfg['Global']['output_dir']
|
|
413
|
+
if not os.path.exists(save_res_path):
|
|
414
|
+
os.makedirs(save_res_path)
|
|
415
|
+
sample_num = 0
|
|
416
|
+
with open(save_res_path + '/det_results.txt', 'wb') as fout:
|
|
417
|
+
for file in get_image_file_list(cfg['Global']['infer_img']):
|
|
418
|
+
|
|
419
|
+
preds_result = model(img_path=file)[0]
|
|
420
|
+
logger.info('{} infer_img: {}, time cost: {}'.format(
|
|
421
|
+
sample_num, file, preds_result['elapse']))
|
|
422
|
+
boxes = preds_result['boxes']
|
|
423
|
+
dt_boxes_json = []
|
|
424
|
+
for box in boxes:
|
|
425
|
+
tmp_json = {}
|
|
426
|
+
tmp_json['points'] = np.array(box).tolist()
|
|
427
|
+
dt_boxes_json.append(tmp_json)
|
|
428
|
+
if is_visualize:
|
|
429
|
+
src_img = cv2.imread(file)
|
|
430
|
+
save_det_path = save_res_path + '/det_results/'
|
|
431
|
+
draw_det_res(boxes, src_img, file, save_det_path)
|
|
432
|
+
logger.info('The detected Image saved in {}'.format(
|
|
433
|
+
os.path.join(save_det_path, os.path.basename(file))))
|
|
434
|
+
otstr = file + '\t' + json.dumps(dt_boxes_json) + '\n'
|
|
435
|
+
logger.info('results: {}'.format(json.dumps(dt_boxes_json)))
|
|
436
|
+
fout.write(otstr.encode())
|
|
437
|
+
sample_num += 1
|
|
438
|
+
|
|
439
|
+
logger.info('success!')
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
if __name__ == '__main__':
|
|
443
|
+
FLAGS = ArgsParser().parse_args()
|
|
444
|
+
cfg = Config(FLAGS.config)
|
|
445
|
+
FLAGS = vars(FLAGS)
|
|
446
|
+
opt = FLAGS.pop('opt')
|
|
447
|
+
cfg.merge_dict(FLAGS)
|
|
448
|
+
cfg.merge_dict(opt)
|
|
449
|
+
main(cfg.cfg)
|