pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/algo/geo.py +12 -0
- pyxllib/algo/intervals.py +1 -1
- pyxllib/algo/matcher.py +78 -0
- pyxllib/algo/pupil.py +187 -19
- pyxllib/algo/specialist.py +2 -1
- pyxllib/algo/stat.py +38 -2
- {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/data/echarts.py +123 -12
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/data/pglib.py +514 -30
- pyxllib/data/sqlite.py +231 -4
- pyxllib/ext/JLineViewer.py +14 -1
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +0 -1594
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/unixlib.py +6 -5
- pyxllib/ext/utools.py +108 -95
- pyxllib/ext/webhook.py +32 -14
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1003 -71
- pyxllib/file/docxlib.py +1 -1
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +9 -0
- pyxllib/file/packlib/__init__.py +112 -75
- pyxllib/file/pdflib.py +1 -1
- pyxllib/file/pupil.py +1 -1
- pyxllib/file/specialist/dirlib.py +1 -1
- pyxllib/file/specialist/download.py +10 -3
- pyxllib/file/specialist/filelib.py +266 -55
- pyxllib/file/xlsxlib.py +205 -50
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +9 -2
- pyxllib/prog/pupil.py +129 -60
- pyxllib/prog/specialist/__init__.py +176 -2
- pyxllib/prog/specialist/bc.py +5 -2
- pyxllib/prog/specialist/browser.py +11 -2
- pyxllib/prog/specialist/datetime.py +68 -0
- pyxllib/prog/specialist/tictoc.py +12 -13
- pyxllib/prog/specialist/xllog.py +5 -5
- pyxllib/prog/xlosenv.py +7 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +17 -5
- pyxllib/text/jiebalib.py +6 -3
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +159 -4
- pyxllib/text/nestenv.py +1 -1
- pyxllib/text/newbie.py +12 -0
- pyxllib/text/pupil/common.py +26 -0
- pyxllib/text/specialist/ptag.py +2 -2
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/xmllib.py +76 -14
- pyxllib/xl.py +2 -1
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
- pyxllib/ext/autogui/__init__.py +0 -8
- pyxllib-0.3.96.dist-info/METADATA +0 -51
- pyxllib-0.3.96.dist-info/RECORD +0 -333
- pyxllib-0.3.96.dist-info/top_level.txt +0 -2
- pyxlpr/ai/__init__.py +0 -5
- pyxlpr/ai/clientlib.py +0 -1281
- pyxlpr/ai/specialist.py +0 -286
- pyxlpr/ai/torch_app.py +0 -172
- pyxlpr/ai/xlpaddle.py +0 -655
- pyxlpr/ai/xltorch.py +0 -705
- pyxlpr/data/__init__.py +0 -11
- pyxlpr/data/coco.py +0 -1325
- pyxlpr/data/datacls.py +0 -365
- pyxlpr/data/datasets.py +0 -200
- pyxlpr/data/gptlib.py +0 -1291
- pyxlpr/data/icdar/__init__.py +0 -96
- pyxlpr/data/icdar/deteval.py +0 -377
- pyxlpr/data/icdar/icdar2013.py +0 -341
- pyxlpr/data/icdar/iou.py +0 -340
- pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
- pyxlpr/data/imtextline.py +0 -473
- pyxlpr/data/labelme.py +0 -866
- pyxlpr/data/removeline.py +0 -179
- pyxlpr/data/specialist.py +0 -57
- pyxlpr/eval/__init__.py +0 -85
- pyxlpr/paddleocr.py +0 -776
- pyxlpr/ppocr/__init__.py +0 -15
- pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
- pyxlpr/ppocr/data/__init__.py +0 -135
- pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
- pyxlpr/ppocr/data/imaug/__init__.py +0 -67
- pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
- pyxlpr/ppocr/data/imaug/east_process.py +0 -437
- pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
- pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
- pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
- pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
- pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
- pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
- pyxlpr/ppocr/data/imaug/operators.py +0 -433
- pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
- pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
- pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
- pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
- pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
- pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
- pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
- pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
- pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
- pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
- pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
- pyxlpr/ppocr/data/simple_dataset.py +0 -372
- pyxlpr/ppocr/losses/__init__.py +0 -61
- pyxlpr/ppocr/losses/ace_loss.py +0 -52
- pyxlpr/ppocr/losses/basic_loss.py +0 -135
- pyxlpr/ppocr/losses/center_loss.py +0 -88
- pyxlpr/ppocr/losses/cls_loss.py +0 -30
- pyxlpr/ppocr/losses/combined_loss.py +0 -67
- pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
- pyxlpr/ppocr/losses/det_db_loss.py +0 -80
- pyxlpr/ppocr/losses/det_east_loss.py +0 -63
- pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
- pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
- pyxlpr/ppocr/losses/distillation_loss.py +0 -272
- pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
- pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
- pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
- pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
- pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
- pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
- pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
- pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
- pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
- pyxlpr/ppocr/losses/table_att_loss.py +0 -109
- pyxlpr/ppocr/metrics/__init__.py +0 -44
- pyxlpr/ppocr/metrics/cls_metric.py +0 -45
- pyxlpr/ppocr/metrics/det_metric.py +0 -82
- pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
- pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
- pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
- pyxlpr/ppocr/metrics/kie_metric.py +0 -70
- pyxlpr/ppocr/metrics/rec_metric.py +0 -75
- pyxlpr/ppocr/metrics/table_metric.py +0 -50
- pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
- pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
- pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
- pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
- pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
- pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
- pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
- pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
- pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
- pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
- pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
- pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
- pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
- pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
- pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
- pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
- pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
- pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
- pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
- pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
- pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
- pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
- pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
- pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
- pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
- pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
- pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
- pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
- pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
- pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
- pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
- pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
- pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
- pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
- pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
- pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
- pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
- pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
- pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
- pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
- pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
- pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
- pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
- pyxlpr/ppocr/optimizer/__init__.py +0 -61
- pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
- pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
- pyxlpr/ppocr/optimizer/optimizer.py +0 -160
- pyxlpr/ppocr/optimizer/regularizer.py +0 -52
- pyxlpr/ppocr/postprocess/__init__.py +0 -55
- pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
- pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
- pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
- pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
- pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
- pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
- pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
- pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
- pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
- pyxlpr/ppocr/tools/__init__.py +0 -14
- pyxlpr/ppocr/tools/eval.py +0 -83
- pyxlpr/ppocr/tools/export_center.py +0 -77
- pyxlpr/ppocr/tools/export_model.py +0 -129
- pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
- pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
- pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
- pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
- pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
- pyxlpr/ppocr/tools/infer/utility.py +0 -629
- pyxlpr/ppocr/tools/infer_cls.py +0 -83
- pyxlpr/ppocr/tools/infer_det.py +0 -134
- pyxlpr/ppocr/tools/infer_e2e.py +0 -122
- pyxlpr/ppocr/tools/infer_kie.py +0 -153
- pyxlpr/ppocr/tools/infer_rec.py +0 -146
- pyxlpr/ppocr/tools/infer_table.py +0 -107
- pyxlpr/ppocr/tools/program.py +0 -596
- pyxlpr/ppocr/tools/test_hubserving.py +0 -117
- pyxlpr/ppocr/tools/train.py +0 -163
- pyxlpr/ppocr/tools/xlprog.py +0 -748
- pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
- pyxlpr/ppocr/utils/__init__.py +0 -24
- pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
- pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
- pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
- pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
- pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
- pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
- pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
- pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
- pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
- pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
- pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
- pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
- pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
- pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
- pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
- pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
- pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
- pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
- pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
- pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
- pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
- pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
- pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
- pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
- pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
- pyxlpr/ppocr/utils/dict90.txt +0 -90
- pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
- pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
- pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
- pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
- pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
- pyxlpr/ppocr/utils/en_dict.txt +0 -95
- pyxlpr/ppocr/utils/gen_label.py +0 -81
- pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
- pyxlpr/ppocr/utils/iou.py +0 -54
- pyxlpr/ppocr/utils/logging.py +0 -69
- pyxlpr/ppocr/utils/network.py +0 -84
- pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
- pyxlpr/ppocr/utils/profiler.py +0 -110
- pyxlpr/ppocr/utils/save_load.py +0 -150
- pyxlpr/ppocr/utils/stats.py +0 -72
- pyxlpr/ppocr/utils/utility.py +0 -80
- pyxlpr/ppstructure/__init__.py +0 -13
- pyxlpr/ppstructure/predict_system.py +0 -187
- pyxlpr/ppstructure/table/__init__.py +0 -13
- pyxlpr/ppstructure/table/eval_table.py +0 -72
- pyxlpr/ppstructure/table/matcher.py +0 -192
- pyxlpr/ppstructure/table/predict_structure.py +0 -136
- pyxlpr/ppstructure/table/predict_table.py +0 -221
- pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
- pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
- pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
- pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
- pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
- pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
- pyxlpr/ppstructure/utility.py +0 -71
- pyxlpr/xlai.py +0 -10
- /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
pyxlpr/ai/xlpaddle.py
DELETED
@@ -1,655 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
# @Author : 陈坤泽
|
4
|
-
# @Email : 877362867@qq.com
|
5
|
-
# @Date : 2021/11/05 09:01
|
6
|
-
|
7
|
-
"""
|
8
|
-
pp是paddlepaddle的缩写
|
9
|
-
"""
|
10
|
-
|
11
|
-
import os
|
12
|
-
import sys
|
13
|
-
import logging
|
14
|
-
import random
|
15
|
-
import shutil
|
16
|
-
import re
|
17
|
-
|
18
|
-
from tqdm import tqdm
|
19
|
-
import numpy as np
|
20
|
-
import pandas as pd
|
21
|
-
import humanfriendly
|
22
|
-
|
23
|
-
import paddle
|
24
|
-
import paddle.inference as paddle_infer
|
25
|
-
|
26
|
-
from pyxllib.algo.pupil import natural_sort
|
27
|
-
from pyxllib.xl import XlPath, browser
|
28
|
-
from pyxllib.xlcv import xlcv
|
29
|
-
from pyxlpr.ai.specialist import ClasEvaluater, show_feature_map
|
30
|
-
|
31
|
-
|
32
|
-
def __1_数据集():
|
33
|
-
pass
|
34
|
-
|
35
|
-
|
36
|
-
class SequenceDataset(paddle.io.Dataset):
|
37
|
-
def __init__(self, samples, labels=None, transform=None):
|
38
|
-
super().__init__()
|
39
|
-
self.samples = samples
|
40
|
-
self.labels = labels
|
41
|
-
# if self.labels is not None: # 对np.array类型无法直接使用len,从通用角度看,这个参数可以不设
|
42
|
-
# assert len(self.samples) == len(self.labels)
|
43
|
-
self.transform = transform
|
44
|
-
|
45
|
-
def __len__(self):
|
46
|
-
return len(self.samples)
|
47
|
-
|
48
|
-
def __getitem__(self, index):
|
49
|
-
x = self.samples[index]
|
50
|
-
if self.transform:
|
51
|
-
x = self.transform(x)
|
52
|
-
|
53
|
-
if self.labels is not None:
|
54
|
-
return x, self.labels[index]
|
55
|
-
else:
|
56
|
-
return x
|
57
|
-
|
58
|
-
|
59
|
-
def build_testdata_loader(samples, *, labels=None, transform=None, **kwargs):
|
60
|
-
""" 简化的一个创建paddle的DataLoader的函数。主要用于简化部署阶段的推理。
|
61
|
-
|
62
|
-
:param samples: list类型的输入格式
|
63
|
-
"""
|
64
|
-
import paddle.fluid.dataloader.fetcher
|
65
|
-
# 暂时不知道怎么关闭这个警告,先用暴力方法
|
66
|
-
paddle.fluid.dataloader.fetcher._WARNING_TO_LOG = False
|
67
|
-
|
68
|
-
if isinstance(samples, paddle.io.DataLoader):
|
69
|
-
return samples
|
70
|
-
elif isinstance(samples, paddle.io.Dataset):
|
71
|
-
dataset = samples
|
72
|
-
else:
|
73
|
-
dataset = SequenceDataset(samples, labels, transform)
|
74
|
-
|
75
|
-
return paddle.io.DataLoader(dataset, **kwargs)
|
76
|
-
|
77
|
-
|
78
|
-
class ImageClasDataset(paddle.io.Dataset):
|
79
|
-
""" 常用的分类数据集格式 """
|
80
|
-
|
81
|
-
def __init__(self, num_classes, samples, ratio=None, *,
|
82
|
-
use_img_augment=False, seed=4101, class_names=None):
|
83
|
-
""" 直接按root下的目录数进行分类,注意如果train和val是分开的,目录结构要一直,否则自带的类别id编号会不匹配
|
84
|
-
|
85
|
-
:param num_classes: 类别数
|
86
|
-
:param list samples: 样本清单,每个条目有[图片路径, 类别id]
|
87
|
-
:param int|float|list|tuple ratio: 取数据的比例,默认全取,可以输入一个区间,指定取哪个部分
|
88
|
-
这个操作会设置seed,确保每次随机打乱后选取的结果相同
|
89
|
-
:param list class_names: 表示id从0开始依次取,对应的类别昵称
|
90
|
-
"""
|
91
|
-
super().__init__()
|
92
|
-
|
93
|
-
if ratio is not None:
|
94
|
-
if isinstance(ratio, (int, float)):
|
95
|
-
# 每个类别取得样本区间
|
96
|
-
if ratio > 0:
|
97
|
-
left, right = 0, ratio
|
98
|
-
else:
|
99
|
-
left, right = (1 + ratio), 1
|
100
|
-
else:
|
101
|
-
left, right = ratio
|
102
|
-
|
103
|
-
# 初始化,按类别分好组
|
104
|
-
random.seed(seed)
|
105
|
-
groups = [[] for i in range(num_classes)]
|
106
|
-
for file, label in samples:
|
107
|
-
groups[label].append(file)
|
108
|
-
|
109
|
-
# 每个类别选取部分数据
|
110
|
-
samples = []
|
111
|
-
for label, files in enumerate(groups):
|
112
|
-
n = len(files)
|
113
|
-
random.shuffle(files)
|
114
|
-
files2 = files[int(left * n):int(right * n)]
|
115
|
-
samples += [[f, label] for f in files2]
|
116
|
-
|
117
|
-
self.samples = samples
|
118
|
-
self.num_classes = num_classes
|
119
|
-
self.class_names = class_names
|
120
|
-
self.use_img_augment = use_img_augment
|
121
|
-
|
122
|
-
@classmethod
|
123
|
-
def from_folder(cls, root, ratio=None, *, class_mode=1, **kwargs):
|
124
|
-
""" 从类别目录式的数据,构造图像分类数据集
|
125
|
-
|
126
|
-
:param root: 数据根目录
|
127
|
-
:param ratio: 每个类别取多少样本量
|
128
|
-
:param class_mode: 类别限定方法。注意空目录也会标记为1个类。
|
129
|
-
0,一般是读取没有label标签的测试集,所有的类别,统一用0占位
|
130
|
-
1,root下每个直接子目录是一个类别,每个类别目录里如果有嵌套目录,都会归为可用图片
|
131
|
-
2,root下每个目录均被视为一个类别,这些类别在目录结构上虽然有嵌套结构,但在模型上先用线性类别模式处理
|
132
|
-
|
133
|
-
注:空目录相当于没有该类别数据,会跳过,不会进入分类清单。比如8个类别的目录,但是有2个空的,那么实际只会生成6分类模型。
|
134
|
-
"""
|
135
|
-
|
136
|
-
def run_mode0():
|
137
|
-
samples = list(XlPath(root).glob_images('**/*'))
|
138
|
-
return samples, []
|
139
|
-
|
140
|
-
def run_mode1():
|
141
|
-
samples, class_names = [], []
|
142
|
-
dirs = sorted(XlPath(root).glob_dirs())
|
143
|
-
for i, d in enumerate(dirs):
|
144
|
-
class_names.append(d.name)
|
145
|
-
for f in d.glob_images('**/*'):
|
146
|
-
samples.append([f, i])
|
147
|
-
return samples, class_names
|
148
|
-
|
149
|
-
def run_mode2():
|
150
|
-
samples, class_names = [], []
|
151
|
-
dirs = sorted(XlPath(root).rglob_dirs())
|
152
|
-
for i, d in enumerate(dirs):
|
153
|
-
class_names.append(d.name)
|
154
|
-
for f in d.glob_images():
|
155
|
-
samples.append([f, i])
|
156
|
-
return samples, class_names
|
157
|
-
|
158
|
-
func = {0: run_mode0, 1: run_mode1, 2: run_mode2}[class_mode]
|
159
|
-
samples, class_names = func()
|
160
|
-
return cls(len(class_names), samples, ratio, class_names=class_names, **kwargs)
|
161
|
-
|
162
|
-
@classmethod
|
163
|
-
def from_label(cls, label_file, root=None, ratio=None, **kwargs):
|
164
|
-
""" 从标注文件初始化 """
|
165
|
-
label_file = XlPath(label_file)
|
166
|
-
lines = label_file.read_text().splitlines()
|
167
|
-
if root is None:
|
168
|
-
root = label_file.parent
|
169
|
-
else:
|
170
|
-
root = XlPath(root)
|
171
|
-
|
172
|
-
samples, class_names = [], set()
|
173
|
-
for line in lines:
|
174
|
-
if not line:
|
175
|
-
continue
|
176
|
-
path, label = line.split('\t')
|
177
|
-
class_names.add(label)
|
178
|
-
samples.append([root / path, int(label)])
|
179
|
-
|
180
|
-
class_names = natural_sort(list(class_names))
|
181
|
-
|
182
|
-
return cls(len(class_names), samples, ratio=ratio, class_names=class_names, **kwargs)
|
183
|
-
|
184
|
-
def __len__(self):
|
185
|
-
return len(self.samples)
|
186
|
-
|
187
|
-
def save_class_names(self, outfile):
|
188
|
-
""" 保存类别文件 """
|
189
|
-
class_names = self.class_names
|
190
|
-
if not class_names:
|
191
|
-
class_names = list(map(str, range(self.num_classes)))
|
192
|
-
outfile = XlPath(outfile)
|
193
|
-
if not outfile.parent.is_dir():
|
194
|
-
os.makedirs(outfile.parent)
|
195
|
-
outfile.write_text('\n'.join(class_names))
|
196
|
-
|
197
|
-
@classmethod
|
198
|
-
def img_augment(cls, img):
|
199
|
-
""" 自带的一套默认的增广、数据处理方案。实际应用建议根据不同任务做扩展调整。
|
200
|
-
"""
|
201
|
-
import albumentations as A
|
202
|
-
h, w, c = img.shape
|
203
|
-
# 如果进行随机裁剪,则h, w的尺寸变化
|
204
|
-
h = random.randint(int(h * 0.7), h)
|
205
|
-
w = random.randint(int(w * 0.7), w)
|
206
|
-
transform = A.Compose([
|
207
|
-
A.RandomCrop(width=w, height=h, p=0.8),
|
208
|
-
A.CoarseDropout(), # 随机噪声遮挡
|
209
|
-
A.RandomSunFlare(p=0.1), # 随机强光
|
210
|
-
A.RandomShadow(p=0.1), # 随机阴影
|
211
|
-
A.RGBShift(p=0.1), # RGB波动
|
212
|
-
A.Blur(p=0.1), # 模糊
|
213
|
-
A.RandomBrightnessContrast(p=0.2), # 随机调整图片明暗
|
214
|
-
])
|
215
|
-
return transform(image=img)['image']
|
216
|
-
|
217
|
-
@classmethod
|
218
|
-
def transform(cls, x):
|
219
|
-
""" 自带的一种默认的图片预处理方案,实际应用建议根据不同任务做扩展调整。
|
220
|
-
"""
|
221
|
-
import paddle.vision.transforms.functional as F
|
222
|
-
img = xlcv.read(x)
|
223
|
-
img = F.resize(img, (256, 256)) # 将图片尺寸统一,方便按batch训练。但resnet并不强制输入图片尺寸大小。
|
224
|
-
img = np.array(img, dtype='float32') / 255.
|
225
|
-
img = img.transpose([2, 0, 1])
|
226
|
-
return img
|
227
|
-
|
228
|
-
def __getitem__(self, index):
|
229
|
-
file, label = self.samples[index]
|
230
|
-
img = xlcv.read(file)
|
231
|
-
if self.use_img_augment:
|
232
|
-
img = self.img_augment(img)
|
233
|
-
img = self.transform(img)
|
234
|
-
return img, np.array(label, dtype='int64')
|
235
|
-
|
236
|
-
|
237
|
-
def __2_模型结构():
|
238
|
-
pass
|
239
|
-
|
240
|
-
|
241
|
-
def check_network(x):
|
242
|
-
""" 检查输入的模型x的相关信息 """
|
243
|
-
msg = '总参数量:'
|
244
|
-
msg += str(sum([p.size for p in x.parameters()]))
|
245
|
-
msg += ' | ' + ', '.join([f'{p.name}={p.size}' for p in x.parameters()])
|
246
|
-
print(msg)
|
247
|
-
|
248
|
-
|
249
|
-
def model_state_dict_df(model, *, browser=False):
|
250
|
-
""" 统计模型中所有的参数
|
251
|
-
|
252
|
-
:param browser: 不单纯返回统计表,而是用浏览器打开,展示更详细的分析报告
|
253
|
-
|
254
|
-
详细见 w211206周报
|
255
|
-
"""
|
256
|
-
ls = []
|
257
|
-
# 摘选ParamBase中部分成员属性进行展示
|
258
|
-
columns = ['var_name', 'name', 'shape', 'size', 'dtype', 'trainable', 'stop_gradient']
|
259
|
-
|
260
|
-
state_dict = model.state_dict() # 可能会有冗余重复
|
261
|
-
|
262
|
-
used = set()
|
263
|
-
for k, v in state_dict.items():
|
264
|
-
# a 由于state_dict的机制,self.b=self.a,a、b都是会重复获取的,这时候不应该重复计算参数量
|
265
|
-
# 但是后面计算存储文件大小的时候,遵循原始机制冗余存储计算空间消耗
|
266
|
-
param_id = id(v)
|
267
|
-
if param_id in used:
|
268
|
-
continue
|
269
|
-
else:
|
270
|
-
used.add(param_id)
|
271
|
-
# b msg
|
272
|
-
msg = [k]
|
273
|
-
for col_name in columns[1:]:
|
274
|
-
msg.append(getattr(v, col_name, None))
|
275
|
-
ls.append(msg)
|
276
|
-
df = pd.DataFrame.from_records(ls, columns=columns)
|
277
|
-
|
278
|
-
def html_content(df):
|
279
|
-
import io
|
280
|
-
|
281
|
-
content = f'<pre>{model}' + '</pre><br/>'
|
282
|
-
content += df.to_html()
|
283
|
-
total_params = sum(df['size'])
|
284
|
-
content += f'<br/>总参数量:{total_params}'
|
285
|
-
|
286
|
-
f = io.BytesIO()
|
287
|
-
paddle.save(state_dict, f)
|
288
|
-
content += f'<br/>文件大小:{humanfriendly.format_size(len(f.getvalue()))}'
|
289
|
-
return content
|
290
|
-
|
291
|
-
if browser:
|
292
|
-
browser.html(html_content(df))
|
293
|
-
|
294
|
-
return df
|
295
|
-
|
296
|
-
|
297
|
-
def __3_损失():
|
298
|
-
pass
|
299
|
-
|
300
|
-
|
301
|
-
def __4_优化器():
|
302
|
-
pass
|
303
|
-
|
304
|
-
|
305
|
-
def __5_评价指标():
|
306
|
-
pass
|
307
|
-
|
308
|
-
|
309
|
-
class ClasAccuracy(paddle.metric.Metric):
|
310
|
-
""" 分类问题的精度 """
|
311
|
-
|
312
|
-
def __init__(self, num_classes=None, *, print_mode=0):
|
313
|
-
"""
|
314
|
-
:param num_classes: 其实这个参数不输也没事~~
|
315
|
-
:param print_mode:
|
316
|
-
0,静默
|
317
|
-
1,reset的时候,输出f1指标
|
318
|
-
2,reset的时候,还会输出crosstab
|
319
|
-
"""
|
320
|
-
super(ClasAccuracy, self).__init__()
|
321
|
-
self.num_classes = num_classes
|
322
|
-
self.total = 0
|
323
|
-
self.count = 0
|
324
|
-
self.gt = []
|
325
|
-
self.pred = []
|
326
|
-
self.print_mode = print_mode
|
327
|
-
|
328
|
-
def name(self):
|
329
|
-
return 'acc'
|
330
|
-
|
331
|
-
def update(self, x, y):
|
332
|
-
x = x.argmax(axis=1)
|
333
|
-
y = y.reshape(-1)
|
334
|
-
cmp = (x == y)
|
335
|
-
self.count += cmp.sum()
|
336
|
-
self.total += len(cmp)
|
337
|
-
self.gt += y.tolist()
|
338
|
-
self.pred += x.tolist()
|
339
|
-
|
340
|
-
def accumulate(self):
|
341
|
-
return self.count / self.total
|
342
|
-
|
343
|
-
def reset(self):
|
344
|
-
if self.print_mode:
|
345
|
-
a = ClasEvaluater(self.gt, self.pred)
|
346
|
-
print(a.f1_score('all'))
|
347
|
-
if self.print_mode > 1:
|
348
|
-
print(a.crosstab())
|
349
|
-
self.count = 0
|
350
|
-
self.total = 0
|
351
|
-
self.gt = []
|
352
|
-
self.pred = []
|
353
|
-
|
354
|
-
|
355
|
-
class VisualAcc(paddle.callbacks.Callback):
|
356
|
-
def __init__(self, logdir, experimental_name, *, reset=False, save_model_with_input=None):
|
357
|
-
"""
|
358
|
-
:param logdir: log所在根目录
|
359
|
-
:param experimental_name: 实验名子目录
|
360
|
-
:param reset: 是否重置目录
|
361
|
-
:param save_model_with_input: 默认不存储模型结构
|
362
|
-
"""
|
363
|
-
from pyxllib.prog.pupil import check_install_package
|
364
|
-
check_install_package('visualdl')
|
365
|
-
from visualdl import LogWriter
|
366
|
-
|
367
|
-
super().__init__()
|
368
|
-
# 这样奇怪地加后缀,是为了字典序后,每个实验的train显示在eval之前
|
369
|
-
d = XlPath(logdir) / (experimental_name + '_train')
|
370
|
-
if reset and d.exists(): shutil.rmtree(d)
|
371
|
-
self.write = LogWriter(logdir=str(d))
|
372
|
-
d = XlPath(logdir) / (experimental_name + '_val')
|
373
|
-
if reset and d.exists(): shutil.rmtree(d)
|
374
|
-
self.eval_writer = LogWriter(logdir=str(d))
|
375
|
-
self.eval_times = 0
|
376
|
-
|
377
|
-
self.save_model_with_input = save_model_with_input
|
378
|
-
|
379
|
-
def on_epoch_end(self, epoch, logs=None):
|
380
|
-
self.write.add_scalar('acc', step=epoch, value=logs['acc'])
|
381
|
-
self.write.flush()
|
382
|
-
|
383
|
-
def on_eval_end(self, logs=None):
|
384
|
-
self.eval_writer.add_scalar('acc', step=self.eval_times, value=logs['acc'])
|
385
|
-
self.eval_writer.flush()
|
386
|
-
self.eval_times += 1
|
387
|
-
|
388
|
-
|
389
|
-
def __6_集成():
|
390
|
-
pass
|
391
|
-
|
392
|
-
|
393
|
-
class XlModel(paddle.Model):
|
394
|
-
def __init__(self, network, **kwargs):
|
395
|
-
"""
|
396
|
-
|
397
|
-
"""
|
398
|
-
super(XlModel, self).__init__(network, **kwargs)
|
399
|
-
self.save_dir = None
|
400
|
-
self.train_data = None
|
401
|
-
self.eval_data = None
|
402
|
-
self.test_data = None
|
403
|
-
self.callbacks = []
|
404
|
-
|
405
|
-
def get_save_dir(self):
|
406
|
-
"""
|
407
|
-
注意 self.save_dir、self.get_save_dir()各有用途
|
408
|
-
self.save_dir获取原始配置,可能是None,表示未设置,则在某些场合默认不输出文件
|
409
|
-
self.get_save_dir(),有些场合显示指定要输出文件了,则需要用这个接口获得一个明确的目录
|
410
|
-
"""
|
411
|
-
if self.save_dir is None:
|
412
|
-
return XlPath('.')
|
413
|
-
else:
|
414
|
-
return self.save_dir
|
415
|
-
|
416
|
-
def set_save_dir(self, save_dir):
|
417
|
-
"""
|
418
|
-
:param save_dir: 模型等保存的目录,有时候并不想保存模型,则可以不设
|
419
|
-
如果在未设置save_dir情况下,仍使用相关读写文件功能,默认在当前目录下处理
|
420
|
-
"""
|
421
|
-
# 相关数据的保存路径
|
422
|
-
self.save_dir = XlPath(save_dir)
|
423
|
-
os.makedirs(self.save_dir, exist_ok=True)
|
424
|
-
|
425
|
-
def set_dataset(self, train_data=None, eval_data=None, test_data=None):
|
426
|
-
if train_data:
|
427
|
-
self.train_data = train_data
|
428
|
-
if isinstance(train_data, ImageClasDataset):
|
429
|
-
self.train_data.save_class_names(self.get_save_dir() / 'class_names.txt') # 保存类别昵称文件
|
430
|
-
if eval_data:
|
431
|
-
self.eval_data = eval_data
|
432
|
-
if test_data:
|
433
|
-
self.test_data = test_data
|
434
|
-
# TODO 可以扩展一些能自动处理测试集的功能
|
435
|
-
# 不过考虑不同性质的任务,这个不太好封装,可能要分图像分类,目标检测的分类处理
|
436
|
-
# 但这样一来就等于重做一遍PaddleDet等框架了,不太必要
|
437
|
-
|
438
|
-
def try_load_params(self, relpath='final.pdparams'):
|
439
|
-
# TODO 怎么更好地resume训练?回复学习率等信息?虽然目前直接加载权重重新训练也没大碍。
|
440
|
-
pretrained_model = self.get_save_dir() / relpath
|
441
|
-
if pretrained_model.is_file():
|
442
|
-
self.network.load_dict(paddle.load(str(pretrained_model)))
|
443
|
-
|
444
|
-
def prepare_clas_task(self, optimizer=None, loss=None, metrics=None, amp_configs=None,
|
445
|
-
use_visualdl=None):
|
446
|
-
""" 分类模型的一套默认的优化器、损失、测评配置
|
447
|
-
|
448
|
-
TODO 这套配置不一定是最泛用的,需要进行更多研究
|
449
|
-
|
450
|
-
:param use_visualdl: 是否使用visualdl
|
451
|
-
支持输入str类型,作为自定义路径名
|
452
|
-
否则每次实验,会自增一个编号,生成 e0001、e0002、e0003、...
|
453
|
-
"""
|
454
|
-
from paddle.optimizer import Momentum
|
455
|
-
from paddle.regularizer import L2Decay
|
456
|
-
|
457
|
-
if optimizer is None:
|
458
|
-
optimizer = Momentum(learning_rate=0.01,
|
459
|
-
momentum=0.9,
|
460
|
-
weight_decay=L2Decay(1e-4),
|
461
|
-
parameters=self.network.parameters())
|
462
|
-
if loss is None:
|
463
|
-
loss = paddle.nn.CrossEntropyLoss()
|
464
|
-
if metrics is None:
|
465
|
-
metrics = ClasAccuracy(print_mode=2) # 自定义可以用crosstab检查的精度类
|
466
|
-
|
467
|
-
self.prepare(optimizer, loss, metrics, amp_configs)
|
468
|
-
|
469
|
-
# 但有设置save_dir的时候,默认开启可视化
|
470
|
-
if use_visualdl is None and self.save_dir is not None:
|
471
|
-
use_visualdl = True
|
472
|
-
|
473
|
-
if use_visualdl:
|
474
|
-
p = self.save_dir or XlPath('.')
|
475
|
-
if not isinstance(use_visualdl, str):
|
476
|
-
num = max([int(re.search(r'\d+', x.stem).group())
|
477
|
-
for x in p.glob_dirs()
|
478
|
-
if re.match(r'e\d+_', x.stem)], default=0) + 1
|
479
|
-
use_visualdl = f'e{num:04}'
|
480
|
-
self.callbacks.append(VisualAcc(p / 'visualdl', use_visualdl))
|
481
|
-
|
482
|
-
def train(self,
|
483
|
-
epochs=1,
|
484
|
-
batch_size=1,
|
485
|
-
eval_freq=1000, # 每多少次epoch进行精度验证,可以调大些,默认就是不验证了。反正目前机制也没有根据metric保存最优模型的操作。
|
486
|
-
log_freq=1000, # 每轮epoch中,每多少step显示一次日志,可以调大些
|
487
|
-
save_freq=1000, # 每多少次epoch保存模型。可以调大些,默认就只保存final了。
|
488
|
-
verbose=2,
|
489
|
-
drop_last=False,
|
490
|
-
shuffle=True,
|
491
|
-
num_workers=0,
|
492
|
-
callbacks=None,
|
493
|
-
accumulate_grad_batches=1,
|
494
|
-
num_iters=None,
|
495
|
-
):
|
496
|
-
""" 对 paddle.Model.fit的封装
|
497
|
-
|
498
|
-
简化了上下游配置
|
499
|
-
修改了一些参数默认值,以更符合我实际使用中的情况
|
500
|
-
"""
|
501
|
-
train_data = self.train_data
|
502
|
-
eval_data = self.eval_data
|
503
|
-
|
504
|
-
callbacks = callbacks or []
|
505
|
-
if self.callbacks:
|
506
|
-
callbacks += self.callbacks
|
507
|
-
|
508
|
-
super(XlModel, self).fit(train_data, eval_data, batch_size, epochs, eval_freq, log_freq,
|
509
|
-
self.save_dir, save_freq, verbose, drop_last, shuffle, num_workers,
|
510
|
-
callbacks, accumulate_grad_batches, num_iters)
|
511
|
-
|
512
|
-
# 判断最后是否要再做一次eval:有验证集 + 原本不是每次epoch都预测 + 正好最后次epochs结束是eval周期结束
|
513
|
-
# 此时paddle.Model.fit机制是恰好不会做eval的,这里做个补充
|
514
|
-
if eval_data and eval_freq != 1 and (epochs % eval_freq == 0):
|
515
|
-
self.evaluate(eval_data)
|
516
|
-
|
517
|
-
# TODO 要再写个metric测评?这个其实就是evaluate,不用重复写吧。
|
518
|
-
|
519
|
-
def save_static_network(self, *, data_shape=None):
|
520
|
-
""" 导出静态图部署模型 """
|
521
|
-
if data_shape is None:
|
522
|
-
data_shape = [1, 3, 256, 256]
|
523
|
-
# TODO 可以尝试从train_data、eval_data等获取尺寸
|
524
|
-
data = paddle.zeros(data_shape, dtype='float32')
|
525
|
-
infer_dir = self.get_save_dir() / 'infer/inference'
|
526
|
-
infer_dir.mkdir(parents=True, exist_ok=True)
|
527
|
-
paddle.jit.save(paddle.jit.to_static(self.network), infer_dir.as_posix(), [data])
|
528
|
-
|
529
|
-
|
530
|
-
def __7_部署():
|
531
|
-
pass
|
532
|
-
|
533
|
-
|
534
|
-
class ImageClasPredictor:
|
535
|
-
""" 图像分类框架的预测器 """
|
536
|
-
|
537
|
-
def __init__(self, model, *, transform=None, class_names=None):
|
538
|
-
self.model = model
|
539
|
-
self.transform = transform
|
540
|
-
# 如果输入该字段,会把下标id自动转为明文类名
|
541
|
-
self.class_names = class_names
|
542
|
-
|
543
|
-
@classmethod
|
544
|
-
def from_dynamic(cls, model, params_file=None, **kwargs):
|
545
|
-
""" 从动态图初始化 """
|
546
|
-
if params_file:
|
547
|
-
model.load_dict(paddle.load(params_file))
|
548
|
-
model.eval()
|
549
|
-
return cls(model, **kwargs)
|
550
|
-
|
551
|
-
@classmethod
|
552
|
-
def from_static(cls, pdmodel, pdiparams, **kwargs):
|
553
|
-
""" 从静态图初始化 """
|
554
|
-
# 创建配置对象,并根据需求配置
|
555
|
-
config = paddle_infer.Config(pdmodel, pdiparams)
|
556
|
-
device = paddle.get_device()
|
557
|
-
|
558
|
-
if device.startswith('gpu'):
|
559
|
-
config.enable_use_gpu(0, int(device.split(':')[1]))
|
560
|
-
|
561
|
-
# 根据Config创建预测对象
|
562
|
-
predictor = paddle_infer.create_predictor(config)
|
563
|
-
|
564
|
-
def model(x):
|
565
|
-
""" 静态图的使用流程会略麻烦一点
|
566
|
-
|
567
|
-
以及为了跟动态图的上下游衔接,需要统一格式
|
568
|
-
输入的tensor x 需要改成 np.array
|
569
|
-
输出的np.array 需要改成 tensor
|
570
|
-
|
571
|
-
TODO 关于这里动静态图部署的代码,可能有更好的组织形式,这个以后继续研究吧~~
|
572
|
-
"""
|
573
|
-
# 获取输入的名称
|
574
|
-
input_names = predictor.get_input_names()
|
575
|
-
# 获取输入handle
|
576
|
-
x_handle = predictor.get_input_handle(input_names[0])
|
577
|
-
x_handle.copy_from_cpu(x.numpy())
|
578
|
-
# 运行预测引擎
|
579
|
-
predictor.run()
|
580
|
-
# 获得输出名称
|
581
|
-
output_names = predictor.get_output_names()
|
582
|
-
# 获得输出handle
|
583
|
-
output_handle = predictor.get_output_handle(output_names[0])
|
584
|
-
output_data = output_handle.copy_to_cpu() # return numpy.ndarray
|
585
|
-
return paddle.Tensor(output_data)
|
586
|
-
|
587
|
-
return cls(model, **kwargs)
|
588
|
-
|
589
|
-
@classmethod
|
590
|
-
def from_modeldir(cls, root, *, dynamic_net=None, **kwargs):
|
591
|
-
""" 从特定的目录结构中初始化部署模型
|
592
|
-
使用固定的配置范式,我自己常用的训练目录结构
|
593
|
-
|
594
|
-
:param dynamic_net: 输入动态图模型类型,初始化动态图
|
595
|
-
|
596
|
-
注:使用这个接口初始化,在目录里必须要有个class_names.txt文件来确定类别数
|
597
|
-
否则请用更底层的from_dynamic、from_static精细配置
|
598
|
-
"""
|
599
|
-
root = XlPath(root)
|
600
|
-
class_names_file = root / 'class_names.txt'
|
601
|
-
assert class_names_file.is_file(), f'{class_names_file} 必须要有类别昵称配置文件,才知道类别数'
|
602
|
-
class_names = class_names_file.read_text().splitlines()
|
603
|
-
|
604
|
-
if dynamic_net:
|
605
|
-
clas = ImageClasPredictor.from_dynamic(dynamic_net(num_classes=len(class_names)),
|
606
|
-
str(root / 'final.pdparams'),
|
607
|
-
class_names=class_names,
|
608
|
-
**kwargs)
|
609
|
-
else:
|
610
|
-
clas = cls.from_static(str(root / 'infer/inference.pdmodel'),
|
611
|
-
str(root / 'infer/inference.pdiparams'),
|
612
|
-
class_names=class_names,
|
613
|
-
**kwargs)
|
614
|
-
|
615
|
-
return clas
|
616
|
-
|
617
|
-
def pred_batch(self, samples, batch_size=None, *, return_mode=0, print_mode=0):
|
618
|
-
""" 默认是进行批量识别,如果只识别单个,可以用pred
|
619
|
-
|
620
|
-
:param samples: 要识别的数据,支持类list的列表,或Dataset、DataLoader
|
621
|
-
:param return_mode: 返回值细粒度,0表示直接预测类别,1则是返回每个预测在各个类别的概率
|
622
|
-
:param print_mode: 0 静默运行,1 显示进度条
|
623
|
-
:param batch_size: 默认按把imgs整个作为一个批次前传,如果数据量很大,可以使用该参数切分batch
|
624
|
-
:return:
|
625
|
-
"""
|
626
|
-
import paddle.nn.functional as F
|
627
|
-
|
628
|
-
if not batch_size: batch_size = len(samples)
|
629
|
-
data_loader = build_testdata_loader(samples, transform=self.transform, batch_size=batch_size)
|
630
|
-
|
631
|
-
logits = []
|
632
|
-
for inputs in tqdm(data_loader, desc='预测:', disable=not print_mode):
|
633
|
-
logits.append(self.model(inputs))
|
634
|
-
# if sys.version_info.minor >= 8: # v0.1.62.2 paddlelib bug,w211202
|
635
|
-
# break
|
636
|
-
logits = paddle.concat(logits, axis=0)
|
637
|
-
|
638
|
-
if return_mode == 0:
|
639
|
-
idx = logits.argmax(1).tolist()
|
640
|
-
if self.class_names:
|
641
|
-
idx = [self.class_names[x] for x in idx]
|
642
|
-
return idx
|
643
|
-
elif return_mode == 1:
|
644
|
-
prob = F.softmax(logits, axis=1).tolist()
|
645
|
-
for i, item in enumerate(prob):
|
646
|
-
prob[i] = [round(x, 4) for x in item] # 保留4位小数就够了
|
647
|
-
return prob
|
648
|
-
else:
|
649
|
-
raise ValueError
|
650
|
-
|
651
|
-
def __call__(self, *args, **kwargs):
|
652
|
-
return self.pred_batch(*args, **kwargs)
|
653
|
-
|
654
|
-
def pred(self, img, *args, **kwargs):
|
655
|
-
return self.pred_batch([img], *args, **kwargs)[0]
|