openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -50,7 +50,7 @@ class BaseRecLabelEncode(object):
50
50
  text_re = []
51
51
  c_current = ''
52
52
  for c in text:
53
- if not bool(re.search('[a-zA-Z0-9 :*./%+-١٢٣٤٥٦٧٨٩٠]', c)):
53
+ if not bool(re.search('[a-zA-Z0-9 :*./%+١٢٣٤٥٦٧٨٩٠-]', c)):
54
54
  if c_current != '':
55
55
  text_re.append(c_current)
56
56
  text_re.append(c)
@@ -1,157 +1,177 @@
1
- import re
2
- from abc import ABC, abstractmethod
3
- from itertools import groupby
4
- from typing import List, Optional, Tuple
5
- import numpy as np
6
- import torch
7
- from torch import Tensor
8
- from torch.nn.utils.rnn import pad_sequence
9
- import unicodedata
10
- from ..modeling.decoders.dptr_parseq_clip_b_decoder import tokenize
11
-
12
- class CharsetAdapter:
13
- """Transforms labels according to the target charset."""
14
-
15
- def __init__(self, target_charset) -> None:
16
- super().__init__()
17
- self.lowercase_only = target_charset == target_charset.lower()
18
- self.uppercase_only = target_charset == target_charset.upper()
19
- self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
20
-
21
- def __call__(self, label):
22
- if self.lowercase_only:
23
- label = label.lower()
24
- elif self.uppercase_only:
25
- label = label.upper()
26
- # Remove unsupported characters
27
- label = self.unsupported.sub('', label)
28
- return label
29
-
30
-
31
- class BaseTokenizer(ABC):
32
- # eos=0, a=1, bos=37, pad=38
33
- def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
34
- self._itos = specials_first + tuple(charset) + specials_last
35
- self._stoi = {s: i for i, s in enumerate(self._itos)}
36
- # print("stoi:", self._stoi)
37
-
38
- def __len__(self):
39
- return len(self._itos)
40
-
41
- def _tok2ids(self, tokens: str) -> List[int]:
42
- # print("tokens", tokens)
43
- return [self._stoi[s] for s in tokens]
44
-
45
- def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
46
- tokens = [self._itos[i] for i in token_ids]
47
- return ''.join(tokens) if join else tokens
48
-
49
- @abstractmethod
50
- def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
51
- """Encode a batch of labels to a representation suitable for the model.
52
-
53
- Args:
54
- labels: List of labels. Each can be of arbitrary length.
55
- device: Create tensor on this device.
56
-
57
- Returns:
58
- Batched tensor representation padded to the max label length. Shape: N, L
59
- """
60
- raise NotImplementedError
61
-
62
- @abstractmethod
63
- def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
64
- """Internal method which performs the necessary filtering prior to decoding."""
65
- raise NotImplementedError
66
-
67
- def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
68
- """Decode a batch of token distributions.
69
-
70
- Args:
71
- token_dists: softmax probabilities over the token distribution. Shape: N, L, C
72
- raw: return unprocessed labels (will return list of list of strings)
73
-
74
- Returns:
75
- list of string labels (arbitrary length) and
76
- their corresponding sequence probabilities as a list of Tensors
77
- """
78
- batch_tokens = []
79
- batch_probs = []
80
- for dist in token_dists:
81
- probs, ids = dist.max(-1) # greedy selection
82
- if not raw:
83
- probs, ids = self._filter(probs, ids)
84
- tokens = self._ids2tok(ids, not raw)
85
- batch_tokens.append(tokens)
86
- batch_probs.append(probs)
87
- return batch_tokens, batch_probs
88
-
89
-
90
- class Tokenizer(BaseTokenizer):
91
- BOS = '[B]'
92
- EOS = '[E]'
93
- PAD = '[P]'
94
-
95
- def __init__(self, charset: str) -> None:
96
- specials_first = (self.EOS,)
97
- specials_last = (self.BOS, self.PAD)
98
- super().__init__(charset, specials_first, specials_last)
99
- self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last]
100
-
101
- def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
102
- batch = [self.bos_id] + self._tok2ids(labels) + [self.eos_id]
103
- return batch
104
- # return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
105
-
106
- def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
107
- ids = ids.tolist()
108
- try:
109
- eos_idx = ids.index(self.eos_id)
110
- except ValueError:
111
- eos_idx = len(ids) # Nothing to truncate.
112
- # Truncate after EOS
113
- ids = ids[:eos_idx]
114
- probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
115
- return probs, ids
116
-
117
- class DPTRLabelEncode(Tokenizer):
118
- """Convert between text-label and text-index."""
119
- def __init__(self, max_text_length=25, character_dict_path=None, **kwargs):
120
- self.max_length = max_text_length
121
- charset = get_alpha(character_dict_path)
122
- charset = ''.join(charset)
123
- # print(charset)
124
- super(DPTRLabelEncode, self).__init__(charset)
125
-
126
- def __call__(self, data, normalize_unicode=True):
127
- text = data['label']
128
-
129
- if normalize_unicode:
130
- text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode()
131
- text = ''.join(text.split())
132
- if len(text) == 0 or len(text) > self.max_length:
133
- return None
134
-
135
- text_ids = self.encode(text)
136
- clip_ids = tokenize(f"a photo of a '{text}'")
137
- text_ids = text_ids + [self.pad_id] * (self.max_length + 2 - len(text_ids))
138
- # print(text, len(text_ids), len(clip_ids[0]))
139
- data['clip_label'] = np.array(clip_ids[0])
140
- data['label'] = np.array(text_ids)
141
- return data
142
-
143
- def add_special_char(self, dict_character):
144
- dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
145
- return dict_character
146
-
147
- def get_alpha(alpha_path):
148
- character_str = []
149
- with open(alpha_path, 'rb') as fin:
150
- lines = fin.readlines()
151
- for line in lines:
152
- line = line.decode('utf-8').strip('\n').strip('\r\n')
153
- character_str.append(line)
154
- dict_character = list(character_str)
155
- if 'arabic' in alpha_path:
156
- reverse = True
157
- return dict_character
1
+ import re
2
+ from abc import ABC, abstractmethod
3
+ from itertools import groupby
4
+ from typing import List, Optional, Tuple
5
+ import numpy as np
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ import unicodedata
10
+ from ..modeling.decoders.dptr_parseq_clip_b_decoder import tokenize
11
+
12
+
13
+ class CharsetAdapter:
14
+ """Transforms labels according to the target charset."""
15
+
16
+ def __init__(self, target_charset) -> None:
17
+ super().__init__()
18
+ self.lowercase_only = target_charset == target_charset.lower()
19
+ self.uppercase_only = target_charset == target_charset.upper()
20
+ self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
21
+
22
+ def __call__(self, label):
23
+ if self.lowercase_only:
24
+ label = label.lower()
25
+ elif self.uppercase_only:
26
+ label = label.upper()
27
+ # Remove unsupported characters
28
+ label = self.unsupported.sub('', label)
29
+ return label
30
+
31
+
32
+ class BaseTokenizer(ABC):
33
+ # eos=0, a=1, bos=37, pad=38
34
+ def __init__(
35
+ self,
36
+ charset: str,
37
+ specials_first: tuple = (),
38
+ specials_last: tuple = ()
39
+ ) -> None:
40
+ self._itos = specials_first + tuple(charset) + specials_last
41
+ self._stoi = {s: i for i, s in enumerate(self._itos)}
42
+ # print("stoi:", self._stoi)
43
+
44
+ def __len__(self):
45
+ return len(self._itos)
46
+
47
+ def _tok2ids(self, tokens: str) -> List[int]:
48
+ # print("tokens", tokens)
49
+ return [self._stoi[s] for s in tokens]
50
+
51
+ def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
52
+ tokens = [self._itos[i] for i in token_ids]
53
+ return ''.join(tokens) if join else tokens
54
+
55
+ @abstractmethod
56
+ def encode(self,
57
+ labels: List[str],
58
+ device: Optional[torch.device] = None) -> Tensor:
59
+ """Encode a batch of labels to a representation suitable for the model.
60
+
61
+ Args:
62
+ labels: List of labels. Each can be of arbitrary length.
63
+ device: Create tensor on this device.
64
+
65
+ Returns:
66
+ Batched tensor representation padded to the max label length. Shape: N, L
67
+ """
68
+ raise NotImplementedError
69
+
70
+ @abstractmethod
71
+ def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
72
+ """Internal method which performs the necessary filtering prior to decoding."""
73
+ raise NotImplementedError
74
+
75
+ def decode(self,
76
+ token_dists: Tensor,
77
+ raw: bool = False) -> Tuple[List[str], List[Tensor]]:
78
+ """Decode a batch of token distributions.
79
+
80
+ Args:
81
+ token_dists: softmax probabilities over the token distribution. Shape: N, L, C
82
+ raw: return unprocessed labels (will return list of list of strings)
83
+
84
+ Returns:
85
+ list of string labels (arbitrary length) and
86
+ their corresponding sequence probabilities as a list of Tensors
87
+ """
88
+ batch_tokens = []
89
+ batch_probs = []
90
+ for dist in token_dists:
91
+ probs, ids = dist.max(-1) # greedy selection
92
+ if not raw:
93
+ probs, ids = self._filter(probs, ids)
94
+ tokens = self._ids2tok(ids, not raw)
95
+ batch_tokens.append(tokens)
96
+ batch_probs.append(probs)
97
+ return batch_tokens, batch_probs
98
+
99
+
100
+ class Tokenizer(BaseTokenizer):
101
+ BOS = '[B]'
102
+ EOS = '[E]'
103
+ PAD = '[P]'
104
+
105
+ def __init__(self, charset: str) -> None:
106
+ specials_first = (self.EOS, )
107
+ specials_last = (self.BOS, self.PAD)
108
+ super().__init__(charset, specials_first, specials_last)
109
+ self.eos_id, self.bos_id, self.pad_id = [
110
+ self._stoi[s] for s in specials_first + specials_last
111
+ ]
112
+
113
+ def encode(self,
114
+ labels: List[str],
115
+ device: Optional[torch.device] = None) -> Tensor:
116
+ batch = [self.bos_id] + self._tok2ids(labels) + [self.eos_id]
117
+ return batch
118
+ # return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
119
+
120
+ def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
121
+ ids = ids.tolist()
122
+ try:
123
+ eos_idx = ids.index(self.eos_id)
124
+ except ValueError:
125
+ eos_idx = len(ids) # Nothing to truncate.
126
+ # Truncate after EOS
127
+ ids = ids[:eos_idx]
128
+ probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
129
+ return probs, ids
130
+
131
+
132
+ class DPTRLabelEncode(Tokenizer):
133
+ """Convert between text-label and text-index."""
134
+
135
+ def __init__(self, max_text_length=25, character_dict_path=None, **kwargs):
136
+ self.max_length = max_text_length
137
+ charset = get_alpha(character_dict_path)
138
+ charset = ''.join(charset)
139
+ # print(charset)
140
+ super(DPTRLabelEncode, self).__init__(charset)
141
+
142
+ def __call__(self, data, normalize_unicode=True):
143
+ text = data['label']
144
+
145
+ if normalize_unicode:
146
+ text = unicodedata.normalize('NFKD',
147
+ text).encode('ascii',
148
+ 'ignore').decode()
149
+ text = ''.join(text.split())
150
+ if len(text) == 0 or len(text) > self.max_length:
151
+ return None
152
+
153
+ text_ids = self.encode(text)
154
+ clip_ids = tokenize(f"a photo of a '{text}'")
155
+ text_ids = text_ids + [self.pad_id
156
+ ] * (self.max_length + 2 - len(text_ids))
157
+ # print(text, len(text_ids), len(clip_ids[0]))
158
+ data['clip_label'] = np.array(clip_ids[0])
159
+ data['label'] = np.array(text_ids)
160
+ return data
161
+
162
+ def add_special_char(self, dict_character):
163
+ dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
164
+ return dict_character
165
+
166
+
167
+ def get_alpha(alpha_path):
168
+ character_str = []
169
+ with open(alpha_path, 'rb') as fin:
170
+ lines = fin.readlines()
171
+ for line in lines:
172
+ line = line.decode('utf-8').strip('\n').strip('\r\n')
173
+ character_str.append(line)
174
+ dict_character = list(character_str)
175
+ if 'arabic' in alpha_path:
176
+ reverse = True
177
+ return dict_character
@@ -121,11 +121,13 @@ class IGTRLabelEncode(BaseRecLabelEncode):
121
121
  ques2_answer = []
122
122
  for q_2, ques2_idx in enumerate(ques2_char_idx.tolist()):
123
123
 
124
- if (train_step == 2 or train_step == 3) and q_2 == ques_len - 1:
124
+ if (train_step == 2
125
+ or train_step == 3) and q_2 == ques_len - 1:
125
126
  new_ques2_char_idx.append(ques2_idx)
126
127
  ques2_answer.append(1)
127
128
  continue
128
- if ques2_idx[1] != self.dict['<pad>'] and random.random() > 0.5:
129
+ if ques2_idx[1] != self.dict['<pad>'] and random.random(
130
+ ) > 0.5:
129
131
  select_idx = random.randint(0, self.num_character - 3)
130
132
  new_ques2_char_idx.append([ques2_idx[0], select_idx])
131
133
  if select_idx == ques2_idx[1]: