openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,760 @@
1
+ """This code is refer from:
2
+ https://github.com/MelosY/CAM
3
+ """
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+
7
+ # All rights reserved.
8
+
9
+ # This source code is licensed under the license found in the
10
+ # LICENSE file in the root directory of this source tree.
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torch.nn.init import trunc_normal_
16
+
17
+ from .convnextv2 import ConvNeXtV2, Block, LayerNorm
18
+ from .svtrv2_lnconv_two33 import SVTRv2LNConvTwo33
19
+
20
+
21
+ class Swish(nn.Module):
22
+
23
+ def __init__(self) -> None:
24
+ super().__init__()
25
+
26
+ def forward(self, x):
27
+ return x * torch.sigmoid(x)
28
+
29
+
30
+ class UNetBlock(nn.Module):
31
+
32
+ def __init__(self, cin, cout, bn2d, stride, deformable=False):
33
+ """
34
+ a UNet block with 2x up sampling
35
+ """
36
+ super().__init__()
37
+ stride_h, stride_w = stride
38
+ if stride_h == 1:
39
+ kernel_h = 1
40
+ padding_h = 0
41
+ elif stride_h == 2:
42
+ kernel_h = 4
43
+ padding_h = 1
44
+ elif stride_h == 4:
45
+ kernel_h = 4
46
+ padding_h = 0
47
+
48
+ if stride_w == 1:
49
+ kernel_w = 1
50
+ padding_w = 0
51
+ elif stride_w == 2:
52
+ kernel_w = 4
53
+ padding_w = 1
54
+ elif stride_w == 4:
55
+ kernel_w = 4
56
+ padding_w = 0
57
+
58
+ conv = nn.Conv2d
59
+
60
+ self.up_sample = nn.ConvTranspose2d(cin,
61
+ cin,
62
+ kernel_size=(kernel_h, kernel_w),
63
+ stride=(stride_h, stride_w),
64
+ padding=(padding_h, padding_w),
65
+ bias=True)
66
+ self.conv = nn.Sequential(
67
+ conv(cin, cin, kernel_size=3, stride=1, padding=1, bias=False),
68
+ bn2d(cin),
69
+ nn.ReLU6(inplace=True),
70
+ conv(cin, cout, kernel_size=3, stride=1, padding=1, bias=False),
71
+ bn2d(cout),
72
+ )
73
+
74
+ def forward(self, x):
75
+ x = self.up_sample(x)
76
+ return self.conv(x)
77
+
78
+
79
+ class DepthWiseUNetBlock(nn.Module):
80
+
81
+ def __init__(self, cin, cout, bn2d, stride, deformable=False):
82
+ """
83
+ a UNet block with 2x up sampling
84
+ """
85
+ super().__init__()
86
+ stride_h, stride_w = stride
87
+ if stride_h == 1:
88
+ kernel_h = 1
89
+ padding_h = 0
90
+ elif stride_h == 2:
91
+ kernel_h = 4
92
+ padding_h = 1
93
+ elif stride_h == 4:
94
+ kernel_h = 4
95
+ padding_h = 0
96
+
97
+ if stride_w == 1:
98
+ kernel_w = 1
99
+ padding_w = 0
100
+ elif stride_w == 2:
101
+ kernel_w = 4
102
+ padding_w = 1
103
+ elif stride_w == 4:
104
+ kernel_w = 4
105
+ padding_w = 0
106
+
107
+ self.up_sample = nn.ConvTranspose2d(cin,
108
+ cin,
109
+ kernel_size=(kernel_h, kernel_w),
110
+ stride=(stride_h, stride_w),
111
+ padding=(padding_h, padding_w),
112
+ bias=True)
113
+ self.conv = nn.Sequential(
114
+ nn.Conv2d(cin,
115
+ cin,
116
+ kernel_size=3,
117
+ stride=1,
118
+ padding=1,
119
+ bias=False,
120
+ groups=cin),
121
+ nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
122
+ bias=False),
123
+ bn2d(cin),
124
+ nn.ReLU6(inplace=True),
125
+ nn.Conv2d(cin,
126
+ cin,
127
+ kernel_size=3,
128
+ stride=1,
129
+ padding=1,
130
+ bias=False,
131
+ groups=cin),
132
+ nn.Conv2d(cin,
133
+ cout,
134
+ kernel_size=1,
135
+ stride=1,
136
+ padding=0,
137
+ bias=False),
138
+ bn2d(cout),
139
+ )
140
+
141
+ def forward(self, x):
142
+ x = self.up_sample(x)
143
+ return self.conv(x)
144
+
145
+
146
+ class SFTLayer(nn.Module):
147
+
148
+ def __init__(self, dim_in, dim_out):
149
+ super(SFTLayer, self).__init__()
150
+ self.SFT_scale_conv0 = nn.Linear(
151
+ dim_in,
152
+ dim_in,
153
+ )
154
+ self.SFT_scale_conv1 = nn.Linear(
155
+ dim_in,
156
+ dim_out,
157
+ )
158
+ self.SFT_shift_conv0 = nn.Linear(
159
+ dim_in,
160
+ dim_in,
161
+ )
162
+ self.SFT_shift_conv1 = nn.Linear(
163
+ dim_in,
164
+ dim_out,
165
+ )
166
+
167
+ def forward(self, x):
168
+ # x[0]: fea; x[1]: cond
169
+ scale = self.SFT_scale_conv1(
170
+ F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.1, inplace=True))
171
+ shift = self.SFT_shift_conv1(
172
+ F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.1, inplace=True))
173
+ return x[0] * (scale + 1) + shift
174
+
175
+
176
+ class MoreUNetBlock(nn.Module):
177
+
178
+ def __init__(self, cin, cout, bn2d, stride, deformable=False):
179
+ """
180
+ a UNet block with 2x up sampling
181
+ """
182
+ super().__init__()
183
+ stride_h, stride_w = stride
184
+ if stride_h == 1:
185
+ kernel_h = 1
186
+ padding_h = 0
187
+ elif stride_h == 2:
188
+ kernel_h = 4
189
+ padding_h = 1
190
+ elif stride_h == 4:
191
+ kernel_h = 4
192
+ padding_h = 0
193
+
194
+ if stride_w == 1:
195
+ kernel_w = 1
196
+ padding_w = 0
197
+ elif stride_w == 2:
198
+ kernel_w = 4
199
+ padding_w = 1
200
+ elif stride_w == 4:
201
+ kernel_w = 4
202
+ padding_w = 0
203
+
204
+ self.up_sample = nn.ConvTranspose2d(cin,
205
+ cin,
206
+ kernel_size=(kernel_h, kernel_w),
207
+ stride=(stride_h, stride_w),
208
+ padding=(padding_h, padding_w),
209
+ bias=True)
210
+ self.conv = nn.Sequential(
211
+ nn.Conv2d(cin,
212
+ cin,
213
+ kernel_size=3,
214
+ stride=1,
215
+ padding=1,
216
+ bias=False,
217
+ groups=cin),
218
+ nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
219
+ bias=False), bn2d(cin), nn.ReLU6(inplace=True),
220
+ nn.Conv2d(cin,
221
+ cin,
222
+ kernel_size=3,
223
+ stride=1,
224
+ padding=1,
225
+ bias=False,
226
+ groups=cin),
227
+ nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0,
228
+ bias=False), bn2d(cin), nn.ReLU6(inplace=True),
229
+ nn.Conv2d(cin,
230
+ cin,
231
+ kernel_size=3,
232
+ stride=1,
233
+ padding=1,
234
+ bias=False,
235
+ groups=cin),
236
+ nn.Conv2d(cin,
237
+ cout,
238
+ kernel_size=1,
239
+ stride=1,
240
+ padding=0,
241
+ bias=False), bn2d(cout), nn.ReLU6(inplace=True),
242
+ nn.Conv2d(cout,
243
+ cout,
244
+ kernel_size=3,
245
+ stride=1,
246
+ padding=1,
247
+ bias=False,
248
+ groups=cout),
249
+ nn.Conv2d(cout,
250
+ cout,
251
+ kernel_size=1,
252
+ stride=1,
253
+ padding=0,
254
+ bias=False), bn2d(cout))
255
+
256
+ def forward(self, x):
257
+ x = self.up_sample(x)
258
+ return self.conv(x)
259
+
260
+
261
+ class BinaryDecoder(nn.Module):
262
+
263
+ def __init__(self,
264
+ dim,
265
+ num_classes,
266
+ strides,
267
+ use_depthwise_unet=False,
268
+ use_more_unet=False,
269
+ binary_loss_type='DiceLoss') -> None:
270
+ super().__init__()
271
+
272
+ channels = [dim // 2**i for i in range(4)]
273
+ self.linear_enc2binary = nn.Sequential(
274
+ nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1),
275
+ nn.SyncBatchNorm(dim),
276
+ )
277
+ self.strides = strides
278
+ self.use_deformable = False
279
+ self.binary_decoder = nn.ModuleList()
280
+ unet = DepthWiseUNetBlock if use_depthwise_unet else UNetBlock
281
+ unet = MoreUNetBlock if use_more_unet else unet
282
+
283
+ for i in range(3):
284
+ up_sample_stride = self.strides[::-1][i]
285
+ cin, cout = channels[i], channels[i + 1]
286
+ self.binary_decoder.append(
287
+ unet(cin, cout, nn.SyncBatchNorm, up_sample_stride,
288
+ self.use_deformable))
289
+
290
+ last_stride = (self.strides[0][0] // 2, self.strides[0][1] // 2)
291
+ self.binary_decoder.append(
292
+ unet(cout, cout, nn.SyncBatchNorm, last_stride,
293
+ self.use_deformable))
294
+
295
+ if binary_loss_type == 'CrossEntropyDiceLoss' or binary_loss_type == 'BanlanceMultiClassCrossEntropyLoss':
296
+ segm_num_cls = num_classes - 2
297
+ else:
298
+ segm_num_cls = num_classes - 3
299
+ self.binary_pred = nn.Conv2d(channels[-1],
300
+ segm_num_cls,
301
+ kernel_size=1,
302
+ stride=1,
303
+ bias=True)
304
+
305
+ def patchify(self, imgs):
306
+ """
307
+ imgs: (N, 3, H, W)
308
+ x: (N, L, patch_size**2 *3)
309
+ """
310
+ p_h, p_w = self.strides[0]
311
+ p_h = p_h // 2
312
+ p_w = p_w // 2
313
+ h = imgs.shape[2] // p_h
314
+ w = imgs.shape[3] // p_w
315
+
316
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p_h, w, p_w))
317
+ x = torch.einsum('nchpwq->nhwpqc', x)
318
+ x = x.reshape(shape=(imgs.shape[0], h * w, p_h * p_w * 3))
319
+ return x
320
+
321
+ def unpatchify(self, x):
322
+ """
323
+ x: (N, patch_size**2, h, w)
324
+ imgs: (N, 3, H, W)
325
+ """
326
+ p_h, p_w = self.strides[0]
327
+ p_h = p_h // 2
328
+ p_w = p_w // 2
329
+ _, _, h, w = x.shape
330
+ assert p_h * p_w == x.shape[1]
331
+
332
+ x = x.permute(0, 2, 3, 1) # [N, h, w, 4*4]
333
+ x = x.reshape(shape=(x.shape[0], h, w, p_h, p_w))
334
+ x = torch.einsum('nhwpq->nhpwq', x)
335
+ imgs = x.reshape(shape=(x.shape[0], h * p_h, w * p_w))
336
+ return imgs
337
+
338
+ def forward(self, x, time=None):
339
+ """
340
+ x: the encoder feat to init the query for binary prediction, usually this is equal to the `img`.
341
+ img: the encoder feat.
342
+ txt: the unnormmed text to get the length of predicted words.
343
+ txt_feat: the text feat before character prediction.
344
+ xs: the encoder feat from different stages
345
+ """
346
+
347
+ binary_feats = []
348
+ x = self.linear_enc2binary(x)
349
+ binary_feats.append(x.clone())
350
+
351
+ for i, d in enumerate(self.binary_decoder):
352
+
353
+ x = d(x)
354
+ binary_feats.append(x.clone())
355
+ #return None,binary_feats
356
+ x = self.binary_pred(x)
357
+
358
+ if self.training:
359
+ return x, binary_feats
360
+ else:
361
+ # return torch.sigmoid(x), binary_feat
362
+ return x.softmax(1), binary_feats
363
+
364
+
365
+ class LayerNormProxy(nn.Module):
366
+
367
+ def __init__(self, dim):
368
+
369
+ super().__init__()
370
+ self.norm = nn.LayerNorm(dim)
371
+
372
+ def forward(self, x):
373
+ x = x.permute(0, 2, 3, 1)
374
+ x = self.norm(x)
375
+ return x.permute(0, 3, 1, 2)
376
+
377
+
378
+ class DAttentionFuse(nn.Module):
379
+
380
+ def __init__(
381
+ self,
382
+ q_size=(4, 32),
383
+ kv_size=(4, 32),
384
+ n_heads=8,
385
+ n_head_channels=80,
386
+ n_groups=4,
387
+ attn_drop=0.0,
388
+ proj_drop=0.0,
389
+ stride=2,
390
+ offset_range_factor=2,
391
+ use_pe=True,
392
+ stage_idx=0,
393
+ ):
394
+ '''
395
+ stage_idx from 2 to 3
396
+ '''
397
+
398
+ super().__init__()
399
+ self.n_head_channels = n_head_channels
400
+ self.scale = self.n_head_channels**-0.5
401
+ self.n_heads = n_heads
402
+ self.q_h, self.q_w = q_size
403
+ self.kv_h, self.kv_w = kv_size
404
+ self.nc = n_head_channels * n_heads
405
+ self.n_groups = n_groups
406
+ self.n_group_channels = self.nc // self.n_groups
407
+ self.n_group_heads = self.n_heads // self.n_groups
408
+ self.use_pe = use_pe
409
+ self.offset_range_factor = offset_range_factor
410
+ ksizes = [9, 7, 5, 3]
411
+ kk = ksizes[stage_idx]
412
+
413
+ self.conv_offset = nn.Sequential(
414
+ nn.Conv2d(2 * self.n_group_channels,
415
+ 2 * self.n_group_channels,
416
+ kk,
417
+ stride,
418
+ kk // 2,
419
+ groups=self.n_group_channels),
420
+ LayerNormProxy(2 * self.n_group_channels), nn.GELU(),
421
+ nn.Conv2d(2 * self.n_group_channels, 2, 1, 1, 0, bias=False))
422
+
423
+ self.proj_q = nn.Conv2d(self.nc,
424
+ self.nc,
425
+ kernel_size=1,
426
+ stride=1,
427
+ padding=0)
428
+
429
+ self.proj_k = nn.Conv2d(self.nc,
430
+ self.nc,
431
+ kernel_size=1,
432
+ stride=1,
433
+ padding=0)
434
+
435
+ self.proj_v = nn.Conv2d(self.nc,
436
+ self.nc,
437
+ kernel_size=1,
438
+ stride=1,
439
+ padding=0)
440
+
441
+ self.proj_out = nn.Conv2d(self.nc,
442
+ self.nc,
443
+ kernel_size=1,
444
+ stride=1,
445
+ padding=0)
446
+
447
+ self.proj_drop = nn.Dropout(proj_drop, inplace=True)
448
+ self.attn_drop = nn.Dropout(attn_drop, inplace=True)
449
+
450
+ if self.use_pe:
451
+ self.rpe_table = nn.Parameter(
452
+ torch.zeros(self.n_heads, self.kv_h * 2 - 1,
453
+ self.kv_w * 2 - 1))
454
+ trunc_normal_(self.rpe_table, std=0.01)
455
+ else:
456
+ self.rpe_table = None
457
+
458
+ @torch.no_grad()
459
+ def _get_ref_points(self, H_key, W_key, B, dtype, device):
460
+
461
+ ref_y, ref_x = torch.meshgrid(
462
+ torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype,
463
+ device=device),
464
+ torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype,
465
+ device=device))
466
+ ref = torch.stack((ref_y, ref_x), -1)
467
+ ref[..., 1].div_(W_key).mul_(2).sub_(1)
468
+ ref[..., 0].div_(H_key).mul_(2).sub_(1)
469
+ ref = ref[None, ...].expand(B * self.n_groups, -1, -1,
470
+ -1) # B * g H W 2
471
+ return ref
472
+
473
+ def forward(self, x, y):
474
+ B, C, H, W = x.size()
475
+ dtype, device = x.dtype, x.device
476
+
477
+ q_off = torch.cat(
478
+ (x, y), dim=1
479
+ ).reshape(B, self.n_groups, 2 * self.n_group_channels, H, W).flatten(
480
+ 0, 1
481
+ ) #einops.rearrange(torch.cat((x,y),dim=1), 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=2*self.n_group_channels)
482
+
483
+ offset = self.conv_offset(q_off) # B * g 2 Hg Wg
484
+ Hk, Wk = offset.size(2), offset.size(3)
485
+ n_sample = Hk * Wk
486
+ if self.offset_range_factor > 0:
487
+ offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk],
488
+ device=device).reshape(1, 2, 1, 1)
489
+ offset = offset.tanh().mul(offset_range).mul(
490
+ self.offset_range_factor)
491
+
492
+ offset = offset.permute(
493
+ 0, 2, 3, 1) #einops.rearrange(offset, 'b p h w -> b h w p')
494
+ reference = self._get_ref_points(Hk, Wk, B, dtype, device)
495
+
496
+ if self.offset_range_factor >= 0:
497
+ pos = offset + reference
498
+ else:
499
+ pos = (offset + reference).tanh()
500
+
501
+ q = self.proj_q(y)
502
+ x_sampled = F.grid_sample(
503
+ input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),
504
+ grid=pos[..., (1, 0)], # y, x -> x, y
505
+ mode='bilinear',
506
+ align_corners=False) # B * g, Cg, Hg, Wg
507
+
508
+ x_sampled = x_sampled.reshape(B, C, 1, n_sample)
509
+
510
+ q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
511
+ k = self.proj_k(x_sampled).reshape(B * self.n_heads,
512
+ self.n_head_channels, n_sample)
513
+ v = self.proj_v(x_sampled).reshape(B * self.n_heads,
514
+ self.n_head_channels, n_sample)
515
+
516
+ attn = torch.einsum('b c m, b c n -> b m n', q, k) # B * h, HW, Ns
517
+ attn = attn.mul(self.scale)
518
+
519
+ if self.use_pe:
520
+ rpe_table = self.rpe_table
521
+ rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
522
+
523
+ q_grid = self._get_ref_points(H, W, B, dtype, device)
524
+
525
+ displacement = (
526
+ q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) -
527
+ pos.reshape(B * self.n_groups, n_sample,
528
+ 2).unsqueeze(1)).mul(0.5)
529
+
530
+ attn_bias = F.grid_sample(input=rpe_bias.reshape(
531
+ B * self.n_groups, self.n_group_heads, 2 * H - 1, 2 * W - 1),
532
+ grid=displacement[..., (1, 0)],
533
+ mode='bilinear',
534
+ align_corners=False)
535
+
536
+ attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
537
+
538
+ attn = attn + attn_bias
539
+
540
+ attn = F.softmax(attn, dim=2)
541
+ attn = self.attn_drop(attn)
542
+
543
+ out = torch.einsum('b m n, b c n -> b c m', attn, v)
544
+ out = out.reshape(B, C, H, W)
545
+ out = self.proj_drop(self.proj_out(out))
546
+
547
+ return out, pos.reshape(B, self.n_groups, Hk, Wk,
548
+ 2), reference.reshape(B, self.n_groups, Hk, Wk,
549
+ 2)
550
+
551
+
552
+ class FuseModel(nn.Module):
553
+
554
+ def __init__(self,
555
+ dim,
556
+ deform_stride=2,
557
+ stage_idx=2,
558
+ k_size=[(2, 2), (2, 1), (2, 1), (1, 1)],
559
+ q_size=(2, 32)):
560
+ super().__init__()
561
+
562
+ channels = [dim // 2**i for i in range(4)]
563
+
564
+ refine_conv = nn.Conv2d
565
+ self.deform_stride = deform_stride
566
+
567
+ in_out_ch = [(-1, -2), (-2, -3), (-3, -4), (-4, -4)]
568
+
569
+ self.binary_condition_layer = DAttentionFuse(q_size=q_size,
570
+ kv_size=q_size,
571
+ stride=self.deform_stride,
572
+ n_head_channels=dim // 8,
573
+ stage_idx=stage_idx)
574
+
575
+ self.binary2refine_linear_norm = nn.ModuleList()
576
+ for i in range(len(k_size)):
577
+ self.binary2refine_linear_norm.append(
578
+ nn.Sequential(
579
+ Block(dim=channels[in_out_ch[i][0]]),
580
+ LayerNorm(channels[in_out_ch[i][0]],
581
+ eps=1e-6,
582
+ data_format='channels_first'),
583
+ refine_conv(channels[in_out_ch[i][0]],
584
+ channels[in_out_ch[i][1]],
585
+ kernel_size=k_size[i],
586
+ stride=k_size[i])), # [8, 32]
587
+ )
588
+
589
+ def forward(self, recog_feat, binary_feats, dec_in=None):
590
+ multi_feat = []
591
+ binary_feat = binary_feats[-1]
592
+ for i in range(len(self.binary2refine_linear_norm)):
593
+ binary_feat = self.binary2refine_linear_norm[i](binary_feat)
594
+ multi_feat.append(binary_feat)
595
+ binary_feat = binary_feat + binary_feats[0]
596
+ multi_feat[3] += binary_feats[0]
597
+ binary_refined_feat, pos, _ = self.binary_condition_layer(
598
+ recog_feat, binary_feat)
599
+ binary_refined_feat = binary_refined_feat + binary_feat
600
+ return binary_refined_feat, binary_feat
601
+
602
+
603
+ class CAMEncoder(nn.Module):
604
+ """
605
+
606
+ Args:
607
+ in_chans (int): Number of input image channels. Default: 3
608
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
609
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
610
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
611
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
612
+
613
+ """
614
+
615
+ def __init__(self,
616
+ in_channels=3,
617
+ encoder_config={'name': 'ConvNeXtV2'},
618
+ nb_classes=71,
619
+ strides=[(4, 4), (2, 1), (2, 1), (1, 1)],
620
+ k_size=[(2, 2), (2, 1), (2, 1), (1, 1)],
621
+ q_size=[2, 32],
622
+ deform_stride=2,
623
+ stage_idx=2,
624
+ use_depthwise_unet=True,
625
+ use_more_unet=False,
626
+ binary_loss_type='BanlanceMultiClassCrossEntropyLoss',
627
+ mid_size=True,
628
+ d_embedding=384):
629
+ super().__init__()
630
+ encoder_name = encoder_config.pop('name')
631
+ encoder_config['in_channels'] = in_channels
632
+ self.backbone = eval(encoder_name)(**encoder_config)
633
+ dim = self.backbone.out_channels
634
+ self.mid_size = mid_size
635
+ if self.mid_size:
636
+ self.enc_downsample = nn.Sequential(
637
+ nn.Conv2d(dim, dim // 2, kernel_size=1, stride=1),
638
+ nn.SyncBatchNorm(dim // 2),
639
+ #nn.ReLU6(inplace=True),
640
+ nn.Conv2d(dim // 2,
641
+ dim // 2,
642
+ kernel_size=3,
643
+ stride=1,
644
+ padding=1,
645
+ bias=False,
646
+ groups=dim // 2),
647
+ nn.Conv2d(dim // 2,
648
+ dim // 2,
649
+ kernel_size=1,
650
+ stride=1,
651
+ padding=0,
652
+ bias=False),
653
+ nn.SyncBatchNorm(dim // 2),
654
+ )
655
+ dim = dim // 2
656
+ # recognition decoder
657
+ self.linear_enc2recog = nn.Sequential(
658
+ nn.Conv2d(
659
+ dim,
660
+ dim,
661
+ kernel_size=1,
662
+ stride=1,
663
+ ),
664
+ nn.SyncBatchNorm(dim),
665
+ #nn.ReLU6(inplace=True),
666
+ nn.Conv2d(dim,
667
+ dim,
668
+ kernel_size=3,
669
+ stride=1,
670
+ padding=1,
671
+ bias=False,
672
+ groups=dim),
673
+ nn.Conv2d(dim,
674
+ dim,
675
+ kernel_size=1,
676
+ stride=1,
677
+ padding=0,
678
+ bias=False),
679
+ nn.SyncBatchNorm(dim),
680
+ )
681
+ else:
682
+ self.linear_enc2recog = nn.Sequential(
683
+ nn.Conv2d(dim, dim // 2, kernel_size=1, stride=1),
684
+ nn.SyncBatchNorm(dim // 2),
685
+ #nn.ReLU6(inplace=True),
686
+ nn.Conv2d(dim // 2, dim, kernel_size=3, stride=1, padding=1),
687
+ nn.SyncBatchNorm(dim),
688
+ )
689
+
690
+ self.linear_norm = nn.Sequential(
691
+ nn.Linear(dim, d_embedding),
692
+ nn.LayerNorm(d_embedding, eps=1e-6),
693
+ )
694
+ self.out_channels = d_embedding
695
+
696
+ self.binary_decoder = BinaryDecoder(
697
+ dim,
698
+ nb_classes,
699
+ strides,
700
+ use_depthwise_unet=use_depthwise_unet,
701
+ use_more_unet=use_more_unet,
702
+ binary_loss_type=binary_loss_type)
703
+ self.fuse_model = FuseModel(dim,
704
+ deform_stride=deform_stride,
705
+ stage_idx=stage_idx,
706
+ k_size=k_size,
707
+ q_size=q_size)
708
+ self.apply(self._init_weights)
709
+
710
+ def _init_weights(self, m):
711
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
712
+ trunc_normal_(m.weight, std=.02)
713
+ if isinstance(m, (nn.Conv2d, nn.Linear)) and m.bias is not None:
714
+ nn.init.constant_(m.bias, 0)
715
+ if isinstance(m, nn.ConvTranspose2d):
716
+ nn.init.kaiming_normal_(m.weight,
717
+ mode='fan_out',
718
+ nonlinearity='relu')
719
+ if m.bias is not None:
720
+ nn.init.constant_(m.bias, 0.)
721
+ elif isinstance(m, nn.LayerNorm):
722
+ if m.bias is not None:
723
+ nn.init.constant_(m.bias, 0)
724
+ if m.weight is not None:
725
+ nn.init.constant_(m.weight, 1.0)
726
+ elif isinstance(m, nn.SyncBatchNorm):
727
+ if m.bias is not None:
728
+ nn.init.constant_(m.bias, 0)
729
+ if m.weight is not None:
730
+ nn.init.constant_(m.weight, 1.0)
731
+ elif isinstance(m, nn.BatchNorm2d):
732
+ if m.bias is not None:
733
+ nn.init.constant_(m.bias, 0)
734
+ if m.weight is not None:
735
+ nn.init.constant_(m.weight, 1.0)
736
+
737
+ def no_weight_decay(self):
738
+ return {}
739
+
740
+ def forward(self, x):
741
+ output = {}
742
+ enc_feat = self.backbone(x)
743
+ if self.mid_size:
744
+ enc_feat = self.enc_downsample(enc_feat)
745
+ output['enc_feat'] = enc_feat
746
+
747
+ # binary mask
748
+ pred_binary, binary_feats = self.binary_decoder(enc_feat)
749
+ output['pred_binary'] = pred_binary
750
+
751
+ reg_feat = self.linear_enc2recog(enc_feat)
752
+ B, C, H, W = reg_feat.shape
753
+ last_feat, binary_feat = self.fuse_model(reg_feat, binary_feats)
754
+
755
+ dec_in = last_feat.reshape(B, C, H * W).permute(0, 2, 1)
756
+ dec_in = self.linear_norm(dec_in)
757
+
758
+ output['refined_feat'] = dec_in
759
+ output['binary_feat'] = binary_feats[-1]
760
+ return output