openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,643 @@
1
+ import math
2
+ from collections import OrderedDict
3
+ from contextlib import nullcontext
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn.attention import sdpa_kernel, SDPBackend
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+ from transformers import (
13
+ GenerationMixin,
14
+ MBartConfig,
15
+ PretrainedConfig,
16
+ PreTrainedModel,
17
+ )
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutput,
20
+ CausalLMOutputWithCrossAttentions,
21
+ )
22
+ from transformers.models.mbart.modeling_mbart import MBartDecoder
23
+
24
+
25
+ class ResidualBlock(nn.Module):
26
+
27
+ def __init__(self, in_channels, out_channels, stride=1):
28
+ super().__init__()
29
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1)
30
+ self.bn1 = nn.BatchNorm2d(out_channels)
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
33
+ self.bn2 = nn.BatchNorm2d(out_channels)
34
+ self.short = nn.Identity()
35
+ if stride != 1 or in_channels != out_channels:
36
+ self.short = nn.Sequential(
37
+ nn.Conv2d(in_channels, out_channels, 1, stride),
38
+ nn.BatchNorm2d(out_channels))
39
+
40
+ def forward(self, x):
41
+ y = self.relu(self.bn1(self.conv1(x)))
42
+ y = self.bn2(self.conv2(y))
43
+ return self.relu(y + self.short(x))
44
+
45
+
46
+ class RMSNorm(nn.Module):
47
+
48
+ def __init__(self, dim: int, eps: float = 1e-6):
49
+ super().__init__()
50
+ self.eps = eps
51
+ self.weight = nn.Parameter(torch.ones(dim))
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ var = x.pow(2).mean(dim=-1, keepdim=True)
55
+ inv_rms = torch.rsqrt(var + self.eps)
56
+ return x * inv_rms * self.weight
57
+
58
+
59
+ class SwiGLU(nn.Module):
60
+
61
+ def __init__(self,
62
+ in_features: int,
63
+ hidden_features: int,
64
+ bias: bool = True):
65
+ super().__init__()
66
+ self.up = nn.Linear(in_features, hidden_features, bias=bias)
67
+ self.gate = nn.Linear(in_features, hidden_features, bias=bias)
68
+ self.act = nn.SiLU()
69
+
70
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
71
+ return self.up(x) * self.act(self.gate(x))
72
+
73
+
74
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
75
+ x_even = x[..., ::2]
76
+ x_odd = x[..., 1::2]
77
+ return torch.stack((-x_odd, x_even), dim=-1).reshape_as(x)
78
+
79
+
80
+ def apply_rope2d(q: torch.Tensor, k: torch.Tensor, cos_sin_cache):
81
+ cos_y, sin_y, cos_x, sin_x = cos_sin_cache
82
+ B, nH, M, dH = q.shape
83
+ half = dH // 2
84
+ cy = cos_y.view(1, 1, M, half)
85
+ sy = sin_y.view(1, 1, M, half)
86
+ cx = cos_x.view(1, 1, M, half)
87
+ sx = sin_x.view(1, 1, M, half)
88
+ qy, qx = q[..., :half], q[..., half:]
89
+ ky, kx = k[..., :half], k[..., half:]
90
+ qy = qy * cy + _rotate_half(qy) * sy
91
+ qx = qx * cx + _rotate_half(qx) * sx
92
+ ky = ky * cy + _rotate_half(ky) * sy
93
+ kx = kx * cx + _rotate_half(kx) * sx
94
+ return torch.cat([qy, qx], dim=-1), torch.cat([ky, kx], dim=-1)
95
+
96
+
97
+ class RoPEMHA(nn.Module):
98
+
99
+ def __init__(self,
100
+ dim: int,
101
+ num_heads: int,
102
+ attn_drop: float = 0.1,
103
+ proj_drop: float = 0.0):
104
+ super().__init__()
105
+ assert dim % num_heads == 0
106
+ self.dim = dim
107
+ self.num_heads = num_heads
108
+ self.head_dim = dim // num_heads
109
+ self.scale = self.head_dim**-0.5
110
+ self.q_proj = nn.Linear(dim, dim, bias=True)
111
+ self.k_proj = nn.Linear(dim, dim, bias=True)
112
+ self.v_proj = nn.Linear(dim, dim, bias=True)
113
+ self.attn_drop = nn.Dropout(attn_drop)
114
+ self.out_proj = nn.Linear(dim, dim, bias=True)
115
+ self.proj_drop = nn.Dropout(proj_drop)
116
+
117
+ def forward(self, x: torch.Tensor, cos_sin_cache):
118
+ B, M, D = x.shape
119
+ H, Hd = self.num_heads, self.head_dim
120
+ assert D == H * Hd, f'D={D}, H*Hd={H * Hd}'
121
+ q = self.q_proj(x).view(B, M, H, Hd).transpose(1, 2).contiguous()
122
+ k = self.k_proj(x).view(B, M, H, Hd).transpose(1, 2).contiguous()
123
+ v = self.v_proj(x).view(B, M, H, Hd).transpose(1, 2).contiguous()
124
+ q, k = apply_rope2d(q, k, cos_sin_cache)
125
+ drop_p = self.attn_drop.p if self.training else 0.0
126
+ ctx = (sdpa_kernel([
127
+ SDPBackend.FLASH_ATTENTION,
128
+ SDPBackend.EFFICIENT_ATTENTION,
129
+ SDPBackend.MATH,
130
+ ]) if torch.cuda.is_available() else nullcontext())
131
+ with ctx:
132
+ attn = F.scaled_dot_product_attention(
133
+ q,
134
+ k,
135
+ v,
136
+ attn_mask=None,
137
+ dropout_p=drop_p,
138
+ is_causal=False,
139
+ scale=self.scale,
140
+ )
141
+ attn = attn.transpose(1, 2).contiguous().view(B, M, D)
142
+ y = self.out_proj(attn)
143
+ return self.proj_drop(y)
144
+
145
+
146
+ class PreNormDecoderLayer(nn.Module):
147
+
148
+ def __init__(self,
149
+ hidden_dim: int,
150
+ num_heads: int,
151
+ attn_drop_rate: float = 0.1,
152
+ ffn_ratio: float = 4.0):
153
+ super().__init__()
154
+ self.norm1 = RMSNorm(hidden_dim, eps=1e-6)
155
+ self.mha = RoPEMHA(hidden_dim,
156
+ num_heads,
157
+ attn_drop=attn_drop_rate,
158
+ proj_drop=attn_drop_rate)
159
+ self.norm2 = RMSNorm(hidden_dim, eps=1e-6)
160
+ inner = max(1, int(hidden_dim * ffn_ratio))
161
+ self.ffn = SwiGLU(hidden_dim, inner)
162
+ self.fc_out = nn.Linear(inner, hidden_dim)
163
+ self.drop = nn.Dropout(attn_drop_rate)
164
+
165
+ def forward(self, x: torch.Tensor, cos_sin_cache):
166
+ h = self.norm1(x)
167
+ h = self.mha(h, cos_sin_cache)
168
+ x = x + h
169
+ h2 = self.norm2(x)
170
+ h2 = self.fc_out(self.ffn(h2))
171
+ return x + self.drop(h2)
172
+
173
+
174
+ class CMEREncoder(nn.Module):
175
+
176
+ def __init__(self,
177
+ num_layers: int,
178
+ num_heads: int,
179
+ hidden_dim: int,
180
+ *,
181
+ down_sample_ratio: int = 16,
182
+ rope_base: float = 10000.0,
183
+ gradient_checkpointing: bool = False):
184
+ super().__init__()
185
+ self.down_sample_ratio = int(down_sample_ratio)
186
+ self.hidden_dim = int(hidden_dim)
187
+ self.gradient_checkpointing = bool(gradient_checkpointing)
188
+ self.rope_base = float(rope_base)
189
+ self.head_dim = hidden_dim // num_heads
190
+ channels = [3, 12, 24, 48, 96, 192, 384, 768]
191
+ self.residual_blocks = nn.ModuleList([
192
+ ResidualBlock(channels[0], channels[1], stride=2),
193
+ ResidualBlock(channels[1], channels[2], stride=1),
194
+ ResidualBlock(channels[2], channels[3], stride=2),
195
+ ResidualBlock(channels[3], channels[4], stride=1),
196
+ ResidualBlock(channels[4], channels[5], stride=2),
197
+ ResidualBlock(channels[5],
198
+ channels[6],
199
+ stride=2 if down_sample_ratio > 16 else 1),
200
+ ResidualBlock(channels[6], channels[7], stride=2),
201
+ ])
202
+ self.fc = nn.Linear(channels[-1], hidden_dim)
203
+ self.vit = nn.ModuleList([
204
+ PreNormDecoderLayer(hidden_dim, num_heads)
205
+ for _ in range(num_layers)
206
+ ])
207
+ self.rope_cache = OrderedDict()
208
+ self.max_rope_cache = getattr(self, 'max_rope_cache', 32)
209
+
210
+ def train(self, mode: bool = True):
211
+ prev = self.training
212
+ super().train(mode)
213
+ if mode != prev:
214
+ self.rope_cache.clear()
215
+ return self
216
+
217
+ def eval(self):
218
+ prev = self.training
219
+ super().eval()
220
+ if prev:
221
+ self.rope_cache.clear()
222
+ return self
223
+
224
+ def clear_rope_cache(self):
225
+ self.rope_cache.clear()
226
+
227
+ def _build_rope2d_cache(self, H: int, W: int, device, dtype):
228
+ H = int(H)
229
+ W = int(W)
230
+ key = (H, W, int(self.head_dim))
231
+ if key in self.rope_cache:
232
+ cos_y_cpu, sin_y_cpu, cos_x_cpu, sin_x_cpu = self.rope_cache[key]
233
+ self.rope_cache.move_to_end(key)
234
+ else:
235
+ head_dim = self.head_dim
236
+ assert head_dim % 4 == 0, '2D RoPE 需要 head_dim 能被 4 整除'
237
+ half = head_dim // 2
238
+ inv_freq = 1.0 / (self.rope_base**(torch.arange(
239
+ 0, half, 2, device='cpu', dtype=torch.float32) / half))
240
+ pos_y = torch.arange(H, device='cpu', dtype=torch.float32)
241
+ pos_x = torch.arange(W, device='cpu', dtype=torch.float32)
242
+ freqs_y = torch.einsum('i,j->ij', pos_y, inv_freq)
243
+ freqs_x = torch.einsum('i,j->ij', pos_x, inv_freq)
244
+ cos_y_1d = torch.cos(freqs_y).repeat_interleave(2, dim=-1)
245
+ sin_y_1d = torch.sin(freqs_y).repeat_interleave(2, dim=-1)
246
+ cos_x_1d = torch.cos(freqs_x).repeat_interleave(2, dim=-1)
247
+ sin_x_1d = torch.sin(freqs_x).repeat_interleave(2, dim=-1)
248
+ cos_y = cos_y_1d[:, None, :].expand(H, W,
249
+ half).reshape(H * W, half)
250
+ sin_y = sin_y_1d[:, None, :].expand(H, W,
251
+ half).reshape(H * W, half)
252
+ cos_x = cos_x_1d[None, :, :].expand(H, W,
253
+ half).reshape(H * W, half)
254
+ sin_x = sin_x_1d[None, :, :].expand(H, W,
255
+ half).reshape(H * W, half)
256
+ entry = tuple(
257
+ t.to(torch.float16).pin_memory()
258
+ for t in (cos_y, sin_y, cos_x, sin_x))
259
+ self.rope_cache[key] = entry
260
+ while len(self.rope_cache) > int(self.max_rope_cache):
261
+ self.rope_cache.popitem(last=False)
262
+ cos_y_cpu, sin_y_cpu, cos_x_cpu, sin_x_cpu = entry
263
+ cos_y = cos_y_cpu.to(device=device, dtype=dtype, non_blocking=True)
264
+ sin_y = sin_y_cpu.to(device=device, dtype=dtype, non_blocking=True)
265
+ cos_x = cos_x_cpu.to(device=device, dtype=dtype, non_blocking=True)
266
+ sin_x = sin_x_cpu.to(device=device, dtype=dtype, non_blocking=True)
267
+ if self.training and torch.is_grad_enabled():
268
+ cos_y = cos_y.clone()
269
+ sin_y = sin_y.clone()
270
+ cos_x = cos_x.clone()
271
+ sin_x = sin_x.clone()
272
+ return (cos_y, sin_y, cos_x, sin_x)
273
+
274
+ def forward(self, pixel_values: torch.Tensor):
275
+ x = pixel_values
276
+ for blk in self.residual_blocks:
277
+ x = blk(x)
278
+ N, C, Hc, Wc = x.shape
279
+ seq = x.flatten(2).transpose(1, 2)
280
+ seq = self.fc(seq)
281
+ cos_sin_cache = self._build_rope2d_cache(Hc, Wc, seq.device, seq.dtype)
282
+ if self.gradient_checkpointing and self.training and torch.is_grad_enabled(
283
+ ):
284
+
285
+ def _run_layer(layer, s, cache):
286
+ return layer(s, cache)
287
+
288
+ for layer in self.vit:
289
+ seq = checkpoint(_run_layer,
290
+ layer,
291
+ seq,
292
+ cos_sin_cache,
293
+ use_reentrant=False)
294
+ else:
295
+ for layer in self.vit:
296
+ seq = layer(seq, cos_sin_cache)
297
+ return seq
298
+
299
+
300
+ class CMERConfig(PretrainedConfig):
301
+ model_type = 'CMER'
302
+
303
+ def __init__(self, vision_config=None, decoder_config=None, **kwargs):
304
+ self.vision_config = vision_config if vision_config is not None else {}
305
+ self.decoder_config = decoder_config if decoder_config is not None else {}
306
+ if self.decoder_config:
307
+ for key, value in self.decoder_config.items():
308
+ setattr(self, key, value)
309
+ if hasattr(self, 'decoder_layers'):
310
+ self.num_hidden_layers = self.decoder_layers
311
+ super().__init__(**kwargs, **self.decoder_config)
312
+
313
+
314
+ class CMER(PreTrainedModel, GenerationMixin):
315
+ config_class = CMERConfig
316
+ base_model_prefix = 'cmer'
317
+ main_input_name = 'pixel_values'
318
+
319
+ def __init__(self, config: CMERConfig):
320
+ super().__init__(config)
321
+ self.config = config
322
+ decoder_config = MBartConfig(**config.decoder_config)
323
+ self.vision_model = CMEREncoder(**config.vision_config)
324
+ self.llm_model = MBartDecoder(decoder_config)
325
+ self.lm_head = torch.nn.Linear(decoder_config.d_model,
326
+ decoder_config.vocab_size,
327
+ bias=False)
328
+ setattr(self.config, 'tie_word_embeddings', True)
329
+ self.tie_weights()
330
+ setattr(self.lm_head, '_dynamic_tied_weights_keys', ['weight'])
331
+ setattr(self.llm_model.embed_tokens, '_dynamic_tied_weights_keys',
332
+ ['weight'])
333
+
334
+ def set_gradient_checkpointing(self, enable: bool = True):
335
+ self.gradient_checkpointing = bool(enable)
336
+ if hasattr(self.vision_model, 'set_gradient_checkpointing'):
337
+ self.vision_model.gradient_checkpointing = self.gradient_checkpointing
338
+ if hasattr(self.llm_model, 'set_gradient_checkpointing'):
339
+ self.llm_model.gradient_checkpointing = self.gradient_checkpointing
340
+ if enable:
341
+ self.llm_model.config.use_cache = False
342
+
343
+ def get_output_embeddings(self):
344
+ return self.lm_head
345
+
346
+ def set_output_embeddings(self, new_emb):
347
+ self.lm_head = new_emb
348
+
349
+ def state_dict(self, *args, **kwargs):
350
+ sd = super().state_dict(*args, **kwargs)
351
+ if 'llm_model.embed_tokens.weight' not in sd and 'lm_head.weight' in sd:
352
+ sd['llm_model.embed_tokens.weight'] = sd['lm_head.weight']
353
+ elif 'lm_head.weight' not in sd and 'llm_model.embed_tokens.weight' in sd:
354
+ sd['lm_head.weight'] = sd['llm_model.embed_tokens.weight']
355
+ return sd
356
+
357
+ def load_state_dict(self, state_dict, strict=True):
358
+ if 'llm_model.embed_tokens.weight' not in state_dict and 'lm_head.weight' in state_dict:
359
+ state_dict['llm_model.embed_tokens.weight'] = state_dict[
360
+ 'lm_head.weight']
361
+ if 'lm_head.weight' not in state_dict and 'llm_model.embed_tokens.weight' in state_dict:
362
+ state_dict['lm_head.weight'] = state_dict[
363
+ 'llm_model.embed_tokens.weight']
364
+ out = super().load_state_dict(state_dict, strict=False)
365
+ self.tie_weights()
366
+ return out
367
+
368
+ def get_input_embeddings(self):
369
+ return self.llm_model.get_input_embeddings()
370
+
371
+ def set_input_embeddings(self, value):
372
+ self.llm_model.set_input_embeddings(value)
373
+
374
+ def get_decoder(self):
375
+ return self.llm_model.get_decoder()
376
+
377
+ def _swin_stride_and_winsize(self):
378
+ cfg = self.vision_model.config
379
+ patch = int(getattr(cfg, 'patch_size', 4))
380
+ depths = getattr(cfg, 'depths', [2, 2, 6, 2])
381
+ stride = patch * (2**(len(depths) - 1))
382
+ wsize = int(getattr(cfg, 'window_size', 7))
383
+ return stride, wsize
384
+
385
+ def _ensure_swin_safe(self, pixel_values: torch.Tensor) -> torch.Tensor:
386
+ stride, wsize = self._swin_stride_and_winsize()
387
+ B, C, H, W = pixel_values.shape
388
+ need_min = wsize * stride
389
+ if min(H, W) < need_min:
390
+ s = need_min / float(min(H, W))
391
+ new_h = math.ceil(H * s / stride) * stride
392
+ new_w = math.ceil(W * s / stride) * stride
393
+ pixel_values = F.interpolate(pixel_values,
394
+ size=(new_h, new_w),
395
+ mode='bilinear',
396
+ align_corners=False)
397
+ H, W = new_h, new_w
398
+ new_h = math.ceil(H / stride) * stride
399
+ new_w = math.ceil(W / stride) * stride
400
+ if (new_h, new_w) != (H, W):
401
+ pixel_values = F.interpolate(pixel_values,
402
+ size=(new_h, new_w),
403
+ mode='bilinear',
404
+ align_corners=False)
405
+ return pixel_values
406
+
407
+ def forward(
408
+ self,
409
+ pixel_values: Optional[torch.Tensor] = None,
410
+ decoder_input_ids: Optional[torch.Tensor] = None,
411
+ encoder_outputs: Optional[BaseModelOutput] = None,
412
+ past_key_values: Optional[tuple] = None,
413
+ labels: Optional[torch.Tensor] = None,
414
+ **kwargs,
415
+ ) -> CausalLMOutputWithCrossAttentions:
416
+
417
+ # 1. 兼容性处理:如果 Trainer 传入的是 'image' 而不是 'pixel_values'
418
+ if pixel_values is None and 'image' in kwargs:
419
+ pixel_values = kwargs.pop('image')
420
+
421
+ # 2. 兼容性处理:如果 Trainer 传入的是 'label' 而不是 'labels'
422
+ if labels is None and 'label' in kwargs:
423
+ labels = kwargs.pop('label')
424
+
425
+ # 3. Encoder Forward
426
+ if encoder_outputs is None:
427
+ if pixel_values is None:
428
+ raise ValueError(
429
+ '`pixel_values` must be provided when `encoder_outputs` is not.'
430
+ )
431
+ # pixel_values = self._ensure_swin_safe(pixel_values) # 如果需要 Swin 对齐,取消注释
432
+ encoder_outputs = self.vision_model(pixel_values)
433
+
434
+ # 4. 自动生成 decoder_input_ids (Teacher Forcing)
435
+ # 如果没有传 decoder_input_ids,但传了 labels,则使用 labels 作为输入
436
+ # 注意:Processor 已经加了 BOS/EOS,labels 格式通常为 [BOS, token1, token2, EOS]
437
+ # 输入给 Decoder 的应该是 [BOS, token1, token2, EOS]
438
+ # 计算 Loss 时,logits 会取 [:-1],labels 会取 [1:],从而实现预测下一个 token
439
+ if decoder_input_ids is None and labels is not None:
440
+ decoder_input_ids = labels.clone()
441
+ # 将 -100 (ignore_index) 替换为 pad_token_id,防止 embedding 越界
442
+ pad_token_id = self.config.decoder_config.get(
443
+ 'pad_token_id',
444
+ self.config.decoder_config.get('eos_token_id',
445
+ 1)) # 默认 fallback
446
+ decoder_input_ids.masked_fill_(decoder_input_ids == -100,
447
+ pad_token_id)
448
+
449
+ # 5. Decoder Forward
450
+ # 此时 decoder_input_ids 应该已经有值了,不会再报 ValueError
451
+ decoder_outputs = self.llm_model(
452
+ input_ids=decoder_input_ids,
453
+ inputs_embeds=None, # <--- 强制为 None,解决报错
454
+ encoder_hidden_states=encoder_outputs,
455
+ past_key_values=past_key_values,
456
+ use_cache=False,
457
+ return_dict=True,
458
+ # 注意:不要在这里传入 **kwargs,因为 kwargs 可能包含 'decoder_inputs_embeds' 等导致冲突的键
459
+ )
460
+
461
+ logits = self.lm_head(decoder_outputs.last_hidden_state)
462
+
463
+ loss = None
464
+ if labels is not None:
465
+ # Shift so that tokens < n predict n
466
+ shift_logits = logits[:, :-1, :].contiguous()
467
+ shift_labels = labels[:, 1:].contiguous()
468
+ eps = getattr(self.config, 'label_smoothing', 0.1)
469
+ loss = F.cross_entropy(
470
+ shift_logits.view(-1,
471
+ self.config.decoder_config['vocab_size']),
472
+ shift_labels.view(-1),
473
+ ignore_index=-100,
474
+ label_smoothing=eps,
475
+ )
476
+
477
+ return CausalLMOutputWithCrossAttentions(
478
+ loss=loss,
479
+ logits=logits,
480
+ past_key_values=None
481
+ if self.training else decoder_outputs.past_key_values,
482
+ hidden_states=None,
483
+ attentions=None,
484
+ cross_attentions=None,
485
+ )
486
+
487
+ def prepare_inputs_for_generation(self,
488
+ input_ids,
489
+ past_key_values=None,
490
+ **kwargs):
491
+ if past_key_values is not None:
492
+ input_ids = input_ids[:, -1:]
493
+ return {
494
+ 'decoder_input_ids': input_ids,
495
+ 'past_key_values': past_key_values,
496
+ 'encoder_outputs': kwargs.get('encoder_outputs'),
497
+ 'attention_mask': kwargs.get('attention_mask'),
498
+ }
499
+
500
+ @torch.no_grad()
501
+ def generate(
502
+ self,
503
+ pixel_values: Optional[torch.Tensor] = None,
504
+ decoder_input_ids: Optional[torch.Tensor] = None,
505
+ max_new_tokens: int = 256,
506
+ do_sample: bool = False,
507
+ temperature: float = 1.0,
508
+ top_k: int = 50,
509
+ top_p: float = 1.0,
510
+ bos_token_id: Optional[int] = 2,
511
+ eos_token_id: Optional[int] = None,
512
+ pad_token_id: Optional[int] = None,
513
+ return_only_new_tokens: bool = True,
514
+ num_beams: int = 1,
515
+ **kwargs,
516
+ ):
517
+ if num_beams != 1:
518
+ raise NotImplementedError(
519
+ '当前极简 generate 未实现 beam search(num_beams>1)。')
520
+ device = pixel_values.device if pixel_values is not None else next(
521
+ self.parameters()).device
522
+ encoder_outputs = kwargs.get('encoder_outputs', None)
523
+ if encoder_outputs is None:
524
+ if pixel_values is None:
525
+ raise ValueError(
526
+ '`pixel_values` is required if `encoder_outputs` is not provided.'
527
+ )
528
+ enc = self.vision_model(pixel_values)
529
+ encoder_hidden_states = enc
530
+ else:
531
+ if isinstance(encoder_outputs, (tuple, list)):
532
+ encoder_hidden_states = encoder_outputs[0]
533
+ elif hasattr(encoder_outputs, 'last_hidden_state'):
534
+ encoder_hidden_states = encoder_outputs.last_hidden_state
535
+ elif isinstance(encoder_outputs,
536
+ dict) and 'last_hidden_state' in encoder_outputs:
537
+ encoder_hidden_states = encoder_outputs['last_hidden_state']
538
+ else:
539
+ raise ValueError(
540
+ '`encoder_outputs` 格式不正确,缺少 last_hidden_state。')
541
+ encoder_hidden_states = encoder_hidden_states.to(device)
542
+ batch_size = encoder_hidden_states.size(0)
543
+
544
+ bos_id = bos_token_id
545
+
546
+ if eos_token_id is None:
547
+ eos_token_id = kwargs.get('eos_token_id', None)
548
+ if eos_token_id is None:
549
+ eos_token_id = -1
550
+ if pad_token_id is None:
551
+ pad_token_id = kwargs.get('pad_token_id', None)
552
+ if pad_token_id is None:
553
+ pad_token_id = bos_id
554
+ if decoder_input_ids is None:
555
+ input_ids = torch.full((batch_size, 1),
556
+ bos_id,
557
+ dtype=torch.long,
558
+ device=device)
559
+ else:
560
+ input_ids = decoder_input_ids.to(device)
561
+ self.llm_model.config.use_cache = True
562
+ past_key_values = None
563
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
564
+
565
+ def _top_k_top_p_filtering(logits,
566
+ top_k=0,
567
+ top_p=1.0,
568
+ min_tokens_to_keep=1):
569
+ top_k = min(max(top_k, 0), logits.size(-1))
570
+ if top_k > 0:
571
+ kth_vals, _ = torch.topk(logits, top_k)
572
+ min_thresh = kth_vals[..., -1, None]
573
+ logits = torch.where(logits < min_thresh,
574
+ torch.full_like(logits, float('-inf')),
575
+ logits)
576
+ if top_p < 1.0:
577
+ sorted_logits, sorted_indices = torch.sort(logits,
578
+ descending=True)
579
+ probs = torch.softmax(sorted_logits, dim=-1)
580
+ cumulative_probs = probs.cumsum(dim=-1)
581
+ sorted_mask = cumulative_probs > top_p
582
+ if min_tokens_to_keep > 0:
583
+ sorted_mask[..., :min_tokens_to_keep] = 0
584
+ sorted_logits = sorted_logits.masked_fill(
585
+ sorted_mask, float('-inf'))
586
+ logits = torch.zeros_like(logits).scatter(dim=-1,
587
+ index=sorted_indices,
588
+ src=sorted_logits)
589
+ return logits
590
+
591
+ for _ in range(max_new_tokens):
592
+ dec_in = input_ids[:, -1:]
593
+ dec_out = self.llm_model(
594
+ input_ids=dec_in,
595
+ encoder_hidden_states=encoder_hidden_states,
596
+ past_key_values=past_key_values,
597
+ use_cache=True,
598
+ return_dict=True,
599
+ )
600
+ past_key_values = dec_out.past_key_values
601
+ hidden = dec_out.last_hidden_state
602
+ logits = self.lm_head(hidden[:, -1, :])
603
+ if do_sample:
604
+ logits = logits / max(temperature, 1e-6)
605
+ logits = _top_k_top_p_filtering(logits,
606
+ top_k=top_k,
607
+ top_p=top_p,
608
+ min_tokens_to_keep=1)
609
+ probs = torch.softmax(logits, dim=-1)
610
+ next_tokens = torch.multinomial(probs,
611
+ num_samples=1).squeeze(-1)
612
+ else:
613
+ next_tokens = torch.argmax(logits, dim=-1)
614
+ next_tokens = torch.where(
615
+ finished, torch.full_like(next_tokens, pad_token_id),
616
+ next_tokens)
617
+ input_ids = torch.cat(
618
+ [input_ids, next_tokens.unsqueeze(-1)], dim=-1)
619
+ if eos_token_id >= 0:
620
+ finished = finished | (next_tokens == eos_token_id)
621
+ if torch.all(finished):
622
+ break
623
+ if return_only_new_tokens:
624
+ if decoder_input_ids is None:
625
+ return input_ids[:, 1:]
626
+ else:
627
+ return input_ids[:, decoder_input_ids.size(1):]
628
+ else:
629
+ return input_ids
630
+
631
+
632
+ def build_model_cmer(config):
633
+ backbone_config = config.get('Backbone', {})
634
+
635
+ vision_cfg = backbone_config.get('vision_config', {})
636
+ decoder_cfg = backbone_config.get('decoder_config', {})
637
+
638
+ cmer_config = CMERConfig(vision_config=vision_cfg,
639
+ decoder_config=decoder_cfg)
640
+
641
+ model = CMER(cmer_config)
642
+
643
+ return model
@@ -28,6 +28,7 @@ class_to_module = {
28
28
  'OTEDecoder': '.ote_decoder',
29
29
  'BUSDecoder': '.bus_decoder',
30
30
  'DptrParseq': '.dptr_parseq_clip_b_decoder',
31
+ 'MDiffDecoder': '.mdiff_decoder',
31
32
  }
32
33
 
33
34
 
@@ -132,7 +132,7 @@ class EncoderWithSVTR(nn.Module):
132
132
  z = self.conv2(z)
133
133
  # SVTR global block
134
134
  B, C, H, W = z.shape
135
- z = z.flatten(2).transpose(1, 2)
135
+ z = z.flatten(2).transpose(1, 2).contiguous()
136
136
  for blk in self.svtr_block:
137
137
  z = blk(z)
138
138
  z = self.norm(z)
@@ -186,10 +186,10 @@ class DANDecoder(nn.Module):
186
186
  torch.zeros(nB, dtype=torch.int64, device=feature.device) +
187
187
  self.bos)
188
188
  dec_seq = torch.full((nB, nT),
189
- self.ignore_index,
190
- dtype=torch.int64,
191
- device=feature.get_device())
192
-
189
+ self.ignore_index,
190
+ dtype=torch.int64,
191
+ device=feature.get_device())
192
+
193
193
  for i in range(0, nT):
194
194
  hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1),
195
195
  hidden)