openocr-python 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. openocr/__init__.py +11 -0
  2. openocr/configs/det/dbnet/repvit_db.yml +173 -0
  3. openocr/configs/rec/abinet/resnet45_trans_abinet_lang.yml +94 -0
  4. openocr/configs/rec/abinet/resnet45_trans_abinet_wo_lang.yml +93 -0
  5. openocr/configs/rec/abinet/svtrv2_abinet_lang.yml +130 -0
  6. openocr/configs/rec/abinet/svtrv2_abinet_wo_lang.yml +128 -0
  7. openocr/configs/rec/aster/resnet31_lstm_aster_tps_on.yml +93 -0
  8. openocr/configs/rec/aster/svtrv2_aster.yml +127 -0
  9. openocr/configs/rec/aster/svtrv2_aster_tps_on.yml +102 -0
  10. openocr/configs/rec/autostr/autostr_lstm_aster_tps_on.yml +95 -0
  11. openocr/configs/rec/busnet/svtrv2_busnet.yml +135 -0
  12. openocr/configs/rec/busnet/svtrv2_busnet_pretraining.yml +134 -0
  13. openocr/configs/rec/busnet/vit_busnet.yml +104 -0
  14. openocr/configs/rec/busnet/vit_busnet_pretraining.yml +104 -0
  15. openocr/configs/rec/cam/convnextv2_cam_tps_on.yml +118 -0
  16. openocr/configs/rec/cam/convnextv2_tiny_cam_tps_on.yml +118 -0
  17. openocr/configs/rec/cam/svtrv2_cam_tps_on.yml +123 -0
  18. openocr/configs/rec/cdistnet/resnet45_trans_cdistnet.yml +93 -0
  19. openocr/configs/rec/cdistnet/svtrv2_cdistnet.yml +139 -0
  20. openocr/configs/rec/cppd/svtr_base_cppd.yml +123 -0
  21. openocr/configs/rec/cppd/svtr_base_cppd_ch.yml +126 -0
  22. openocr/configs/rec/cppd/svtr_base_cppd_h8.yml +123 -0
  23. openocr/configs/rec/cppd/svtr_base_cppd_syn.yml +124 -0
  24. openocr/configs/rec/cppd/svtrv2_cppd.yml +150 -0
  25. openocr/configs/rec/dan/resnet45_fpn_dan.yml +98 -0
  26. openocr/configs/rec/dan/svtrv2_dan.yml +130 -0
  27. openocr/configs/rec/focalsvtr/focalsvtr_ctc.yml +137 -0
  28. openocr/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +168 -0
  29. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +151 -0
  30. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +150 -0
  31. openocr/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +152 -0
  32. openocr/configs/rec/igtr/svtr_base_ds_igtr.yml +157 -0
  33. openocr/configs/rec/lister/focalsvtr_lister_wo_fem_maxratio12.yml +133 -0
  34. openocr/configs/rec/lister/svtrv2_lister_wo_fem_maxratio12.yml +138 -0
  35. openocr/configs/rec/lpv/svtr_base_lpv.yml +124 -0
  36. openocr/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +123 -0
  37. openocr/configs/rec/lpv/svtrv2_lpv.yml +147 -0
  38. openocr/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +146 -0
  39. openocr/configs/rec/maerec/vit_nrtr.yml +116 -0
  40. openocr/configs/rec/matrn/resnet45_trans_matrn.yml +95 -0
  41. openocr/configs/rec/matrn/svtrv2_matrn.yml +130 -0
  42. openocr/configs/rec/mgpstr/svtrv2_mgpstr_only_char.yml +140 -0
  43. openocr/configs/rec/mgpstr/vit_base_mgpstr_only_char.yml +111 -0
  44. openocr/configs/rec/mgpstr/vit_large_mgpstr_only_char.yml +110 -0
  45. openocr/configs/rec/mgpstr/vit_mgpstr.yml +110 -0
  46. openocr/configs/rec/mgpstr/vit_mgpstr_only_char.yml +110 -0
  47. openocr/configs/rec/moran/resnet31_lstm_moran.yml +92 -0
  48. openocr/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +145 -0
  49. openocr/configs/rec/nrtr/nrtr.yml +107 -0
  50. openocr/configs/rec/nrtr/svtr_base_nrtr.yml +118 -0
  51. openocr/configs/rec/nrtr/svtr_base_nrtr_syn.yml +119 -0
  52. openocr/configs/rec/nrtr/svtrv2_nrtr.yml +146 -0
  53. openocr/configs/rec/ote/svtr_base_h8_ote.yml +117 -0
  54. openocr/configs/rec/ote/svtr_base_ote.yml +116 -0
  55. openocr/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +140 -0
  56. openocr/configs/rec/parseq/svrtv2_parseq.yml +136 -0
  57. openocr/configs/rec/parseq/vit_parseq.yml +100 -0
  58. openocr/configs/rec/robustscanner/resnet31_robustscanner.yml +102 -0
  59. openocr/configs/rec/robustscanner/svtrv2_robustscanner.yml +134 -0
  60. openocr/configs/rec/sar/resnet31_lstm_sar.yml +94 -0
  61. openocr/configs/rec/sar/svtrv2_sar.yml +128 -0
  62. openocr/configs/rec/seed/resnet31_lstm_seed_tps_on.yml +96 -0
  63. openocr/configs/rec/smtr/focalsvtr_smtr.yml +150 -0
  64. openocr/configs/rec/smtr/focalsvtr_smtr_long.yml +133 -0
  65. openocr/configs/rec/smtr/svtrv2_smtr.yml +150 -0
  66. openocr/configs/rec/smtr/svtrv2_smtr_bi.yml +136 -0
  67. openocr/configs/rec/srn/resnet50_fpn_srn.yml +97 -0
  68. openocr/configs/rec/srn/svtrv2_srn.yml +131 -0
  69. openocr/configs/rec/svtrs/convnextv2_ctc.yml +105 -0
  70. openocr/configs/rec/svtrs/convnextv2_h8_ctc.yml +105 -0
  71. openocr/configs/rec/svtrs/convnextv2_h8_rctc.yml +106 -0
  72. openocr/configs/rec/svtrs/convnextv2_rctc.yml +106 -0
  73. openocr/configs/rec/svtrs/convnextv2_tiny_h8_ctc.yml +105 -0
  74. openocr/configs/rec/svtrs/convnextv2_tiny_h8_rctc.yml +106 -0
  75. openocr/configs/rec/svtrs/crnn_ctc.yml +99 -0
  76. openocr/configs/rec/svtrs/crnn_ctc_long.yml +116 -0
  77. openocr/configs/rec/svtrs/focalnet_base_ctc.yml +108 -0
  78. openocr/configs/rec/svtrs/focalnet_base_rctc.yml +109 -0
  79. openocr/configs/rec/svtrs/focalsvtr_ctc.yml +106 -0
  80. openocr/configs/rec/svtrs/focalsvtr_rctc.yml +107 -0
  81. openocr/configs/rec/svtrs/resnet45_trans_ctc.yml +103 -0
  82. openocr/configs/rec/svtrs/resnet45_trans_rctc.yml +104 -0
  83. openocr/configs/rec/svtrs/svtr_base_ctc.yml +110 -0
  84. openocr/configs/rec/svtrs/svtr_base_rctc.yml +111 -0
  85. openocr/configs/rec/svtrs/svtrnet_ctc_syn.yml +111 -0
  86. openocr/configs/rec/svtrs/vit_ctc.yml +103 -0
  87. openocr/configs/rec/svtrs/vit_rctc.yml +103 -0
  88. openocr/configs/rec/svtrv2/repsvtr_ch.yml +121 -0
  89. openocr/configs/rec/svtrv2/svtrv2_ch.yml +133 -0
  90. openocr/configs/rec/svtrv2/svtrv2_ctc.yml +136 -0
  91. openocr/configs/rec/svtrv2/svtrv2_rctc.yml +135 -0
  92. openocr/configs/rec/svtrv2/svtrv2_small_rctc.yml +135 -0
  93. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +162 -0
  94. openocr/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +153 -0
  95. openocr/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +135 -0
  96. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LA.yml +103 -0
  97. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_1.yml +102 -0
  98. openocr/configs/rec/visionlan/resnet45_trans_visionlan_LF_2.yml +103 -0
  99. openocr/configs/rec/visionlan/svtrv2_visionlan_LA.yml +112 -0
  100. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_1.yml +111 -0
  101. openocr/configs/rec/visionlan/svtrv2_visionlan_LF_2.yml +112 -0
  102. openocr/demo_gradio.py +128 -0
  103. openocr/opendet/modeling/__init__.py +11 -0
  104. openocr/opendet/modeling/backbones/__init__.py +14 -0
  105. openocr/opendet/modeling/backbones/repvit.py +340 -0
  106. openocr/opendet/modeling/base_detector.py +69 -0
  107. openocr/opendet/modeling/heads/__init__.py +14 -0
  108. openocr/opendet/modeling/heads/db_head.py +73 -0
  109. openocr/opendet/modeling/necks/__init__.py +14 -0
  110. openocr/opendet/modeling/necks/db_fpn.py +609 -0
  111. openocr/opendet/postprocess/__init__.py +18 -0
  112. openocr/opendet/postprocess/db_postprocess.py +273 -0
  113. openocr/opendet/preprocess/__init__.py +154 -0
  114. openocr/opendet/preprocess/crop_resize.py +121 -0
  115. openocr/opendet/preprocess/db_resize_for_test.py +135 -0
  116. openocr/openrec/losses/__init__.py +62 -0
  117. openocr/openrec/losses/abinet_loss.py +42 -0
  118. openocr/openrec/losses/ar_loss.py +23 -0
  119. openocr/openrec/losses/cam_loss.py +48 -0
  120. openocr/openrec/losses/cdistnet_loss.py +34 -0
  121. openocr/openrec/losses/ce_loss.py +68 -0
  122. openocr/openrec/losses/cppd_loss.py +77 -0
  123. openocr/openrec/losses/ctc_loss.py +33 -0
  124. openocr/openrec/losses/igtr_loss.py +12 -0
  125. openocr/openrec/losses/lister_loss.py +14 -0
  126. openocr/openrec/losses/lpv_loss.py +30 -0
  127. openocr/openrec/losses/mgp_loss.py +34 -0
  128. openocr/openrec/losses/parseq_loss.py +12 -0
  129. openocr/openrec/losses/robustscanner_loss.py +20 -0
  130. openocr/openrec/losses/seed_loss.py +46 -0
  131. openocr/openrec/losses/smtr_loss.py +12 -0
  132. openocr/openrec/losses/srn_loss.py +40 -0
  133. openocr/openrec/losses/visionlan_loss.py +58 -0
  134. openocr/openrec/metrics/__init__.py +19 -0
  135. openocr/openrec/metrics/rec_metric.py +270 -0
  136. openocr/openrec/metrics/rec_metric_gtc.py +58 -0
  137. openocr/openrec/metrics/rec_metric_long.py +142 -0
  138. openocr/openrec/metrics/rec_metric_mgp.py +93 -0
  139. openocr/openrec/modeling/__init__.py +11 -0
  140. openocr/openrec/modeling/base_recognizer.py +69 -0
  141. openocr/openrec/modeling/common.py +238 -0
  142. openocr/openrec/modeling/decoders/__init__.py +109 -0
  143. openocr/openrec/modeling/decoders/abinet_decoder.py +283 -0
  144. openocr/openrec/modeling/decoders/aster_decoder.py +170 -0
  145. openocr/openrec/modeling/decoders/bus_decoder.py +133 -0
  146. openocr/openrec/modeling/decoders/cam_decoder.py +43 -0
  147. openocr/openrec/modeling/decoders/cdistnet_decoder.py +334 -0
  148. openocr/openrec/modeling/decoders/cppd_decoder.py +393 -0
  149. openocr/openrec/modeling/decoders/ctc_decoder.py +203 -0
  150. openocr/openrec/modeling/decoders/dan_decoder.py +203 -0
  151. openocr/openrec/modeling/decoders/igtr_decoder.py +815 -0
  152. openocr/openrec/modeling/decoders/lister_decoder.py +535 -0
  153. openocr/openrec/modeling/decoders/lpv_decoder.py +119 -0
  154. openocr/openrec/modeling/decoders/matrn_decoder.py +236 -0
  155. openocr/openrec/modeling/decoders/mgp_decoder.py +99 -0
  156. openocr/openrec/modeling/decoders/nrtr_decoder.py +439 -0
  157. openocr/openrec/modeling/decoders/ote_decoder.py +205 -0
  158. openocr/openrec/modeling/decoders/parseq_decoder.py +504 -0
  159. openocr/openrec/modeling/decoders/rctc_decoder.py +70 -0
  160. openocr/openrec/modeling/decoders/robustscanner_decoder.py +749 -0
  161. openocr/openrec/modeling/decoders/sar_decoder.py +236 -0
  162. openocr/openrec/modeling/decoders/smtr_decoder.py +621 -0
  163. openocr/openrec/modeling/decoders/smtr_decoder_nattn.py +521 -0
  164. openocr/openrec/modeling/decoders/srn_decoder.py +283 -0
  165. openocr/openrec/modeling/decoders/visionlan_decoder.py +321 -0
  166. openocr/openrec/modeling/encoders/__init__.py +39 -0
  167. openocr/openrec/modeling/encoders/autostr_encoder.py +327 -0
  168. openocr/openrec/modeling/encoders/cam_encoder.py +760 -0
  169. openocr/openrec/modeling/encoders/convnextv2.py +213 -0
  170. openocr/openrec/modeling/encoders/focalsvtr.py +631 -0
  171. openocr/openrec/modeling/encoders/nrtr_encoder.py +28 -0
  172. openocr/openrec/modeling/encoders/rec_hgnet.py +346 -0
  173. openocr/openrec/modeling/encoders/rec_lcnetv3.py +488 -0
  174. openocr/openrec/modeling/encoders/rec_mobilenet_v3.py +132 -0
  175. openocr/openrec/modeling/encoders/rec_mv1_enhance.py +254 -0
  176. openocr/openrec/modeling/encoders/rec_nrtr_mtb.py +37 -0
  177. openocr/openrec/modeling/encoders/rec_resnet_31.py +213 -0
  178. openocr/openrec/modeling/encoders/rec_resnet_45.py +183 -0
  179. openocr/openrec/modeling/encoders/rec_resnet_fpn.py +216 -0
  180. openocr/openrec/modeling/encoders/rec_resnet_vd.py +252 -0
  181. openocr/openrec/modeling/encoders/repvit.py +338 -0
  182. openocr/openrec/modeling/encoders/resnet31_rnn.py +123 -0
  183. openocr/openrec/modeling/encoders/svtrnet.py +574 -0
  184. openocr/openrec/modeling/encoders/svtrnet2dpos.py +616 -0
  185. openocr/openrec/modeling/encoders/svtrv2.py +470 -0
  186. openocr/openrec/modeling/encoders/svtrv2_lnconv.py +503 -0
  187. openocr/openrec/modeling/encoders/svtrv2_lnconv_two33.py +517 -0
  188. openocr/openrec/modeling/encoders/vit.py +120 -0
  189. openocr/openrec/modeling/transforms/__init__.py +15 -0
  190. openocr/openrec/modeling/transforms/aster_tps.py +262 -0
  191. openocr/openrec/modeling/transforms/moran.py +136 -0
  192. openocr/openrec/modeling/transforms/tps.py +246 -0
  193. openocr/openrec/optimizer/__init__.py +73 -0
  194. openocr/openrec/optimizer/lr.py +227 -0
  195. openocr/openrec/postprocess/__init__.py +72 -0
  196. openocr/openrec/postprocess/abinet_postprocess.py +37 -0
  197. openocr/openrec/postprocess/ar_postprocess.py +63 -0
  198. openocr/openrec/postprocess/ce_postprocess.py +43 -0
  199. openocr/openrec/postprocess/char_postprocess.py +108 -0
  200. openocr/openrec/postprocess/cppd_postprocess.py +42 -0
  201. openocr/openrec/postprocess/ctc_postprocess.py +119 -0
  202. openocr/openrec/postprocess/igtr_postprocess.py +100 -0
  203. openocr/openrec/postprocess/lister_postprocess.py +59 -0
  204. openocr/openrec/postprocess/mgp_postprocess.py +143 -0
  205. openocr/openrec/postprocess/nrtr_postprocess.py +75 -0
  206. openocr/openrec/postprocess/smtr_postprocess.py +73 -0
  207. openocr/openrec/postprocess/srn_postprocess.py +80 -0
  208. openocr/openrec/postprocess/visionlan_postprocess.py +81 -0
  209. openocr/openrec/preprocess/__init__.py +173 -0
  210. openocr/openrec/preprocess/abinet_aug.py +473 -0
  211. openocr/openrec/preprocess/abinet_label_encode.py +36 -0
  212. openocr/openrec/preprocess/ar_label_encode.py +36 -0
  213. openocr/openrec/preprocess/auto_augment.py +1012 -0
  214. openocr/openrec/preprocess/cam_label_encode.py +141 -0
  215. openocr/openrec/preprocess/ce_label_encode.py +116 -0
  216. openocr/openrec/preprocess/char_label_encode.py +36 -0
  217. openocr/openrec/preprocess/cppd_label_encode.py +173 -0
  218. openocr/openrec/preprocess/ctc_label_encode.py +124 -0
  219. openocr/openrec/preprocess/ep_label_encode.py +38 -0
  220. openocr/openrec/preprocess/igtr_label_encode.py +360 -0
  221. openocr/openrec/preprocess/mgp_label_encode.py +95 -0
  222. openocr/openrec/preprocess/parseq_aug.py +150 -0
  223. openocr/openrec/preprocess/rec_aug.py +211 -0
  224. openocr/openrec/preprocess/resize.py +534 -0
  225. openocr/openrec/preprocess/smtr_label_encode.py +125 -0
  226. openocr/openrec/preprocess/srn_label_encode.py +37 -0
  227. openocr/openrec/preprocess/visionlan_label_encode.py +67 -0
  228. openocr/tools/create_lmdb_dataset.py +118 -0
  229. openocr/tools/data/__init__.py +94 -0
  230. openocr/tools/data/collate_fn.py +100 -0
  231. openocr/tools/data/lmdb_dataset.py +142 -0
  232. openocr/tools/data/lmdb_dataset_test.py +166 -0
  233. openocr/tools/data/multi_scale_sampler.py +177 -0
  234. openocr/tools/data/ratio_dataset.py +217 -0
  235. openocr/tools/data/ratio_dataset_test.py +273 -0
  236. openocr/tools/data/ratio_dataset_tvresize.py +213 -0
  237. openocr/tools/data/ratio_dataset_tvresize_test.py +276 -0
  238. openocr/tools/data/ratio_sampler.py +190 -0
  239. openocr/tools/data/simple_dataset.py +263 -0
  240. openocr/tools/data/strlmdb_dataset.py +143 -0
  241. openocr/tools/engine/__init__.py +5 -0
  242. openocr/tools/engine/config.py +158 -0
  243. openocr/tools/engine/trainer.py +621 -0
  244. openocr/tools/eval_rec.py +41 -0
  245. openocr/tools/eval_rec_all_ch.py +184 -0
  246. openocr/tools/eval_rec_all_en.py +206 -0
  247. openocr/tools/eval_rec_all_long.py +119 -0
  248. openocr/tools/eval_rec_all_long_simple.py +122 -0
  249. openocr/tools/export_rec.py +118 -0
  250. openocr/tools/infer/onnx_engine.py +65 -0
  251. openocr/tools/infer/predict_rec.py +140 -0
  252. openocr/tools/infer/utility.py +234 -0
  253. openocr/tools/infer_det.py +449 -0
  254. openocr/tools/infer_e2e.py +462 -0
  255. openocr/tools/infer_e2e_parallel.py +184 -0
  256. openocr/tools/infer_rec.py +371 -0
  257. openocr/tools/train_rec.py +37 -0
  258. openocr/tools/utility.py +45 -0
  259. openocr/tools/utils/EN_symbol_dict.txt +94 -0
  260. openocr/tools/utils/__init__.py +0 -0
  261. openocr/tools/utils/ckpt.py +87 -0
  262. openocr/tools/utils/dict/ar_dict.txt +117 -0
  263. openocr/tools/utils/dict/arabic_dict.txt +161 -0
  264. openocr/tools/utils/dict/be_dict.txt +145 -0
  265. openocr/tools/utils/dict/bg_dict.txt +140 -0
  266. openocr/tools/utils/dict/chinese_cht_dict.txt +8421 -0
  267. openocr/tools/utils/dict/cyrillic_dict.txt +163 -0
  268. openocr/tools/utils/dict/devanagari_dict.txt +167 -0
  269. openocr/tools/utils/dict/en_dict.txt +63 -0
  270. openocr/tools/utils/dict/fa_dict.txt +136 -0
  271. openocr/tools/utils/dict/french_dict.txt +136 -0
  272. openocr/tools/utils/dict/german_dict.txt +143 -0
  273. openocr/tools/utils/dict/hi_dict.txt +162 -0
  274. openocr/tools/utils/dict/it_dict.txt +118 -0
  275. openocr/tools/utils/dict/japan_dict.txt +4399 -0
  276. openocr/tools/utils/dict/ka_dict.txt +153 -0
  277. openocr/tools/utils/dict/kie_dict/xfund_class_list.txt +4 -0
  278. openocr/tools/utils/dict/korean_dict.txt +3688 -0
  279. openocr/tools/utils/dict/latex_symbol_dict.txt +111 -0
  280. openocr/tools/utils/dict/latin_dict.txt +185 -0
  281. openocr/tools/utils/dict/layout_dict/layout_cdla_dict.txt +10 -0
  282. openocr/tools/utils/dict/layout_dict/layout_publaynet_dict.txt +5 -0
  283. openocr/tools/utils/dict/layout_dict/layout_table_dict.txt +1 -0
  284. openocr/tools/utils/dict/mr_dict.txt +153 -0
  285. openocr/tools/utils/dict/ne_dict.txt +153 -0
  286. openocr/tools/utils/dict/oc_dict.txt +96 -0
  287. openocr/tools/utils/dict/pu_dict.txt +130 -0
  288. openocr/tools/utils/dict/rs_dict.txt +91 -0
  289. openocr/tools/utils/dict/rsc_dict.txt +134 -0
  290. openocr/tools/utils/dict/ru_dict.txt +125 -0
  291. openocr/tools/utils/dict/spin_dict.txt +68 -0
  292. openocr/tools/utils/dict/ta_dict.txt +128 -0
  293. openocr/tools/utils/dict/table_dict.txt +277 -0
  294. openocr/tools/utils/dict/table_master_structure_dict.txt +39 -0
  295. openocr/tools/utils/dict/table_structure_dict.txt +28 -0
  296. openocr/tools/utils/dict/table_structure_dict_ch.txt +48 -0
  297. openocr/tools/utils/dict/te_dict.txt +151 -0
  298. openocr/tools/utils/dict/ug_dict.txt +114 -0
  299. openocr/tools/utils/dict/uk_dict.txt +142 -0
  300. openocr/tools/utils/dict/ur_dict.txt +137 -0
  301. openocr/tools/utils/dict/xi_dict.txt +110 -0
  302. openocr/tools/utils/dict90.txt +90 -0
  303. openocr/tools/utils/e2e_metric/Deteval.py +802 -0
  304. openocr/tools/utils/e2e_metric/polygon_fast.py +70 -0
  305. openocr/tools/utils/e2e_utils/extract_batchsize.py +86 -0
  306. openocr/tools/utils/e2e_utils/extract_textpoint_fast.py +479 -0
  307. openocr/tools/utils/e2e_utils/extract_textpoint_slow.py +582 -0
  308. openocr/tools/utils/e2e_utils/pgnet_pp_utils.py +159 -0
  309. openocr/tools/utils/e2e_utils/visual.py +152 -0
  310. openocr/tools/utils/en_dict.txt +95 -0
  311. openocr/tools/utils/gen_label.py +68 -0
  312. openocr/tools/utils/ic15_dict.txt +36 -0
  313. openocr/tools/utils/logging.py +56 -0
  314. openocr/tools/utils/poly_nms.py +132 -0
  315. openocr/tools/utils/ppocr_keys_v1.txt +6623 -0
  316. openocr/tools/utils/stats.py +58 -0
  317. openocr/tools/utils/utility.py +165 -0
  318. openocr/tools/utils/visual.py +117 -0
  319. openocr_python-0.0.2.dist-info/LICENCE +201 -0
  320. openocr_python-0.0.2.dist-info/METADATA +98 -0
  321. openocr_python-0.0.2.dist-info/RECORD +323 -0
  322. openocr_python-0.0.2.dist-info/WHEEL +5 -0
  323. openocr_python-0.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,488 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from openrec.modeling.common import Activation
