openocr-python 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +11 -0
- openocr/configs/det/dbnet/repvit_db.yml +173 -0
- openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
- openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
- openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
- openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
- openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
- openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
- openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
- openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
- openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
- openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
- openocr/configs/rec/busnet/vit_busnet.yml +104 -0
- openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
- openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
- openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
- openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
- openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
- openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
- openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
- openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
- openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
- openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
- openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
- openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
- openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
- openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
- openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
- openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
- openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
- openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
- openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
- openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
- openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
- openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
- openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
- openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
- openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
- openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
- openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
- openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
- openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
- openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
- openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
- openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
- openocr/configs/rec/nrtr/nrtr.yml +107 -0
- openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
- openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
- openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
- openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
- openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
- openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
- openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
- openocr/configs/rec/parseq/vit_parseq.yml +100 -0
- openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
- openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
- openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
- openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
- openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
- openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
- openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
- openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
- openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
- openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
- openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
- openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
- openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
- openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
- openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
- openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
- openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
- openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
- openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
- openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
- openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
- openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
- openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
- openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
- openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
- openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
- openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
- openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
- openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
- openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
- openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
- openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
- openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
- openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
- openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
- openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
- openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
- openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
- openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
- openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
- openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
- openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
- openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
- openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
- openocr/demo_gradio.py +128 -0
- openocr/opendet/modeling/__init__.py +11 -0
- openocr/opendet/modeling/backbones/__init__.py +14 -0
- openocr/opendet/modeling/backbones/repvit.py +340 -0
- openocr/opendet/modeling/base_detector.py +69 -0
- openocr/opendet/modeling/heads/__init__.py +14 -0
- openocr/opendet/modeling/heads/db_head.py +73 -0
- openocr/opendet/modeling/necks/__init__.py +14 -0
- openocr/opendet/modeling/necks/db_fpn.py +609 -0
- openocr/opendet/postprocess/__init__.py +18 -0
- openocr/opendet/postprocess/db_postprocess.py +273 -0
- openocr/opendet/preprocess/__init__.py +154 -0
- openocr/opendet/preprocess/crop_resize.py +121 -0
- openocr/opendet/preprocess/db_resize_for_test.py +135 -0
- openocr/openrec/losses/__init__.py +62 -0
- openocr/openrec/losses/abinet_loss.py +42 -0
- openocr/openrec/losses/ar_loss.py +23 -0
- openocr/openrec/losses/cam_loss.py +48 -0
- openocr/openrec/losses/cdistnet_loss.py +34 -0
- openocr/openrec/losses/ce_loss.py +68 -0
- openocr/openrec/losses/cppd_loss.py +77 -0
- openocr/openrec/losses/ctc_loss.py +33 -0
- openocr/openrec/losses/igtr_loss.py +12 -0
- openocr/openrec/losses/lister_loss.py +14 -0
- openocr/openrec/losses/lpv_loss.py +30 -0
- openocr/openrec/losses/mgp_loss.py +34 -0
- openocr/openrec/losses/parseq_loss.py +12 -0
- openocr/openrec/losses/robustscanner_loss.py +20 -0
- openocr/openrec/losses/seed_loss.py +46 -0
- openocr/openrec/losses/smtr_loss.py +12 -0
- openocr/openrec/losses/srn_loss.py +40 -0
- openocr/openrec/losses/visionlan_loss.py +58 -0
- openocr/openrec/metrics/__init__.py +19 -0
- openocr/openrec/metrics/rec_metric.py +270 -0
- openocr/openrec/metrics/rec_metric_gtc.py +58 -0
- openocr/openrec/metrics/rec_metric_long.py +142 -0
- openocr/openrec/metrics/rec_metric_mgp.py +93 -0
- openocr/openrec/modeling/__init__.py +11 -0
- openocr/openrec/modeling/base_recognizer.py +69 -0
- openocr/openrec/modeling/common.py +238 -0
- openocr/openrec/modeling/decoders/__init__.py +109 -0
- openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
- openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
- openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
- openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
- openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
- openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
- openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
- openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
- openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
- openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
- openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
- openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
- openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
- openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
- openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
- openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
- openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
- openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
- openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
- openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
- openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
- openocr/openrec/modeling/encoders/__init__.py +39 -0
- openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
- openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
- openocr/openrec/modeling/encoders/convnextv2.py +213 -0
- openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
- openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
- openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
- openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
- openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
- openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
- openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
- openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
- openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
- openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
- openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
- openocr/openrec/modeling/encoders/repvit.py +338 -0
- openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
- openocr/openrec/modeling/encoders/svtrnet.py +574 -0
- openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
- openocr/openrec/modeling/encoders/svtrv2.py +470 -0
- openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
- openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
- openocr/openrec/modeling/encoders/vit.py +120 -0
- openocr/openrec/modeling/transforms/__init__.py +15 -0
- openocr/openrec/modeling/transforms/aster_tps.py +262 -0
- openocr/openrec/modeling/transforms/moran.py +136 -0
- openocr/openrec/modeling/transforms/tps.py +246 -0
- openocr/openrec/optimizer/__init__.py +73 -0
- openocr/openrec/optimizer/lr.py +227 -0
- openocr/openrec/postprocess/__init__.py +72 -0
- openocr/openrec/postprocess/abinet_postprocess.py +37 -0
- openocr/openrec/postprocess/ar_postprocess.py +63 -0
- openocr/openrec/postprocess/ce_postprocess.py +43 -0
- openocr/openrec/postprocess/char_postprocess.py +108 -0
- openocr/openrec/postprocess/cppd_postprocess.py +42 -0
- openocr/openrec/postprocess/ctc_postprocess.py +119 -0
- openocr/openrec/postprocess/igtr_postprocess.py +100 -0
- openocr/openrec/postprocess/lister_postprocess.py +59 -0
- openocr/openrec/postprocess/mgp_postprocess.py +143 -0
- openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
- openocr/openrec/postprocess/smtr_postprocess.py +73 -0
- openocr/openrec/postprocess/srn_postprocess.py +80 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
- openocr/openrec/preprocess/__init__.py +173 -0
- openocr/openrec/preprocess/abinet_aug.py +473 -0
- openocr/openrec/preprocess/abinet_label_encode.py +36 -0
- openocr/openrec/preprocess/ar_label_encode.py +36 -0
- openocr/openrec/preprocess/auto_augment.py +1012 -0
- openocr/openrec/preprocess/cam_label_encode.py +141 -0
- openocr/openrec/preprocess/ce_label_encode.py +116 -0
- openocr/openrec/preprocess/char_label_encode.py +36 -0
- openocr/openrec/preprocess/cppd_label_encode.py +173 -0
- openocr/openrec/preprocess/ctc_label_encode.py +124 -0
- openocr/openrec/preprocess/ep_label_encode.py +38 -0
- openocr/openrec/preprocess/igtr_label_encode.py +360 -0
- openocr/openrec/preprocess/mgp_label_encode.py +95 -0
- openocr/openrec/preprocess/parseq_aug.py +150 -0
- openocr/openrec/preprocess/rec_aug.py +211 -0
- openocr/openrec/preprocess/resize.py +534 -0
- openocr/openrec/preprocess/smtr_label_encode.py +125 -0
- openocr/openrec/preprocess/srn_label_encode.py +37 -0
- openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
- openocr/tools/create_lmdb_dataset.py +118 -0
- openocr/tools/data/__init__.py +94 -0
- openocr/tools/data/collate_fn.py +100 -0
- openocr/tools/data/lmdb_dataset.py +142 -0
- openocr/tools/data/lmdb_dataset_test.py +166 -0
- openocr/tools/data/multi_scale_sampler.py +177 -0
- openocr/tools/data/ratio_dataset.py +217 -0
- openocr/tools/data/ratio_dataset_test.py +273 -0
- openocr/tools/data/ratio_dataset_tvresize.py +213 -0
- openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
- openocr/tools/data/ratio_sampler.py +190 -0
- openocr/tools/data/simple_dataset.py +263 -0
- openocr/tools/data/strlmdb_dataset.py +143 -0
- openocr/tools/engine/__init__.py +5 -0
- openocr/tools/engine/config.py +158 -0
- openocr/tools/engine/trainer.py +621 -0
- openocr/tools/eval_rec.py +41 -0
- openocr/tools/eval_rec_all_ch.py +184 -0
- openocr/tools/eval_rec_all_en.py +206 -0
- openocr/tools/eval_rec_all_long.py +119 -0
- openocr/tools/eval_rec_all_long_simple.py +122 -0
- openocr/tools/export_rec.py +118 -0
- openocr/tools/infer/onnx_engine.py +65 -0
- openocr/tools/infer/predict_rec.py +140 -0
- openocr/tools/infer/utility.py +234 -0
- openocr/tools/infer_det.py +449 -0
- openocr/tools/infer_e2e.py +462 -0
- openocr/tools/infer_e2e_parallel.py +184 -0
- openocr/tools/infer_rec.py +371 -0
- openocr/tools/train_rec.py +37 -0
- openocr/tools/utility.py +45 -0
- openocr/tools/utils/EN_symbol_dict.txt +94 -0
- openocr/tools/utils/__init__.py +0 -0
- openocr/tools/utils/ckpt.py +87 -0
- openocr/tools/utils/dict/ar_dict.txt +117 -0
- openocr/tools/utils/dict/arabic_dict.txt +161 -0
- openocr/tools/utils/dict/be_dict.txt +145 -0
- openocr/tools/utils/dict/bg_dict.txt +140 -0
- openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
- openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
- openocr/tools/utils/dict/devanagari_dict.txt +167 -0
- openocr/tools/utils/dict/en_dict.txt +63 -0
- openocr/tools/utils/dict/fa_dict.txt +136 -0
- openocr/tools/utils/dict/french_dict.txt +136 -0
- openocr/tools/utils/dict/german_dict.txt +143 -0
- openocr/tools/utils/dict/hi_dict.txt +162 -0
- openocr/tools/utils/dict/it_dict.txt +118 -0
- openocr/tools/utils/dict/japan_dict.txt +4399 -0
- openocr/tools/utils/dict/ka_dict.txt +153 -0
- openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
- openocr/tools/utils/dict/korean_dict.txt +3688 -0
- openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
- openocr/tools/utils/dict/latin_dict.txt +185 -0
- openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
- openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
- openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
- openocr/tools/utils/dict/mr_dict.txt +153 -0
- openocr/tools/utils/dict/ne_dict.txt +153 -0
- openocr/tools/utils/dict/oc_dict.txt +96 -0
- openocr/tools/utils/dict/pu_dict.txt +130 -0
- openocr/tools/utils/dict/rs_dict.txt +91 -0
- openocr/tools/utils/dict/rsc_dict.txt +134 -0
- openocr/tools/utils/dict/ru_dict.txt +125 -0
- openocr/tools/utils/dict/spin_dict.txt +68 -0
- openocr/tools/utils/dict/ta_dict.txt +128 -0
- openocr/tools/utils/dict/table_dict.txt +277 -0
- openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
- openocr/tools/utils/dict/table_structure_dict.txt +28 -0
- openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
- openocr/tools/utils/dict/te_dict.txt +151 -0
- openocr/tools/utils/dict/ug_dict.txt +114 -0
- openocr/tools/utils/dict/uk_dict.txt +142 -0
- openocr/tools/utils/dict/ur_dict.txt +137 -0
- openocr/tools/utils/dict/xi_dict.txt +110 -0
- openocr/tools/utils/dict90.txt +90 -0
- openocr/tools/utils/e2e_metric/Deteval.py +802 -0
- openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
- openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
- openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
- openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
- openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
- openocr/tools/utils/e2e_utils/visual.py +152 -0
- openocr/tools/utils/en_dict.txt +95 -0
- openocr/tools/utils/gen_label.py +68 -0
- openocr/tools/utils/ic15_dict.txt +36 -0
- openocr/tools/utils/logging.py +56 -0
- openocr/tools/utils/poly_nms.py +132 -0
- openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
- openocr/tools/utils/stats.py +58 -0
- openocr/tools/utils/utility.py +165 -0
- openocr/tools/utils/visual.py +117 -0
- openocr_python-0.0.2.dist-info/LICENCE +201 -0
- openocr_python-0.0.2.dist-info/METADATA +98 -0
- openocr_python-0.0.2.dist-info/RECORD +323 -0
- openocr_python-0.0.2.dist-info/WHEEL +5 -0
- openocr_python-0.0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
from __future__ import absolute_import
|
|
2
|
+
from __future__ import division
|
|
3
|
+
from __future__ import print_function
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
import sys
|
|
8
|
+
|
|
9
|
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
|
10
|
+
sys.path.append(__dir__)
|
|
11
|
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
|
12
|
+
|
|
13
|
+
os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
|
|
14
|
+
import argparse
|
|
15
|
+
import numpy as np
|
|
16
|
+
import copy
|
|
17
|
+
import time
|
|
18
|
+
import cv2
|
|
19
|
+
import json
|
|
20
|
+
from PIL import Image
|
|
21
|
+
import torch
|
|
22
|
+
from tools.utils.utility import get_image_file_list, check_and_read
|
|
23
|
+
from tools.infer_rec import OpenRecognizer
|
|
24
|
+
from tools.infer_det import OpenDetector
|
|
25
|
+
from tools.engine import Config
|
|
26
|
+
from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop, draw_ocr_box_txt
|
|
27
|
+
from tools.utils.logging import get_logger
|
|
28
|
+
|
|
29
|
+
root_dir = Path(__file__).resolve().parent
|
|
30
|
+
DEFAULT_CFG_PATH_DET = str(root_dir / '../configs/det/dbnet/repvit_db.yml')
|
|
31
|
+
DEFAULT_CFG_PATH_REC_SERVER = str(root_dir / '../configs/det/svtrv2/svtrv2_ch.yml')
|
|
32
|
+
DEFAULT_CFG_PATH_REC = str(root_dir / '../configs/rec/svtrv2/repsvtr_ch.yml')
|
|
33
|
+
|
|
34
|
+
logger = get_logger()
|
|
35
|
+
|
|
36
|
+
MODEL_NAME_DET = './openocr_det_repvit_ch.pth' # 模型文件名称
|
|
37
|
+
DOWNLOAD_URL_DET = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth' # 模型文件 URL
|
|
38
|
+
MODEL_NAME_REC = './openocr_repsvtr_ch.pth' # 模型文件名称
|
|
39
|
+
DOWNLOAD_URL_REC = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth' # 模型文件 URL
|
|
40
|
+
MODEL_NAME_REC_SERVER = './openocr_svtrv2_ch.pth' # 模型文件名称
|
|
41
|
+
DOWNLOAD_URL_REC_SERVER = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth' # 模型文件 URL
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def check_and_download_model(model_name: str, url: str):
|
|
45
|
+
"""
|
|
46
|
+
检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
model_name (str): 模型文件的名称,例如 "model.pt"
|
|
50
|
+
url (str): 模型文件的下载地址
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
str: 模型文件的完整路径
|
|
54
|
+
"""
|
|
55
|
+
if os.path.exists(model_name):
|
|
56
|
+
return model_name
|
|
57
|
+
|
|
58
|
+
# 固定缓存路径为用户主目录下的 ".cache/openocr"
|
|
59
|
+
cache_dir = Path.home() / '.cache' / 'openocr'
|
|
60
|
+
model_path = cache_dir / model_name
|
|
61
|
+
|
|
62
|
+
# 如果模型文件已存在,直接返回路径
|
|
63
|
+
if model_path.exists():
|
|
64
|
+
logger.info(f'Model already exists at: {model_path}')
|
|
65
|
+
return str(model_path)
|
|
66
|
+
|
|
67
|
+
# 如果文件不存在,下载模型
|
|
68
|
+
logger.info(f'Model not found. Downloading from {url}...')
|
|
69
|
+
|
|
70
|
+
# 创建缓存目录(如果不存在)
|
|
71
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
# 下载文件
|
|
75
|
+
import urllib.request
|
|
76
|
+
with urllib.request.urlopen(url) as response, open(model_path,
|
|
77
|
+
'wb') as out_file:
|
|
78
|
+
out_file.write(response.read())
|
|
79
|
+
logger.info(f'Model downloaded and saved at: {model_path}')
|
|
80
|
+
return str(model_path)
|
|
81
|
+
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.info(f'Error downloading the model: {e}')
|
|
84
|
+
raise
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def set_device(device):
|
|
88
|
+
if device == 'gpu' and torch.cuda.is_available():
|
|
89
|
+
device = torch.device('cuda:0')
|
|
90
|
+
else:
|
|
91
|
+
device = torch.device('cpu')
|
|
92
|
+
return device
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def check_and_download_font(font_path):
|
|
96
|
+
if not os.path.exists(font_path):
|
|
97
|
+
logger.info(f"Downloading '{font_path}' ...")
|
|
98
|
+
try:
|
|
99
|
+
import urllib.request
|
|
100
|
+
font_url = 'https://shuiche-shop.oss-cn-chengdu.aliyuncs.com/fonts/simfang.ttf'
|
|
101
|
+
cache_dir = Path.home() / '.cache' / 'openocr'
|
|
102
|
+
font_path = str(cache_dir / font_path)
|
|
103
|
+
urllib.request.urlretrieve(font_url, font_path)
|
|
104
|
+
logger.info(f'Downloading font success: {font_path}')
|
|
105
|
+
except Exception as e:
|
|
106
|
+
logger.info(f'Downloading font error: {e}')
|
|
107
|
+
return font_path
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def sorted_boxes(dt_boxes):
|
|
111
|
+
"""
|
|
112
|
+
Sort text boxes in order from top to bottom, left to right
|
|
113
|
+
args:
|
|
114
|
+
dt_boxes(array):detected text boxes with shape [4, 2]
|
|
115
|
+
return:
|
|
116
|
+
sorted boxes(array) with shape [4, 2]
|
|
117
|
+
"""
|
|
118
|
+
num_boxes = dt_boxes.shape[0]
|
|
119
|
+
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
|
120
|
+
_boxes = list(sorted_boxes)
|
|
121
|
+
|
|
122
|
+
for i in range(num_boxes - 1):
|
|
123
|
+
for j in range(i, -1, -1):
|
|
124
|
+
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
|
|
125
|
+
_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
|
126
|
+
tmp = _boxes[j]
|
|
127
|
+
_boxes[j] = _boxes[j + 1]
|
|
128
|
+
_boxes[j + 1] = tmp
|
|
129
|
+
else:
|
|
130
|
+
break
|
|
131
|
+
return _boxes
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class OpenOCR(object):
|
|
135
|
+
|
|
136
|
+
def __init__(self, mode='mobile', drop_score=0.5, det_box_type='quad'):
|
|
137
|
+
"""
|
|
138
|
+
初始化函数,用于初始化OCR引擎的相关配置和组件。
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
mode (str, optional): 运行模式,可选值为'mobile'或'server'。默认为'mobile'。
|
|
142
|
+
drop_score (float, optional): 检测框的置信度阈值,低于该阈值的检测框将被丢弃。默认为0.5。
|
|
143
|
+
det_box_type (str, optional): 检测框的类型,可选值为'quad' and 'poly'。默认为'quad'。
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
无返回值。
|
|
147
|
+
|
|
148
|
+
"""
|
|
149
|
+
cfg_det = Config(DEFAULT_CFG_PATH_DET).cfg # mobile model
|
|
150
|
+
model_dir = check_and_download_model(MODEL_NAME_DET, DOWNLOAD_URL_DET)
|
|
151
|
+
cfg_det['Global']['pretrained_model'] = model_dir
|
|
152
|
+
if mode == 'server':
|
|
153
|
+
cfg_rec = Config(DEFAULT_CFG_PATH_REC_SERVER).cfg # server model
|
|
154
|
+
model_dir = check_and_download_model(MODEL_NAME_REC_SERVER,
|
|
155
|
+
DOWNLOAD_URL_REC_SERVER)
|
|
156
|
+
else:
|
|
157
|
+
cfg_rec = Config(DEFAULT_CFG_PATH_REC).cfg # mobile model
|
|
158
|
+
model_dir = check_and_download_model(MODEL_NAME_REC,
|
|
159
|
+
DOWNLOAD_URL_REC)
|
|
160
|
+
cfg_rec['Global']['pretrained_model'] = model_dir
|
|
161
|
+
self.text_detector = OpenDetector(cfg_det)
|
|
162
|
+
self.text_recognizer = OpenRecognizer(cfg_rec)
|
|
163
|
+
self.det_box_type = det_box_type
|
|
164
|
+
self.drop_score = drop_score
|
|
165
|
+
|
|
166
|
+
self.crop_image_res_index = 0
|
|
167
|
+
|
|
168
|
+
def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
|
|
169
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
170
|
+
bbox_num = len(img_crop_list)
|
|
171
|
+
for bno in range(bbox_num):
|
|
172
|
+
cv2.imwrite(
|
|
173
|
+
os.path.join(output_dir,
|
|
174
|
+
f'mg_crop_{bno+self.crop_image_res_index}.jpg'),
|
|
175
|
+
img_crop_list[bno],
|
|
176
|
+
)
|
|
177
|
+
self.crop_image_res_index += bbox_num
|
|
178
|
+
|
|
179
|
+
def infer_single_image(self,
|
|
180
|
+
img_numpy,
|
|
181
|
+
ori_img,
|
|
182
|
+
crop_infer=False,
|
|
183
|
+
rec_batch_num=6,
|
|
184
|
+
return_mask=False):
|
|
185
|
+
start = time.time()
|
|
186
|
+
if crop_infer:
|
|
187
|
+
dt_boxes = self.text_detector.crop_infer(
|
|
188
|
+
img_numpy=img_numpy)[0]['boxes']
|
|
189
|
+
else:
|
|
190
|
+
det_res = self.text_detector(img_numpy=img_numpy,
|
|
191
|
+
return_mask=return_mask)[0]
|
|
192
|
+
dt_boxes = det_res['boxes']
|
|
193
|
+
# logger.info(dt_boxes)
|
|
194
|
+
det_time_cost = time.time() - start
|
|
195
|
+
|
|
196
|
+
if dt_boxes is None:
|
|
197
|
+
return None, None, None
|
|
198
|
+
|
|
199
|
+
img_crop_list = []
|
|
200
|
+
|
|
201
|
+
dt_boxes = sorted_boxes(dt_boxes)
|
|
202
|
+
|
|
203
|
+
for bno in range(len(dt_boxes)):
|
|
204
|
+
tmp_box = np.array(copy.deepcopy(dt_boxes[bno])).astype(np.float32)
|
|
205
|
+
if self.det_box_type == 'quad':
|
|
206
|
+
img_crop = get_rotate_crop_image(ori_img, tmp_box)
|
|
207
|
+
else:
|
|
208
|
+
img_crop = get_minarea_rect_crop(ori_img, tmp_box)
|
|
209
|
+
img_crop_list.append(img_crop)
|
|
210
|
+
|
|
211
|
+
start = time.time()
|
|
212
|
+
rec_res = self.text_recognizer(img_numpy_list=img_crop_list,
|
|
213
|
+
batch_num=rec_batch_num)
|
|
214
|
+
rec_time_cost = time.time() - start
|
|
215
|
+
|
|
216
|
+
filter_boxes, filter_rec_res = [], []
|
|
217
|
+
rec_time_cost_sig = 0.0
|
|
218
|
+
for box, rec_result in zip(dt_boxes, rec_res):
|
|
219
|
+
text, score = rec_result['text'], rec_result['score']
|
|
220
|
+
rec_time_cost_sig += rec_result['elapse']
|
|
221
|
+
if score >= self.drop_score:
|
|
222
|
+
filter_boxes.append(box)
|
|
223
|
+
filter_rec_res.append([text, score])
|
|
224
|
+
|
|
225
|
+
avg_rec_time_cost = rec_time_cost_sig / len(dt_boxes) if len(
|
|
226
|
+
dt_boxes) > 0 else 0.0
|
|
227
|
+
if return_mask:
|
|
228
|
+
return filter_boxes, filter_rec_res, {
|
|
229
|
+
'time_cost': det_time_cost + rec_time_cost,
|
|
230
|
+
'detection_time': det_time_cost,
|
|
231
|
+
'recognition_time': rec_time_cost,
|
|
232
|
+
'avg_rec_time_cost': avg_rec_time_cost
|
|
233
|
+
}, det_res['mask']
|
|
234
|
+
|
|
235
|
+
return filter_boxes, filter_rec_res, {
|
|
236
|
+
'time_cost': det_time_cost + rec_time_cost,
|
|
237
|
+
'detection_time': det_time_cost,
|
|
238
|
+
'recognition_time': rec_time_cost,
|
|
239
|
+
'avg_rec_time_cost': avg_rec_time_cost
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
def __call__(self,
|
|
243
|
+
img_path=None,
|
|
244
|
+
save_dir='e2e_results/',
|
|
245
|
+
is_visualize=False,
|
|
246
|
+
img_numpy=None,
|
|
247
|
+
rec_batch_num=6,
|
|
248
|
+
crop_infer=False,
|
|
249
|
+
return_mask=False):
|
|
250
|
+
"""
|
|
251
|
+
img_path: str, optional, default=None
|
|
252
|
+
Path to the directory containing images or the image filename.
|
|
253
|
+
save_dir: str, optional, default='e2e_results/'
|
|
254
|
+
Directory to save prediction and visualization results. Defaults to a subfolder in img_path.
|
|
255
|
+
is_visualize: bool, optional, default=False
|
|
256
|
+
Visualize the results.
|
|
257
|
+
img_numpy: numpy or list[numpy], optional, default=None
|
|
258
|
+
numpy of an image or List of numpy arrays representing images.
|
|
259
|
+
rec_batch_num: int, optional, default=6
|
|
260
|
+
Batch size for text recognition.
|
|
261
|
+
crop_infer: bool, optional, default=False
|
|
262
|
+
Whether to use crop inference.
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
if img_numpy is None and img_path is None:
|
|
266
|
+
raise ValueError('img_path and img_numpy cannot be both None.')
|
|
267
|
+
if img_numpy is not None:
|
|
268
|
+
if not isinstance(img_numpy, list):
|
|
269
|
+
img_numpy = [img_numpy]
|
|
270
|
+
results = []
|
|
271
|
+
time_dicts = []
|
|
272
|
+
for index, img in enumerate(img_numpy):
|
|
273
|
+
ori_img = img.copy()
|
|
274
|
+
if return_mask:
|
|
275
|
+
dt_boxes, rec_res, time_dict, mask = self.infer_single_image(
|
|
276
|
+
img_numpy=img,
|
|
277
|
+
ori_img=ori_img,
|
|
278
|
+
crop_infer=crop_infer,
|
|
279
|
+
rec_batch_num=rec_batch_num,
|
|
280
|
+
return_mask=return_mask)
|
|
281
|
+
else:
|
|
282
|
+
dt_boxes, rec_res, time_dict = self.infer_single_image(
|
|
283
|
+
img_numpy=img,
|
|
284
|
+
ori_img=ori_img,
|
|
285
|
+
crop_infer=crop_infer,
|
|
286
|
+
rec_batch_num=rec_batch_num)
|
|
287
|
+
if dt_boxes is None:
|
|
288
|
+
results.append([])
|
|
289
|
+
time_dicts.append({})
|
|
290
|
+
continue
|
|
291
|
+
res = [{
|
|
292
|
+
'transcription': rec_res[i][0],
|
|
293
|
+
'points': np.array(dt_boxes[i]).tolist(),
|
|
294
|
+
'score': rec_res[i][1],
|
|
295
|
+
} for i in range(len(dt_boxes))]
|
|
296
|
+
results.append(res)
|
|
297
|
+
time_dicts.append(time_dict)
|
|
298
|
+
if return_mask:
|
|
299
|
+
return results, time_dicts, mask
|
|
300
|
+
return results, time_dicts
|
|
301
|
+
|
|
302
|
+
image_file_list = get_image_file_list(img_path)
|
|
303
|
+
save_results = []
|
|
304
|
+
time_dicts_return = []
|
|
305
|
+
for idx, image_file in enumerate(image_file_list):
|
|
306
|
+
img, flag_gif, flag_pdf = check_and_read(image_file)
|
|
307
|
+
if not flag_gif and not flag_pdf:
|
|
308
|
+
img = cv2.imread(image_file)
|
|
309
|
+
if not flag_pdf:
|
|
310
|
+
if img is None:
|
|
311
|
+
return None
|
|
312
|
+
imgs = [img]
|
|
313
|
+
else:
|
|
314
|
+
imgs = img
|
|
315
|
+
logger.info(
|
|
316
|
+
f'Processing {idx+1}/{len(image_file_list)}: {image_file}')
|
|
317
|
+
|
|
318
|
+
res_list = []
|
|
319
|
+
time_dicts = []
|
|
320
|
+
for index, img_numpy in enumerate(imgs):
|
|
321
|
+
ori_img = img_numpy.copy()
|
|
322
|
+
dt_boxes, rec_res, time_dict = self.infer_single_image(
|
|
323
|
+
img_numpy=img_numpy,
|
|
324
|
+
ori_img=ori_img,
|
|
325
|
+
crop_infer=crop_infer,
|
|
326
|
+
rec_batch_num=rec_batch_num)
|
|
327
|
+
if dt_boxes is None:
|
|
328
|
+
res_list.append([])
|
|
329
|
+
time_dicts.append({})
|
|
330
|
+
continue
|
|
331
|
+
res = [{
|
|
332
|
+
'transcription': rec_res[i][0],
|
|
333
|
+
'points': np.array(dt_boxes[i]).tolist(),
|
|
334
|
+
'score': rec_res[i][1],
|
|
335
|
+
} for i in range(len(dt_boxes))]
|
|
336
|
+
res_list.append(res)
|
|
337
|
+
time_dicts.append(time_dict)
|
|
338
|
+
|
|
339
|
+
for index, (res, time_dict) in enumerate(zip(res_list,
|
|
340
|
+
time_dicts)):
|
|
341
|
+
|
|
342
|
+
if len(res) > 0:
|
|
343
|
+
logger.info(f'Results: {res}.')
|
|
344
|
+
logger.info(f'Time cost: {time_dict}.')
|
|
345
|
+
else:
|
|
346
|
+
logger.info('No text detected.')
|
|
347
|
+
|
|
348
|
+
if len(res_list) > 1:
|
|
349
|
+
save_pred = (os.path.basename(image_file) + '_' +
|
|
350
|
+
str(index) + '\t' +
|
|
351
|
+
json.dumps(res, ensure_ascii=False) + '\n')
|
|
352
|
+
else:
|
|
353
|
+
if len(res) > 0:
|
|
354
|
+
save_pred = (os.path.basename(image_file) + '\t' +
|
|
355
|
+
json.dumps(res, ensure_ascii=False) +
|
|
356
|
+
'\n')
|
|
357
|
+
else:
|
|
358
|
+
continue
|
|
359
|
+
save_results.append(save_pred)
|
|
360
|
+
time_dicts_return.append(time_dict)
|
|
361
|
+
|
|
362
|
+
if is_visualize and len(res) > 0:
|
|
363
|
+
if idx == 0:
|
|
364
|
+
font_path = './simfang.ttf'
|
|
365
|
+
font_path = check_and_download_font(font_path)
|
|
366
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
367
|
+
draw_img_save_dir = os.path.join(
|
|
368
|
+
save_dir, 'vis_results/')
|
|
369
|
+
os.makedirs(draw_img_save_dir, exist_ok=True)
|
|
370
|
+
logger.info(
|
|
371
|
+
f'Visualized results will be saved to {draw_img_save_dir}.'
|
|
372
|
+
)
|
|
373
|
+
dt_boxes = [res[i]['points'] for i in range(len(res))]
|
|
374
|
+
rec_res = [
|
|
375
|
+
res[i]['transcription'] for i in range(len(res))
|
|
376
|
+
]
|
|
377
|
+
rec_score = [res[i]['score'] for i in range(len(res))]
|
|
378
|
+
image = Image.fromarray(
|
|
379
|
+
cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
|
380
|
+
boxes = dt_boxes
|
|
381
|
+
txts = [rec_res[i] for i in range(len(rec_res))]
|
|
382
|
+
scores = [rec_score[i] for i in range(len(rec_res))]
|
|
383
|
+
|
|
384
|
+
draw_img = draw_ocr_box_txt(
|
|
385
|
+
image,
|
|
386
|
+
boxes,
|
|
387
|
+
txts,
|
|
388
|
+
scores,
|
|
389
|
+
drop_score=self.drop_score,
|
|
390
|
+
font_path=font_path,
|
|
391
|
+
)
|
|
392
|
+
if flag_gif:
|
|
393
|
+
save_file = image_file[:-3] + 'png'
|
|
394
|
+
elif flag_pdf:
|
|
395
|
+
save_file = image_file.replace(
|
|
396
|
+
'.pdf', '_' + str(index) + '.png')
|
|
397
|
+
else:
|
|
398
|
+
save_file = image_file
|
|
399
|
+
cv2.imwrite(
|
|
400
|
+
os.path.join(draw_img_save_dir,
|
|
401
|
+
os.path.basename(save_file)),
|
|
402
|
+
draw_img[:, :, ::-1],
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
if save_results:
|
|
406
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
407
|
+
with open(os.path.join(save_dir, 'system_results.txt'),
|
|
408
|
+
'w',
|
|
409
|
+
encoding='utf-8') as f:
|
|
410
|
+
f.writelines(save_results)
|
|
411
|
+
logger.info(
|
|
412
|
+
f"Results saved to {os.path.join(save_dir, 'system_results.txt')}."
|
|
413
|
+
)
|
|
414
|
+
if is_visualize:
|
|
415
|
+
logger.info(
|
|
416
|
+
f'Visualized results saved to {draw_img_save_dir}.')
|
|
417
|
+
return save_results, time_dicts_return
|
|
418
|
+
else:
|
|
419
|
+
logger.info('No text detected.')
|
|
420
|
+
return None, None
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def main():
|
|
424
|
+
parser = argparse.ArgumentParser(description='OpenOCR system')
|
|
425
|
+
parser.add_argument(
|
|
426
|
+
'--img_path',
|
|
427
|
+
type=str,
|
|
428
|
+
help='Path to the directory containing images or the image filename.')
|
|
429
|
+
parser.add_argument(
|
|
430
|
+
'--mode',
|
|
431
|
+
type=str,
|
|
432
|
+
default='mobile',
|
|
433
|
+
help="Mode of the OCR system, e.g., 'mobile' or 'server'.")
|
|
434
|
+
parser.add_argument(
|
|
435
|
+
'--save_dir',
|
|
436
|
+
type=str,
|
|
437
|
+
default='e2e_results/',
|
|
438
|
+
help='Directory to save prediction and visualization results. \
|
|
439
|
+
Defaults to ./e2e_results/.')
|
|
440
|
+
parser.add_argument('--is_vis',
|
|
441
|
+
action='store_true',
|
|
442
|
+
default=False,
|
|
443
|
+
help='Visualize the results.')
|
|
444
|
+
parser.add_argument('--drop_score',
|
|
445
|
+
type=float,
|
|
446
|
+
default=0.5,
|
|
447
|
+
help='Score threshold for text recognition.')
|
|
448
|
+
args = parser.parse_args()
|
|
449
|
+
|
|
450
|
+
img_path = args.img_path
|
|
451
|
+
mode = args.mode
|
|
452
|
+
save_dir = args.save_dir
|
|
453
|
+
is_visualize = args.is_vis
|
|
454
|
+
drop_score = args.drop_score
|
|
455
|
+
|
|
456
|
+
text_sys = OpenOCR(mode=mode, drop_score=drop_score,
|
|
457
|
+
det_box_type='quad') # det_box_type: 'quad' or 'poly'
|
|
458
|
+
text_sys(img_path=img_path, save_dir=save_dir, is_visualize=is_visualize)
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
if __name__ == '__main__':
|
|
462
|
+
main()
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from __future__ import absolute_import
|
|
2
|
+
from __future__ import division
|
|
3
|
+
from __future__ import print_function
|
|
4
|
+
|
|
5
|
+
import threading
|
|
6
|
+
import queue
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
import time
|
|
10
|
+
|
|
11
|
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
|
12
|
+
sys.path.append(__dir__)
|
|
13
|
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import cv2
|
|
17
|
+
import json
|
|
18
|
+
from PIL import Image
|
|
19
|
+
from tools.utils.utility import get_image_file_list, check_and_read
|
|
20
|
+
from tools.infer_rec import OpenRecognizer
|
|
21
|
+
from tools.infer_det import OpenDetector
|
|
22
|
+
from tools.infer_e2e import check_and_download_font, sorted_boxes
|
|
23
|
+
from tools.engine import Config
|
|
24
|
+
from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop, draw_ocr_box_txt
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OpenOCRParallel:
|
|
28
|
+
|
|
29
|
+
def __init__(self, drop_score=0.5, det_box_type='quad', max_rec_threads=1):
|
|
30
|
+
cfg_det = Config(
|
|
31
|
+
'./configs/det/dbnet/repvit_db.yml').cfg # mobile model
|
|
32
|
+
# cfg_rec = Config('./configs/rec/svtrv2/svtrv2_ch.yml').cfg # server model
|
|
33
|
+
cfg_rec = Config(
|
|
34
|
+
'./configs/rec/svtrv2/repsvtr_ch.yml').cfg # mobile model
|
|
35
|
+
self.text_detector = OpenDetector(cfg_det, numId=0)
|
|
36
|
+
self.text_recognizer = OpenRecognizer(cfg_rec, numId=0)
|
|
37
|
+
self.det_box_type = det_box_type
|
|
38
|
+
self.drop_score = drop_score
|
|
39
|
+
self.queue = queue.Queue(
|
|
40
|
+
) # Queue to hold detected boxes for recognition
|
|
41
|
+
self.results = {}
|
|
42
|
+
self.lock = threading.Lock() # Lock for thread-safe access to results
|
|
43
|
+
self.max_rec_threads = max_rec_threads
|
|
44
|
+
self.stop_signal = threading.Event() # Signal to stop threads
|
|
45
|
+
|
|
46
|
+
def start_recognition_threads(self):
|
|
47
|
+
"""Start recognition threads."""
|
|
48
|
+
self.rec_threads = []
|
|
49
|
+
for _ in range(self.max_rec_threads):
|
|
50
|
+
t = threading.Thread(target=self.recognize_text)
|
|
51
|
+
t.start()
|
|
52
|
+
self.rec_threads.append(t)
|
|
53
|
+
|
|
54
|
+
def detect_text(self, image_list):
|
|
55
|
+
"""Single-threaded text detection for all images."""
|
|
56
|
+
for image_id, (img_numpy, ori_img) in enumerate(image_list):
|
|
57
|
+
dt_boxes = self.text_detector(img_numpy=img_numpy)[0]['boxes']
|
|
58
|
+
if dt_boxes is None:
|
|
59
|
+
self.results[image_id] = [] # If no boxes, set empty results
|
|
60
|
+
continue
|
|
61
|
+
|
|
62
|
+
dt_boxes = sorted_boxes(dt_boxes)
|
|
63
|
+
img_crop_list = []
|
|
64
|
+
for box in dt_boxes:
|
|
65
|
+
tmp_box = np.array(box).astype(np.float32)
|
|
66
|
+
img_crop = (get_rotate_crop_image(ori_img, tmp_box)
|
|
67
|
+
if self.det_box_type == 'quad' else
|
|
68
|
+
get_minarea_rect_crop(ori_img, tmp_box))
|
|
69
|
+
img_crop_list.append(img_crop)
|
|
70
|
+
self.queue.put(
|
|
71
|
+
(image_id, dt_boxes, img_crop_list
|
|
72
|
+
)) # Put image ID, detected box, and cropped image in queue
|
|
73
|
+
|
|
74
|
+
# Signal that no more items will be added to the queue
|
|
75
|
+
self.stop_signal.set()
|
|
76
|
+
|
|
77
|
+
def recognize_text(self):
|
|
78
|
+
"""Recognize text in each cropped image."""
|
|
79
|
+
while not self.stop_signal.is_set() or not self.queue.empty():
|
|
80
|
+
try:
|
|
81
|
+
image_id, boxs, img_crop_list = self.queue.get(timeout=0.5)
|
|
82
|
+
rec_results = self.text_recognizer(
|
|
83
|
+
img_numpy_list=img_crop_list, batch_num=6)
|
|
84
|
+
for rec_result, box in zip(rec_results, boxs):
|
|
85
|
+
text, score = rec_result['text'], rec_result['score']
|
|
86
|
+
if score >= self.drop_score:
|
|
87
|
+
with self.lock:
|
|
88
|
+
# Ensure results dictionary has a list for each image ID
|
|
89
|
+
if image_id not in self.results:
|
|
90
|
+
self.results[image_id] = []
|
|
91
|
+
self.results[image_id].append({
|
|
92
|
+
'transcription':
|
|
93
|
+
text,
|
|
94
|
+
'points':
|
|
95
|
+
box.tolist(),
|
|
96
|
+
'score':
|
|
97
|
+
score
|
|
98
|
+
})
|
|
99
|
+
self.queue.task_done()
|
|
100
|
+
except queue.Empty:
|
|
101
|
+
continue
|
|
102
|
+
|
|
103
|
+
def process_images(self, image_list):
|
|
104
|
+
"""Process a list of images."""
|
|
105
|
+
# Initialize results dictionary
|
|
106
|
+
self.results = {i: [] for i in range(len(image_list))}
|
|
107
|
+
|
|
108
|
+
# Start recognition threads
|
|
109
|
+
t_start_1 = time.time()
|
|
110
|
+
self.start_recognition_threads()
|
|
111
|
+
|
|
112
|
+
# Start detection in the main thread
|
|
113
|
+
t_start = time.time()
|
|
114
|
+
self.detect_text(image_list)
|
|
115
|
+
print('det time:', time.time() - t_start)
|
|
116
|
+
|
|
117
|
+
# Wait for recognition threads to finish
|
|
118
|
+
for t in self.rec_threads:
|
|
119
|
+
t.join()
|
|
120
|
+
self.stop_signal.clear()
|
|
121
|
+
print('all time:', time.time() - t_start_1)
|
|
122
|
+
return self.results
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def main(cfg_det, cfg_rec):
|
|
126
|
+
img_path = './testA/'
|
|
127
|
+
image_file_list = get_image_file_list(img_path)
|
|
128
|
+
drop_score = 0.5
|
|
129
|
+
text_sys = OpenOCRParallel(
|
|
130
|
+
drop_score=drop_score,
|
|
131
|
+
det_box_type='quad') # det_box_type: 'quad' or 'poly'
|
|
132
|
+
is_visualize = False
|
|
133
|
+
if is_visualize:
|
|
134
|
+
font_path = './simfang.ttf'
|
|
135
|
+
check_and_download_font(font_path)
|
|
136
|
+
draw_img_save_dir = img_path + 'e2e_results/' if img_path[
|
|
137
|
+
-1] != '/' else img_path[:-1] + 'e2e_results/'
|
|
138
|
+
os.makedirs(draw_img_save_dir, exist_ok=True)
|
|
139
|
+
save_results = []
|
|
140
|
+
|
|
141
|
+
# Prepare images
|
|
142
|
+
images = []
|
|
143
|
+
t_start = time.time()
|
|
144
|
+
for image_file in image_file_list:
|
|
145
|
+
img, flag_gif, flag_pdf = check_and_read(image_file)
|
|
146
|
+
if not flag_gif and not flag_pdf:
|
|
147
|
+
img = cv2.imread(image_file)
|
|
148
|
+
if img is not None:
|
|
149
|
+
images.append((img, img.copy()))
|
|
150
|
+
|
|
151
|
+
results = text_sys.process_images(images)
|
|
152
|
+
print(f'time cost: {time.time() - t_start}')
|
|
153
|
+
# Save results and visualize
|
|
154
|
+
for image_id, res in results.items():
|
|
155
|
+
image_file = image_file_list[image_id]
|
|
156
|
+
save_pred = f'{os.path.basename(image_file)}\t{json.dumps(res, ensure_ascii=False)}\n'
|
|
157
|
+
# print(save_pred)
|
|
158
|
+
save_results.append(save_pred)
|
|
159
|
+
|
|
160
|
+
if is_visualize:
|
|
161
|
+
dt_boxes = [result['points'] for result in res]
|
|
162
|
+
rec_res = [result['transcription'] for result in res]
|
|
163
|
+
rec_score = [result['score'] for result in res]
|
|
164
|
+
image = Image.fromarray(
|
|
165
|
+
cv2.cvtColor(images[image_id][0], cv2.COLOR_BGR2RGB))
|
|
166
|
+
draw_img = draw_ocr_box_txt(image,
|
|
167
|
+
dt_boxes,
|
|
168
|
+
rec_res,
|
|
169
|
+
rec_score,
|
|
170
|
+
drop_score=drop_score,
|
|
171
|
+
font_path=font_path)
|
|
172
|
+
|
|
173
|
+
save_file = os.path.join(draw_img_save_dir,
|
|
174
|
+
os.path.basename(image_file))
|
|
175
|
+
cv2.imwrite(save_file, draw_img[:, :, ::-1])
|
|
176
|
+
|
|
177
|
+
with open(os.path.join(draw_img_save_dir, 'system_results.txt'),
|
|
178
|
+
'w',
|
|
179
|
+
encoding='utf-8') as f:
|
|
180
|
+
f.writelines(save_results)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
if __name__ == '__main__':
|
|
184
|
+
main()
|