pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (306) hide show
  1. pyxllib/algo/geo.py +12 -0
  2. pyxllib/algo/intervals.py +1 -1
  3. pyxllib/algo/matcher.py +78 -0
  4. pyxllib/algo/pupil.py +187 -19
  5. pyxllib/algo/specialist.py +2 -1
  6. pyxllib/algo/stat.py +38 -2
  7. {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
  8. pyxllib/autogui/activewin.py +246 -0
  9. pyxllib/autogui/all.py +9 -0
  10. pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
  11. pyxllib/autogui/uiautolib.py +362 -0
  12. pyxllib/autogui/wechat.py +827 -0
  13. pyxllib/autogui/wechat_msg.py +421 -0
  14. pyxllib/autogui/wxautolib.py +84 -0
  15. pyxllib/cv/slidercaptcha.py +137 -0
  16. pyxllib/data/echarts.py +123 -12
  17. pyxllib/data/jsonlib.py +89 -0
  18. pyxllib/data/pglib.py +514 -30
  19. pyxllib/data/sqlite.py +231 -4
  20. pyxllib/ext/JLineViewer.py +14 -1
  21. pyxllib/ext/drissionlib.py +277 -0
  22. pyxllib/ext/kq5034lib.py +0 -1594
  23. pyxllib/ext/robustprocfile.py +497 -0
  24. pyxllib/ext/unixlib.py +6 -5
  25. pyxllib/ext/utools.py +108 -95
  26. pyxllib/ext/webhook.py +32 -14
  27. pyxllib/ext/wjxlib.py +88 -0
  28. pyxllib/ext/wpsapi.py +124 -0
  29. pyxllib/ext/xlwork.py +9 -0
  30. pyxllib/ext/yuquelib.py +1003 -71
  31. pyxllib/file/docxlib.py +1 -1
  32. pyxllib/file/libreoffice.py +165 -0
  33. pyxllib/file/movielib.py +9 -0
  34. pyxllib/file/packlib/__init__.py +112 -75
  35. pyxllib/file/pdflib.py +1 -1
  36. pyxllib/file/pupil.py +1 -1
  37. pyxllib/file/specialist/dirlib.py +1 -1
  38. pyxllib/file/specialist/download.py +10 -3
  39. pyxllib/file/specialist/filelib.py +266 -55
  40. pyxllib/file/xlsxlib.py +205 -50
  41. pyxllib/file/xlsyncfile.py +341 -0
  42. pyxllib/prog/cachetools.py +64 -0
  43. pyxllib/prog/filelock.py +42 -0
  44. pyxllib/prog/multiprogs.py +940 -0
  45. pyxllib/prog/newbie.py +9 -2
  46. pyxllib/prog/pupil.py +129 -60
  47. pyxllib/prog/specialist/__init__.py +176 -2
  48. pyxllib/prog/specialist/bc.py +5 -2
  49. pyxllib/prog/specialist/browser.py +11 -2
  50. pyxllib/prog/specialist/datetime.py +68 -0
  51. pyxllib/prog/specialist/tictoc.py +12 -13
  52. pyxllib/prog/specialist/xllog.py +5 -5
  53. pyxllib/prog/xlosenv.py +7 -0
  54. pyxllib/text/airscript.js +744 -0
  55. pyxllib/text/charclasslib.py +17 -5
  56. pyxllib/text/jiebalib.py +6 -3
  57. pyxllib/text/jinjalib.py +32 -0
  58. pyxllib/text/jsa_ai_prompt.md +271 -0
  59. pyxllib/text/jscode.py +159 -4
  60. pyxllib/text/nestenv.py +1 -1
  61. pyxllib/text/newbie.py +12 -0
  62. pyxllib/text/pupil/common.py +26 -0
  63. pyxllib/text/specialist/ptag.py +2 -2
  64. pyxllib/text/templates/echart_base.html +11 -0
  65. pyxllib/text/templates/highlight_code.html +17 -0
  66. pyxllib/text/templates/latex_editor.html +103 -0
  67. pyxllib/text/xmllib.py +76 -14
  68. pyxllib/xl.py +2 -1
  69. pyxllib-0.3.197.dist-info/METADATA +48 -0
  70. pyxllib-0.3.197.dist-info/RECORD +126 -0
  71. {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
  72. pyxllib/ext/autogui/__init__.py +0 -8
  73. pyxllib-0.3.96.dist-info/METADATA +0 -51
  74. pyxllib-0.3.96.dist-info/RECORD +0 -333
  75. pyxllib-0.3.96.dist-info/top_level.txt +0 -2
  76. pyxlpr/ai/__init__.py +0 -5
  77. pyxlpr/ai/clientlib.py +0 -1281
  78. pyxlpr/ai/specialist.py +0 -286
  79. pyxlpr/ai/torch_app.py +0 -172
  80. pyxlpr/ai/xlpaddle.py +0 -655
  81. pyxlpr/ai/xltorch.py +0 -705
  82. pyxlpr/data/__init__.py +0 -11
  83. pyxlpr/data/coco.py +0 -1325
  84. pyxlpr/data/datacls.py +0 -365
  85. pyxlpr/data/datasets.py +0 -200
  86. pyxlpr/data/gptlib.py +0 -1291
  87. pyxlpr/data/icdar/__init__.py +0 -96
  88. pyxlpr/data/icdar/deteval.py +0 -377
  89. pyxlpr/data/icdar/icdar2013.py +0 -341
  90. pyxlpr/data/icdar/iou.py +0 -340
  91. pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
  92. pyxlpr/data/imtextline.py +0 -473
  93. pyxlpr/data/labelme.py +0 -866
  94. pyxlpr/data/removeline.py +0 -179
  95. pyxlpr/data/specialist.py +0 -57
  96. pyxlpr/eval/__init__.py +0 -85
  97. pyxlpr/paddleocr.py +0 -776
  98. pyxlpr/ppocr/__init__.py +0 -15
  99. pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
  100. pyxlpr/ppocr/data/__init__.py +0 -135
  101. pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
  102. pyxlpr/ppocr/data/imaug/__init__.py +0 -67
  103. pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
  104. pyxlpr/ppocr/data/imaug/east_process.py +0 -437
  105. pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
  106. pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
  107. pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
  108. pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
  109. pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
  110. pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
  111. pyxlpr/ppocr/data/imaug/operators.py +0 -433
  112. pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
  113. pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
  114. pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
  115. pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
  116. pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
  117. pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
  118. pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
  119. pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
  120. pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
  121. pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
  122. pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
  123. pyxlpr/ppocr/data/simple_dataset.py +0 -372
  124. pyxlpr/ppocr/losses/__init__.py +0 -61
  125. pyxlpr/ppocr/losses/ace_loss.py +0 -52
  126. pyxlpr/ppocr/losses/basic_loss.py +0 -135
  127. pyxlpr/ppocr/losses/center_loss.py +0 -88
  128. pyxlpr/ppocr/losses/cls_loss.py +0 -30
  129. pyxlpr/ppocr/losses/combined_loss.py +0 -67
  130. pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
  131. pyxlpr/ppocr/losses/det_db_loss.py +0 -80
  132. pyxlpr/ppocr/losses/det_east_loss.py +0 -63
  133. pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
  134. pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
  135. pyxlpr/ppocr/losses/distillation_loss.py +0 -272
  136. pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
  137. pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
  138. pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
  139. pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
  140. pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
  141. pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
  142. pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
  143. pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
  144. pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
  145. pyxlpr/ppocr/losses/table_att_loss.py +0 -109
  146. pyxlpr/ppocr/metrics/__init__.py +0 -44
  147. pyxlpr/ppocr/metrics/cls_metric.py +0 -45
  148. pyxlpr/ppocr/metrics/det_metric.py +0 -82
  149. pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
  150. pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
  151. pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
  152. pyxlpr/ppocr/metrics/kie_metric.py +0 -70
  153. pyxlpr/ppocr/metrics/rec_metric.py +0 -75
  154. pyxlpr/ppocr/metrics/table_metric.py +0 -50
  155. pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
  156. pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
  157. pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
  158. pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
  159. pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
  160. pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
  161. pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
  162. pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
  163. pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
  164. pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
  165. pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
  166. pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
  167. pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
  168. pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
  169. pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
  170. pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
  171. pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
  172. pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
  173. pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
  174. pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
  175. pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
  176. pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
  177. pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
  178. pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
  179. pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
  180. pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
  181. pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
  182. pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
  183. pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
  184. pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
  185. pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
  186. pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
  187. pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
  188. pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
  189. pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
  190. pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
  191. pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
  192. pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
  193. pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
  194. pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
  195. pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
  196. pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
  197. pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
  198. pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
  199. pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
  200. pyxlpr/ppocr/optimizer/__init__.py +0 -61
  201. pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
  202. pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
  203. pyxlpr/ppocr/optimizer/optimizer.py +0 -160
  204. pyxlpr/ppocr/optimizer/regularizer.py +0 -52
  205. pyxlpr/ppocr/postprocess/__init__.py +0 -55
  206. pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
  207. pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
  208. pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
  209. pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
  210. pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
  211. pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
  212. pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
  213. pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
  214. pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
  215. pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
  216. pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
  217. pyxlpr/ppocr/tools/__init__.py +0 -14
  218. pyxlpr/ppocr/tools/eval.py +0 -83
  219. pyxlpr/ppocr/tools/export_center.py +0 -77
  220. pyxlpr/ppocr/tools/export_model.py +0 -129
  221. pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
  222. pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
  223. pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
  224. pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
  225. pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
  226. pyxlpr/ppocr/tools/infer/utility.py +0 -629
  227. pyxlpr/ppocr/tools/infer_cls.py +0 -83
  228. pyxlpr/ppocr/tools/infer_det.py +0 -134
  229. pyxlpr/ppocr/tools/infer_e2e.py +0 -122
  230. pyxlpr/ppocr/tools/infer_kie.py +0 -153
  231. pyxlpr/ppocr/tools/infer_rec.py +0 -146
  232. pyxlpr/ppocr/tools/infer_table.py +0 -107
  233. pyxlpr/ppocr/tools/program.py +0 -596
  234. pyxlpr/ppocr/tools/test_hubserving.py +0 -117
  235. pyxlpr/ppocr/tools/train.py +0 -163
  236. pyxlpr/ppocr/tools/xlprog.py +0 -748
  237. pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
  238. pyxlpr/ppocr/utils/__init__.py +0 -24
  239. pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
  240. pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
  241. pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
  242. pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
  243. pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
  244. pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
  245. pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
  246. pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
  247. pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
  248. pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
  249. pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
  250. pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
  251. pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
  252. pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
  253. pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
  254. pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
  255. pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
  256. pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
  257. pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
  258. pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
  259. pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
  260. pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
  261. pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
  262. pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
  263. pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
  264. pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
  265. pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
  266. pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
  267. pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
  268. pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
  269. pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
  270. pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
  271. pyxlpr/ppocr/utils/dict90.txt +0 -90
  272. pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
  273. pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
  274. pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
  275. pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
  276. pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
  277. pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
  278. pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
  279. pyxlpr/ppocr/utils/en_dict.txt +0 -95
  280. pyxlpr/ppocr/utils/gen_label.py +0 -81
  281. pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
  282. pyxlpr/ppocr/utils/iou.py +0 -54
  283. pyxlpr/ppocr/utils/logging.py +0 -69
  284. pyxlpr/ppocr/utils/network.py +0 -84
  285. pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
  286. pyxlpr/ppocr/utils/profiler.py +0 -110
  287. pyxlpr/ppocr/utils/save_load.py +0 -150
  288. pyxlpr/ppocr/utils/stats.py +0 -72
  289. pyxlpr/ppocr/utils/utility.py +0 -80
  290. pyxlpr/ppstructure/__init__.py +0 -13
  291. pyxlpr/ppstructure/predict_system.py +0 -187
  292. pyxlpr/ppstructure/table/__init__.py +0 -13
  293. pyxlpr/ppstructure/table/eval_table.py +0 -72
  294. pyxlpr/ppstructure/table/matcher.py +0 -192
  295. pyxlpr/ppstructure/table/predict_structure.py +0 -136
  296. pyxlpr/ppstructure/table/predict_table.py +0 -221
  297. pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
  298. pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
  299. pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
  300. pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
  301. pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
  302. pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
  303. pyxlpr/ppstructure/utility.py +0 -71
  304. pyxlpr/xlai.py +0 -10
  305. /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
  306. {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
pyxlpr/ai/xltorch.py DELETED
@@ -1,705 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # @Author : 陈坤泽
4
- # @Email : 877362867@qq.com
5
- # @Date : 2021/06/06 23:10
6
-
7
-
8
- from pyxllib.xlcv import *
9
-
10
- import torch
11
- from torch import nn, optim
12
- import torch.utils.data
13
-
14
- import torchvision
15
- from torchvision import transforms
16
-
17
- # 把pytorch等常用的导入写了
18
- import torch.utils.data
19
- from torchvision.datasets import VisionDataset
20
-
21
- from pyxlpr.ai.specialist import ClasEvaluater, NvmDevice
22
-
23
- __base = """
24
- """
25
-
26
-
27
- def get_most_free_torch_gpu_device():
28
- gpu_id = NvmDevice().get_most_free_gpu_id()
29
- if gpu_id is not None:
30
- return torch.device(f'cuda:{gpu_id}')
31
-
32
-
33
- def get_device():
34
- """ 自动获得一个可用的设备
35
- """
36
- return get_most_free_torch_gpu_device() or torch.device('cpu')
37
-
38
-
39
- __data = """
40
- """
41
-
42
-
43
- class TinyDataset(torch.utils.data.Dataset):
44
- def __init__(self, labelfile, label_transform, maxn=None):
45
- """ 超轻量级的Dataset类,一般由外部ProjectData类指定每行label的转换规则 """
46
- self.labels = File(labelfile).read().splitlines()
47
- self.label_transform = label_transform
48
-
49
- self.number = len(self.labels)
50
- if maxn: self.number = min(self.number, maxn)
51
-
52
- def __len__(self):
53
- return self.number
54
-
55
- def __getitem__(self, idx):
56
- return self.label_transform(self.labels[idx])
57
-
58
-
59
- class InputDataset(torch.utils.data.Dataset):
60
- def __init__(self, raw_in, transform=None, *, y_placeholder=...):
61
- """ 将非list、tuple数据转为list,并生成一个dataset类的万用类
62
- :param raw_in:
63
- """
64
- if not isinstance(raw_in, (list, tuple)):
65
- raw_in = [raw_in]
66
-
67
- self.data = raw_in
68
- self.transform = transform
69
- self.y_placeholder = y_placeholder
70
-
71
- def __len__(self):
72
- return len(self.data)
73
-
74
- def __getitem__(self, idx):
75
- x = self.data[idx]
76
- if self.transform:
77
- x = self.transform(x)
78
-
79
- if self.y_placeholder is not ...:
80
- return x, self.y_placeholder
81
- else:
82
- return x
83
-
84
-
85
- __model = """
86
- """
87
-
88
-
89
- class LeNet5(nn.Module):
90
- """ https://towardsdatascience.com/implementing-yann-lecuns-lenet-5-in-pytorch-5e05a0911320 """
91
-
92
- def __init__(self, n_classes):
93
- super().__init__()
94
-
95
- self.feature_extractor = nn.Sequential(
96
- nn.Conv2d(1, 6, kernel_size=5, stride=1),
97
- nn.Tanh(),
98
- nn.AvgPool2d(kernel_size=2),
99
- nn.Conv2d(6, 16, kernel_size=5, stride=1),
100
- nn.Tanh(),
101
- nn.AvgPool2d(kernel_size=2),
102
- nn.Conv2d(16, 120, kernel_size=5, stride=1),
103
- nn.Tanh()
104
- )
105
-
106
- self.classifier = nn.Sequential(
107
- nn.Linear(120, 84),
108
- nn.Tanh(),
109
- nn.Linear(84, n_classes),
110
- )
111
-
112
- def forward(self, batched_inputs):
113
- device = next(self.parameters()).device
114
-
115
- x = batched_inputs[0].to(device)
116
- x = self.feature_extractor(x)
117
- x = torch.flatten(x, 1)
118
- logits = self.classifier(x)
119
-
120
- if self.training:
121
- y = batched_inputs[1].to(device)
122
- return nn.functional.cross_entropy(logits, y)
123
- else:
124
- return logits.argmax(dim=1)
125
-
126
-
127
- __train = """
128
- """
129
-
130
-
131
- class Trainer:
132
- """ 对pytorch模型的训练、测试等操作的进一步封装
133
-
134
- # TODO log变成可选项,可以关掉
135
- """
136
-
137
- def __init__(self, log_dir, device, data, model, optimizer,
138
- loss_func=None, pred_func=None, accuracy_func=None):
139
- # 0 初始化成员变量
140
- self.log_dir, self.device = log_dir, device
141
- self.data, self.model, self.optimizer = data, model, optimizer
142
- if loss_func: self.loss_func = loss_func
143
- if pred_func: self.pred_func = pred_func
144
- if accuracy_func: self.accuracy_func = accuracy_func
145
-
146
- # 1 日志
147
- timetag = datetime.datetime.now().strftime('%Y%m%d.%H%M%S')
148
- # self.curlog_dir = Dir(self.log_dir / timetag) # 本轮运行,实际log位置,是存放在一个子目录里
149
- self.curlog_dir = Dir(self.log_dir)
150
- self.curlog_dir.ensure_dir()
151
- self.log = get_xllog(log_file=self.curlog_dir / 'log.txt')
152
- self.log.info(f'1/4 log_dir={self.curlog_dir}')
153
-
154
- # 2 设备
155
- self.log.info(f'2/4 use_device={self.device}')
156
-
157
- # 3 数据
158
- self.train_dataloader = self.data.get_train_dataloader()
159
- self.val_dataloader = self.data.get_val_dataloader()
160
- self.train_data_number = len(self.train_dataloader.dataset)
161
- self.val_data_number = len(self.val_dataloader.dataset)
162
- self.log.info(f'3/4 get data, train_data_number={self.train_data_number}(batch={len(self.train_dataloader)}), '
163
- f'val_data_number={self.val_data_number}(batch={len(self.val_dataloader)}), '
164
- f'batch_size={self.data.batch_size}')
165
-
166
- # 4 模型
167
- parasize = sum(map(lambda p: p.numel(), self.model.parameters()))
168
- self.log.info(
169
- f'4/4 model parameters size: {parasize}* 4 Bytes per float ≈ {humanfriendly.format_size(parasize * 4)}')
170
-
171
- # 5 其他辅助变量
172
- self.min_total_loss = math.inf # 目前epoch中总损失最小的值(训练损失,训练过程损失)
173
- self.min_train_loss = math.inf # 训练集损失
174
- self.max_val_accuracy = 0 # 验证集精度
175
-
176
- @classmethod
177
- def loss_func(cls, model_out, y):
178
- """ 自定义损失函数 """
179
- # return loss
180
- raise NotImplementedError
181
-
182
- @classmethod
183
- def pred_func(cls, model_out):
184
- """ 自定义模型输出到预测结果 """
185
- # return y_hat
186
- raise NotImplementedError
187
-
188
- @classmethod
189
- def accuracy_func(cls, y_hat, y):
190
- """ 自定义预测结果和实际标签y之间的精度
191
-
192
- 返回"正确的样本数"(在非分类任务中,需要抽象出这个数量关系)
193
- """
194
- # return accuracy
195
- raise NotImplementedError
196
-
197
- def loss_values_stat(self, loss_vales):
198
- """ 一组loss损失的统计分析
199
-
200
- :param loss_vales: 一次batch中,多份样本产生的误差数据
201
- :return: 统计信息文本字符串
202
- """
203
- if not loss_vales:
204
- raise ValueError
205
-
206
- data = np.array(loss_vales, dtype=float)
207
- n, sum_ = len(data), data.sum()
208
- mean, std = data.mean(), data.std()
209
- msg = f'total_loss={sum_:.3f}, mean±std={mean:.3f}±{std:.3f}({max(data):.3f}->{min(data):.3f})'
210
- if sum_ < self.min_total_loss:
211
- self.min_total_loss = sum_
212
- msg = '*' + msg
213
- return msg
214
-
215
- @classmethod
216
- def sample_size(cls, data):
217
- """ 单个样本占用的空间大小,返回字节数 """
218
- x, label = data.dataset[0] # 取第0个样本作为参考
219
- return getasizeof(x.numpy()) + getasizeof(label)
220
-
221
- def save_model_state(self, file, if_exists='error'):
222
- """ 保存模型参数值
223
- 一般存储model.state_dict,而不是直接存储model,确保灵活性
224
-
225
- # TODO 和path结合,增加if_exists参数
226
- """
227
- f = File(file, self.curlog_dir)
228
- if f.exist_preprcs(if_exists=if_exists):
229
- f.ensure_parent()
230
- torch.save(self.model.state_dict(), str(f))
231
-
232
- def load_model_state(self, file):
233
- """ 读取模型参数值
234
-
235
- 注意load和save的root差异! load的默认父目录是在log_dir,而save默认是在curlog_dir!
236
- """
237
- f = File(file, self.log_dir)
238
- self.model.load_state_dict(torch.load(str(f), map_location=self.device))
239
-
240
- def viz_data(self):
241
- """ 用visdom显示样本数据
242
-
243
- TODO 增加一些自定义格式参数
244
- TODO 不能使用\n、\r\n、<br/>实现文本换行,有时间可以研究下,结合nrow、图片宽度,自动推算,怎么美化展示效果
245
- """
246
- from visdom import Visdom
247
-
248
- viz = Visdom()
249
- if not viz: return
250
-
251
- x, label = next(iter(self.train_dataloader))
252
- viz.one_batch_images(x, label, 'train data')
253
-
254
- x, label = next(iter(self.val_dataloader))
255
- viz.one_batch_images(x, label, 'val data')
256
-
257
- def training_one_epoch(self):
258
- # 1 检查模式
259
- if not self.model.training:
260
- self.model.train(True)
261
-
262
- # 2 训练一轮
263
- loss_values = []
264
- for x, y in self.train_dataloader:
265
- # 每个batch可能很大,所以每个batch依次放到cuda,而不是一次性全放入
266
- x, y = x.to(self.device), y.to(self.device)
267
-
268
- y_hat = self.model(x)
269
- loss = self.loss_func(y_hat, y)
270
- loss_values.append(float(loss))
271
-
272
- self.optimizer.zero_grad()
273
- loss.backward()
274
- self.optimizer.step()
275
-
276
- # 3 训练阶段只看loss,不看实际预测准确度,默认每个epoch都会输出
277
- return loss_values
278
-
279
- def calculate_accuracy(self, dataloader):
280
- """ 测试验证集等数据上的精度 """
281
- # 1 eval模式
282
- if self.model.training:
283
- self.model.train(False)
284
-
285
- # 2 关闭梯度,可以节省显存和加速
286
- with torch.no_grad():
287
- tt = TicToc()
288
-
289
- # 预测结果,计算正确率
290
- loss, correct, number = [], 0, len(dataloader.dataset)
291
- for x, y in dataloader:
292
- x, y = x.to(self.device), y.to(self.device)
293
- model_out = self.model(x)
294
- loss.append(self.loss_func(model_out, y))
295
- y_hat = self.pred_func(model_out)
296
- correct += self.accuracy_func(y_hat, y) # 预测正确的数量
297
- elapsed_time, mean_loss = tt.tocvalue(), np.mean(loss, dtype=float)
298
- accuracy = correct / number
299
- info = f'accuracy={correct:.0f}/{number} ({accuracy:.2%})\t' \
300
- f'mean_loss={mean_loss:.3f}\telapsed_time={elapsed_time:.0f}s'
301
- return accuracy, mean_loss, info
302
-
303
- def train_accuracy(self):
304
- accuracy, mean_loss, info = self.calculate_accuracy(self.train_dataloader)
305
- info = 'train ' + info
306
- if mean_loss < self.min_train_loss:
307
- # 如果是best ever,则log换成debug模式输出
308
- self.log.debug('*' + info)
309
- self.min_train_loss = mean_loss
310
- else:
311
- self.log.info(info)
312
- return accuracy
313
-
314
- def val_accuracy(self, save_model=None):
315
- """
316
- :param save_model: 如果验证集精度best ever,则保存当前epoch模型
317
- 如果精度不是最好的,哪怕指定save_model也不会保存的
318
- :return:
319
- """
320
- accuracy, mean_loss, info = self.calculate_accuracy(self.val_dataloader)
321
- info = ' val ' + info
322
- if accuracy > self.max_val_accuracy:
323
- self.log.debug('*' + info)
324
- if save_model:
325
- self.save_model_state(save_model, if_exists='replace')
326
- self.max_val_accuracy = accuracy
327
- else:
328
- self.log.info(info)
329
- return accuracy
330
-
331
- def training(self, epochs, *, start_epoch=0, log_interval=1):
332
- """ 主要训练接口
333
-
334
- :param epochs: 训练代数,输出时从1开始编号
335
- :param start_epoch: 直接从现有的第几个epoch的模型读取参数
336
- 使用该参数,需要在self.save_dir有对应名称的model文件
337
- :param log_interval: 每隔几个epoch输出当前epoch的训练情况,损失值
338
- 每个几轮epoch进行一次监控
339
- 且如果总损失是训练以来最好的结果,则会保存模型
340
- 并对训练集、测试集进行精度测试
341
- TODO 看到其他框架,包括智财的框架,对保存的模型文件,都有更规范的一套命名方案,有空要去学一下
342
- :return:
343
- """
344
- from visdom import Visdom
345
-
346
- # 1 配置参数
347
- tag = self.model.__class__.__name__
348
- epoch_time_tag = f'elapsed_time' if log_interval == 1 else f'{log_interval}*epoch_time'
349
- viz = Visdom() # 其实这里不是用原生的Visdom,而是我封装过的,但是我封装的那个也没太大作用意义,删掉了
350
-
351
- # 2 加载之前的模型继续训练
352
- if start_epoch:
353
- self.load_model_state(f'{tag} epoch{start_epoch}.pth')
354
-
355
- # 3 训练
356
- tt = TicToc()
357
- for epoch in range(start_epoch + 1, epochs + 1):
358
- loss_values = self.training_one_epoch()
359
- # 3.1 训练损失可视化
360
- if viz: viz.loss_line(loss_values, epoch, 'train_loss')
361
- # 3.2 显示epoch训练效果
362
- if log_interval and epoch % log_interval == 0:
363
- # 3.2.1 显示训练用时、训练损失
364
- msg = self.loss_values_stat(loss_values)
365
- elapsed_time = tt.tocvalue(restart=True)
366
- info = f'epoch={epoch}, {epoch_time_tag}={elapsed_time:.0f}s\t{msg.lstrip("*")}'
367
- # 3.2.2 截止目前训练损失最小的结果
368
- if msg[0] == '*':
369
- self.log.debug('*' + info)
370
- # 3.2.2.1 测试训练集、验证集上的精度
371
- accuracy1 = self.train_accuracy()
372
- accuracy2 = self.val_accuracy(save_model=f'{tag} epoch{epoch}.pth')
373
- # 3.2.2.2 可视化图表
374
- if viz: viz.plot_line([[accuracy1, accuracy2]], [epoch], 'accuracy', legend=['train', 'val'])
375
- else:
376
- self.log.info(info)
377
-
378
-
379
- @deprecated(reason='推荐使用XlPredictor实现')
380
- def gen_classification_func(model, *, state_file=None, transform=None, pred_func=None,
381
- device=None):
382
- """ 工厂函数,生成一个分类器函数
383
-
384
- 用这个函数做过渡的一个重要目的,也是避免重复加载模型
385
-
386
- :param model: 模型结构
387
- :param state_file: 存储参数的文件
388
- :param transform: 每一个输入样本的预处理函数
389
- :param pred_func: model 结果的参数的后处理
390
- :return: 返回的函数结构见下述 cls_func
391
- """
392
- if state_file: model.load_state_dict(torch.load(str(state_file), map_location=get_device()))
393
- model.train(False)
394
- device = device or get_device()
395
- model.to(device)
396
-
397
- def cls_func(raw_in):
398
- """
399
- :param raw_in: 输入可以是路径、np.ndarray、PIL图片等,都为转为batch结构的tensor
400
- im,一张图片路径、np.ndarray、PIL图片
401
- [im1, im2, ...],多张图片清单
402
- :return: 输入如果只有一张图片,则返回一个结果
403
- 否则会存在list,返回多个结果
404
- """
405
- dataset = InputDataset(raw_in, transform)
406
- # TODO batch_size根据device空间大小自适应设置
407
- xs = torch.utils.data.DataLoader(dataset, batch_size=8)
408
- res = None
409
- for x in xs:
410
- # 每个batch可能很大,所以每个batch依次放到cuda,而不是一次性全放入
411
- x = x.to(device)
412
- y = model(x)
413
- if pred_func: y = pred_func(y)
414
- res = y if res is None else (res + y)
415
- return res
416
-
417
- return cls_func
418
-
419
-
420
- class XlPredictor:
421
- """ 生成一个类似函数用法的推断功能类
422
-
423
- 这是一个通用的生成器,不同的业务可以继承开发,进一步设计细则
424
-
425
- 这里默认写的结构是兼容detectron2框架的分类模型,即model.forward:
426
- 输入:list,第1个是batch_x,第2个是batch_y
427
- 输出:training是logits,eval是(batch)y_hat
428
- """
429
-
430
- def __init__(self, model, state_file=None, device=None, *, batch_size=1, y_placeholder=...):
431
- """
432
- :param model: 基于d2框架的模型结构
433
- :param state_file: 存储权重的文件
434
- 一般写某个本地文件路径
435
- 也可以写url地址,会下载存储到临时目录中
436
- 可以不传入文件,直接给到初始化好权重的model,该模式常用语训练阶段的model
437
- :param batch_size: 支持每次最多几个样本一起推断
438
- 具体运作细节参见 XlPredictor.inputs2loader的解释
439
- TODO batch_size根据device空间大小自适应设置
440
- :param y_placeholder: 参见XlPredictor.inputs2loader的解释
441
- """
442
- # 默认使用model所在的device
443
- if device is None:
444
- self.device = next(model.parameters()).device
445
- else:
446
- self.device = device
447
-
448
- if state_file is not None:
449
- if is_url(state_file):
450
- state_file = download(state_file, XlPath.tempdir() / 'xlpr')
451
- state = torch.load(str(state_file), map_location=self.device)
452
- if 'model' in state:
453
- state = state['model']
454
- model = model.to(device)
455
- model.load_state_dict(state)
456
-
457
- self.model = model
458
- self.model.train(False)
459
-
460
- self.batch_size = batch_size
461
- self.y_placeholder = y_placeholder
462
-
463
- self.transform = self.build_transform()
464
- self.target_transform = self.build_target_transform()
465
-
466
- @classmethod
467
- def build_transform(cls):
468
- """ 单个数据的转换规则,进入模型前的读取、格式转换
469
-
470
- 为了效率性能,建议比较特殊的不同初始化策略,可以额外定义函数接口,例如:def from_paths()
471
- """
472
- return None
473
-
474
- @classmethod
475
- def build_target_transform(cls):
476
- """ 单个结果的转换的规则,模型预测完的结果,到最终结果的转换方式
477
-
478
- 一些简单的情况直接返回y即可,但还有些复杂的任务可能要增加后处理
479
- """
480
- return None
481
-
482
- def inputs2loader(self, raw_in, *, batch_size=None, y_placeholder=..., sampler=None, **kwargs):
483
- """ 将各种类列表数据,转成torch.utils.data.DataLoader类型
484
-
485
- :param raw_in: 各种类列表数据格式,或者单个数据,都为转为batch结构的tensor
486
- torch.util.data.DataLoader
487
- 此时XlPredictor自定义参数全部无效:transform、batch_size、y_placeholder,sampler
488
- 因为这些在loader里都有配置了
489
- torch.util.data.Dataset
490
- 此时可以定制扩展的参数有:batch_size,sampler
491
- [data1, data2, ...],列表表示批量处理多个数据
492
- 此时所有配置参数均可用:transform、batch_size、y_placeholder, sampler
493
- 通常是图片文件路径清单
494
- XlPredictor原生并没有扩展图片读取功能,但可以通过transform增加CvPrcs.read来支持
495
- single_data,单个数据
496
- 通常是单个图片文件路径,注意transfrom要增加xlcv.read或xlpil.read来支持路径读取
497
- 注意:有时候单个数据就是list格式,此时需要麻烦点,再套一层list避免歧义
498
- :param batch_size: 支持每次最多几个样本一起推断
499
- TODO batch_size根据device空间大小自适应设置
500
- :param y_placeholder: 常见的model.forward,是只输入batch_x就行,这时候就默认值处理机制就行
501
- 但我从d2框架模仿的写法,forward需要补一个y的真实值,输入是[batch_x, batch_y]
502
- 实际预测数据可能没有y,此时需要填充一个batch_y=None来对齐,即设置y_placeholder=None
503
- 或者y_placeholder=0,则所有的y用0作为占位符填充
504
- 不过用None、0、False这些填充都很诡异,容易误导开发者,建议需要设置的时候使用-1
505
-
506
- 如果读者写的model.forward前传机制不同,本来batch_inputs就只输入x没有y,则这里不用设置y_placeholder参数
507
- :param sampler: 有时候只是要简单抽样部分数据测试下,可以设置该参数
508
- 比如random.sample(range(10), 5):可以从前10个数据中,无放回随机抽取5个数据
509
- """
510
- if isinstance(raw_in, torch.utils.data.DataLoader):
511
- loader = raw_in
512
- else:
513
- if not isinstance(raw_in, torch.utils.data.Dataset):
514
- y_placeholder = first_nonnone([y_placeholder, self.y_placeholder], lambda x: x is not ...)
515
- dataset = InputDataset(raw_in, self.transform, y_placeholder=y_placeholder)
516
- else:
517
- if not isinstance(raw_in, (list, tuple)):
518
- raw_in = [raw_in]
519
- dataset = raw_in
520
- batch_size = first_nonnone([batch_size, self.batch_size])
521
- loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=sampler, **kwargs)
522
-
523
- return loader
524
-
525
- def forward(self, loader, *, print_mode=False, return_gt=True):
526
- """ 前向传播
527
-
528
- 改功能是__call__的子部分,常在train、eval阶段单独调用
529
- 因为eval阶段,已经有预设好的train_loader、val_loader,不需要使用inputs2loader智能生成一个loader
530
-
531
- :param torch.utils.data.DataLoader loader: 标准的DataLoader类型,每次能获取[batch_x, batch_y]
532
- :param print_mode: 有时候数据量比较大,可能会需要看推断进度条
533
- :param return_gt: 注意跟__call__的不同,这里默认是True,__call__默认是False
534
- 前者常用于评价阶段,后者常用于部署阶段,应用场景不同,常见配置有区别
535
- :return:
536
- return_gt=True(默认):[(y1, y_hat1), (y2, y_hat2), ...]
537
- return_gt=False:[y_hat1, y_hat2, ...]
538
- """
539
- preds = []
540
- with torch.no_grad():
541
- for batched_inputs in tqdm(loader, 'eval batch', disable=not print_mode):
542
- # 有的模型forward里没有处理input的device问题,则需要在这里使用self.device设置
543
- # batched_inputs = batched_inputs.to(self.device) # 这一步可能不应该写在这里,还是先注释掉吧
544
- batch_y = self.model(batched_inputs).tolist()
545
- if self.target_transform:
546
- batch_y = [self.target_transform(y) for y in batch_y]
547
- if return_gt:
548
- gt = batched_inputs[1].tolist()
549
- preds += list(zip(*[gt, batch_y]))
550
- else:
551
- preds += batch_y
552
- return preds
553
-
554
- def __call__(self, raw_in, *, batch_size=None, y_placeholder=...,
555
- print_mode=False, return_gt=False):
556
- """ 前传推断结果
557
-
558
- :param batch_size: 具体运行中可以重新指定batch_size
559
- :param return_gt: 使用该功能,必须确保每次loader都含有[x,y],可能是raw_in自带,也可以用y_placeholder设置默认值
560
- 单样本:y, y_hat
561
- 多样本:[(y1, y_hat1), (y2, y_hat2), ...]
562
- :return:
563
- 单样本:y_hat
564
- 多样表:[y_hat1, y_hat2, ...]
565
-
566
- 根据不同model结构特殊性
567
- """
568
- loader = self.inputs2loader(raw_in, batch_size=batch_size, y_placeholder=y_placeholder)
569
- preds = self.forward(loader, print_mode=print_mode, return_gt=return_gt)
570
- # 返回结果,单样本的时候作简化
571
- if len(preds) == 1 and not isinstance(raw_in, (list, tuple, set)):
572
- return preds[0]
573
- else:
574
- return preds
575
-
576
-
577
- def setup_seed(seed):
578
- """ 完整的需要设置的随机数种子
579
-
580
- 不过个人实验有时候也不一定有用~~
581
- 还是有可能各种干扰因素导致模型无法复现
582
- """
583
- torch.manual_seed(seed)
584
- torch.cuda.manual_seed(seed)
585
- torch.cuda.manual_seed_all(seed)
586
- np.random.seed(seed)
587
- random.seed(seed)
588
- torch.backends.cudnn.benchmark = False
589
- torch.backends.cudnn.deterministic = True
590
-
591
-
592
- class TrainingSampler:
593
- """ 摘自detectron2,用来做无限循环的抽样
594
- 我这里的功能做了简化,只能支持单卡训练,原版可以支持多卡训练
595
-
596
- In training, we only care about the "infinite stream" of training data.
597
- So this sampler produces an infinite stream of indices and
598
- all workers cooperate to correctly shuffle the indices and sample different indices.
599
-
600
- The samplers in each worker effectively produces `indices[worker_id::num_workers]`
601
- where `indices` is an infinite stream of indices consisting of
602
- `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
603
- or `range(size) + range(size) + ...` (if shuffle is False)
604
- """
605
-
606
- def __init__(self, size: int, shuffle: bool = True):
607
- """
608
- Args:
609
- size (int): the total number of data of the underlying dataset to sample from
610
- shuffle (bool): whether to shuffle the indices or not
611
- """
612
- self._size = size
613
- assert size > 0
614
- self._shuffle = shuffle
615
-
616
- def __iter__(self):
617
- g = torch.Generator()
618
- while True:
619
- if self._shuffle:
620
- yield from torch.randperm(self._size, generator=g).tolist()
621
- else:
622
- yield from torch.arange(self._size).tolist()
623
-
624
-
625
- class ZcPredictor:
626
- """ 智财ocrwork框架的封装接口
627
-
628
- 这个本来是特用功能,不应该放这里的,反正也没啥不可公开的技术细节,为了使用方便就放这了
629
- """
630
-
631
- def __init__(self, config_file, *, gpu=None, batch_size=None, opts=None):
632
- """
633
- :param config_file: 支持输入配置文件路径,或者字符串格式的配置参数值
634
- :param gpu: 默认可以不设,会挑选当前最大剩余的一张卡
635
- 注意配置文件中也有gpu参数,在该接口模式下会被弃用
636
- :param batch_size: 每次能同时识别的最大图片数
637
- 注意config_file里也有batch_size,不过那是训练用的参数,跟这没必然联系,部署最好额外设置batch_size
638
- 该参数可以不设,默认每次传入多少张图,就同时多少张进行批处理
639
- :param opts: 除了配置文件的参数,可以自设字典,覆盖更新配置参数值,常用的参数有
640
- """
641
- from easydict import EasyDict
642
-
643
- # 1 配置参数
644
- if isinstance(config_file, str) and config_file[-5:].lower() == '.yaml':
645
- deploy_path = os.environ.get('OCRWORK_DEPLOY', '.') # 支持在环境变量自定义:部署所用的配置、模型所在目录
646
- config_file = os.path.join(deploy_path, config_file)
647
- f = open(config_file, "r")
648
- elif isinstance(config_file, str):
649
- f = io.StringIO(config_file)
650
- else:
651
- raise TypeError
652
- prepare_args = EasyDict(list(yaml.load_all(f, Loader=yaml.FullLoader))[0])
653
- f.close()
654
-
655
- # 2 特殊配置参数
656
- opts = opts or {}
657
- if gpu is not None:
658
- opts['gpu'] = str(gpu)
659
- if 'gpu' not in opts:
660
- # gpu没设置的时候,默认找一个空闲最大的显卡
661
- opts['gpu'] = NvmDevice().get_most_free_gpu_id()
662
- if 'gpu' in opts: # 智财的配置好像必须要写字符串
663
- opts['gpu'] = str(opts['gpu'])
664
- prepare_args.update(opts)
665
-
666
- # 3 初始化各组件
667
- self.prepare_args = prepare_args
668
- self.batch_size = batch_size
669
- self.transform = lambda x: xlcv.read(x, 1) # 默认统一转cv2的图片格式
670
- # self.transform = lambda x: PilPrcs.read(x, 1) # 也可以使用pil图片格式
671
-
672
- def forward(self, imgs):
673
- raise NotImplemented('子类必须实现forward方法')
674
-
675
- def __call__(self, raw_in, *, batch_size=None, progress=False):
676
- """ 智财的框架,dataloader默认不需要对齐,重置collate_fn
677
- (其实不是不需要对齐,而是其augument组件会处理)
678
-
679
- :return: 以多个结果为例
680
- preds结果是list
681
- pred = preds[0]
682
- pred也是list,是第0张图的所有检测框,比如一共8个
683
- 每个框是 4*2 的numpy矩阵(整数)
684
- """
685
- # 1 判断长度
686
- if not getattr(raw_in, '__len__', None):
687
- imgs = [raw_in]
688
- else:
689
- imgs = raw_in
690
- n = len(imgs)
691
- batch_size = first_nonnone([batch_size, self.batch_size, n])
692
-
693
- # 2 一段一段处理
694
- preds = []
695
- t = tqdm(desc='forward', total=n, disable=not progress)
696
- for i in range(0, n, batch_size):
697
- inputs = imgs[i:i + batch_size]
698
- preds += self.forward([self.transform(img) for img in inputs])
699
- t.update(len(inputs))
700
-
701
- # 3 返回结果,单样本的时候作简化
702
- if len(preds) == 1 and not getattr(raw_in, '__len__', None):
703
- return preds[0]
704
- else:
705
- return preds