openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,360 @@
1
+ import copy
2
+ import random
3
+
4
+ import numpy as np
5
+
6
+ from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
7
+
8
+
9
+ class IGTRLabelEncode(BaseRecLabelEncode):
10
+ """Convert between text-label and text-index."""
11
+
12
+ def __init__(self,
13
+ max_text_length,
14
+ character_dict_path=None,
15
+ use_space_char=False,
16
+ k=1,
17
+ ch=False,
18
+ prompt_error=False,
19
+ **kwargs):
20
+ super(IGTRLabelEncode,
21
+ self).__init__(max_text_length, character_dict_path,
22
+ use_space_char)
23
+ self.ignore_index = self.dict['<pad>']
24
+ self.k = k
25
+ self.prompt_error = prompt_error
26
+ self.ch = ch
27
+ rare_file = kwargs.get('rare_file', None)
28
+ siml_file = kwargs.get('siml_file', None)
29
+ siml_char_dict = {}
30
+ siml_char_list = [0 for _ in range(self.num_character)]
31
+ if siml_file is not None:
32
+ with open(siml_file, 'r') as f:
33
+ for lin in f.readlines():
34
+ lin_s = lin.strip().split('\t')
35
+ char_siml = lin_s[0]
36
+ if char_siml in self.dict:
37
+ siml_list = []
38
+ siml_prob = []
39
+ for i in range(1, len(lin_s), 2):
40
+ c = lin_s[i]
41
+ prob = int(lin_s[i + 1])
42
+ if c in self.dict and prob >= 1:
43
+ siml_list.append(self.dict[c])
44
+ siml_prob.append(prob)
45
+ siml_prob = np.array(siml_prob,
46
+ dtype=np.float32) / sum(siml_prob)
47
+ siml_char_dict[self.dict[char_siml]] = [
48
+ siml_list, siml_prob.tolist()
49
+ ]
50
+ siml_char_list[self.dict[char_siml]] = 1
51
+ self.siml_char_dict = siml_char_dict
52
+ self.siml_char_list = siml_char_list
53
+
54
+ rare_char_list = [0 for _ in range(self.num_character)]
55
+ if rare_file is not None:
56
+ with open(rare_file, 'r') as f:
57
+ for lin in f.readlines():
58
+ lin_s = lin.strip().split('\t')
59
+ # print(lin_s)
60
+ char_rare = lin_s[0]
61
+ num_appear = int(lin_s[1])
62
+ if char_rare in self.dict and num_appear < 1000:
63
+ rare_char_list[self.dict[char_rare]] = 1
64
+
65
+ self.rare_char_list = rare_char_list # [self.dict[char] for char in rare_char_list]
66
+
67
+ def __call__(self, data):
68
+ text = data['label'] # coffee
69
+
70
+ encoder_result = self.encode(text)
71
+ if encoder_result is None:
72
+ return None
73
+
74
+ text, text_char_num, ques_list_s, prompt_list_s = encoder_result
75
+
76
+ if len(text) > self.max_text_len:
77
+ return None
78
+ data['length'] = np.array(len(text))
79
+
80
+ text = [self.dict['<s>']] + text + [self.dict['</s>']]
81
+ text = text + [self.dict['<pad>']
82
+ ] * (self.max_text_len + 2 - len(text))
83
+ data['label'] = np.array(text) # 6
84
+
85
+ ques_len_list = []
86
+ ques2_len_list = []
87
+ prompt_len_list = []
88
+
89
+ prompt_pos_idx_list = []
90
+ prompt_char_idx_list = []
91
+ ques_pos_idx_list = []
92
+ ques1_answer_list = []
93
+ ques2_char_idx_list = []
94
+ ques2_answer_list = []
95
+ ques4_char_num_list = []
96
+ train_step = 0
97
+ for prompt_list, ques_list in zip(prompt_list_s, ques_list_s):
98
+
99
+ prompt_len = len(prompt_list) + 1
100
+ prompt_len_list.append(prompt_len)
101
+ prompt_list = np.array(
102
+ [[0, self.dict['<s>'], 0]] + prompt_list +
103
+ [[self.max_text_len + 2, self.dict['<pad>'], 0]] *
104
+ (self.max_text_len - len(prompt_list)))
105
+ prompt_pos_idx_list.append(prompt_list[:, 0])
106
+ prompt_char_idx_list.append(prompt_list[:, 1])
107
+
108
+ ques_len = len(ques_list)
109
+ ques_len_list.append(ques_len)
110
+
111
+ ques_list = np.array(
112
+ ques_list + [[self.max_text_len + 2, self.dict['<pad>'], 0]] *
113
+ (self.max_text_len + 1 - ques_len))
114
+ ques_pos_idx_list.append(ques_list[:, 0])
115
+ # what is the first and third char?
116
+ # Is the first character 't'? and Is the third character 'f'?
117
+ # How many 'c', 's' and 'f' are there in the text image?
118
+ ques1_answer_list.append(ques_list[:, 1])
119
+ ques2_char_idx = copy.deepcopy(ques_list[:ques_len, :2])
120
+ new_ques2_char_idx = []
121
+ ques2_answer = []
122
+ for q_2, ques2_idx in enumerate(ques2_char_idx.tolist()):
123
+
124
+ if (train_step == 2 or train_step == 3) and q_2 == ques_len - 1:
125
+ new_ques2_char_idx.append(ques2_idx)
126
+ ques2_answer.append(1)
127
+ continue
128
+ if ques2_idx[1] != self.dict['<pad>'] and random.random() > 0.5:
129
+ select_idx = random.randint(0, self.num_character - 3)
130
+ new_ques2_char_idx.append([ques2_idx[0], select_idx])
131
+ if select_idx == ques2_idx[1]:
132
+ ques2_answer.append(1)
133
+ else:
134
+ ques2_answer.append(0)
135
+
136
+ if self.siml_char_list[
137
+ ques2_idx[1]] == 1 and random.random() > 0.5:
138
+ select_idx_sim_list = random.sample(
139
+ self.siml_char_dict[ques2_idx[1]][0],
140
+ min(3, len(self.siml_char_dict[ques2_idx[1]][0])),
141
+ )
142
+ for select_idx in select_idx_sim_list:
143
+ new_ques2_char_idx.append(
144
+ [ques2_idx[0], select_idx])
145
+ if select_idx == ques2_idx[1]:
146
+ ques2_answer.append(1)
147
+ else:
148
+ ques2_answer.append(0)
149
+ else:
150
+ new_ques2_char_idx.append(ques2_idx)
151
+ ques2_answer.append(1)
152
+ ques2_len_list.append(len(new_ques2_char_idx))
153
+ ques2_char_idx_new = np.array(
154
+ new_ques2_char_idx +
155
+ [[self.max_text_len + 2, self.dict['<pad>']]] *
156
+ (self.max_text_len * 4 + 1 - len(new_ques2_char_idx)))
157
+ ques2_answer = np.array(
158
+ ques2_answer + [0] *
159
+ (self.max_text_len * 4 + 1 - len(ques2_answer)))
160
+ ques2_char_idx_list.append(ques2_char_idx_new)
161
+ ques2_answer_list.append(ques2_answer)
162
+
163
+ ques4_char_num_list.append(ques_list[:, 2])
164
+ train_step += 1
165
+
166
+ data['ques_len_list'] = np.array(ques_len_list, dtype=np.int64)
167
+ data['ques2_len_list'] = np.array(ques2_len_list, dtype=np.int64)
168
+ data['prompt_len_list'] = np.array(prompt_len_list, dtype=np.int64)
169
+
170
+ data['prompt_pos_idx_list'] = np.array(prompt_pos_idx_list,
171
+ dtype=np.int64)
172
+ data['prompt_char_idx_list'] = np.array(prompt_char_idx_list,
173
+ dtype=np.int64)
174
+ data['ques_pos_idx_list'] = np.array(ques_pos_idx_list, dtype=np.int64)
175
+ data['ques1_answer_list'] = np.array(ques1_answer_list, dtype=np.int64)
176
+ data['ques2_char_idx_list'] = np.array(ques2_char_idx_list,
177
+ dtype=np.int64)
178
+ data['ques2_answer_list'] = np.array(ques2_answer_list,
179
+ dtype=np.float32)
180
+
181
+ data['ques3_answer'] = np.array(
182
+ text_char_num,
183
+ dtype=np.int64) # np.array([1, 0, 2]) # answer 1, 0, 2
184
+ data['ques4_char_num_list'] = np.array(ques4_char_num_list)
185
+
186
+ return data
187
+
188
+ def add_special_char(self, dict_character):
189
+ dict_character = ['</s>'] + dict_character + ['<s>'] + ['<pad>']
190
+ self.num_character = len(dict_character)
191
+
192
+ return dict_character
193
+
194
+ def encode(self, text):
195
+ """convert text-label into text-index.
196
+ input:
197
+ text: text labels of each image. [batch_size]
198
+
199
+ output:
200
+ text: concatenated text index for CTCLoss.
201
+ [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
202
+ length: length of each text. [batch_size]
203
+ """
204
+ if len(text) == 0:
205
+ return None
206
+ if self.lower:
207
+ text = text.lower()
208
+ char_num = [0 for _ in range(self.num_character - 2)]
209
+ char_num[0] = 1
210
+ text_list = []
211
+ qa_text = []
212
+ pos_i = 0
213
+ rare_char_qa = []
214
+ unrare_char_qa = []
215
+ for char in text:
216
+ if char not in self.dict:
217
+ continue
218
+
219
+ char_id = self.dict[char]
220
+ text_list.append(char_id)
221
+ qa_text.append([pos_i + 1, char_id, char_num[char_id]])
222
+ if self.rare_char_list[char_id] == 1:
223
+ rare_char_qa.append([pos_i + 1, char_id, char_num[char_id]])
224
+ else:
225
+ unrare_char_qa.append([pos_i + 1, char_id, char_num[char_id]])
226
+ char_num[char_id] += 1
227
+ pos_i += 1
228
+
229
+ if self.ch:
230
+ char_num_ch = []
231
+ char_num_ch_none = []
232
+ rare_char_num_ch_none = []
233
+ for i, num in enumerate(char_num):
234
+ if self.rare_char_list[i] == 1:
235
+ rare_char_num_ch_none.append([i, num])
236
+ if num > 0:
237
+ char_num_ch.append([i, num])
238
+ else:
239
+ char_num_ch_none.append([i, 0])
240
+ none_char_index = random.sample(
241
+ char_num_ch_none,
242
+ min(37 - len(char_num_ch), len(char_num_ch_none)))
243
+ if len(rare_char_num_ch_none) > 0:
244
+ none_rare_char_index = random.sample(
245
+ rare_char_num_ch_none,
246
+ min(40 - len(char_num_ch) - len(none_char_index),
247
+ len(rare_char_num_ch_none)),
248
+ )
249
+ char_num_ch = char_num_ch + none_char_index + none_rare_char_index
250
+ else:
251
+ char_num_ch = char_num_ch + none_char_index
252
+ char_num_ch.sort(key=lambda x: x[0])
253
+ char_num = char_num_ch
254
+
255
+ len_ = len(text_list)
256
+ if len_ == 0:
257
+ return None
258
+ ques_list = [
259
+ qa_text + [[pos_i + 1, self.dict['</s>'], 0]],
260
+ [[pos_i + 1, self.dict['</s>'], 0]],
261
+ ]
262
+ prompt_list = [qa_text[len_:], qa_text]
263
+ if len_ == 1:
264
+ ques_list.append([[self.max_text_len + 1, self.dict['</s>'], 0]])
265
+ prompt_list.append(
266
+ [[self.max_text_len + 2, self.dict['<pad>'], 0]] * 4 + qa_text)
267
+ for _ in range(1, self.k):
268
+ ques_list.append(
269
+ [[self.max_text_len + 2, self.dict['<pad>'], 0]])
270
+ prompt_list.append(qa_text[1:])
271
+ else:
272
+
273
+ next_id = random.sample(range(1, len_ + 1), 2)
274
+ for slice_id in next_id:
275
+ b_i = slice_id - 5 if slice_id - 5 > 0 else 0
276
+ if slice_id == len_:
277
+ ques_list.append(
278
+ [[self.max_text_len + 1, self.dict['</s>'], 0]])
279
+ else:
280
+ ques_list.append(
281
+ qa_text[slice_id:] +
282
+ [[self.max_text_len + 1, qa_text[slice_id][1], 0]])
283
+ prompt_list.append(
284
+ [[self.max_text_len + 2, self.dict['<pad>'], 0]] *
285
+ (5 - slice_id + b_i) + qa_text[b_i:slice_id])
286
+
287
+ shuffle_id1 = random.sample(range(1, len_),
288
+ 2) if len_ > 2 else [1, 0]
289
+ for slice_id in shuffle_id1:
290
+ if slice_id == 0:
291
+ ques_list.append(
292
+ [[self.max_text_len + 2, self.dict['<pad>'], 0]])
293
+ prompt_list.append(qa_text[:0])
294
+ else:
295
+ ques_list.append(qa_text[slice_id:] +
296
+ [[pos_i + 1, self.dict['</s>'], 0]])
297
+ prompt_list.append(qa_text[:slice_id])
298
+
299
+ if len_ > 2:
300
+ shuffle_id2 = random.sample(
301
+ range(1, len_),
302
+ self.k - 4 if len_ - 1 > self.k - 4 else len_ - 1)
303
+ if self.k - 4 != len(shuffle_id2):
304
+ shuffle_id2 += random.sample(range(1, len_),
305
+ self.k - 4 - len(shuffle_id2))
306
+ rare_slice_id = len(rare_char_qa)
307
+ unrare_slice_id = len(unrare_char_qa)
308
+ for slice_id in shuffle_id2:
309
+ random.shuffle(qa_text)
310
+ if len(rare_char_qa) > 0 and random.random() < 0.5:
311
+ ques_list.append(rare_char_qa[:rare_slice_id] +
312
+ unrare_char_qa[unrare_slice_id:] +
313
+ [[pos_i + 1, self.dict['</s>'], 0]])
314
+ if len(unrare_char_qa[:unrare_slice_id]) > 0:
315
+ prompt_list1 = random.sample(
316
+ unrare_char_qa[:unrare_slice_id],
317
+ random.randint(
318
+ 1, len(unrare_char_qa[:unrare_slice_id]))
319
+ if len(unrare_char_qa[:unrare_slice_id]) > 1
320
+ else 1,
321
+ )
322
+ else:
323
+ prompt_list1 = []
324
+ if len(rare_char_qa[rare_slice_id:]) > 0:
325
+ prompt_list2 = random.sample(
326
+ rare_char_qa[rare_slice_id:],
327
+ random.randint(
328
+ 1,
329
+ len(rare_char_qa[rare_slice_id:])
330
+ if len(rare_char_qa[rare_slice_id:]) > 1
331
+ else 1,
332
+ ),
333
+ )
334
+ else:
335
+ prompt_list2 = []
336
+ prompt_list.append(prompt_list1 + prompt_list2)
337
+ random.shuffle(rare_char_qa)
338
+ random.shuffle(unrare_char_qa)
339
+ rare_slice_id = random.randint(
340
+ 1,
341
+ len(rare_char_qa)) if len(rare_char_qa) > 1 else 1
342
+ unrare_slice_id = random.randint(
343
+ 1, len(unrare_char_qa)
344
+ ) if len(unrare_char_qa) > 1 else 1
345
+ else:
346
+ ques_list.append(qa_text[slice_id:] +
347
+ [[pos_i + 1, self.dict['</s>'], 0]])
348
+ prompt_list.append(qa_text[:slice_id])
349
+ else:
350
+ ques_list.append(qa_text[1:] +
351
+ [[pos_i + 1, self.dict['</s>'], 0]])
352
+ prompt_list.append(qa_text[:1])
353
+ ques_list.append(qa_text[:1] +
354
+ [[pos_i + 1, self.dict['</s>'], 0]])
355
+ prompt_list.append(qa_text[1:])
356
+ ques_list += [[[self.max_text_len + 2, self.dict['<pad>'], 0]]
357
+ ] * (self.k - 6)
358
+ prompt_list += [qa_text[:0]] * (self.k - 6)
359
+
360
+ return text_list, char_num, ques_list, prompt_list
@@ -0,0 +1,95 @@
1
+ '''
2
+ This code is refer from:
3
+ https://github.com/AlibabaResearch/AdvancedLiterateMachinery/blob/main/OCR/MGP-STR
4
+ '''
5
+ import numpy as np
6
+
7
+ from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
8
+
9
+
10
+ class MGPLabelEncode(BaseRecLabelEncode):
11
+ """ Convert between text-label and text-index """
12
+ SPACE = '[s]'
13
+ GO = '[GO]'
14
+ list_token = [GO, SPACE]
15
+
16
+ def __init__(self,
17
+ max_text_length,
18
+ character_dict_path=None,
19
+ use_space_char=False,
20
+ only_char=False,
21
+ **kwargs):
22
+ super(MGPLabelEncode,
23
+ self).__init__(max_text_length, character_dict_path,
24
+ use_space_char)
25
+ # character (str): set of the possible characters.
26
+ # [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
27
+
28
+ self.batch_max_length = max_text_length + len(self.list_token)
29
+ self.only_char = only_char
30
+ if not only_char:
31
+ # transformers==4.2.1
32
+ from transformers import BertTokenizer, GPT2Tokenizer
33
+ self.bpe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
34
+ self.wp_tokenizer = BertTokenizer.from_pretrained(
35
+ 'bert-base-uncased')
36
+
37
+ def __call__(self, data):
38
+ text = data['label']
39
+ char_text, char_len = self.encode(text)
40
+ if char_text is None:
41
+ return None
42
+ data['length'] = np.array(char_len)
43
+ data['char_label'] = np.array(char_text)
44
+ if self.only_char:
45
+ return data
46
+ bpe_text = self.bpe_encode(text)
47
+ if bpe_text is None:
48
+ return None
49
+ wp_text = self.wp_encode(text)
50
+ data['bpe_label'] = np.array(bpe_text)
51
+ data['wp_label'] = wp_text
52
+ return data
53
+
54
+ def add_special_char(self, dict_character):
55
+ dict_character = self.list_token + dict_character
56
+ return dict_character
57
+
58
+ def encode(self, text):
59
+ """ convert text-label into text-index.
60
+ """
61
+ if len(text) == 0:
62
+ return None, None
63
+ if self.lower:
64
+ text = text.lower()
65
+ length = len(text)
66
+ text = [self.GO] + list(text) + [self.SPACE]
67
+ text_list = []
68
+ for char in text:
69
+ if char not in self.dict:
70
+ continue
71
+ text_list.append(self.dict[char])
72
+ if len(text_list) == 0 or len(text_list) > self.batch_max_length:
73
+ return None, None
74
+ text_list = text_list + [self.dict[self.GO]
75
+ ] * (self.batch_max_length - len(text_list))
76
+ return text_list, length
77
+
78
+ def bpe_encode(self, text):
79
+ if len(text) == 0:
80
+ return None
81
+ token = self.bpe_tokenizer(text)['input_ids']
82
+ text_list = [1] + token + [2]
83
+ if len(text_list) == 0 or len(text_list) > self.batch_max_length:
84
+ return None
85
+ text_list = text_list + [self.dict[self.GO]
86
+ ] * (self.batch_max_length - len(text_list))
87
+ return text_list
88
+
89
+ def wp_encode(self, text):
90
+ wp_target = self.wp_tokenizer([text],
91
+ padding='max_length',
92
+ max_length=self.batch_max_length,
93
+ truncation=True,
94
+ return_tensors='np')
95
+ return wp_target['input_ids'][0]
@@ -0,0 +1,150 @@
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+
18
+ import imgaug.augmenters as iaa
19
+ import numpy as np
20
+ from PIL import Image, ImageFilter
21
+
22
+ from openrec.preprocess import auto_augment
23
+ from openrec.preprocess.auto_augment import _LEVEL_DENOM, LEVEL_TO_ARG, NAME_TO_OP, _randomly_negate, rotate
24
+
25
+
26
+ def rotate_expand(img, degrees, **kwargs):
27
+ """Rotate operation with expand=True to avoid cutting off the
28
+ characters."""
29
+ kwargs['expand'] = True
30
+ return rotate(img, degrees, **kwargs)
31
+
32
+
33
+ def _level_to_arg(level, hparams, key, default):
34
+ magnitude = hparams.get(key, default)
35
+ level = (level / _LEVEL_DENOM) * magnitude
36
+ level = _randomly_negate(level)
37
+ return level,
38
+
39
+
40
+ def apply():
41
+ # Overrides
42
+ NAME_TO_OP.update({'Rotate': rotate_expand})
43
+ LEVEL_TO_ARG.update({
44
+ 'Rotate':
45
+ partial(_level_to_arg, key='rotate_deg', default=30.),
46
+ 'ShearX':
47
+ partial(_level_to_arg, key='shear_x_pct', default=0.3),
48
+ 'ShearY':
49
+ partial(_level_to_arg, key='shear_y_pct', default=0.3),
50
+ 'TranslateXRel':
51
+ partial(_level_to_arg, key='translate_x_pct', default=0.45),
52
+ 'TranslateYRel':
53
+ partial(_level_to_arg, key='translate_y_pct', default=0.45),
54
+ })
55
+
56
+
57
+ apply()
58
+
59
+ _OP_CACHE = {}
60
+
61
+
62
+ def _get_op(key, factory):
63
+ try:
64
+ op = _OP_CACHE[key]
65
+ except KeyError:
66
+ op = factory()
67
+ _OP_CACHE[key] = op
68
+ return op
69
+
70
+
71
+ def _get_param(level, img, max_dim_factor, min_level=1):
72
+ max_level = max(min_level, max_dim_factor * max(img.size))
73
+ return round(min(level, max_level))
74
+
75
+
76
+ def gaussian_blur(img, radius, **__):
77
+ radius = _get_param(radius, img, 0.02)
78
+ key = 'gaussian_blur_' + str(radius)
79
+ op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius))
80
+ return img.filter(op)
81
+
82
+
83
+ def motion_blur(img, k, **__):
84
+ k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values
85
+ key = 'motion_blur_' + str(k)
86
+ op = _get_op(key, lambda: iaa.MotionBlur(k))
87
+ return Image.fromarray(op(image=np.asarray(img)))
88
+
89
+
90
+ def gaussian_noise(img, scale, **_):
91
+ scale = _get_param(scale, img, 0.25) | 1 # bin to odd values
92
+ key = 'gaussian_noise_' + str(scale)
93
+ op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale))
94
+ return Image.fromarray(op(image=np.asarray(img)))
95
+
96
+
97
+ def poisson_noise(img, lam, **_):
98
+ lam = _get_param(lam, img, 0.2) | 1 # bin to odd values
99
+ key = 'poisson_noise_' + str(lam)
100
+ op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam))
101
+ return Image.fromarray(op(image=np.asarray(img)))
102
+
103
+
104
+ def _level_to_arg(level, _hparams, max):
105
+ level = max * level / auto_augment._LEVEL_DENOM
106
+ return level,
107
+
108
+
109
+ _RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy()
110
+ _RAND_TRANSFORMS.remove(
111
+ 'SharpnessIncreasing') # remove, interferes with *blur ops
112
+ _RAND_TRANSFORMS.extend([
113
+ 'GaussianBlur',
114
+ # 'MotionBlur',
115
+ # 'GaussianNoise',
116
+ 'PoissonNoise'
117
+ ])
118
+ auto_augment.LEVEL_TO_ARG.update({
119
+ 'GaussianBlur':
120
+ partial(_level_to_arg, max=4),
121
+ 'MotionBlur':
122
+ partial(_level_to_arg, max=20),
123
+ 'GaussianNoise':
124
+ partial(_level_to_arg, max=0.1 * 255),
125
+ 'PoissonNoise':
126
+ partial(_level_to_arg, max=40)
127
+ })
128
+ auto_augment.NAME_TO_OP.update({
129
+ 'GaussianBlur': gaussian_blur,
130
+ 'MotionBlur': motion_blur,
131
+ 'GaussianNoise': gaussian_noise,
132
+ 'PoissonNoise': poisson_noise
133
+ })
134
+
135
+
136
+ def rand_augment_transform(magnitude=5, num_layers=3):
137
+ # These are tuned for magnitude=5, which means that effective magnitudes are half of these values.
138
+ hparams = {
139
+ 'rotate_deg': 30,
140
+ 'shear_x_pct': 0.9,
141
+ 'shear_y_pct': 0.2,
142
+ 'translate_x_pct': 0.10,
143
+ 'translate_y_pct': 0.30
144
+ }
145
+ ra_ops = auto_augment.rand_augment_ops(magnitude,
146
+ hparams=hparams,
147
+ transforms=_RAND_TRANSFORMS)
148
+ # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice)
149
+ choice_weights = [1. / len(ra_ops) for _ in range(len(ra_ops))]
150
+ return auto_augment.RandAugment(ra_ops, num_layers, choice_weights)