openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,177 @@
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch.distributed as dist
5
+ from torch.utils.data import Sampler
6
+
7
+
8
+ class MultiScaleSampler(Sampler):
9
+
10
+ def __init__(
11
+ self,
12
+ data_source,
13
+ scales,
14
+ first_bs=128,
15
+ fix_bs=True,
16
+ divided_factor=[8, 16],
17
+ is_training=True,
18
+ ratio_wh=0.8,
19
+ max_w=480.0,
20
+ seed=None,
21
+ ):
22
+ """
23
+ multi scale samper
24
+ Args:
25
+ data_source(dataset)
26
+ scales(list): several scales for image resolution
27
+ first_bs(int): batch size for the first scale in scales
28
+ divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
29
+ is_training(boolean): mode
30
+ """
31
+ # min. and max. spatial dimensions
32
+ self.data_source = data_source
33
+ self.data_idx_order_list = np.array(data_source.data_idx_order_list)
34
+ self.ds_width = data_source.ds_width
35
+ self.seed = data_source.seed
36
+ if self.ds_width:
37
+ self.wh_ratio = data_source.wh_ratio
38
+ self.wh_ratio_sort = data_source.wh_ratio_sort
39
+ self.n_data_samples = len(self.data_source)
40
+ self.ratio_wh = ratio_wh
41
+ self.max_w = max_w
42
+
43
+ if isinstance(scales[0], list):
44
+ width_dims = [i[0] for i in scales]
45
+ height_dims = [i[1] for i in scales]
46
+ elif isinstance(scales[0], int):
47
+ width_dims = scales
48
+ height_dims = scales
49
+ base_im_w = width_dims[0]
50
+ base_im_h = height_dims[0]
51
+ base_batch_size = first_bs
52
+
53
+ # Get the GPU and node related information
54
+ if dist.is_initialized():
55
+ num_replicas = dist.get_world_size()
56
+ rank = dist.get_rank()
57
+ else:
58
+ num_replicas = 1
59
+ rank = 0
60
+ # adjust the total samples to avoid batch dropping
61
+ num_samples_per_replica = int(self.n_data_samples * 1.0 / num_replicas)
62
+
63
+ img_indices = [idx for idx in range(self.n_data_samples)]
64
+
65
+ self.shuffle = False
66
+ if is_training:
67
+ # compute the spatial dimensions and corresponding batch size
68
+ # ImageNet models down-sample images by a factor of 32.
69
+ # Ensure that width and height dimensions are multiples are multiple of 32.
70
+ width_dims = [
71
+ int((w // divided_factor[0]) * divided_factor[0])
72
+ for w in width_dims
73
+ ]
74
+ height_dims = [
75
+ int((h // divided_factor[1]) * divided_factor[1])
76
+ for h in height_dims
77
+ ]
78
+
79
+ img_batch_pairs = list()
80
+ base_elements = base_im_w * base_im_h * base_batch_size
81
+ for h, w in zip(height_dims, width_dims):
82
+ if fix_bs:
83
+ batch_size = base_batch_size
84
+ else:
85
+ batch_size = int(max(1, (base_elements / (h * w))))
86
+ img_batch_pairs.append((w, h, batch_size))
87
+ self.img_batch_pairs = img_batch_pairs
88
+ self.shuffle = True
89
+ else:
90
+ self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)]
91
+
92
+ self.img_indices = img_indices
93
+ self.n_samples_per_replica = num_samples_per_replica
94
+ self.epoch = 0
95
+ self.rank = rank
96
+ self.num_replicas = num_replicas
97
+
98
+ self.batch_list = []
99
+ self.current = 0
100
+ last_index = num_samples_per_replica * num_replicas
101
+ indices_rank_i = self.img_indices[self.rank:last_index:self.
102
+ num_replicas]
103
+ while self.current < self.n_samples_per_replica:
104
+ for curr_w, curr_h, curr_bsz in self.img_batch_pairs:
105
+ end_index = min(self.current + curr_bsz,
106
+ self.n_samples_per_replica)
107
+ batch_ids = indices_rank_i[self.current:end_index]
108
+ n_batch_samples = len(batch_ids)
109
+ if n_batch_samples != curr_bsz:
110
+ batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
111
+ self.current += curr_bsz
112
+
113
+ if len(batch_ids) > 0:
114
+ batch = [curr_w, curr_h, len(batch_ids)]
115
+ self.batch_list.append(batch)
116
+ random.shuffle(self.batch_list)
117
+ self.length = len(self.batch_list)
118
+ self.batchs_in_one_epoch = self.iter()
119
+ self.batchs_in_one_epoch_id = [
120
+ i for i in range(len(self.batchs_in_one_epoch))
121
+ ]
122
+
123
+ def __iter__(self):
124
+ if self.seed is None:
125
+ random.seed(self.epoch)
126
+ self.epoch += 1
127
+ else:
128
+ random.seed(self.seed)
129
+ random.shuffle(self.batchs_in_one_epoch_id)
130
+ for batch_tuple_id in self.batchs_in_one_epoch_id:
131
+ yield self.batchs_in_one_epoch[batch_tuple_id]
132
+
133
+ def iter(self):
134
+ if self.shuffle:
135
+ if self.seed is not None:
136
+ random.seed(self.seed)
137
+ else:
138
+ random.seed(self.epoch)
139
+ if not self.ds_width:
140
+ random.shuffle(self.img_indices)
141
+ random.shuffle(self.img_batch_pairs)
142
+ indices_rank_i = self.img_indices[
143
+ self.rank:len(self.img_indices):self.num_replicas]
144
+ else:
145
+ indices_rank_i = self.img_indices[
146
+ self.rank:len(self.img_indices):self.num_replicas]
147
+
148
+ start_index = 0
149
+ batchs_in_one_epoch = []
150
+ for batch_tuple in self.batch_list:
151
+ curr_w, curr_h, curr_bsz = batch_tuple
152
+ end_index = min(start_index + curr_bsz, self.n_samples_per_replica)
153
+ batch_ids = indices_rank_i[start_index:end_index]
154
+ n_batch_samples = len(batch_ids)
155
+ if n_batch_samples != curr_bsz:
156
+ batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
157
+ start_index += curr_bsz
158
+
159
+ if len(batch_ids) > 0:
160
+ if self.ds_width:
161
+ wh_ratio_current = self.wh_ratio[
162
+ self.wh_ratio_sort[batch_ids]]
163
+ ratio_current = wh_ratio_current.mean()
164
+ ratio_current = ratio_current if ratio_current * curr_h < self.max_w else self.max_w / curr_h
165
+ else:
166
+ ratio_current = None
167
+ batch = [(curr_w, curr_h, b_id, ratio_current)
168
+ for b_id in batch_ids]
169
+ # yield batch
170
+ batchs_in_one_epoch.append(batch)
171
+ return batchs_in_one_epoch
172
+
173
+ def set_epoch(self, epoch: int):
174
+ self.epoch = epoch
175
+
176
+ def __len__(self):
177
+ return self.length
@@ -0,0 +1,217 @@
1
+ import io
2
+ import math
3
+ import random
4
+ import os
5
+ import cv2
6
+ import lmdb
7
+ import numpy as np
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset
10
+
11
+ from openrec.preprocess import create_operators, transform
12
+
13
+
14
+ class RatioDataSet(Dataset):
15
+
16
+ def __init__(self, config, mode, logger, seed=None, epoch=1):
17
+ super(RatioDataSet, self).__init__()
18
+ self.ds_width = config[mode]['dataset'].get('ds_width', True)
19
+ global_config = config['Global']
20
+ dataset_config = config[mode]['dataset']
21
+ loader_config = config[mode]['loader']
22
+ max_ratio = loader_config.get('max_ratio', 10)
23
+ min_ratio = loader_config.get('min_ratio', 1)
24
+ syn = dataset_config.get('syn', False)
25
+ if syn:
26
+ data_dir_list = []
27
+ data_dir = '../training_aug_lmdb_noerror/ep' + str(epoch)
28
+ for dir_syn in os.listdir(data_dir):
29
+ data_dir_list.append(data_dir + '/' + dir_syn)
30
+ else:
31
+ data_dir_list = dataset_config['data_dir_list']
32
+ self.padding = dataset_config.get('padding', True)
33
+ self.padding_rand = dataset_config.get('padding_rand', False)
34
+ self.padding_doub = dataset_config.get('padding_doub', False)
35
+ self.do_shuffle = loader_config['shuffle']
36
+ self.seed = epoch
37
+ data_source_num = len(data_dir_list)
38
+ ratio_list = dataset_config.get('ratio_list', 1.0)
39
+ if isinstance(ratio_list, (float, int)):
40
+ ratio_list = [float(ratio_list)] * int(data_source_num)
41
+ assert (
42
+ len(ratio_list) == data_source_num
43
+ ), 'The length of ratio_list should be the same as the file_list.'
44
+ self.lmdb_sets = self.load_hierarchical_lmdb_dataset(
45
+ data_dir_list, ratio_list)
46
+ for data_dir in data_dir_list:
47
+ logger.info('Initialize indexs of datasets:%s' % data_dir)
48
+ self.logger = logger
49
+ self.data_idx_order_list = self.dataset_traversal()
50
+ wh_ratio = np.around(np.array(self.get_wh_ratio()))
51
+ self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio)
52
+ for i in range(max_ratio + 1):
53
+ logger.info((1 * (self.wh_ratio == i)).sum())
54
+ self.wh_ratio_sort = np.argsort(self.wh_ratio)
55
+ self.ops = create_operators(dataset_config['transforms'],
56
+ global_config)
57
+
58
+ self.need_reset = True in [x < 1 for x in ratio_list]
59
+ self.error = 0
60
+ self.base_shape = dataset_config.get(
61
+ 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
62
+ self.base_h = 32
63
+
64
+ def get_wh_ratio(self):
65
+ wh_ratio = []
66
+ for idx in range(self.data_idx_order_list.shape[0]):
67
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
68
+ lmdb_idx = int(lmdb_idx)
69
+ file_idx = int(file_idx)
70
+ wh_key = 'wh-%09d'.encode() % file_idx
71
+ wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key)
72
+ if wh is None:
73
+ img_key = f'image-{file_idx:09d}'.encode()
74
+ img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key)
75
+ buf = io.BytesIO(img)
76
+ w, h = Image.open(buf).size
77
+ else:
78
+ wh = wh.decode('utf-8')
79
+ w, h = wh.split('_')
80
+ wh_ratio.append(float(w) / float(h))
81
+ return wh_ratio
82
+
83
+ def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list):
84
+ lmdb_sets = {}
85
+ dataset_idx = 0
86
+ for dirpath, ratio in zip(data_dir_list, ratio_list):
87
+ env = lmdb.open(dirpath,
88
+ max_readers=32,
89
+ readonly=True,
90
+ lock=False,
91
+ readahead=False,
92
+ meminit=False)
93
+ txn = env.begin(write=False)
94
+ num_samples = int(txn.get('num-samples'.encode()))
95
+ lmdb_sets[dataset_idx] = {
96
+ 'dirpath': dirpath,
97
+ 'env': env,
98
+ 'txn': txn,
99
+ 'num_samples': num_samples,
100
+ 'ratio_num_samples': int(ratio * num_samples)
101
+ }
102
+ dataset_idx += 1
103
+ return lmdb_sets
104
+
105
+ def dataset_traversal(self):
106
+ lmdb_num = len(self.lmdb_sets)
107
+ total_sample_num = 0
108
+ for lno in range(lmdb_num):
109
+ total_sample_num += self.lmdb_sets[lno]['ratio_num_samples']
110
+ data_idx_order_list = np.zeros((total_sample_num, 2))
111
+ beg_idx = 0
112
+ for lno in range(lmdb_num):
113
+ tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples']
114
+ end_idx = beg_idx + tmp_sample_num
115
+ data_idx_order_list[beg_idx:end_idx, 0] = lno
116
+ data_idx_order_list[beg_idx:end_idx, 1] = list(
117
+ random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1),
118
+ self.lmdb_sets[lno]['ratio_num_samples']))
119
+ beg_idx = beg_idx + tmp_sample_num
120
+ return data_idx_order_list
121
+
122
+ def get_img_data(self, value):
123
+ """get_img_data."""
124
+ if not value:
125
+ return None
126
+ imgdata = np.frombuffer(value, dtype='uint8')
127
+ if imgdata is None:
128
+ return None
129
+ imgori = cv2.imdecode(imgdata, 1)
130
+ if imgori is None:
131
+ return None
132
+ return imgori
133
+
134
+ def resize_norm_img(self, data, gen_ratio, padding=True):
135
+ img = data['image']
136
+ h = img.shape[0]
137
+ w = img.shape[1]
138
+ if self.padding_rand and random.random() < 0.5:
139
+ padding = not padding
140
+ imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [
141
+ self.base_h * gen_ratio, self.base_h
142
+ ]
143
+ use_ratio = imgW // imgH
144
+ if use_ratio >= (w // h) + 2:
145
+ self.error += 1
146
+ return None
147
+ if not padding:
148
+ resized_image = cv2.resize(img, (imgW, imgH),
149
+ interpolation=cv2.INTER_LINEAR)
150
+ resized_w = imgW
151
+ else:
152
+ ratio = w / float(h)
153
+ if math.ceil(imgH * ratio) > imgW:
154
+ resized_w = imgW
155
+ else:
156
+ resized_w = int(
157
+ math.ceil(imgH * ratio * (random.random() + 0.5)))
158
+ resized_w = min(imgW, resized_w)
159
+
160
+ resized_image = cv2.resize(img, (resized_w, imgH))
161
+ resized_image = resized_image.astype('float32')
162
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
163
+ resized_image -= 0.5
164
+ resized_image /= 0.5
165
+ padding_im = np.zeros((3, imgH, imgW), dtype=np.float32)
166
+ if self.padding_doub and random.random() < 0.5:
167
+ padding_im[:, :, -resized_w:] = resized_image
168
+ else:
169
+ padding_im[:, :, :resized_w] = resized_image
170
+ valid_ratio = min(1.0, float(resized_w / imgW))
171
+ data['image'] = padding_im
172
+ data['valid_ratio'] = valid_ratio
173
+ data['real_ratio'] = round(w / h)
174
+ return data
175
+
176
+ def get_lmdb_sample_info(self, txn, index):
177
+ label_key = 'label-%09d'.encode() % index
178
+ label = txn.get(label_key)
179
+ if label is None:
180
+ return None
181
+ label = label.decode('utf-8')
182
+ img_key = 'image-%09d'.encode() % index
183
+ imgbuf = txn.get(img_key)
184
+ return imgbuf, label
185
+
186
+ def __getitem__(self, properties):
187
+ img_width = properties[0]
188
+ img_height = properties[1]
189
+ idx = properties[2]
190
+ ratio = properties[3]
191
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
192
+ lmdb_idx = int(lmdb_idx)
193
+ file_idx = int(file_idx)
194
+ sample_info = self.get_lmdb_sample_info(
195
+ self.lmdb_sets[lmdb_idx]['txn'], file_idx)
196
+ if sample_info is None:
197
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
198
+ ids = random.sample(ratio_ids, 1)
199
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
200
+ img, label = sample_info
201
+ data = {'image': img, 'label': label}
202
+ outs = transform(data, self.ops[:-1])
203
+ if outs is not None:
204
+ outs = self.resize_norm_img(outs, ratio, padding=self.padding)
205
+ if outs is None:
206
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
207
+ ids = random.sample(ratio_ids, 1)
208
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
209
+ outs = transform(outs, self.ops[-1:])
210
+ if outs is None:
211
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
212
+ ids = random.sample(ratio_ids, 1)
213
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
214
+ return outs
215
+
216
+ def __len__(self):
217
+ return self.data_idx_order_list.shape[0]
@@ -0,0 +1,273 @@
1
+ import io
2
+ import math
3
+ import random
4
+ import re
5
+ import unicodedata
6
+
7
+ import cv2
8
+ import lmdb
9
+ import numpy as np
10
+ from PIL import Image
11
+ from torch.utils.data import Dataset
12
+
13
+ from openrec.preprocess import create_operators, transform
14
+
15
+
16
+ class CharsetAdapter:
17
+ """Transforms labels according to the target charset."""
18
+
19
+ def __init__(self, target_charset) -> None:
20
+ super().__init__()
21
+ self.lowercase_only = target_charset == target_charset.lower()
22
+ self.uppercase_only = target_charset == target_charset.upper()
23
+ self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
24
+
25
+ def __call__(self, label):
26
+ if self.lowercase_only:
27
+ label = label.lower()
28
+ elif self.uppercase_only:
29
+ label = label.upper()
30
+ # Remove unsupported characters
31
+ label = self.unsupported.sub('', label)
32
+ return label
33
+
34
+
35
+ class RatioDataSetTest(Dataset):
36
+
37
+ def __init__(self, config, mode, logger, seed=None, epoch=1):
38
+ super(RatioDataSetTest, self).__init__()
39
+ self.ds_width = config[mode]['dataset'].get('ds_width', True)
40
+ global_config = config['Global']
41
+ dataset_config = config[mode]['dataset']
42
+ loader_config = config[mode]['loader']
43
+ max_ratio = loader_config.get('max_ratio', 10)
44
+ min_ratio = loader_config.get('min_ratio', 1)
45
+ data_dir_list = dataset_config['data_dir_list']
46
+ self.do_shuffle = loader_config['shuffle']
47
+ self.seed = epoch
48
+ self.max_text_length = global_config['max_text_length']
49
+ data_source_num = len(data_dir_list)
50
+ ratio_list = dataset_config.get('ratio_list', 1.0)
51
+ if isinstance(ratio_list, (float, int)):
52
+ ratio_list = [float(ratio_list)] * int(data_source_num)
53
+ assert len(
54
+ ratio_list
55
+ ) == data_source_num, 'The length of ratio_list should be the same as the file_list.'
56
+ self.lmdb_sets = self.load_hierarchical_lmdb_dataset(
57
+ data_dir_list, ratio_list)
58
+ for data_dir in data_dir_list:
59
+ logger.info('Initialize indexs of datasets:%s' % data_dir)
60
+ self.logger = logger
61
+ data_idx_order_list = self.dataset_traversal()
62
+ character_dict_path = global_config.get('character_dict_path', None)
63
+ use_space_char = global_config.get('use_space_char', False)
64
+ if character_dict_path is None:
65
+ char_test = '0123456789abcdefghijklmnopqrstuvwxyz'
66
+ else:
67
+ char_test = ''
68
+ with open(character_dict_path, 'rb') as fin:
69
+ lines = fin.readlines()
70
+ for line in lines:
71
+ line = line.decode('utf-8').strip('\n').strip('\r\n')
72
+ char_test += line
73
+ if use_space_char:
74
+ char_test += ' '
75
+ wh_ratio, data_idx_order_list = self.get_wh_ratio(
76
+ data_idx_order_list, char_test)
77
+ self.data_idx_order_list = np.array(data_idx_order_list)
78
+ wh_ratio = np.around(np.array(wh_ratio))
79
+ self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio)
80
+ for i in range(max_ratio + 1):
81
+ logger.info((1 * (self.wh_ratio == i)).sum())
82
+ self.wh_ratio_sort = np.argsort(self.wh_ratio)
83
+ self.ops = create_operators(dataset_config['transforms'],
84
+ global_config)
85
+
86
+ self.need_reset = True in [x < 1 for x in ratio_list]
87
+ self.error = 0
88
+ self.base_shape = dataset_config.get(
89
+ 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
90
+ self.base_h = 32
91
+
92
+ def get_wh_ratio(self, data_idx_order_list, char_test):
93
+ wh_ratio = []
94
+ wh_ratio_len = [[0 for _ in range(26)] for _ in range(11)]
95
+ data_idx_order_list_filter = []
96
+ charset_adapter = CharsetAdapter(char_test)
97
+
98
+ for idx in range(data_idx_order_list.shape[0]):
99
+ lmdb_idx, file_idx = data_idx_order_list[idx]
100
+ lmdb_idx = int(lmdb_idx)
101
+ file_idx = int(file_idx)
102
+ wh_key = 'wh-%09d'.encode() % file_idx
103
+ wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key)
104
+ if wh is None:
105
+ img_key = f'image-{file_idx:09d}'.encode()
106
+ img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key)
107
+ buf = io.BytesIO(img)
108
+ w, h = Image.open(buf).size
109
+ else:
110
+ wh = wh.decode('utf-8')
111
+ w, h = wh.split('_')
112
+
113
+ label_key = 'label-%09d'.encode() % file_idx
114
+ label = self.lmdb_sets[lmdb_idx]['txn'].get(label_key)
115
+ if label is not None:
116
+ # return None
117
+ label = label.decode('utf-8')
118
+ # if remove_whitespace:
119
+ label = ''.join(label.split())
120
+ # Normalize unicode composites (if any) and convert to compatible ASCII characters
121
+ # if normalize_unicode:
122
+ label = unicodedata.normalize('NFKD',
123
+ label).encode('ascii',
124
+ 'ignore').decode()
125
+ # Filter by length before removing unsupported characters. The original label might be too long.
126
+ if len(label) > self.max_text_length:
127
+ continue
128
+ label = charset_adapter(label)
129
+ if not label:
130
+ continue
131
+
132
+ wh_ratio.append(float(w) / float(h))
133
+ wh_ratio_len[int(float(w) /
134
+ float(h)) if int(float(w) /
135
+ float(h)) <= 10 else
136
+ 10][len(label) if len(label) <= 25 else 25] += 1
137
+ data_idx_order_list_filter.append([lmdb_idx, file_idx])
138
+ self.logger.info(wh_ratio_len)
139
+ return wh_ratio, data_idx_order_list_filter
140
+
141
+ def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list):
142
+ lmdb_sets = {}
143
+ dataset_idx = 0
144
+ for dirpath, ratio in zip(data_dir_list, ratio_list):
145
+ env = lmdb.open(dirpath,
146
+ max_readers=32,
147
+ readonly=True,
148
+ lock=False,
149
+ readahead=False,
150
+ meminit=False)
151
+ txn = env.begin(write=False)
152
+ num_samples = int(txn.get('num-samples'.encode()))
153
+ lmdb_sets[dataset_idx] = {
154
+ 'dirpath': dirpath,
155
+ 'env': env,
156
+ 'txn': txn,
157
+ 'num_samples': num_samples,
158
+ 'ratio_num_samples': int(ratio * num_samples),
159
+ }
160
+ dataset_idx += 1
161
+ return lmdb_sets
162
+
163
+ def dataset_traversal(self):
164
+ lmdb_num = len(self.lmdb_sets)
165
+ total_sample_num = 0
166
+ for lno in range(lmdb_num):
167
+ total_sample_num += self.lmdb_sets[lno]['ratio_num_samples']
168
+ data_idx_order_list = np.zeros((total_sample_num, 2))
169
+ beg_idx = 0
170
+ for lno in range(lmdb_num):
171
+ tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples']
172
+ end_idx = beg_idx + tmp_sample_num
173
+ data_idx_order_list[beg_idx:end_idx, 0] = lno
174
+ data_idx_order_list[beg_idx:end_idx, 1] = list(
175
+ random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1),
176
+ self.lmdb_sets[lno]['ratio_num_samples']))
177
+ beg_idx = beg_idx + tmp_sample_num
178
+ return data_idx_order_list
179
+
180
+ def get_img_data(self, value):
181
+ """get_img_data."""
182
+ if not value:
183
+ return None
184
+ imgdata = np.frombuffer(value, dtype='uint8')
185
+ if imgdata is None:
186
+ return None
187
+ imgori = cv2.imdecode(imgdata, 1)
188
+ if imgori is None:
189
+ return None
190
+ return imgori
191
+
192
+ def resize_norm_img(self, data, gen_ratio, padding=True):
193
+ img = data['image']
194
+ h = img.shape[0]
195
+ w = img.shape[1]
196
+
197
+ imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [
198
+ self.base_h * gen_ratio, self.base_h
199
+ ]
200
+ use_ratio = imgW // imgH
201
+ if use_ratio >= (w // h) + 2:
202
+ self.error += 1
203
+ return None
204
+ if not padding:
205
+ resized_image = cv2.resize(img, (imgW, imgH),
206
+ interpolation=cv2.INTER_LINEAR)
207
+ resized_w = imgW
208
+ else:
209
+ ratio = w / float(h)
210
+ if math.ceil(imgH * ratio) > imgW:
211
+ resized_w = imgW
212
+ else:
213
+ resized_w = int(
214
+ math.ceil(imgH * ratio * (random.random() + 0.5)))
215
+ resized_w = min(imgW, resized_w)
216
+
217
+ resized_image = cv2.resize(img, (resized_w, imgH))
218
+ resized_image = resized_image.astype('float32')
219
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
220
+ resized_image -= 0.5
221
+ resized_image /= 0.5
222
+ padding_im = np.zeros((3, imgH, imgW), dtype=np.float32)
223
+ padding_im[:, :, :resized_w] = resized_image
224
+ valid_ratio = min(1.0, float(resized_w / imgW))
225
+ data['image'] = padding_im
226
+ data['valid_ratio'] = valid_ratio
227
+ data['gen_ratio'] = imgW // imgH
228
+ data['real_ratio'] = max(1, round(w / h))
229
+ return data
230
+
231
+ def get_lmdb_sample_info(self, txn, index):
232
+ label_key = 'label-%09d'.encode() % index
233
+ label = txn.get(label_key)
234
+ if label is None:
235
+ return None
236
+ label = label.decode('utf-8')
237
+ img_key = 'image-%09d'.encode() % index
238
+ imgbuf = txn.get(img_key)
239
+ return imgbuf, label
240
+
241
+ def __getitem__(self, properties):
242
+ img_width = properties[0]
243
+ img_height = properties[1]
244
+ idx = properties[2]
245
+ ratio = properties[3]
246
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
247
+ lmdb_idx = int(lmdb_idx)
248
+ file_idx = int(file_idx)
249
+ sample_info = self.get_lmdb_sample_info(
250
+ self.lmdb_sets[lmdb_idx]['txn'], file_idx)
251
+ if sample_info is None:
252
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
253
+ ids = random.sample(ratio_ids, 1)
254
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
255
+ img, label = sample_info
256
+ data = {'image': img, 'label': label}
257
+ outs = transform(data, self.ops[:-1])
258
+ if outs is not None:
259
+ outs = self.resize_norm_img(outs, ratio, padding=False)
260
+ if outs is None:
261
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
262
+ ids = random.sample(ratio_ids, 1)
263
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
264
+
265
+ outs = transform(outs, self.ops[-1:])
266
+ if outs is None:
267
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
268
+ ids = random.sample(ratio_ids, 1)
269
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
270
+ return outs
271
+
272
+ def __len__(self):
273
+ return self.data_idx_order_list.shape[0]