pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/algo/geo.py +12 -0
- pyxllib/algo/intervals.py +1 -1
- pyxllib/algo/matcher.py +78 -0
- pyxllib/algo/pupil.py +187 -19
- pyxllib/algo/specialist.py +2 -1
- pyxllib/algo/stat.py +38 -2
- {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/data/echarts.py +123 -12
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/data/pglib.py +514 -30
- pyxllib/data/sqlite.py +231 -4
- pyxllib/ext/JLineViewer.py +14 -1
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +0 -1594
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/unixlib.py +6 -5
- pyxllib/ext/utools.py +108 -95
- pyxllib/ext/webhook.py +32 -14
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1003 -71
- pyxllib/file/docxlib.py +1 -1
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +9 -0
- pyxllib/file/packlib/__init__.py +112 -75
- pyxllib/file/pdflib.py +1 -1
- pyxllib/file/pupil.py +1 -1
- pyxllib/file/specialist/dirlib.py +1 -1
- pyxllib/file/specialist/download.py +10 -3
- pyxllib/file/specialist/filelib.py +266 -55
- pyxllib/file/xlsxlib.py +205 -50
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +9 -2
- pyxllib/prog/pupil.py +129 -60
- pyxllib/prog/specialist/__init__.py +176 -2
- pyxllib/prog/specialist/bc.py +5 -2
- pyxllib/prog/specialist/browser.py +11 -2
- pyxllib/prog/specialist/datetime.py +68 -0
- pyxllib/prog/specialist/tictoc.py +12 -13
- pyxllib/prog/specialist/xllog.py +5 -5
- pyxllib/prog/xlosenv.py +7 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +17 -5
- pyxllib/text/jiebalib.py +6 -3
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +159 -4
- pyxllib/text/nestenv.py +1 -1
- pyxllib/text/newbie.py +12 -0
- pyxllib/text/pupil/common.py +26 -0
- pyxllib/text/specialist/ptag.py +2 -2
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/xmllib.py +76 -14
- pyxllib/xl.py +2 -1
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
- pyxllib/ext/autogui/__init__.py +0 -8
- pyxllib-0.3.96.dist-info/METADATA +0 -51
- pyxllib-0.3.96.dist-info/RECORD +0 -333
- pyxllib-0.3.96.dist-info/top_level.txt +0 -2
- pyxlpr/ai/__init__.py +0 -5
- pyxlpr/ai/clientlib.py +0 -1281
- pyxlpr/ai/specialist.py +0 -286
- pyxlpr/ai/torch_app.py +0 -172
- pyxlpr/ai/xlpaddle.py +0 -655
- pyxlpr/ai/xltorch.py +0 -705
- pyxlpr/data/__init__.py +0 -11
- pyxlpr/data/coco.py +0 -1325
- pyxlpr/data/datacls.py +0 -365
- pyxlpr/data/datasets.py +0 -200
- pyxlpr/data/gptlib.py +0 -1291
- pyxlpr/data/icdar/__init__.py +0 -96
- pyxlpr/data/icdar/deteval.py +0 -377
- pyxlpr/data/icdar/icdar2013.py +0 -341
- pyxlpr/data/icdar/iou.py +0 -340
- pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
- pyxlpr/data/imtextline.py +0 -473
- pyxlpr/data/labelme.py +0 -866
- pyxlpr/data/removeline.py +0 -179
- pyxlpr/data/specialist.py +0 -57
- pyxlpr/eval/__init__.py +0 -85
- pyxlpr/paddleocr.py +0 -776
- pyxlpr/ppocr/__init__.py +0 -15
- pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
- pyxlpr/ppocr/data/__init__.py +0 -135
- pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
- pyxlpr/ppocr/data/imaug/__init__.py +0 -67
- pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
- pyxlpr/ppocr/data/imaug/east_process.py +0 -437
- pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
- pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
- pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
- pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
- pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
- pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
- pyxlpr/ppocr/data/imaug/operators.py +0 -433
- pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
- pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
- pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
- pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
- pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
- pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
- pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
- pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
- pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
- pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
- pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
- pyxlpr/ppocr/data/simple_dataset.py +0 -372
- pyxlpr/ppocr/losses/__init__.py +0 -61
- pyxlpr/ppocr/losses/ace_loss.py +0 -52
- pyxlpr/ppocr/losses/basic_loss.py +0 -135
- pyxlpr/ppocr/losses/center_loss.py +0 -88
- pyxlpr/ppocr/losses/cls_loss.py +0 -30
- pyxlpr/ppocr/losses/combined_loss.py +0 -67
- pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
- pyxlpr/ppocr/losses/det_db_loss.py +0 -80
- pyxlpr/ppocr/losses/det_east_loss.py +0 -63
- pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
- pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
- pyxlpr/ppocr/losses/distillation_loss.py +0 -272
- pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
- pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
- pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
- pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
- pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
- pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
- pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
- pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
- pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
- pyxlpr/ppocr/losses/table_att_loss.py +0 -109
- pyxlpr/ppocr/metrics/__init__.py +0 -44
- pyxlpr/ppocr/metrics/cls_metric.py +0 -45
- pyxlpr/ppocr/metrics/det_metric.py +0 -82
- pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
- pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
- pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
- pyxlpr/ppocr/metrics/kie_metric.py +0 -70
- pyxlpr/ppocr/metrics/rec_metric.py +0 -75
- pyxlpr/ppocr/metrics/table_metric.py +0 -50
- pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
- pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
- pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
- pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
- pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
- pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
- pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
- pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
- pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
- pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
- pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
- pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
- pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
- pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
- pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
- pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
- pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
- pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
- pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
- pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
- pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
- pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
- pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
- pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
- pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
- pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
- pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
- pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
- pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
- pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
- pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
- pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
- pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
- pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
- pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
- pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
- pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
- pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
- pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
- pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
- pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
- pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
- pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
- pyxlpr/ppocr/optimizer/__init__.py +0 -61
- pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
- pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
- pyxlpr/ppocr/optimizer/optimizer.py +0 -160
- pyxlpr/ppocr/optimizer/regularizer.py +0 -52
- pyxlpr/ppocr/postprocess/__init__.py +0 -55
- pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
- pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
- pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
- pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
- pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
- pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
- pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
- pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
- pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
- pyxlpr/ppocr/tools/__init__.py +0 -14
- pyxlpr/ppocr/tools/eval.py +0 -83
- pyxlpr/ppocr/tools/export_center.py +0 -77
- pyxlpr/ppocr/tools/export_model.py +0 -129
- pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
- pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
- pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
- pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
- pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
- pyxlpr/ppocr/tools/infer/utility.py +0 -629
- pyxlpr/ppocr/tools/infer_cls.py +0 -83
- pyxlpr/ppocr/tools/infer_det.py +0 -134
- pyxlpr/ppocr/tools/infer_e2e.py +0 -122
- pyxlpr/ppocr/tools/infer_kie.py +0 -153
- pyxlpr/ppocr/tools/infer_rec.py +0 -146
- pyxlpr/ppocr/tools/infer_table.py +0 -107
- pyxlpr/ppocr/tools/program.py +0 -596
- pyxlpr/ppocr/tools/test_hubserving.py +0 -117
- pyxlpr/ppocr/tools/train.py +0 -163
- pyxlpr/ppocr/tools/xlprog.py +0 -748
- pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
- pyxlpr/ppocr/utils/__init__.py +0 -24
- pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
- pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
- pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
- pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
- pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
- pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
- pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
- pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
- pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
- pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
- pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
- pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
- pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
- pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
- pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
- pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
- pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
- pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
- pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
- pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
- pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
- pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
- pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
- pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
- pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
- pyxlpr/ppocr/utils/dict90.txt +0 -90
- pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
- pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
- pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
- pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
- pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
- pyxlpr/ppocr/utils/en_dict.txt +0 -95
- pyxlpr/ppocr/utils/gen_label.py +0 -81
- pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
- pyxlpr/ppocr/utils/iou.py +0 -54
- pyxlpr/ppocr/utils/logging.py +0 -69
- pyxlpr/ppocr/utils/network.py +0 -84
- pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
- pyxlpr/ppocr/utils/profiler.py +0 -110
- pyxlpr/ppocr/utils/save_load.py +0 -150
- pyxlpr/ppocr/utils/stats.py +0 -72
- pyxlpr/ppocr/utils/utility.py +0 -80
- pyxlpr/ppstructure/__init__.py +0 -13
- pyxlpr/ppstructure/predict_system.py +0 -187
- pyxlpr/ppstructure/table/__init__.py +0 -13
- pyxlpr/ppstructure/table/eval_table.py +0 -72
- pyxlpr/ppstructure/table/matcher.py +0 -192
- pyxlpr/ppstructure/table/predict_structure.py +0 -136
- pyxlpr/ppstructure/table/predict_table.py +0 -221
- pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
- pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
- pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
- pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
- pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
- pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
- pyxlpr/ppstructure/utility.py +0 -71
- pyxlpr/xlai.py +0 -10
- /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
pyxlpr/ai/xltorch.py
DELETED
@@ -1,705 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
# @Author : 陈坤泽
|
4
|
-
# @Email : 877362867@qq.com
|
5
|
-
# @Date : 2021/06/06 23:10
|
6
|
-
|
7
|
-
|
8
|
-
from pyxllib.xlcv import *
|
9
|
-
|
10
|
-
import torch
|
11
|
-
from torch import nn, optim
|
12
|
-
import torch.utils.data
|
13
|
-
|
14
|
-
import torchvision
|
15
|
-
from torchvision import transforms
|
16
|
-
|
17
|
-
# 把pytorch等常用的导入写了
|
18
|
-
import torch.utils.data
|
19
|
-
from torchvision.datasets import VisionDataset
|
20
|
-
|
21
|
-
from pyxlpr.ai.specialist import ClasEvaluater, NvmDevice
|
22
|
-
|
23
|
-
__base = """
|
24
|
-
"""
|
25
|
-
|
26
|
-
|
27
|
-
def get_most_free_torch_gpu_device():
|
28
|
-
gpu_id = NvmDevice().get_most_free_gpu_id()
|
29
|
-
if gpu_id is not None:
|
30
|
-
return torch.device(f'cuda:{gpu_id}')
|
31
|
-
|
32
|
-
|
33
|
-
def get_device():
|
34
|
-
""" 自动获得一个可用的设备
|
35
|
-
"""
|
36
|
-
return get_most_free_torch_gpu_device() or torch.device('cpu')
|
37
|
-
|
38
|
-
|
39
|
-
__data = """
|
40
|
-
"""
|
41
|
-
|
42
|
-
|
43
|
-
class TinyDataset(torch.utils.data.Dataset):
|
44
|
-
def __init__(self, labelfile, label_transform, maxn=None):
|
45
|
-
""" 超轻量级的Dataset类,一般由外部ProjectData类指定每行label的转换规则 """
|
46
|
-
self.labels = File(labelfile).read().splitlines()
|
47
|
-
self.label_transform = label_transform
|
48
|
-
|
49
|
-
self.number = len(self.labels)
|
50
|
-
if maxn: self.number = min(self.number, maxn)
|
51
|
-
|
52
|
-
def __len__(self):
|
53
|
-
return self.number
|
54
|
-
|
55
|
-
def __getitem__(self, idx):
|
56
|
-
return self.label_transform(self.labels[idx])
|
57
|
-
|
58
|
-
|
59
|
-
class InputDataset(torch.utils.data.Dataset):
|
60
|
-
def __init__(self, raw_in, transform=None, *, y_placeholder=...):
|
61
|
-
""" 将非list、tuple数据转为list,并生成一个dataset类的万用类
|
62
|
-
:param raw_in:
|
63
|
-
"""
|
64
|
-
if not isinstance(raw_in, (list, tuple)):
|
65
|
-
raw_in = [raw_in]
|
66
|
-
|
67
|
-
self.data = raw_in
|
68
|
-
self.transform = transform
|
69
|
-
self.y_placeholder = y_placeholder
|
70
|
-
|
71
|
-
def __len__(self):
|
72
|
-
return len(self.data)
|
73
|
-
|
74
|
-
def __getitem__(self, idx):
|
75
|
-
x = self.data[idx]
|
76
|
-
if self.transform:
|
77
|
-
x = self.transform(x)
|
78
|
-
|
79
|
-
if self.y_placeholder is not ...:
|
80
|
-
return x, self.y_placeholder
|
81
|
-
else:
|
82
|
-
return x
|
83
|
-
|
84
|
-
|
85
|
-
__model = """
|
86
|
-
"""
|
87
|
-
|
88
|
-
|
89
|
-
class LeNet5(nn.Module):
|
90
|
-
""" https://towardsdatascience.com/implementing-yann-lecuns-lenet-5-in-pytorch-5e05a0911320 """
|
91
|
-
|
92
|
-
def __init__(self, n_classes):
|
93
|
-
super().__init__()
|
94
|
-
|
95
|
-
self.feature_extractor = nn.Sequential(
|
96
|
-
nn.Conv2d(1, 6, kernel_size=5, stride=1),
|
97
|
-
nn.Tanh(),
|
98
|
-
nn.AvgPool2d(kernel_size=2),
|
99
|
-
nn.Conv2d(6, 16, kernel_size=5, stride=1),
|
100
|
-
nn.Tanh(),
|
101
|
-
nn.AvgPool2d(kernel_size=2),
|
102
|
-
nn.Conv2d(16, 120, kernel_size=5, stride=1),
|
103
|
-
nn.Tanh()
|
104
|
-
)
|
105
|
-
|
106
|
-
self.classifier = nn.Sequential(
|
107
|
-
nn.Linear(120, 84),
|
108
|
-
nn.Tanh(),
|
109
|
-
nn.Linear(84, n_classes),
|
110
|
-
)
|
111
|
-
|
112
|
-
def forward(self, batched_inputs):
|
113
|
-
device = next(self.parameters()).device
|
114
|
-
|
115
|
-
x = batched_inputs[0].to(device)
|
116
|
-
x = self.feature_extractor(x)
|
117
|
-
x = torch.flatten(x, 1)
|
118
|
-
logits = self.classifier(x)
|
119
|
-
|
120
|
-
if self.training:
|
121
|
-
y = batched_inputs[1].to(device)
|
122
|
-
return nn.functional.cross_entropy(logits, y)
|
123
|
-
else:
|
124
|
-
return logits.argmax(dim=1)
|
125
|
-
|
126
|
-
|
127
|
-
__train = """
|
128
|
-
"""
|
129
|
-
|
130
|
-
|
131
|
-
class Trainer:
|
132
|
-
""" 对pytorch模型的训练、测试等操作的进一步封装
|
133
|
-
|
134
|
-
# TODO log变成可选项,可以关掉
|
135
|
-
"""
|
136
|
-
|
137
|
-
def __init__(self, log_dir, device, data, model, optimizer,
|
138
|
-
loss_func=None, pred_func=None, accuracy_func=None):
|
139
|
-
# 0 初始化成员变量
|
140
|
-
self.log_dir, self.device = log_dir, device
|
141
|
-
self.data, self.model, self.optimizer = data, model, optimizer
|
142
|
-
if loss_func: self.loss_func = loss_func
|
143
|
-
if pred_func: self.pred_func = pred_func
|
144
|
-
if accuracy_func: self.accuracy_func = accuracy_func
|
145
|
-
|
146
|
-
# 1 日志
|
147
|
-
timetag = datetime.datetime.now().strftime('%Y%m%d.%H%M%S')
|
148
|
-
# self.curlog_dir = Dir(self.log_dir / timetag) # 本轮运行,实际log位置,是存放在一个子目录里
|
149
|
-
self.curlog_dir = Dir(self.log_dir)
|
150
|
-
self.curlog_dir.ensure_dir()
|
151
|
-
self.log = get_xllog(log_file=self.curlog_dir / 'log.txt')
|
152
|
-
self.log.info(f'1/4 log_dir={self.curlog_dir}')
|
153
|
-
|
154
|
-
# 2 设备
|
155
|
-
self.log.info(f'2/4 use_device={self.device}')
|
156
|
-
|
157
|
-
# 3 数据
|
158
|
-
self.train_dataloader = self.data.get_train_dataloader()
|
159
|
-
self.val_dataloader = self.data.get_val_dataloader()
|
160
|
-
self.train_data_number = len(self.train_dataloader.dataset)
|
161
|
-
self.val_data_number = len(self.val_dataloader.dataset)
|
162
|
-
self.log.info(f'3/4 get data, train_data_number={self.train_data_number}(batch={len(self.train_dataloader)}), '
|
163
|
-
f'val_data_number={self.val_data_number}(batch={len(self.val_dataloader)}), '
|
164
|
-
f'batch_size={self.data.batch_size}')
|
165
|
-
|
166
|
-
# 4 模型
|
167
|
-
parasize = sum(map(lambda p: p.numel(), self.model.parameters()))
|
168
|
-
self.log.info(
|
169
|
-
f'4/4 model parameters size: {parasize}* 4 Bytes per float ≈ {humanfriendly.format_size(parasize * 4)}')
|
170
|
-
|
171
|
-
# 5 其他辅助变量
|
172
|
-
self.min_total_loss = math.inf # 目前epoch中总损失最小的值(训练损失,训练过程损失)
|
173
|
-
self.min_train_loss = math.inf # 训练集损失
|
174
|
-
self.max_val_accuracy = 0 # 验证集精度
|
175
|
-
|
176
|
-
@classmethod
|
177
|
-
def loss_func(cls, model_out, y):
|
178
|
-
""" 自定义损失函数 """
|
179
|
-
# return loss
|
180
|
-
raise NotImplementedError
|
181
|
-
|
182
|
-
@classmethod
|
183
|
-
def pred_func(cls, model_out):
|
184
|
-
""" 自定义模型输出到预测结果 """
|
185
|
-
# return y_hat
|
186
|
-
raise NotImplementedError
|
187
|
-
|
188
|
-
@classmethod
|
189
|
-
def accuracy_func(cls, y_hat, y):
|
190
|
-
""" 自定义预测结果和实际标签y之间的精度
|
191
|
-
|
192
|
-
返回"正确的样本数"(在非分类任务中,需要抽象出这个数量关系)
|
193
|
-
"""
|
194
|
-
# return accuracy
|
195
|
-
raise NotImplementedError
|
196
|
-
|
197
|
-
def loss_values_stat(self, loss_vales):
|
198
|
-
""" 一组loss损失的统计分析
|
199
|
-
|
200
|
-
:param loss_vales: 一次batch中,多份样本产生的误差数据
|
201
|
-
:return: 统计信息文本字符串
|
202
|
-
"""
|
203
|
-
if not loss_vales:
|
204
|
-
raise ValueError
|
205
|
-
|
206
|
-
data = np.array(loss_vales, dtype=float)
|
207
|
-
n, sum_ = len(data), data.sum()
|
208
|
-
mean, std = data.mean(), data.std()
|
209
|
-
msg = f'total_loss={sum_:.3f}, mean±std={mean:.3f}±{std:.3f}({max(data):.3f}->{min(data):.3f})'
|
210
|
-
if sum_ < self.min_total_loss:
|
211
|
-
self.min_total_loss = sum_
|
212
|
-
msg = '*' + msg
|
213
|
-
return msg
|
214
|
-
|
215
|
-
@classmethod
|
216
|
-
def sample_size(cls, data):
|
217
|
-
""" 单个样本占用的空间大小,返回字节数 """
|
218
|
-
x, label = data.dataset[0] # 取第0个样本作为参考
|
219
|
-
return getasizeof(x.numpy()) + getasizeof(label)
|
220
|
-
|
221
|
-
def save_model_state(self, file, if_exists='error'):
|
222
|
-
""" 保存模型参数值
|
223
|
-
一般存储model.state_dict,而不是直接存储model,确保灵活性
|
224
|
-
|
225
|
-
# TODO 和path结合,增加if_exists参数
|
226
|
-
"""
|
227
|
-
f = File(file, self.curlog_dir)
|
228
|
-
if f.exist_preprcs(if_exists=if_exists):
|
229
|
-
f.ensure_parent()
|
230
|
-
torch.save(self.model.state_dict(), str(f))
|
231
|
-
|
232
|
-
def load_model_state(self, file):
|
233
|
-
""" 读取模型参数值
|
234
|
-
|
235
|
-
注意load和save的root差异! load的默认父目录是在log_dir,而save默认是在curlog_dir!
|
236
|
-
"""
|
237
|
-
f = File(file, self.log_dir)
|
238
|
-
self.model.load_state_dict(torch.load(str(f), map_location=self.device))
|
239
|
-
|
240
|
-
def viz_data(self):
|
241
|
-
""" 用visdom显示样本数据
|
242
|
-
|
243
|
-
TODO 增加一些自定义格式参数
|
244
|
-
TODO 不能使用\n、\r\n、<br/>实现文本换行,有时间可以研究下,结合nrow、图片宽度,自动推算,怎么美化展示效果
|
245
|
-
"""
|
246
|
-
from visdom import Visdom
|
247
|
-
|
248
|
-
viz = Visdom()
|
249
|
-
if not viz: return
|
250
|
-
|
251
|
-
x, label = next(iter(self.train_dataloader))
|
252
|
-
viz.one_batch_images(x, label, 'train data')
|
253
|
-
|
254
|
-
x, label = next(iter(self.val_dataloader))
|
255
|
-
viz.one_batch_images(x, label, 'val data')
|
256
|
-
|
257
|
-
def training_one_epoch(self):
|
258
|
-
# 1 检查模式
|
259
|
-
if not self.model.training:
|
260
|
-
self.model.train(True)
|
261
|
-
|
262
|
-
# 2 训练一轮
|
263
|
-
loss_values = []
|
264
|
-
for x, y in self.train_dataloader:
|
265
|
-
# 每个batch可能很大,所以每个batch依次放到cuda,而不是一次性全放入
|
266
|
-
x, y = x.to(self.device), y.to(self.device)
|
267
|
-
|
268
|
-
y_hat = self.model(x)
|
269
|
-
loss = self.loss_func(y_hat, y)
|
270
|
-
loss_values.append(float(loss))
|
271
|
-
|
272
|
-
self.optimizer.zero_grad()
|
273
|
-
loss.backward()
|
274
|
-
self.optimizer.step()
|
275
|
-
|
276
|
-
# 3 训练阶段只看loss,不看实际预测准确度,默认每个epoch都会输出
|
277
|
-
return loss_values
|
278
|
-
|
279
|
-
def calculate_accuracy(self, dataloader):
|
280
|
-
""" 测试验证集等数据上的精度 """
|
281
|
-
# 1 eval模式
|
282
|
-
if self.model.training:
|
283
|
-
self.model.train(False)
|
284
|
-
|
285
|
-
# 2 关闭梯度,可以节省显存和加速
|
286
|
-
with torch.no_grad():
|
287
|
-
tt = TicToc()
|
288
|
-
|
289
|
-
# 预测结果,计算正确率
|
290
|
-
loss, correct, number = [], 0, len(dataloader.dataset)
|
291
|
-
for x, y in dataloader:
|
292
|
-
x, y = x.to(self.device), y.to(self.device)
|
293
|
-
model_out = self.model(x)
|
294
|
-
loss.append(self.loss_func(model_out, y))
|
295
|
-
y_hat = self.pred_func(model_out)
|
296
|
-
correct += self.accuracy_func(y_hat, y) # 预测正确的数量
|
297
|
-
elapsed_time, mean_loss = tt.tocvalue(), np.mean(loss, dtype=float)
|
298
|
-
accuracy = correct / number
|
299
|
-
info = f'accuracy={correct:.0f}/{number} ({accuracy:.2%})\t' \
|
300
|
-
f'mean_loss={mean_loss:.3f}\telapsed_time={elapsed_time:.0f}s'
|
301
|
-
return accuracy, mean_loss, info
|
302
|
-
|
303
|
-
def train_accuracy(self):
|
304
|
-
accuracy, mean_loss, info = self.calculate_accuracy(self.train_dataloader)
|
305
|
-
info = 'train ' + info
|
306
|
-
if mean_loss < self.min_train_loss:
|
307
|
-
# 如果是best ever,则log换成debug模式输出
|
308
|
-
self.log.debug('*' + info)
|
309
|
-
self.min_train_loss = mean_loss
|
310
|
-
else:
|
311
|
-
self.log.info(info)
|
312
|
-
return accuracy
|
313
|
-
|
314
|
-
def val_accuracy(self, save_model=None):
|
315
|
-
"""
|
316
|
-
:param save_model: 如果验证集精度best ever,则保存当前epoch模型
|
317
|
-
如果精度不是最好的,哪怕指定save_model也不会保存的
|
318
|
-
:return:
|
319
|
-
"""
|
320
|
-
accuracy, mean_loss, info = self.calculate_accuracy(self.val_dataloader)
|
321
|
-
info = ' val ' + info
|
322
|
-
if accuracy > self.max_val_accuracy:
|
323
|
-
self.log.debug('*' + info)
|
324
|
-
if save_model:
|
325
|
-
self.save_model_state(save_model, if_exists='replace')
|
326
|
-
self.max_val_accuracy = accuracy
|
327
|
-
else:
|
328
|
-
self.log.info(info)
|
329
|
-
return accuracy
|
330
|
-
|
331
|
-
def training(self, epochs, *, start_epoch=0, log_interval=1):
|
332
|
-
""" 主要训练接口
|
333
|
-
|
334
|
-
:param epochs: 训练代数,输出时从1开始编号
|
335
|
-
:param start_epoch: 直接从现有的第几个epoch的模型读取参数
|
336
|
-
使用该参数,需要在self.save_dir有对应名称的model文件
|
337
|
-
:param log_interval: 每隔几个epoch输出当前epoch的训练情况,损失值
|
338
|
-
每个几轮epoch进行一次监控
|
339
|
-
且如果总损失是训练以来最好的结果,则会保存模型
|
340
|
-
并对训练集、测试集进行精度测试
|
341
|
-
TODO 看到其他框架,包括智财的框架,对保存的模型文件,都有更规范的一套命名方案,有空要去学一下
|
342
|
-
:return:
|
343
|
-
"""
|
344
|
-
from visdom import Visdom
|
345
|
-
|
346
|
-
# 1 配置参数
|
347
|
-
tag = self.model.__class__.__name__
|
348
|
-
epoch_time_tag = f'elapsed_time' if log_interval == 1 else f'{log_interval}*epoch_time'
|
349
|
-
viz = Visdom() # 其实这里不是用原生的Visdom,而是我封装过的,但是我封装的那个也没太大作用意义,删掉了
|
350
|
-
|
351
|
-
# 2 加载之前的模型继续训练
|
352
|
-
if start_epoch:
|
353
|
-
self.load_model_state(f'{tag} epoch{start_epoch}.pth')
|
354
|
-
|
355
|
-
# 3 训练
|
356
|
-
tt = TicToc()
|
357
|
-
for epoch in range(start_epoch + 1, epochs + 1):
|
358
|
-
loss_values = self.training_one_epoch()
|
359
|
-
# 3.1 训练损失可视化
|
360
|
-
if viz: viz.loss_line(loss_values, epoch, 'train_loss')
|
361
|
-
# 3.2 显示epoch训练效果
|
362
|
-
if log_interval and epoch % log_interval == 0:
|
363
|
-
# 3.2.1 显示训练用时、训练损失
|
364
|
-
msg = self.loss_values_stat(loss_values)
|
365
|
-
elapsed_time = tt.tocvalue(restart=True)
|
366
|
-
info = f'epoch={epoch}, {epoch_time_tag}={elapsed_time:.0f}s\t{msg.lstrip("*")}'
|
367
|
-
# 3.2.2 截止目前训练损失最小的结果
|
368
|
-
if msg[0] == '*':
|
369
|
-
self.log.debug('*' + info)
|
370
|
-
# 3.2.2.1 测试训练集、验证集上的精度
|
371
|
-
accuracy1 = self.train_accuracy()
|
372
|
-
accuracy2 = self.val_accuracy(save_model=f'{tag} epoch{epoch}.pth')
|
373
|
-
# 3.2.2.2 可视化图表
|
374
|
-
if viz: viz.plot_line([[accuracy1, accuracy2]], [epoch], 'accuracy', legend=['train', 'val'])
|
375
|
-
else:
|
376
|
-
self.log.info(info)
|
377
|
-
|
378
|
-
|
379
|
-
@deprecated(reason='推荐使用XlPredictor实现')
|
380
|
-
def gen_classification_func(model, *, state_file=None, transform=None, pred_func=None,
|
381
|
-
device=None):
|
382
|
-
""" 工厂函数,生成一个分类器函数
|
383
|
-
|
384
|
-
用这个函数做过渡的一个重要目的,也是避免重复加载模型
|
385
|
-
|
386
|
-
:param model: 模型结构
|
387
|
-
:param state_file: 存储参数的文件
|
388
|
-
:param transform: 每一个输入样本的预处理函数
|
389
|
-
:param pred_func: model 结果的参数的后处理
|
390
|
-
:return: 返回的函数结构见下述 cls_func
|
391
|
-
"""
|
392
|
-
if state_file: model.load_state_dict(torch.load(str(state_file), map_location=get_device()))
|
393
|
-
model.train(False)
|
394
|
-
device = device or get_device()
|
395
|
-
model.to(device)
|
396
|
-
|
397
|
-
def cls_func(raw_in):
|
398
|
-
"""
|
399
|
-
:param raw_in: 输入可以是路径、np.ndarray、PIL图片等,都为转为batch结构的tensor
|
400
|
-
im,一张图片路径、np.ndarray、PIL图片
|
401
|
-
[im1, im2, ...],多张图片清单
|
402
|
-
:return: 输入如果只有一张图片,则返回一个结果
|
403
|
-
否则会存在list,返回多个结果
|
404
|
-
"""
|
405
|
-
dataset = InputDataset(raw_in, transform)
|
406
|
-
# TODO batch_size根据device空间大小自适应设置
|
407
|
-
xs = torch.utils.data.DataLoader(dataset, batch_size=8)
|
408
|
-
res = None
|
409
|
-
for x in xs:
|
410
|
-
# 每个batch可能很大,所以每个batch依次放到cuda,而不是一次性全放入
|
411
|
-
x = x.to(device)
|
412
|
-
y = model(x)
|
413
|
-
if pred_func: y = pred_func(y)
|
414
|
-
res = y if res is None else (res + y)
|
415
|
-
return res
|
416
|
-
|
417
|
-
return cls_func
|
418
|
-
|
419
|
-
|
420
|
-
class XlPredictor:
|
421
|
-
""" 生成一个类似函数用法的推断功能类
|
422
|
-
|
423
|
-
这是一个通用的生成器,不同的业务可以继承开发,进一步设计细则
|
424
|
-
|
425
|
-
这里默认写的结构是兼容detectron2框架的分类模型,即model.forward:
|
426
|
-
输入:list,第1个是batch_x,第2个是batch_y
|
427
|
-
输出:training是logits,eval是(batch)y_hat
|
428
|
-
"""
|
429
|
-
|
430
|
-
def __init__(self, model, state_file=None, device=None, *, batch_size=1, y_placeholder=...):
|
431
|
-
"""
|
432
|
-
:param model: 基于d2框架的模型结构
|
433
|
-
:param state_file: 存储权重的文件
|
434
|
-
一般写某个本地文件路径
|
435
|
-
也可以写url地址,会下载存储到临时目录中
|
436
|
-
可以不传入文件,直接给到初始化好权重的model,该模式常用语训练阶段的model
|
437
|
-
:param batch_size: 支持每次最多几个样本一起推断
|
438
|
-
具体运作细节参见 XlPredictor.inputs2loader的解释
|
439
|
-
TODO batch_size根据device空间大小自适应设置
|
440
|
-
:param y_placeholder: 参见XlPredictor.inputs2loader的解释
|
441
|
-
"""
|
442
|
-
# 默认使用model所在的device
|
443
|
-
if device is None:
|
444
|
-
self.device = next(model.parameters()).device
|
445
|
-
else:
|
446
|
-
self.device = device
|
447
|
-
|
448
|
-
if state_file is not None:
|
449
|
-
if is_url(state_file):
|
450
|
-
state_file = download(state_file, XlPath.tempdir() / 'xlpr')
|
451
|
-
state = torch.load(str(state_file), map_location=self.device)
|
452
|
-
if 'model' in state:
|
453
|
-
state = state['model']
|
454
|
-
model = model.to(device)
|
455
|
-
model.load_state_dict(state)
|
456
|
-
|
457
|
-
self.model = model
|
458
|
-
self.model.train(False)
|
459
|
-
|
460
|
-
self.batch_size = batch_size
|
461
|
-
self.y_placeholder = y_placeholder
|
462
|
-
|
463
|
-
self.transform = self.build_transform()
|
464
|
-
self.target_transform = self.build_target_transform()
|
465
|
-
|
466
|
-
@classmethod
|
467
|
-
def build_transform(cls):
|
468
|
-
""" 单个数据的转换规则,进入模型前的读取、格式转换
|
469
|
-
|
470
|
-
为了效率性能,建议比较特殊的不同初始化策略,可以额外定义函数接口,例如:def from_paths()
|
471
|
-
"""
|
472
|
-
return None
|
473
|
-
|
474
|
-
@classmethod
|
475
|
-
def build_target_transform(cls):
|
476
|
-
""" 单个结果的转换的规则,模型预测完的结果,到最终结果的转换方式
|
477
|
-
|
478
|
-
一些简单的情况直接返回y即可,但还有些复杂的任务可能要增加后处理
|
479
|
-
"""
|
480
|
-
return None
|
481
|
-
|
482
|
-
def inputs2loader(self, raw_in, *, batch_size=None, y_placeholder=..., sampler=None, **kwargs):
|
483
|
-
""" 将各种类列表数据,转成torch.utils.data.DataLoader类型
|
484
|
-
|
485
|
-
:param raw_in: 各种类列表数据格式,或者单个数据,都为转为batch结构的tensor
|
486
|
-
torch.util.data.DataLoader
|
487
|
-
此时XlPredictor自定义参数全部无效:transform、batch_size、y_placeholder,sampler
|
488
|
-
因为这些在loader里都有配置了
|
489
|
-
torch.util.data.Dataset
|
490
|
-
此时可以定制扩展的参数有:batch_size,sampler
|
491
|
-
[data1, data2, ...],列表表示批量处理多个数据
|
492
|
-
此时所有配置参数均可用:transform、batch_size、y_placeholder, sampler
|
493
|
-
通常是图片文件路径清单
|
494
|
-
XlPredictor原生并没有扩展图片读取功能,但可以通过transform增加CvPrcs.read来支持
|
495
|
-
single_data,单个数据
|
496
|
-
通常是单个图片文件路径,注意transfrom要增加xlcv.read或xlpil.read来支持路径读取
|
497
|
-
注意:有时候单个数据就是list格式,此时需要麻烦点,再套一层list避免歧义
|
498
|
-
:param batch_size: 支持每次最多几个样本一起推断
|
499
|
-
TODO batch_size根据device空间大小自适应设置
|
500
|
-
:param y_placeholder: 常见的model.forward,是只输入batch_x就行,这时候就默认值处理机制就行
|
501
|
-
但我从d2框架模仿的写法,forward需要补一个y的真实值,输入是[batch_x, batch_y]
|
502
|
-
实际预测数据可能没有y,此时需要填充一个batch_y=None来对齐,即设置y_placeholder=None
|
503
|
-
或者y_placeholder=0,则所有的y用0作为占位符填充
|
504
|
-
不过用None、0、False这些填充都很诡异,容易误导开发者,建议需要设置的时候使用-1
|
505
|
-
|
506
|
-
如果读者写的model.forward前传机制不同,本来batch_inputs就只输入x没有y,则这里不用设置y_placeholder参数
|
507
|
-
:param sampler: 有时候只是要简单抽样部分数据测试下,可以设置该参数
|
508
|
-
比如random.sample(range(10), 5):可以从前10个数据中,无放回随机抽取5个数据
|
509
|
-
"""
|
510
|
-
if isinstance(raw_in, torch.utils.data.DataLoader):
|
511
|
-
loader = raw_in
|
512
|
-
else:
|
513
|
-
if not isinstance(raw_in, torch.utils.data.Dataset):
|
514
|
-
y_placeholder = first_nonnone([y_placeholder, self.y_placeholder], lambda x: x is not ...)
|
515
|
-
dataset = InputDataset(raw_in, self.transform, y_placeholder=y_placeholder)
|
516
|
-
else:
|
517
|
-
if not isinstance(raw_in, (list, tuple)):
|
518
|
-
raw_in = [raw_in]
|
519
|
-
dataset = raw_in
|
520
|
-
batch_size = first_nonnone([batch_size, self.batch_size])
|
521
|
-
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=sampler, **kwargs)
|
522
|
-
|
523
|
-
return loader
|
524
|
-
|
525
|
-
def forward(self, loader, *, print_mode=False, return_gt=True):
|
526
|
-
""" 前向传播
|
527
|
-
|
528
|
-
改功能是__call__的子部分,常在train、eval阶段单独调用
|
529
|
-
因为eval阶段,已经有预设好的train_loader、val_loader,不需要使用inputs2loader智能生成一个loader
|
530
|
-
|
531
|
-
:param torch.utils.data.DataLoader loader: 标准的DataLoader类型,每次能获取[batch_x, batch_y]
|
532
|
-
:param print_mode: 有时候数据量比较大,可能会需要看推断进度条
|
533
|
-
:param return_gt: 注意跟__call__的不同,这里默认是True,__call__默认是False
|
534
|
-
前者常用于评价阶段,后者常用于部署阶段,应用场景不同,常见配置有区别
|
535
|
-
:return:
|
536
|
-
return_gt=True(默认):[(y1, y_hat1), (y2, y_hat2), ...]
|
537
|
-
return_gt=False:[y_hat1, y_hat2, ...]
|
538
|
-
"""
|
539
|
-
preds = []
|
540
|
-
with torch.no_grad():
|
541
|
-
for batched_inputs in tqdm(loader, 'eval batch', disable=not print_mode):
|
542
|
-
# 有的模型forward里没有处理input的device问题,则需要在这里使用self.device设置
|
543
|
-
# batched_inputs = batched_inputs.to(self.device) # 这一步可能不应该写在这里,还是先注释掉吧
|
544
|
-
batch_y = self.model(batched_inputs).tolist()
|
545
|
-
if self.target_transform:
|
546
|
-
batch_y = [self.target_transform(y) for y in batch_y]
|
547
|
-
if return_gt:
|
548
|
-
gt = batched_inputs[1].tolist()
|
549
|
-
preds += list(zip(*[gt, batch_y]))
|
550
|
-
else:
|
551
|
-
preds += batch_y
|
552
|
-
return preds
|
553
|
-
|
554
|
-
def __call__(self, raw_in, *, batch_size=None, y_placeholder=...,
|
555
|
-
print_mode=False, return_gt=False):
|
556
|
-
""" 前传推断结果
|
557
|
-
|
558
|
-
:param batch_size: 具体运行中可以重新指定batch_size
|
559
|
-
:param return_gt: 使用该功能,必须确保每次loader都含有[x,y],可能是raw_in自带,也可以用y_placeholder设置默认值
|
560
|
-
单样本:y, y_hat
|
561
|
-
多样本:[(y1, y_hat1), (y2, y_hat2), ...]
|
562
|
-
:return:
|
563
|
-
单样本:y_hat
|
564
|
-
多样表:[y_hat1, y_hat2, ...]
|
565
|
-
|
566
|
-
根据不同model结构特殊性
|
567
|
-
"""
|
568
|
-
loader = self.inputs2loader(raw_in, batch_size=batch_size, y_placeholder=y_placeholder)
|
569
|
-
preds = self.forward(loader, print_mode=print_mode, return_gt=return_gt)
|
570
|
-
# 返回结果,单样本的时候作简化
|
571
|
-
if len(preds) == 1 and not isinstance(raw_in, (list, tuple, set)):
|
572
|
-
return preds[0]
|
573
|
-
else:
|
574
|
-
return preds
|
575
|
-
|
576
|
-
|
577
|
-
def setup_seed(seed):
|
578
|
-
""" 完整的需要设置的随机数种子
|
579
|
-
|
580
|
-
不过个人实验有时候也不一定有用~~
|
581
|
-
还是有可能各种干扰因素导致模型无法复现
|
582
|
-
"""
|
583
|
-
torch.manual_seed(seed)
|
584
|
-
torch.cuda.manual_seed(seed)
|
585
|
-
torch.cuda.manual_seed_all(seed)
|
586
|
-
np.random.seed(seed)
|
587
|
-
random.seed(seed)
|
588
|
-
torch.backends.cudnn.benchmark = False
|
589
|
-
torch.backends.cudnn.deterministic = True
|
590
|
-
|
591
|
-
|
592
|
-
class TrainingSampler:
|
593
|
-
""" 摘自detectron2,用来做无限循环的抽样
|
594
|
-
我这里的功能做了简化,只能支持单卡训练,原版可以支持多卡训练
|
595
|
-
|
596
|
-
In training, we only care about the "infinite stream" of training data.
|
597
|
-
So this sampler produces an infinite stream of indices and
|
598
|
-
all workers cooperate to correctly shuffle the indices and sample different indices.
|
599
|
-
|
600
|
-
The samplers in each worker effectively produces `indices[worker_id::num_workers]`
|
601
|
-
where `indices` is an infinite stream of indices consisting of
|
602
|
-
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
|
603
|
-
or `range(size) + range(size) + ...` (if shuffle is False)
|
604
|
-
"""
|
605
|
-
|
606
|
-
def __init__(self, size: int, shuffle: bool = True):
|
607
|
-
"""
|
608
|
-
Args:
|
609
|
-
size (int): the total number of data of the underlying dataset to sample from
|
610
|
-
shuffle (bool): whether to shuffle the indices or not
|
611
|
-
"""
|
612
|
-
self._size = size
|
613
|
-
assert size > 0
|
614
|
-
self._shuffle = shuffle
|
615
|
-
|
616
|
-
def __iter__(self):
|
617
|
-
g = torch.Generator()
|
618
|
-
while True:
|
619
|
-
if self._shuffle:
|
620
|
-
yield from torch.randperm(self._size, generator=g).tolist()
|
621
|
-
else:
|
622
|
-
yield from torch.arange(self._size).tolist()
|
623
|
-
|
624
|
-
|
625
|
-
class ZcPredictor:
|
626
|
-
""" 智财ocrwork框架的封装接口
|
627
|
-
|
628
|
-
这个本来是特用功能,不应该放这里的,反正也没啥不可公开的技术细节,为了使用方便就放这了
|
629
|
-
"""
|
630
|
-
|
631
|
-
def __init__(self, config_file, *, gpu=None, batch_size=None, opts=None):
|
632
|
-
"""
|
633
|
-
:param config_file: 支持输入配置文件路径,或者字符串格式的配置参数值
|
634
|
-
:param gpu: 默认可以不设,会挑选当前最大剩余的一张卡
|
635
|
-
注意配置文件中也有gpu参数,在该接口模式下会被弃用
|
636
|
-
:param batch_size: 每次能同时识别的最大图片数
|
637
|
-
注意config_file里也有batch_size,不过那是训练用的参数,跟这没必然联系,部署最好额外设置batch_size
|
638
|
-
该参数可以不设,默认每次传入多少张图,就同时多少张进行批处理
|
639
|
-
:param opts: 除了配置文件的参数,可以自设字典,覆盖更新配置参数值,常用的参数有
|
640
|
-
"""
|
641
|
-
from easydict import EasyDict
|
642
|
-
|
643
|
-
# 1 配置参数
|
644
|
-
if isinstance(config_file, str) and config_file[-5:].lower() == '.yaml':
|
645
|
-
deploy_path = os.environ.get('OCRWORK_DEPLOY', '.') # 支持在环境变量自定义:部署所用的配置、模型所在目录
|
646
|
-
config_file = os.path.join(deploy_path, config_file)
|
647
|
-
f = open(config_file, "r")
|
648
|
-
elif isinstance(config_file, str):
|
649
|
-
f = io.StringIO(config_file)
|
650
|
-
else:
|
651
|
-
raise TypeError
|
652
|
-
prepare_args = EasyDict(list(yaml.load_all(f, Loader=yaml.FullLoader))[0])
|
653
|
-
f.close()
|
654
|
-
|
655
|
-
# 2 特殊配置参数
|
656
|
-
opts = opts or {}
|
657
|
-
if gpu is not None:
|
658
|
-
opts['gpu'] = str(gpu)
|
659
|
-
if 'gpu' not in opts:
|
660
|
-
# gpu没设置的时候,默认找一个空闲最大的显卡
|
661
|
-
opts['gpu'] = NvmDevice().get_most_free_gpu_id()
|
662
|
-
if 'gpu' in opts: # 智财的配置好像必须要写字符串
|
663
|
-
opts['gpu'] = str(opts['gpu'])
|
664
|
-
prepare_args.update(opts)
|
665
|
-
|
666
|
-
# 3 初始化各组件
|
667
|
-
self.prepare_args = prepare_args
|
668
|
-
self.batch_size = batch_size
|
669
|
-
self.transform = lambda x: xlcv.read(x, 1) # 默认统一转cv2的图片格式
|
670
|
-
# self.transform = lambda x: PilPrcs.read(x, 1) # 也可以使用pil图片格式
|
671
|
-
|
672
|
-
def forward(self, imgs):
|
673
|
-
raise NotImplemented('子类必须实现forward方法')
|
674
|
-
|
675
|
-
def __call__(self, raw_in, *, batch_size=None, progress=False):
|
676
|
-
""" 智财的框架,dataloader默认不需要对齐,重置collate_fn
|
677
|
-
(其实不是不需要对齐,而是其augument组件会处理)
|
678
|
-
|
679
|
-
:return: 以多个结果为例
|
680
|
-
preds结果是list
|
681
|
-
pred = preds[0]
|
682
|
-
pred也是list,是第0张图的所有检测框,比如一共8个
|
683
|
-
每个框是 4*2 的numpy矩阵(整数)
|
684
|
-
"""
|
685
|
-
# 1 判断长度
|
686
|
-
if not getattr(raw_in, '__len__', None):
|
687
|
-
imgs = [raw_in]
|
688
|
-
else:
|
689
|
-
imgs = raw_in
|
690
|
-
n = len(imgs)
|
691
|
-
batch_size = first_nonnone([batch_size, self.batch_size, n])
|
692
|
-
|
693
|
-
# 2 一段一段处理
|
694
|
-
preds = []
|
695
|
-
t = tqdm(desc='forward', total=n, disable=not progress)
|
696
|
-
for i in range(0, n, batch_size):
|
697
|
-
inputs = imgs[i:i + batch_size]
|
698
|
-
preds += self.forward([self.transform(img) for img in inputs])
|
699
|
-
t.update(len(inputs))
|
700
|
-
|
701
|
-
# 3 返回结果,单样本的时候作简化
|
702
|
-
if len(preds) == 1 and not getattr(raw_in, '__len__', None):
|
703
|
-
return preds[0]
|
704
|
-
else:
|
705
|
-
return preds
|