openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,534 @@
1
+ import math
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from torchvision import transforms as T
9
+ from torchvision.transforms import functional as F
10
+
11
+
12
+ class CDistNetResize(object):
13
+
14
+ def __init__(self, image_shape, **kwargs):
15
+ self.image_shape = image_shape
16
+
17
+ def __call__(self, data):
18
+ img = data['image']
19
+ _, h, w = self.image_shape
20
+ # keep_aspect_ratio = False
21
+ image_pil = Image.fromarray(np.uint8(img))
22
+ image = image_pil.resize((w, h), Image.LANCZOS)
23
+ image = np.array(image)
24
+ # rgb2gray = False
25
+ image = image.transpose((2, 0, 1))
26
+ image = image.astype(np.float32) / 128.0 - 1.0
27
+ data['image'] = image
28
+ data['valid_ratio'] = 1
29
+ return data
30
+
31
+
32
+ class ABINetResize(object):
33
+
34
+ def __init__(self, image_shape, **kwargs):
35
+ self.image_shape = image_shape
36
+
37
+ def __call__(self, data):
38
+ img = data['image']
39
+ h, w = img.shape[:2]
40
+ norm_img, valid_ratio = resize_norm_img_abinet(img, self.image_shape)
41
+ data['image'] = norm_img
42
+ data['valid_ratio'] = valid_ratio
43
+ r = float(w) / float(h)
44
+ data['real_ratio'] = max(1, round(r))
45
+ return data
46
+
47
+
48
+ def resize_norm_img_abinet(img, image_shape):
49
+ imgC, imgH, imgW = image_shape
50
+
51
+ resized_image = cv2.resize(img, (imgW, imgH),
52
+ interpolation=cv2.INTER_LINEAR)
53
+ resized_w = imgW
54
+ resized_image = resized_image.astype('float32')
55
+ resized_image = resized_image / 255.0
56
+
57
+ mean = np.array([0.485, 0.456, 0.406])
58
+ std = np.array([0.229, 0.224, 0.225])
59
+ resized_image = (resized_image - mean[None, None, ...]) / std[None, None,
60
+ ...]
61
+ resized_image = resized_image.transpose((2, 0, 1))
62
+ resized_image = resized_image.astype('float32')
63
+
64
+ valid_ratio = min(1.0, float(resized_w / imgW))
65
+ return resized_image, valid_ratio
66
+
67
+
68
+ class SVTRResize(object):
69
+
70
+ def __init__(self, image_shape, padding=True, **kwargs):
71
+ self.image_shape = image_shape
72
+ self.padding = padding
73
+
74
+ def __call__(self, data):
75
+ img = data['image']
76
+ h, w = img.shape[:2]
77
+ norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
78
+ self.padding)
79
+ data['image'] = norm_img
80
+ data['valid_ratio'] = valid_ratio
81
+ r = float(w) / float(h)
82
+ data['real_ratio'] = max(1, round(r))
83
+ return data
84
+
85
+
86
+ class RecTVResize(object):
87
+
88
+ def __init__(self, image_shape=[32, 128], padding=True, **kwargs):
89
+ self.padding = padding
90
+ self.image_shape = image_shape
91
+ self.interpolation = T.InterpolationMode.BICUBIC
92
+ transforms = []
93
+ transforms.extend([
94
+ T.ToTensor(),
95
+ T.Normalize(0.5, 0.5),
96
+ ])
97
+ self.transforms = T.Compose(transforms)
98
+
99
+ def __call__(self, data):
100
+ img = data['image']
101
+ imgH, imgW = self.image_shape
102
+ w, h = img.size
103
+ if not self.padding:
104
+ resized_w = imgW
105
+ else:
106
+ ratio = w / float(h)
107
+ if math.ceil(imgH * ratio) > imgW:
108
+ resized_w = imgW
109
+ else:
110
+ resized_w = int(math.ceil(imgH * ratio))
111
+ resized_image = F.resize(img, (imgH, resized_w),
112
+ interpolation=self.interpolation)
113
+ img = self.transforms(resized_image)
114
+ if resized_w < imgW:
115
+ img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.)
116
+ valid_ratio = min(1.0, float(resized_w / imgW))
117
+ data['image'] = img
118
+ data['valid_ratio'] = valid_ratio
119
+ r = float(w) / float(h)
120
+ data['real_ratio'] = max(1, round(r))
121
+ return data
122
+
123
+
124
+ class LongResize(object):
125
+
126
+ def __init__(self,
127
+ base_shape=[[64, 64], [96, 48], [112, 40], [128, 32]],
128
+ max_ratio=12,
129
+ base_h=32,
130
+ padding_rand=False,
131
+ padding_bi=False,
132
+ padding=True,
133
+ **kwargs):
134
+ self.base_shape = base_shape
135
+ self.max_ratio = max_ratio
136
+ self.base_h = base_h
137
+ self.padding = padding
138
+ self.padding_rand = padding_rand
139
+ self.padding_bi = padding_bi
140
+
141
+ def __call__(self, data):
142
+ data = resize_norm_img_long(
143
+ data,
144
+ self.base_shape,
145
+ self.max_ratio,
146
+ self.base_h,
147
+ self.padding,
148
+ self.padding_rand,
149
+ self.padding_bi,
150
+ )
151
+ return data
152
+
153
+
154
+ class SliceResize(object):
155
+
156
+ def __init__(self, image_shape, padding=True, max_ratio=12, **kwargs):
157
+ self.image_shape = image_shape
158
+ self.padding = padding
159
+ self.max_ratio = max_ratio
160
+
161
+ def __call__(self, data):
162
+ img = data['image']
163
+ h, w = img.shape[:2]
164
+ w_bi = w // 2
165
+ img_list = [
166
+ img[:, :w_bi, :], img[:, w_bi:2 * w_bi, :],
167
+ img[:, w_bi // 2:(w_bi // 2) + w_bi, :]
168
+ ]
169
+ img_reshape = []
170
+ for img_s in img_list:
171
+ norm_img, valid_ratio = resize_norm_img_slice(
172
+ img_s, self.image_shape, max_ratio=self.max_ratio)
173
+ img_reshape.append(norm_img[None, :, :, :])
174
+ data['image'] = np.concatenate(img_reshape, 0)
175
+ data['valid_ratio'] = valid_ratio
176
+ return data
177
+
178
+
179
+ class SliceTVResize(object):
180
+
181
+ def __init__(self,
182
+ image_shape,
183
+ padding=True,
184
+ base_shape=[[64, 64], [96, 48], [112, 40], [128, 32]],
185
+ max_ratio=12,
186
+ base_h=32,
187
+ **kwargs):
188
+ self.image_shape = image_shape
189
+ self.padding = padding
190
+ self.max_ratio = max_ratio
191
+ self.base_h = base_h
192
+ self.interpolation = T.InterpolationMode.BICUBIC
193
+ transforms = []
194
+ transforms.extend([
195
+ T.ToTensor(),
196
+ T.Normalize(0.5, 0.5),
197
+ ])
198
+ self.transforms = T.Compose(transforms)
199
+
200
+ def __call__(self, data):
201
+ img = data['image']
202
+ w, h = img.size
203
+ w_ratio = ((w // h) // 2) * 2
204
+ w_ratio = max(6, w_ratio)
205
+ img = F.resize(img, (self.base_h, self.base_h * w_ratio),
206
+ interpolation=self.interpolation)
207
+ img = self.transforms(img)
208
+ img_list = []
209
+ for i in range(0, w_ratio // 2 - 1):
210
+ img_list.append(img[None, :, :,
211
+ i * 2 * self.base_h:(i * 2 + 4) * self.base_h])
212
+ data['image'] = torch.concat(img_list, 0)
213
+ data['valid_ratio'] = float(w_ratio) / w
214
+ return data
215
+
216
+
217
+ class RecTVResizeRatio(object):
218
+
219
+ def __init__(self,
220
+ image_shape=[32, 128],
221
+ padding=True,
222
+ base_shape=[[64, 64], [96, 48], [112, 40], [128, 32]],
223
+ max_ratio=12,
224
+ base_h=32,
225
+ **kwargs):
226
+ self.padding = padding
227
+ self.image_shape = image_shape
228
+ self.max_ratio = max_ratio
229
+ self.base_shape = base_shape
230
+ self.base_h = base_h
231
+ self.interpolation = T.InterpolationMode.BICUBIC
232
+ transforms = []
233
+ transforms.extend([
234
+ T.ToTensor(),
235
+ T.Normalize(0.5, 0.5),
236
+ ])
237
+ self.transforms = T.Compose(transforms)
238
+
239
+ def __call__(self, data):
240
+ img = data['image']
241
+ imgH, imgW = self.image_shape
242
+ w, h = img.size
243
+ gen_ratio = round(float(w) / float(h))
244
+ ratio_resize = 1 if gen_ratio == 0 else gen_ratio
245
+ ratio_resize = min(ratio_resize, self.max_ratio)
246
+ imgW, imgH = self.base_shape[ratio_resize -
247
+ 1] if ratio_resize <= 4 else [
248
+ self.base_h *
249
+ ratio_resize, self.base_h
250
+ ]
251
+ if not self.padding:
252
+ resized_w = imgW
253
+ else:
254
+ ratio = w / float(h)
255
+ if math.ceil(imgH * ratio) > imgW:
256
+ resized_w = imgW
257
+ else:
258
+ resized_w = int(math.ceil(imgH * ratio))
259
+ resized_image = F.resize(img, (imgH, resized_w),
260
+ interpolation=self.interpolation)
261
+ img = self.transforms(resized_image)
262
+ if resized_w < imgW:
263
+ img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.)
264
+ valid_ratio = min(1.0, float(resized_w / imgW))
265
+ data['image'] = img
266
+ data['valid_ratio'] = valid_ratio
267
+ return data
268
+
269
+
270
+ class RecDynamicResize(object):
271
+
272
+ def __init__(self, image_shape=[32, 128], padding=True, **kwargs):
273
+ self.padding = padding
274
+ self.image_shape = image_shape
275
+ self.max_ratio = image_shape[1] * 1.0 / image_shape[0]
276
+
277
+ def __call__(self, data):
278
+ img = data['image']
279
+ imgH, imgW = self.image_shape
280
+ h, w, imgC = img.shape
281
+ ratio = w / float(h)
282
+ max_wh_ratio = max(ratio, self.max_ratio)
283
+ imgW = int(imgH * max_wh_ratio)
284
+ if math.ceil(imgH * ratio) > imgW:
285
+ resized_w = imgW
286
+ else:
287
+ resized_w = int(math.ceil(imgH * ratio))
288
+ resized_image = cv2.resize(img, (resized_w, imgH))
289
+ resized_image = resized_image.astype('float32')
290
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
291
+ resized_image -= 0.5
292
+ resized_image /= 0.5
293
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
294
+ padding_im[:, :, 0:resized_w] = resized_image
295
+ data['image'] = padding_im
296
+ return data
297
+
298
+
299
+ def resize_norm_img_slice(
300
+ img,
301
+ image_shape,
302
+ base_shape=[[64, 64], [96, 48], [112, 40], [128, 32]],
303
+ max_ratio=12,
304
+ base_h=32,
305
+ padding=True,
306
+ ):
307
+ imgC, imgH, imgW = image_shape
308
+ h = img.shape[0]
309
+ w = img.shape[1]
310
+ gen_ratio = round(float(w) / float(h))
311
+ ratio_resize = 1 if gen_ratio == 0 else gen_ratio
312
+ ratio_resize = min(ratio_resize, max_ratio)
313
+ imgW, imgH = base_shape[ratio_resize - 1] if ratio_resize <= 4 else [
314
+ base_h * ratio_resize, base_h
315
+ ]
316
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
317
+ if not padding:
318
+ resized_image = cv2.resize(img, (imgW, imgH))
319
+ resized_w = imgW
320
+ else:
321
+ ratio = w / float(h)
322
+ if math.ceil(imgH * ratio) > imgW:
323
+ resized_w = imgW
324
+ else:
325
+ resized_w = int(math.ceil(imgH * ratio * (random.random() + 0.5)))
326
+ resized_w = min(imgW, resized_w)
327
+ resized_image = cv2.resize(img, (resized_w, imgH))
328
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
329
+ resized_image -= 0.5
330
+ resized_image /= 0.5
331
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
332
+ padding_im[:, :, :resized_w] = resized_image
333
+ valid_ratio = min(1.0, float(resized_w / imgW))
334
+ return padding_im, valid_ratio
335
+
336
+
337
+ def resize_norm_img(img,
338
+ image_shape,
339
+ padding=True,
340
+ interpolation=cv2.INTER_LINEAR):
341
+ imgC, imgH, imgW = image_shape
342
+ h = img.shape[0]
343
+ w = img.shape[1]
344
+ if not padding:
345
+ resized_image = cv2.resize(img, (imgW, imgH),
346
+ interpolation=interpolation)
347
+ resized_w = imgW
348
+ else:
349
+ ratio = w / float(h)
350
+ if math.ceil(imgH * ratio) > imgW:
351
+ resized_w = imgW
352
+ else:
353
+ resized_w = int(math.ceil(imgH * ratio))
354
+ resized_image = cv2.resize(img, (resized_w, imgH))
355
+ resized_image = resized_image.astype('float32')
356
+ if image_shape[0] == 1:
357
+ resized_image = resized_image / 255
358
+ resized_image = resized_image[np.newaxis, :]
359
+ else:
360
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
361
+ resized_image -= 0.5
362
+ resized_image /= 0.5
363
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
364
+ padding_im[:, :, 0:resized_w] = resized_image
365
+ valid_ratio = min(1.0, float(resized_w / imgW))
366
+ return padding_im, valid_ratio
367
+
368
+
369
+ def resize_norm_img_long(
370
+ data,
371
+ base_shape=[[64, 64], [96, 48], [112, 40], [128, 32]],
372
+ max_ratio=12,
373
+ base_h=32,
374
+ padding=True,
375
+ padding_rand=False,
376
+ padding_bi=False,
377
+ ):
378
+ img = data['image']
379
+ h = img.shape[0]
380
+ w = img.shape[1]
381
+ gen_ratio = data.get('gen_ratio', 0)
382
+ if gen_ratio == 0:
383
+ ratio = w / float(h)
384
+ gen_ratio = round(ratio) if ratio > 0.5 else 1
385
+ gen_ratio = min(data['gen_ratio'], max_ratio)
386
+ if padding_rand and random.random() < 0.5:
387
+ padding = False if padding else True
388
+ imgW, imgH = base_shape[gen_ratio -
389
+ 1] if gen_ratio <= len(base_shape) else [
390
+ base_h * gen_ratio, base_h
391
+ ]
392
+ if not padding:
393
+ resized_image = cv2.resize(img, (imgW, imgH),
394
+ interpolation=cv2.INTER_LINEAR)
395
+ resized_w = imgW
396
+ else:
397
+ ratio = w / float(h)
398
+ if math.ceil(imgH * ratio) > imgW:
399
+ resized_w = imgW
400
+ else:
401
+ resized_w = int(math.ceil(imgH * ratio * (random.random() + 0.5)))
402
+ resized_w = min(imgW, resized_w)
403
+
404
+ resized_image = cv2.resize(img, (resized_w, imgH))
405
+ resized_image = resized_image.astype('float32')
406
+
407
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
408
+ resized_image -= 0.5
409
+ resized_image /= 0.5
410
+ padding_im = np.zeros((3, imgH, imgW), dtype=np.float32)
411
+ if padding_bi and random.random() < 0.5:
412
+ padding_im[:, :, -resized_w:] = resized_image
413
+ else:
414
+ padding_im[:, :, :resized_w] = resized_image
415
+ valid_ratio = min(1.0, float(resized_w / imgW))
416
+ data['image'] = padding_im
417
+ data['valid_ratio'] = valid_ratio
418
+ data['gen_ratio'] = imgW // imgH
419
+ data['real_ratio'] = w // h
420
+ return data
421
+
422
+
423
+ class VisionLANResize(object):
424
+
425
+ def __init__(self, image_shape, **kwargs):
426
+ self.image_shape = image_shape
427
+
428
+ def __call__(self, data):
429
+ img = data['image']
430
+
431
+ imgC, imgH, imgW = self.image_shape
432
+ resized_image = cv2.resize(img, (imgW, imgH))
433
+ resized_image = resized_image.astype('float32')
434
+ if imgC == 1:
435
+ resized_image = resized_image / 255
436
+ norm_img = resized_image[np.newaxis, :]
437
+ else:
438
+ norm_img = resized_image.transpose((2, 0, 1)) / 255
439
+
440
+ data['image'] = norm_img
441
+ data['valid_ratio'] = 1.0
442
+ return data
443
+
444
+
445
+ class RobustScannerRecResizeImg(object):
446
+
447
+ def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
448
+ self.image_shape = image_shape
449
+ self.width_downsample_ratio = width_downsample_ratio
450
+
451
+ def __call__(self, data):
452
+ img = data['image']
453
+ norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
454
+ img, self.image_shape, self.width_downsample_ratio)
455
+ data['image'] = norm_img
456
+ data['resized_shape'] = resize_shape
457
+ data['pad_shape'] = pad_shape
458
+ data['valid_ratio'] = valid_ratio
459
+ return data
460
+
461
+
462
+ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
463
+ imgC, imgH, imgW_min, imgW_max = image_shape
464
+ h = img.shape[0]
465
+ w = img.shape[1]
466
+ valid_ratio = 1.0
467
+ # make sure new_width is an integral multiple of width_divisor.
468
+ width_divisor = int(1 / width_downsample_ratio)
469
+ # resize
470
+ ratio = w / float(h)
471
+ resize_w = math.ceil(imgH * ratio)
472
+ if resize_w % width_divisor != 0:
473
+ resize_w = round(resize_w / width_divisor) * width_divisor
474
+ if imgW_min is not None:
475
+ resize_w = max(imgW_min, resize_w)
476
+ if imgW_max is not None:
477
+ valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
478
+ resize_w = min(imgW_max, resize_w)
479
+ resized_image = cv2.resize(img, (resize_w, imgH))
480
+ resized_image = resized_image.astype('float32')
481
+ # norm
482
+ if image_shape[0] == 1:
483
+ resized_image = resized_image / 255
484
+ resized_image = resized_image[np.newaxis, :]
485
+ else:
486
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
487
+ resized_image -= 0.5
488
+ resized_image /= 0.5
489
+ resize_shape = resized_image.shape
490
+ padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
491
+ padding_im[:, :, 0:resize_w] = resized_image
492
+ pad_shape = padding_im.shape
493
+
494
+ return padding_im, resize_shape, pad_shape, valid_ratio
495
+
496
+
497
+ class SRNRecResizeImg(object):
498
+
499
+ def __init__(self, image_shape, **kwargs):
500
+ self.image_shape = image_shape
501
+
502
+ def __call__(self, data):
503
+ img = data['image']
504
+ norm_img = resize_norm_img_srn(img, self.image_shape)
505
+ data['image'] = norm_img
506
+
507
+ return data
508
+
509
+
510
+ def resize_norm_img_srn(img, image_shape):
511
+ imgC, imgH, imgW = image_shape
512
+
513
+ img_black = np.zeros((imgH, imgW))
514
+ im_hei = img.shape[0]
515
+ im_wid = img.shape[1]
516
+
517
+ if im_wid <= im_hei * 1:
518
+ img_new = cv2.resize(img, (imgH * 1, imgH))
519
+ elif im_wid <= im_hei * 2:
520
+ img_new = cv2.resize(img, (imgH * 2, imgH))
521
+ elif im_wid <= im_hei * 3:
522
+ img_new = cv2.resize(img, (imgH * 3, imgH))
523
+ else:
524
+ img_new = cv2.resize(img, (imgW, imgH))
525
+
526
+ img_np = np.asarray(img_new)
527
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
528
+ img_black[:, 0:img_np.shape[1]] = img_np
529
+ img_black = img_black[:, :, np.newaxis]
530
+
531
+ row, col, c = img_black.shape
532
+ c = 1
533
+
534
+ return np.reshape(img_black, (c, row, col)).astype(np.float32)
@@ -0,0 +1,125 @@
1
+ import copy
2
+ import random
3
+
4
+ import numpy as np
5
+
6
+ from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
7
+
8
+
9
+ class SMTRLabelEncode(BaseRecLabelEncode):
10
+ """Convert between text-label and text-index."""
11
+
12
+ BOS = '<s>'
13
+ EOS = '</s>'
14
+ IN_F = '<INF>' # ignore
15
+ IN_B = '<INB>' # ignore
16
+ PAD = '<pad>'
17
+
18
+ def __init__(self,
19
+ max_text_length,
20
+ character_dict_path=None,
21
+ use_space_char=False,
22
+ sub_str_len=5,
23
+ **kwargs):
24
+
25
+ super(SMTRLabelEncode,
26
+ self).__init__(max_text_length, character_dict_path,
27
+ use_space_char)
28
+ self.substr_len = sub_str_len
29
+ self.rang_subs = [i for i in range(1, self.substr_len + 1)]
30
+ self.idx_char = [i for i in range(1, self.num_character - 5)]
31
+
32
+ def __call__(self, data):
33
+ text = data['label']
34
+ text = self.encode(text)
35
+ if text is None:
36
+ return None
37
+ if len(text) > self.max_text_len:
38
+ return None
39
+
40
+ data['length'] = np.array(len(text))
41
+ text_in = [self.dict[self.IN_F]] * (self.substr_len) + text + [
42
+ self.dict[self.IN_B]
43
+ ] * (self.substr_len)
44
+
45
+ sub_string_list_pre = []
46
+ next_label_pre = []
47
+ sub_string_list = []
48
+ next_label = []
49
+ for i in range(self.substr_len, len(text_in) - self.substr_len):
50
+
51
+ sub_string_list.append(text_in[i - self.substr_len:i])
52
+ next_label.append(text_in[i])
53
+
54
+ if self.substr_len - i == 0:
55
+ sub_string_list_pre.append(text_in[-i:])
56
+ else:
57
+ sub_string_list_pre.append(text_in[-i:self.substr_len - i])
58
+
59
+ next_label_pre.append(text_in[-(i + 1)])
60
+
61
+ sub_string_list.append(
62
+ [self.dict[self.IN_F]] *
63
+ (self.substr_len - len(text[-self.substr_len:])) +
64
+ text[-self.substr_len:])
65
+ next_label.append(self.dict[self.EOS])
66
+ sub_string_list_pre.append(
67
+ text[:self.substr_len] + [self.dict[self.IN_B]] *
68
+ (self.substr_len - len(text[:self.substr_len])))
69
+ next_label_pre.append(self.dict[self.EOS])
70
+
71
+ for sstr, l in zip(sub_string_list[self.substr_len:],
72
+ next_label[self.substr_len:]):
73
+
74
+ id_shu = np.random.choice(self.rang_subs, 2)
75
+
76
+ sstr1 = copy.deepcopy(sstr)
77
+ sstr1[id_shu[0] - 1] = random.randint(1, self.num_character - 5)
78
+ if sstr1 not in sub_string_list:
79
+ sub_string_list.append(sstr1)
80
+ next_label.append(l)
81
+
82
+ sstr[id_shu[1] - 1] = random.randint(1, self.num_character - 5)
83
+
84
+ for sstr, l in zip(sub_string_list_pre[self.substr_len:],
85
+ next_label_pre[self.substr_len:]):
86
+
87
+ id_shu = np.random.choice(self.rang_subs, 2)
88
+
89
+ sstr1 = copy.deepcopy(sstr)
90
+ sstr1[id_shu[0] - 1] = random.randint(1, self.num_character - 5)
91
+ if sstr1 not in sub_string_list_pre:
92
+ sub_string_list_pre.append(sstr1)
93
+ next_label_pre.append(l)
94
+ sstr[id_shu[1] - 1] = random.randint(1, self.num_character - 5)
95
+
96
+ data['length_subs'] = np.array(len(sub_string_list))
97
+ sub_string_list = sub_string_list + [
98
+ [self.dict[self.PAD]] * self.substr_len
99
+ ] * ((self.max_text_len * 2) + 2 - len(sub_string_list))
100
+ next_label = next_label + [self.dict[self.PAD]] * (
101
+ (self.max_text_len * 2) + 2 - len(next_label))
102
+ data['label_subs'] = np.array(sub_string_list)
103
+ data['label_next'] = np.array(next_label)
104
+
105
+ data['length_subs_pre'] = np.array(len(sub_string_list_pre))
106
+ sub_string_list_pre = sub_string_list_pre + [
107
+ [self.dict[self.PAD]] * self.substr_len
108
+ ] * ((self.max_text_len * 2) + 2 - len(sub_string_list_pre))
109
+ next_label_pre = next_label_pre + [self.dict[self.PAD]] * (
110
+ (self.max_text_len * 2) + 2 - len(next_label_pre))
111
+ data['label_subs_pre'] = np.array(sub_string_list_pre)
112
+ data['label_next_pre'] = np.array(next_label_pre)
113
+
114
+ text = [self.dict[self.BOS]] + text + [self.dict[self.EOS]]
115
+ text = text + [self.dict[self.PAD]
116
+ ] * (self.max_text_len + 2 - len(text))
117
+ data['label'] = np.array(text)
118
+ return data
119
+
120
+ def add_special_char(self, dict_character):
121
+ dict_character = [self.EOS] + dict_character + [
122
+ self.BOS, self.IN_F, self.IN_B, self.PAD
123
+ ]
124
+ self.num_character = len(dict_character)
125
+ return dict_character
@@ -0,0 +1,37 @@
1
+ import numpy as np
2
+
3
+ from .ce_label_encode import BaseRecLabelEncode
4
+
5
+
6
+ class SRNLabelEncode(BaseRecLabelEncode):
7
+ """Convert between text-label and text-index."""
8
+
9
+ def __init__(self,
10
+ max_text_length,
11
+ character_dict_path=None,
12
+ use_space_char=False,
13
+ **kwargs):
14
+ super(SRNLabelEncode,
15
+ self).__init__(max_text_length, character_dict_path,
16
+ use_space_char)
17
+
18
+ def add_special_char(self, dict_character):
19
+ dict_character = dict_character + ['<BOS>', '<EOS>']
20
+ self.start_idx = len(dict_character) - 2
21
+ self.end_idx = len(dict_character) - 1
22
+ return dict_character
23
+
24
+ def __call__(self, data):
25
+ text = data['label']
26
+ text = self.encode(text)
27
+ if text is None:
28
+ return None
29
+ if len(text) > self.max_text_len:
30
+ return None
31
+ data['length'] = np.array(len(text))
32
+ text = text + [self.end_idx] * (self.max_text_len - len(text))
33
+ data['label'] = np.array(text)
34
+ return data
35
+
36
+ def get_ignored_tokens(self):
37
+ return [self.start_idx, self.end_idx]