pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/algo/geo.py +12 -0
- pyxllib/algo/intervals.py +1 -1
- pyxllib/algo/matcher.py +78 -0
- pyxllib/algo/pupil.py +187 -19
- pyxllib/algo/specialist.py +2 -1
- pyxllib/algo/stat.py +38 -2
- {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/data/echarts.py +123 -12
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/data/pglib.py +514 -30
- pyxllib/data/sqlite.py +231 -4
- pyxllib/ext/JLineViewer.py +14 -1
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +0 -1594
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/unixlib.py +6 -5
- pyxllib/ext/utools.py +108 -95
- pyxllib/ext/webhook.py +32 -14
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1003 -71
- pyxllib/file/docxlib.py +1 -1
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +9 -0
- pyxllib/file/packlib/__init__.py +112 -75
- pyxllib/file/pdflib.py +1 -1
- pyxllib/file/pupil.py +1 -1
- pyxllib/file/specialist/dirlib.py +1 -1
- pyxllib/file/specialist/download.py +10 -3
- pyxllib/file/specialist/filelib.py +266 -55
- pyxllib/file/xlsxlib.py +205 -50
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +9 -2
- pyxllib/prog/pupil.py +129 -60
- pyxllib/prog/specialist/__init__.py +176 -2
- pyxllib/prog/specialist/bc.py +5 -2
- pyxllib/prog/specialist/browser.py +11 -2
- pyxllib/prog/specialist/datetime.py +68 -0
- pyxllib/prog/specialist/tictoc.py +12 -13
- pyxllib/prog/specialist/xllog.py +5 -5
- pyxllib/prog/xlosenv.py +7 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +17 -5
- pyxllib/text/jiebalib.py +6 -3
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +159 -4
- pyxllib/text/nestenv.py +1 -1
- pyxllib/text/newbie.py +12 -0
- pyxllib/text/pupil/common.py +26 -0
- pyxllib/text/specialist/ptag.py +2 -2
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/xmllib.py +76 -14
- pyxllib/xl.py +2 -1
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
- pyxllib/ext/autogui/__init__.py +0 -8
- pyxllib-0.3.96.dist-info/METADATA +0 -51
- pyxllib-0.3.96.dist-info/RECORD +0 -333
- pyxllib-0.3.96.dist-info/top_level.txt +0 -2
- pyxlpr/ai/__init__.py +0 -5
- pyxlpr/ai/clientlib.py +0 -1281
- pyxlpr/ai/specialist.py +0 -286
- pyxlpr/ai/torch_app.py +0 -172
- pyxlpr/ai/xlpaddle.py +0 -655
- pyxlpr/ai/xltorch.py +0 -705
- pyxlpr/data/__init__.py +0 -11
- pyxlpr/data/coco.py +0 -1325
- pyxlpr/data/datacls.py +0 -365
- pyxlpr/data/datasets.py +0 -200
- pyxlpr/data/gptlib.py +0 -1291
- pyxlpr/data/icdar/__init__.py +0 -96
- pyxlpr/data/icdar/deteval.py +0 -377
- pyxlpr/data/icdar/icdar2013.py +0 -341
- pyxlpr/data/icdar/iou.py +0 -340
- pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
- pyxlpr/data/imtextline.py +0 -473
- pyxlpr/data/labelme.py +0 -866
- pyxlpr/data/removeline.py +0 -179
- pyxlpr/data/specialist.py +0 -57
- pyxlpr/eval/__init__.py +0 -85
- pyxlpr/paddleocr.py +0 -776
- pyxlpr/ppocr/__init__.py +0 -15
- pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
- pyxlpr/ppocr/data/__init__.py +0 -135
- pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
- pyxlpr/ppocr/data/imaug/__init__.py +0 -67
- pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
- pyxlpr/ppocr/data/imaug/east_process.py +0 -437
- pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
- pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
- pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
- pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
- pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
- pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
- pyxlpr/ppocr/data/imaug/operators.py +0 -433
- pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
- pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
- pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
- pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
- pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
- pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
- pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
- pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
- pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
- pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
- pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
- pyxlpr/ppocr/data/simple_dataset.py +0 -372
- pyxlpr/ppocr/losses/__init__.py +0 -61
- pyxlpr/ppocr/losses/ace_loss.py +0 -52
- pyxlpr/ppocr/losses/basic_loss.py +0 -135
- pyxlpr/ppocr/losses/center_loss.py +0 -88
- pyxlpr/ppocr/losses/cls_loss.py +0 -30
- pyxlpr/ppocr/losses/combined_loss.py +0 -67
- pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
- pyxlpr/ppocr/losses/det_db_loss.py +0 -80
- pyxlpr/ppocr/losses/det_east_loss.py +0 -63
- pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
- pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
- pyxlpr/ppocr/losses/distillation_loss.py +0 -272
- pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
- pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
- pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
- pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
- pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
- pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
- pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
- pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
- pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
- pyxlpr/ppocr/losses/table_att_loss.py +0 -109
- pyxlpr/ppocr/metrics/__init__.py +0 -44
- pyxlpr/ppocr/metrics/cls_metric.py +0 -45
- pyxlpr/ppocr/metrics/det_metric.py +0 -82
- pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
- pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
- pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
- pyxlpr/ppocr/metrics/kie_metric.py +0 -70
- pyxlpr/ppocr/metrics/rec_metric.py +0 -75
- pyxlpr/ppocr/metrics/table_metric.py +0 -50
- pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
- pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
- pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
- pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
- pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
- pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
- pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
- pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
- pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
- pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
- pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
- pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
- pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
- pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
- pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
- pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
- pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
- pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
- pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
- pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
- pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
- pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
- pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
- pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
- pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
- pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
- pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
- pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
- pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
- pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
- pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
- pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
- pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
- pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
- pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
- pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
- pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
- pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
- pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
- pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
- pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
- pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
- pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
- pyxlpr/ppocr/optimizer/__init__.py +0 -61
- pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
- pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
- pyxlpr/ppocr/optimizer/optimizer.py +0 -160
- pyxlpr/ppocr/optimizer/regularizer.py +0 -52
- pyxlpr/ppocr/postprocess/__init__.py +0 -55
- pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
- pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
- pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
- pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
- pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
- pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
- pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
- pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
- pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
- pyxlpr/ppocr/tools/__init__.py +0 -14
- pyxlpr/ppocr/tools/eval.py +0 -83
- pyxlpr/ppocr/tools/export_center.py +0 -77
- pyxlpr/ppocr/tools/export_model.py +0 -129
- pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
- pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
- pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
- pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
- pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
- pyxlpr/ppocr/tools/infer/utility.py +0 -629
- pyxlpr/ppocr/tools/infer_cls.py +0 -83
- pyxlpr/ppocr/tools/infer_det.py +0 -134
- pyxlpr/ppocr/tools/infer_e2e.py +0 -122
- pyxlpr/ppocr/tools/infer_kie.py +0 -153
- pyxlpr/ppocr/tools/infer_rec.py +0 -146
- pyxlpr/ppocr/tools/infer_table.py +0 -107
- pyxlpr/ppocr/tools/program.py +0 -596
- pyxlpr/ppocr/tools/test_hubserving.py +0 -117
- pyxlpr/ppocr/tools/train.py +0 -163
- pyxlpr/ppocr/tools/xlprog.py +0 -748
- pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
- pyxlpr/ppocr/utils/__init__.py +0 -24
- pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
- pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
- pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
- pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
- pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
- pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
- pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
- pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
- pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
- pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
- pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
- pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
- pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
- pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
- pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
- pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
- pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
- pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
- pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
- pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
- pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
- pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
- pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
- pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
- pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
- pyxlpr/ppocr/utils/dict90.txt +0 -90
- pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
- pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
- pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
- pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
- pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
- pyxlpr/ppocr/utils/en_dict.txt +0 -95
- pyxlpr/ppocr/utils/gen_label.py +0 -81
- pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
- pyxlpr/ppocr/utils/iou.py +0 -54
- pyxlpr/ppocr/utils/logging.py +0 -69
- pyxlpr/ppocr/utils/network.py +0 -84
- pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
- pyxlpr/ppocr/utils/profiler.py +0 -110
- pyxlpr/ppocr/utils/save_load.py +0 -150
- pyxlpr/ppocr/utils/stats.py +0 -72
- pyxlpr/ppocr/utils/utility.py +0 -80
- pyxlpr/ppstructure/__init__.py +0 -13
- pyxlpr/ppstructure/predict_system.py +0 -187
- pyxlpr/ppstructure/table/__init__.py +0 -13
- pyxlpr/ppstructure/table/eval_table.py +0 -72
- pyxlpr/ppstructure/table/matcher.py +0 -192
- pyxlpr/ppstructure/table/predict_structure.py +0 -136
- pyxlpr/ppstructure/table/predict_table.py +0 -221
- pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
- pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
- pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
- pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
- pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
- pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
- pyxlpr/ppstructure/utility.py +0 -71
- pyxlpr/xlai.py +0 -10
- /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
@@ -1,789 +0,0 @@
|
|
1
|
-
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
from __future__ import absolute_import
|
16
|
-
from __future__ import division
|
17
|
-
from __future__ import print_function
|
18
|
-
from __future__ import unicode_literals
|
19
|
-
|
20
|
-
import numpy as np
|
21
|
-
import string
|
22
|
-
from shapely.geometry import LineString, Point, Polygon
|
23
|
-
import json
|
24
|
-
|
25
|
-
from pyxlpr.ppocr.utils.logging import get_logger
|
26
|
-
|
27
|
-
|
28
|
-
class ClsLabelEncode(object):
|
29
|
-
def __init__(self, label_list, **kwargs):
|
30
|
-
self.label_list = label_list
|
31
|
-
|
32
|
-
def __call__(self, data):
|
33
|
-
label = data['label']
|
34
|
-
if label not in self.label_list:
|
35
|
-
return None
|
36
|
-
label = self.label_list.index(label)
|
37
|
-
data['label'] = label
|
38
|
-
return data
|
39
|
-
|
40
|
-
|
41
|
-
class DetLabelEncode(object):
|
42
|
-
def __init__(self, **kwargs):
|
43
|
-
pass
|
44
|
-
|
45
|
-
def __call__(self, data):
|
46
|
-
label = data['label']
|
47
|
-
# 1. 使用json读入标签
|
48
|
-
label = json.loads(label)
|
49
|
-
nBox = len(label)
|
50
|
-
boxes, txts, txt_tags = [], [], []
|
51
|
-
for bno in range(0, nBox):
|
52
|
-
box = label[bno]['points']
|
53
|
-
txt = label[bno]['transcription']
|
54
|
-
boxes.append(box)
|
55
|
-
txts.append(txt)
|
56
|
-
# 1.1 如果文本标注是*或者###,表示此标注无效
|
57
|
-
if txt in ['*', '###']:
|
58
|
-
txt_tags.append(True)
|
59
|
-
else:
|
60
|
-
txt_tags.append(False)
|
61
|
-
if len(boxes) == 0:
|
62
|
-
return None
|
63
|
-
boxes = self.expand_points_num(boxes)
|
64
|
-
boxes = np.array(boxes, dtype=np.float32)
|
65
|
-
txt_tags = np.array(txt_tags, dtype=np.bool)
|
66
|
-
|
67
|
-
# 2. 得到文字、box等信息
|
68
|
-
data['polys'] = boxes
|
69
|
-
data['texts'] = txts
|
70
|
-
data['ignore_tags'] = txt_tags
|
71
|
-
return data
|
72
|
-
|
73
|
-
def order_points_clockwise(self, pts):
|
74
|
-
rect = np.zeros((4, 2), dtype="float32")
|
75
|
-
s = pts.sum(axis=1)
|
76
|
-
rect[0] = pts[np.argmin(s)]
|
77
|
-
rect[2] = pts[np.argmax(s)]
|
78
|
-
diff = np.diff(pts, axis=1)
|
79
|
-
rect[1] = pts[np.argmin(diff)]
|
80
|
-
rect[3] = pts[np.argmax(diff)]
|
81
|
-
return rect
|
82
|
-
|
83
|
-
def expand_points_num(self, boxes):
|
84
|
-
# 计算边数最多的多边形
|
85
|
-
max_points_num = 0
|
86
|
-
for box in boxes:
|
87
|
-
if len(box) > max_points_num:
|
88
|
-
max_points_num = len(box)
|
89
|
-
# 将边数少的多边形,扩展对齐到 max_points_num
|
90
|
-
ex_boxes = []
|
91
|
-
for box in boxes:
|
92
|
-
ex_box = box + [box[-1]] * (max_points_num - len(box))
|
93
|
-
ex_boxes.append(ex_box)
|
94
|
-
return ex_boxes
|
95
|
-
|
96
|
-
|
97
|
-
class BaseRecLabelEncode(object):
|
98
|
-
""" Convert between text-label and text-index """
|
99
|
-
|
100
|
-
def __init__(self,
|
101
|
-
max_text_length,
|
102
|
-
character_dict_path=None,
|
103
|
-
use_space_char=False):
|
104
|
-
|
105
|
-
self.max_text_len = max_text_length
|
106
|
-
self.beg_str = "sos"
|
107
|
-
self.end_str = "eos"
|
108
|
-
self.lower = False
|
109
|
-
|
110
|
-
if character_dict_path is None:
|
111
|
-
logger = get_logger()
|
112
|
-
logger.warning(
|
113
|
-
"The character_dict_path is None, model can only recognize number and lower letters"
|
114
|
-
)
|
115
|
-
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
116
|
-
dict_character = list(self.character_str)
|
117
|
-
self.lower = True
|
118
|
-
else:
|
119
|
-
self.character_str = ""
|
120
|
-
with open(character_dict_path, "rb") as fin:
|
121
|
-
lines = fin.readlines()
|
122
|
-
for line in lines:
|
123
|
-
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
124
|
-
self.character_str += line
|
125
|
-
if use_space_char:
|
126
|
-
self.character_str += " "
|
127
|
-
dict_character = list(self.character_str)
|
128
|
-
dict_character = self.add_special_char(dict_character)
|
129
|
-
self.dict = {}
|
130
|
-
for i, char in enumerate(dict_character):
|
131
|
-
self.dict[char] = i
|
132
|
-
self.character = dict_character
|
133
|
-
|
134
|
-
def add_special_char(self, dict_character):
|
135
|
-
return dict_character
|
136
|
-
|
137
|
-
def encode(self, text):
|
138
|
-
"""convert text-label into text-index.
|
139
|
-
input:
|
140
|
-
text: text labels of each image. [batch_size]
|
141
|
-
|
142
|
-
output:
|
143
|
-
text: concatenated text index for CTCLoss.
|
144
|
-
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
|
145
|
-
length: length of each text. [batch_size]
|
146
|
-
"""
|
147
|
-
if len(text) == 0 or len(text) > self.max_text_len:
|
148
|
-
return None
|
149
|
-
if self.lower:
|
150
|
-
text = text.lower()
|
151
|
-
text_list = []
|
152
|
-
for char in text:
|
153
|
-
if char not in self.dict:
|
154
|
-
# logger = get_logger()
|
155
|
-
# logger.warning('{} is not in dict'.format(char))
|
156
|
-
continue
|
157
|
-
text_list.append(self.dict[char])
|
158
|
-
if len(text_list) == 0:
|
159
|
-
return None
|
160
|
-
return text_list
|
161
|
-
|
162
|
-
|
163
|
-
class NRTRLabelEncode(BaseRecLabelEncode):
|
164
|
-
""" Convert between text-label and text-index """
|
165
|
-
|
166
|
-
def __init__(self,
|
167
|
-
max_text_length,
|
168
|
-
character_dict_path=None,
|
169
|
-
use_space_char=False,
|
170
|
-
**kwargs):
|
171
|
-
|
172
|
-
super(NRTRLabelEncode, self).__init__(
|
173
|
-
max_text_length, character_dict_path, use_space_char)
|
174
|
-
|
175
|
-
def __call__(self, data):
|
176
|
-
text = data['label']
|
177
|
-
text = self.encode(text)
|
178
|
-
if text is None:
|
179
|
-
return None
|
180
|
-
if len(text) >= self.max_text_len - 1:
|
181
|
-
return None
|
182
|
-
data['length'] = np.array(len(text))
|
183
|
-
text.insert(0, 2)
|
184
|
-
text.append(3)
|
185
|
-
text = text + [0] * (self.max_text_len - len(text))
|
186
|
-
data['label'] = np.array(text)
|
187
|
-
return data
|
188
|
-
|
189
|
-
def add_special_char(self, dict_character):
|
190
|
-
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
|
191
|
-
return dict_character
|
192
|
-
|
193
|
-
|
194
|
-
class CTCLabelEncode(BaseRecLabelEncode):
|
195
|
-
""" Convert between text-label and text-index """
|
196
|
-
|
197
|
-
def __init__(self,
|
198
|
-
max_text_length,
|
199
|
-
character_dict_path=None,
|
200
|
-
use_space_char=False,
|
201
|
-
**kwargs):
|
202
|
-
super(CTCLabelEncode, self).__init__(
|
203
|
-
max_text_length, character_dict_path, use_space_char)
|
204
|
-
|
205
|
-
def __call__(self, data):
|
206
|
-
text = data['label']
|
207
|
-
text = self.encode(text)
|
208
|
-
if text is None:
|
209
|
-
return None
|
210
|
-
data['length'] = np.array(len(text))
|
211
|
-
text = text + [0] * (self.max_text_len - len(text))
|
212
|
-
data['label'] = np.array(text)
|
213
|
-
|
214
|
-
label = [0] * len(self.character)
|
215
|
-
for x in text:
|
216
|
-
label[x] += 1
|
217
|
-
data['label_ace'] = np.array(label)
|
218
|
-
return data
|
219
|
-
|
220
|
-
def add_special_char(self, dict_character):
|
221
|
-
dict_character = ['blank'] + dict_character
|
222
|
-
return dict_character
|
223
|
-
|
224
|
-
|
225
|
-
class E2ELabelEncodeTest(BaseRecLabelEncode):
|
226
|
-
def __init__(self,
|
227
|
-
max_text_length,
|
228
|
-
character_dict_path=None,
|
229
|
-
use_space_char=False,
|
230
|
-
**kwargs):
|
231
|
-
super(E2ELabelEncodeTest, self).__init__(
|
232
|
-
max_text_length, character_dict_path, use_space_char)
|
233
|
-
|
234
|
-
def __call__(self, data):
|
235
|
-
import json
|
236
|
-
padnum = len(self.dict)
|
237
|
-
label = data['label']
|
238
|
-
label = json.loads(label)
|
239
|
-
nBox = len(label)
|
240
|
-
boxes, txts, txt_tags = [], [], []
|
241
|
-
for bno in range(0, nBox):
|
242
|
-
box = label[bno]['points']
|
243
|
-
txt = label[bno]['transcription']
|
244
|
-
boxes.append(box)
|
245
|
-
txts.append(txt)
|
246
|
-
if txt in ['*', '###']:
|
247
|
-
txt_tags.append(True)
|
248
|
-
else:
|
249
|
-
txt_tags.append(False)
|
250
|
-
boxes = np.array(boxes, dtype=np.float32)
|
251
|
-
txt_tags = np.array(txt_tags, dtype=np.bool)
|
252
|
-
data['polys'] = boxes
|
253
|
-
data['ignore_tags'] = txt_tags
|
254
|
-
temp_texts = []
|
255
|
-
for text in txts:
|
256
|
-
text = text.lower()
|
257
|
-
text = self.encode(text)
|
258
|
-
if text is None:
|
259
|
-
return None
|
260
|
-
text = text + [padnum] * (self.max_text_len - len(text)
|
261
|
-
) # use 36 to pad
|
262
|
-
temp_texts.append(text)
|
263
|
-
data['texts'] = np.array(temp_texts)
|
264
|
-
return data
|
265
|
-
|
266
|
-
|
267
|
-
class E2ELabelEncodeTrain(object):
|
268
|
-
def __init__(self, **kwargs):
|
269
|
-
pass
|
270
|
-
|
271
|
-
def __call__(self, data):
|
272
|
-
import json
|
273
|
-
label = data['label']
|
274
|
-
label = json.loads(label)
|
275
|
-
nBox = len(label)
|
276
|
-
boxes, txts, txt_tags = [], [], []
|
277
|
-
for bno in range(0, nBox):
|
278
|
-
box = label[bno]['points']
|
279
|
-
txt = label[bno]['transcription']
|
280
|
-
boxes.append(box)
|
281
|
-
txts.append(txt)
|
282
|
-
if txt in ['*', '###']:
|
283
|
-
txt_tags.append(True)
|
284
|
-
else:
|
285
|
-
txt_tags.append(False)
|
286
|
-
boxes = np.array(boxes, dtype=np.float32)
|
287
|
-
txt_tags = np.array(txt_tags, dtype=np.bool)
|
288
|
-
|
289
|
-
data['polys'] = boxes
|
290
|
-
data['texts'] = txts
|
291
|
-
data['ignore_tags'] = txt_tags
|
292
|
-
return data
|
293
|
-
|
294
|
-
|
295
|
-
class KieLabelEncode(object):
|
296
|
-
def __init__(self, character_dict_path, norm=10, directed=False, **kwargs):
|
297
|
-
super(KieLabelEncode, self).__init__()
|
298
|
-
self.dict = dict({'': 0})
|
299
|
-
with open(character_dict_path, 'r', encoding='utf-8') as fr:
|
300
|
-
idx = 1
|
301
|
-
for line in fr:
|
302
|
-
char = line.strip()
|
303
|
-
self.dict[char] = idx
|
304
|
-
idx += 1
|
305
|
-
self.norm = norm
|
306
|
-
self.directed = directed
|
307
|
-
|
308
|
-
def compute_relation(self, boxes):
|
309
|
-
"""Compute relation between every two boxes."""
|
310
|
-
x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
|
311
|
-
x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
|
312
|
-
ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
|
313
|
-
dxs = (x1s[:, 0][None] - x1s) / self.norm
|
314
|
-
dys = (y1s[:, 0][None] - y1s) / self.norm
|
315
|
-
xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
|
316
|
-
whs = ws / hs + np.zeros_like(xhhs)
|
317
|
-
relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
|
318
|
-
bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
|
319
|
-
return relations, bboxes
|
320
|
-
|
321
|
-
def pad_text_indices(self, text_inds):
|
322
|
-
"""Pad text index to same length."""
|
323
|
-
max_len = 300
|
324
|
-
recoder_len = max([len(text_ind) for text_ind in text_inds])
|
325
|
-
padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
|
326
|
-
for idx, text_ind in enumerate(text_inds):
|
327
|
-
padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)
|
328
|
-
return padded_text_inds, recoder_len
|
329
|
-
|
330
|
-
def list_to_numpy(self, ann_infos):
|
331
|
-
"""Convert bboxes, relations, texts and labels to ndarray."""
|
332
|
-
boxes, text_inds = ann_infos['points'], ann_infos['text_inds']
|
333
|
-
boxes = np.array(boxes, np.int32)
|
334
|
-
relations, bboxes = self.compute_relation(boxes)
|
335
|
-
|
336
|
-
labels = ann_infos.get('labels', None)
|
337
|
-
if labels is not None:
|
338
|
-
labels = np.array(labels, np.int32)
|
339
|
-
edges = ann_infos.get('edges', None)
|
340
|
-
if edges is not None:
|
341
|
-
labels = labels[:, None]
|
342
|
-
edges = np.array(edges)
|
343
|
-
edges = (edges[:, None] == edges[None, :]).astype(np.int32)
|
344
|
-
if self.directed:
|
345
|
-
edges = (edges & labels == 1).astype(np.int32)
|
346
|
-
np.fill_diagonal(edges, -1)
|
347
|
-
labels = np.concatenate([labels, edges], -1)
|
348
|
-
padded_text_inds, recoder_len = self.pad_text_indices(text_inds)
|
349
|
-
max_num = 300
|
350
|
-
temp_bboxes = np.zeros([max_num, 4])
|
351
|
-
h, _ = bboxes.shape
|
352
|
-
temp_bboxes[:h, :h] = bboxes
|
353
|
-
|
354
|
-
temp_relations = np.zeros([max_num, max_num, 5])
|
355
|
-
temp_relations[:h, :h, :] = relations
|
356
|
-
|
357
|
-
temp_padded_text_inds = np.zeros([max_num, max_num])
|
358
|
-
temp_padded_text_inds[:h, :] = padded_text_inds
|
359
|
-
|
360
|
-
temp_labels = np.zeros([max_num, max_num])
|
361
|
-
temp_labels[:h, :h + 1] = labels
|
362
|
-
|
363
|
-
tag = np.array([h, recoder_len])
|
364
|
-
return dict(
|
365
|
-
image=ann_infos['image'],
|
366
|
-
points=temp_bboxes,
|
367
|
-
relations=temp_relations,
|
368
|
-
texts=temp_padded_text_inds,
|
369
|
-
labels=temp_labels,
|
370
|
-
tag=tag)
|
371
|
-
|
372
|
-
def convert_canonical(self, points_x, points_y):
|
373
|
-
|
374
|
-
assert len(points_x) == 4
|
375
|
-
assert len(points_y) == 4
|
376
|
-
|
377
|
-
points = [Point(points_x[i], points_y[i]) for i in range(4)]
|
378
|
-
|
379
|
-
polygon = Polygon([(p.x, p.y) for p in points])
|
380
|
-
min_x, min_y, _, _ = polygon.bounds
|
381
|
-
points_to_lefttop = [
|
382
|
-
LineString([points[i], Point(min_x, min_y)]) for i in range(4)
|
383
|
-
]
|
384
|
-
distances = np.array([line.length for line in points_to_lefttop])
|
385
|
-
sort_dist_idx = np.argsort(distances)
|
386
|
-
lefttop_idx = sort_dist_idx[0]
|
387
|
-
|
388
|
-
if lefttop_idx == 0:
|
389
|
-
point_orders = [0, 1, 2, 3]
|
390
|
-
elif lefttop_idx == 1:
|
391
|
-
point_orders = [1, 2, 3, 0]
|
392
|
-
elif lefttop_idx == 2:
|
393
|
-
point_orders = [2, 3, 0, 1]
|
394
|
-
else:
|
395
|
-
point_orders = [3, 0, 1, 2]
|
396
|
-
|
397
|
-
sorted_points_x = [points_x[i] for i in point_orders]
|
398
|
-
sorted_points_y = [points_y[j] for j in point_orders]
|
399
|
-
|
400
|
-
return sorted_points_x, sorted_points_y
|
401
|
-
|
402
|
-
def sort_vertex(self, points_x, points_y):
|
403
|
-
|
404
|
-
assert len(points_x) == 4
|
405
|
-
assert len(points_y) == 4
|
406
|
-
|
407
|
-
x = np.array(points_x)
|
408
|
-
y = np.array(points_y)
|
409
|
-
center_x = np.sum(x) * 0.25
|
410
|
-
center_y = np.sum(y) * 0.25
|
411
|
-
|
412
|
-
x_arr = np.array(x - center_x)
|
413
|
-
y_arr = np.array(y - center_y)
|
414
|
-
|
415
|
-
angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
|
416
|
-
sort_idx = np.argsort(angle)
|
417
|
-
|
418
|
-
sorted_points_x, sorted_points_y = [], []
|
419
|
-
for i in range(4):
|
420
|
-
sorted_points_x.append(points_x[sort_idx[i]])
|
421
|
-
sorted_points_y.append(points_y[sort_idx[i]])
|
422
|
-
|
423
|
-
return self.convert_canonical(sorted_points_x, sorted_points_y)
|
424
|
-
|
425
|
-
def __call__(self, data):
|
426
|
-
import json
|
427
|
-
label = data['label']
|
428
|
-
annotations = json.loads(label)
|
429
|
-
boxes, texts, text_inds, labels, edges = [], [], [], [], []
|
430
|
-
for ann in annotations:
|
431
|
-
box = ann['points']
|
432
|
-
x_list = [box[i][0] for i in range(4)]
|
433
|
-
y_list = [box[i][1] for i in range(4)]
|
434
|
-
sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list)
|
435
|
-
sorted_box = []
|
436
|
-
for x, y in zip(sorted_x_list, sorted_y_list):
|
437
|
-
sorted_box.append(x)
|
438
|
-
sorted_box.append(y)
|
439
|
-
boxes.append(sorted_box)
|
440
|
-
text = ann['transcription']
|
441
|
-
texts.append(ann['transcription'])
|
442
|
-
text_ind = [self.dict[c] for c in text if c in self.dict]
|
443
|
-
text_inds.append(text_ind)
|
444
|
-
labels.append(ann['label'])
|
445
|
-
edges.append(ann.get('edge', 0))
|
446
|
-
ann_infos = dict(
|
447
|
-
image=data['image'],
|
448
|
-
points=boxes,
|
449
|
-
texts=texts,
|
450
|
-
text_inds=text_inds,
|
451
|
-
edges=edges,
|
452
|
-
labels=labels)
|
453
|
-
|
454
|
-
return self.list_to_numpy(ann_infos)
|
455
|
-
|
456
|
-
|
457
|
-
class AttnLabelEncode(BaseRecLabelEncode):
|
458
|
-
""" Convert between text-label and text-index """
|
459
|
-
|
460
|
-
def __init__(self,
|
461
|
-
max_text_length,
|
462
|
-
character_dict_path=None,
|
463
|
-
use_space_char=False,
|
464
|
-
**kwargs):
|
465
|
-
super(AttnLabelEncode, self).__init__(
|
466
|
-
max_text_length, character_dict_path, use_space_char)
|
467
|
-
|
468
|
-
def add_special_char(self, dict_character):
|
469
|
-
self.beg_str = "sos"
|
470
|
-
self.end_str = "eos"
|
471
|
-
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
472
|
-
return dict_character
|
473
|
-
|
474
|
-
def __call__(self, data):
|
475
|
-
text = data['label']
|
476
|
-
text = self.encode(text)
|
477
|
-
if text is None:
|
478
|
-
return None
|
479
|
-
if len(text) >= self.max_text_len:
|
480
|
-
return None
|
481
|
-
data['length'] = np.array(len(text))
|
482
|
-
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
|
483
|
-
- len(text) - 2)
|
484
|
-
data['label'] = np.array(text)
|
485
|
-
return data
|
486
|
-
|
487
|
-
def get_ignored_tokens(self):
|
488
|
-
beg_idx = self.get_beg_end_flag_idx("beg")
|
489
|
-
end_idx = self.get_beg_end_flag_idx("end")
|
490
|
-
return [beg_idx, end_idx]
|
491
|
-
|
492
|
-
def get_beg_end_flag_idx(self, beg_or_end):
|
493
|
-
if beg_or_end == "beg":
|
494
|
-
idx = np.array(self.dict[self.beg_str])
|
495
|
-
elif beg_or_end == "end":
|
496
|
-
idx = np.array(self.dict[self.end_str])
|
497
|
-
else:
|
498
|
-
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
499
|
-
% beg_or_end
|
500
|
-
return idx
|
501
|
-
|
502
|
-
|
503
|
-
class SEEDLabelEncode(BaseRecLabelEncode):
|
504
|
-
""" Convert between text-label and text-index """
|
505
|
-
|
506
|
-
def __init__(self,
|
507
|
-
max_text_length,
|
508
|
-
character_dict_path=None,
|
509
|
-
use_space_char=False,
|
510
|
-
**kwargs):
|
511
|
-
super(SEEDLabelEncode, self).__init__(
|
512
|
-
max_text_length, character_dict_path, use_space_char)
|
513
|
-
|
514
|
-
def add_special_char(self, dict_character):
|
515
|
-
self.padding = "padding"
|
516
|
-
self.end_str = "eos"
|
517
|
-
self.unknown = "unknown"
|
518
|
-
dict_character = dict_character + [
|
519
|
-
self.end_str, self.padding, self.unknown
|
520
|
-
]
|
521
|
-
return dict_character
|
522
|
-
|
523
|
-
def __call__(self, data):
|
524
|
-
text = data['label']
|
525
|
-
text = self.encode(text)
|
526
|
-
if text is None:
|
527
|
-
return None
|
528
|
-
if len(text) >= self.max_text_len:
|
529
|
-
return None
|
530
|
-
data['length'] = np.array(len(text)) + 1 # conclude eos
|
531
|
-
text = text + [len(self.character) - 3] + [len(self.character) - 2] * (
|
532
|
-
self.max_text_len - len(text) - 1)
|
533
|
-
data['label'] = np.array(text)
|
534
|
-
return data
|
535
|
-
|
536
|
-
|
537
|
-
class SRNLabelEncode(BaseRecLabelEncode):
|
538
|
-
""" Convert between text-label and text-index """
|
539
|
-
|
540
|
-
def __init__(self,
|
541
|
-
max_text_length=25,
|
542
|
-
character_dict_path=None,
|
543
|
-
use_space_char=False,
|
544
|
-
**kwargs):
|
545
|
-
super(SRNLabelEncode, self).__init__(
|
546
|
-
max_text_length, character_dict_path, use_space_char)
|
547
|
-
|
548
|
-
def add_special_char(self, dict_character):
|
549
|
-
dict_character = dict_character + [self.beg_str, self.end_str]
|
550
|
-
return dict_character
|
551
|
-
|
552
|
-
def __call__(self, data):
|
553
|
-
text = data['label']
|
554
|
-
text = self.encode(text)
|
555
|
-
char_num = len(self.character)
|
556
|
-
if text is None:
|
557
|
-
return None
|
558
|
-
if len(text) > self.max_text_len:
|
559
|
-
return None
|
560
|
-
data['length'] = np.array(len(text))
|
561
|
-
text = text + [char_num - 1] * (self.max_text_len - len(text))
|
562
|
-
data['label'] = np.array(text)
|
563
|
-
return data
|
564
|
-
|
565
|
-
def get_ignored_tokens(self):
|
566
|
-
beg_idx = self.get_beg_end_flag_idx("beg")
|
567
|
-
end_idx = self.get_beg_end_flag_idx("end")
|
568
|
-
return [beg_idx, end_idx]
|
569
|
-
|
570
|
-
def get_beg_end_flag_idx(self, beg_or_end):
|
571
|
-
if beg_or_end == "beg":
|
572
|
-
idx = np.array(self.dict[self.beg_str])
|
573
|
-
elif beg_or_end == "end":
|
574
|
-
idx = np.array(self.dict[self.end_str])
|
575
|
-
else:
|
576
|
-
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
577
|
-
% beg_or_end
|
578
|
-
return idx
|
579
|
-
|
580
|
-
|
581
|
-
class TableLabelEncode(object):
|
582
|
-
""" Convert between text-label and text-index """
|
583
|
-
|
584
|
-
def __init__(self,
|
585
|
-
max_text_length,
|
586
|
-
max_elem_length,
|
587
|
-
max_cell_num,
|
588
|
-
character_dict_path,
|
589
|
-
span_weight=1.0,
|
590
|
-
**kwargs):
|
591
|
-
self.max_text_length = max_text_length
|
592
|
-
self.max_elem_length = max_elem_length
|
593
|
-
self.max_cell_num = max_cell_num
|
594
|
-
list_character, list_elem = self.load_char_elem_dict(
|
595
|
-
character_dict_path)
|
596
|
-
list_character = self.add_special_char(list_character)
|
597
|
-
list_elem = self.add_special_char(list_elem)
|
598
|
-
self.dict_character = {}
|
599
|
-
for i, char in enumerate(list_character):
|
600
|
-
self.dict_character[char] = i
|
601
|
-
self.dict_elem = {}
|
602
|
-
for i, elem in enumerate(list_elem):
|
603
|
-
self.dict_elem[elem] = i
|
604
|
-
self.span_weight = span_weight
|
605
|
-
|
606
|
-
def load_char_elem_dict(self, character_dict_path):
|
607
|
-
list_character = []
|
608
|
-
list_elem = []
|
609
|
-
with open(character_dict_path, "rb") as fin:
|
610
|
-
lines = fin.readlines()
|
611
|
-
substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
|
612
|
-
character_num = int(substr[0])
|
613
|
-
elem_num = int(substr[1])
|
614
|
-
for cno in range(1, 1 + character_num):
|
615
|
-
character = lines[cno].decode('utf-8').strip("\r\n")
|
616
|
-
list_character.append(character)
|
617
|
-
for eno in range(1 + character_num, 1 + character_num + elem_num):
|
618
|
-
elem = lines[eno].decode('utf-8').strip("\r\n")
|
619
|
-
list_elem.append(elem)
|
620
|
-
return list_character, list_elem
|
621
|
-
|
622
|
-
def add_special_char(self, list_character):
|
623
|
-
self.beg_str = "sos"
|
624
|
-
self.end_str = "eos"
|
625
|
-
list_character = [self.beg_str] + list_character + [self.end_str]
|
626
|
-
return list_character
|
627
|
-
|
628
|
-
def get_span_idx_list(self):
|
629
|
-
span_idx_list = []
|
630
|
-
for elem in self.dict_elem:
|
631
|
-
if 'span' in elem:
|
632
|
-
span_idx_list.append(self.dict_elem[elem])
|
633
|
-
return span_idx_list
|
634
|
-
|
635
|
-
def __call__(self, data):
|
636
|
-
cells = data['cells']
|
637
|
-
structure = data['structure']['tokens']
|
638
|
-
structure = self.encode(structure, 'elem')
|
639
|
-
if structure is None:
|
640
|
-
return None
|
641
|
-
elem_num = len(structure)
|
642
|
-
structure = [0] + structure + [len(self.dict_elem) - 1]
|
643
|
-
structure = structure + [0] * (self.max_elem_length + 2 - len(structure)
|
644
|
-
)
|
645
|
-
structure = np.array(structure)
|
646
|
-
data['structure'] = structure
|
647
|
-
elem_char_idx1 = self.dict_elem['<td>']
|
648
|
-
elem_char_idx2 = self.dict_elem['<td']
|
649
|
-
span_idx_list = self.get_span_idx_list()
|
650
|
-
td_idx_list = np.logical_or(structure == elem_char_idx1,
|
651
|
-
structure == elem_char_idx2)
|
652
|
-
td_idx_list = np.where(td_idx_list)[0]
|
653
|
-
|
654
|
-
structure_mask = np.ones(
|
655
|
-
(self.max_elem_length + 2, 1), dtype=np.float32)
|
656
|
-
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
|
657
|
-
bbox_list_mask = np.zeros(
|
658
|
-
(self.max_elem_length + 2, 1), dtype=np.float32)
|
659
|
-
img_height, img_width, img_ch = data['image'].shape
|
660
|
-
if len(span_idx_list) > 0:
|
661
|
-
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
|
662
|
-
span_weight = min(max(span_weight, 1.0), self.span_weight)
|
663
|
-
for cno in range(len(cells)):
|
664
|
-
if 'bbox' in cells[cno]:
|
665
|
-
bbox = cells[cno]['bbox'].copy()
|
666
|
-
bbox[0] = bbox[0] * 1.0 / img_width
|
667
|
-
bbox[1] = bbox[1] * 1.0 / img_height
|
668
|
-
bbox[2] = bbox[2] * 1.0 / img_width
|
669
|
-
bbox[3] = bbox[3] * 1.0 / img_height
|
670
|
-
td_idx = td_idx_list[cno]
|
671
|
-
bbox_list[td_idx] = bbox
|
672
|
-
bbox_list_mask[td_idx] = 1.0
|
673
|
-
cand_span_idx = td_idx + 1
|
674
|
-
if cand_span_idx < (self.max_elem_length + 2):
|
675
|
-
if structure[cand_span_idx] in span_idx_list:
|
676
|
-
structure_mask[cand_span_idx] = span_weight
|
677
|
-
|
678
|
-
data['bbox_list'] = bbox_list
|
679
|
-
data['bbox_list_mask'] = bbox_list_mask
|
680
|
-
data['structure_mask'] = structure_mask
|
681
|
-
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
|
682
|
-
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
|
683
|
-
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
|
684
|
-
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
|
685
|
-
data['sp_tokens'] = np.array([
|
686
|
-
char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
|
687
|
-
elem_char_idx1, elem_char_idx2, self.max_text_length,
|
688
|
-
self.max_elem_length, self.max_cell_num, elem_num
|
689
|
-
])
|
690
|
-
return data
|
691
|
-
|
692
|
-
def encode(self, text, char_or_elem):
|
693
|
-
"""convert text-label into text-index.
|
694
|
-
"""
|
695
|
-
if char_or_elem == "char":
|
696
|
-
max_len = self.max_text_length
|
697
|
-
current_dict = self.dict_character
|
698
|
-
else:
|
699
|
-
max_len = self.max_elem_length
|
700
|
-
current_dict = self.dict_elem
|
701
|
-
if len(text) > max_len:
|
702
|
-
return None
|
703
|
-
if len(text) == 0:
|
704
|
-
if char_or_elem == "char":
|
705
|
-
return [self.dict_character['space']]
|
706
|
-
else:
|
707
|
-
return None
|
708
|
-
text_list = []
|
709
|
-
for char in text:
|
710
|
-
if char not in current_dict:
|
711
|
-
return None
|
712
|
-
text_list.append(current_dict[char])
|
713
|
-
if len(text_list) == 0:
|
714
|
-
if char_or_elem == "char":
|
715
|
-
return [self.dict_character['space']]
|
716
|
-
else:
|
717
|
-
return None
|
718
|
-
return text_list
|
719
|
-
|
720
|
-
def get_ignored_tokens(self, char_or_elem):
|
721
|
-
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
|
722
|
-
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
|
723
|
-
return [beg_idx, end_idx]
|
724
|
-
|
725
|
-
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
|
726
|
-
if char_or_elem == "char":
|
727
|
-
if beg_or_end == "beg":
|
728
|
-
idx = np.array(self.dict_character[self.beg_str])
|
729
|
-
elif beg_or_end == "end":
|
730
|
-
idx = np.array(self.dict_character[self.end_str])
|
731
|
-
else:
|
732
|
-
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
|
733
|
-
% beg_or_end
|
734
|
-
elif char_or_elem == "elem":
|
735
|
-
if beg_or_end == "beg":
|
736
|
-
idx = np.array(self.dict_elem[self.beg_str])
|
737
|
-
elif beg_or_end == "end":
|
738
|
-
idx = np.array(self.dict_elem[self.end_str])
|
739
|
-
else:
|
740
|
-
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
|
741
|
-
% beg_or_end
|
742
|
-
else:
|
743
|
-
assert False, "Unsupport type %s in char_or_elem" \
|
744
|
-
% char_or_elem
|
745
|
-
return idx
|
746
|
-
|
747
|
-
|
748
|
-
class SARLabelEncode(BaseRecLabelEncode):
|
749
|
-
""" Convert between text-label and text-index """
|
750
|
-
|
751
|
-
def __init__(self,
|
752
|
-
max_text_length,
|
753
|
-
character_dict_path=None,
|
754
|
-
use_space_char=False,
|
755
|
-
**kwargs):
|
756
|
-
super(SARLabelEncode, self).__init__(
|
757
|
-
max_text_length, character_dict_path, use_space_char)
|
758
|
-
|
759
|
-
def add_special_char(self, dict_character):
|
760
|
-
beg_end_str = "<BOS/EOS>"
|
761
|
-
unknown_str = "<UKN>"
|
762
|
-
padding_str = "<PAD>"
|
763
|
-
dict_character = dict_character + [unknown_str]
|
764
|
-
self.unknown_idx = len(dict_character) - 1
|
765
|
-
dict_character = dict_character + [beg_end_str]
|
766
|
-
self.start_idx = len(dict_character) - 1
|
767
|
-
self.end_idx = len(dict_character) - 1
|
768
|
-
dict_character = dict_character + [padding_str]
|
769
|
-
self.padding_idx = len(dict_character) - 1
|
770
|
-
|
771
|
-
return dict_character
|
772
|
-
|
773
|
-
def __call__(self, data):
|
774
|
-
text = data['label']
|
775
|
-
text = self.encode(text)
|
776
|
-
if text is None:
|
777
|
-
return None
|
778
|
-
if len(text) >= self.max_text_len - 1:
|
779
|
-
return None
|
780
|
-
data['length'] = np.array(len(text))
|
781
|
-
target = [self.start_idx] + text + [self.end_idx]
|
782
|
-
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
|
783
|
-
|
784
|
-
padded_text[:len(target)] = target
|
785
|
-
data['label'] = np.array(padded_text)
|
786
|
-
return data
|
787
|
-
|
788
|
-
def get_ignored_tokens(self):
|
789
|
-
return [self.padding_idx]
|