pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/algo/geo.py +12 -0
- pyxllib/algo/intervals.py +1 -1
- pyxllib/algo/matcher.py +78 -0
- pyxllib/algo/pupil.py +187 -19
- pyxllib/algo/specialist.py +2 -1
- pyxllib/algo/stat.py +38 -2
- {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/data/echarts.py +123 -12
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/data/pglib.py +514 -30
- pyxllib/data/sqlite.py +231 -4
- pyxllib/ext/JLineViewer.py +14 -1
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +0 -1594
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/unixlib.py +6 -5
- pyxllib/ext/utools.py +108 -95
- pyxllib/ext/webhook.py +32 -14
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1003 -71
- pyxllib/file/docxlib.py +1 -1
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +9 -0
- pyxllib/file/packlib/__init__.py +112 -75
- pyxllib/file/pdflib.py +1 -1
- pyxllib/file/pupil.py +1 -1
- pyxllib/file/specialist/dirlib.py +1 -1
- pyxllib/file/specialist/download.py +10 -3
- pyxllib/file/specialist/filelib.py +266 -55
- pyxllib/file/xlsxlib.py +205 -50
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +9 -2
- pyxllib/prog/pupil.py +129 -60
- pyxllib/prog/specialist/__init__.py +176 -2
- pyxllib/prog/specialist/bc.py +5 -2
- pyxllib/prog/specialist/browser.py +11 -2
- pyxllib/prog/specialist/datetime.py +68 -0
- pyxllib/prog/specialist/tictoc.py +12 -13
- pyxllib/prog/specialist/xllog.py +5 -5
- pyxllib/prog/xlosenv.py +7 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +17 -5
- pyxllib/text/jiebalib.py +6 -3
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +159 -4
- pyxllib/text/nestenv.py +1 -1
- pyxllib/text/newbie.py +12 -0
- pyxllib/text/pupil/common.py +26 -0
- pyxllib/text/specialist/ptag.py +2 -2
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/xmllib.py +76 -14
- pyxllib/xl.py +2 -1
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
- pyxllib/ext/autogui/__init__.py +0 -8
- pyxllib-0.3.96.dist-info/METADATA +0 -51
- pyxllib-0.3.96.dist-info/RECORD +0 -333
- pyxllib-0.3.96.dist-info/top_level.txt +0 -2
- pyxlpr/ai/__init__.py +0 -5
- pyxlpr/ai/clientlib.py +0 -1281
- pyxlpr/ai/specialist.py +0 -286
- pyxlpr/ai/torch_app.py +0 -172
- pyxlpr/ai/xlpaddle.py +0 -655
- pyxlpr/ai/xltorch.py +0 -705
- pyxlpr/data/__init__.py +0 -11
- pyxlpr/data/coco.py +0 -1325
- pyxlpr/data/datacls.py +0 -365
- pyxlpr/data/datasets.py +0 -200
- pyxlpr/data/gptlib.py +0 -1291
- pyxlpr/data/icdar/__init__.py +0 -96
- pyxlpr/data/icdar/deteval.py +0 -377
- pyxlpr/data/icdar/icdar2013.py +0 -341
- pyxlpr/data/icdar/iou.py +0 -340
- pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
- pyxlpr/data/imtextline.py +0 -473
- pyxlpr/data/labelme.py +0 -866
- pyxlpr/data/removeline.py +0 -179
- pyxlpr/data/specialist.py +0 -57
- pyxlpr/eval/__init__.py +0 -85
- pyxlpr/paddleocr.py +0 -776
- pyxlpr/ppocr/__init__.py +0 -15
- pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
- pyxlpr/ppocr/data/__init__.py +0 -135
- pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
- pyxlpr/ppocr/data/imaug/__init__.py +0 -67
- pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
- pyxlpr/ppocr/data/imaug/east_process.py +0 -437
- pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
- pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
- pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
- pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
- pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
- pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
- pyxlpr/ppocr/data/imaug/operators.py +0 -433
- pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
- pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
- pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
- pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
- pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
- pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
- pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
- pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
- pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
- pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
- pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
- pyxlpr/ppocr/data/simple_dataset.py +0 -372
- pyxlpr/ppocr/losses/__init__.py +0 -61
- pyxlpr/ppocr/losses/ace_loss.py +0 -52
- pyxlpr/ppocr/losses/basic_loss.py +0 -135
- pyxlpr/ppocr/losses/center_loss.py +0 -88
- pyxlpr/ppocr/losses/cls_loss.py +0 -30
- pyxlpr/ppocr/losses/combined_loss.py +0 -67
- pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
- pyxlpr/ppocr/losses/det_db_loss.py +0 -80
- pyxlpr/ppocr/losses/det_east_loss.py +0 -63
- pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
- pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
- pyxlpr/ppocr/losses/distillation_loss.py +0 -272
- pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
- pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
- pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
- pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
- pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
- pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
- pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
- pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
- pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
- pyxlpr/ppocr/losses/table_att_loss.py +0 -109
- pyxlpr/ppocr/metrics/__init__.py +0 -44
- pyxlpr/ppocr/metrics/cls_metric.py +0 -45
- pyxlpr/ppocr/metrics/det_metric.py +0 -82
- pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
- pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
- pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
- pyxlpr/ppocr/metrics/kie_metric.py +0 -70
- pyxlpr/ppocr/metrics/rec_metric.py +0 -75
- pyxlpr/ppocr/metrics/table_metric.py +0 -50
- pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
- pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
- pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
- pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
- pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
- pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
- pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
- pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
- pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
- pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
- pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
- pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
- pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
- pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
- pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
- pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
- pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
- pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
- pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
- pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
- pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
- pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
- pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
- pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
- pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
- pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
- pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
- pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
- pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
- pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
- pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
- pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
- pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
- pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
- pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
- pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
- pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
- pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
- pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
- pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
- pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
- pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
- pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
- pyxlpr/ppocr/optimizer/__init__.py +0 -61
- pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
- pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
- pyxlpr/ppocr/optimizer/optimizer.py +0 -160
- pyxlpr/ppocr/optimizer/regularizer.py +0 -52
- pyxlpr/ppocr/postprocess/__init__.py +0 -55
- pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
- pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
- pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
- pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
- pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
- pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
- pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
- pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
- pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
- pyxlpr/ppocr/tools/__init__.py +0 -14
- pyxlpr/ppocr/tools/eval.py +0 -83
- pyxlpr/ppocr/tools/export_center.py +0 -77
- pyxlpr/ppocr/tools/export_model.py +0 -129
- pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
- pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
- pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
- pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
- pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
- pyxlpr/ppocr/tools/infer/utility.py +0 -629
- pyxlpr/ppocr/tools/infer_cls.py +0 -83
- pyxlpr/ppocr/tools/infer_det.py +0 -134
- pyxlpr/ppocr/tools/infer_e2e.py +0 -122
- pyxlpr/ppocr/tools/infer_kie.py +0 -153
- pyxlpr/ppocr/tools/infer_rec.py +0 -146
- pyxlpr/ppocr/tools/infer_table.py +0 -107
- pyxlpr/ppocr/tools/program.py +0 -596
- pyxlpr/ppocr/tools/test_hubserving.py +0 -117
- pyxlpr/ppocr/tools/train.py +0 -163
- pyxlpr/ppocr/tools/xlprog.py +0 -748
- pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
- pyxlpr/ppocr/utils/__init__.py +0 -24
- pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
- pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
- pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
- pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
- pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
- pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
- pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
- pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
- pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
- pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
- pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
- pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
- pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
- pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
- pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
- pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
- pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
- pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
- pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
- pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
- pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
- pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
- pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
- pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
- pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
- pyxlpr/ppocr/utils/dict90.txt +0 -90
- pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
- pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
- pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
- pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
- pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
- pyxlpr/ppocr/utils/en_dict.txt +0 -95
- pyxlpr/ppocr/utils/gen_label.py +0 -81
- pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
- pyxlpr/ppocr/utils/iou.py +0 -54
- pyxlpr/ppocr/utils/logging.py +0 -69
- pyxlpr/ppocr/utils/network.py +0 -84
- pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
- pyxlpr/ppocr/utils/profiler.py +0 -110
- pyxlpr/ppocr/utils/save_load.py +0 -150
- pyxlpr/ppocr/utils/stats.py +0 -72
- pyxlpr/ppocr/utils/utility.py +0 -80
- pyxlpr/ppstructure/__init__.py +0 -13
- pyxlpr/ppstructure/predict_system.py +0 -187
- pyxlpr/ppstructure/table/__init__.py +0 -13
- pyxlpr/ppstructure/table/eval_table.py +0 -72
- pyxlpr/ppstructure/table/matcher.py +0 -192
- pyxlpr/ppstructure/table/predict_structure.py +0 -136
- pyxlpr/ppstructure/table/predict_table.py +0 -221
- pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
- pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
- pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
- pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
- pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
- pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
- pyxlpr/ppstructure/utility.py +0 -71
- pyxlpr/xlai.py +0 -10
- /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
@@ -1,372 +0,0 @@
|
|
1
|
-
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
import numpy as np
|
15
|
-
import os
|
16
|
-
import random
|
17
|
-
import traceback
|
18
|
-
from paddle.io import Dataset
|
19
|
-
from .imaug import transform, create_operators
|
20
|
-
|
21
|
-
import json
|
22
|
-
|
23
|
-
from pyxllib.xl import run_once, XlPath
|
24
|
-
|
25
|
-
__all__ = ['SimpleDataSet', 'XlSimpleDataSet']
|
26
|
-
|
27
|
-
|
28
|
-
class SimpleDataSet(Dataset):
|
29
|
-
""" paddleocr 源生的基础数据格式
|
30
|
-
每张图的标注压缩在一个总的txt文件里
|
31
|
-
"""
|
32
|
-
|
33
|
-
def __init__(self, config, mode, logger, seed=None):
|
34
|
-
super(SimpleDataSet, self).__init__()
|
35
|
-
self.logger = logger
|
36
|
-
self.mode = mode.lower()
|
37
|
-
|
38
|
-
# 这里能取到全局的配置信息
|
39
|
-
global_config = config['Global']
|
40
|
-
dataset_config = config[mode]['dataset']
|
41
|
-
loader_config = config[mode]['loader']
|
42
|
-
|
43
|
-
# label文件里图片路径和json之间的分隔符
|
44
|
-
self.delimiter = dataset_config.get('delimiter', '\t')
|
45
|
-
|
46
|
-
label_file_list = dataset_config.pop('label_file_list')
|
47
|
-
data_source_num = len(label_file_list)
|
48
|
-
ratio_list = dataset_config.get("ratio_list", [1.0])
|
49
|
-
if isinstance(ratio_list, (float, int)):
|
50
|
-
ratio_list = [float(ratio_list)] * int(data_source_num)
|
51
|
-
|
52
|
-
assert len(
|
53
|
-
ratio_list
|
54
|
-
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
55
|
-
self.data_dir = dataset_config['data_dir']
|
56
|
-
self.do_shuffle = loader_config['shuffle']
|
57
|
-
|
58
|
-
self.seed = seed
|
59
|
-
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
60
|
-
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
61
|
-
self.data_idx_order_list = list(range(len(self.data_lines)))
|
62
|
-
# shuffle好像在dataset和dataloader会重复操作,虽然其实也没什么事~~
|
63
|
-
if self.mode == "train" and self.do_shuffle:
|
64
|
-
self.shuffle_data_random()
|
65
|
-
# (读取图片)数据增广操作器
|
66
|
-
self.ops = create_operators(dataset_config['transforms'], global_config)
|
67
|
-
|
68
|
-
def get_image_info_list(self, file_list, ratio_list):
|
69
|
-
""" 从多个文件按比例随机抽样获取样本 """
|
70
|
-
if isinstance(file_list, str):
|
71
|
-
file_list = [file_list]
|
72
|
-
data_lines = []
|
73
|
-
for idx, file in enumerate(file_list):
|
74
|
-
with open(file, "rb") as f:
|
75
|
-
lines = f.readlines()
|
76
|
-
if self.mode == "train" or ratio_list[idx] < 1.0:
|
77
|
-
random.seed(self.seed)
|
78
|
-
lines = random.sample(lines,
|
79
|
-
round(len(lines) * ratio_list[idx]))
|
80
|
-
data_lines.extend(lines)
|
81
|
-
return data_lines
|
82
|
-
|
83
|
-
def shuffle_data_random(self):
|
84
|
-
random.seed(self.seed)
|
85
|
-
random.shuffle(self.data_lines)
|
86
|
-
return
|
87
|
-
|
88
|
-
def get_ext_data(self):
|
89
|
-
""" 是否要添加其他图片及数量
|
90
|
-
|
91
|
-
猜测是用于mixup、cropmix等场合的数据增广,除了当前图,能随机获取其他来源图片,做综合处理
|
92
|
-
其它来源的图,会调用前2个opts,读取图片,解析标签
|
93
|
-
"""
|
94
|
-
ext_data_num = 0
|
95
|
-
for op in self.ops:
|
96
|
-
if hasattr(op, 'ext_data_num'):
|
97
|
-
ext_data_num = getattr(op, 'ext_data_num')
|
98
|
-
break
|
99
|
-
load_data_ops = self.ops[:2]
|
100
|
-
ext_data = []
|
101
|
-
|
102
|
-
while len(ext_data) < ext_data_num:
|
103
|
-
# 随机从中抽一个样本
|
104
|
-
file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
|
105
|
-
))]
|
106
|
-
data_line = self.data_lines[file_idx]
|
107
|
-
data_line = data_line.decode('utf-8')
|
108
|
-
substr = data_line.strip("\n").split(self.delimiter)
|
109
|
-
file_name = substr[0]
|
110
|
-
label = substr[1]
|
111
|
-
img_path = os.path.join(self.data_dir, file_name)
|
112
|
-
data = {'img_path': img_path, 'label': label}
|
113
|
-
if not os.path.exists(img_path):
|
114
|
-
continue
|
115
|
-
with open(data['img_path'], 'rb') as f:
|
116
|
-
img = f.read()
|
117
|
-
data['image'] = img
|
118
|
-
data = transform(data, load_data_ops)
|
119
|
-
|
120
|
-
if data is None or data['polys'].shape[1]!=4:
|
121
|
-
continue
|
122
|
-
ext_data.append(data)
|
123
|
-
return ext_data
|
124
|
-
|
125
|
-
def __getitem__(self, idx):
|
126
|
-
file_idx = self.data_idx_order_list[idx]
|
127
|
-
data_line = self.data_lines[file_idx]
|
128
|
-
# print(data_line.split()[0])
|
129
|
-
try:
|
130
|
-
data_line = data_line.decode('utf-8')
|
131
|
-
substr = data_line.strip("\n").split(self.delimiter)
|
132
|
-
file_name = substr[0]
|
133
|
-
label = substr[1]
|
134
|
-
img_path = os.path.join(self.data_dir, file_name)
|
135
|
-
data = {'img_path': img_path, 'label': label}
|
136
|
-
if not os.path.exists(img_path):
|
137
|
-
raise Exception("{} does not exist!".format(img_path))
|
138
|
-
# 无论是图片还是任何文件,都是统一先读成bytes了,然后交由transform的配置实现。
|
139
|
-
# 其中DecodeImage又可以解析读取图片数据。
|
140
|
-
with open(data['img_path'], 'rb') as f:
|
141
|
-
img = f.read()
|
142
|
-
data['image'] = img
|
143
|
-
data['ext_data'] = self.get_ext_data()
|
144
|
-
if data['label'] == '[]': # 没有文本的图片
|
145
|
-
data['label'] = '[{"transcription": "###", "points": [[0, 0], [1, 0], [1, 1], [0, 1]]}]'
|
146
|
-
outs = transform(data, self.ops)
|
147
|
-
except Exception as e:
|
148
|
-
self.logger.error(
|
149
|
-
"When parsing line {}, error happened with msg: {}".format(
|
150
|
-
data_line, e))
|
151
|
-
outs = None
|
152
|
-
if outs is None:
|
153
|
-
# 如果遇到解析出错的数据,训练阶段会随机取另一个图片代替。
|
154
|
-
# eval阶段则直接取下一张有效图片代替。
|
155
|
-
# during evaluation, we should fix the idx to get same results for many times of evaluation.
|
156
|
-
rnd_idx = np.random.randint(self.__len__(
|
157
|
-
)) if self.mode == "train" else (idx + 1) % self.__len__()
|
158
|
-
return self.__getitem__(rnd_idx)
|
159
|
-
return outs
|
160
|
-
|
161
|
-
def __len__(self):
|
162
|
-
return len(self.data_idx_order_list)
|
163
|
-
|
164
|
-
|
165
|
-
class SimpleDataSetExt(SimpleDataSet):
|
166
|
-
""" 自定义的数据结构类,支持输入标注文件所在目录来初始化
|
167
|
-
|
168
|
-
这里的 __init__、get_image_info_list 设计了一套特殊的输入范式
|
169
|
-
"""
|
170
|
-
|
171
|
-
def __init__(self, config, mode, logger, seed=None):
|
172
|
-
self.logger = logger
|
173
|
-
self.mode = mode.lower()
|
174
|
-
|
175
|
-
# 这里能取到全局的配置信息
|
176
|
-
global_config = config['Global']
|
177
|
-
dataset_config = config[mode]['dataset']
|
178
|
-
loader_config = config[mode]['loader']
|
179
|
-
|
180
|
-
# label文件里图片路径和json之间的分隔符
|
181
|
-
self.delimiter = dataset_config.get('delimiter', '\t')
|
182
|
-
|
183
|
-
self.data_dir = dataset_config['data_dir']
|
184
|
-
self.do_shuffle = loader_config['shuffle']
|
185
|
-
|
186
|
-
self.seed = seed
|
187
|
-
data_list = dataset_config.get('data_list', [])
|
188
|
-
logger.info("Initialize indexs of datasets:%s" % data_list)
|
189
|
-
self.data_lines = self.get_image_info_list(data_list)
|
190
|
-
self.data_idx_order_list = list(range(len(self.data_lines)))
|
191
|
-
# shuffle好像在dataset和dataloader会重复操作,虽然其实也没什么事~~
|
192
|
-
if self.mode == "train" and self.do_shuffle:
|
193
|
-
self.shuffle_data_random()
|
194
|
-
# (读取图片)数据增广操作器
|
195
|
-
self.ops = create_operators(dataset_config['transforms'], global_config)
|
196
|
-
|
197
|
-
def get_image_info_list(self, data_list):
|
198
|
-
""" 从标注文件所在目录获取每张图的标注信息
|
199
|
-
|
200
|
-
:return: list data_lines,每一行有两列,第1列是图片数据相对data_dir的路径,\t隔开,第2列是json标注数据
|
201
|
-
"""
|
202
|
-
if isinstance(data_list, dict):
|
203
|
-
data_list = [data_list]
|
204
|
-
|
205
|
-
data_lines = []
|
206
|
-
|
207
|
-
for idx, cfg in enumerate(data_list):
|
208
|
-
add_lines = [x.encode('utf8') for x in self.get_image_info(cfg)]
|
209
|
-
data_lines.extend(add_lines)
|
210
|
-
|
211
|
-
return data_lines
|
212
|
-
|
213
|
-
def get_image_info(self, cfg):
|
214
|
-
raise NotImplementedError
|
215
|
-
|
216
|
-
|
217
|
-
class XlSimpleDataSet(SimpleDataSetExt):
|
218
|
-
""" 支持直接配置原始的icdar2015数据格式
|
219
|
-
"""
|
220
|
-
|
221
|
-
def get_image_info(self, cfg):
|
222
|
-
# 我的XlSimpleDataSet,type是必填字段
|
223
|
-
t = cfg.pop('type') if ('type' in cfg) else ''
|
224
|
-
func = getattr(self, 'from_' + t, None)
|
225
|
-
if func:
|
226
|
-
return func(**cfg)
|
227
|
-
else:
|
228
|
-
raise TypeError('指定数据集格式不存在')
|
229
|
-
|
230
|
-
def __repr__(self):
|
231
|
-
""" 跟runonce有关,需要用这个构造类字符串,判断参数是否重复 """
|
232
|
-
args = [self.data_dir, self.seed]
|
233
|
-
return 'XlSimpleDataSet(' + ','.join(map(str, args)) + ')'
|
234
|
-
|
235
|
-
def _select_ratio(self, data, ratio):
|
236
|
-
""" 随机筛选给定的data数组数据
|
237
|
-
|
238
|
-
方便有些数据没有物理地址划分训练、验证集的,可以通过ratio直接设置,目前只支持一个数值ratio,后续也可以考虑支持list,更灵活的截选策略
|
239
|
-
"""
|
240
|
-
random.seed(4101) # 我这里是要用固定策略拆分数据,不用self.seed
|
241
|
-
random.shuffle(data)
|
242
|
-
n = len(data)
|
243
|
-
if isinstance(ratio, float):
|
244
|
-
if ratio > 0:
|
245
|
-
data = data[:int(ratio * n)]
|
246
|
-
elif ratio < 0:
|
247
|
-
data = data[int(ratio * n):]
|
248
|
-
elif isinstance(ratio, list):
|
249
|
-
left, right = ratio # 这个ratio是每个类别的模板分开处理的
|
250
|
-
data = data[int(left * n):int(right * n)]
|
251
|
-
return data
|
252
|
-
|
253
|
-
@run_once('str') # 这个标注格式是固定的,不用每次重复生成,可以使用run_once限定
|
254
|
-
def from_icdar2015(self, subdir, label_dir, ratio=None):
|
255
|
-
data_dir = XlPath(self.data_dir)
|
256
|
-
subdir = data_dir / subdir
|
257
|
-
label_dir = data_dir / label_dir
|
258
|
-
|
259
|
-
data_lines = []
|
260
|
-
|
261
|
-
txt_files = list(label_dir.glob('*.txt'))
|
262
|
-
if ratio is not None:
|
263
|
-
txt_files = self._select_ratio(txt_files, ratio)
|
264
|
-
|
265
|
-
def label2json(content):
|
266
|
-
""" 单个图的content标注内容转为json格式 """
|
267
|
-
label = []
|
268
|
-
for line in content.splitlines():
|
269
|
-
tmp = line.split(',')
|
270
|
-
points = tmp[:8]
|
271
|
-
s = []
|
272
|
-
for i in range(0, len(points), 2):
|
273
|
-
b = points[i:i + 2]
|
274
|
-
b = [int(t) for t in b]
|
275
|
-
s.append(b)
|
276
|
-
result = {"transcription": tmp[8], "points": s}
|
277
|
-
label.append(result)
|
278
|
-
return label
|
279
|
-
|
280
|
-
for f in txt_files:
|
281
|
-
# stem[3:]是去掉标注文件名多出的'gt_'的前缀
|
282
|
-
impath = (subdir / (f.stem[3:] + '.jpg')).relative_to(data_dir).as_posix()
|
283
|
-
# icdar的标注文件,有的是utf8,有的是utf-8-sig,这里使用我的自动识别功能
|
284
|
-
json_label = label2json(f.read_text(encoding=None))
|
285
|
-
label = json.dumps(json_label, ensure_ascii=False)
|
286
|
-
data_lines.append(('\t'.join([impath, label])))
|
287
|
-
|
288
|
-
return data_lines
|
289
|
-
|
290
|
-
@run_once('str')
|
291
|
-
def from_refineAgree(self, subdir, json_dir, label_file):
|
292
|
-
""" 只需要输入根目录 """
|
293
|
-
from pyxlpr.data.labelme import LabelmeDict
|
294
|
-
|
295
|
-
data_dir = XlPath(self.data_dir)
|
296
|
-
subdir = data_dir / subdir
|
297
|
-
json_dir = data_dir / json_dir
|
298
|
-
label_file = data_dir / label_file
|
299
|
-
|
300
|
-
def labelme2json(d):
|
301
|
-
""" labelme的json转为paddle的json标注 """
|
302
|
-
label = []
|
303
|
-
shapes = d['shapes']
|
304
|
-
for sp in shapes:
|
305
|
-
msg = json.loads(sp['label'])
|
306
|
-
if msg['type'] != '印刷体':
|
307
|
-
continue
|
308
|
-
result = {"transcription": msg['text'],
|
309
|
-
"points": LabelmeDict.to_quad_pts(sp)}
|
310
|
-
label.append(result)
|
311
|
-
return label
|
312
|
-
|
313
|
-
data_lines = []
|
314
|
-
sample_list = label_file.read_text().splitlines()
|
315
|
-
for x in sample_list:
|
316
|
-
if not x: continue # 忽略空行
|
317
|
-
impath = (subdir / (x + '.jpg')).relative_to(data_dir).as_posix()
|
318
|
-
f = json_dir / (x + '.json')
|
319
|
-
json_label = labelme2json(f.read_json())
|
320
|
-
label = json.dumps(json_label, ensure_ascii=False)
|
321
|
-
data_lines.append(('\t'.join([impath, label])))
|
322
|
-
|
323
|
-
return data_lines
|
324
|
-
|
325
|
-
@run_once('str')
|
326
|
-
def from_labelme_det(self, subdir='.', ratio=None, transcription_field='text'):
|
327
|
-
""" 读取sub_data_dir目录(含子目录)下所有的json文件为标注文件
|
328
|
-
|
329
|
-
:param transcription_field: 对于检测任务这个值一般没什么用,主要是一些特殊数据,标记"#"的会记为难样本,跳过不检测
|
330
|
-
None, 不设置则取sp['label']为文本值
|
331
|
-
若设置,则按字典解析label并取对应名称的键值
|
332
|
-
|
333
|
-
json1: labelme的json标注文件
|
334
|
-
json2: paddle的SimpleDataSet要传入的json格式
|
335
|
-
"""
|
336
|
-
from pyxlpr.data.labelme import LabelmeDict
|
337
|
-
|
338
|
-
data_dir = XlPath(self.data_dir)
|
339
|
-
subdir = data_dir / subdir
|
340
|
-
data_lines = []
|
341
|
-
|
342
|
-
def json1_to_json2(d):
|
343
|
-
res = []
|
344
|
-
shapes = d['shapes']
|
345
|
-
for sp in shapes:
|
346
|
-
label = sp['label']
|
347
|
-
if transcription_field:
|
348
|
-
msg = json.loads(sp['label'])
|
349
|
-
label = msg[transcription_field]
|
350
|
-
result = {"transcription": label,
|
351
|
-
"points": LabelmeDict.to_quad_pts(sp)}
|
352
|
-
res.append(result)
|
353
|
-
return res
|
354
|
-
|
355
|
-
json1_files = list(subdir.rglob('*.json'))
|
356
|
-
if ratio is not None:
|
357
|
-
json1_files = self._select_ratio(json1_files, ratio)
|
358
|
-
|
359
|
-
for json1_file in json1_files:
|
360
|
-
data = json1_file.read_json()
|
361
|
-
# 比较简单的检查是否为合法labelme的规则
|
362
|
-
if 'imagePath' not in data:
|
363
|
-
continue
|
364
|
-
img_file = json1_file.parent / data['imagePath'] # 不确定关联的图片格式,所以直接从labelme里取比较准
|
365
|
-
json2_data = json1_to_json2(data)
|
366
|
-
json2_str = json.dumps(json2_data, ensure_ascii=False)
|
367
|
-
data_lines.append(('\t'.join([img_file.relative_to(data_dir).as_posix(), json2_str])))
|
368
|
-
return data_lines
|
369
|
-
|
370
|
-
@run_once('str')
|
371
|
-
def from_simple_rec(self, subdir='.', ratio=None):
|
372
|
-
pass
|
pyxlpr/ppocr/losses/__init__.py
DELETED
@@ -1,61 +0,0 @@
|
|
1
|
-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
import copy
|
16
|
-
import paddle
|
17
|
-
import paddle.nn as nn
|
18
|
-
|
19
|
-
# det loss
|
20
|
-
from .det_db_loss import DBLoss
|
21
|
-
from .det_east_loss import EASTLoss
|
22
|
-
from .det_sast_loss import SASTLoss
|
23
|
-
from .det_pse_loss import PSELoss
|
24
|
-
|
25
|
-
# rec loss
|
26
|
-
from .rec_ctc_loss import CTCLoss
|
27
|
-
from .rec_att_loss import AttentionLoss
|
28
|
-
from .rec_srn_loss import SRNLoss
|
29
|
-
from .rec_nrtr_loss import NRTRLoss
|
30
|
-
from .rec_sar_loss import SARLoss
|
31
|
-
from .rec_aster_loss import AsterLoss
|
32
|
-
|
33
|
-
# cls loss
|
34
|
-
from .cls_loss import ClsLoss
|
35
|
-
|
36
|
-
# e2e loss
|
37
|
-
from .e2e_pg_loss import PGLoss
|
38
|
-
from .kie_sdmgr_loss import SDMGRLoss
|
39
|
-
|
40
|
-
# basic loss function
|
41
|
-
from .basic_loss import DistanceLoss
|
42
|
-
|
43
|
-
# combined loss function
|
44
|
-
from .combined_loss import CombinedLoss
|
45
|
-
|
46
|
-
# table loss
|
47
|
-
from .table_att_loss import TableAttentionLoss
|
48
|
-
|
49
|
-
|
50
|
-
def build_loss(config):
|
51
|
-
support_dict = [
|
52
|
-
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
|
53
|
-
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
|
54
|
-
'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss'
|
55
|
-
]
|
56
|
-
config = copy.deepcopy(config)
|
57
|
-
module_name = config.pop('name')
|
58
|
-
assert module_name in support_dict, Exception('loss only support {}'.format(
|
59
|
-
support_dict))
|
60
|
-
module_class = eval(module_name)(**config)
|
61
|
-
return module_class
|
pyxlpr/ppocr/losses/ace_loss.py
DELETED
@@ -1,52 +0,0 @@
|
|
1
|
-
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
# This code is refer from: https://github.com/viig99/LS-ACELoss
|
16
|
-
|
17
|
-
from __future__ import absolute_import
|
18
|
-
from __future__ import division
|
19
|
-
from __future__ import print_function
|
20
|
-
|
21
|
-
import paddle
|
22
|
-
import paddle.nn as nn
|
23
|
-
|
24
|
-
|
25
|
-
class ACELoss(nn.Layer):
|
26
|
-
def __init__(self, **kwargs):
|
27
|
-
super().__init__()
|
28
|
-
self.loss_func = nn.CrossEntropyLoss(
|
29
|
-
weight=None,
|
30
|
-
ignore_index=0,
|
31
|
-
reduction='none',
|
32
|
-
soft_label=True,
|
33
|
-
axis=-1)
|
34
|
-
|
35
|
-
def __call__(self, predicts, batch):
|
36
|
-
if isinstance(predicts, (list, tuple)):
|
37
|
-
predicts = predicts[-1]
|
38
|
-
|
39
|
-
B, N = predicts.shape[:2]
|
40
|
-
div = paddle.to_tensor([N]).astype('float32')
|
41
|
-
|
42
|
-
predicts = nn.functional.softmax(predicts, axis=-1)
|
43
|
-
aggregation_preds = paddle.sum(predicts, axis=1)
|
44
|
-
aggregation_preds = paddle.divide(aggregation_preds, div)
|
45
|
-
|
46
|
-
length = batch[2].astype("float32")
|
47
|
-
batch = batch[3].astype("float32")
|
48
|
-
batch[:, 0] = paddle.subtract(div, length)
|
49
|
-
batch = paddle.divide(batch, div)
|
50
|
-
|
51
|
-
loss = self.loss_func(aggregation_preds, batch)
|
52
|
-
return {"loss_ace": loss}
|
@@ -1,135 +0,0 @@
|
|
1
|
-
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
#Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
#you may not use this file except in compliance with the License.
|
5
|
-
#You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
#Unless required by applicable law or agreed to in writing, software
|
10
|
-
#distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
#See the License for the specific language governing permissions and
|
13
|
-
#limitations under the License.
|
14
|
-
|
15
|
-
import paddle
|
16
|
-
import paddle.nn as nn
|
17
|
-
import paddle.nn.functional as F
|
18
|
-
|
19
|
-
from paddle.nn import L1Loss
|
20
|
-
from paddle.nn import MSELoss as L2Loss
|
21
|
-
from paddle.nn import SmoothL1Loss
|
22
|
-
|
23
|
-
|
24
|
-
class CELoss(nn.Layer):
|
25
|
-
def __init__(self, epsilon=None):
|
26
|
-
super().__init__()
|
27
|
-
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
|
28
|
-
epsilon = None
|
29
|
-
self.epsilon = epsilon
|
30
|
-
|
31
|
-
def _labelsmoothing(self, target, class_num):
|
32
|
-
if target.shape[-1] != class_num:
|
33
|
-
one_hot_target = F.one_hot(target, class_num)
|
34
|
-
else:
|
35
|
-
one_hot_target = target
|
36
|
-
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
|
37
|
-
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
|
38
|
-
return soft_target
|
39
|
-
|
40
|
-
def forward(self, x, label):
|
41
|
-
loss_dict = {}
|
42
|
-
if self.epsilon is not None:
|
43
|
-
class_num = x.shape[-1]
|
44
|
-
label = self._labelsmoothing(label, class_num)
|
45
|
-
x = -F.log_softmax(x, axis=-1)
|
46
|
-
loss = paddle.sum(x * label, axis=-1)
|
47
|
-
else:
|
48
|
-
if label.shape[-1] == x.shape[-1]:
|
49
|
-
label = F.softmax(label, axis=-1)
|
50
|
-
soft_label = True
|
51
|
-
else:
|
52
|
-
soft_label = False
|
53
|
-
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
|
54
|
-
return loss
|
55
|
-
|
56
|
-
|
57
|
-
class KLJSLoss(object):
|
58
|
-
def __init__(self, mode='kl'):
|
59
|
-
assert mode in ['kl', 'js', 'KL', 'JS'
|
60
|
-
], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
61
|
-
self.mode = mode
|
62
|
-
|
63
|
-
def __call__(self, p1, p2, reduction="mean"):
|
64
|
-
|
65
|
-
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
|
66
|
-
|
67
|
-
if self.mode.lower() == "js":
|
68
|
-
loss += paddle.multiply(
|
69
|
-
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
|
70
|
-
loss *= 0.5
|
71
|
-
if reduction == "mean":
|
72
|
-
loss = paddle.mean(loss, axis=[1, 2])
|
73
|
-
elif reduction == "none" or reduction is None:
|
74
|
-
return loss
|
75
|
-
else:
|
76
|
-
loss = paddle.sum(loss, axis=[1, 2])
|
77
|
-
|
78
|
-
return loss
|
79
|
-
|
80
|
-
|
81
|
-
class DMLLoss(nn.Layer):
|
82
|
-
"""
|
83
|
-
DMLLoss
|
84
|
-
"""
|
85
|
-
|
86
|
-
def __init__(self, act=None, use_log=False):
|
87
|
-
super().__init__()
|
88
|
-
if act is not None:
|
89
|
-
assert act in ["softmax", "sigmoid"]
|
90
|
-
if act == "softmax":
|
91
|
-
self.act = nn.Softmax(axis=-1)
|
92
|
-
elif act == "sigmoid":
|
93
|
-
self.act = nn.Sigmoid()
|
94
|
-
else:
|
95
|
-
self.act = None
|
96
|
-
|
97
|
-
self.use_log = use_log
|
98
|
-
|
99
|
-
self.jskl_loss = KLJSLoss(mode="js")
|
100
|
-
|
101
|
-
def forward(self, out1, out2):
|
102
|
-
if self.act is not None:
|
103
|
-
out1 = self.act(out1)
|
104
|
-
out2 = self.act(out2)
|
105
|
-
if self.use_log:
|
106
|
-
# for recognition distillation, log is needed for feature map
|
107
|
-
log_out1 = paddle.log(out1)
|
108
|
-
log_out2 = paddle.log(out2)
|
109
|
-
loss = (F.kl_div(
|
110
|
-
log_out1, out2, reduction='batchmean') + F.kl_div(
|
111
|
-
log_out2, out1, reduction='batchmean')) / 2.0
|
112
|
-
else:
|
113
|
-
# for detection distillation log is not needed
|
114
|
-
loss = self.jskl_loss(out1, out2)
|
115
|
-
return loss
|
116
|
-
|
117
|
-
|
118
|
-
class DistanceLoss(nn.Layer):
|
119
|
-
"""
|
120
|
-
DistanceLoss:
|
121
|
-
mode: loss mode
|
122
|
-
"""
|
123
|
-
|
124
|
-
def __init__(self, mode="l2", **kargs):
|
125
|
-
super().__init__()
|
126
|
-
assert mode in ["l1", "l2", "smooth_l1"]
|
127
|
-
if mode == "l1":
|
128
|
-
self.loss_func = nn.L1Loss(**kargs)
|
129
|
-
elif mode == "l2":
|
130
|
-
self.loss_func = nn.MSELoss(**kargs)
|
131
|
-
elif mode == "smooth_l1":
|
132
|
-
self.loss_func = nn.SmoothL1Loss(**kargs)
|
133
|
-
|
134
|
-
def forward(self, x, y):
|
135
|
-
return self.loss_func(x, y)
|