openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,213 @@
1
+ import io
2
+ import math
3
+ import random
4
+
5
+ import cv2
6
+ import lmdb
7
+ import numpy as np
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset
10
+ from torchvision import transforms as T
11
+ from torchvision.transforms import functional as F
12
+
13
+ from openrec.preprocess import create_operators, transform
14
+
15
+
16
+ class RatioDataSetTVResize(Dataset):
17
+
18
+ def __init__(self, config, mode, logger, seed=None, epoch=1):
19
+ super(RatioDataSetTVResize, self).__init__()
20
+ self.ds_width = config[mode]['dataset'].get('ds_width', True)
21
+ global_config = config['Global']
22
+ dataset_config = config[mode]['dataset']
23
+ loader_config = config[mode]['loader']
24
+ max_ratio = loader_config.get('max_ratio', 10)
25
+ min_ratio = loader_config.get('min_ratio', 1)
26
+ data_dir_list = dataset_config['data_dir_list']
27
+ self.padding = dataset_config.get('padding', True)
28
+ self.padding_rand = dataset_config.get('padding_rand', False)
29
+ self.padding_doub = dataset_config.get('padding_doub', False)
30
+ self.do_shuffle = loader_config['shuffle']
31
+ self.seed = epoch
32
+ data_source_num = len(data_dir_list)
33
+ ratio_list = dataset_config.get('ratio_list', 1.0)
34
+ if isinstance(ratio_list, (float, int)):
35
+ ratio_list = [float(ratio_list)] * int(data_source_num)
36
+ assert (
37
+ len(ratio_list) == data_source_num
38
+ ), 'The length of ratio_list should be the same as the file_list.'
39
+ self.lmdb_sets = self.load_hierarchical_lmdb_dataset(
40
+ data_dir_list, ratio_list)
41
+ for data_dir in data_dir_list:
42
+ logger.info('Initialize indexs of datasets:%s' % data_dir)
43
+ self.logger = logger
44
+ self.data_idx_order_list = self.dataset_traversal()
45
+ wh_ratio = np.around(np.array(self.get_wh_ratio()))
46
+ self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio)
47
+ for i in range(max_ratio + 1):
48
+ logger.info((1 * (self.wh_ratio == i)).sum())
49
+ self.wh_ratio_sort = np.argsort(self.wh_ratio)
50
+ self.ops = create_operators(dataset_config['transforms'],
51
+ global_config)
52
+
53
+ self.need_reset = True in [x < 1 for x in ratio_list]
54
+ self.error = 0
55
+ self.base_shape = dataset_config.get(
56
+ 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
57
+ self.base_h = dataset_config.get('base_h', 32)
58
+ self.interpolation = T.InterpolationMode.BICUBIC
59
+ transforms = []
60
+ transforms.extend([
61
+ T.ToTensor(),
62
+ T.Normalize(0.5, 0.5),
63
+ ])
64
+ self.transforms = T.Compose(transforms)
65
+
66
+ def get_wh_ratio(self):
67
+ wh_ratio = []
68
+ for idx in range(self.data_idx_order_list.shape[0]):
69
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
70
+ lmdb_idx = int(lmdb_idx)
71
+ file_idx = int(file_idx)
72
+ wh_key = 'wh-%09d'.encode() % file_idx
73
+ wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key)
74
+ if wh is None:
75
+ img_key = f'image-{file_idx:09d}'.encode()
76
+ img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key)
77
+ buf = io.BytesIO(img)
78
+ w, h = Image.open(buf).size
79
+ else:
80
+ wh = wh.decode('utf-8')
81
+ w, h = wh.split('_')
82
+ wh_ratio.append(float(w) / float(h))
83
+ return wh_ratio
84
+
85
+ def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list):
86
+ lmdb_sets = {}
87
+ dataset_idx = 0
88
+ for dirpath, ratio in zip(data_dir_list, ratio_list):
89
+ env = lmdb.open(dirpath,
90
+ max_readers=32,
91
+ readonly=True,
92
+ lock=False,
93
+ readahead=False,
94
+ meminit=False)
95
+ txn = env.begin(write=False)
96
+ num_samples = int(txn.get('num-samples'.encode()))
97
+ lmdb_sets[dataset_idx] = {
98
+ 'dirpath': dirpath,
99
+ 'env': env,
100
+ 'txn': txn,
101
+ 'num_samples': num_samples,
102
+ 'ratio_num_samples': int(ratio * num_samples)
103
+ }
104
+ dataset_idx += 1
105
+ return lmdb_sets
106
+
107
+ def dataset_traversal(self):
108
+ lmdb_num = len(self.lmdb_sets)
109
+ total_sample_num = 0
110
+ for lno in range(lmdb_num):
111
+ total_sample_num += self.lmdb_sets[lno]['ratio_num_samples']
112
+ data_idx_order_list = np.zeros((total_sample_num, 2))
113
+ beg_idx = 0
114
+ for lno in range(lmdb_num):
115
+ tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples']
116
+ end_idx = beg_idx + tmp_sample_num
117
+ data_idx_order_list[beg_idx:end_idx, 0] = lno
118
+ data_idx_order_list[beg_idx:end_idx, 1] = list(
119
+ random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1),
120
+ self.lmdb_sets[lno]['ratio_num_samples']))
121
+ beg_idx = beg_idx + tmp_sample_num
122
+ return data_idx_order_list
123
+
124
+ def get_img_data(self, value):
125
+ """get_img_data."""
126
+ if not value:
127
+ return None
128
+ imgdata = np.frombuffer(value, dtype='uint8')
129
+ if imgdata is None:
130
+ return None
131
+ imgori = cv2.imdecode(imgdata, 1)
132
+ if imgori is None:
133
+ return None
134
+ return imgori
135
+
136
+ def resize_norm_img(self, data, gen_ratio, padding=True):
137
+ img = data['image']
138
+ w, h = img.size
139
+ if self.padding_rand and random.random() < 0.5:
140
+ padding = not padding
141
+ imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [
142
+ self.base_h * gen_ratio, self.base_h
143
+ ]
144
+ use_ratio = imgW // imgH
145
+ if use_ratio >= (w // h) + 2:
146
+ self.error += 1
147
+ return None
148
+ if not padding:
149
+ resized_w = imgW
150
+ else:
151
+ ratio = w / float(h)
152
+ if math.ceil(imgH * ratio) > imgW:
153
+ resized_w = imgW
154
+ else:
155
+ resized_w = int(
156
+ math.ceil(imgH * ratio * (random.random() + 0.5)))
157
+ resized_w = min(imgW, resized_w)
158
+ resized_image = F.resize(img, (imgH, resized_w),
159
+ interpolation=self.interpolation)
160
+ img = self.transforms(resized_image)
161
+ if resized_w < imgW and padding:
162
+ # img = F.pad(img, [0, 0, imgW-resized_w, 0], fill=0.)
163
+ if self.padding_doub and random.random() < 0.5:
164
+ img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.)
165
+ else:
166
+ img = F.pad(img, [imgW - resized_w, 0, 0, 0], fill=0.)
167
+ valid_ratio = min(1.0, float(resized_w / imgW))
168
+ data['image'] = img
169
+ data['valid_ratio'] = valid_ratio
170
+ return data
171
+
172
+ def get_lmdb_sample_info(self, txn, index):
173
+ label_key = 'label-%09d'.encode() % index
174
+ label = txn.get(label_key)
175
+ if label is None:
176
+ return None
177
+ label = label.decode('utf-8')
178
+ img_key = 'image-%09d'.encode() % index
179
+ imgbuf = txn.get(img_key)
180
+ return imgbuf, label
181
+
182
+ def __getitem__(self, properties):
183
+ img_width = properties[0]
184
+ img_height = properties[1]
185
+ idx = properties[2]
186
+ ratio = properties[3]
187
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
188
+ lmdb_idx = int(lmdb_idx)
189
+ file_idx = int(file_idx)
190
+ sample_info = self.get_lmdb_sample_info(
191
+ self.lmdb_sets[lmdb_idx]['txn'], file_idx)
192
+ if sample_info is None:
193
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
194
+ ids = random.sample(ratio_ids, 1)
195
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
196
+ img, label = sample_info
197
+ data = {'image': img, 'label': label}
198
+ outs = transform(data, self.ops[:-1])
199
+ if outs is not None:
200
+ outs = self.resize_norm_img(outs, ratio, padding=self.padding)
201
+ if outs is None:
202
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
203
+ ids = random.sample(ratio_ids, 1)
204
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
205
+ outs = transform(outs, self.ops[-1:])
206
+ if outs is None:
207
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
208
+ ids = random.sample(ratio_ids, 1)
209
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
210
+ return outs
211
+
212
+ def __len__(self):
213
+ return self.data_idx_order_list.shape[0]
@@ -0,0 +1,276 @@
1
+ import io
2
+ import math
3
+ import random
4
+ import re
5
+ import unicodedata
6
+
7
+ import cv2
8
+ import lmdb
9
+ import numpy as np
10
+ from PIL import Image
11
+ from torch.utils.data import Dataset
12
+ from torchvision import transforms as T
13
+ from torchvision.transforms import functional as F
14
+
15
+ from openrec.preprocess import create_operators, transform
16
+
17
+
18
+ class CharsetAdapter:
19
+ """Transforms labels according to the target charset."""
20
+
21
+ def __init__(self, target_charset) -> None:
22
+ super().__init__()
23
+ self.lowercase_only = target_charset == target_charset.lower()
24
+ self.uppercase_only = target_charset == target_charset.upper()
25
+ self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
26
+
27
+ def __call__(self, label):
28
+ if self.lowercase_only:
29
+ label = label.lower()
30
+ elif self.uppercase_only:
31
+ label = label.upper()
32
+ # Remove unsupported characters
33
+ label = self.unsupported.sub('', label)
34
+ return label
35
+
36
+
37
+ class RatioDataSetTVResizeTest(Dataset):
38
+
39
+ def __init__(self, config, mode, logger, seed=None, epoch=1):
40
+ super(RatioDataSetTVResizeTest, self).__init__()
41
+ self.ds_width = config[mode]['dataset'].get('ds_width', True)
42
+ global_config = config['Global']
43
+ dataset_config = config[mode]['dataset']
44
+ loader_config = config[mode]['loader']
45
+ max_ratio = loader_config.get('max_ratio', 10)
46
+ min_ratio = loader_config.get('min_ratio', 1)
47
+ data_dir_list = dataset_config['data_dir_list']
48
+ self.do_shuffle = loader_config['shuffle']
49
+ self.seed = epoch
50
+ self.max_text_length = global_config['max_text_length']
51
+ data_source_num = len(data_dir_list)
52
+ ratio_list = dataset_config.get('ratio_list', 1.0)
53
+ if isinstance(ratio_list, (float, int)):
54
+ ratio_list = [float(ratio_list)] * int(data_source_num)
55
+ assert len(
56
+ ratio_list
57
+ ) == data_source_num, 'The length of ratio_list should be the same as the file_list.'
58
+ self.lmdb_sets = self.load_hierarchical_lmdb_dataset(
59
+ data_dir_list, ratio_list)
60
+ for data_dir in data_dir_list:
61
+ logger.info('Initialize indexs of datasets:%s' % data_dir)
62
+ self.logger = logger
63
+ data_idx_order_list = self.dataset_traversal()
64
+ character_dict_path = global_config.get('character_dict_path', None)
65
+ use_space_char = global_config.get('use_space_char', False)
66
+ if character_dict_path is None:
67
+ char_test = '0123456789abcdefghijklmnopqrstuvwxyz'
68
+ else:
69
+ char_test = ''
70
+ with open(character_dict_path, 'rb') as fin:
71
+ lines = fin.readlines()
72
+ for line in lines:
73
+ line = line.decode('utf-8').strip('\n').strip('\r\n')
74
+ char_test += line
75
+ if use_space_char:
76
+ char_test += ' '
77
+ wh_ratio, data_idx_order_list = self.get_wh_ratio(
78
+ data_idx_order_list, char_test)
79
+ self.data_idx_order_list = np.array(data_idx_order_list)
80
+ wh_ratio = np.around(np.array(wh_ratio))
81
+ self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio)
82
+ for i in range(max_ratio + 1):
83
+ logger.info((1 * (self.wh_ratio == i)).sum())
84
+ self.wh_ratio_sort = np.argsort(self.wh_ratio)
85
+ self.ops = create_operators(dataset_config['transforms'],
86
+ global_config)
87
+
88
+ self.need_reset = True in [x < 1 for x in ratio_list]
89
+ self.error = 0
90
+ self.base_shape = dataset_config.get(
91
+ 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
92
+ self.base_h = dataset_config.get('base_h', 32)
93
+ self.interpolation = T.InterpolationMode.BICUBIC
94
+ transforms = []
95
+ transforms.extend([
96
+ T.ToTensor(),
97
+ T.Normalize(0.5, 0.5),
98
+ ])
99
+ self.transforms = T.Compose(transforms)
100
+
101
+ def get_wh_ratio(self, data_idx_order_list, char_test):
102
+ wh_ratio = []
103
+ wh_ratio_len = [[0 for _ in range(26)] for _ in range(11)]
104
+ data_idx_order_list_filter = []
105
+ charset_adapter = CharsetAdapter(char_test)
106
+
107
+ for idx in range(data_idx_order_list.shape[0]):
108
+ lmdb_idx, file_idx = data_idx_order_list[idx]
109
+ lmdb_idx = int(lmdb_idx)
110
+ file_idx = int(file_idx)
111
+ wh_key = 'wh-%09d'.encode() % file_idx
112
+ wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key)
113
+ if wh is None:
114
+ img_key = f'image-{file_idx:09d}'.encode()
115
+ img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key)
116
+ buf = io.BytesIO(img)
117
+ w, h = Image.open(buf).size
118
+ else:
119
+ wh = wh.decode('utf-8')
120
+ w, h = wh.split('_')
121
+
122
+ label_key = 'label-%09d'.encode() % file_idx
123
+ label = self.lmdb_sets[lmdb_idx]['txn'].get(label_key)
124
+ if label is not None:
125
+ # return None
126
+ label = label.decode('utf-8')
127
+ # if remove_whitespace:
128
+ label = ''.join(label.split())
129
+ # Normalize unicode composites (if any) and convert to compatible ASCII characters
130
+ # if normalize_unicode:
131
+ label = unicodedata.normalize('NFKD',
132
+ label).encode('ascii',
133
+ 'ignore').decode()
134
+ # Filter by length before removing unsupported characters. The original label might be too long.
135
+ if len(label) > self.max_text_length:
136
+ continue
137
+ label = charset_adapter(label)
138
+ if not label:
139
+ continue
140
+
141
+ wh_ratio.append(float(w) / float(h))
142
+ wh_ratio_len[int(float(w) /
143
+ float(h)) if int(float(w) /
144
+ float(h)) <= 10 else
145
+ 10][len(label) if len(label) <= 25 else 25] += 1
146
+ data_idx_order_list_filter.append([lmdb_idx, file_idx])
147
+ self.logger.info(wh_ratio_len)
148
+ return wh_ratio, data_idx_order_list_filter
149
+
150
+ def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list):
151
+ lmdb_sets = {}
152
+ dataset_idx = 0
153
+ for dirpath, ratio in zip(data_dir_list, ratio_list):
154
+ env = lmdb.open(dirpath,
155
+ max_readers=32,
156
+ readonly=True,
157
+ lock=False,
158
+ readahead=False,
159
+ meminit=False)
160
+ txn = env.begin(write=False)
161
+ num_samples = int(txn.get('num-samples'.encode()))
162
+ lmdb_sets[dataset_idx] = {
163
+ 'dirpath': dirpath,
164
+ 'env': env,
165
+ 'txn': txn,
166
+ 'num_samples': num_samples,
167
+ 'ratio_num_samples': int(ratio * num_samples),
168
+ }
169
+ dataset_idx += 1
170
+ return lmdb_sets
171
+
172
+ def dataset_traversal(self):
173
+ lmdb_num = len(self.lmdb_sets)
174
+ total_sample_num = 0
175
+ for lno in range(lmdb_num):
176
+ total_sample_num += self.lmdb_sets[lno]['ratio_num_samples']
177
+ data_idx_order_list = np.zeros((total_sample_num, 2))
178
+ beg_idx = 0
179
+ for lno in range(lmdb_num):
180
+ tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples']
181
+ end_idx = beg_idx + tmp_sample_num
182
+ data_idx_order_list[beg_idx:end_idx, 0] = lno
183
+ data_idx_order_list[beg_idx:end_idx, 1] = list(
184
+ random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1),
185
+ self.lmdb_sets[lno]['ratio_num_samples']))
186
+ beg_idx = beg_idx + tmp_sample_num
187
+ return data_idx_order_list
188
+
189
+ def get_img_data(self, value):
190
+ """get_img_data."""
191
+ if not value:
192
+ return None
193
+ imgdata = np.frombuffer(value, dtype='uint8')
194
+ if imgdata is None:
195
+ return None
196
+ imgori = cv2.imdecode(imgdata, 1)
197
+ if imgori is None:
198
+ return None
199
+ return imgori
200
+
201
+ def resize_norm_img(self, data, gen_ratio, padding=True):
202
+ img = data['image']
203
+ w, h = img.size
204
+ imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [
205
+ self.base_h * gen_ratio, self.base_h
206
+ ]
207
+ use_ratio = imgW // imgH
208
+ if use_ratio >= (w // h) + 2:
209
+ self.error += 1
210
+ return None
211
+ if not padding:
212
+ resized_w = imgW
213
+ else:
214
+ ratio = w / float(h)
215
+ if math.ceil(imgH * ratio) > imgW:
216
+ resized_w = imgW
217
+ else:
218
+ resized_w = int(
219
+ math.ceil(imgH * ratio * (random.random() + 0.5)))
220
+ resized_w = min(imgW, resized_w)
221
+ resized_image = F.resize(img, (imgH, resized_w),
222
+ interpolation=self.interpolation)
223
+ img = self.transforms(resized_image)
224
+ if resized_w < imgW and padding:
225
+ img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.)
226
+ valid_ratio = min(1.0, float(resized_w / imgW))
227
+ data['image'] = img
228
+ data['valid_ratio'] = valid_ratio
229
+ data['gen_ratio'] = imgW // imgH
230
+ r = float(w) / float(h)
231
+ data['real_ratio'] = max(1, round(r))
232
+ return data
233
+
234
+ def get_lmdb_sample_info(self, txn, index):
235
+ label_key = 'label-%09d'.encode() % index
236
+ label = txn.get(label_key)
237
+ if label is None:
238
+ return None
239
+ label = label.decode('utf-8')
240
+ img_key = 'image-%09d'.encode() % index
241
+ imgbuf = txn.get(img_key)
242
+ return imgbuf, label
243
+
244
+ def __getitem__(self, properties):
245
+ img_width = properties[0]
246
+ img_height = properties[1]
247
+ idx = properties[2]
248
+ ratio = properties[3]
249
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
250
+ lmdb_idx = int(lmdb_idx)
251
+ file_idx = int(file_idx)
252
+ sample_info = self.get_lmdb_sample_info(
253
+ self.lmdb_sets[lmdb_idx]['txn'], file_idx)
254
+ if sample_info is None:
255
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
256
+ ids = random.sample(ratio_ids, 1)
257
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
258
+ img, label = sample_info
259
+ data = {'image': img, 'label': label}
260
+ outs = transform(data, self.ops[:-1])
261
+ if outs is not None:
262
+ outs = self.resize_norm_img(outs, ratio, padding=False)
263
+ if outs is None:
264
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
265
+ ids = random.sample(ratio_ids, 1)
266
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
267
+
268
+ outs = transform(outs, self.ops[-1:])
269
+ if outs is None:
270
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
271
+ ids = random.sample(ratio_ids, 1)
272
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
273
+ return outs
274
+
275
+ def __len__(self):
276
+ return self.data_idx_order_list.shape[0]
@@ -0,0 +1,190 @@
1
+ import math
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Sampler
8
+
9
+
10
+ class RatioSampler(Sampler):
11
+
12
+ def __init__(self,
13
+ data_source,
14
+ scales,
15
+ first_bs=512,
16
+ fix_bs=True,
17
+ divided_factor=[8, 16],
18
+ is_training=True,
19
+ max_ratio=10,
20
+ max_bs=1024,
21
+ seed=None):
22
+ """
23
+ multi scale samper
24
+ Args:
25
+ data_source(dataset)
26
+ scales(list): several scales for image resolution
27
+ first_bs(int): batch size for the first scale in scales
28
+ divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
29
+ is_training(boolean): mode
30
+ """
31
+ # min. and max. spatial dimensions
32
+ self.data_source = data_source
33
+ # self.data_idx_order_list = np.array(data_source.data_idx_order_list)
34
+ self.ds_width = data_source.ds_width
35
+ self.seed = data_source.seed
36
+ if self.ds_width:
37
+ self.wh_ratio = data_source.wh_ratio
38
+ self.wh_ratio_sort = data_source.wh_ratio_sort
39
+ self.n_data_samples = len(self.data_source)
40
+ self.max_ratio = max_ratio
41
+ self.max_bs = max_bs
42
+
43
+ if isinstance(scales[0], list):
44
+ width_dims = [i[0] for i in scales]
45
+ height_dims = [i[1] for i in scales]
46
+ elif isinstance(scales[0], int):
47
+ width_dims = scales
48
+ height_dims = scales
49
+ base_im_w = width_dims[0]
50
+ base_im_h = height_dims[0]
51
+ base_batch_size = first_bs
52
+ base_elements = base_im_w * base_im_h * base_batch_size
53
+ self.base_elements = base_elements
54
+ self.base_batch_size = base_batch_size
55
+ self.base_im_h = base_im_h
56
+ self.base_im_w = base_im_w
57
+
58
+ # Get the GPU and node related information
59
+ num_replicas = torch.cuda.device_count()
60
+ # rank = dist.get_rank()
61
+ rank = (int(os.environ['LOCAL_RANK'])
62
+ if 'LOCAL_RANK' in os.environ else 0)
63
+ # self.rank = rank
64
+ # adjust the total samples to avoid batch dropping
65
+ num_samples_per_replica = int(
66
+ math.ceil(self.n_data_samples * 1.0 / num_replicas))
67
+
68
+ img_indices = [idx for idx in range(self.n_data_samples)]
69
+ self.shuffle = False
70
+ if is_training:
71
+ # compute the spatial dimensions and corresponding batch size
72
+ # ImageNet models down-sample images by a factor of 32.
73
+ # Ensure that width and height dimensions are multiples are multiple of 32.
74
+ width_dims = [
75
+ int((w // divided_factor[0]) * divided_factor[0])
76
+ for w in width_dims
77
+ ]
78
+ height_dims = [
79
+ int((h // divided_factor[1]) * divided_factor[1])
80
+ for h in height_dims
81
+ ]
82
+
83
+ img_batch_pairs = list()
84
+ for (h, w) in zip(height_dims, width_dims):
85
+ if fix_bs:
86
+ batch_size = base_batch_size
87
+ else:
88
+ batch_size = int(max(1, (base_elements / (h * w))))
89
+ img_batch_pairs.append((w, h, batch_size))
90
+ self.img_batch_pairs = img_batch_pairs
91
+ self.shuffle = True
92
+ np.random.seed(seed)
93
+ random.seed(seed)
94
+ else:
95
+ self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)]
96
+
97
+ self.img_indices = img_indices
98
+ self.n_samples_per_replica = num_samples_per_replica
99
+ self.epoch = 0
100
+ self.rank = rank
101
+ self.num_replicas = num_replicas
102
+
103
+ # self.batch_list = []
104
+ self.current = 0
105
+ self.is_training = is_training
106
+ if is_training:
107
+ indices_rank_i = self.img_indices[
108
+ self.rank:len(self.img_indices):self.num_replicas]
109
+ else:
110
+ indices_rank_i = self.img_indices
111
+ self.indices_rank_i_ori = np.array(self.wh_ratio_sort[indices_rank_i])
112
+ self.indices_rank_i_ratio = self.wh_ratio[self.indices_rank_i_ori]
113
+ indices_rank_i_ratio_unique = np.unique(self.indices_rank_i_ratio)
114
+ self.indices_rank_i_ratio_unique = indices_rank_i_ratio_unique.tolist()
115
+ self.batch_list = self.create_batch()
116
+ self.length = len(self.batch_list)
117
+ self.batchs_in_one_epoch_id = [i for i in range(self.length)]
118
+
119
+ def create_batch(self):
120
+ batch_list = []
121
+ for ratio in self.indices_rank_i_ratio_unique:
122
+ ratio_ids = np.where(self.indices_rank_i_ratio == ratio)[0]
123
+ ratio_ids = self.indices_rank_i_ori[ratio_ids]
124
+ if self.shuffle:
125
+ random.shuffle(ratio_ids)
126
+ num_ratio = ratio_ids.shape[0]
127
+ if ratio < 5:
128
+ batch_size_ratio = self.base_batch_size
129
+ else:
130
+ batch_size_ratio = min(
131
+ self.max_bs,
132
+ int(
133
+ max(1, (self.base_elements /
134
+ (self.base_im_h * ratio * self.base_im_h)))))
135
+ if num_ratio > batch_size_ratio:
136
+ batch_num_ratio = num_ratio // batch_size_ratio
137
+ print(self.rank, num_ratio, ratio * self.base_im_h,
138
+ batch_num_ratio, batch_size_ratio)
139
+ ratio_ids_full = ratio_ids[:batch_num_ratio *
140
+ batch_size_ratio].reshape(
141
+ batch_num_ratio,
142
+ batch_size_ratio, 1)
143
+ w = np.full_like(ratio_ids_full, ratio * self.base_im_h)
144
+ h = np.full_like(ratio_ids_full, self.base_im_h)
145
+ ra_wh = np.full_like(ratio_ids_full, ratio)
146
+ ratio_ids_full = np.concatenate([w, h, ratio_ids_full, ra_wh],
147
+ axis=-1)
148
+ batch_ratio = ratio_ids_full.tolist()
149
+
150
+ if batch_num_ratio * batch_size_ratio < num_ratio:
151
+ drop = ratio_ids[batch_num_ratio * batch_size_ratio:]
152
+ if self.is_training:
153
+ drop_full = ratio_ids[:batch_size_ratio - (
154
+ num_ratio - batch_num_ratio * batch_size_ratio)]
155
+ drop = np.append(drop_full, drop)
156
+ drop = drop.reshape(-1, 1)
157
+ w = np.full_like(drop, ratio * self.base_im_h)
158
+ h = np.full_like(drop, self.base_im_h)
159
+ ra_wh = np.full_like(drop, ratio)
160
+
161
+ drop = np.concatenate([w, h, drop, ra_wh], axis=-1)
162
+
163
+ batch_ratio.append(drop.tolist())
164
+ batch_list += batch_ratio
165
+ else:
166
+ print(self.rank, num_ratio, ratio * self.base_im_h,
167
+ batch_size_ratio)
168
+ ratio_ids = ratio_ids.reshape(-1, 1)
169
+ w = np.full_like(ratio_ids, ratio * self.base_im_h)
170
+ h = np.full_like(ratio_ids, self.base_im_h)
171
+ ra_wh = np.full_like(ratio_ids, ratio)
172
+
173
+ ratio_ids = np.concatenate([w, h, ratio_ids, ra_wh], axis=-1)
174
+ batch_list.append(ratio_ids.tolist())
175
+ return batch_list
176
+
177
+ def __iter__(self):
178
+ if self.shuffle or self.is_training:
179
+ random.seed(self.epoch)
180
+ self.epoch += 1
181
+ self.batch_list = self.create_batch()
182
+ random.shuffle(self.batchs_in_one_epoch_id)
183
+ for batch_tuple_id in self.batchs_in_one_epoch_id:
184
+ yield self.batch_list[batch_tuple_id]
185
+
186
+ def set_epoch(self, epoch: int):
187
+ self.epoch = epoch
188
+
189
+ def __len__(self):
190
+ return self.length