openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,433 @@
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import CrossEntropyLoss
6
+
7
+ # transformers == 4.45.1
8
+ from .configuration_unirec import UniRecConfig
9
+ from transformers import M2M100PreTrainedModel
10
+ from transformers.models.m2m_100.modeling_m2m_100 import M2M100ScaledWordEmbedding, M2M100Decoder
11
+ from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput, BaseModelOutput
12
+ from transformers.generation import GenerationMixin
13
+ from openrec.modeling.encoders.focalsvtr import FocalSVTR
14
+
15
+
16
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int,
17
+ decoder_start_token_id: int):
18
+ """
19
+ Shift input ids one token to the right.
20
+ """
21
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
22
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
23
+ shifted_input_ids[:, 0] = decoder_start_token_id
24
+
25
+ if pad_token_id is None:
26
+ raise ValueError('self.model.config.pad_token_id has to be defined.')
27
+ # replace possible -100 values in labels by `pad_token_id`
28
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
29
+
30
+ return shifted_input_ids
31
+
32
+
33
+ class UniRecEncoder(M2M100PreTrainedModel):
34
+ """
35
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
36
+ [`M2M100EncoderLayer`].
37
+
38
+ Args:
39
+ config: UniRecConfig
40
+ """
41
+
42
+ def __init__(self, config: UniRecConfig):
43
+ super().__init__(config)
44
+
45
+ self.config = config
46
+
47
+ self.vision_encoder = FocalSVTR(img_size=[1408, 960],
48
+ depths=[2, 2, 9, 2],
49
+ embed_dim=96,
50
+ sub_k=[[2, 2], [2, 2], [2, 2],
51
+ [-1, -1]],
52
+ focal_levels=[3, 3, 3, 3],
53
+ max_khs=[7, 3, 3, 3],
54
+ focal_windows=[3, 3, 3, 3],
55
+ last_stage=False,
56
+ feat2d=False)
57
+
58
+ self.vision_fc = nn.Linear(config.d_model, config.d_model)
59
+
60
+ def forward(
61
+ self,
62
+ pixel_values: torch.Tensor = None,
63
+ input_ids: Optional[torch.Tensor] = None,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ head_mask: Optional[torch.Tensor] = None,
66
+ inputs_embeds: Optional[torch.Tensor] = None,
67
+ output_attentions: Optional[bool] = None,
68
+ output_hidden_states: Optional[bool] = None,
69
+ return_dict: Optional[bool] = None,
70
+ ):
71
+ r"""
72
+ Args:
73
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
74
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
75
+ provide it.
76
+
77
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
78
+ [`PreTrainedTokenizer.__call__`] for details.
79
+
80
+ [What are input IDs?](../glossary#input-ids)
81
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
82
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
83
+
84
+ - 1 for tokens that are **not masked**,
85
+ - 0 for tokens that are **masked**.
86
+
87
+ [What are attention masks?](../glossary#attention-mask)
88
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
89
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
90
+
91
+ - 1 indicates the head is **not masked**,
92
+ - 0 indicates the head is **masked**.
93
+
94
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
95
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
96
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
97
+ than the model's internal embedding lookup matrix.
98
+ output_attentions (`bool`, *optional*):
99
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
100
+ returned tensors for more detail.
101
+ output_hidden_states (`bool`, *optional*):
102
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
103
+ for more detail.
104
+ return_dict (`bool`, *optional*):
105
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
106
+ """
107
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
108
+ output_hidden_states = (output_hidden_states
109
+ if output_hidden_states is not None else
110
+ self.config.output_hidden_states)
111
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
112
+ # print('visionencoder pixel_values', pixel_values)
113
+ # retrieve input_ids and inputs_embeds
114
+
115
+ encoder_states = () if output_hidden_states else None
116
+ all_attentions = () if output_attentions else None
117
+
118
+ hidden_states = self.vision_encoder(pixel_values)
119
+ hidden_states = self.vision_fc(hidden_states)
120
+
121
+ # hidden_states = self.layer_norm(hidden_states)
122
+
123
+ if output_hidden_states:
124
+ encoder_states = (hidden_states, )
125
+
126
+ if not return_dict:
127
+ return tuple(
128
+ v for v in [hidden_states, encoder_states, all_attentions]
129
+ if v is not None)
130
+ return BaseModelOutput(last_hidden_state=hidden_states,
131
+ hidden_states=encoder_states,
132
+ attentions=all_attentions)
133
+
134
+
135
+ class UniRecModel(M2M100PreTrainedModel):
136
+ _tied_weights_keys = [
137
+ 'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'
138
+ ]
139
+
140
+ def __init__(self, config: UniRecConfig):
141
+ super().__init__(config)
142
+
143
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
144
+ embed_scale = math.sqrt(
145
+ config.d_model) if config.scale_embedding else 1.0
146
+ self.shared = M2M100ScaledWordEmbedding(vocab_size,
147
+ config.d_model,
148
+ padding_idx,
149
+ embed_scale=embed_scale)
150
+
151
+ self.encoder = UniRecEncoder(config)
152
+ self.decoder = M2M100Decoder(config, self.shared)
153
+
154
+ # Initialize weights and apply final processing
155
+ self.post_init()
156
+
157
+ def get_input_embeddings(self):
158
+ return self.shared
159
+
160
+ def set_input_embeddings(self, value):
161
+ self.shared = value
162
+ self.decoder.embed_tokens = self.shared
163
+
164
+ def _tie_weights(self):
165
+ if self.config.tie_word_embeddings:
166
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
167
+
168
+ def get_encoder(self):
169
+ return self.encoder
170
+
171
+ def get_decoder(self):
172
+ return self.decoder
173
+
174
+ def forward(
175
+ self,
176
+ pixel_values: torch.Tensor = None,
177
+ input_ids: Optional[torch.LongTensor] = None,
178
+ attention_mask: Optional[torch.Tensor] = None,
179
+ decoder_input_ids: Optional[torch.LongTensor] = None,
180
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
181
+ head_mask: Optional[torch.Tensor] = None,
182
+ decoder_head_mask: Optional[torch.Tensor] = None,
183
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
184
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
185
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
186
+ inputs_embeds: Optional[torch.FloatTensor] = None,
187
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
188
+ use_cache: Optional[bool] = None,
189
+ output_attentions: Optional[bool] = None,
190
+ output_hidden_states: Optional[bool] = None,
191
+ return_dict: Optional[bool] = None,
192
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
193
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
194
+ output_hidden_states = (output_hidden_states
195
+ if output_hidden_states is not None else
196
+ self.config.output_hidden_states)
197
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
198
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
199
+
200
+ if encoder_outputs is None:
201
+ encoder_outputs = self.encoder(pixel_values)
202
+
203
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
204
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
205
+ encoder_outputs = BaseModelOutput(
206
+ last_hidden_state=encoder_outputs[0],
207
+ hidden_states=encoder_outputs[1]
208
+ if len(encoder_outputs) > 1 else None,
209
+ attentions=encoder_outputs[2]
210
+ if len(encoder_outputs) > 2 else None,
211
+ )
212
+ attention_mask = torch.ones(encoder_outputs[0].shape[:2],
213
+ dtype=torch.long,
214
+ device=encoder_outputs[0].device)
215
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
216
+ decoder_outputs = self.decoder(
217
+ input_ids=decoder_input_ids,
218
+ attention_mask=decoder_attention_mask,
219
+ encoder_hidden_states=encoder_outputs[0],
220
+ encoder_attention_mask=attention_mask,
221
+ head_mask=decoder_head_mask,
222
+ cross_attn_head_mask=cross_attn_head_mask,
223
+ past_key_values=past_key_values,
224
+ inputs_embeds=decoder_inputs_embeds,
225
+ use_cache=use_cache,
226
+ output_attentions=output_attentions,
227
+ output_hidden_states=output_hidden_states,
228
+ return_dict=return_dict,
229
+ )
230
+
231
+ if not return_dict:
232
+ return decoder_outputs + encoder_outputs
233
+
234
+ return Seq2SeqModelOutput(
235
+ last_hidden_state=decoder_outputs.last_hidden_state,
236
+ past_key_values=decoder_outputs.past_key_values,
237
+ decoder_hidden_states=decoder_outputs.hidden_states,
238
+ decoder_attentions=decoder_outputs.attentions,
239
+ cross_attentions=decoder_outputs.cross_attentions,
240
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
241
+ encoder_hidden_states=encoder_outputs.hidden_states,
242
+ encoder_attentions=encoder_outputs.attentions,
243
+ )
244
+
245
+
246
+ class UniRecForConditionalGenerationNew(M2M100PreTrainedModel,
247
+ GenerationMixin):
248
+ base_model_prefix = 'model'
249
+ _tied_weights_keys = [
250
+ 'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight',
251
+ 'lm_head.weight'
252
+ ]
253
+
254
+ def __init__(self, config: UniRecConfig):
255
+ super().__init__(config)
256
+ self.model = UniRecModel(config)
257
+ self.lm_head = nn.Linear(config.d_model,
258
+ self.model.shared.num_embeddings,
259
+ bias=False)
260
+ self.loss_fct = CrossEntropyLoss(
261
+ ignore_index=config.pad_token_id,
262
+ label_smoothing=config.label_smoothing)
263
+ # Initialize weights and apply final processing
264
+ self.post_init()
265
+
266
+ def get_encoder(self):
267
+ return self.model.get_encoder()
268
+
269
+ def get_decoder(self):
270
+ return self.model.get_decoder()
271
+
272
+ def get_output_embeddings(self):
273
+ return self.lm_head
274
+
275
+ def set_output_embeddings(self, new_embeddings):
276
+ self.lm_head = new_embeddings
277
+
278
+ def forward(
279
+ self,
280
+ pixel_values: torch.Tensor,
281
+ input_ids: Optional[torch.LongTensor] = None,
282
+ attention_mask: Optional[torch.Tensor] = None,
283
+ decoder_input_ids: Optional[torch.LongTensor] = None,
284
+ length: Optional[torch.LongTensor] = None,
285
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
286
+ head_mask: Optional[torch.Tensor] = None,
287
+ decoder_head_mask: Optional[torch.Tensor] = None,
288
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
289
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
290
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
291
+ inputs_embeds: Optional[torch.FloatTensor] = None,
292
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
293
+ labels: Optional[torch.LongTensor] = None,
294
+ use_cache: Optional[bool] = None,
295
+ output_attentions: Optional[bool] = None,
296
+ output_hidden_states: Optional[bool] = None,
297
+ return_dict: Optional[bool] = None,
298
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
299
+ r"""
300
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
301
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
302
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
303
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
304
+
305
+ Returns:
306
+ """
307
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
308
+
309
+ if labels is not None:
310
+ if decoder_input_ids is None:
311
+ # decoder_input_ids = shift_tokens_right(
312
+ # labels, self.config.pad_token_id, self.config.decoder_start_token_id
313
+ # )
314
+ if length is not None:
315
+ max_len = length.max()
316
+ decoder_input_ids = labels[:, :1 + max_len]
317
+ labels = labels[:, 1:2 + max_len]
318
+ else:
319
+ decoder_input_ids = labels[:, :-1]
320
+ labels = labels[:, 1:]
321
+ masked_lm_loss = None
322
+ if self.training and labels is not None:
323
+ outputs = self.model(
324
+ pixel_values=pixel_values,
325
+ input_ids=None,
326
+ attention_mask=attention_mask,
327
+ decoder_input_ids=decoder_input_ids,
328
+ encoder_outputs=encoder_outputs,
329
+ decoder_attention_mask=decoder_attention_mask,
330
+ head_mask=head_mask,
331
+ decoder_head_mask=decoder_head_mask,
332
+ cross_attn_head_mask=cross_attn_head_mask,
333
+ past_key_values=past_key_values,
334
+ inputs_embeds=inputs_embeds,
335
+ decoder_inputs_embeds=decoder_inputs_embeds,
336
+ use_cache=use_cache,
337
+ output_attentions=output_attentions,
338
+ output_hidden_states=output_hidden_states,
339
+ return_dict=return_dict,
340
+ )
341
+ lm_logits = self.lm_head(outputs[0])
342
+ masked_lm_loss = self.loss_fct(
343
+ lm_logits.reshape(-1, self.config.vocab_size),
344
+ labels.reshape(-1))
345
+ else:
346
+ # print('pixel_values', pixel_values.shape)
347
+ # print('decoder_input_ids', decoder_input_ids)
348
+ outputs = self.model(
349
+ pixel_values=pixel_values,
350
+ input_ids=None,
351
+ attention_mask=attention_mask,
352
+ decoder_input_ids=decoder_input_ids,
353
+ encoder_outputs=encoder_outputs,
354
+ decoder_attention_mask=decoder_attention_mask,
355
+ head_mask=head_mask,
356
+ decoder_head_mask=decoder_head_mask,
357
+ cross_attn_head_mask=cross_attn_head_mask,
358
+ past_key_values=past_key_values,
359
+ inputs_embeds=inputs_embeds,
360
+ decoder_inputs_embeds=decoder_inputs_embeds,
361
+ use_cache=use_cache,
362
+ output_attentions=output_attentions,
363
+ output_hidden_states=output_hidden_states,
364
+ return_dict=return_dict,
365
+ )
366
+ lm_logits = self.lm_head(outputs[0])
367
+
368
+ if not return_dict:
369
+ output = (lm_logits, ) + outputs[1:]
370
+ return ((masked_lm_loss, ) +
371
+ output) if masked_lm_loss is not None else output
372
+
373
+ return Seq2SeqLMOutput(
374
+ loss=masked_lm_loss,
375
+ logits=lm_logits,
376
+ past_key_values=outputs.past_key_values,
377
+ decoder_hidden_states=outputs.decoder_hidden_states,
378
+ decoder_attentions=outputs.decoder_attentions,
379
+ cross_attentions=outputs.cross_attentions,
380
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
381
+ encoder_hidden_states=outputs.encoder_hidden_states,
382
+ encoder_attentions=outputs.encoder_attentions,
383
+ )
384
+
385
+ def prepare_inputs_for_generation(
386
+ self,
387
+ decoder_input_ids,
388
+ past_key_values=None,
389
+ attention_mask=None,
390
+ head_mask=None,
391
+ decoder_head_mask=None,
392
+ cross_attn_head_mask=None,
393
+ use_cache=None,
394
+ encoder_outputs=None,
395
+ pixel_values=None,
396
+ **kwargs,
397
+ ):
398
+ # cut decoder_input_ids if past is used
399
+ if past_key_values is not None:
400
+ past_length = past_key_values[0][0].shape[2]
401
+
402
+ # Some generation methods already pass only the last input ID
403
+ if decoder_input_ids.shape[1] > past_length:
404
+ remove_prefix_length = past_length
405
+ else:
406
+ # Default to old behavior: keep only final ID
407
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
408
+
409
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
410
+
411
+ return {
412
+ 'input_ids':
413
+ None, # encoder_outputs is defined. input_ids not needed
414
+ 'encoder_outputs': encoder_outputs,
415
+ 'past_key_values': past_key_values,
416
+ 'decoder_input_ids': decoder_input_ids,
417
+ 'attention_mask': attention_mask,
418
+ 'head_mask': head_mask,
419
+ 'decoder_head_mask': decoder_head_mask,
420
+ 'cross_attn_head_mask': cross_attn_head_mask,
421
+ 'use_cache':
422
+ use_cache, # change this to avoid caching (presumably for debugging)
423
+ 'pixel_values': pixel_values,
424
+ }
425
+
426
+ @staticmethod
427
+ def _reorder_cache(past_key_values, beam_idx):
428
+ reordered_past = ()
429
+ for layer_past in past_key_values:
430
+ reordered_past += (tuple(
431
+ past_state.index_select(0, beam_idx.to(past_state.device))
432
+ for past_state in layer_past), )
433
+ return reordered_past
@@ -2,7 +2,6 @@ import copy
2
2
 
3
3
  import torch
4
4
  from torch import nn
5
-
6
5
  __all__ = ['build_optimizer']
7
6
 
8
7
 
@@ -63,11 +62,13 @@ def build_optimizer(optim_config, lr_scheduler_config, epochs, step_each_epoch,
63
62
  **config)
64
63
 
65
64
  lr_config = copy.deepcopy(lr_scheduler_config)
65
+ scheduler_name = lr_config.pop('name')
66
+
66
67
  lr_config.update({
67
68
  'epochs': epochs,
68
69
  'step_each_epoch': step_each_epoch,
69
70
  'lr': config['lr']
70
71
  })
71
- lr_scheduler = getattr(lr,
72
- lr_config.pop('name'))(**lr_config)(optimizer=optim)
72
+ lr_scheduler = getattr(lr, scheduler_name)(**lr_config)(optimizer=optim)
73
+
73
74
  return optim, lr_scheduler
@@ -225,3 +225,52 @@ class CdistNetLR(object):
225
225
  np.power(self.n_warmup_steps, -1.5) * current_step,
226
226
  ])
