openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -50,7 +50,7 @@ class BaseRecLabelEncode(object):
|
|
|
50
50
|
text_re = []
|
|
51
51
|
c_current = ''
|
|
52
52
|
for c in text:
|
|
53
|
-
if not bool(re.search('[a-zA-Z0-9
|
|
53
|
+
if not bool(re.search('[a-zA-Z0-9 :*./%+١٢٣٤٥٦٧٨٩٠-]', c)):
|
|
54
54
|
if c_current != '':
|
|
55
55
|
text_re.append(c_current)
|
|
56
56
|
text_re.append(c)
|
|
@@ -1,157 +1,177 @@
|
|
|
1
|
-
import re
|
|
2
|
-
from abc import ABC, abstractmethod
|
|
3
|
-
from itertools import groupby
|
|
4
|
-
from typing import List, Optional, Tuple
|
|
5
|
-
import numpy as np
|
|
6
|
-
import torch
|
|
7
|
-
from torch import Tensor
|
|
8
|
-
from torch.nn.utils.rnn import pad_sequence
|
|
9
|
-
import unicodedata
|
|
10
|
-
from ..modeling.decoders.dptr_parseq_clip_b_decoder import tokenize
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
self.
|
|
19
|
-
self.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
self
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
# print("
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
"""
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
1
|
+
import re
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from itertools import groupby
|
|
4
|
+
from typing import List, Optional, Tuple
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
9
|
+
import unicodedata
|
|
10
|
+
from ..modeling.decoders.dptr_parseq_clip_b_decoder import tokenize
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CharsetAdapter:
|
|
14
|
+
"""Transforms labels according to the target charset."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, target_charset) -> None:
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.lowercase_only = target_charset == target_charset.lower()
|
|
19
|
+
self.uppercase_only = target_charset == target_charset.upper()
|
|
20
|
+
self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
|
|
21
|
+
|
|
22
|
+
def __call__(self, label):
|
|
23
|
+
if self.lowercase_only:
|
|
24
|
+
label = label.lower()
|
|
25
|
+
elif self.uppercase_only:
|
|
26
|
+
label = label.upper()
|
|
27
|
+
# Remove unsupported characters
|
|
28
|
+
label = self.unsupported.sub('', label)
|
|
29
|
+
return label
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class BaseTokenizer(ABC):
|
|
33
|
+
# eos=0, a=1, bos=37, pad=38
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
charset: str,
|
|
37
|
+
specials_first: tuple = (),
|
|
38
|
+
specials_last: tuple = ()
|
|
39
|
+
) -> None:
|
|
40
|
+
self._itos = specials_first + tuple(charset) + specials_last
|
|
41
|
+
self._stoi = {s: i for i, s in enumerate(self._itos)}
|
|
42
|
+
# print("stoi:", self._stoi)
|
|
43
|
+
|
|
44
|
+
def __len__(self):
|
|
45
|
+
return len(self._itos)
|
|
46
|
+
|
|
47
|
+
def _tok2ids(self, tokens: str) -> List[int]:
|
|
48
|
+
# print("tokens", tokens)
|
|
49
|
+
return [self._stoi[s] for s in tokens]
|
|
50
|
+
|
|
51
|
+
def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
|
|
52
|
+
tokens = [self._itos[i] for i in token_ids]
|
|
53
|
+
return ''.join(tokens) if join else tokens
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def encode(self,
|
|
57
|
+
labels: List[str],
|
|
58
|
+
device: Optional[torch.device] = None) -> Tensor:
|
|
59
|
+
"""Encode a batch of labels to a representation suitable for the model.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
labels: List of labels. Each can be of arbitrary length.
|
|
63
|
+
device: Create tensor on this device.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Batched tensor representation padded to the max label length. Shape: N, L
|
|
67
|
+
"""
|
|
68
|
+
raise NotImplementedError
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
|
|
72
|
+
"""Internal method which performs the necessary filtering prior to decoding."""
|
|
73
|
+
raise NotImplementedError
|
|
74
|
+
|
|
75
|
+
def decode(self,
|
|
76
|
+
token_dists: Tensor,
|
|
77
|
+
raw: bool = False) -> Tuple[List[str], List[Tensor]]:
|
|
78
|
+
"""Decode a batch of token distributions.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
token_dists: softmax probabilities over the token distribution. Shape: N, L, C
|
|
82
|
+
raw: return unprocessed labels (will return list of list of strings)
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
list of string labels (arbitrary length) and
|
|
86
|
+
their corresponding sequence probabilities as a list of Tensors
|
|
87
|
+
"""
|
|
88
|
+
batch_tokens = []
|
|
89
|
+
batch_probs = []
|
|
90
|
+
for dist in token_dists:
|
|
91
|
+
probs, ids = dist.max(-1) # greedy selection
|
|
92
|
+
if not raw:
|
|
93
|
+
probs, ids = self._filter(probs, ids)
|
|
94
|
+
tokens = self._ids2tok(ids, not raw)
|
|
95
|
+
batch_tokens.append(tokens)
|
|
96
|
+
batch_probs.append(probs)
|
|
97
|
+
return batch_tokens, batch_probs
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class Tokenizer(BaseTokenizer):
|
|
101
|
+
BOS = '[B]'
|
|
102
|
+
EOS = '[E]'
|
|
103
|
+
PAD = '[P]'
|
|
104
|
+
|
|
105
|
+
def __init__(self, charset: str) -> None:
|
|
106
|
+
specials_first = (self.EOS, )
|
|
107
|
+
specials_last = (self.BOS, self.PAD)
|
|
108
|
+
super().__init__(charset, specials_first, specials_last)
|
|
109
|
+
self.eos_id, self.bos_id, self.pad_id = [
|
|
110
|
+
self._stoi[s] for s in specials_first + specials_last
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
def encode(self,
|
|
114
|
+
labels: List[str],
|
|
115
|
+
device: Optional[torch.device] = None) -> Tensor:
|
|
116
|
+
batch = [self.bos_id] + self._tok2ids(labels) + [self.eos_id]
|
|
117
|
+
return batch
|
|
118
|
+
# return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
|
|
119
|
+
|
|
120
|
+
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
|
|
121
|
+
ids = ids.tolist()
|
|
122
|
+
try:
|
|
123
|
+
eos_idx = ids.index(self.eos_id)
|
|
124
|
+
except ValueError:
|
|
125
|
+
eos_idx = len(ids) # Nothing to truncate.
|
|
126
|
+
# Truncate after EOS
|
|
127
|
+
ids = ids[:eos_idx]
|
|
128
|
+
probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
|
|
129
|
+
return probs, ids
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class DPTRLabelEncode(Tokenizer):
|
|
133
|
+
"""Convert between text-label and text-index."""
|
|
134
|
+
|
|
135
|
+
def __init__(self, max_text_length=25, character_dict_path=None, **kwargs):
|
|
136
|
+
self.max_length = max_text_length
|
|
137
|
+
charset = get_alpha(character_dict_path)
|
|
138
|
+
charset = ''.join(charset)
|
|
139
|
+
# print(charset)
|
|
140
|
+
super(DPTRLabelEncode, self).__init__(charset)
|
|
141
|
+
|
|
142
|
+
def __call__(self, data, normalize_unicode=True):
|
|
143
|
+
text = data['label']
|
|
144
|
+
|
|
145
|
+
if normalize_unicode:
|
|
146
|
+
text = unicodedata.normalize('NFKD',
|
|
147
|
+
text).encode('ascii',
|
|
148
|
+
'ignore').decode()
|
|
149
|
+
text = ''.join(text.split())
|
|
150
|
+
if len(text) == 0 or len(text) > self.max_length:
|
|
151
|
+
return None
|
|
152
|
+
|
|
153
|
+
text_ids = self.encode(text)
|
|
154
|
+
clip_ids = tokenize(f"a photo of a '{text}'")
|
|
155
|
+
text_ids = text_ids + [self.pad_id
|
|
156
|
+
] * (self.max_length + 2 - len(text_ids))
|
|
157
|
+
# print(text, len(text_ids), len(clip_ids[0]))
|
|
158
|
+
data['clip_label'] = np.array(clip_ids[0])
|
|
159
|
+
data['label'] = np.array(text_ids)
|
|
160
|
+
return data
|
|
161
|
+
|
|
162
|
+
def add_special_char(self, dict_character):
|
|
163
|
+
dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
|
|
164
|
+
return dict_character
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def get_alpha(alpha_path):
|
|
168
|
+
character_str = []
|
|
169
|
+
with open(alpha_path, 'rb') as fin:
|
|
170
|
+
lines = fin.readlines()
|
|
171
|
+
for line in lines:
|
|
172
|
+
line = line.decode('utf-8').strip('\n').strip('\r\n')
|
|
173
|
+
character_str.append(line)
|
|
174
|
+
dict_character = list(character_str)
|
|
175
|
+
if 'arabic' in alpha_path:
|
|
176
|
+
reverse = True
|
|
177
|
+
return dict_character
|
|
@@ -121,11 +121,13 @@ class IGTRLabelEncode(BaseRecLabelEncode):
|
|
|
121
121
|
ques2_answer = []
|
|
122
122
|
for q_2, ques2_idx in enumerate(ques2_char_idx.tolist()):
|
|
123
123
|
|
|
124
|
-
if (train_step == 2
|
|
124
|
+
if (train_step == 2
|
|
125
|
+
or train_step == 3) and q_2 == ques_len - 1:
|
|
125
126
|
new_ques2_char_idx.append(ques2_idx)
|
|
126
127
|
ques2_answer.append(1)
|
|
127
128
|
continue
|
|
128
|
-
if ques2_idx[1] != self.dict['<pad>'] and random.random(
|
|
129
|
+
if ques2_idx[1] != self.dict['<pad>'] and random.random(
|
|
130
|
+
) > 0.5:
|
|
129
131
|
select_idx = random.randint(0, self.num_character - 3)
|
|
130
132
|
new_ques2_char_idx.append([ques2_idx[0], select_idx])
|
|
131
133
|
if select_idx == ques2_idx[1]:
|