6
+
7
+ NET_CONFIG_det = {
8
+ 'blocks2':
9
+ # k, in_c, out_c, s, use_se
10
+ [[3, 16, 32, 1, False]],
11
+ 'blocks3': [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
12
+ 'blocks4': [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
13
+ 'blocks5': [
14
+ [3, 128, 256, 2, False],
15
+ [5, 256, 256, 1, False],
16
+ [5, 256, 256, 1, False],
17
+ [5, 256, 256, 1, False],
18
+ [5, 256, 256, 1, False],
19
+ ],
20
+ 'blocks6': [
21
+ [5, 256, 512, 2, True],
22
+ [5, 512, 512, 1, True],
23
+ [5, 512, 512, 1, False],
24
+ [5, 512, 512, 1, False],
25
+ ],
26
+ }
27
+
28
+ NET_CONFIG_rec = {
29
+ 'blocks2':
30
+ # k, in_c, out_c, s, use_se
31
+ [[3, 16, 32, 1, False]],
32
+ 'blocks3': [[3, 32, 64, 1, False], [3, 64, 64, 1, False]],
33
+ 'blocks4': [[3, 64, 128, (2, 1), False], [3, 128, 128, 1, False]],
34
+ 'blocks5': [
35
+ [3, 128, 256, (1, 2), False],
36
+ [5, 256, 256, 1, False],
37
+ [5, 256, 256, 1, False],
38
+ [5, 256, 256, 1, False],
39
+ [5, 256, 256, 1, False],
40
+ ],
41
+ 'blocks6': [
42
+ [5, 256, 512, (2, 1), True],
43
+ [5, 512, 512, 1, True],
44
+ [5, 512, 512, (2, 1), False],
45
+ [5, 512, 512, 1, False],
46
+ ],
47
+ }
48
+
49
+
50
+ def make_divisible(v, divisor=16, min_value=None):
51
+ if min_value is None:
52
+ min_value = divisor
53
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
54
+ if new_v < 0.9 * v:
55
+ new_v += divisor
56
+ return new_v
57
+
58
+
59
+ class LearnableAffineBlock(nn.Module):
60
+
61
+ def __init__(self,
62
+ scale_value=1.0,
63
+ bias_value=0.0,
64
+ lr_mult=1.0,
65
+ lab_lr=0.1):
66
+ super().__init__()
67
+ self.scale = nn.Parameter(torch.Tensor([scale_value]))
68
+ self.bias = nn.Parameter(torch.Tensor([bias_value]))
69
+
70
+ def forward(self, x):
71
+ return self.scale * x + self.bias
72
+
73
+
74
+ class ConvBNLayer(nn.Module):
75
+
76
+ def __init__(self,
77
+ in_channels,
78
+ out_channels,
79
+ kernel_size,
80
+ stride,
81
+ groups=1,
82
+ lr_mult=1.0):
83
+ super().__init__()
84
+ self.conv = nn.Conv2d(
85
+ in_channels=in_channels,
86
+ out_channels=out_channels,
87
+ kernel_size=kernel_size,
88
+ stride=stride,
89
+ padding=(kernel_size - 1) // 2,
90
+ groups=groups,
91
+ bias=False,
92
+ )
93
+
94
+ self.bn = nn.BatchNorm2d(out_channels)
95
+
96
+ def forward(self, x):
97
+ x = self.conv(x)
98
+ x = self.bn(x)
99
+ return x
100
+
101
+
102
+ class Act(nn.Module):
103
+
104
+ def __init__(self, act='hard_swish', lr_mult=1.0, lab_lr=0.1):
105
+ super().__init__()
106
+ assert act in ['hard_swish', 'relu']
107
+ self.act = Activation(act)
108
+ self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
109
+
110
+ def forward(self, x):
111
+ return self.lab(self.act(x))
112
+
113
+
114
+ class LearnableRepLayer(nn.Module):
115
+
116
+ def __init__(
117
+ self,
118
+ in_channels,
119
+ out_channels,
120
+ kernel_size,
121
+ stride=1,
122
+ groups=1,
123
+ num_conv_branches=1,
124
+ lr_mult=1.0,
125
+ lab_lr=0.1,
126
+ ):
127
+ super().__init__()
128
+ self.is_repped = False
129
+ self.groups = groups
130
+ self.stride = stride
131
+ self.kernel_size = kernel_size
132
+ self.in_channels = in_channels
133
+ self.out_channels = out_channels
134
+ self.num_conv_branches = num_conv_branches
135
+ self.padding = (kernel_size - 1) // 2
136
+
137
+ self.identity = (nn.BatchNorm2d(in_channels) if
138
+ out_channels == in_channels and stride == 1 else None)
139
+
140
+ self.conv_kxk = nn.ModuleList([
141
+ ConvBNLayer(
142
+ in_channels,
143
+ out_channels,
144
+ kernel_size,
145
+ stride,
146
+ groups=groups,
147
+ lr_mult=lr_mult,
148
+ ) for _ in range(self.num_conv_branches)
149
+ ])
150
+
151
+ self.conv_1x1 = (ConvBNLayer(in_channels,
152
+ out_channels,
153
+ 1,
154
+ stride,
155
+ groups=groups,
156
+ lr_mult=lr_mult)
157
+ if kernel_size > 1 else None)
158
+
159
+ self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
160
+ self.act = Act(lr_mult=lr_mult, lab_lr=lab_lr)
161
+
162
+ def forward(self, x):
163
+ # for export
164
+ if self.is_repped:
165
+ out = self.lab(self.reparam_conv(x))
166
+ if self.stride != 2:
167
+ out = self.act(out)
168
+ return out
169
+
170
+ out = 0
171
+ if self.identity is not None:
172
+ out += self.identity(x)
173
+
174
+ if self.conv_1x1 is not None:
175
+ out += self.conv_1x1(x)
176
+
177
+ for conv in self.conv_kxk:
178
+ out += conv(x)
179
+
180
+ out = self.lab(out)
181
+ if self.stride != 2:
182
+ out = self.act(out)
183
+ return out
184
+
185
+ def rep(self):
186
+ if self.is_repped:
187
+ return
188
+ kernel, bias = self._get_kernel_bias()
189
+ self.reparam_conv = nn.Conv2d(
190
+ in_channels=self.in_channels,
191
+ out_channels=self.out_channels,
192
+ kernel_size=self.kernel_size,
193
+ stride=self.stride,
194
+ padding=self.padding,
195
+ groups=self.groups,
196
+ )
197
+ self.reparam_conv.weight.data = kernel
198
+ self.reparam_conv.bias.data = bias
199
+ self.is_repped = True
200
+
201
+ def _pad_kernel_1x1_to_kxk(self, kernel1x1, pad):
202
+ if not isinstance(kernel1x1, torch.Tensor):
203
+ return 0
204
+ else:
205
+ return nn.functional.pad(kernel1x1, [pad, pad, pad, pad])
206
+
207
+ def _get_kernel_bias(self):
208
+ kernel_conv_1x1, bias_conv_1x1 = self._fuse_bn_tensor(self.conv_1x1)
209
+ kernel_conv_1x1 = self._pad_kernel_1x1_to_kxk(kernel_conv_1x1,
210
+ self.kernel_size // 2)
211
+
212
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
213
+
214
+ kernel_conv_kxk = 0
215
+ bias_conv_kxk = 0
216
+ for conv in self.conv_kxk:
217
+ kernel, bias = self._fuse_bn_tensor(conv)
218
+ kernel_conv_kxk += kernel
219
+ bias_conv_kxk += bias
220
+
221
+ kernel_reparam = kernel_conv_kxk + kernel_conv_1x1 + kernel_identity
222
+ bias_reparam = bias_conv_kxk + bias_conv_1x1 + bias_identity
223
+ return kernel_reparam, bias_reparam
224
+
225
+ def _fuse_bn_tensor(self, branch):
226
+ if not branch:
227
+ return 0, 0
228
+ elif isinstance(branch, ConvBNLayer):
229
+ kernel = branch.conv.weight
230
+ running_mean = branch.bn.running_mean
231
+ running_var = branch.bn.running_var
232
+ gamma = branch.bn.weight
233
+ beta = branch.bn.bias
234
+ eps = branch.bn.eps
235
+ else:
236
+ assert isinstance(branch, nn.BatchNorm2d)
237
+ if not hasattr(self, 'id_tensor'):
238
+ input_dim = self.in_channels // self.groups
239
+ kernel_value = torch.zeros(
240
+ (self.in_channels, input_dim, self.kernel_size,
241
+ self.kernel_size),
242
+ dtype=branch.weight.dtype,
243
+ )
244
+ for i in range(self.in_channels):
245
+ kernel_value[i, i % input_dim, self.kernel_size // 2,
246
+ self.kernel_size // 2] = 1
247
+ self.id_tensor = kernel_value
248
+ kernel = self.id_tensor
249
+ running_mean = branch.running_mean
250
+ running_var = branch.running_var
251
+ gamma = branch.weight
252
+ beta = branch.bias
253
+ eps = branch.eps
254
+ std = (running_var + eps).sqrt()
255
+ t = (gamma / std).reshape((-1, 1, 1, 1))
256
+ return kernel * t, beta - running_mean * gamma / std
257
+
258
+
259
+ class SELayer(nn.Module):
260
+
261
+ def __init__(self, channel, reduction=4, lr_mult=1.0):
262
+ super().__init__()
263
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
264
+ self.conv1 = nn.Conv2d(
265
+ in_channels=channel,
266
+ out_channels=channel // reduction,
267
+ kernel_size=1,
268
+ stride=1,
269
+ padding=0,
270
+ )
271
+ self.relu = nn.ReLU()
272
+ self.conv2 = nn.Conv2d(
273
+ in_channels=channel // reduction,
274
+ out_channels=channel,
275
+ kernel_size=1,
276
+ stride=1,
277
+ padding=0,
278
+ )
279
+ self.hardsigmoid = Activation('hard_sigmoid')
280
+
281
+ def forward(self, x):
282
+ identity = x
283
+ x = self.avg_pool(x)
284
+ x = self.conv1(x)
285
+ x = self.relu(x)
286
+ x = self.conv2(x)
287
+ x = self.hardsigmoid(x)
288
+ x = x * identity
289
+ return x
290
+
291
+
292
+ class LCNetV3Block(nn.Module):
293
+
294
+ def __init__(
295
+ self,
296
+ in_channels,
297
+ out_channels,
298
+ stride,
299
+ dw_size,
300
+ use_se=False,
301
+ conv_kxk_num=4,
302
+ lr_mult=1.0,
303
+ lab_lr=0.1,
304
+ ):
305
+ super().__init__()
306
+ self.use_se = use_se
307
+ self.dw_conv = LearnableRepLayer(
308
+ in_channels=in_channels,
309
+ out_channels=in_channels,
310
+ kernel_size=dw_size,
311
+ stride=stride,
312
+ groups=in_channels,
313
+ num_conv_branches=conv_kxk_num,
314
+ lr_mult=lr_mult,
315
+ lab_lr=lab_lr,
316
+ )
317
+ if use_se:
318
+ self.se = SELayer(in_channels, lr_mult=lr_mult)
319
+ self.pw_conv = LearnableRepLayer(
320
+ in_channels=in_channels,
321
+ out_channels=out_channels,
322
+ kernel_size=1,
323
+ stride=1,
324
+ num_conv_branches=conv_kxk_num,
325
+ lr_mult=lr_mult,
326
+ lab_lr=lab_lr,
327
+ )
328
+
329
+ def forward(self, x):
330
+ x = self.dw_conv(x)
331
+ if self.use_se:
332
+ x = self.se(x)
333
+ x = self.pw_conv(x)
334
+ return x
335
+
336
+
337
+ class PPLCNetV3(nn.Module):
338
+
339
+ def __init__(self,
340
+ scale=1.0,
341
+ conv_kxk_num=4,
342
+ lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
343
+ lab_lr=0.1,
344
+ det=False,
345
+ **kwargs):
346
+ super().__init__()
347
+ self.scale = scale
348
+ self.lr_mult_list = lr_mult_list
349
+ self.det = det
350
+
351
+ self.net_config = NET_CONFIG_det if self.det else NET_CONFIG_rec
352
+
353
+ assert isinstance(
354
+ self.lr_mult_list,
355
+ (list, tuple
356
+ )), 'lr_mult_list should be in (list, tuple) but got {}'.format(
357
+ type(self.lr_mult_list))
358
+ assert len(self.lr_mult_list
359
+ ) == 6, 'lr_mult_list length should be 6 but got {}'.format(
360
+ len(self.lr_mult_list))
361
+
362
+ self.conv1 = ConvBNLayer(
363
+ in_channels=3,
364
+ out_channels=make_divisible(16 * scale),
365
+ kernel_size=3,
366
+ stride=2,
367
+ lr_mult=self.lr_mult_list[0],
368
+ )
369
+
370
+ self.blocks2 = nn.Sequential(*[
371
+ LCNetV3Block(
372
+ in_channels=make_divisible(in_c * scale),
373
+ out_channels=make_divisible(out_c * scale),
374
+ dw_size=k,
375
+ stride=s,
376
+ use_se=se,
377
+ conv_kxk_num=conv_kxk_num,
378
+ lr_mult=self.lr_mult_list[1],
379
+ lab_lr=lab_lr,
380
+ ) for i, (k, in_c, out_c, s,
381
+ se) in enumerate(self.net_config['blocks2'])
382
+ ])
383
+
384
+ self.blocks3 = nn.Sequential(*[
385
+ LCNetV3Block(
386
+ in_channels=make_divisible(in_c * scale),
387
+ out_channels=make_divisible(out_c * scale),
388
+ dw_size=k,
389
+ stride=s,
390
+ use_se=se,
391
+ conv_kxk_num=conv_kxk_num,
392
+ lr_mult=self.lr_mult_list[2],
393
+ lab_lr=lab_lr,
394
+ ) for i, (k, in_c, out_c, s,
395
+ se) in enumerate(self.net_config['blocks3'])
396
+ ])
397
+
398
+ self.blocks4 = nn.Sequential(*[
399
+ LCNetV3Block(
400
+ in_channels=make_divisible(in_c * scale),
401
+ out_channels=make_divisible(out_c * scale),
402
+ dw_size=k,
403
+ stride=s,
404
+ use_se=se,
405
+ conv_kxk_num=conv_kxk_num,
406
+ lr_mult=self.lr_mult_list[3],
407
+ lab_lr=lab_lr,
408
+ ) for i, (k, in_c, out_c, s,
409
+ se) in enumerate(self.net_config['blocks4'])
410
+ ])
411
+
412
+ self.blocks5 = nn.Sequential(*[
413
+ LCNetV3Block(
414
+ in_channels=make_divisible(in_c * scale),
415
+ out_channels=make_divisible(out_c * scale),
416
+ dw_size=k,
417
+ stride=s,
418
+ use_se=se,
419
+ conv_kxk_num=conv_kxk_num,
420
+ lr_mult=self.lr_mult_list[4],
421
+ lab_lr=lab_lr,
422
+ ) for i, (k, in_c, out_c, s,
423
+ se) in enumerate(self.net_config['blocks5'])
424
+ ])
425
+
426
+ self.blocks6 = nn.Sequential(*[
427
+ LCNetV3Block(
428
+ in_channels=make_divisible(in_c * scale),
429
+ out_channels=make_divisible(out_c * scale),
430
+ dw_size=k,
431
+ stride=s,
432
+ use_se=se,
433
+ conv_kxk_num=conv_kxk_num,
434
+ lr_mult=self.lr_mult_list[5],
435
+ lab_lr=lab_lr,
436
+ ) for i, (k, in_c, out_c, s,
437
+ se) in enumerate(self.net_config['blocks6'])
438
+ ])
439
+ self.out_channels = make_divisible(512 * scale)
440
+
441
+ if self.det:
442
+ mv_c = [16, 24, 56, 480]
443
+ self.out_channels = [
444
+ make_divisible(self.net_config['blocks3'][-1][2] * scale),
445
+ make_divisible(self.net_config['blocks4'][-1][2] * scale),
446
+ make_divisible(self.net_config['blocks5'][-1][2] * scale),
447
+ make_divisible(self.net_config['blocks6'][-1][2] * scale),
448
+ ]
449
+
450
+ self.layer_list = nn.ModuleList([
451
+ nn.Conv2d(self.out_channels[0], int(mv_c[0] * scale), 1, 1, 0),
452
+ nn.Conv2d(self.out_channels[1], int(mv_c[1] * scale), 1, 1, 0),
453
+ nn.Conv2d(self.out_channels[2], int(mv_c[2] * scale), 1, 1, 0),
454
+ nn.Conv2d(self.out_channels[3], int(mv_c[3] * scale), 1, 1, 0),
455
+ ])
456
+ self.out_channels = [
457
+ int(mv_c[0] * scale),
458
+ int(mv_c[1] * scale),
459
+ int(mv_c[2] * scale),
460
+ int(mv_c[3] * scale),
461
+ ]
462
+
463
+ def forward(self, x):
464
+ out_list = []
465
+ x = self.conv1(x)
466
+
467
+ x = self.blocks2(x)
468
+ x = self.blocks3(x)
469
+ out_list.append(x)
470
+ x = self.blocks4(x)
471
+ out_list.append(x)
472
+ x = self.blocks5(x)
473
+ out_list.append(x)
474
+ x = self.blocks6(x)
475
+ out_list.append(x)
476
+
477
+ if self.det:
478
+ out_list[0] = self.layer_list[0](out_list[0])
479
+ out_list[1] = self.layer_list[1](out_list[1])
480
+ out_list[2] = self.layer_list[2](out_list[2])
481
+ out_list[3] = self.layer_list[3](out_list[3])
482
+ return out_list
483
+
484
+ if self.training:
485
+ x = F.adaptive_avg_pool2d(x, [1, 40])
486
+ else:
487
+ x = F.avg_pool2d(x, [3, 2])
488
+ return x
@@ -0,0 +1,132 @@
1
+ import torch.nn as nn
2
+
3
+ from .det_mobilenet_v3 import ConvBNLayer, ResidualUnit, make_divisible
4
+
5
+
6
+ class MobileNetV3(nn.Module):
7
+
8
+ def __init__(self,
9
+ in_channels=3,
10
+ model_name='small',
11
+ scale=0.5,
12
+ large_stride=None,
13
+ small_stride=None,
14
+ **kwargs):
15
+ super(MobileNetV3, self).__init__()
16
+ if small_stride is None:
17
+ small_stride = [2, 2, 2, 2]
18
+ if large_stride is None:
19
+ large_stride = [1, 2, 2, 2]
20
+
21
+ assert isinstance(
22
+ large_stride,
23
+ list), 'large_stride type must ' 'be list but got {}'.format(
24
+ type(large_stride))
25
+ assert isinstance(
26
+ small_stride,
27
+ list), 'small_stride type must ' 'be list but got {}'.format(
28
+ type(small_stride))
29
+ assert len(
30
+ large_stride
31
+ ) == 4, 'large_stride length must be ' '4 but got {}'.format(
32
+ len(large_stride))
33
+ assert len(
34
+ small_stride
35
+ ) == 4, 'small_stride length must be ' '4 but got {}'.format(
36
+ len(small_stride))
37
+
38
+ if model_name == 'large':
39
+ cfg = [
40
+ # k, exp, c, se, nl, s,
41
+ [3, 16, 16, False, 'relu', large_stride[0]],
42
+ [3, 64, 24, False, 'relu', (large_stride[1], 1)],
43
+ [3, 72, 24, False, 'relu', 1],
44
+ [5, 72, 40, True, 'relu', (large_stride[2], 1)],
45
+ [5, 120, 40, True, 'relu', 1],
46
+ [5, 120, 40, True, 'relu', 1],
47
+ [3, 240, 80, False, 'hard_swish', 1],
48
+ [3, 200, 80, False, 'hard_swish', 1],
49
+ [3, 184, 80, False, 'hard_swish', 1],
50
+ [3, 184, 80, False, 'hard_swish', 1],
51
+ [3, 480, 112, True, 'hard_swish', 1],
52
+ [3, 672, 112, True, 'hard_swish', 1],
53
+ [5, 672, 160, True, 'hard_swish', (large_stride[3], 1)],
54
+ [5, 960, 160, True, 'hard_swish', 1],
55
+ [5, 960, 160, True, 'hard_swish', 1],
56
+ ]
57
+ cls_ch_squeeze = 960
58
+ elif model_name == 'small':
59
+ cfg = [
60
+ # k, exp, c, se, nl, s,
61
+ [3, 16, 16, True, 'relu', (small_stride[0], 1)],
62
+ [3, 72, 24, False, 'relu', (small_stride[1], 1)],
63
+ [3, 88, 24, False, 'relu', 1],
64
+ [5, 96, 40, True, 'hard_swish', (small_stride[2], 1)],
65
+ [5, 240, 40, True, 'hard_swish', 1],
66
+ [5, 240, 40, True, 'hard_swish', 1],
67
+ [5, 120, 48, True, 'hard_swish', 1],
68
+ [5, 144, 48, True, 'hard_swish', 1],
69
+ [5, 288, 96, True, 'hard_swish', (small_stride[3], 1)],
70
+ [5, 576, 96, True, 'hard_swish', 1],
71
+ [5, 576, 96, True, 'hard_swish', 1],
72
+ ]
73
+ cls_ch_squeeze = 576
74
+ else:
75
+ raise NotImplementedError('mode[' + model_name +
76
+ '_model] is not implemented!')
77
+
78
+ supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
79
+ assert scale in supported_scale, 'supported scales are {} but input scale is {}'.format(
80
+ supported_scale, scale)
81
+
82
+ inplanes = 16
83
+ # conv1
84
+ self.conv1 = ConvBNLayer(
85
+ in_channels=in_channels,
86
+ out_channels=make_divisible(inplanes * scale),
87
+ kernel_size=3,
88
+ stride=2,
89
+ padding=1,
90
+ groups=1,
91
+ if_act=True,
92
+ act='hard_swish',
93
+ )
94
+ i = 0
95
+ block_list = []
96
+ inplanes = make_divisible(inplanes * scale)
97
+ for k, exp, c, se, nl, s in cfg:
98
+ block_list.append(
99
+ ResidualUnit(
100
+ in_channels=inplanes,
101
+ mid_channels=make_divisible(scale * exp),
102
+ out_channels=make_divisible(scale * c),
103
+ kernel_size=k,
104
+ stride=s,
105
+ use_se=se,
106
+ act=nl,
107
+ name='conv' + str(i + 2),
108
+ ))
109
+ inplanes = make_divisible(scale * c)
110
+ i += 1
111
+ self.blocks = nn.Sequential(*block_list)
112
+
113
+ self.conv2 = ConvBNLayer(
114
+ in_channels=inplanes,
115
+ out_channels=make_divisible(scale * cls_ch_squeeze),
116
+ kernel_size=1,
117
+ stride=1,
118
+ padding=0,
119
+ groups=1,
120
+ if_act=True,
121
+ act='hard_swish',
122
+ )
123
+
124
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
125
+ self.out_channels = make_divisible(scale * cls_ch_squeeze)
126
+
127
+ def forward(self, x):
128
+ x = self.conv1(x)
129
+ x = self.blocks(x)
130
+ x = self.conv2(x)
131
+ x = self.pool(x)
132
+ return x