pyxllib 0.0.43__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/__init__.py +9 -2
- pyxllib/algo/__init__.py +8 -0
- pyxllib/algo/disjoint.py +54 -0
- pyxllib/algo/geo.py +541 -0
- pyxllib/{util/mathlib.py → algo/intervals.py} +172 -36
- pyxllib/algo/matcher.py +389 -0
- pyxllib/algo/newbie.py +166 -0
- pyxllib/algo/pupil.py +629 -0
- pyxllib/algo/shapelylib.py +67 -0
- pyxllib/algo/specialist.py +241 -0
- pyxllib/algo/stat.py +494 -0
- pyxllib/algo/treelib.py +149 -0
- pyxllib/algo/unitlib.py +66 -0
- pyxllib/autogui/__init__.py +5 -0
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/autogui/autogui.py +852 -0
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/virtualkey.py +102 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/__init__.py +1 -11
- pyxllib/cv/expert.py +267 -0
- pyxllib/cv/{imlib.py → imfile.py} +18 -83
- pyxllib/cv/imhash.py +39 -0
- pyxllib/cv/pupil.py +9 -0
- pyxllib/cv/rgbfmt.py +1525 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/cv/trackbartools.py +163 -49
- pyxllib/cv/xlcvlib.py +1040 -0
- pyxllib/cv/xlpillib.py +423 -0
- pyxllib/data/__init__.py +0 -0
- pyxllib/data/echarts.py +240 -0
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/{util/oss2_.py → data/oss.py} +11 -9
- pyxllib/data/pglib.py +1127 -0
- pyxllib/data/sqlite.py +568 -0
- pyxllib/{util → data}/sqllib.py +13 -31
- pyxllib/ext/JLineViewer.py +505 -0
- pyxllib/ext/__init__.py +6 -0
- pyxllib/{util → ext}/demolib.py +119 -35
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +12 -0
- pyxllib/{util/main.py → ext/old.py} +122 -284
- pyxllib/ext/qt.py +449 -0
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/seleniumlib.py +76 -0
- pyxllib/{util/tklib.py → ext/tk.py} +10 -11
- pyxllib/ext/unixlib.py +827 -0
- pyxllib/ext/utools.py +351 -0
- pyxllib/{util/webhooklib.py → ext/webhook.py} +45 -17
- pyxllib/ext/win32lib.py +40 -0
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1105 -0
- pyxllib/file/__init__.py +17 -0
- pyxllib/file/docxlib.py +761 -0
- pyxllib/{util → file}/gitlib.py +40 -27
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +148 -0
- pyxllib/file/newbie.py +10 -0
- pyxllib/file/onenotelib.py +1469 -0
- pyxllib/file/packlib/__init__.py +330 -0
- pyxllib/{util → file/packlib}/zipfile.py +598 -195
- pyxllib/file/pdflib.py +426 -0
- pyxllib/file/pupil.py +185 -0
- pyxllib/file/specialist/__init__.py +685 -0
- pyxllib/{basic/_5_dirlib.py → file/specialist/dirlib.py} +364 -93
- pyxllib/file/specialist/download.py +193 -0
- pyxllib/file/specialist/filelib.py +2829 -0
- pyxllib/file/xlsxlib.py +3131 -0
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/__init__.py +5 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/deprecatedlib.py +233 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/ipyexec.py +253 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +451 -0
- pyxllib/prog/pupil.py +1197 -0
- pyxllib/{sitepackages.py → prog/sitepackages.py} +5 -3
- pyxllib/prog/specialist/__init__.py +391 -0
- pyxllib/prog/specialist/bc.py +203 -0
- pyxllib/prog/specialist/browser.py +497 -0
- pyxllib/prog/specialist/common.py +347 -0
- pyxllib/prog/specialist/datetime.py +199 -0
- pyxllib/prog/specialist/tictoc.py +240 -0
- pyxllib/prog/specialist/xllog.py +180 -0
- pyxllib/prog/xlosenv.py +108 -0
- pyxllib/stdlib/__init__.py +17 -0
- pyxllib/{util → stdlib}/tablepyxl/__init__.py +1 -3
- pyxllib/{util → stdlib}/tablepyxl/style.py +1 -1
- pyxllib/{util → stdlib}/tablepyxl/tablepyxl.py +2 -4
- pyxllib/text/__init__.py +8 -0
- pyxllib/text/ahocorasick.py +39 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +121 -0
- pyxllib/text/jiebalib.py +267 -0
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +922 -0
- pyxllib/text/latex/__init__.py +158 -0
- pyxllib/text/levenshtein.py +303 -0
- pyxllib/text/nestenv.py +1215 -0
- pyxllib/text/newbie.py +300 -0
- pyxllib/text/pupil/__init__.py +8 -0
- pyxllib/text/pupil/common.py +1121 -0
- pyxllib/text/pupil/xlalign.py +326 -0
- pyxllib/text/pycode.py +47 -0
- pyxllib/text/specialist/__init__.py +8 -0
- pyxllib/text/specialist/common.py +112 -0
- pyxllib/text/specialist/ptag.py +186 -0
- pyxllib/text/spellchecker.py +172 -0
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/vbacode.py +17 -0
- pyxllib/text/xmllib.py +747 -0
- pyxllib/xl.py +39 -0
- pyxllib/xlcv.py +17 -0
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.0.43.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +4 -5
- pyxllib/basic/_1_strlib.py +0 -945
- pyxllib/basic/_2_timelib.py +0 -488
- pyxllib/basic/_3_pathlib.py +0 -916
- pyxllib/basic/_4_loglib.py +0 -419
- pyxllib/basic/__init__.py +0 -54
- pyxllib/basic/arrow_.py +0 -250
- pyxllib/basic/chardet_.py +0 -66
- pyxllib/basic/dirlib.py +0 -529
- pyxllib/basic/dprint.py +0 -202
- pyxllib/basic/extension.py +0 -12
- pyxllib/basic/judge.py +0 -31
- pyxllib/basic/log.py +0 -204
- pyxllib/basic/pathlib_.py +0 -705
- pyxllib/basic/pytictoc.py +0 -102
- pyxllib/basic/qiniu_.py +0 -61
- pyxllib/basic/strlib.py +0 -761
- pyxllib/basic/timer.py +0 -132
- pyxllib/cv/cv.py +0 -834
- pyxllib/cv/cvlib/_1_geo.py +0 -543
- pyxllib/cv/cvlib/_2_cvprcs.py +0 -309
- pyxllib/cv/cvlib/_2_imgproc.py +0 -594
- pyxllib/cv/cvlib/_3_pilprcs.py +0 -80
- pyxllib/cv/cvlib/_4_cvimg.py +0 -211
- pyxllib/cv/cvlib/__init__.py +0 -10
- pyxllib/cv/debugtools.py +0 -82
- pyxllib/cv/fitz_.py +0 -300
- pyxllib/cv/installer.py +0 -42
- pyxllib/debug/_0_installer.py +0 -38
- pyxllib/debug/_1_typelib.py +0 -277
- pyxllib/debug/_2_chrome.py +0 -198
- pyxllib/debug/_3_showdir.py +0 -161
- pyxllib/debug/_4_bcompare.py +0 -140
- pyxllib/debug/__init__.py +0 -49
- pyxllib/debug/bcompare.py +0 -132
- pyxllib/debug/chrome.py +0 -198
- pyxllib/debug/installer.py +0 -38
- pyxllib/debug/showdir.py +0 -158
- pyxllib/debug/typelib.py +0 -278
- pyxllib/image/__init__.py +0 -12
- pyxllib/torch/__init__.py +0 -20
- pyxllib/torch/modellib.py +0 -37
- pyxllib/torch/trainlib.py +0 -344
- pyxllib/util/__init__.py +0 -20
- pyxllib/util/aip_.py +0 -141
- pyxllib/util/casiadb.py +0 -59
- pyxllib/util/excellib.py +0 -495
- pyxllib/util/filelib.py +0 -612
- pyxllib/util/jsondata.py +0 -27
- pyxllib/util/jsondata2.py +0 -92
- pyxllib/util/labelmelib.py +0 -139
- pyxllib/util/onepy/__init__.py +0 -29
- pyxllib/util/onepy/onepy.py +0 -574
- pyxllib/util/onepy/onmanager.py +0 -170
- pyxllib/util/pyautogui_.py +0 -219
- pyxllib/util/textlib.py +0 -1305
- pyxllib/util/unorder.py +0 -22
- pyxllib/util/xmllib.py +0 -639
- pyxllib-0.0.43.dist-info/METADATA +0 -39
- pyxllib-0.0.43.dist-info/RECORD +0 -80
- pyxllib-0.0.43.dist-info/top_level.txt +0 -1
- {pyxllib-0.0.43.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
pyxllib/torch/trainlib.py
DELETED
@@ -1,344 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
# @Author : 陈坤泽
|
4
|
-
# @Email : 877362867@qq.com
|
5
|
-
# @Data : 2020/10/18 16:06
|
6
|
-
|
7
|
-
|
8
|
-
"""
|
9
|
-
|
10
|
-
常见的训练操作的代码封装
|
11
|
-
|
12
|
-
"""
|
13
|
-
from abc import ABC
|
14
|
-
|
15
|
-
from pyxllib.debug import *
|
16
|
-
|
17
|
-
import torch
|
18
|
-
from torch import nn, optim
|
19
|
-
import torch.utils.data
|
20
|
-
|
21
|
-
import torchvision
|
22
|
-
from torchvision import transforms
|
23
|
-
|
24
|
-
# 可视化工具
|
25
|
-
try:
|
26
|
-
import visdom
|
27
|
-
except ModuleNotFoundError:
|
28
|
-
subprocess.run(['pip', 'install', 'visdom'])
|
29
|
-
import visdom
|
30
|
-
|
31
|
-
|
32
|
-
class Visdom(visdom.Visdom, metaclass=SingletonForEveryInitArgs):
|
33
|
-
"""
|
34
|
-
|
35
|
-
visdom文档: https://www.yuque.com/code4101/pytorch/visdom
|
36
|
-
"""
|
37
|
-
|
38
|
-
def __init__(
|
39
|
-
self,
|
40
|
-
server='http://localhost',
|
41
|
-
endpoint='events',
|
42
|
-
port=8097,
|
43
|
-
base_url='/',
|
44
|
-
ipv6=True,
|
45
|
-
http_proxy_host=None,
|
46
|
-
http_proxy_port=None,
|
47
|
-
env='main',
|
48
|
-
send=True,
|
49
|
-
raise_exceptions=None,
|
50
|
-
use_incoming_socket=True,
|
51
|
-
log_to_filename=None):
|
52
|
-
self.is_connection = is_url_connect(f'{server}:{port}')
|
53
|
-
|
54
|
-
if self.is_connection:
|
55
|
-
super().__init__(server, endpoint, port, base_url, ipv6,
|
56
|
-
http_proxy_host, http_proxy_port, env, send,
|
57
|
-
raise_exceptions, use_incoming_socket, log_to_filename)
|
58
|
-
else:
|
59
|
-
get_xllog().info('visdom server not support')
|
60
|
-
|
61
|
-
self.plot_windows = set()
|
62
|
-
|
63
|
-
def __bool__(self):
|
64
|
-
return self.is_connection
|
65
|
-
|
66
|
-
def one_batch_images(self, imgs, targets, title='one_batch_image', *, nrow=8, padding=2):
|
67
|
-
self.images(imgs, nrow=nrow, padding=padding,
|
68
|
-
win=title, opts={'title': title, 'caption': str(targets)})
|
69
|
-
|
70
|
-
def _check_plot_win(self, win, update=None):
|
71
|
-
# 记录窗口是否为本次执行程序时第一次初始化,并且据此推导update是首次None,还是复用append
|
72
|
-
if update is None:
|
73
|
-
if win in self.plot_windows:
|
74
|
-
update = 'append'
|
75
|
-
else:
|
76
|
-
update = None
|
77
|
-
self.plot_windows.add(win)
|
78
|
-
return update
|
79
|
-
|
80
|
-
def _refine_opts(self, opts=None, *, title=None, legend=None, **kwargs):
|
81
|
-
if opts is None:
|
82
|
-
opts = {}
|
83
|
-
if title and 'title' not in opts: opts['title'] = title
|
84
|
-
if legend and 'legend' not in opts: opts['legend'] = legend
|
85
|
-
for k, v in kwargs.items():
|
86
|
-
if k not in opts:
|
87
|
-
opts[k] = v
|
88
|
-
return opts
|
89
|
-
|
90
|
-
def loss_line(self, loss_values, epoch, win='loss', *, title=None, update=None):
|
91
|
-
""" 损失函数曲线
|
92
|
-
|
93
|
-
横坐标是epoch
|
94
|
-
"""
|
95
|
-
# 1 记录窗口是否为本次执行程序时第一次初始化
|
96
|
-
if title is None: title = win
|
97
|
-
update = self._check_plot_win(win, update)
|
98
|
-
|
99
|
-
# 2 画线
|
100
|
-
xs = np.linspace(epoch - 1, epoch, num=len(loss_values) + 1)
|
101
|
-
self.line(loss_values, xs[1:], win=win, opts={'title': title, 'xlabel': 'epoch'},
|
102
|
-
update=update)
|
103
|
-
|
104
|
-
def plot_line(self, y, x, win, *, opts=None,
|
105
|
-
title=None, legend=None, update=None):
|
106
|
-
# 1 记录窗口是否为本次执行程序时第一次初始化
|
107
|
-
if title is None: title = win
|
108
|
-
update = self._check_plot_win(win, update)
|
109
|
-
|
110
|
-
# 2 画线
|
111
|
-
self.line(y, x, win=win, update=update,
|
112
|
-
opts=self._refine_opts(opts, title=title, legend=legend, xlabel='epoch'))
|
113
|
-
|
114
|
-
|
115
|
-
def get_device():
|
116
|
-
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
117
|
-
|
118
|
-
|
119
|
-
class TinyDataset(torch.utils.data.Dataset):
|
120
|
-
def __init__(self, labelfile, label_transform):
|
121
|
-
""" 超轻量级的Dataset类,一般由外部ProjectData类指定每行label的转换规则 """
|
122
|
-
self.labels = labelfile.read().splitlines()
|
123
|
-
self.label_transform = label_transform
|
124
|
-
|
125
|
-
def __len__(self):
|
126
|
-
return len(self.labels)
|
127
|
-
|
128
|
-
def __getitem__(self, idx):
|
129
|
-
return self.label_transform(self.labels[idx])
|
130
|
-
|
131
|
-
|
132
|
-
class TrainerBase:
|
133
|
-
def __init__(self, model, datasets, *,
|
134
|
-
save_dir=None,
|
135
|
-
batch_size=None,
|
136
|
-
optimizer=None, loss_func=None):
|
137
|
-
self.log = get_xllog()
|
138
|
-
self.device = get_device()
|
139
|
-
self.save_dir = Path(save_dir) if save_dir else Path() # 没有指定数据路径则以当前工作目录为准
|
140
|
-
self.model = model
|
141
|
-
self.datasets = datasets
|
142
|
-
|
143
|
-
@classmethod
|
144
|
-
def loss_values_stat(cls, loss_vales):
|
145
|
-
""" 一组loss损失的统计分析
|
146
|
-
|
147
|
-
:param loss_vales: 一次info中,多份batch产生的误差数据
|
148
|
-
"""
|
149
|
-
if not loss_vales:
|
150
|
-
raise ValueError
|
151
|
-
|
152
|
-
data = np.array(loss_vales, dtype=float)
|
153
|
-
n, sum_ = len(data), data.sum()
|
154
|
-
mean, std = data.mean(), data.std()
|
155
|
-
msg = f'total_loss={sum_:.3f}, mean±std={mean:.3f}±{std:.3f}({max(data):.3f}->{min(data):.3f})'
|
156
|
-
return msg
|
157
|
-
|
158
|
-
@classmethod
|
159
|
-
def sample_size(cls, data):
|
160
|
-
""" 单个样本占用的空间大小,返回字节数 """
|
161
|
-
x, label = data.dataset[0] # 取第0个样本作为参考
|
162
|
-
return getasizeof(x.numpy()) + getasizeof(label)
|
163
|
-
|
164
|
-
def save_model_state(self, file):
|
165
|
-
""" 保存模型参数值
|
166
|
-
一般存储model.state_dict,而不是直接存储model,确保灵活性
|
167
|
-
|
168
|
-
# TODO 和path结合,增加if_exists参数
|
169
|
-
"""
|
170
|
-
p = Path(file, root=self.save_dir)
|
171
|
-
p.ensure_dir(pathtype='file')
|
172
|
-
torch.save(self.model.state_dict(), str(p))
|
173
|
-
|
174
|
-
def load_model_state(self, file):
|
175
|
-
""" 读取模型参数值 """
|
176
|
-
p = Path(file, root=self.save_dir)
|
177
|
-
self.model.load_state_dict(torch.load(str(p), map_location=self.device))
|
178
|
-
|
179
|
-
def get_train_data(self):
|
180
|
-
train_loader = torch.utils.data.DataLoader(
|
181
|
-
ImageDirectionDataset(self.data_dir, mode='train'),
|
182
|
-
batch_size=self.batch_size, shuffle=True, num_workers=8)
|
183
|
-
return train_loader
|
184
|
-
|
185
|
-
def get_val_data(self):
|
186
|
-
val_loader = torch.utils.data.DataLoader(
|
187
|
-
self.datasets(self.data_dir, mode='val'),
|
188
|
-
batch_size=self.batch_size, shuffle=True, num_workers=8)
|
189
|
-
return val_loader
|
190
|
-
|
191
|
-
|
192
|
-
class ClassificationTrainer(TrainerBase):
|
193
|
-
""" 对pytorch(分类)模型的训练、测试等操作的进一步封装
|
194
|
-
|
195
|
-
# TODO log变成可选项,可以关掉
|
196
|
-
"""
|
197
|
-
|
198
|
-
def __init__(self, model, *, data_dir=None, save_dir=None,
|
199
|
-
batch_size=None, optimizer=None, loss_func=None):
|
200
|
-
|
201
|
-
super().__init__(save_dir=save_dir)
|
202
|
-
self.log.info(f'initialize. use_device={self.device}.')
|
203
|
-
|
204
|
-
self.model = model.to(self.device)
|
205
|
-
self.optimizer = optimizer if optimizer else optim.Adam(model.parameters(), lr=0.01)
|
206
|
-
self.loss_func = loss_func if loss_func else nn.CrossEntropyLoss().to(self.device)
|
207
|
-
self.log.info('model parameters size: ' + str(sum(map(lambda p: p.numel(), self.model.parameters()))))
|
208
|
-
|
209
|
-
self.data_dir = Path(data_dir) if data_dir else Path() # 没有指定数据路径则以当前工作目录为准
|
210
|
-
self.log.info(f'data_dir={self.data_dir}, save_dir={self.save_dir}')
|
211
|
-
|
212
|
-
self.batch_size = batch_size if batch_size else 500
|
213
|
-
self.train_data = self.get_train_data()
|
214
|
-
self.val_data = self.get_val_data()
|
215
|
-
self.train_data_number, self.test_data_number = len(self.train_data.dataset), len(self.val_data.dataset)
|
216
|
-
self.log.info(f'get data, train_data_number={self.train_data_number}(batch={len(self.train_data)}), '
|
217
|
-
f'test_data_number={self.test_data_number}(batch={len(self.val_data)}), batch_size={self.batch_size}')
|
218
|
-
|
219
|
-
def viz_data(self):
|
220
|
-
""" 用visdom显示样本数据
|
221
|
-
|
222
|
-
TODO 增加一些自定义格式参数
|
223
|
-
TODO 不能使用\n、\r\n、<br/>实现文本换行,有时间可以研究下,结合nrow、图片宽度,自动推算,怎么美化展示效果
|
224
|
-
"""
|
225
|
-
viz = Visdom()
|
226
|
-
if not viz: return
|
227
|
-
|
228
|
-
x, label = next(iter(self.train_data))
|
229
|
-
viz.one_batch_images(x, label, 'train data')
|
230
|
-
|
231
|
-
x, label = next(iter(self.val_data))
|
232
|
-
viz.one_batch_images(x, label, 'val data')
|
233
|
-
|
234
|
-
def training_one_epoch(self):
|
235
|
-
# 1 检查模式
|
236
|
-
if not self.model.training:
|
237
|
-
self.model.train(True)
|
238
|
-
|
239
|
-
# 2 训练一轮
|
240
|
-
loss_values = []
|
241
|
-
for x, label in self.train_data:
|
242
|
-
# 每个batch可能很大,所以每个batch依次放到cuda,而不是一次性全放入
|
243
|
-
x, label = x.to(self.device), label.to(self.device)
|
244
|
-
|
245
|
-
logits = self.model(x)
|
246
|
-
if isinstance(logits, tuple):
|
247
|
-
logits = logits[0] # 如果返回是多个值,一般是RNN等层有其他信息,先只取第一个参数值就行了
|
248
|
-
loss = self.loss_func(logits, label)
|
249
|
-
loss_values.append(float(loss))
|
250
|
-
|
251
|
-
self.optimizer.zero_grad()
|
252
|
-
loss.backward()
|
253
|
-
self.optimizer.step()
|
254
|
-
|
255
|
-
# 3 训练阶段只看loss,不看实际预测准确度,默认每个epoch都会输出
|
256
|
-
return loss_values
|
257
|
-
|
258
|
-
def calculate_accuracy(self, data, prefix=''):
|
259
|
-
""" 测试验证集等数据上的精度 """
|
260
|
-
# 1 eval模式
|
261
|
-
if self.model.training:
|
262
|
-
self.model.train(False)
|
263
|
-
|
264
|
-
# 2 关闭梯度,可以节省显存和加速
|
265
|
-
with torch.no_grad():
|
266
|
-
tt = TicToc()
|
267
|
-
|
268
|
-
# 预测结果,计算正确率
|
269
|
-
loss, correct, number = [], 0, len(data.dataset)
|
270
|
-
for x, label in data:
|
271
|
-
x, label = x.to(self.device), label.to(self.device)
|
272
|
-
logits = self.model(x)
|
273
|
-
if isinstance(logits, tuple):
|
274
|
-
logits = logits[0]
|
275
|
-
loss.append(self.loss_func(logits, label))
|
276
|
-
correct += logits.argmax(dim=1).eq(label).sum().item() # 预测正确的数量
|
277
|
-
elapsed_time, mean_loss = tt.tocvalue(), np.mean(loss, dtype=float)
|
278
|
-
accuracy = correct / number
|
279
|
-
info = f'{prefix} accuracy={correct}/{number} ({accuracy:.2%})\t' \
|
280
|
-
f'mean_loss={mean_loss:.3f}\telapsed_time={elapsed_time:.0f}s'
|
281
|
-
self.log.info(info)
|
282
|
-
return accuracy
|
283
|
-
|
284
|
-
def training(self, epochs=20, *, start_epoch=0,
|
285
|
-
log_interval=1,
|
286
|
-
test_interval=0, save_interval=0):
|
287
|
-
""" 主要训练接口
|
288
|
-
|
289
|
-
:param epochs: 训练代数,输出时从1开始编号
|
290
|
-
:param start_epoch: 直接从现有的第几个epoch的模型读取参数
|
291
|
-
使用该参数,需要在self.save_dir有对应名称的model文件
|
292
|
-
:param log_interval: 每隔几个epoch输出当前epoch的训练情况,损失值
|
293
|
-
:param test_interval: 每隔几个epoch进行一次正确率测试(训练阶段只能看到每轮epoch中多个batch的平均损失)
|
294
|
-
:param save_interval: 每隔几个epoch保存一次模型
|
295
|
-
:return:
|
296
|
-
"""
|
297
|
-
# 1 参数
|
298
|
-
tag = self.model.__class__.__name__
|
299
|
-
epoch_time_tag = f'elapsed_time' if log_interval == 1 else f'{log_interval}*epoch_time'
|
300
|
-
viz = Visdom()
|
301
|
-
if test_interval == 0 and save_interval: test_interval = save_interval
|
302
|
-
|
303
|
-
# 2 加载之前的模型继续训练
|
304
|
-
if start_epoch:
|
305
|
-
self.load_model_state(f'{tag} epoch={start_epoch}.pth')
|
306
|
-
|
307
|
-
# 3 训练
|
308
|
-
tt = TicToc()
|
309
|
-
for epoch in range(start_epoch + 1, epochs + 1):
|
310
|
-
loss_values = self.training_one_epoch()
|
311
|
-
if viz: viz.loss_line(loss_values, epoch, 'train_loss')
|
312
|
-
if log_interval and epoch % log_interval == 0:
|
313
|
-
msg = self.loss_values_stat(loss_values)
|
314
|
-
elapsed_time = tt.tocvalue(restart=True)
|
315
|
-
self.log.info(f'epoch={epoch}, {epoch_time_tag}={elapsed_time:.0f}s\t{msg}')
|
316
|
-
if test_interval and epoch % test_interval == 0:
|
317
|
-
accuracy1 = self.calculate_accuracy(self.train_data, 'train_data')
|
318
|
-
accuracy2 = self.calculate_accuracy(self.val_data, ' val_data')
|
319
|
-
if viz: viz.plot_line([[accuracy1, accuracy2]], [epoch], 'accuracy', legend=['train', 'val'])
|
320
|
-
if save_interval and epoch % save_interval == 0:
|
321
|
-
self.save_model_state(f'{tag} epoch={epoch}.pth')
|
322
|
-
|
323
|
-
|
324
|
-
def get_classification_func(model, state_file, func):
|
325
|
-
""" 工厂函数,生成一个分类器函数
|
326
|
-
|
327
|
-
用这个函数做过渡的一个重要目的,也是避免重复加载模型
|
328
|
-
|
329
|
-
:param model: 模型结构
|
330
|
-
:param state_file: 存储参数的文件
|
331
|
-
:param func: 模型结果的处理器,默认
|
332
|
-
:return: 返回的函数结构见下述cls_func
|
333
|
-
"""
|
334
|
-
model.load_state_dict(torch.load(str(state_file), map_location=get_device()))
|
335
|
-
|
336
|
-
def cls_func(x):
|
337
|
-
"""
|
338
|
-
:param x: 输入可以是路径、np.ndarray、PIL图片等,都为转为batch结构的tensor
|
339
|
-
:return: 输入如果只有一张图片,则返回一个结果
|
340
|
-
否则会存在list,返回一个batch的多个结果
|
341
|
-
"""
|
342
|
-
pass
|
343
|
-
|
344
|
-
return cls_func
|
pyxllib/util/__init__.py
DELETED
@@ -1,20 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
# @Author : 陈坤泽
|
4
|
-
# @Email : 877362867@qq.com
|
5
|
-
# @Data : 2018/09/19 19:41
|
6
|
-
|
7
|
-
"""
|
8
|
-
一、一些通用的扩展组件
|
9
|
-
|
10
|
-
二、对标准库或一些第三方库,进行的功能扩展
|
11
|
-
也有可能对一些bug进行了修改
|
12
|
-
|
13
|
-
有些是小的库,直接把源码搬过来了
|
14
|
-
有些是较大的库,仍然要(会自动在需要使用时 pip install)安装
|
15
|
-
|
16
|
-
zipfile: py3.6在windows处理zip,解压中文文件会乱码,要改一个编码
|
17
|
-
这个在py3.8中也没有修复,但是py3.8的zipfile更新了不少内容,有时间我要重新整理过来
|
18
|
-
onepy: 做了些中文注解,其他修改了啥我也忘了~~可能是有改源码功能的
|
19
|
-
pyautogui: 封装扩展了自己的一个 AutoGui 类
|
20
|
-
"""
|
pyxllib/util/aip_.py
DELETED
@@ -1,141 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
# @Author : 陈坤泽
|
4
|
-
# @Email : 877362867@qq.com
|
5
|
-
# @Data : 2020/05/30 21:14
|
6
|
-
|
7
|
-
"""
|
8
|
-
百度人工智能API接口
|
9
|
-
"""
|
10
|
-
|
11
|
-
import subprocess
|
12
|
-
import pandas as pd
|
13
|
-
|
14
|
-
try:
|
15
|
-
import aip
|
16
|
-
except ModuleNotFoundError:
|
17
|
-
subprocess.run(['pip', 'install', 'baidu-aip'])
|
18
|
-
import aip
|
19
|
-
|
20
|
-
from pyxllib.basic import Path
|
21
|
-
from pyxllib.image import get_img_content
|
22
|
-
|
23
|
-
|
24
|
-
AIP_OCR_ACCOUNT_FILE = Path(__file__).parent / 'aipocraccount.pkl'
|
25
|
-
|
26
|
-
|
27
|
-
def create_account_df(file='aipocraccount.pkl'):
|
28
|
-
"""请在这里设置您个人的账户密码,并在运行完后,销毁明文信息"""
|
29
|
-
df = pd.DataFrame.from_records([
|
30
|
-
['坤泽小号', '16936214', 'aaaaaa', '123456'],
|
31
|
-
['陈坤泽', '16913345', 'bbbbbb', '123456'],
|
32
|
-
['欧龙', '16933485', 'cccccc', '123456'],
|
33
|
-
['韩锦锦', '16933339', 'dddddd', '123456'],
|
34
|
-
], columns=['user', 'APP_ID', 'API_KEY', 'SECRET_KEY'])
|
35
|
-
Path(file).write(df)
|
36
|
-
|
37
|
-
|
38
|
-
class AipOcr:
|
39
|
-
"""
|
40
|
-
封装该类
|
41
|
-
目的1:合并输入文件和url的识别
|
42
|
-
目的2:带透明底的png百度api识别不了,要先转成RGB格式
|
43
|
-
"""
|
44
|
-
client = None
|
45
|
-
client_id = 0
|
46
|
-
account_df = None
|
47
|
-
|
48
|
-
@classmethod
|
49
|
-
def init(cls, next_client=False, account_file_path=None):
|
50
|
-
# 1 账号信息
|
51
|
-
if cls.account_df is None:
|
52
|
-
if not account_file_path:
|
53
|
-
cls.account_df = Path(AIP_OCR_ACCOUNT_FILE).read()
|
54
|
-
|
55
|
-
# 2 初始化client
|
56
|
-
if cls.client is None or next_client:
|
57
|
-
t = cls.client_id + next_client
|
58
|
-
if t > len(cls.account_df):
|
59
|
-
raise ValueError('今天账号份额都用完啦!Open api daily request limit reached')
|
60
|
-
row = cls.account_df.loc[t]
|
61
|
-
AipOcr.client = aip.AipOcr(row.APP_ID, row.API_KEY, row.SECRET_KEY)
|
62
|
-
AipOcr.client_id = t
|
63
|
-
return AipOcr.client
|
64
|
-
|
65
|
-
@classmethod
|
66
|
-
def text(cls, in_, options=None):
|
67
|
-
""" 调用baidu的普通文本识别
|
68
|
-
这个函数你们随便调用,每天5万次用不完
|
69
|
-
|
70
|
-
:param in_: 可以是图片路径,也可以是网页上的url,也可以是Image对象
|
71
|
-
:param options: 可选参数
|
72
|
-
详见:https://cloud.baidu.com/doc/OCR/s/pjwvxzmtc
|
73
|
-
:return: 返回识别出的dict字典
|
74
|
-
|
75
|
-
>> baidu_accurate_ocr('0.png')
|
76
|
-
>> baidu_accurate_ocr(r'http://ksrc2.gaosiedu.com//...',
|
77
|
-
{'language_type': 'ENG'})
|
78
|
-
"""
|
79
|
-
client = cls.init()
|
80
|
-
content = get_img_content(in_)
|
81
|
-
return client.basicGeneral(content, options)
|
82
|
-
|
83
|
-
@classmethod
|
84
|
-
def accurate_text(cls, in_, options=None):
|
85
|
-
""" 调用baidu的高精度文本识别
|
86
|
-
|
87
|
-
:param in_: 可以是图片路径,也可以是url
|
88
|
-
:param options: 可选参数
|
89
|
-
详见:https://cloud.baidu.com/doc/OCR/s/pjwvxzmtc
|
90
|
-
:return: 返回识别出的dict字典
|
91
|
-
|
92
|
-
>> baidu_accurate_ocr('0.png')
|
93
|
-
>> baidu_accurate_ocr(r'http://ksrc2.gaosiedu.com//...',
|
94
|
-
{'language_type': 'ENG'})
|
95
|
-
"""
|
96
|
-
client = cls.init()
|
97
|
-
content = get_img_content(in_)
|
98
|
-
# 会自动转base64
|
99
|
-
while True:
|
100
|
-
t = client.basicAccurate(content, options)
|
101
|
-
# dprint(t)
|
102
|
-
if t.get('error_code', None) == 17:
|
103
|
-
client = AipOcr.init(next_client=True)
|
104
|
-
elif t.get('error_code', None) == 18:
|
105
|
-
# {'error_code': 18, 'error_msg': 'Open api qps request limit reached'},继续尝试
|
106
|
-
continue
|
107
|
-
else:
|
108
|
-
break
|
109
|
-
return t
|
110
|
-
|
111
|
-
|
112
|
-
def demo_aipocr():
|
113
|
-
client = AipOcr()
|
114
|
-
d = client.text("http://i1.fuimg.com/582188/7b0f9cb22c1770a0.png", {'language_type': 'ENG'})
|
115
|
-
print(d)
|
116
|
-
# d = {'log_id': 8013455108426397566, 'words_result_num': 64,
|
117
|
-
# 'words_result': [{'words': '1 . 4 . cre ated'}, {'words': 'B . shook'},
|
118
|
-
# {'words': 'C . entered'}, {'words': 'D'}, {'words': 'C ,'}, {'words': 'D , until'},
|
119
|
-
# {'words': '3 . A . break up'}, {'words': 'B . hold up'}, {'words': 'C . keep up'},
|
120
|
-
# {'words': 'D . show up'}, {'words': '4 . A . whispered'}, {'words': 'B , fought'},
|
121
|
-
# {'words': 'C . talked'}, {'words': 'D'}, {'words': '5 . A . throughout'}, {'words': 'D . after'},
|
122
|
-
# {'words': '6 . A . where'}, {'words': 'B . although'}, {'words': 'C , whle'}, {'words': 'D . that'},
|
123
|
-
# {'words': '7 . A . visitor'}, {'words': 'B . relative'}, {'words': 'C . nei gabor'},
|
124
|
-
# {'words': 'D . stranger'}, {'words': 'B , interest'}, {'words': 'D . anger'}, {'words': 'B . differenc'},
|
125
|
-
# {'words': 'C . point'}, {'words': '10 . A . forgot'}, {'words': 'B . supported'},
|
126
|
-
# {'words': 'C , resi sted'}, {'words': 'D , valued'}, {'words': '11 . A . serious'},
|
127
|
-
# {'words': 'B , nice'}, {'words': 'C , bad'}, {'words': 'D . generous'}, {'words': '12 . A . Gradually'},
|
128
|
-
# {'words': 'B . Imm ediately C . Usuall'}, {'words': 'D , Real ar'}, {'words': '13 . 4 . mind'},
|
129
|
-
# {'words': 'B , trouble'}, {'words': 'D , , order'}, {'words': 'C , lost'}, {'words': 'D , saved'},
|
130
|
-
# {'words': '15 . A . experi'}, {'words': 'B . inform ation C . impression D . advice'},
|
131
|
-
# {'words': '16 . A . However'}, {'words': 'B . Besides'}, {'words': 'C . Eventually D . Occasionally'},
|
132
|
-
# {'words': '17 . A . wrong'}, {'words': 'B . confident'}, {'words': 'C . gulty'},
|
133
|
-
# {'words': '18 . A . rem arned'}, {'words': 'B . retumed'}, {'words': 'C . changed'},
|
134
|
-
# {'words': 'D , left'}, {'words': '19 . A . method'}, {'words': 'B , truth'}, {'words': 'C , skill'},
|
135
|
-
# {'words': 'D , word'}, {'words': '20 . A . exist'}, {'words': 'B , remain'}, {'words': 'C , continue'},
|
136
|
-
# {'words': 'D . happen'}]}
|
137
|
-
|
138
|
-
|
139
|
-
if __name__ == '__main__':
|
140
|
-
# create_account_df()
|
141
|
-
demo_aipocr()
|
pyxllib/util/casiadb.py
DELETED
@@ -1,59 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
# @Author : 陈坤泽
|
4
|
-
# @Email : 877362867@qq.com
|
5
|
-
# @Data : 2020/11/08 09:30
|
6
|
-
|
7
|
-
"""
|
8
|
-
|
9
|
-
Database Home: http://www.nlpr.ia.ac.cn/databases/handwriting/Home.html
|
10
|
-
|
11
|
-
CASIA 在线和离线中文手写数据库的一些数据读取功能
|
12
|
-
|
13
|
-
"""
|
14
|
-
|
15
|
-
from pyxllib.basic import *
|
16
|
-
import numpy as np
|
17
|
-
|
18
|
-
|
19
|
-
def read_from_dgrl(dgrl):
|
20
|
-
""" 解析中科院的DGRL格式数据
|
21
|
-
|
22
|
-
参考代码:https://blog.csdn.net/DaGongJiGuoMaLu09/article/details/107050519
|
23
|
-
有做了大量简化、工程封装
|
24
|
-
|
25
|
-
TODO 可以考虑做一个返回类似labelme格式的接口,会更通用
|
26
|
-
因为有时候会需要取整张原图
|
27
|
-
而且如果有整个原图,那么每个文本行用shape形状标记即可,不需要取出子图
|
28
|
-
|
29
|
-
:param dgrl: dgrl 格式的文件,或者对应的二进制数据流
|
30
|
-
:return: [(img0, label0), (img1, label1), ...]
|
31
|
-
"""
|
32
|
-
# 输入参数可以是bytes,也可以是文件
|
33
|
-
f = XlBytesIO(dgrl)
|
34
|
-
# 表头尺寸
|
35
|
-
header_size = f.unpack('I')
|
36
|
-
# 表头剩下内容,提取 code_length
|
37
|
-
header = f.read(header_size - 4)
|
38
|
-
code_length = struct_unpack(header[-4:-2], 'H') # 每个字符存储的字节数,一般都是用gbk编码,2个字节
|
39
|
-
# 读取图像尺寸信息,文本行数量
|
40
|
-
height, width, line_num = f.unpack('I' * 3)
|
41
|
-
|
42
|
-
# 读取每一行的信息
|
43
|
-
res = []
|
44
|
-
for k in range(line_num):
|
45
|
-
# 读取该行的字符数量
|
46
|
-
char_num = f.unpack('I')
|
47
|
-
label = f.readtext(char_num, code_length=code_length)
|
48
|
-
label = label.replace('\x00', '') # 去掉不可见字符 \x00,这一步不加的话后面保存的内容会出现看不见的问题
|
49
|
-
|
50
|
-
# 读取该行的位置和尺寸
|
51
|
-
y, x, h, w = f.unpack('I' * 4)
|
52
|
-
|
53
|
-
# 读取该行的图片
|
54
|
-
bitmap = f.unpack('B' * (h * w))
|
55
|
-
bitmap = np.array(bitmap).reshape(h, w)
|
56
|
-
|
57
|
-
res.append((bitmap, label))
|
58
|
-
|
59
|
-
return res
|