openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,631 @@
1
+ # --------------------------------------------------------
2
+ # FocalNets -- Focal Modulation Networks
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Jianwei Yang (jianwyan@microsoft.com)
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.utils.checkpoint as checkpoint
11
+ from torch.nn.init import trunc_normal_
12
+
13
+ from openrec.modeling.common import DropPath, Mlp
14
+ from openrec.modeling.encoders.svtrnet import ConvBNLayer
15
+
16
+
17
+ class FocalModulation(nn.Module):
18
+
19
+ def __init__(self,
20
+ dim,
21
+ focal_window,
22
+ focal_level,
23
+ max_kh=None,
24
+ focal_factor=2,
25
+ bias=True,
26
+ proj_drop=0.0,
27
+ use_postln_in_modulation=False,
28
+ normalize_modulator=False):
29
+ super().__init__()
30
+
31
+ self.dim = dim
32
+ self.focal_window = focal_window
33
+ self.focal_level = focal_level
34
+ self.focal_factor = focal_factor
35
+ self.use_postln_in_modulation = use_postln_in_modulation
36
+ self.normalize_modulator = normalize_modulator
37
+
38
+ self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=bias)
39
+ self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias)
40
+
41
+ self.act = nn.GELU()
42
+ self.proj = nn.Linear(dim, dim)
43
+ self.proj_drop = nn.Dropout(proj_drop)
44
+ self.focal_layers = nn.ModuleList()
45
+
46
+ self.kernel_sizes = []
47
+ for k in range(self.focal_level):
48
+ kernel_size = self.focal_factor * k + self.focal_window
49
+ if max_kh is not None:
50
+ k_h, k_w = [min(kernel_size, max_kh), kernel_size]
51
+ kernel_size = [k_h, k_w]
52
+ padding = [k_h // 2, k_w // 2]
53
+ else:
54
+ padding = kernel_size // 2
55
+ self.focal_layers.append(
56
+ nn.Sequential(
57
+ nn.Conv2d(dim,
58
+ dim,
59
+ kernel_size=kernel_size,
60
+ stride=1,
61
+ groups=dim,
62
+ padding=padding,
63
+ bias=False),
64
+ nn.GELU(),
65
+ ))
66
+ self.kernel_sizes.append(kernel_size)
67
+ if self.use_postln_in_modulation:
68
+ self.ln = nn.LayerNorm(dim)
69
+
70
+ def forward(self, x):
71
+ """
72
+ Args:
73
+ x: input features with shape of (B, H, W, C)
74
+ """
75
+ C = x.shape[-1]
76
+
77
+ # pre linear projection
78
+ x = self.f(x).permute(0, 3, 1, 2).contiguous()
79
+ q, ctx, self.gates = torch.split(x, (C, C, self.focal_level + 1), 1)
80
+
81
+ # context aggreation
82
+ ctx_all = 0
83
+ for l in range(self.focal_level):
84
+ ctx = self.focal_layers[l](ctx)
85
+ ctx_all = ctx_all + ctx * self.gates[:, l:l + 1]
86
+ ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
87
+ ctx_all = ctx_all + ctx_global * self.gates[:, self.focal_level:]
88
+
89
+ # normalize context
90
+ if self.normalize_modulator:
91
+ ctx_all = ctx_all / (self.focal_level + 1)
92
+
93
+ # focal modulation
94
+ self.modulator = self.h(ctx_all)
95
+ x_out = q * self.modulator
96
+ x_out = x_out.permute(0, 2, 3, 1).contiguous()
97
+ if self.use_postln_in_modulation:
98
+ x_out = self.ln(x_out)
99
+
100
+ # post linear porjection
101
+ x_out = self.proj(x_out)
102
+ x_out = self.proj_drop(x_out)
103
+ return x_out
104
+
105
+ def extra_repr(self) -> str:
106
+ return f'dim={self.dim}'
107
+
108
+ def flops(self, N):
109
+ # calculate flops for 1 window with token length of N
110
+ flops = 0
111
+
112
+ flops += N * self.dim * (self.dim * 2 + (self.focal_level + 1))
113
+
114
+ # focal convolution
115
+ for k in range(self.focal_level):
116
+ flops += N * (self.kernel_sizes[k]**2 + 1) * self.dim
117
+
118
+ # global gating
119
+ flops += N * 1 * self.dim
120
+
121
+ # self.linear
122
+ flops += N * self.dim * (self.dim + 1)
123
+
124
+ # x = self.proj(x)
125
+ flops += N * self.dim * self.dim
126
+ return flops
127
+
128
+
129
+ class FocalNetBlock(nn.Module):
130
+ r"""Focal Modulation Network Block.
131
+
132
+ Args:
133
+ dim (int): Number of input channels.
134
+ input_resolution (tuple[int]): Input resulotion.
135
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
136
+ drop (float, optional): Dropout rate. Default: 0.0
137
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
138
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
139
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
140
+ focal_level (int): Number of focal levels.
141
+ focal_window (int): Focal window size at first focal level
142
+ use_layerscale (bool): Whether use layerscale
143
+ layerscale_value (float): Initial layerscale value
144
+ use_postln (bool): Whether use layernorm after modulation
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ dim,
150
+ input_resolution=None,
151
+ mlp_ratio=4.0,
152
+ drop=0.0,
153
+ drop_path=0.0,
154
+ act_layer=nn.GELU,
155
+ norm_layer=nn.LayerNorm,
156
+ focal_level=1,
157
+ focal_window=3,
158
+ max_kh=None,
159
+ use_layerscale=False,
160
+ layerscale_value=1e-4,
161
+ use_postln=False,
162
+ use_postln_in_modulation=False,
163
+ normalize_modulator=False,
164
+ ):
165
+ super().__init__()
166
+ self.dim = dim
167
+ self.input_resolution = input_resolution
168
+ self.mlp_ratio = mlp_ratio
169
+
170
+ self.focal_window = focal_window
171
+ self.focal_level = focal_level
172
+ self.use_postln = use_postln
173
+
174
+ self.norm1 = norm_layer(dim)
175
+ self.modulation = FocalModulation(
176
+ dim,
177
+ proj_drop=drop,
178
+ focal_window=focal_window,
179
+ focal_level=self.focal_level,
180
+ max_kh=max_kh,
181
+ use_postln_in_modulation=use_postln_in_modulation,
182
+ normalize_modulator=normalize_modulator,
183
+ )
184
+
185
+ self.drop_path = DropPath(
186
+ drop_path) if drop_path > 0.0 else nn.Identity()
187
+ self.norm2 = norm_layer(dim)
188
+ mlp_hidden_dim = int(dim * mlp_ratio)
189
+ self.mlp = Mlp(in_features=dim,
190
+ hidden_features=mlp_hidden_dim,
191
+ act_layer=act_layer,
192
+ drop=drop)
193
+
194
+ self.gamma_1 = 1.0
195
+ self.gamma_2 = 1.0
196
+ if use_layerscale:
197
+ self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)),
198
+ requires_grad=True)
199
+ self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)),
200
+ requires_grad=True)
201
+
202
+ self.H = None
203
+ self.W = None
204
+
205
+ def forward(self, x):
206
+ H, W = self.H, self.W
207
+ B, L, C = x.shape
208
+ shortcut = x
209
+
210
+ # Focal Modulation
211
+ x = x if self.use_postln else self.norm1(x)
212
+ x = x.view(B, H, W, C)
213
+ x = self.modulation(x).view(B, H * W, C)
214
+ x = x if not self.use_postln else self.norm1(x)
215
+
216
+ # FFN
217
+ x = shortcut + self.drop_path(self.gamma_1 * x)
218
+ x = x + self.drop_path(self.gamma_2 * (self.norm2(
219
+ self.mlp(x)) if self.use_postln else self.mlp(self.norm2(x))))
220
+
221
+ return x
222
+
223
+ def extra_repr(self) -> str:
224
+ return f'dim={self.dim}, input_resolution={self.input_resolution}, ' f'mlp_ratio={self.mlp_ratio}'
225
+
226
+ def flops(self):
227
+ flops = 0
228
+ H, W = self.input_resolution
229
+ # norm1
230
+ flops += self.dim * H * W
231
+
232
+ # W-MSA/SW-MSA
233
+ flops += self.modulation.flops(H * W)
234
+
235
+ # mlp
236
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
237
+ # norm2
238
+ flops += self.dim * H * W
239
+ return flops
240
+
241
+
242
+ class BasicLayer(nn.Module):
243
+ """A basic Focal Transformer layer for one stage.
244
+
245
+ Args:
246
+ dim (int): Number of input channels.
247
+ input_resolution (tuple[int]): Input resolution.
248
+ depth (int): Number of blocks.
249
+ window_size (int): Local window size.
250
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
251
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
252
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
253
+ drop (float, optional): Dropout rate. Default: 0.0
254
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
255
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
256
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
257
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
258
+ focal_level (int): Number of focal levels
259
+ focal_window (int): Focal window size at first focal level
260
+ use_layerscale (bool): Whether use layerscale
261
+ layerscale_value (float): Initial layerscale value
262
+ use_postln (bool): Whether use layernorm after modulation
263
+ """
264
+
265
+ def __init__(
266
+ self,
267
+ dim,
268
+ out_dim,
269
+ input_resolution,
270
+ depth,
271
+ mlp_ratio=4.0,
272
+ drop=0.0,
273
+ drop_path=0.0,
274
+ norm_layer=nn.LayerNorm,
275
+ downsample=None,
276
+ downsample_kernel=[],
277
+ use_checkpoint=False,
278
+ focal_level=1,
279
+ focal_window=1,
280
+ use_conv_embed=False,
281
+ use_layerscale=False,
282
+ layerscale_value=1e-4,
283
+ use_postln=False,
284
+ use_postln_in_modulation=False,
285
+ normalize_modulator=False,
286
+ ):
287
+
288
+ super().__init__()
289
+ self.dim = dim
290
+ self.input_resolution = input_resolution
291
+ self.depth = depth
292
+ self.use_checkpoint = use_checkpoint
293
+
294
+ # build blocks
295
+ self.blocks = nn.ModuleList([
296
+ FocalNetBlock(
297
+ dim=dim,
298
+ input_resolution=input_resolution,
299
+ mlp_ratio=mlp_ratio,
300
+ drop=drop,
301
+ drop_path=drop_path[i]
302
+ if isinstance(drop_path, list) else drop_path,
303
+ norm_layer=norm_layer,
304
+ focal_level=focal_level,
305
+ focal_window=focal_window,
306
+ use_layerscale=use_layerscale,
307
+ layerscale_value=layerscale_value,
308
+ use_postln=use_postln,
309
+ use_postln_in_modulation=use_postln_in_modulation,
310
+ normalize_modulator=normalize_modulator,
311
+ ) for i in range(depth)
312
+ ])
313
+
314
+ if downsample is not None:
315
+ self.downsample = downsample(
316
+ img_size=input_resolution,
317
+ patch_size=downsample_kernel,
318
+ in_chans=dim,
319
+ embed_dim=out_dim,
320
+ use_conv_embed=use_conv_embed,
321
+ norm_layer=norm_layer,
322
+ is_stem=False,
323
+ )
324
+ else:
325
+ self.downsample = None
326
+
327
+ def forward(self, x, H, W):
328
+ for blk in self.blocks:
329
+ blk.H, blk.W = H, W
330
+ if self.use_checkpoint:
331
+ x = checkpoint.checkpoint(blk, x)
332
+ else:
333
+ x = blk(x)
334
+
335
+ if self.downsample is not None:
336
+ x = x.transpose(1, 2).reshape(x.shape[0], -1, H, W)
337
+ x, Ho, Wo = self.downsample(x)
338
+ else:
339
+ Ho, Wo = H, W
340
+ return x, Ho, Wo
341
+
342
+ def extra_repr(self) -> str:
343
+ return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
344
+
345
+ def flops(self):
346
+ flops = 0
347
+ for blk in self.blocks:
348
+ flops += blk.flops()
349
+ if self.downsample is not None:
350
+ flops += self.downsample.flops()
351
+ return flops
352
+
353
+
354
+ class PatchEmbed(nn.Module):
355
+ r"""Image to Patch Embedding
356
+
357
+ Args:
358
+ img_size (int): Image size. Default: 224.
359
+ patch_size (int): Patch token size. Default: 4.
360
+ in_chans (int): Number of input image channels. Default: 3.
361
+ embed_dim (int): Number of linear projection output channels. Default: 96.
362
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
363
+ """
364
+
365
+ def __init__(self,
366
+ img_size=(224, 224),
367
+ patch_size=[4, 4],
368
+ in_chans=3,
369
+ embed_dim=96,
370
+ use_conv_embed=False,
371
+ norm_layer=None,
372
+ is_stem=False):
373
+ super().__init__()
374
+ # patch_size = to_2tuple(patch_size)
375
+ patches_resolution = [
376
+ img_size[0] // patch_size[0], img_size[1] // patch_size[1]
377
+ ]
378
+ self.img_size = img_size
379
+ self.patch_size = patch_size
380
+ self.patches_resolution = patches_resolution
381
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
382
+
383
+ self.in_chans = in_chans
384
+ self.embed_dim = embed_dim
385
+
386
+ if use_conv_embed:
387
+ # if we choose to use conv embedding, then we treat the stem and non-stem differently
388
+ if is_stem:
389
+ kernel_size = 7
390
+ padding = 2
391
+ stride = 4
392
+ else:
393
+ kernel_size = 3
394
+ padding = 1
395
+ stride = 2
396
+ self.proj = nn.Conv2d(in_chans,
397
+ embed_dim,
398
+ kernel_size=kernel_size,
399
+ stride=stride,
400
+ padding=padding)
401
+ else:
402
+ self.proj = nn.Conv2d(in_chans,
403
+ embed_dim,
404
+ kernel_size=patch_size,
405
+ stride=patch_size)
406
+
407
+ if norm_layer is not None:
408
+ self.norm = norm_layer(embed_dim)
409
+ else:
410
+ self.norm = None
411
+
412
+ def forward(self, x):
413
+ B, C, H, W = x.shape
414
+
415
+ x = self.proj(x)
416
+ H, W = x.shape[2:]
417
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
418
+ if self.norm is not None:
419
+ x = self.norm(x)
420
+ return x, H, W
421
+
422
+ def flops(self):
423
+ Ho, Wo = self.patches_resolution
424
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (
425
+ self.patch_size[0] * self.patch_size[1])
426
+ if self.norm is not None:
427
+ flops += Ho * Wo * self.embed_dim
428
+ return flops
429
+
430
+
431
+ class FocalSVTR(nn.Module):
432
+ r"""Focal Modulation Networks (FocalNets)
433
+
434
+ Args:
435
+ img_size (int | tuple(int)): Input image size. Default [32, 128]
436
+ patch_size (int | tuple(int)): Patch size. Default: [4, 4]
437
+ in_chans (int): Number of input image channels. Default: 3
438
+ num_classes (int): Number of classes for classification head. Default: 1000
439
+ embed_dim (int): Patch embedding dimension. Default: 96
440
+ depths (tuple(int)): Depth of each Focal Transformer layer.
441
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
442
+ drop_rate (float): Dropout rate. Default: 0
443
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
444
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
445
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
446
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
447
+ focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level. Default: [1, 1, 1, 1]
448
+ focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1]
449
+ use_conv_embed (bool): Whether use convolutional embedding. We noted that using convolutional embedding usually improve the performance,
450
+ but we do not use it by default. Default: False
451
+ use_layerscale (bool): Whether use layerscale proposed in CaiT. Default: False
452
+ layerscale_value (float): Value for layer scale. Default: 1e-4
453
+ use_postln (bool): Whether use layernorm after modulation (it helps stablize training of large models)
454
+ """
455
+
456
+ def __init__(
457
+ self,
458
+ img_size=[32, 128],
459
+ patch_size=[4, 4],
460
+ out_channels=256,
461
+ out_char_num=25,
462
+ in_channels=3,
463
+ embed_dim=96,
464
+ depths=[3, 6, 3],
465
+ sub_k=[[2, 1], [2, 1], [1, 1]],
466
+ last_stage=False,
467
+ mlp_ratio=4.0,
468
+ drop_rate=0.0,
469
+ drop_path_rate=0.1,
470
+ norm_layer=nn.LayerNorm,
471
+ patch_norm=True,
472
+ use_checkpoint=False,
473
+ focal_levels=[6, 6, 6],
474
+ focal_windows=[3, 3, 3],
475
+ use_conv_embed=False,
476
+ use_layerscale=False,
477
+ layerscale_value=1e-4,
478
+ use_postln=False,
479
+ use_postln_in_modulation=False,
480
+ normalize_modulator=False,
481
+ feat2d=False,
482
+ **kwargs,
483
+ ):
484
+ super().__init__()
485
+
486
+ self.num_layers = len(depths)
487
+ embed_dim = [embed_dim * (2**i) for i in range(self.num_layers)]
488
+ self.feat2d = feat2d
489
+ self.embed_dim = embed_dim
490
+ self.patch_norm = patch_norm
491
+ self.num_features = embed_dim[-1]
492
+ self.mlp_ratio = mlp_ratio
493
+
494
+ self.patch_embed = nn.Sequential(
495
+ ConvBNLayer(
496
+ in_channels=in_channels,
497
+ out_channels=embed_dim[0] // 2,
498
+ kernel_size=3,
499
+ stride=2,
500
+ padding=1,
501
+ act=nn.GELU,
502
+ bias=None,
503
+ ),
504
+ ConvBNLayer(
505
+ in_channels=embed_dim[0] // 2,
506
+ out_channels=embed_dim[0],
507
+ kernel_size=3,
508
+ stride=2,
509
+ padding=1,
510
+ act=nn.GELU,
511
+ bias=None,
512
+ ),
513
+ )
514
+
515
+ patches_resolution = [
516
+ img_size[0] // patch_size[0], img_size[1] // patch_size[1]
517
+ ]
518
+ self.patches_resolution = patches_resolution
519
+ self.pos_drop = nn.Dropout(p=drop_rate)
520
+
521
+ # stochastic depth
522
+ dpr = [
523
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
524
+ ] # stochastic depth decay rule
525
+
526
+ # build layers
527
+ self.layers = nn.ModuleList()
528
+ for i_layer in range(self.num_layers):
529
+
530
+ layer = BasicLayer(
531
+ dim=embed_dim[i_layer],
532
+ out_dim=embed_dim[i_layer + 1] if
533
+ (i_layer < self.num_layers - 1) else None,
534
+ input_resolution=patches_resolution,
535
+ depth=depths[i_layer],
536
+ mlp_ratio=self.mlp_ratio,
537
+ drop=drop_rate,
538
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
539
+ norm_layer=norm_layer,
540
+ downsample=PatchEmbed if
541
+ (i_layer < self.num_layers - 1) else None,
542
+ downsample_kernel=sub_k[i_layer],
543
+ focal_level=focal_levels[i_layer],
544
+ focal_window=focal_windows[i_layer],
545
+ use_conv_embed=use_conv_embed,
546
+ use_checkpoint=use_checkpoint,
547
+ use_layerscale=use_layerscale,
548
+ layerscale_value=layerscale_value,
549
+ use_postln=use_postln,
550
+ use_postln_in_modulation=use_postln_in_modulation,
551
+ normalize_modulator=normalize_modulator,
552
+ )
553
+ patches_resolution = [
554
+ patches_resolution[0] // sub_k[i_layer][0],
555
+ patches_resolution[1] // sub_k[i_layer][1]
556
+ ]
557
+ self.layers.append(layer)
558
+ self.out_channels = self.num_features
559
+ self.last_stage = last_stage
560
+ if last_stage:
561
+ self.out_channels = out_channels
562
+ self.last_conv = nn.Linear(self.num_features,
563
+ self.out_channels,
564
+ bias=False)
565
+ self.hardswish = nn.Hardswish()
566
+ self.dropout = nn.Dropout(p=0.1)
567
+ # self.avg_pool = nn.AdaptiveAvgPool2d([1, out_char_num])
568
+ # self.last_conv = nn.Conv2d(
569
+ # in_channels=self.num_features,
570
+ # out_channels=self.out_channels,
571
+ # kernel_size=1,
572
+ # stride=1,
573
+ # padding=0,
574
+ # bias=False,
575
+ # )
576
+ # self.hardswish = nn.Hardswish()
577
+ # self.dropout = nn.Dropout(p=0.1)
578
+ self.apply(self._init_weights)
579
+
580
+ def _init_weights(self, m):
581
+ if isinstance(m, nn.Linear):
582
+ trunc_normal_(m.weight, std=0.02)
583
+ if isinstance(m, nn.Linear) and m.bias is not None:
584
+ nn.init.constant_(m.bias, 0)
585
+ elif isinstance(m, nn.LayerNorm):
586
+ nn.init.constant_(m.bias, 0)
587
+ nn.init.constant_(m.weight, 1.0)
588
+ elif isinstance(m, nn.Conv2d):
589
+ nn.init.kaiming_normal_(m.weight,
590
+ mode='fan_out',
591
+ nonlinearity='relu')
592
+
593
+ @torch.jit.ignore
594
+ def no_weight_decay(self):
595
+ return {'patch_embed', 'downsample'}
596
+
597
+ def forward(self, x):
598
+ if len(x.shape) == 5:
599
+ x = x.flatten(0, 1)
600
+ x = self.patch_embed(x)
601
+ H, W = x.shape[2:]
602
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
603
+ x = self.pos_drop(x)
604
+
605
+ for layer in self.layers:
606
+ x, H, W = layer(x, H, W)
607
+
608
+ if self.feat2d:
609
+ x = x.transpose(1, 2).reshape(-1, self.num_features, H, W)
610
+
611
+ if self.last_stage:
612
+
613
+ x = x.reshape(-1, H, W, self.num_features).mean(1)
614
+ x = self.last_conv(x)
615
+ x = self.hardswish(x)
616
+ x = self.dropout(x)
617
+ # x = self.avg_pool(x.transpose(1, 2).reshape(-1, self.num_features, H, W))
618
+ # x = self.last_conv(x)
619
+ # x = self.hardswish(x)
620
+ # x = self.dropout(x)
621
+ # x = x.flatten(2).transpose(1, 2)
622
+ return x
623
+
624
+ def flops(self):
625
+ flops = 0
626
+ flops += self.patch_embed.flops()
627
+ for i, layer in enumerate(self.layers):
628
+ flops += layer.flops()
629
+ flops += self.num_features * self.patches_resolution[
630
+ 0] * self.patches_resolution[1] // (2**self.num_layers)
631
+ return flops
@@ -0,0 +1,28 @@
1
+ from torch import nn
2
+
3
+
4
+ class NRTREncoder(nn.Module):
5
+
6
+ def __init__(self, in_channels):
7
+ super(NRTREncoder, self).__init__()
8
+ self.out_channels = 512 # 64*H
9
+ self.block = nn.Sequential(
10
+ nn.Conv2d(
11
+ in_channels=in_channels,
12
+ out_channels=32,
13
+ kernel_size=3,
14
+ stride=2,
15
+ padding=1,
16
+ ), nn.ReLU(), nn.BatchNorm2d(32),
17
+ nn.Conv2d(
18
+ in_channels=32,
19
+ out_channels=64,
20
+ kernel_size=3,
21
+ stride=2,
22
+ padding=1,
23
+ ), nn.ReLU(), nn.BatchNorm2d(64))
24
+
25
+ def forward(self, images):
26
+ x = self.block(images)
27
+ x = x.permute(0, 3, 2, 1).flatten(2) # B, W, H*C
28
+ return x