pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/algo/geo.py +12 -0
- pyxllib/algo/intervals.py +1 -1
- pyxllib/algo/matcher.py +78 -0
- pyxllib/algo/pupil.py +187 -19
- pyxllib/algo/specialist.py +2 -1
- pyxllib/algo/stat.py +38 -2
- {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/data/echarts.py +123 -12
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/data/pglib.py +514 -30
- pyxllib/data/sqlite.py +231 -4
- pyxllib/ext/JLineViewer.py +14 -1
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +0 -1594
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/unixlib.py +6 -5
- pyxllib/ext/utools.py +108 -95
- pyxllib/ext/webhook.py +32 -14
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1003 -71
- pyxllib/file/docxlib.py +1 -1
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +9 -0
- pyxllib/file/packlib/__init__.py +112 -75
- pyxllib/file/pdflib.py +1 -1
- pyxllib/file/pupil.py +1 -1
- pyxllib/file/specialist/dirlib.py +1 -1
- pyxllib/file/specialist/download.py +10 -3
- pyxllib/file/specialist/filelib.py +266 -55
- pyxllib/file/xlsxlib.py +205 -50
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +9 -2
- pyxllib/prog/pupil.py +129 -60
- pyxllib/prog/specialist/__init__.py +176 -2
- pyxllib/prog/specialist/bc.py +5 -2
- pyxllib/prog/specialist/browser.py +11 -2
- pyxllib/prog/specialist/datetime.py +68 -0
- pyxllib/prog/specialist/tictoc.py +12 -13
- pyxllib/prog/specialist/xllog.py +5 -5
- pyxllib/prog/xlosenv.py +7 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +17 -5
- pyxllib/text/jiebalib.py +6 -3
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +159 -4
- pyxllib/text/nestenv.py +1 -1
- pyxllib/text/newbie.py +12 -0
- pyxllib/text/pupil/common.py +26 -0
- pyxllib/text/specialist/ptag.py +2 -2
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/xmllib.py +76 -14
- pyxllib/xl.py +2 -1
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
- pyxllib/ext/autogui/__init__.py +0 -8
- pyxllib-0.3.96.dist-info/METADATA +0 -51
- pyxllib-0.3.96.dist-info/RECORD +0 -333
- pyxllib-0.3.96.dist-info/top_level.txt +0 -2
- pyxlpr/ai/__init__.py +0 -5
- pyxlpr/ai/clientlib.py +0 -1281
- pyxlpr/ai/specialist.py +0 -286
- pyxlpr/ai/torch_app.py +0 -172
- pyxlpr/ai/xlpaddle.py +0 -655
- pyxlpr/ai/xltorch.py +0 -705
- pyxlpr/data/__init__.py +0 -11
- pyxlpr/data/coco.py +0 -1325
- pyxlpr/data/datacls.py +0 -365
- pyxlpr/data/datasets.py +0 -200
- pyxlpr/data/gptlib.py +0 -1291
- pyxlpr/data/icdar/__init__.py +0 -96
- pyxlpr/data/icdar/deteval.py +0 -377
- pyxlpr/data/icdar/icdar2013.py +0 -341
- pyxlpr/data/icdar/iou.py +0 -340
- pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
- pyxlpr/data/imtextline.py +0 -473
- pyxlpr/data/labelme.py +0 -866
- pyxlpr/data/removeline.py +0 -179
- pyxlpr/data/specialist.py +0 -57
- pyxlpr/eval/__init__.py +0 -85
- pyxlpr/paddleocr.py +0 -776
- pyxlpr/ppocr/__init__.py +0 -15
- pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
- pyxlpr/ppocr/data/__init__.py +0 -135
- pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
- pyxlpr/ppocr/data/imaug/__init__.py +0 -67
- pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
- pyxlpr/ppocr/data/imaug/east_process.py +0 -437
- pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
- pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
- pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
- pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
- pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
- pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
- pyxlpr/ppocr/data/imaug/operators.py +0 -433
- pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
- pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
- pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
- pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
- pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
- pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
- pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
- pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
- pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
- pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
- pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
- pyxlpr/ppocr/data/simple_dataset.py +0 -372
- pyxlpr/ppocr/losses/__init__.py +0 -61
- pyxlpr/ppocr/losses/ace_loss.py +0 -52
- pyxlpr/ppocr/losses/basic_loss.py +0 -135
- pyxlpr/ppocr/losses/center_loss.py +0 -88
- pyxlpr/ppocr/losses/cls_loss.py +0 -30
- pyxlpr/ppocr/losses/combined_loss.py +0 -67
- pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
- pyxlpr/ppocr/losses/det_db_loss.py +0 -80
- pyxlpr/ppocr/losses/det_east_loss.py +0 -63
- pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
- pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
- pyxlpr/ppocr/losses/distillation_loss.py +0 -272
- pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
- pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
- pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
- pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
- pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
- pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
- pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
- pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
- pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
- pyxlpr/ppocr/losses/table_att_loss.py +0 -109
- pyxlpr/ppocr/metrics/__init__.py +0 -44
- pyxlpr/ppocr/metrics/cls_metric.py +0 -45
- pyxlpr/ppocr/metrics/det_metric.py +0 -82
- pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
- pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
- pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
- pyxlpr/ppocr/metrics/kie_metric.py +0 -70
- pyxlpr/ppocr/metrics/rec_metric.py +0 -75
- pyxlpr/ppocr/metrics/table_metric.py +0 -50
- pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
- pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
- pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
- pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
- pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
- pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
- pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
- pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
- pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
- pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
- pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
- pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
- pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
- pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
- pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
- pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
- pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
- pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
- pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
- pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
- pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
- pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
- pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
- pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
- pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
- pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
- pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
- pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
- pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
- pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
- pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
- pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
- pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
- pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
- pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
- pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
- pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
- pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
- pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
- pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
- pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
- pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
- pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
- pyxlpr/ppocr/optimizer/__init__.py +0 -61
- pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
- pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
- pyxlpr/ppocr/optimizer/optimizer.py +0 -160
- pyxlpr/ppocr/optimizer/regularizer.py +0 -52
- pyxlpr/ppocr/postprocess/__init__.py +0 -55
- pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
- pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
- pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
- pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
- pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
- pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
- pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
- pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
- pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
- pyxlpr/ppocr/tools/__init__.py +0 -14
- pyxlpr/ppocr/tools/eval.py +0 -83
- pyxlpr/ppocr/tools/export_center.py +0 -77
- pyxlpr/ppocr/tools/export_model.py +0 -129
- pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
- pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
- pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
- pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
- pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
- pyxlpr/ppocr/tools/infer/utility.py +0 -629
- pyxlpr/ppocr/tools/infer_cls.py +0 -83
- pyxlpr/ppocr/tools/infer_det.py +0 -134
- pyxlpr/ppocr/tools/infer_e2e.py +0 -122
- pyxlpr/ppocr/tools/infer_kie.py +0 -153
- pyxlpr/ppocr/tools/infer_rec.py +0 -146
- pyxlpr/ppocr/tools/infer_table.py +0 -107
- pyxlpr/ppocr/tools/program.py +0 -596
- pyxlpr/ppocr/tools/test_hubserving.py +0 -117
- pyxlpr/ppocr/tools/train.py +0 -163
- pyxlpr/ppocr/tools/xlprog.py +0 -748
- pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
- pyxlpr/ppocr/utils/__init__.py +0 -24
- pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
- pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
- pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
- pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
- pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
- pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
- pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
- pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
- pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
- pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
- pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
- pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
- pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
- pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
- pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
- pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
- pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
- pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
- pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
- pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
- pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
- pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
- pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
- pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
- pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
- pyxlpr/ppocr/utils/dict90.txt +0 -90
- pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
- pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
- pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
- pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
- pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
- pyxlpr/ppocr/utils/en_dict.txt +0 -95
- pyxlpr/ppocr/utils/gen_label.py +0 -81
- pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
- pyxlpr/ppocr/utils/iou.py +0 -54
- pyxlpr/ppocr/utils/logging.py +0 -69
- pyxlpr/ppocr/utils/network.py +0 -84
- pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
- pyxlpr/ppocr/utils/profiler.py +0 -110
- pyxlpr/ppocr/utils/save_load.py +0 -150
- pyxlpr/ppocr/utils/stats.py +0 -72
- pyxlpr/ppocr/utils/utility.py +0 -80
- pyxlpr/ppstructure/__init__.py +0 -13
- pyxlpr/ppstructure/predict_system.py +0 -187
- pyxlpr/ppstructure/table/__init__.py +0 -13
- pyxlpr/ppstructure/table/eval_table.py +0 -72
- pyxlpr/ppstructure/table/matcher.py +0 -192
- pyxlpr/ppstructure/table/predict_structure.py +0 -136
- pyxlpr/ppstructure/table/predict_table.py +0 -221
- pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
- pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
- pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
- pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
- pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
- pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
- pyxlpr/ppstructure/utility.py +0 -71
- pyxlpr/xlai.py +0 -10
- /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
@@ -1,73 +0,0 @@
|
|
1
|
-
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
import importlib
|
16
|
-
import copy
|
17
|
-
|
18
|
-
from .rec_metric import RecMetric
|
19
|
-
from .det_metric import DetMetric
|
20
|
-
from .e2e_metric import E2EMetric
|
21
|
-
from .cls_metric import ClsMetric
|
22
|
-
|
23
|
-
|
24
|
-
class DistillationMetric(object):
|
25
|
-
def __init__(self,
|
26
|
-
key=None,
|
27
|
-
base_metric_name=None,
|
28
|
-
main_indicator=None,
|
29
|
-
**kwargs):
|
30
|
-
self.main_indicator = main_indicator
|
31
|
-
self.key = key
|
32
|
-
self.main_indicator = main_indicator
|
33
|
-
self.base_metric_name = base_metric_name
|
34
|
-
self.kwargs = kwargs
|
35
|
-
self.metrics = None
|
36
|
-
|
37
|
-
def _init_metrcis(self, preds):
|
38
|
-
self.metrics = dict()
|
39
|
-
mod = importlib.import_module(__name__)
|
40
|
-
for key in preds:
|
41
|
-
self.metrics[key] = getattr(mod, self.base_metric_name)(
|
42
|
-
main_indicator=self.main_indicator, **self.kwargs)
|
43
|
-
self.metrics[key].reset()
|
44
|
-
|
45
|
-
def __call__(self, preds, batch, **kwargs):
|
46
|
-
assert isinstance(preds, dict)
|
47
|
-
if self.metrics is None:
|
48
|
-
self._init_metrcis(preds)
|
49
|
-
output = dict()
|
50
|
-
for key in preds:
|
51
|
-
self.metrics[key].__call__(preds[key], batch, **kwargs)
|
52
|
-
|
53
|
-
def get_metric(self):
|
54
|
-
"""
|
55
|
-
return metrics {
|
56
|
-
'acc': 0,
|
57
|
-
'norm_edit_dis': 0,
|
58
|
-
}
|
59
|
-
"""
|
60
|
-
output = dict()
|
61
|
-
for key in self.metrics:
|
62
|
-
metric = self.metrics[key].get_metric()
|
63
|
-
# main indicator
|
64
|
-
if key == self.key:
|
65
|
-
output.update(metric)
|
66
|
-
else:
|
67
|
-
for sub_key in metric:
|
68
|
-
output["{}_{}".format(key, sub_key)] = metric[sub_key]
|
69
|
-
return output
|
70
|
-
|
71
|
-
def reset(self):
|
72
|
-
for key in self.metrics:
|
73
|
-
self.metrics[key].reset()
|
@@ -1,86 +0,0 @@
|
|
1
|
-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
__all__ = ['E2EMetric']
|
20
|
-
|
21
|
-
from pyxlpr.ppocr.utils.e2e_metric.Deteval import get_socre_A, get_socre_B, combine_results
|
22
|
-
from pyxlpr.ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict
|
23
|
-
|
24
|
-
|
25
|
-
class E2EMetric(object):
|
26
|
-
def __init__(self,
|
27
|
-
mode,
|
28
|
-
gt_mat_dir,
|
29
|
-
character_dict_path,
|
30
|
-
main_indicator='f_score_e2e',
|
31
|
-
**kwargs):
|
32
|
-
self.mode = mode
|
33
|
-
self.gt_mat_dir = gt_mat_dir
|
34
|
-
self.label_list = get_dict(character_dict_path)
|
35
|
-
self.max_index = len(self.label_list)
|
36
|
-
self.main_indicator = main_indicator
|
37
|
-
self.reset()
|
38
|
-
|
39
|
-
def __call__(self, preds, batch, **kwargs):
|
40
|
-
if self.mode == 'A':
|
41
|
-
gt_polyons_batch = batch[2]
|
42
|
-
temp_gt_strs_batch = batch[3][0]
|
43
|
-
ignore_tags_batch = batch[4]
|
44
|
-
gt_strs_batch = []
|
45
|
-
|
46
|
-
for temp_list in temp_gt_strs_batch:
|
47
|
-
t = ""
|
48
|
-
for index in temp_list:
|
49
|
-
if index < self.max_index:
|
50
|
-
t += self.label_list[index]
|
51
|
-
gt_strs_batch.append(t)
|
52
|
-
|
53
|
-
for pred, gt_polyons, gt_strs, ignore_tags in zip(
|
54
|
-
[preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch):
|
55
|
-
# prepare gt
|
56
|
-
gt_info_list = [{
|
57
|
-
'points': gt_polyon,
|
58
|
-
'text': gt_str,
|
59
|
-
'ignore': ignore_tag
|
60
|
-
} for gt_polyon, gt_str, ignore_tag in
|
61
|
-
zip(gt_polyons, gt_strs, ignore_tags)]
|
62
|
-
# prepare det
|
63
|
-
e2e_info_list = [{
|
64
|
-
'points': det_polyon,
|
65
|
-
'texts': pred_str
|
66
|
-
} for det_polyon, pred_str in
|
67
|
-
zip(pred['points'], pred['texts'])]
|
68
|
-
|
69
|
-
result = get_socre_A(gt_info_list, e2e_info_list)
|
70
|
-
self.results.append(result)
|
71
|
-
else:
|
72
|
-
img_id = batch[5][0]
|
73
|
-
e2e_info_list = [{
|
74
|
-
'points': det_polyon,
|
75
|
-
'texts': pred_str
|
76
|
-
} for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
|
77
|
-
result = get_socre_B(self.gt_mat_dir, img_id, e2e_info_list)
|
78
|
-
self.results.append(result)
|
79
|
-
|
80
|
-
def get_metric(self):
|
81
|
-
metircs = combine_results(self.results)
|
82
|
-
self.reset()
|
83
|
-
return metircs
|
84
|
-
|
85
|
-
def reset(self):
|
86
|
-
self.results = [] # clear results
|
@@ -1,274 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
from collections import namedtuple
|
4
|
-
import numpy as np
|
5
|
-
from shapely.geometry import Polygon
|
6
|
-
|
7
|
-
"""
|
8
|
-
reference from :
|
9
|
-
https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8
|
10
|
-
"""
|
11
|
-
|
12
|
-
|
13
|
-
class DetectionIoUEvaluator(object):
|
14
|
-
""" 文字检测和普通目标检测有点不太一样,所以用了个人觉得并不是最精确的iou指标
|
15
|
-
|
16
|
-
iou指标有个0.5的阈值,好像gt和pred达到0.5就算匹配了,没有扣分
|
17
|
-
"""
|
18
|
-
|
19
|
-
def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5):
|
20
|
-
self.iou_constraint = iou_constraint
|
21
|
-
self.area_precision_constraint = area_precision_constraint
|
22
|
-
|
23
|
-
def evaluate_image(self, gt, pred):
|
24
|
-
""" 计算一张图上的所有检测框精度
|
25
|
-
|
26
|
-
输入的文本内容目前看好像是用不到的~~
|
27
|
-
"""
|
28
|
-
|
29
|
-
def get_union(pD, pG):
|
30
|
-
return Polygon(pD).union(Polygon(pG)).area
|
31
|
-
|
32
|
-
def get_intersection_over_union(pD, pG):
|
33
|
-
# 因为底层是使用Polygon计算交并比,所以其实支持任意多边形的
|
34
|
-
return get_intersection(pD, pG) / get_union(pD, pG)
|
35
|
-
|
36
|
-
def get_intersection(pD, pG):
|
37
|
-
return Polygon(pD).intersection(Polygon(pG)).area
|
38
|
-
|
39
|
-
def compute_ap(confList, matchList, numGtCare):
|
40
|
-
correct = 0
|
41
|
-
AP = 0
|
42
|
-
if len(confList) > 0:
|
43
|
-
confList = np.array(confList)
|
44
|
-
matchList = np.array(matchList)
|
45
|
-
sorted_ind = np.argsort(-confList)
|
46
|
-
confList = confList[sorted_ind]
|
47
|
-
matchList = matchList[sorted_ind]
|
48
|
-
for n in range(len(confList)):
|
49
|
-
match = matchList[n]
|
50
|
-
if match:
|
51
|
-
correct += 1
|
52
|
-
AP += float(correct) / (n + 1)
|
53
|
-
|
54
|
-
if numGtCare > 0:
|
55
|
-
AP /= numGtCare
|
56
|
-
|
57
|
-
return AP
|
58
|
-
|
59
|
-
perSampleMetrics = {}
|
60
|
-
|
61
|
-
matchedSum = 0
|
62
|
-
|
63
|
-
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
|
64
|
-
|
65
|
-
numGlobalCareGt = 0
|
66
|
-
numGlobalCareDet = 0
|
67
|
-
|
68
|
-
arrGlobalConfidences = []
|
69
|
-
arrGlobalMatches = []
|
70
|
-
|
71
|
-
recall = 0
|
72
|
-
precision = 0
|
73
|
-
hmean = 0
|
74
|
-
|
75
|
-
detMatched = 0
|
76
|
-
|
77
|
-
iouMat = np.empty([1, 1])
|
78
|
-
|
79
|
-
gtPols = []
|
80
|
-
detPols = []
|
81
|
-
|
82
|
-
gtPolPoints = []
|
83
|
-
detPolPoints = []
|
84
|
-
|
85
|
-
# Array of Ground Truth Polygons' keys marked as don't Care
|
86
|
-
gtDontCarePolsNum = []
|
87
|
-
# Array of Detected Polygons' matched with a don't Care GT
|
88
|
-
detDontCarePolsNum = []
|
89
|
-
|
90
|
-
pairs = []
|
91
|
-
detMatchedNums = []
|
92
|
-
|
93
|
-
arrSampleConfidences = []
|
94
|
-
arrSampleMatch = []
|
95
|
-
|
96
|
-
evaluationLog = ""
|
97
|
-
|
98
|
-
# print(len(gt))
|
99
|
-
for n in range(len(gt)):
|
100
|
-
points = gt[n]['points']
|
101
|
-
# transcription = gt[n]['text']
|
102
|
-
dontCare = gt[n]['ignore']
|
103
|
-
# points = Polygon(points)
|
104
|
-
# points = points.buffer(0)
|
105
|
-
if not Polygon(points).is_valid or not Polygon(points).is_simple:
|
106
|
-
continue
|
107
|
-
|
108
|
-
gtPol = points
|
109
|
-
gtPols.append(gtPol)
|
110
|
-
gtPolPoints.append(points)
|
111
|
-
if dontCare:
|
112
|
-
gtDontCarePolsNum.append(len(gtPols) - 1)
|
113
|
-
|
114
|
-
evaluationLog += "GT polygons: " + str(len(gtPols)) + (
|
115
|
-
" (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
|
116
|
-
if len(gtDontCarePolsNum) > 0 else "\n")
|
117
|
-
|
118
|
-
for n in range(len(pred)):
|
119
|
-
points = pred[n]['points']
|
120
|
-
# points = Polygon(points)
|
121
|
-
# points = points.buffer(0)
|
122
|
-
if not Polygon(points).is_valid or not Polygon(points).is_simple:
|
123
|
-
continue
|
124
|
-
|
125
|
-
detPol = points
|
126
|
-
detPols.append(detPol)
|
127
|
-
detPolPoints.append(points)
|
128
|
-
if len(gtDontCarePolsNum) > 0:
|
129
|
-
for dontCarePol in gtDontCarePolsNum:
|
130
|
-
dontCarePol = gtPols[dontCarePol]
|
131
|
-
intersected_area = get_intersection(dontCarePol, detPol)
|
132
|
-
pdDimensions = Polygon(detPol).area
|
133
|
-
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
|
134
|
-
if (precision > self.area_precision_constraint):
|
135
|
-
detDontCarePolsNum.append(len(detPols) - 1)
|
136
|
-
break
|
137
|
-
|
138
|
-
evaluationLog += "DET polygons: " + str(len(detPols)) + (
|
139
|
-
" (" + str(len(detDontCarePolsNum)) + " don't care)\n"
|
140
|
-
if len(detDontCarePolsNum) > 0 else "\n")
|
141
|
-
|
142
|
-
if len(gtPols) > 0 and len(detPols) > 0:
|
143
|
-
# Calculate IoU and precision matrixs
|
144
|
-
outputShape = [len(gtPols), len(detPols)]
|
145
|
-
# 1. 创建[n, m]大小的矩阵,用于保存计算的IOU
|
146
|
-
iouMat = np.empty(outputShape)
|
147
|
-
gtRectMat = np.zeros(len(gtPols), np.int8)
|
148
|
-
detRectMat = np.zeros(len(detPols), np.int8)
|
149
|
-
for gtNum in range(len(gtPols)):
|
150
|
-
for detNum in range(len(detPols)):
|
151
|
-
pG = gtPols[gtNum]
|
152
|
-
pD = detPols[detNum]
|
153
|
-
# 2. 计算预测框和GT框之间的IOU
|
154
|
-
iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
|
155
|
-
|
156
|
-
for gtNum in range(len(gtPols)):
|
157
|
-
for detNum in range(len(detPols)):
|
158
|
-
if gtRectMat[gtNum] == 0 and detRectMat[
|
159
|
-
detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
|
160
|
-
# 2.1 统计IOU大于阈值0.5的个数
|
161
|
-
if iouMat[gtNum, detNum] > self.iou_constraint:
|
162
|
-
gtRectMat[gtNum] = 1
|
163
|
-
detRectMat[detNum] = 1
|
164
|
-
detMatched += 1
|
165
|
-
pairs.append({'gt': gtNum, 'det': detNum})
|
166
|
-
detMatchedNums.append(detNum)
|
167
|
-
evaluationLog += "Match GT #" + \
|
168
|
-
str(gtNum) + " with Det #" + str(detNum) + "\n"
|
169
|
-
|
170
|
-
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
|
171
|
-
numDetCare = (len(detPols) - len(detDontCarePolsNum))
|
172
|
-
if numGtCare == 0:
|
173
|
-
recall = float(1)
|
174
|
-
precision = float(0) if numDetCare > 0 else float(1)
|
175
|
-
else:
|
176
|
-
# 3. IOU大于阈值0.5的个数除以GT框的个数numGtcare得到recall
|
177
|
-
recall = float(detMatched) / numGtCare
|
178
|
-
# 4. IOU大于阈值0.5的个数除以预测框的个数numDetcare得到precision
|
179
|
-
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
|
180
|
-
|
181
|
-
# 5. 通过公式计算得到Hmean指标
|
182
|
-
hmean = 0 if (precision + recall) == 0 else 2.0 * \
|
183
|
-
precision * recall / (precision + recall)
|
184
|
-
|
185
|
-
matchedSum += detMatched
|
186
|
-
numGlobalCareGt += numGtCare
|
187
|
-
numGlobalCareDet += numDetCare
|
188
|
-
|
189
|
-
perSampleMetrics = {
|
190
|
-
'precision': precision,
|
191
|
-
'recall': recall,
|
192
|
-
'hmean': hmean,
|
193
|
-
'pairs': pairs,
|
194
|
-
'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
|
195
|
-
'gtPolPoints': gtPolPoints,
|
196
|
-
'detPolPoints': detPolPoints,
|
197
|
-
'gtCare': numGtCare,
|
198
|
-
'detCare': numDetCare,
|
199
|
-
'gtDontCare': gtDontCarePolsNum,
|
200
|
-
'detDontCare': detDontCarePolsNum,
|
201
|
-
'detMatched': detMatched,
|
202
|
-
'evaluationLog': evaluationLog
|
203
|
-
}
|
204
|
-
|
205
|
-
return perSampleMetrics
|
206
|
-
|
207
|
-
def combine_results(self, results):
|
208
|
-
numGlobalCareGt = 0
|
209
|
-
numGlobalCareDet = 0
|
210
|
-
matchedSum = 0
|
211
|
-
for result in results:
|
212
|
-
numGlobalCareGt += result['gtCare']
|
213
|
-
numGlobalCareDet += result['detCare']
|
214
|
-
matchedSum += result['detMatched']
|
215
|
-
|
216
|
-
methodRecall = 0 if numGlobalCareGt == 0 else float(
|
217
|
-
matchedSum) / numGlobalCareGt
|
218
|
-
methodPrecision = 0 if numGlobalCareDet == 0 else float(
|
219
|
-
matchedSum) / numGlobalCareDet
|
220
|
-
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
|
221
|
-
methodRecall * methodPrecision / (
|
222
|
-
methodRecall + methodPrecision)
|
223
|
-
# print(methodRecall, methodPrecision, methodHmean)
|
224
|
-
# sys.exit(-1)
|
225
|
-
methodMetrics = {
|
226
|
-
'precision': methodPrecision,
|
227
|
-
'recall': methodRecall,
|
228
|
-
'hmean': methodHmean
|
229
|
-
}
|
230
|
-
|
231
|
-
return methodMetrics
|
232
|
-
|
233
|
-
def evaluate_image_simple(self, gt, pred):
|
234
|
-
""" 官方原版接口太复杂,这里扩展支持普通的list的函数,内部实现数据类型扩展 """
|
235
|
-
gt = [{'points': x, 'text': '', 'ignore': False} for x in gt]
|
236
|
-
pred = [{'points': x, 'text': '', 'ignore': False} for x in pred]
|
237
|
-
return self.evaluate_image(gt, pred)
|
238
|
-
|
239
|
-
@classmethod
|
240
|
-
def eval(cls, gts, preds):
|
241
|
-
""" N张图的IoU检测分数
|
242
|
-
|
243
|
-
gts、preds都是长度为N的list
|
244
|
-
每个元素值仍然是list,存储每张图的所有boxes检测框,polygon格式表示的任意多边形
|
245
|
-
"""
|
246
|
-
evaluator = DetectionIoUEvaluator()
|
247
|
-
results = []
|
248
|
-
for gt, pred in zip(gts, preds):
|
249
|
-
results.append(evaluator.evaluate_image_simple(gt, pred))
|
250
|
-
metrics = evaluator.combine_results(results)
|
251
|
-
return metrics
|
252
|
-
|
253
|
-
|
254
|
-
if __name__ == '__main__':
|
255
|
-
evaluator = DetectionIoUEvaluator()
|
256
|
-
gts = [[{
|
257
|
-
'points': [(0, 0), (1, 0), (1, 1), (0.5, 1), (0, 1)],
|
258
|
-
'text': 1234,
|
259
|
-
'ignore': False,
|
260
|
-
}, {
|
261
|
-
'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
|
262
|
-
'text': 5678,
|
263
|
-
'ignore': False,
|
264
|
-
}]]
|
265
|
-
preds = [[{
|
266
|
-
'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
|
267
|
-
'text': 123,
|
268
|
-
'ignore': False,
|
269
|
-
}]]
|
270
|
-
results = []
|
271
|
-
for gt, pred in zip(gts, preds):
|
272
|
-
results.append(evaluator.evaluate_image(gt, pred))
|
273
|
-
metrics = evaluator.combine_results(results)
|
274
|
-
print(metrics)
|
@@ -1,70 +0,0 @@
|
|
1
|
-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
|
19
|
-
import numpy as np
|
20
|
-
import paddle
|
21
|
-
|
22
|
-
__all__ = ['KIEMetric']
|
23
|
-
|
24
|
-
|
25
|
-
class KIEMetric(object):
|
26
|
-
def __init__(self, main_indicator='hmean', **kwargs):
|
27
|
-
self.main_indicator = main_indicator
|
28
|
-
self.reset()
|
29
|
-
self.node = []
|
30
|
-
self.gt = []
|
31
|
-
|
32
|
-
def __call__(self, preds, batch, **kwargs):
|
33
|
-
nodes, _ = preds
|
34
|
-
gts, tag = batch[4].squeeze(0), batch[5].tolist()[0]
|
35
|
-
gts = gts[:tag[0], :1].reshape([-1])
|
36
|
-
self.node.append(nodes.numpy())
|
37
|
-
self.gt.append(gts)
|
38
|
-
# result = self.compute_f1_score(nodes, gts)
|
39
|
-
# self.results.append(result)
|
40
|
-
|
41
|
-
def compute_f1_score(self, preds, gts):
|
42
|
-
ignores = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25]
|
43
|
-
C = preds.shape[1]
|
44
|
-
classes = np.array(sorted(set(range(C)) - set(ignores)))
|
45
|
-
hist = np.bincount(
|
46
|
-
(gts * C).astype('int64') + preds.argmax(1), minlength=C
|
47
|
-
**2).reshape([C, C]).astype('float32')
|
48
|
-
diag = np.diag(hist)
|
49
|
-
recalls = diag / hist.sum(1).clip(min=1)
|
50
|
-
precisions = diag / hist.sum(0).clip(min=1)
|
51
|
-
f1 = 2 * recalls * precisions / (recalls + precisions).clip(min=1e-8)
|
52
|
-
return f1[classes]
|
53
|
-
|
54
|
-
def combine_results(self, results):
|
55
|
-
node = np.concatenate(self.node, 0)
|
56
|
-
gts = np.concatenate(self.gt, 0)
|
57
|
-
results = self.compute_f1_score(node, gts)
|
58
|
-
data = {'hmean': results.mean()}
|
59
|
-
return data
|
60
|
-
|
61
|
-
def get_metric(self):
|
62
|
-
|
63
|
-
metircs = self.combine_results(self.results)
|
64
|
-
self.reset()
|
65
|
-
return metircs
|
66
|
-
|
67
|
-
def reset(self):
|
68
|
-
self.results = [] # clear results
|
69
|
-
self.node = []
|
70
|
-
self.gt = []
|
@@ -1,75 +0,0 @@
|
|
1
|
-
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
import Levenshtein
|
16
|
-
import string
|
17
|
-
|
18
|
-
|
19
|
-
class RecMetric(object):
|
20
|
-
def __init__(self, main_indicator='acc', is_filter=False, **kwargs):
|
21
|
-
self.main_indicator = main_indicator
|
22
|
-
self.is_filter = is_filter
|
23
|
-
self.reset()
|
24
|
-
|
25
|
-
def _normalize_text(self, text):
|
26
|
-
text = ''.join(
|
27
|
-
filter(lambda x: x in (string.digits + string.ascii_letters), text))
|
28
|
-
return text.lower()
|
29
|
-
|
30
|
-
def __call__(self, pred_label, *args, **kwargs):
|
31
|
-
preds, labels = pred_label
|
32
|
-
correct_num = 0
|
33
|
-
all_num = 0
|
34
|
-
norm_edit_dis = 0.0
|
35
|
-
for (pred, pred_conf), (target, _) in zip(preds, labels):
|
36
|
-
pred = pred.replace(" ", "")
|
37
|
-
target = target.replace(" ", "")
|
38
|
-
if self.is_filter:
|
39
|
-
pred = self._normalize_text(pred)
|
40
|
-
target = self._normalize_text(target)
|
41
|
-
norm_edit_dis += Levenshtein.distance(pred, target) / max(
|
42
|
-
len(pred), len(target), 1)
|
43
|
-
if pred == target:
|
44
|
-
correct_num += 1
|
45
|
-
all_num += 1
|
46
|
-
self.correct_num += correct_num
|
47
|
-
self.all_num += all_num
|
48
|
-
self.norm_edit_dis += norm_edit_dis
|
49
|
-
return {
|
50
|
-
'acc': correct_num / all_num,
|
51
|
-
'norm_edit_dis': 1 - norm_edit_dis / (all_num + 1e-3)
|
52
|
-
}
|
53
|
-
|
54
|
-
def get_metric(self):
|
55
|
-
"""
|
56
|
-
return metrics {
|
57
|
-
'acc': 0,
|
58
|
-
'norm_edit_dis': 0,
|
59
|
-
}
|
60
|
-
"""
|
61
|
-
acc = 1.0 * self.correct_num / (self.all_num + 1e-3)
|
62
|
-
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + 1e-3)
|
63
|
-
self.reset()
|
64
|
-
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
|
65
|
-
|
66
|
-
def reset(self):
|
67
|
-
self.correct_num = 0
|
68
|
-
self.all_num = 0
|
69
|
-
self.norm_edit_dis = 0
|
70
|
-
|
71
|
-
@classmethod
|
72
|
-
def eval(cls, preds, labels):
|
73
|
-
preds = [(x, 1) for x in preds]
|
74
|
-
labels = [(x, 1) for x in labels]
|
75
|
-
return cls()([preds, labels])
|
@@ -1,50 +0,0 @@
|
|
1
|
-
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
import numpy as np
|
15
|
-
class TableMetric(object):
|
16
|
-
def __init__(self, main_indicator='acc', **kwargs):
|
17
|
-
self.main_indicator = main_indicator
|
18
|
-
self.reset()
|
19
|
-
|
20
|
-
def __call__(self, pred, batch, *args, **kwargs):
|
21
|
-
structure_probs = pred['structure_probs'].numpy()
|
22
|
-
structure_labels = batch[1]
|
23
|
-
correct_num = 0
|
24
|
-
all_num = 0
|
25
|
-
structure_probs = np.argmax(structure_probs, axis=2)
|
26
|
-
structure_labels = structure_labels[:, 1:]
|
27
|
-
batch_size = structure_probs.shape[0]
|
28
|
-
for bno in range(batch_size):
|
29
|
-
all_num += 1
|
30
|
-
if (structure_probs[bno] == structure_labels[bno]).all():
|
31
|
-
correct_num += 1
|
32
|
-
self.correct_num += correct_num
|
33
|
-
self.all_num += all_num
|
34
|
-
return {
|
35
|
-
'acc': correct_num * 1.0 / all_num,
|
36
|
-
}
|
37
|
-
|
38
|
-
def get_metric(self):
|
39
|
-
"""
|
40
|
-
return metrics {
|
41
|
-
'acc': 0,
|
42
|
-
}
|
43
|
-
"""
|
44
|
-
acc = 1.0 * self.correct_num / self.all_num
|
45
|
-
self.reset()
|
46
|
-
return {'acc': acc}
|
47
|
-
|
48
|
-
def reset(self):
|
49
|
-
self.correct_num = 0
|
50
|
-
self.all_num = 0
|
@@ -1,32 +0,0 @@
|
|
1
|
-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
import copy
|
16
|
-
import importlib
|
17
|
-
|
18
|
-
from .base_model import BaseModel
|
19
|
-
from .distillation_model import DistillationModel
|
20
|
-
|
21
|
-
__all__ = ['build_model']
|
22
|
-
|
23
|
-
|
24
|
-
def build_model(config):
|
25
|
-
config = copy.deepcopy(config)
|
26
|
-
if not "name" in config:
|
27
|
-
arch = BaseModel(config)
|
28
|
-
else:
|
29
|
-
name = config.pop("name")
|
30
|
-
mod = importlib.import_module(__name__)
|
31
|
-
arch = getattr(mod, name)(config)
|
32
|
-
return arch
|