openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,433 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Optional, Tuple, Union
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from torch.nn import CrossEntropyLoss
|
|
6
|
+
|
|
7
|
+
# transformers == 4.45.1
|
|
8
|
+
from .configuration_unirec import UniRecConfig
|
|
9
|
+
from transformers import M2M100PreTrainedModel
|
|
10
|
+
from transformers.models.m2m_100.modeling_m2m_100 import M2M100ScaledWordEmbedding, M2M100Decoder
|
|
11
|
+
from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput, BaseModelOutput
|
|
12
|
+
from transformers.generation import GenerationMixin
|
|
13
|
+
from openrec.modeling.encoders.focalsvtr import FocalSVTR
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int,
|
|
17
|
+
decoder_start_token_id: int):
|
|
18
|
+
"""
|
|
19
|
+
Shift input ids one token to the right.
|
|
20
|
+
"""
|
|
21
|
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
|
22
|
+
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
|
23
|
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
|
24
|
+
|
|
25
|
+
if pad_token_id is None:
|
|
26
|
+
raise ValueError('self.model.config.pad_token_id has to be defined.')
|
|
27
|
+
# replace possible -100 values in labels by `pad_token_id`
|
|
28
|
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
|
29
|
+
|
|
30
|
+
return shifted_input_ids
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class UniRecEncoder(M2M100PreTrainedModel):
|
|
34
|
+
"""
|
|
35
|
+
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
|
36
|
+
[`M2M100EncoderLayer`].
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
config: UniRecConfig
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, config: UniRecConfig):
|
|
43
|
+
super().__init__(config)
|
|
44
|
+
|
|
45
|
+
self.config = config
|
|
46
|
+
|
|
47
|
+
self.vision_encoder = FocalSVTR(img_size=[1408, 960],
|
|
48
|
+
depths=[2, 2, 9, 2],
|
|
49
|
+
embed_dim=96,
|
|
50
|
+
sub_k=[[2, 2], [2, 2], [2, 2],
|
|
51
|
+
[-1, -1]],
|
|
52
|
+
focal_levels=[3, 3, 3, 3],
|
|
53
|
+
max_khs=[7, 3, 3, 3],
|
|
54
|
+
focal_windows=[3, 3, 3, 3],
|
|
55
|
+
last_stage=False,
|
|
56
|
+
feat2d=False)
|
|
57
|
+
|
|
58
|
+
self.vision_fc = nn.Linear(config.d_model, config.d_model)
|
|
59
|
+
|
|
60
|
+
def forward(
|
|
61
|
+
self,
|
|
62
|
+
pixel_values: torch.Tensor = None,
|
|
63
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
64
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
65
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
66
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
67
|
+
output_attentions: Optional[bool] = None,
|
|
68
|
+
output_hidden_states: Optional[bool] = None,
|
|
69
|
+
return_dict: Optional[bool] = None,
|
|
70
|
+
):
|
|
71
|
+
r"""
|
|
72
|
+
Args:
|
|
73
|
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
74
|
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
|
75
|
+
provide it.
|
|
76
|
+
|
|
77
|
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
78
|
+
[`PreTrainedTokenizer.__call__`] for details.
|
|
79
|
+
|
|
80
|
+
[What are input IDs?](../glossary#input-ids)
|
|
81
|
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
82
|
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
83
|
+
|
|
84
|
+
- 1 for tokens that are **not masked**,
|
|
85
|
+
- 0 for tokens that are **masked**.
|
|
86
|
+
|
|
87
|
+
[What are attention masks?](../glossary#attention-mask)
|
|
88
|
+
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
|
89
|
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
|
90
|
+
|
|
91
|
+
- 1 indicates the head is **not masked**,
|
|
92
|
+
- 0 indicates the head is **masked**.
|
|
93
|
+
|
|
94
|
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
95
|
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
|
96
|
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
97
|
+
than the model's internal embedding lookup matrix.
|
|
98
|
+
output_attentions (`bool`, *optional*):
|
|
99
|
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
100
|
+
returned tensors for more detail.
|
|
101
|
+
output_hidden_states (`bool`, *optional*):
|
|
102
|
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
103
|
+
for more detail.
|
|
104
|
+
return_dict (`bool`, *optional*):
|
|
105
|
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
106
|
+
"""
|
|
107
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
108
|
+
output_hidden_states = (output_hidden_states
|
|
109
|
+
if output_hidden_states is not None else
|
|
110
|
+
self.config.output_hidden_states)
|
|
111
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
112
|
+
# print('visionencoder pixel_values', pixel_values)
|
|
113
|
+
# retrieve input_ids and inputs_embeds
|
|
114
|
+
|
|
115
|
+
encoder_states = () if output_hidden_states else None
|
|
116
|
+
all_attentions = () if output_attentions else None
|
|
117
|
+
|
|
118
|
+
hidden_states = self.vision_encoder(pixel_values)
|
|
119
|
+
hidden_states = self.vision_fc(hidden_states)
|
|
120
|
+
|
|
121
|
+
# hidden_states = self.layer_norm(hidden_states)
|
|
122
|
+
|
|
123
|
+
if output_hidden_states:
|
|
124
|
+
encoder_states = (hidden_states, )
|
|
125
|
+
|
|
126
|
+
if not return_dict:
|
|
127
|
+
return tuple(
|
|
128
|
+
v for v in [hidden_states, encoder_states, all_attentions]
|
|
129
|
+
if v is not None)
|
|
130
|
+
return BaseModelOutput(last_hidden_state=hidden_states,
|
|
131
|
+
hidden_states=encoder_states,
|
|
132
|
+
attentions=all_attentions)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class UniRecModel(M2M100PreTrainedModel):
|
|
136
|
+
_tied_weights_keys = [
|
|
137
|
+
'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
def __init__(self, config: UniRecConfig):
|
|
141
|
+
super().__init__(config)
|
|
142
|
+
|
|
143
|
+
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
|
144
|
+
embed_scale = math.sqrt(
|
|
145
|
+
config.d_model) if config.scale_embedding else 1.0
|
|
146
|
+
self.shared = M2M100ScaledWordEmbedding(vocab_size,
|
|
147
|
+
config.d_model,
|
|
148
|
+
padding_idx,
|
|
149
|
+
embed_scale=embed_scale)
|
|
150
|
+
|
|
151
|
+
self.encoder = UniRecEncoder(config)
|
|
152
|
+
self.decoder = M2M100Decoder(config, self.shared)
|
|
153
|
+
|
|
154
|
+
# Initialize weights and apply final processing
|
|
155
|
+
self.post_init()
|
|
156
|
+
|
|
157
|
+
def get_input_embeddings(self):
|
|
158
|
+
return self.shared
|
|
159
|
+
|
|
160
|
+
def set_input_embeddings(self, value):
|
|
161
|
+
self.shared = value
|
|
162
|
+
self.decoder.embed_tokens = self.shared
|
|
163
|
+
|
|
164
|
+
def _tie_weights(self):
|
|
165
|
+
if self.config.tie_word_embeddings:
|
|
166
|
+
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
|
167
|
+
|
|
168
|
+
def get_encoder(self):
|
|
169
|
+
return self.encoder
|
|
170
|
+
|
|
171
|
+
def get_decoder(self):
|
|
172
|
+
return self.decoder
|
|
173
|
+
|
|
174
|
+
def forward(
|
|
175
|
+
self,
|
|
176
|
+
pixel_values: torch.Tensor = None,
|
|
177
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
178
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
179
|
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
|
180
|
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
|
181
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
182
|
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
|
183
|
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
|
184
|
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
185
|
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
186
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
187
|
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
188
|
+
use_cache: Optional[bool] = None,
|
|
189
|
+
output_attentions: Optional[bool] = None,
|
|
190
|
+
output_hidden_states: Optional[bool] = None,
|
|
191
|
+
return_dict: Optional[bool] = None,
|
|
192
|
+
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
|
193
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
194
|
+
output_hidden_states = (output_hidden_states
|
|
195
|
+
if output_hidden_states is not None else
|
|
196
|
+
self.config.output_hidden_states)
|
|
197
|
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
198
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
199
|
+
|
|
200
|
+
if encoder_outputs is None:
|
|
201
|
+
encoder_outputs = self.encoder(pixel_values)
|
|
202
|
+
|
|
203
|
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
|
204
|
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
|
205
|
+
encoder_outputs = BaseModelOutput(
|
|
206
|
+
last_hidden_state=encoder_outputs[0],
|
|
207
|
+
hidden_states=encoder_outputs[1]
|
|
208
|
+
if len(encoder_outputs) > 1 else None,
|
|
209
|
+
attentions=encoder_outputs[2]
|
|
210
|
+
if len(encoder_outputs) > 2 else None,
|
|
211
|
+
)
|
|
212
|
+
attention_mask = torch.ones(encoder_outputs[0].shape[:2],
|
|
213
|
+
dtype=torch.long,
|
|
214
|
+
device=encoder_outputs[0].device)
|
|
215
|
+
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
|
216
|
+
decoder_outputs = self.decoder(
|
|
217
|
+
input_ids=decoder_input_ids,
|
|
218
|
+
attention_mask=decoder_attention_mask,
|
|
219
|
+
encoder_hidden_states=encoder_outputs[0],
|
|
220
|
+
encoder_attention_mask=attention_mask,
|
|
221
|
+
head_mask=decoder_head_mask,
|
|
222
|
+
cross_attn_head_mask=cross_attn_head_mask,
|
|
223
|
+
past_key_values=past_key_values,
|
|
224
|
+
inputs_embeds=decoder_inputs_embeds,
|
|
225
|
+
use_cache=use_cache,
|
|
226
|
+
output_attentions=output_attentions,
|
|
227
|
+
output_hidden_states=output_hidden_states,
|
|
228
|
+
return_dict=return_dict,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if not return_dict:
|
|
232
|
+
return decoder_outputs + encoder_outputs
|
|
233
|
+
|
|
234
|
+
return Seq2SeqModelOutput(
|
|
235
|
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
236
|
+
past_key_values=decoder_outputs.past_key_values,
|
|
237
|
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
238
|
+
decoder_attentions=decoder_outputs.attentions,
|
|
239
|
+
cross_attentions=decoder_outputs.cross_attentions,
|
|
240
|
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
|
241
|
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
242
|
+
encoder_attentions=encoder_outputs.attentions,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class UniRecForConditionalGenerationNew(M2M100PreTrainedModel,
|
|
247
|
+
GenerationMixin):
|
|
248
|
+
base_model_prefix = 'model'
|
|
249
|
+
_tied_weights_keys = [
|
|
250
|
+
'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight',
|
|
251
|
+
'lm_head.weight'
|
|
252
|
+
]
|
|
253
|
+
|
|
254
|
+
def __init__(self, config: UniRecConfig):
|
|
255
|
+
super().__init__(config)
|
|
256
|
+
self.model = UniRecModel(config)
|
|
257
|
+
self.lm_head = nn.Linear(config.d_model,
|
|
258
|
+
self.model.shared.num_embeddings,
|
|
259
|
+
bias=False)
|
|
260
|
+
self.loss_fct = CrossEntropyLoss(
|
|
261
|
+
ignore_index=config.pad_token_id,
|
|
262
|
+
label_smoothing=config.label_smoothing)
|
|
263
|
+
# Initialize weights and apply final processing
|
|
264
|
+
self.post_init()
|
|
265
|
+
|
|
266
|
+
def get_encoder(self):
|
|
267
|
+
return self.model.get_encoder()
|
|
268
|
+
|
|
269
|
+
def get_decoder(self):
|
|
270
|
+
return self.model.get_decoder()
|
|
271
|
+
|
|
272
|
+
def get_output_embeddings(self):
|
|
273
|
+
return self.lm_head
|
|
274
|
+
|
|
275
|
+
def set_output_embeddings(self, new_embeddings):
|
|
276
|
+
self.lm_head = new_embeddings
|
|
277
|
+
|
|
278
|
+
def forward(
|
|
279
|
+
self,
|
|
280
|
+
pixel_values: torch.Tensor,
|
|
281
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
282
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
283
|
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
|
284
|
+
length: Optional[torch.LongTensor] = None,
|
|
285
|
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
|
286
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
287
|
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
|
288
|
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
|
289
|
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
290
|
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
291
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
292
|
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
293
|
+
labels: Optional[torch.LongTensor] = None,
|
|
294
|
+
use_cache: Optional[bool] = None,
|
|
295
|
+
output_attentions: Optional[bool] = None,
|
|
296
|
+
output_hidden_states: Optional[bool] = None,
|
|
297
|
+
return_dict: Optional[bool] = None,
|
|
298
|
+
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
|
299
|
+
r"""
|
|
300
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
301
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
302
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
303
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
"""
|
|
307
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
308
|
+
|
|
309
|
+
if labels is not None:
|
|
310
|
+
if decoder_input_ids is None:
|
|
311
|
+
# decoder_input_ids = shift_tokens_right(
|
|
312
|
+
# labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
|
313
|
+
# )
|
|
314
|
+
if length is not None:
|
|
315
|
+
max_len = length.max()
|
|
316
|
+
decoder_input_ids = labels[:, :1 + max_len]
|
|
317
|
+
labels = labels[:, 1:2 + max_len]
|
|
318
|
+
else:
|
|
319
|
+
decoder_input_ids = labels[:, :-1]
|
|
320
|
+
labels = labels[:, 1:]
|
|
321
|
+
masked_lm_loss = None
|
|
322
|
+
if self.training and labels is not None:
|
|
323
|
+
outputs = self.model(
|
|
324
|
+
pixel_values=pixel_values,
|
|
325
|
+
input_ids=None,
|
|
326
|
+
attention_mask=attention_mask,
|
|
327
|
+
decoder_input_ids=decoder_input_ids,
|
|
328
|
+
encoder_outputs=encoder_outputs,
|
|
329
|
+
decoder_attention_mask=decoder_attention_mask,
|
|
330
|
+
head_mask=head_mask,
|
|
331
|
+
decoder_head_mask=decoder_head_mask,
|
|
332
|
+
cross_attn_head_mask=cross_attn_head_mask,
|
|
333
|
+
past_key_values=past_key_values,
|
|
334
|
+
inputs_embeds=inputs_embeds,
|
|
335
|
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
336
|
+
use_cache=use_cache,
|
|
337
|
+
output_attentions=output_attentions,
|
|
338
|
+
output_hidden_states=output_hidden_states,
|
|
339
|
+
return_dict=return_dict,
|
|
340
|
+
)
|
|
341
|
+
lm_logits = self.lm_head(outputs[0])
|
|
342
|
+
masked_lm_loss = self.loss_fct(
|
|
343
|
+
lm_logits.reshape(-1, self.config.vocab_size),
|
|
344
|
+
labels.reshape(-1))
|
|
345
|
+
else:
|
|
346
|
+
# print('pixel_values', pixel_values.shape)
|
|
347
|
+
# print('decoder_input_ids', decoder_input_ids)
|
|
348
|
+
outputs = self.model(
|
|
349
|
+
pixel_values=pixel_values,
|
|
350
|
+
input_ids=None,
|
|
351
|
+
attention_mask=attention_mask,
|
|
352
|
+
decoder_input_ids=decoder_input_ids,
|
|
353
|
+
encoder_outputs=encoder_outputs,
|
|
354
|
+
decoder_attention_mask=decoder_attention_mask,
|
|
355
|
+
head_mask=head_mask,
|
|
356
|
+
decoder_head_mask=decoder_head_mask,
|
|
357
|
+
cross_attn_head_mask=cross_attn_head_mask,
|
|
358
|
+
past_key_values=past_key_values,
|
|
359
|
+
inputs_embeds=inputs_embeds,
|
|
360
|
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
361
|
+
use_cache=use_cache,
|
|
362
|
+
output_attentions=output_attentions,
|
|
363
|
+
output_hidden_states=output_hidden_states,
|
|
364
|
+
return_dict=return_dict,
|
|
365
|
+
)
|
|
366
|
+
lm_logits = self.lm_head(outputs[0])
|
|
367
|
+
|
|
368
|
+
if not return_dict:
|
|
369
|
+
output = (lm_logits, ) + outputs[1:]
|
|
370
|
+
return ((masked_lm_loss, ) +
|
|
371
|
+
output) if masked_lm_loss is not None else output
|
|
372
|
+
|
|
373
|
+
return Seq2SeqLMOutput(
|
|
374
|
+
loss=masked_lm_loss,
|
|
375
|
+
logits=lm_logits,
|
|
376
|
+
past_key_values=outputs.past_key_values,
|
|
377
|
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
|
378
|
+
decoder_attentions=outputs.decoder_attentions,
|
|
379
|
+
cross_attentions=outputs.cross_attentions,
|
|
380
|
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
|
381
|
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
|
382
|
+
encoder_attentions=outputs.encoder_attentions,
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
def prepare_inputs_for_generation(
|
|
386
|
+
self,
|
|
387
|
+
decoder_input_ids,
|
|
388
|
+
past_key_values=None,
|
|
389
|
+
attention_mask=None,
|
|
390
|
+
head_mask=None,
|
|
391
|
+
decoder_head_mask=None,
|
|
392
|
+
cross_attn_head_mask=None,
|
|
393
|
+
use_cache=None,
|
|
394
|
+
encoder_outputs=None,
|
|
395
|
+
pixel_values=None,
|
|
396
|
+
**kwargs,
|
|
397
|
+
):
|
|
398
|
+
# cut decoder_input_ids if past is used
|
|
399
|
+
if past_key_values is not None:
|
|
400
|
+
past_length = past_key_values[0][0].shape[2]
|
|
401
|
+
|
|
402
|
+
# Some generation methods already pass only the last input ID
|
|
403
|
+
if decoder_input_ids.shape[1] > past_length:
|
|
404
|
+
remove_prefix_length = past_length
|
|
405
|
+
else:
|
|
406
|
+
# Default to old behavior: keep only final ID
|
|
407
|
+
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
|
408
|
+
|
|
409
|
+
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
|
410
|
+
|
|
411
|
+
return {
|
|
412
|
+
'input_ids':
|
|
413
|
+
None, # encoder_outputs is defined. input_ids not needed
|
|
414
|
+
'encoder_outputs': encoder_outputs,
|
|
415
|
+
'past_key_values': past_key_values,
|
|
416
|
+
'decoder_input_ids': decoder_input_ids,
|
|
417
|
+
'attention_mask': attention_mask,
|
|
418
|
+
'head_mask': head_mask,
|
|
419
|
+
'decoder_head_mask': decoder_head_mask,
|
|
420
|
+
'cross_attn_head_mask': cross_attn_head_mask,
|
|
421
|
+
'use_cache':
|
|
422
|
+
use_cache, # change this to avoid caching (presumably for debugging)
|
|
423
|
+
'pixel_values': pixel_values,
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
@staticmethod
|
|
427
|
+
def _reorder_cache(past_key_values, beam_idx):
|
|
428
|
+
reordered_past = ()
|
|
429
|
+
for layer_past in past_key_values:
|
|
430
|
+
reordered_past += (tuple(
|
|
431
|
+
past_state.index_select(0, beam_idx.to(past_state.device))
|
|
432
|
+
for past_state in layer_past), )
|
|
433
|
+
return reordered_past
|
|
@@ -2,7 +2,6 @@ import copy
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import nn
|
|
5
|
-
|
|
6
5
|
__all__ = ['build_optimizer']
|
|
7
6
|
|
|
8
7
|
|
|
@@ -63,11 +62,13 @@ def build_optimizer(optim_config, lr_scheduler_config, epochs, step_each_epoch,
|
|
|
63
62
|
**config)
|
|
64
63
|
|
|
65
64
|
lr_config = copy.deepcopy(lr_scheduler_config)
|
|
65
|
+
scheduler_name = lr_config.pop('name')
|
|
66
|
+
|
|
66
67
|
lr_config.update({
|
|
67
68
|
'epochs': epochs,
|
|
68
69
|
'step_each_epoch': step_each_epoch,
|
|
69
70
|
'lr': config['lr']
|
|
70
71
|
})
|
|
71
|
-
lr_scheduler = getattr(lr,
|
|
72
|
-
|
|
72
|
+
lr_scheduler = getattr(lr, scheduler_name)(**lr_config)(optimizer=optim)
|
|
73
|
+
|
|
73
74
|
return optim, lr_scheduler
|
openocr/openrec/optimizer/lr.py
CHANGED
|
@@ -225,3 +225,52 @@ class CdistNetLR(object):
|
|
|
225
225
|
np.power(self.n_warmup_steps, -1.5) * current_step,
|
|
226
226
|
])
|
|
227
227
|
return self.step2_lr / self.init_lr
|
|
228
|
+
|
|
229
|
+
class WarmupCosineLR(object):
|
|
230
|
+
def __init__(self,
|
|
231
|
+
epochs,
|
|
232
|
+
step_each_epoch,
|
|
233
|
+
warmup_steps=0,
|
|
234
|
+
eta_min=0.0,
|
|
235
|
+
last_epoch=-1,
|
|
236
|
+
**kwargs):
|
|
237
|
+
super(WarmupCosineLR, self).__init__()
|
|
238
|
+
self.total_steps = epochs * step_each_epoch
|
|
239
|
+
self.warmup_steps = warmup_steps
|
|
240
|
+
self.eta_min = eta_min
|
|
241
|
+
self.last_epoch = last_epoch
|
|
242
|
+
|
|
243
|
+
def __call__(self, optimizer):
|
|
244
|
+
schedulers = []
|
|
245
|
+
milestones = []
|
|
246
|
+
|
|
247
|
+
# 1. Warmup Phase
|
|
248
|
+
if self.warmup_steps > 0:
|
|
249
|
+
warmup_scheduler = lr_scheduler.LinearLR(
|
|
250
|
+
optimizer,
|
|
251
|
+
start_factor=1e-7,
|
|
252
|
+
end_factor=1.0,
|
|
253
|
+
total_iters=self.warmup_steps
|
|
254
|
+
)
|
|
255
|
+
schedulers.append(warmup_scheduler)
|
|
256
|
+
milestones.append(self.warmup_steps)
|
|
257
|
+
|
|
258
|
+
# 2. Cosine Phase
|
|
259
|
+
remain_steps = max(1, self.total_steps - self.warmup_steps)
|
|
260
|
+
|
|
261
|
+
cosine_scheduler = lr_scheduler.CosineAnnealingLR(
|
|
262
|
+
optimizer,
|
|
263
|
+
T_max=remain_steps,
|
|
264
|
+
eta_min=self.eta_min
|
|
265
|
+
)
|
|
266
|
+
schedulers.append(cosine_scheduler)
|
|
267
|
+
|
|
268
|
+
if len(schedulers) == 1:
|
|
269
|
+
return schedulers[0]
|
|
270
|
+
else:
|
|
271
|
+
return lr_scheduler.SequentialLR(
|
|
272
|
+
optimizer,
|
|
273
|
+
schedulers=schedulers,
|
|
274
|
+
milestones=milestones,
|
|
275
|
+
last_epoch=self.last_epoch
|
|
276
|
+
)
|
|
@@ -18,6 +18,8 @@ module_mapping = {
|
|
|
18
18
|
'SRNLabelDecode': '.srn_postprocess',
|
|
19
19
|
'LISTERLabelDecode': '.lister_postprocess',
|
|
20
20
|
'MPGLabelDecode': '.mgp_postprocess',
|
|
21
|
+
'UniRecLabelDecode': '.unirec_postprocess',
|
|
22
|
+
'CMERLabelDecode': '.cmer_postprocess',
|
|
21
23
|
'GTCLabelDecode': '.' # 当前模块中的类
|
|
22
24
|
}
|
|
23
25
|
|
|
@@ -29,7 +29,7 @@ class ABINetLabelDecode(NRTRLabelDecode):
|
|
|
29
29
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
|
30
30
|
if batch is None:
|
|
31
31
|
return text
|
|
32
|
-
label = self.decode(batch[1]
|
|
32
|
+
label = self.decode(batch[1])
|
|
33
33
|
return text, label
|
|
34
34
|
|
|
35
35
|
def add_special_char(self, dict_character):
|
|
@@ -30,7 +30,7 @@ class ARLabelDecode(BaseRecLabelDecode):
|
|
|
30
30
|
if batch is None:
|
|
31
31
|
return text
|
|
32
32
|
label = batch[1]
|
|
33
|
-
label = self.decode(label[:, 1:]
|
|
33
|
+
label = self.decode(label[:, 1:])
|
|
34
34
|
return text, label
|
|
35
35
|
|
|
36
36
|
def add_special_char(self, dict_character):
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from transformers import PreTrainedTokenizerFast
|
|
3
|
+
from .ctc_postprocess import BaseRecLabelDecode
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CMERLabelDecode(BaseRecLabelDecode):
|
|
7
|
+
"""
|
|
8
|
+
Decodes model output Token IDs into text.
|
|
9
|
+
Refactored to match UniRecLabelDecode style and return format (text, score).
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
character_dict_path=None,
|
|
15
|
+
use_space_char=False,
|
|
16
|
+
tokenizer_file='./configs/rec/cmer/cmer_tokenizer/tokenizer.json',
|
|
17
|
+
**kwargs):
|
|
18
|
+
"""
|
|
19
|
+
Args:
|
|
20
|
+
character_dict_path: Path to character dict (inherited param).
|
|
21
|
+
use_space_char: Whether to use space char (inherited param).
|
|
22
|
+
tokenizer_file: Path to the tokenizer json file.
|
|
23
|
+
**kwargs: Other configurations.
|
|
24
|
+
"""
|
|
25
|
+
# 1. Call super constructor to match UniRec style
|
|
26
|
+
super(CMERLabelDecode, self).__init__(character_dict_path,
|
|
27
|
+
use_space_char)
|
|
28
|
+
|
|
29
|
+
# 2. CMER specific logic
|
|
30
|
+
self.remove_spaces = True
|
|
31
|
+
|
|
32
|
+
self.tokenizer = PreTrainedTokenizerFast(
|
|
33
|
+
tokenizer_file=tokenizer_file,
|
|
34
|
+
padding_side='right',
|
|
35
|
+
truncation_side='right',
|
|
36
|
+
pad_token='<|pad|>',
|
|
37
|
+
bos_token='<|bos|>',
|
|
38
|
+
eos_token='<|eos|>',
|
|
39
|
+
unk_token='<|unk|>',
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def get_character_num(self):
|
|
43
|
+
"""
|
|
44
|
+
Called by Trainer to determine classification layer output dimension (vocab_size).
|
|
45
|
+
"""
|
|
46
|
+
if hasattr(self.tokenizer, 'vocab_size'):
|
|
47
|
+
return self.tokenizer.vocab_size
|
|
48
|
+
elif hasattr(self.tokenizer, '__len__'):
|
|
49
|
+
return len(self.tokenizer)
|
|
50
|
+
return 0
|
|
51
|
+
|
|
52
|
+
def __call__(self, preds, batch=None, *args, **kwargs):
|
|
53
|
+
"""
|
|
54
|
+
Args:
|
|
55
|
+
preds: Tensor (Batch, Seq_Len) or Dict
|
|
56
|
+
batch: Raw batch data
|
|
57
|
+
Returns:
|
|
58
|
+
list: List of tuples [(text, score), ...]
|
|
59
|
+
"""
|
|
60
|
+
# Handle tuple/dict inputs from Trainer
|
|
61
|
+
if isinstance(preds, dict):
|
|
62
|
+
if 'cmer_pred' in preds:
|
|
63
|
+
token_ids = preds['cmer_pred']
|
|
64
|
+
elif 'maps' in preds:
|
|
65
|
+
token_ids = preds['maps']
|
|
66
|
+
else:
|
|
67
|
+
token_ids = next(iter(preds.values()))
|
|
68
|
+
else:
|
|
69
|
+
token_ids = preds
|
|
70
|
+
|
|
71
|
+
if isinstance(token_ids, torch.Tensor):
|
|
72
|
+
token_ids = token_ids.cpu()
|
|
73
|
+
|
|
74
|
+
# Batch decode using tokenizer
|
|
75
|
+
decoded_texts = self.tokenizer.batch_decode(token_ids,
|
|
76
|
+
skip_special_tokens=True)
|
|
77
|
+
|
|
78
|
+
result_list = []
|
|
79
|
+
for text in decoded_texts:
|
|
80
|
+
if self.remove_spaces:
|
|
81
|
+
text = text.replace(' ', '')
|
|
82
|
+
text = text.strip()
|
|
83
|
+
|
|
84
|
+
result_list.append((text, 0.0))
|
|
85
|
+
|
|
86
|
+
return result_list
|
|
@@ -37,7 +37,7 @@ class MPGLabelDecode(BaseRecLabelDecode):
|
|
|
37
37
|
if batch is None:
|
|
38
38
|
return char_text
|
|
39
39
|
label = batch[1]
|
|
40
|
-
label = self.char_decode(label[:, 1:]
|
|
40
|
+
label = self.char_decode(label[:, 1:])
|
|
41
41
|
if self.only_char:
|
|
42
42
|
return char_text, label
|
|
43
43
|
else:
|
|
@@ -33,7 +33,7 @@ class NRTRLabelDecode(BaseRecLabelDecode):
|
|
|
33
33
|
is_remove_duplicate=False)
|
|
34
34
|
if batch is None:
|
|
35
35
|
return text
|
|
36
|
-
label = self.decode(batch[1][:, 1:]
|
|
36
|
+
label = self.decode(batch[1][:, 1:])
|
|
37
37
|
else:
|
|
38
38
|
if isinstance(preds, torch.Tensor):
|
|
39
39
|
preds = preds.detach().cpu().numpy()
|
|
@@ -44,7 +44,7 @@ class NRTRLabelDecode(BaseRecLabelDecode):
|
|
|
44
44
|
is_remove_duplicate=False)
|
|
45
45
|
if batch is None:
|
|
46
46
|
return text
|
|
47
|
-
label = self.decode(batch[1][:, 1:]
|
|
47
|
+
label = self.decode(batch[1][:, 1:])
|
|
48
48
|
return text, label
|
|
49
49
|
|
|
50
50
|
def add_special_char(self, dict_character):
|