pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/algo/geo.py +12 -0
- pyxllib/algo/intervals.py +1 -1
- pyxllib/algo/matcher.py +78 -0
- pyxllib/algo/pupil.py +187 -19
- pyxllib/algo/specialist.py +2 -1
- pyxllib/algo/stat.py +38 -2
- {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/data/echarts.py +123 -12
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/data/pglib.py +514 -30
- pyxllib/data/sqlite.py +231 -4
- pyxllib/ext/JLineViewer.py +14 -1
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +0 -1594
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/unixlib.py +6 -5
- pyxllib/ext/utools.py +108 -95
- pyxllib/ext/webhook.py +32 -14
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1003 -71
- pyxllib/file/docxlib.py +1 -1
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +9 -0
- pyxllib/file/packlib/__init__.py +112 -75
- pyxllib/file/pdflib.py +1 -1
- pyxllib/file/pupil.py +1 -1
- pyxllib/file/specialist/dirlib.py +1 -1
- pyxllib/file/specialist/download.py +10 -3
- pyxllib/file/specialist/filelib.py +266 -55
- pyxllib/file/xlsxlib.py +205 -50
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +9 -2
- pyxllib/prog/pupil.py +129 -60
- pyxllib/prog/specialist/__init__.py +176 -2
- pyxllib/prog/specialist/bc.py +5 -2
- pyxllib/prog/specialist/browser.py +11 -2
- pyxllib/prog/specialist/datetime.py +68 -0
- pyxllib/prog/specialist/tictoc.py +12 -13
- pyxllib/prog/specialist/xllog.py +5 -5
- pyxllib/prog/xlosenv.py +7 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +17 -5
- pyxllib/text/jiebalib.py +6 -3
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +159 -4
- pyxllib/text/nestenv.py +1 -1
- pyxllib/text/newbie.py +12 -0
- pyxllib/text/pupil/common.py +26 -0
- pyxllib/text/specialist/ptag.py +2 -2
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/xmllib.py +76 -14
- pyxllib/xl.py +2 -1
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
- pyxllib/ext/autogui/__init__.py +0 -8
- pyxllib-0.3.96.dist-info/METADATA +0 -51
- pyxllib-0.3.96.dist-info/RECORD +0 -333
- pyxllib-0.3.96.dist-info/top_level.txt +0 -2
- pyxlpr/ai/__init__.py +0 -5
- pyxlpr/ai/clientlib.py +0 -1281
- pyxlpr/ai/specialist.py +0 -286
- pyxlpr/ai/torch_app.py +0 -172
- pyxlpr/ai/xlpaddle.py +0 -655
- pyxlpr/ai/xltorch.py +0 -705
- pyxlpr/data/__init__.py +0 -11
- pyxlpr/data/coco.py +0 -1325
- pyxlpr/data/datacls.py +0 -365
- pyxlpr/data/datasets.py +0 -200
- pyxlpr/data/gptlib.py +0 -1291
- pyxlpr/data/icdar/__init__.py +0 -96
- pyxlpr/data/icdar/deteval.py +0 -377
- pyxlpr/data/icdar/icdar2013.py +0 -341
- pyxlpr/data/icdar/iou.py +0 -340
- pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
- pyxlpr/data/imtextline.py +0 -473
- pyxlpr/data/labelme.py +0 -866
- pyxlpr/data/removeline.py +0 -179
- pyxlpr/data/specialist.py +0 -57
- pyxlpr/eval/__init__.py +0 -85
- pyxlpr/paddleocr.py +0 -776
- pyxlpr/ppocr/__init__.py +0 -15
- pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
- pyxlpr/ppocr/data/__init__.py +0 -135
- pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
- pyxlpr/ppocr/data/imaug/__init__.py +0 -67
- pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
- pyxlpr/ppocr/data/imaug/east_process.py +0 -437
- pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
- pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
- pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
- pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
- pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
- pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
- pyxlpr/ppocr/data/imaug/operators.py +0 -433
- pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
- pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
- pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
- pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
- pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
- pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
- pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
- pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
- pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
- pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
- pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
- pyxlpr/ppocr/data/simple_dataset.py +0 -372
- pyxlpr/ppocr/losses/__init__.py +0 -61
- pyxlpr/ppocr/losses/ace_loss.py +0 -52
- pyxlpr/ppocr/losses/basic_loss.py +0 -135
- pyxlpr/ppocr/losses/center_loss.py +0 -88
- pyxlpr/ppocr/losses/cls_loss.py +0 -30
- pyxlpr/ppocr/losses/combined_loss.py +0 -67
- pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
- pyxlpr/ppocr/losses/det_db_loss.py +0 -80
- pyxlpr/ppocr/losses/det_east_loss.py +0 -63
- pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
- pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
- pyxlpr/ppocr/losses/distillation_loss.py +0 -272
- pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
- pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
- pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
- pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
- pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
- pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
- pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
- pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
- pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
- pyxlpr/ppocr/losses/table_att_loss.py +0 -109
- pyxlpr/ppocr/metrics/__init__.py +0 -44
- pyxlpr/ppocr/metrics/cls_metric.py +0 -45
- pyxlpr/ppocr/metrics/det_metric.py +0 -82
- pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
- pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
- pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
- pyxlpr/ppocr/metrics/kie_metric.py +0 -70
- pyxlpr/ppocr/metrics/rec_metric.py +0 -75
- pyxlpr/ppocr/metrics/table_metric.py +0 -50
- pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
- pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
- pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
- pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
- pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
- pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
- pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
- pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
- pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
- pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
- pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
- pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
- pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
- pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
- pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
- pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
- pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
- pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
- pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
- pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
- pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
- pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
- pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
- pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
- pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
- pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
- pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
- pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
- pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
- pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
- pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
- pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
- pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
- pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
- pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
- pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
- pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
- pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
- pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
- pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
- pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
- pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
- pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
- pyxlpr/ppocr/optimizer/__init__.py +0 -61
- pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
- pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
- pyxlpr/ppocr/optimizer/optimizer.py +0 -160
- pyxlpr/ppocr/optimizer/regularizer.py +0 -52
- pyxlpr/ppocr/postprocess/__init__.py +0 -55
- pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
- pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
- pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
- pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
- pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
- pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
- pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
- pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
- pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
- pyxlpr/ppocr/tools/__init__.py +0 -14
- pyxlpr/ppocr/tools/eval.py +0 -83
- pyxlpr/ppocr/tools/export_center.py +0 -77
- pyxlpr/ppocr/tools/export_model.py +0 -129
- pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
- pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
- pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
- pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
- pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
- pyxlpr/ppocr/tools/infer/utility.py +0 -629
- pyxlpr/ppocr/tools/infer_cls.py +0 -83
- pyxlpr/ppocr/tools/infer_det.py +0 -134
- pyxlpr/ppocr/tools/infer_e2e.py +0 -122
- pyxlpr/ppocr/tools/infer_kie.py +0 -153
- pyxlpr/ppocr/tools/infer_rec.py +0 -146
- pyxlpr/ppocr/tools/infer_table.py +0 -107
- pyxlpr/ppocr/tools/program.py +0 -596
- pyxlpr/ppocr/tools/test_hubserving.py +0 -117
- pyxlpr/ppocr/tools/train.py +0 -163
- pyxlpr/ppocr/tools/xlprog.py +0 -748
- pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
- pyxlpr/ppocr/utils/__init__.py +0 -24
- pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
- pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
- pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
- pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
- pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
- pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
- pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
- pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
- pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
- pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
- pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
- pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
- pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
- pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
- pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
- pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
- pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
- pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
- pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
- pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
- pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
- pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
- pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
- pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
- pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
- pyxlpr/ppocr/utils/dict90.txt +0 -90
- pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
- pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
- pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
- pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
- pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
- pyxlpr/ppocr/utils/en_dict.txt +0 -95
- pyxlpr/ppocr/utils/gen_label.py +0 -81
- pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
- pyxlpr/ppocr/utils/iou.py +0 -54
- pyxlpr/ppocr/utils/logging.py +0 -69
- pyxlpr/ppocr/utils/network.py +0 -84
- pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
- pyxlpr/ppocr/utils/profiler.py +0 -110
- pyxlpr/ppocr/utils/save_load.py +0 -150
- pyxlpr/ppocr/utils/stats.py +0 -72
- pyxlpr/ppocr/utils/utility.py +0 -80
- pyxlpr/ppstructure/__init__.py +0 -13
- pyxlpr/ppstructure/predict_system.py +0 -187
- pyxlpr/ppstructure/table/__init__.py +0 -13
- pyxlpr/ppstructure/table/eval_table.py +0 -72
- pyxlpr/ppstructure/table/matcher.py +0 -192
- pyxlpr/ppstructure/table/predict_structure.py +0 -136
- pyxlpr/ppstructure/table/predict_table.py +0 -221
- pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
- pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
- pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
- pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
- pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
- pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
- pyxlpr/ppstructure/utility.py +0 -71
- pyxlpr/xlai.py +0 -10
- /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
pyxlpr/ppocr/tools/program.py
DELETED
@@ -1,596 +0,0 @@
|
|
1
|
-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
import os
|
20
|
-
import sys
|
21
|
-
import platform
|
22
|
-
import yaml
|
23
|
-
import time
|
24
|
-
import paddle
|
25
|
-
import paddle.distributed as dist
|
26
|
-
from tqdm import tqdm
|
27
|
-
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
28
|
-
|
29
|
-
from pyxlpr.ppocr.utils.stats import TrainingStats
|
30
|
-
from pyxlpr.ppocr.utils.save_load import save_model
|
31
|
-
from pyxlpr.ppocr.utils.utility import print_dict
|
32
|
-
from pyxlpr.ppocr.utils.logging import get_logger
|
33
|
-
from pyxlpr.ppocr.utils import profiler
|
34
|
-
from pyxlpr.ppocr.data import build_dataloader
|
35
|
-
|
36
|
-
|
37
|
-
class ArgsParser(ArgumentParser):
|
38
|
-
def __init__(self):
|
39
|
-
""" 这是pp自定义的一个命令行参数解释器 """
|
40
|
-
|
41
|
-
''' RawDescriptionHelpFormatter
|
42
|
-
|
43
|
-
formatter_class:重置 help 信息输出的格式,可供选择的参数有:
|
44
|
-
HelpFormatter、ArgumentDefaultsHelpFormatter、RawDescriptionHelpFormatter、RawTextHelpFormatter
|
45
|
-
详见 Python 模块简介-argparse: https://mp.weixin.qq.com/s/s49awBykc7pFEV4XnFNO6g
|
46
|
-
|
47
|
-
默认是HelpFormatter,应该是argparse提供的另一种使用提示吧。
|
48
|
-
使用--help,获得的好像也是正常提示,没啥区别
|
49
|
-
报错的情况我也测试了下,目前发现不了跟HelpFormatter有啥区别,先不管了。
|
50
|
-
'''
|
51
|
-
super(ArgsParser, self).__init__(
|
52
|
-
formatter_class=RawDescriptionHelpFormatter)
|
53
|
-
|
54
|
-
self.add_argument("-c", "--config", help="configuration file to use")
|
55
|
-
|
56
|
-
# argparse的nargs用法:https://docs.python.org/3/library/argparse.html?highlight=argparse%20nargs#nargs
|
57
|
-
# +表示使用-o时,至少要提供1个参数值,也可以有多个值,但不能为空。进入内存后会组织为list对象。
|
58
|
-
self.add_argument(
|
59
|
-
"-o", "--opt", nargs='+', help="set configuration options")
|
60
|
-
self.add_argument(
|
61
|
-
'-p',
|
62
|
-
'--profiler_options',
|
63
|
-
type=str,
|
64
|
-
default=None,
|
65
|
-
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
|
66
|
-
)
|
67
|
-
|
68
|
-
def parse_args(self, argv=None):
|
69
|
-
""" 注意执行parse_args时,这里重载了 """
|
70
|
-
args = super(ArgsParser, self).parse_args(argv)
|
71
|
-
assert args.config is not None, \
|
72
|
-
"Please specify --config=configure_file_path."
|
73
|
-
args.opt = self._parse_opt(args.opt)
|
74
|
-
return args
|
75
|
-
|
76
|
-
def _parse_opt(self, opts):
|
77
|
-
""" 把list格式的opt值,重新设计为字典格式
|
78
|
-
"""
|
79
|
-
config = {}
|
80
|
-
if not opts:
|
81
|
-
return config
|
82
|
-
for s in opts:
|
83
|
-
s = s.strip()
|
84
|
-
k, v = s.split('=')
|
85
|
-
config[k] = yaml.load(v, Loader=yaml.Loader)
|
86
|
-
return config
|
87
|
-
|
88
|
-
|
89
|
-
class AttrDict(dict):
|
90
|
-
"""Single level attribute dict, NOT recursive
|
91
|
-
|
92
|
-
AttrDict就是个普通的字典类,没啥特别的
|
93
|
-
"""
|
94
|
-
|
95
|
-
def __init__(self, **kwargs):
|
96
|
-
super(AttrDict, self).__init__()
|
97
|
-
super(AttrDict, self).update(kwargs)
|
98
|
-
|
99
|
-
def __getattr__(self, key):
|
100
|
-
if key in self:
|
101
|
-
return self[key]
|
102
|
-
raise AttributeError("object has no attribute '{}'".format(key))
|
103
|
-
|
104
|
-
|
105
|
-
# 定义了一个全局配置字典
|
106
|
-
global_config = AttrDict()
|
107
|
-
|
108
|
-
default_config = {'Global': {'debug': False, }}
|
109
|
-
|
110
|
-
|
111
|
-
def load_config(file_path):
|
112
|
-
""" 解析传入的yaml配置文件
|
113
|
-
把配置文件的参数合并到全局配置,函数返回值也是全局配置
|
114
|
-
|
115
|
-
Load config from yml/yaml file.
|
116
|
-
Args:
|
117
|
-
file_path (str): Path of the config file to be loaded.
|
118
|
-
Returns: global config
|
119
|
-
"""
|
120
|
-
merge_config(default_config)
|
121
|
-
_, ext = os.path.splitext(file_path)
|
122
|
-
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
|
123
|
-
merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
|
124
|
-
return global_config
|
125
|
-
|
126
|
-
|
127
|
-
def merge_config(config):
|
128
|
-
""" 可以递归,把配置更新合并到全局配置中
|
129
|
-
|
130
|
-
Merge config into global config.
|
131
|
-
Args:
|
132
|
-
config (dict): Config to be merged.
|
133
|
-
Returns: global config
|
134
|
-
"""
|
135
|
-
for key, value in config.items():
|
136
|
-
if "." not in key:
|
137
|
-
if isinstance(value, dict) and key in global_config:
|
138
|
-
global_config[key].update(value)
|
139
|
-
else:
|
140
|
-
global_config[key] = value
|
141
|
-
else:
|
142
|
-
sub_keys = key.split('.')
|
143
|
-
assert (
|
144
|
-
sub_keys[0] in global_config
|
145
|
-
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
|
146
|
-
global_config.keys(), sub_keys[0])
|
147
|
-
cur = global_config[sub_keys[0]]
|
148
|
-
for idx, sub_key in enumerate(sub_keys[1:]):
|
149
|
-
if idx == len(sub_keys) - 2:
|
150
|
-
cur[sub_key] = value
|
151
|
-
else:
|
152
|
-
cur = cur[sub_key]
|
153
|
-
|
154
|
-
|
155
|
-
def check_gpu(use_gpu):
|
156
|
-
"""
|
157
|
-
Log error and exit when set use_gpu=true in paddlepaddle
|
158
|
-
cpu version.
|
159
|
-
"""
|
160
|
-
err = "Config use_gpu cannot be set as true while you are " \
|
161
|
-
"using paddlepaddle cpu version ! \nPlease try: \n" \
|
162
|
-
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
|
163
|
-
"\t2. Set use_gpu as false in config file to run " \
|
164
|
-
"model on CPU"
|
165
|
-
|
166
|
-
try:
|
167
|
-
if use_gpu and not paddle.is_compiled_with_cuda():
|
168
|
-
print(err)
|
169
|
-
sys.exit(1)
|
170
|
-
except Exception as e:
|
171
|
-
pass
|
172
|
-
|
173
|
-
|
174
|
-
def train(config,
|
175
|
-
train_dataloader,
|
176
|
-
valid_dataloader,
|
177
|
-
device,
|
178
|
-
model,
|
179
|
-
loss_class,
|
180
|
-
optimizer,
|
181
|
-
lr_scheduler,
|
182
|
-
post_process_class,
|
183
|
-
eval_class,
|
184
|
-
pre_best_model_dict,
|
185
|
-
logger,
|
186
|
-
vdl_writer=None,
|
187
|
-
scaler=None):
|
188
|
-
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
|
189
|
-
False)
|
190
|
-
log_smooth_window = config['Global']['log_smooth_window']
|
191
|
-
epoch_num = config['Global']['epoch_num']
|
192
|
-
print_batch_step = config['Global']['print_batch_step']
|
193
|
-
eval_batch_step = config['Global']['eval_batch_step']
|
194
|
-
profiler_options = config['profiler_options']
|
195
|
-
|
196
|
-
global_step = 0
|
197
|
-
if 'global_step' in pre_best_model_dict:
|
198
|
-
global_step = pre_best_model_dict['global_step']
|
199
|
-
start_eval_step = 0
|
200
|
-
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
|
201
|
-
start_eval_step = eval_batch_step[0]
|
202
|
-
eval_batch_step = eval_batch_step[1]
|
203
|
-
if len(valid_dataloader) == 0:
|
204
|
-
logger.info(
|
205
|
-
'No Images in eval dataset, evaluation during training will be disabled'
|
206
|
-
)
|
207
|
-
start_eval_step = 1e111
|
208
|
-
logger.info(
|
209
|
-
"During the training process, after the {}th iteration, an evaluation is run every {} iterations".
|
210
|
-
format(start_eval_step, eval_batch_step))
|
211
|
-
save_epoch_step = config['Global']['save_epoch_step']
|
212
|
-
save_model_dir = config['Global']['save_model_dir']
|
213
|
-
if not os.path.exists(save_model_dir):
|
214
|
-
os.makedirs(save_model_dir)
|
215
|
-
main_indicator = eval_class.main_indicator
|
216
|
-
best_model_dict = {main_indicator: 0}
|
217
|
-
best_model_dict.update(pre_best_model_dict)
|
218
|
-
train_stats = TrainingStats(log_smooth_window, ['lr'])
|
219
|
-
model_average = False
|
220
|
-
model.train()
|
221
|
-
|
222
|
-
use_srn = config['Architecture']['algorithm'] == "SRN"
|
223
|
-
extra_input = config['Architecture'][
|
224
|
-
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
|
225
|
-
try:
|
226
|
-
model_type = config['Architecture']['model_type']
|
227
|
-
except:
|
228
|
-
model_type = None
|
229
|
-
algorithm = config['Architecture']['algorithm']
|
230
|
-
|
231
|
-
if 'start_epoch' in best_model_dict:
|
232
|
-
start_epoch = best_model_dict['start_epoch']
|
233
|
-
else:
|
234
|
-
start_epoch = 1
|
235
|
-
|
236
|
-
for epoch in range(start_epoch, epoch_num + 1):
|
237
|
-
# 每轮都会重新构建一次数据
|
238
|
-
train_dataloader = build_dataloader(
|
239
|
-
config, 'Train', device, logger, seed=epoch)
|
240
|
-
train_reader_cost = 0.0
|
241
|
-
train_run_cost = 0.0
|
242
|
-
total_samples = 0
|
243
|
-
reader_start = time.time()
|
244
|
-
max_iter = len(train_dataloader) - 1 if platform.system(
|
245
|
-
) == "Windows" else len(train_dataloader)
|
246
|
-
for idx, batch in enumerate(train_dataloader):
|
247
|
-
profiler.add_profiler_step(profiler_options)
|
248
|
-
train_reader_cost += time.time() - reader_start
|
249
|
-
if idx >= max_iter:
|
250
|
-
break
|
251
|
-
lr = optimizer.get_lr()
|
252
|
-
images = batch[0]
|
253
|
-
if use_srn:
|
254
|
-
model_average = True
|
255
|
-
|
256
|
-
train_start = time.time()
|
257
|
-
# use amp
|
258
|
-
if scaler:
|
259
|
-
with paddle.amp.auto_cast():
|
260
|
-
if model_type == 'table' or extra_input:
|
261
|
-
preds = model(images, data=batch[1:])
|
262
|
-
else:
|
263
|
-
preds = model(images)
|
264
|
-
else:
|
265
|
-
if model_type == 'table' or extra_input:
|
266
|
-
preds = model(images, data=batch[1:])
|
267
|
-
elif model_type == "kie":
|
268
|
-
preds = model(batch)
|
269
|
-
else:
|
270
|
-
preds = model(images)
|
271
|
-
loss = loss_class(preds, batch)
|
272
|
-
avg_loss = loss['loss']
|
273
|
-
|
274
|
-
if scaler:
|
275
|
-
scaled_avg_loss = scaler.scale(avg_loss)
|
276
|
-
scaled_avg_loss.backward()
|
277
|
-
scaler.minimize(optimizer, scaled_avg_loss)
|
278
|
-
else:
|
279
|
-
avg_loss.backward()
|
280
|
-
optimizer.step()
|
281
|
-
optimizer.clear_grad()
|
282
|
-
|
283
|
-
train_run_cost += time.time() - train_start
|
284
|
-
total_samples += len(images)
|
285
|
-
|
286
|
-
if not isinstance(lr_scheduler, float):
|
287
|
-
lr_scheduler.step()
|
288
|
-
|
289
|
-
# logger and visualdl
|
290
|
-
stats = {k: v.numpy().mean() for k, v in loss.items()}
|
291
|
-
stats['lr'] = lr
|
292
|
-
train_stats.update(stats)
|
293
|
-
|
294
|
-
if cal_metric_during_train and (model_type != "det"): # only rec and cls need
|
295
|
-
batch = [item.numpy() for item in batch]
|
296
|
-
if model_type in ['table', 'kie']:
|
297
|
-
eval_class(preds, batch)
|
298
|
-
else:
|
299
|
-
post_result = post_process_class(preds, batch[1])
|
300
|
-
eval_class(post_result, batch)
|
301
|
-
metric = eval_class.get_metric()
|
302
|
-
train_stats.update(metric)
|
303
|
-
|
304
|
-
if vdl_writer is not None and dist.get_rank() == 0:
|
305
|
-
for k, v in train_stats.get().items():
|
306
|
-
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
|
307
|
-
vdl_writer.add_scalar('TRAIN/lr', lr, global_step)
|
308
|
-
|
309
|
-
if dist.get_rank() == 0 and (
|
310
|
-
(global_step > 0 and global_step % print_batch_step == 0) or
|
311
|
-
(idx >= len(train_dataloader) - 1)):
|
312
|
-
logs = train_stats.log()
|
313
|
-
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
|
314
|
-
epoch, epoch_num, global_step, logs, train_reader_cost /
|
315
|
-
print_batch_step, (train_reader_cost + train_run_cost) /
|
316
|
-
print_batch_step, total_samples,
|
317
|
-
total_samples / (train_reader_cost + train_run_cost))
|
318
|
-
logger.info(strs)
|
319
|
-
train_reader_cost = 0.0
|
320
|
-
train_run_cost = 0.0
|
321
|
-
total_samples = 0
|
322
|
-
# eval
|
323
|
-
if global_step > start_eval_step and \
|
324
|
-
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
325
|
-
if model_average:
|
326
|
-
Model_Average = paddle.incubate.optimizer.ModelAverage(
|
327
|
-
0.15,
|
328
|
-
parameters=model.parameters(),
|
329
|
-
min_average_window=10000,
|
330
|
-
max_average_window=15625)
|
331
|
-
Model_Average.apply()
|
332
|
-
cur_metric = eval(
|
333
|
-
model,
|
334
|
-
valid_dataloader,
|
335
|
-
post_process_class,
|
336
|
-
eval_class,
|
337
|
-
model_type,
|
338
|
-
extra_input=extra_input)
|
339
|
-
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
340
|
-
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
341
|
-
logger.info(cur_metric_str)
|
342
|
-
|
343
|
-
# logger metric
|
344
|
-
if vdl_writer is not None:
|
345
|
-
for k, v in cur_metric.items():
|
346
|
-
if isinstance(v, (float, int)):
|
347
|
-
vdl_writer.add_scalar('EVAL/{}'.format(k),
|
348
|
-
cur_metric[k], global_step)
|
349
|
-
if cur_metric[main_indicator] >= best_model_dict[
|
350
|
-
main_indicator]:
|
351
|
-
best_model_dict.update(cur_metric)
|
352
|
-
best_model_dict['best_epoch'] = epoch
|
353
|
-
save_model(
|
354
|
-
model,
|
355
|
-
optimizer,
|
356
|
-
save_model_dir,
|
357
|
-
logger,
|
358
|
-
is_best=True,
|
359
|
-
prefix='best_accuracy',
|
360
|
-
best_model_dict=best_model_dict,
|
361
|
-
epoch=epoch,
|
362
|
-
global_step=global_step)
|
363
|
-
best_str = 'best metric, {}'.format(', '.join([
|
364
|
-
'{}: {}'.format(k, v) for k, v in best_model_dict.items()
|
365
|
-
]))
|
366
|
-
logger.info(best_str)
|
367
|
-
# logger best metric
|
368
|
-
if vdl_writer is not None:
|
369
|
-
vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
|
370
|
-
best_model_dict[main_indicator],
|
371
|
-
global_step)
|
372
|
-
global_step += 1
|
373
|
-
optimizer.clear_grad()
|
374
|
-
reader_start = time.time()
|
375
|
-
if dist.get_rank() == 0:
|
376
|
-
save_model(
|
377
|
-
model,
|
378
|
-
optimizer,
|
379
|
-
save_model_dir,
|
380
|
-
logger,
|
381
|
-
is_best=False,
|
382
|
-
prefix='latest',
|
383
|
-
best_model_dict=best_model_dict,
|
384
|
-
epoch=epoch,
|
385
|
-
global_step=global_step)
|
386
|
-
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
|
387
|
-
save_model(
|
388
|
-
model,
|
389
|
-
optimizer,
|
390
|
-
save_model_dir,
|
391
|
-
logger,
|
392
|
-
is_best=False,
|
393
|
-
prefix='iter_epoch_{}'.format(epoch),
|
394
|
-
best_model_dict=best_model_dict,
|
395
|
-
epoch=epoch,
|
396
|
-
global_step=global_step)
|
397
|
-
best_str = 'best metric, {}'.format(', '.join(
|
398
|
-
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
|
399
|
-
logger.info(best_str)
|
400
|
-
if dist.get_rank() == 0 and vdl_writer is not None:
|
401
|
-
vdl_writer.close()
|
402
|
-
return
|
403
|
-
|
404
|
-
|
405
|
-
def eval(model,
|
406
|
-
valid_dataloader,
|
407
|
-
post_process_class,
|
408
|
-
eval_class,
|
409
|
-
model_type=None,
|
410
|
-
extra_input=False):
|
411
|
-
model.eval()
|
412
|
-
with paddle.no_grad():
|
413
|
-
total_frame = 0.0
|
414
|
-
total_time = 0.0
|
415
|
-
pbar = tqdm(
|
416
|
-
total=len(valid_dataloader),
|
417
|
-
desc='eval model:',
|
418
|
-
position=0,
|
419
|
-
leave=True)
|
420
|
-
max_iter = len(valid_dataloader) - 1 if platform.system(
|
421
|
-
) == "Windows" else len(valid_dataloader)
|
422
|
-
for idx, batch in enumerate(valid_dataloader):
|
423
|
-
# if idx >= max_iter:
|
424
|
-
# break
|
425
|
-
images = batch[0]
|
426
|
-
start = time.time()
|
427
|
-
if model_type == 'table' or extra_input:
|
428
|
-
preds = model(images, data=batch[1:])
|
429
|
-
elif model_type == "kie":
|
430
|
-
preds = model(batch)
|
431
|
-
else:
|
432
|
-
preds = model(images)
|
433
|
-
batch = [item.numpy() for item in batch]
|
434
|
-
# Obtain usable results from post-processing methods
|
435
|
-
total_time += time.time() - start
|
436
|
-
# Evaluate the results of the current batch
|
437
|
-
if model_type in ['table', 'kie']:
|
438
|
-
eval_class(preds, batch)
|
439
|
-
else:
|
440
|
-
post_result = post_process_class(preds, batch[1])
|
441
|
-
# print(post_result)
|
442
|
-
eval_class(post_result, batch)
|
443
|
-
|
444
|
-
pbar.update(1)
|
445
|
-
total_frame += len(images)
|
446
|
-
# Get final metric,eg. acc or hmean
|
447
|
-
metric = eval_class.get_metric()
|
448
|
-
|
449
|
-
pbar.close()
|
450
|
-
model.train()
|
451
|
-
metric['total_frame'] = int(total_frame)
|
452
|
-
metric['fps'] = total_frame / total_time
|
453
|
-
return metric
|
454
|
-
|
455
|
-
|
456
|
-
def update_center(char_center, post_result, preds):
|
457
|
-
result, label = post_result
|
458
|
-
feats, logits = preds
|
459
|
-
logits = paddle.argmax(logits, axis=-1)
|
460
|
-
feats = feats.numpy()
|
461
|
-
logits = logits.numpy()
|
462
|
-
|
463
|
-
for idx_sample in range(len(label)):
|
464
|
-
if result[idx_sample][0] == label[idx_sample][0]:
|
465
|
-
feat = feats[idx_sample]
|
466
|
-
logit = logits[idx_sample]
|
467
|
-
for idx_time in range(len(logit)):
|
468
|
-
index = logit[idx_time]
|
469
|
-
if index in char_center.keys():
|
470
|
-
char_center[index][0] = (
|
471
|
-
char_center[index][0] * char_center[index][1] +
|
472
|
-
feat[idx_time]) / (char_center[index][1] + 1)
|
473
|
-
char_center[index][1] += 1
|
474
|
-
else:
|
475
|
-
char_center[index] = [feat[idx_time], 1]
|
476
|
-
return char_center
|
477
|
-
|
478
|
-
|
479
|
-
def get_center(model, eval_dataloader, post_process_class):
|
480
|
-
pbar = tqdm(total=len(eval_dataloader), desc='get center:')
|
481
|
-
max_iter = len(eval_dataloader) - 1 if platform.system(
|
482
|
-
) == "Windows" else len(eval_dataloader)
|
483
|
-
char_center = dict()
|
484
|
-
for idx, batch in enumerate(eval_dataloader):
|
485
|
-
if idx >= max_iter:
|
486
|
-
break
|
487
|
-
images = batch[0]
|
488
|
-
start = time.time()
|
489
|
-
preds = model(images)
|
490
|
-
|
491
|
-
batch = [item.numpy() for item in batch]
|
492
|
-
# Obtain usable results from post-processing methods
|
493
|
-
post_result = post_process_class(preds, batch[1])
|
494
|
-
|
495
|
-
# update char_center
|
496
|
-
char_center = update_center(char_center, post_result, preds)
|
497
|
-
pbar.update(1)
|
498
|
-
|
499
|
-
pbar.close()
|
500
|
-
for key in char_center.keys():
|
501
|
-
char_center[key] = char_center[key][0]
|
502
|
-
return char_center
|
503
|
-
|
504
|
-
|
505
|
-
def preprocess(is_train=False, *, use_visualdl=True, from_dict=None):
|
506
|
-
""" 用于获取配置、设备、日志、visualdl相关对象工具
|
507
|
-
|
508
|
-
:param use_visualdl: 除了检查配置文件是否开启vdl,这个参数同时也为True时才会开启
|
509
|
-
在有时候需要preprocess获得前三者,但并不需要重复开一个vdl时使用
|
510
|
-
"""
|
511
|
-
|
512
|
-
# 1 config
|
513
|
-
if from_dict:
|
514
|
-
config = global_config
|
515
|
-
merge_config(default_config)
|
516
|
-
merge_config(from_dict)
|
517
|
-
profile_dic = {"profiler_options": None}
|
518
|
-
else:
|
519
|
-
# global_config/config <-- default_config + 配置文件 FLAGS.config + 命令行参数 FLAGS.opt
|
520
|
-
FLAGS = ArgsParser().parse_args()
|
521
|
-
profiler_options = FLAGS.profiler_options
|
522
|
-
config = load_config(FLAGS.config) # 返回的是全局变量global_config
|
523
|
-
# 可以递归,把配置(这里是命令行参数)更新合并到全局配置中
|
524
|
-
merge_config(FLAGS.opt) # 该函数会修改全局变量,所以会修改config的值
|
525
|
-
profile_dic = {"profiler_options": FLAGS.profiler_options}
|
526
|
-
merge_config(profile_dic)
|
527
|
-
|
528
|
-
''' pp处理跟d2有点区别。d2底层默认配置了很复杂的一套默认参数值。
|
529
|
-
pp则几乎什么都没有,只有很简洁的一个默认配置,然后叠加配置文件里的参数,再更新命令行设置的参数。
|
530
|
-
相比d2的好处,是pp的yaml是纯粹的yaml配置文件,没有任何特殊的依赖要求。
|
531
|
-
所以框架里有些必须要获取的结构内容,但很容易自定义扩展各种其他配置参数值。
|
532
|
-
|
533
|
-
因为该种设计模式,后面的接口会有对应很多默认值的设置,确保没有传递对应配置时,能run。
|
534
|
-
'''
|
535
|
-
|
536
|
-
# 2 logger
|
537
|
-
if is_train:
|
538
|
-
# 跟is_train有关,如果开启,会在save_model_dir目录下备份一个config.yml配置文件,
|
539
|
-
# 并且会把日志记录到train.log文件中。
|
540
|
-
# save_config
|
541
|
-
save_model_dir = config['Global']['save_model_dir']
|
542
|
-
os.makedirs(save_model_dir, exist_ok=True)
|
543
|
-
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
|
544
|
-
yaml.dump(
|
545
|
-
dict(config), f, default_flow_style=False, sort_keys=False)
|
546
|
-
log_file = '{}/train.log'.format(save_model_dir)
|
547
|
-
else: # 否则虽然有日志类,但不会把运行记录到文件中
|
548
|
-
log_file = None
|
549
|
-
logger = get_logger(name='root', log_file=log_file)
|
550
|
-
|
551
|
-
# 3 device
|
552
|
-
# check if set use_gpu=True in paddlepaddle cpu version
|
553
|
-
use_gpu = config['Global']['use_gpu']
|
554
|
-
check_gpu(use_gpu) # 在使用gpu时会检查cuda是否可用
|
555
|
-
|
556
|
-
# 检查是否在所支持的算法组件里,自己应该可以通过后续框架的学习,扩展自己的算法组件。
|
557
|
-
# 需要的话,自己可以把这些算法论文都搜出来,学习一遍。
|
558
|
-
alg = config['Architecture']['algorithm']
|
559
|
-
assert alg in [
|
560
|
-
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
561
|
-
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
562
|
-
'SEED', 'SDMGR'
|
563
|
-
]
|
564
|
-
windows_not_support_list = ['PSE']
|
565
|
-
if platform.system() == "Windows" and alg in windows_not_support_list:
|
566
|
-
logger.warning('{} is not support in Windows now'.format(
|
567
|
-
windows_not_support_list))
|
568
|
-
sys.exit()
|
569
|
-
|
570
|
-
# dist.ParallelEnv().dev_id不太清楚作用,看文档也推荐不直接使用这个接口。
|
571
|
-
# 我测试了下,虽然0卡有在用,默认还是返回0,总之不是啥智能判断获得空余显卡这种功能~
|
572
|
-
# 简单来说,就是设置了device,细节我也先不用太纠结。
|
573
|
-
# 应该是跟分布式有关,在分布式的时候,这里才会有些区别。
|
574
|
-
# 默认单卡情况,第14行获得的dist.get_world_size()也是1。
|
575
|
-
# 第14行会自动标记一个是否使用分布式训练的参数distributed。
|
576
|
-
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
577
|
-
device = paddle.set_device(device)
|
578
|
-
|
579
|
-
config['Global']['distributed'] = dist.get_world_size() != 1
|
580
|
-
|
581
|
-
# 4 vdl_write,如果开启了可视化功能
|
582
|
-
# 在save_model_dir目录下,会再建立一个vdl目录,返回一个vdl_writer对象
|
583
|
-
if config['Global']['use_visualdl'] and use_visualdl:
|
584
|
-
from visualdl import LogWriter
|
585
|
-
save_model_dir = config['Global']['save_model_dir']
|
586
|
-
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
587
|
-
os.makedirs(vdl_writer_path, exist_ok=True)
|
588
|
-
vdl_writer = LogWriter(logdir=vdl_writer_path)
|
589
|
-
else:
|
590
|
-
vdl_writer = None
|
591
|
-
|
592
|
-
# 用logger输出config的内容
|
593
|
-
print_dict(config, logger)
|
594
|
-
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
|
595
|
-
device))
|
596
|
-
return config, device, logger, vdl_writer
|