227
227
  return self.step2_lr / self.init_lr
228
+
229
+ class WarmupCosineLR(object):
230
+ def __init__(self,
231
+ epochs,
232
+ step_each_epoch,
233
+ warmup_steps=0,
234
+ eta_min=0.0,
235
+ last_epoch=-1,
236
+ **kwargs):
237
+ super(WarmupCosineLR, self).__init__()
238
+ self.total_steps = epochs * step_each_epoch
239
+ self.warmup_steps = warmup_steps
240
+ self.eta_min = eta_min
241
+ self.last_epoch = last_epoch
242
+
243
+ def __call__(self, optimizer):
244
+ schedulers = []
245
+ milestones = []
246
+
247
+ # 1. Warmup Phase
248
+ if self.warmup_steps > 0:
249
+ warmup_scheduler = lr_scheduler.LinearLR(
250
+ optimizer,
251
+ start_factor=1e-7,
252
+ end_factor=1.0,
253
+ total_iters=self.warmup_steps
254
+ )
255
+ schedulers.append(warmup_scheduler)
256
+ milestones.append(self.warmup_steps)
257
+
258
+ # 2. Cosine Phase
259
+ remain_steps = max(1, self.total_steps - self.warmup_steps)
260
+
261
+ cosine_scheduler = lr_scheduler.CosineAnnealingLR(
262
+ optimizer,
263
+ T_max=remain_steps,
264
+ eta_min=self.eta_min
265
+ )
266
+ schedulers.append(cosine_scheduler)
267
+
268
+ if len(schedulers) == 1:
269
+ return schedulers[0]
270
+ else:
271
+ return lr_scheduler.SequentialLR(
272
+ optimizer,
273
+ schedulers=schedulers,
274
+ milestones=milestones,
275
+ last_epoch=self.last_epoch
276
+ )
@@ -18,6 +18,8 @@ module_mapping = {
18
18
  'SRNLabelDecode': '.srn_postprocess',
19
19
  'LISTERLabelDecode': '.lister_postprocess',
20
20
  'MPGLabelDecode': '.mgp_postprocess',
21
+ 'UniRecLabelDecode': '.unirec_postprocess',
22
+ 'CMERLabelDecode': '.cmer_postprocess',
21
23
  'GTCLabelDecode': '.' # 当前模块中的类
22
24
  }
23
25
 
@@ -29,7 +29,7 @@ class ABINetLabelDecode(NRTRLabelDecode):
29
29
  text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
30
30
  if batch is None:
31
31
  return text
32
- label = self.decode(batch[1].cpu().numpy())
32
+ label = self.decode(batch[1])
33
33
  return text, label
34
34
 
35
35
  def add_special_char(self, dict_character):
@@ -30,7 +30,7 @@ class ARLabelDecode(BaseRecLabelDecode):
30
30
  if batch is None:
31
31
  return text
32
32
  label = batch[1]
33
- label = self.decode(label[:, 1:].detach().cpu().numpy())
33
+ label = self.decode(label[:, 1:])
34
34
  return text, label
35
35
 
36
36
  def add_special_char(self, dict_character):
@@ -0,0 +1,86 @@
1
+ import torch
2
+ from transformers import PreTrainedTokenizerFast
3
+ from .ctc_postprocess import BaseRecLabelDecode
4
+
5
+
6
+ class CMERLabelDecode(BaseRecLabelDecode):
7
+ """
8
+ Decodes model output Token IDs into text.
9
+ Refactored to match UniRecLabelDecode style and return format (text, score).
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ character_dict_path=None,
15
+ use_space_char=False,
16
+ tokenizer_file='./configs/rec/cmer/cmer_tokenizer/tokenizer.json',
17
+ **kwargs):
18
+ """
19
+ Args:
20
+ character_dict_path: Path to character dict (inherited param).
21
+ use_space_char: Whether to use space char (inherited param).
22
+ tokenizer_file: Path to the tokenizer json file.
23
+ **kwargs: Other configurations.
24
+ """
25
+ # 1. Call super constructor to match UniRec style
26
+ super(CMERLabelDecode, self).__init__(character_dict_path,
27
+ use_space_char)
28
+
29
+ # 2. CMER specific logic
30
+ self.remove_spaces = True
31
+
32
+ self.tokenizer = PreTrainedTokenizerFast(
33
+ tokenizer_file=tokenizer_file,
34
+ padding_side='right',
35
+ truncation_side='right',
36
+ pad_token='<|pad|>',
37
+ bos_token='<|bos|>',
38
+ eos_token='<|eos|>',
39
+ unk_token='<|unk|>',
40
+ )
41
+
42
+ def get_character_num(self):
43
+ """
44
+ Called by Trainer to determine classification layer output dimension (vocab_size).
45
+ """
46
+ if hasattr(self.tokenizer, 'vocab_size'):
47
+ return self.tokenizer.vocab_size
48
+ elif hasattr(self.tokenizer, '__len__'):
49
+ return len(self.tokenizer)
50
+ return 0
51
+
52
+ def __call__(self, preds, batch=None, *args, **kwargs):
53
+ """
54
+ Args:
55
+ preds: Tensor (Batch, Seq_Len) or Dict
56
+ batch: Raw batch data
57
+ Returns:
58
+ list: List of tuples [(text, score), ...]
59
+ """
60
+ # Handle tuple/dict inputs from Trainer
61
+ if isinstance(preds, dict):
62
+ if 'cmer_pred' in preds:
63
+ token_ids = preds['cmer_pred']
64
+ elif 'maps' in preds:
65
+ token_ids = preds['maps']
66
+ else:
67
+ token_ids = next(iter(preds.values()))
68
+ else:
69
+ token_ids = preds
70
+
71
+ if isinstance(token_ids, torch.Tensor):
72
+ token_ids = token_ids.cpu()
73
+
74
+ # Batch decode using tokenizer
75
+ decoded_texts = self.tokenizer.batch_decode(token_ids,
76
+ skip_special_tokens=True)
77
+
78
+ result_list = []
79
+ for text in decoded_texts:
80
+ if self.remove_spaces:
81
+ text = text.replace(' ', '')
82
+ text = text.strip()
83
+
84
+ result_list.append((text, 0.0))
85
+
86
+ return result_list
@@ -34,7 +34,7 @@ class CPPDLabelDecode(NRTRLabelDecode):
34
34
  if batch is None:
35
35
  return text
36
36
  label = batch[1]
37
- label = self.decode(label.detach().cpu().numpy())
37
+ label = self.decode(label)
38
38
  return text, label
39
39
 
40
40
  def add_special_char(self, dict_character):
@@ -53,7 +53,7 @@ class IGTRLabelDecode(NRTRLabelDecode):
53
53
  if batch is None:
54
54
  return text
55
55
  label = batch[1]
56
- label = self.decode(label.detach().cpu().numpy())
56
+ label = self.decode(label)
57
57
  return text, label
58
58
 
59
59
  def add_special_char(self, dict_character):
@@ -26,7 +26,7 @@ class LISTERLabelDecode(BaseRecLabelDecode):
26
26
  if batch is None:
27
27
  return text
28
28
  label = batch[1]
29
- label = self.decode(label.detach().cpu().numpy())
29
+ label = self.decode(label)
30
30
  return text, label
31
31
 
32
32
  def add_special_char(self, dict_character):
@@ -37,7 +37,7 @@ class MPGLabelDecode(BaseRecLabelDecode):
37
37
  if batch is None:
38
38
  return char_text
39
39
  label = batch[1]
40
- label = self.char_decode(label[:, 1:].detach().cpu().numpy())
40
+ label = self.char_decode(label[:, 1:])
41
41
  if self.only_char:
42
42
  return char_text, label
43
43
  else:
@@ -33,7 +33,7 @@ class NRTRLabelDecode(BaseRecLabelDecode):
33
33
  is_remove_duplicate=False)
34
34
  if batch is None:
35
35
  return text
36
- label = self.decode(batch[1][:, 1:].cpu().numpy())
36
+ label = self.decode(batch[1][:, 1:])
37
37
  else:
38
38
  if isinstance(preds, torch.Tensor):
39
39
  preds = preds.detach().cpu().numpy()
@@ -44,7 +44,7 @@ class NRTRLabelDecode(BaseRecLabelDecode):
44
44
  is_remove_duplicate=False)
45
45
  if batch is None:
46
46
  return text
47
- label = self.decode(batch[1][:, 1:].cpu().numpy())
47
+ label = self.decode(batch[1][:, 1:])
48
48
  return text, label
49
49
 
50
50
  def add_special_char(self, dict_character):