openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,643 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections import OrderedDict
|
|
3
|
+
from contextlib import nullcontext
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from torch.nn.attention import sdpa_kernel, SDPBackend
|
|
10
|
+
from torch.utils.checkpoint import checkpoint
|
|
11
|
+
|
|
12
|
+
from transformers import (
|
|
13
|
+
GenerationMixin,
|
|
14
|
+
MBartConfig,
|
|
15
|
+
PretrainedConfig,
|
|
16
|
+
PreTrainedModel,
|
|
17
|
+
)
|
|
18
|
+
from transformers.modeling_outputs import (
|
|
19
|
+
BaseModelOutput,
|
|
20
|
+
CausalLMOutputWithCrossAttentions,
|
|
21
|
+
)
|
|
22
|
+
from transformers.models.mbart.modeling_mbart import MBartDecoder
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ResidualBlock(nn.Module):
|
|
26
|
+
|
|
27
|
+
def __init__(self, in_channels, out_channels, stride=1):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1)
|
|
30
|
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
|
31
|
+
self.relu = nn.ReLU(inplace=True)
|
|
32
|
+
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
|
|
33
|
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
|
34
|
+
self.short = nn.Identity()
|
|
35
|
+
if stride != 1 or in_channels != out_channels:
|
|
36
|
+
self.short = nn.Sequential(
|
|
37
|
+
nn.Conv2d(in_channels, out_channels, 1, stride),
|
|
38
|
+
nn.BatchNorm2d(out_channels))
|
|
39
|
+
|
|
40
|
+
def forward(self, x):
|
|
41
|
+
y = self.relu(self.bn1(self.conv1(x)))
|
|
42
|
+
y = self.bn2(self.conv2(y))
|
|
43
|
+
return self.relu(y + self.short(x))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class RMSNorm(nn.Module):
|
|
47
|
+
|
|
48
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
49
|
+
super().__init__()
|
|
50
|
+
self.eps = eps
|
|
51
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
|
52
|
+
|
|
53
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
54
|
+
var = x.pow(2).mean(dim=-1, keepdim=True)
|
|
55
|
+
inv_rms = torch.rsqrt(var + self.eps)
|
|
56
|
+
return x * inv_rms * self.weight
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SwiGLU(nn.Module):
|
|
60
|
+
|
|
61
|
+
def __init__(self,
|
|
62
|
+
in_features: int,
|
|
63
|
+
hidden_features: int,
|
|
64
|
+
bias: bool = True):
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.up = nn.Linear(in_features, hidden_features, bias=bias)
|
|
67
|
+
self.gate = nn.Linear(in_features, hidden_features, bias=bias)
|
|
68
|
+
self.act = nn.SiLU()
|
|
69
|
+
|
|
70
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
71
|
+
return self.up(x) * self.act(self.gate(x))
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
75
|
+
x_even = x[..., ::2]
|
|
76
|
+
x_odd = x[..., 1::2]
|
|
77
|
+
return torch.stack((-x_odd, x_even), dim=-1).reshape_as(x)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def apply_rope2d(q: torch.Tensor, k: torch.Tensor, cos_sin_cache):
|
|
81
|
+
cos_y, sin_y, cos_x, sin_x = cos_sin_cache
|
|
82
|
+
B, nH, M, dH = q.shape
|
|
83
|
+
half = dH // 2
|
|
84
|
+
cy = cos_y.view(1, 1, M, half)
|
|
85
|
+
sy = sin_y.view(1, 1, M, half)
|
|
86
|
+
cx = cos_x.view(1, 1, M, half)
|
|
87
|
+
sx = sin_x.view(1, 1, M, half)
|
|
88
|
+
qy, qx = q[..., :half], q[..., half:]
|
|
89
|
+
ky, kx = k[..., :half], k[..., half:]
|
|
90
|
+
qy = qy * cy + _rotate_half(qy) * sy
|
|
91
|
+
qx = qx * cx + _rotate_half(qx) * sx
|
|
92
|
+
ky = ky * cy + _rotate_half(ky) * sy
|
|
93
|
+
kx = kx * cx + _rotate_half(kx) * sx
|
|
94
|
+
return torch.cat([qy, qx], dim=-1), torch.cat([ky, kx], dim=-1)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class RoPEMHA(nn.Module):
|
|
98
|
+
|
|
99
|
+
def __init__(self,
|
|
100
|
+
dim: int,
|
|
101
|
+
num_heads: int,
|
|
102
|
+
attn_drop: float = 0.1,
|
|
103
|
+
proj_drop: float = 0.0):
|
|
104
|
+
super().__init__()
|
|
105
|
+
assert dim % num_heads == 0
|
|
106
|
+
self.dim = dim
|
|
107
|
+
self.num_heads = num_heads
|
|
108
|
+
self.head_dim = dim // num_heads
|
|
109
|
+
self.scale = self.head_dim**-0.5
|
|
110
|
+
self.q_proj = nn.Linear(dim, dim, bias=True)
|
|
111
|
+
self.k_proj = nn.Linear(dim, dim, bias=True)
|
|
112
|
+
self.v_proj = nn.Linear(dim, dim, bias=True)
|
|
113
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
114
|
+
self.out_proj = nn.Linear(dim, dim, bias=True)
|
|
115
|
+
self.proj_drop = nn.Dropout(proj_drop)
|
|
116
|
+
|
|
117
|
+
def forward(self, x: torch.Tensor, cos_sin_cache):
|
|
118
|
+
B, M, D = x.shape
|
|
119
|
+
H, Hd = self.num_heads, self.head_dim
|
|
120
|
+
assert D == H * Hd, f'D={D}, H*Hd={H * Hd}'
|
|
121
|
+
q = self.q_proj(x).view(B, M, H, Hd).transpose(1, 2).contiguous()
|
|
122
|
+
k = self.k_proj(x).view(B, M, H, Hd).transpose(1, 2).contiguous()
|
|
123
|
+
v = self.v_proj(x).view(B, M, H, Hd).transpose(1, 2).contiguous()
|
|
124
|
+
q, k = apply_rope2d(q, k, cos_sin_cache)
|
|
125
|
+
drop_p = self.attn_drop.p if self.training else 0.0
|
|
126
|
+
ctx = (sdpa_kernel([
|
|
127
|
+
SDPBackend.FLASH_ATTENTION,
|
|
128
|
+
SDPBackend.EFFICIENT_ATTENTION,
|
|
129
|
+
SDPBackend.MATH,
|
|
130
|
+
]) if torch.cuda.is_available() else nullcontext())
|
|
131
|
+
with ctx:
|
|
132
|
+
attn = F.scaled_dot_product_attention(
|
|
133
|
+
q,
|
|
134
|
+
k,
|
|
135
|
+
v,
|
|
136
|
+
attn_mask=None,
|
|
137
|
+
dropout_p=drop_p,
|
|
138
|
+
is_causal=False,
|
|
139
|
+
scale=self.scale,
|
|
140
|
+
)
|
|
141
|
+
attn = attn.transpose(1, 2).contiguous().view(B, M, D)
|
|
142
|
+
y = self.out_proj(attn)
|
|
143
|
+
return self.proj_drop(y)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class PreNormDecoderLayer(nn.Module):
|
|
147
|
+
|
|
148
|
+
def __init__(self,
|
|
149
|
+
hidden_dim: int,
|
|
150
|
+
num_heads: int,
|
|
151
|
+
attn_drop_rate: float = 0.1,
|
|
152
|
+
ffn_ratio: float = 4.0):
|
|
153
|
+
super().__init__()
|
|
154
|
+
self.norm1 = RMSNorm(hidden_dim, eps=1e-6)
|
|
155
|
+
self.mha = RoPEMHA(hidden_dim,
|
|
156
|
+
num_heads,
|
|
157
|
+
attn_drop=attn_drop_rate,
|
|
158
|
+
proj_drop=attn_drop_rate)
|
|
159
|
+
self.norm2 = RMSNorm(hidden_dim, eps=1e-6)
|
|
160
|
+
inner = max(1, int(hidden_dim * ffn_ratio))
|
|
161
|
+
self.ffn = SwiGLU(hidden_dim, inner)
|
|
162
|
+
self.fc_out = nn.Linear(inner, hidden_dim)
|
|
163
|
+
self.drop = nn.Dropout(attn_drop_rate)
|
|
164
|
+
|
|
165
|
+
def forward(self, x: torch.Tensor, cos_sin_cache):
|
|
166
|
+
h = self.norm1(x)
|
|
167
|
+
h = self.mha(h, cos_sin_cache)
|
|
168
|
+
x = x + h
|
|
169
|
+
h2 = self.norm2(x)
|
|
170
|
+
h2 = self.fc_out(self.ffn(h2))
|
|
171
|
+
return x + self.drop(h2)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class CMEREncoder(nn.Module):
|
|
175
|
+
|
|
176
|
+
def __init__(self,
|
|
177
|
+
num_layers: int,
|
|
178
|
+
num_heads: int,
|
|
179
|
+
hidden_dim: int,
|
|
180
|
+
*,
|
|
181
|
+
down_sample_ratio: int = 16,
|
|
182
|
+
rope_base: float = 10000.0,
|
|
183
|
+
gradient_checkpointing: bool = False):
|
|
184
|
+
super().__init__()
|
|
185
|
+
self.down_sample_ratio = int(down_sample_ratio)
|
|
186
|
+
self.hidden_dim = int(hidden_dim)
|
|
187
|
+
self.gradient_checkpointing = bool(gradient_checkpointing)
|
|
188
|
+
self.rope_base = float(rope_base)
|
|
189
|
+
self.head_dim = hidden_dim // num_heads
|
|
190
|
+
channels = [3, 12, 24, 48, 96, 192, 384, 768]
|
|
191
|
+
self.residual_blocks = nn.ModuleList([
|
|
192
|
+
ResidualBlock(channels[0], channels[1], stride=2),
|
|
193
|
+
ResidualBlock(channels[1], channels[2], stride=1),
|
|
194
|
+
ResidualBlock(channels[2], channels[3], stride=2),
|
|
195
|
+
ResidualBlock(channels[3], channels[4], stride=1),
|
|
196
|
+
ResidualBlock(channels[4], channels[5], stride=2),
|
|
197
|
+
ResidualBlock(channels[5],
|
|
198
|
+
channels[6],
|
|
199
|
+
stride=2 if down_sample_ratio > 16 else 1),
|
|
200
|
+
ResidualBlock(channels[6], channels[7], stride=2),
|
|
201
|
+
])
|
|
202
|
+
self.fc = nn.Linear(channels[-1], hidden_dim)
|
|
203
|
+
self.vit = nn.ModuleList([
|
|
204
|
+
PreNormDecoderLayer(hidden_dim, num_heads)
|
|
205
|
+
for _ in range(num_layers)
|
|
206
|
+
])
|
|
207
|
+
self.rope_cache = OrderedDict()
|
|
208
|
+
self.max_rope_cache = getattr(self, 'max_rope_cache', 32)
|
|
209
|
+
|
|
210
|
+
def train(self, mode: bool = True):
|
|
211
|
+
prev = self.training
|
|
212
|
+
super().train(mode)
|
|
213
|
+
if mode != prev:
|
|
214
|
+
self.rope_cache.clear()
|
|
215
|
+
return self
|
|
216
|
+
|
|
217
|
+
def eval(self):
|
|
218
|
+
prev = self.training
|
|
219
|
+
super().eval()
|
|
220
|
+
if prev:
|
|
221
|
+
self.rope_cache.clear()
|
|
222
|
+
return self
|
|
223
|
+
|
|
224
|
+
def clear_rope_cache(self):
|
|
225
|
+
self.rope_cache.clear()
|
|
226
|
+
|
|
227
|
+
def _build_rope2d_cache(self, H: int, W: int, device, dtype):
|
|
228
|
+
H = int(H)
|
|
229
|
+
W = int(W)
|
|
230
|
+
key = (H, W, int(self.head_dim))
|
|
231
|
+
if key in self.rope_cache:
|
|
232
|
+
cos_y_cpu, sin_y_cpu, cos_x_cpu, sin_x_cpu = self.rope_cache[key]
|
|
233
|
+
self.rope_cache.move_to_end(key)
|
|
234
|
+
else:
|
|
235
|
+
head_dim = self.head_dim
|
|
236
|
+
assert head_dim % 4 == 0, '2D RoPE 需要 head_dim 能被 4 整除'
|
|
237
|
+
half = head_dim // 2
|
|
238
|
+
inv_freq = 1.0 / (self.rope_base**(torch.arange(
|
|
239
|
+
0, half, 2, device='cpu', dtype=torch.float32) / half))
|
|
240
|
+
pos_y = torch.arange(H, device='cpu', dtype=torch.float32)
|
|
241
|
+
pos_x = torch.arange(W, device='cpu', dtype=torch.float32)
|
|
242
|
+
freqs_y = torch.einsum('i,j->ij', pos_y, inv_freq)
|
|
243
|
+
freqs_x = torch.einsum('i,j->ij', pos_x, inv_freq)
|
|
244
|
+
cos_y_1d = torch.cos(freqs_y).repeat_interleave(2, dim=-1)
|
|
245
|
+
sin_y_1d = torch.sin(freqs_y).repeat_interleave(2, dim=-1)
|
|
246
|
+
cos_x_1d = torch.cos(freqs_x).repeat_interleave(2, dim=-1)
|
|
247
|
+
sin_x_1d = torch.sin(freqs_x).repeat_interleave(2, dim=-1)
|
|
248
|
+
cos_y = cos_y_1d[:, None, :].expand(H, W,
|
|
249
|
+
half).reshape(H * W, half)
|
|
250
|
+
sin_y = sin_y_1d[:, None, :].expand(H, W,
|
|
251
|
+
half).reshape(H * W, half)
|
|
252
|
+
cos_x = cos_x_1d[None, :, :].expand(H, W,
|
|
253
|
+
half).reshape(H * W, half)
|
|
254
|
+
sin_x = sin_x_1d[None, :, :].expand(H, W,
|
|
255
|
+
half).reshape(H * W, half)
|
|
256
|
+
entry = tuple(
|
|
257
|
+
t.to(torch.float16).pin_memory()
|
|
258
|
+
for t in (cos_y, sin_y, cos_x, sin_x))
|
|
259
|
+
self.rope_cache[key] = entry
|
|
260
|
+
while len(self.rope_cache) > int(self.max_rope_cache):
|
|
261
|
+
self.rope_cache.popitem(last=False)
|
|
262
|
+
cos_y_cpu, sin_y_cpu, cos_x_cpu, sin_x_cpu = entry
|
|
263
|
+
cos_y = cos_y_cpu.to(device=device, dtype=dtype, non_blocking=True)
|
|
264
|
+
sin_y = sin_y_cpu.to(device=device, dtype=dtype, non_blocking=True)
|
|
265
|
+
cos_x = cos_x_cpu.to(device=device, dtype=dtype, non_blocking=True)
|
|
266
|
+
sin_x = sin_x_cpu.to(device=device, dtype=dtype, non_blocking=True)
|
|
267
|
+
if self.training and torch.is_grad_enabled():
|
|
268
|
+
cos_y = cos_y.clone()
|
|
269
|
+
sin_y = sin_y.clone()
|
|
270
|
+
cos_x = cos_x.clone()
|
|
271
|
+
sin_x = sin_x.clone()
|
|
272
|
+
return (cos_y, sin_y, cos_x, sin_x)
|
|
273
|
+
|
|
274
|
+
def forward(self, pixel_values: torch.Tensor):
|
|
275
|
+
x = pixel_values
|
|
276
|
+
for blk in self.residual_blocks:
|
|
277
|
+
x = blk(x)
|
|
278
|
+
N, C, Hc, Wc = x.shape
|
|
279
|
+
seq = x.flatten(2).transpose(1, 2)
|
|
280
|
+
seq = self.fc(seq)
|
|
281
|
+
cos_sin_cache = self._build_rope2d_cache(Hc, Wc, seq.device, seq.dtype)
|
|
282
|
+
if self.gradient_checkpointing and self.training and torch.is_grad_enabled(
|
|
283
|
+
):
|
|
284
|
+
|
|
285
|
+
def _run_layer(layer, s, cache):
|
|
286
|
+
return layer(s, cache)
|
|
287
|
+
|
|
288
|
+
for layer in self.vit:
|
|
289
|
+
seq = checkpoint(_run_layer,
|
|
290
|
+
layer,
|
|
291
|
+
seq,
|
|
292
|
+
cos_sin_cache,
|
|
293
|
+
use_reentrant=False)
|
|
294
|
+
else:
|
|
295
|
+
for layer in self.vit:
|
|
296
|
+
seq = layer(seq, cos_sin_cache)
|
|
297
|
+
return seq
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class CMERConfig(PretrainedConfig):
|
|
301
|
+
model_type = 'CMER'
|
|
302
|
+
|
|
303
|
+
def __init__(self, vision_config=None, decoder_config=None, **kwargs):
|
|
304
|
+
self.vision_config = vision_config if vision_config is not None else {}
|
|
305
|
+
self.decoder_config = decoder_config if decoder_config is not None else {}
|
|
306
|
+
if self.decoder_config:
|
|
307
|
+
for key, value in self.decoder_config.items():
|
|
308
|
+
setattr(self, key, value)
|
|
309
|
+
if hasattr(self, 'decoder_layers'):
|
|
310
|
+
self.num_hidden_layers = self.decoder_layers
|
|
311
|
+
super().__init__(**kwargs, **self.decoder_config)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class CMER(PreTrainedModel, GenerationMixin):
|
|
315
|
+
config_class = CMERConfig
|
|
316
|
+
base_model_prefix = 'cmer'
|
|
317
|
+
main_input_name = 'pixel_values'
|
|
318
|
+
|
|
319
|
+
def __init__(self, config: CMERConfig):
|
|
320
|
+
super().__init__(config)
|
|
321
|
+
self.config = config
|
|
322
|
+
decoder_config = MBartConfig(**config.decoder_config)
|
|
323
|
+
self.vision_model = CMEREncoder(**config.vision_config)
|
|
324
|
+
self.llm_model = MBartDecoder(decoder_config)
|
|
325
|
+
self.lm_head = torch.nn.Linear(decoder_config.d_model,
|
|
326
|
+
decoder_config.vocab_size,
|
|
327
|
+
bias=False)
|
|
328
|
+
setattr(self.config, 'tie_word_embeddings', True)
|
|
329
|
+
self.tie_weights()
|
|
330
|
+
setattr(self.lm_head, '_dynamic_tied_weights_keys', ['weight'])
|
|
331
|
+
setattr(self.llm_model.embed_tokens, '_dynamic_tied_weights_keys',
|
|
332
|
+
['weight'])
|
|
333
|
+
|
|
334
|
+
def set_gradient_checkpointing(self, enable: bool = True):
|
|
335
|
+
self.gradient_checkpointing = bool(enable)
|
|
336
|
+
if hasattr(self.vision_model, 'set_gradient_checkpointing'):
|
|
337
|
+
self.vision_model.gradient_checkpointing = self.gradient_checkpointing
|
|
338
|
+
if hasattr(self.llm_model, 'set_gradient_checkpointing'):
|
|
339
|
+
self.llm_model.gradient_checkpointing = self.gradient_checkpointing
|
|
340
|
+
if enable:
|
|
341
|
+
self.llm_model.config.use_cache = False
|
|
342
|
+
|
|
343
|
+
def get_output_embeddings(self):
|
|
344
|
+
return self.lm_head
|
|
345
|
+
|
|
346
|
+
def set_output_embeddings(self, new_emb):
|
|
347
|
+
self.lm_head = new_emb
|
|
348
|
+
|
|
349
|
+
def state_dict(self, *args, **kwargs):
|
|
350
|
+
sd = super().state_dict(*args, **kwargs)
|
|
351
|
+
if 'llm_model.embed_tokens.weight' not in sd and 'lm_head.weight' in sd:
|
|
352
|
+
sd['llm_model.embed_tokens.weight'] = sd['lm_head.weight']
|
|
353
|
+
elif 'lm_head.weight' not in sd and 'llm_model.embed_tokens.weight' in sd:
|
|
354
|
+
sd['lm_head.weight'] = sd['llm_model.embed_tokens.weight']
|
|
355
|
+
return sd
|
|
356
|
+
|
|
357
|
+
def load_state_dict(self, state_dict, strict=True):
|
|
358
|
+
if 'llm_model.embed_tokens.weight' not in state_dict and 'lm_head.weight' in state_dict:
|
|
359
|
+
state_dict['llm_model.embed_tokens.weight'] = state_dict[
|
|
360
|
+
'lm_head.weight']
|
|
361
|
+
if 'lm_head.weight' not in state_dict and 'llm_model.embed_tokens.weight' in state_dict:
|
|
362
|
+
state_dict['lm_head.weight'] = state_dict[
|
|
363
|
+
'llm_model.embed_tokens.weight']
|
|
364
|
+
out = super().load_state_dict(state_dict, strict=False)
|
|
365
|
+
self.tie_weights()
|
|
366
|
+
return out
|
|
367
|
+
|
|
368
|
+
def get_input_embeddings(self):
|
|
369
|
+
return self.llm_model.get_input_embeddings()
|
|
370
|
+
|
|
371
|
+
def set_input_embeddings(self, value):
|
|
372
|
+
self.llm_model.set_input_embeddings(value)
|
|
373
|
+
|
|
374
|
+
def get_decoder(self):
|
|
375
|
+
return self.llm_model.get_decoder()
|
|
376
|
+
|
|
377
|
+
def _swin_stride_and_winsize(self):
|
|
378
|
+
cfg = self.vision_model.config
|
|
379
|
+
patch = int(getattr(cfg, 'patch_size', 4))
|
|
380
|
+
depths = getattr(cfg, 'depths', [2, 2, 6, 2])
|
|
381
|
+
stride = patch * (2**(len(depths) - 1))
|
|
382
|
+
wsize = int(getattr(cfg, 'window_size', 7))
|
|
383
|
+
return stride, wsize
|
|
384
|
+
|
|
385
|
+
def _ensure_swin_safe(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
386
|
+
stride, wsize = self._swin_stride_and_winsize()
|
|
387
|
+
B, C, H, W = pixel_values.shape
|
|
388
|
+
need_min = wsize * stride
|
|
389
|
+
if min(H, W) < need_min:
|
|
390
|
+
s = need_min / float(min(H, W))
|
|
391
|
+
new_h = math.ceil(H * s / stride) * stride
|
|
392
|
+
new_w = math.ceil(W * s / stride) * stride
|
|
393
|
+
pixel_values = F.interpolate(pixel_values,
|
|
394
|
+
size=(new_h, new_w),
|
|
395
|
+
mode='bilinear',
|
|
396
|
+
align_corners=False)
|
|
397
|
+
H, W = new_h, new_w
|
|
398
|
+
new_h = math.ceil(H / stride) * stride
|
|
399
|
+
new_w = math.ceil(W / stride) * stride
|
|
400
|
+
if (new_h, new_w) != (H, W):
|
|
401
|
+
pixel_values = F.interpolate(pixel_values,
|
|
402
|
+
size=(new_h, new_w),
|
|
403
|
+
mode='bilinear',
|
|
404
|
+
align_corners=False)
|
|
405
|
+
return pixel_values
|
|
406
|
+
|
|
407
|
+
def forward(
|
|
408
|
+
self,
|
|
409
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
410
|
+
decoder_input_ids: Optional[torch.Tensor] = None,
|
|
411
|
+
encoder_outputs: Optional[BaseModelOutput] = None,
|
|
412
|
+
past_key_values: Optional[tuple] = None,
|
|
413
|
+
labels: Optional[torch.Tensor] = None,
|
|
414
|
+
**kwargs,
|
|
415
|
+
) -> CausalLMOutputWithCrossAttentions:
|
|
416
|
+
|
|
417
|
+
# 1. 兼容性处理:如果 Trainer 传入的是 'image' 而不是 'pixel_values'
|
|
418
|
+
if pixel_values is None and 'image' in kwargs:
|
|
419
|
+
pixel_values = kwargs.pop('image')
|
|
420
|
+
|
|
421
|
+
# 2. 兼容性处理:如果 Trainer 传入的是 'label' 而不是 'labels'
|
|
422
|
+
if labels is None and 'label' in kwargs:
|
|
423
|
+
labels = kwargs.pop('label')
|
|
424
|
+
|
|
425
|
+
# 3. Encoder Forward
|
|
426
|
+
if encoder_outputs is None:
|
|
427
|
+
if pixel_values is None:
|
|
428
|
+
raise ValueError(
|
|
429
|
+
'`pixel_values` must be provided when `encoder_outputs` is not.'
|
|
430
|
+
)
|
|
431
|
+
# pixel_values = self._ensure_swin_safe(pixel_values) # 如果需要 Swin 对齐,取消注释
|
|
432
|
+
encoder_outputs = self.vision_model(pixel_values)
|
|
433
|
+
|
|
434
|
+
# 4. 自动生成 decoder_input_ids (Teacher Forcing)
|
|
435
|
+
# 如果没有传 decoder_input_ids,但传了 labels,则使用 labels 作为输入
|
|
436
|
+
# 注意:Processor 已经加了 BOS/EOS,labels 格式通常为 [BOS, token1, token2, EOS]
|
|
437
|
+
# 输入给 Decoder 的应该是 [BOS, token1, token2, EOS]
|
|
438
|
+
# 计算 Loss 时,logits 会取 [:-1],labels 会取 [1:],从而实现预测下一个 token
|
|
439
|
+
if decoder_input_ids is None and labels is not None:
|
|
440
|
+
decoder_input_ids = labels.clone()
|
|
441
|
+
# 将 -100 (ignore_index) 替换为 pad_token_id,防止 embedding 越界
|
|
442
|
+
pad_token_id = self.config.decoder_config.get(
|
|
443
|
+
'pad_token_id',
|
|
444
|
+
self.config.decoder_config.get('eos_token_id',
|
|
445
|
+
1)) # 默认 fallback
|
|
446
|
+
decoder_input_ids.masked_fill_(decoder_input_ids == -100,
|
|
447
|
+
pad_token_id)
|
|
448
|
+
|
|
449
|
+
# 5. Decoder Forward
|
|
450
|
+
# 此时 decoder_input_ids 应该已经有值了,不会再报 ValueError
|
|
451
|
+
decoder_outputs = self.llm_model(
|
|
452
|
+
input_ids=decoder_input_ids,
|
|
453
|
+
inputs_embeds=None, # <--- 强制为 None,解决报错
|
|
454
|
+
encoder_hidden_states=encoder_outputs,
|
|
455
|
+
past_key_values=past_key_values,
|
|
456
|
+
use_cache=False,
|
|
457
|
+
return_dict=True,
|
|
458
|
+
# 注意:不要在这里传入 **kwargs,因为 kwargs 可能包含 'decoder_inputs_embeds' 等导致冲突的键
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
logits = self.lm_head(decoder_outputs.last_hidden_state)
|
|
462
|
+
|
|
463
|
+
loss = None
|
|
464
|
+
if labels is not None:
|
|
465
|
+
# Shift so that tokens < n predict n
|
|
466
|
+
shift_logits = logits[:, :-1, :].contiguous()
|
|
467
|
+
shift_labels = labels[:, 1:].contiguous()
|
|
468
|
+
eps = getattr(self.config, 'label_smoothing', 0.1)
|
|
469
|
+
loss = F.cross_entropy(
|
|
470
|
+
shift_logits.view(-1,
|
|
471
|
+
self.config.decoder_config['vocab_size']),
|
|
472
|
+
shift_labels.view(-1),
|
|
473
|
+
ignore_index=-100,
|
|
474
|
+
label_smoothing=eps,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
return CausalLMOutputWithCrossAttentions(
|
|
478
|
+
loss=loss,
|
|
479
|
+
logits=logits,
|
|
480
|
+
past_key_values=None
|
|
481
|
+
if self.training else decoder_outputs.past_key_values,
|
|
482
|
+
hidden_states=None,
|
|
483
|
+
attentions=None,
|
|
484
|
+
cross_attentions=None,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
def prepare_inputs_for_generation(self,
|
|
488
|
+
input_ids,
|
|
489
|
+
past_key_values=None,
|
|
490
|
+
**kwargs):
|
|
491
|
+
if past_key_values is not None:
|
|
492
|
+
input_ids = input_ids[:, -1:]
|
|
493
|
+
return {
|
|
494
|
+
'decoder_input_ids': input_ids,
|
|
495
|
+
'past_key_values': past_key_values,
|
|
496
|
+
'encoder_outputs': kwargs.get('encoder_outputs'),
|
|
497
|
+
'attention_mask': kwargs.get('attention_mask'),
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
@torch.no_grad()
|
|
501
|
+
def generate(
|
|
502
|
+
self,
|
|
503
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
504
|
+
decoder_input_ids: Optional[torch.Tensor] = None,
|
|
505
|
+
max_new_tokens: int = 256,
|
|
506
|
+
do_sample: bool = False,
|
|
507
|
+
temperature: float = 1.0,
|
|
508
|
+
top_k: int = 50,
|
|
509
|
+
top_p: float = 1.0,
|
|
510
|
+
bos_token_id: Optional[int] = 2,
|
|
511
|
+
eos_token_id: Optional[int] = None,
|
|
512
|
+
pad_token_id: Optional[int] = None,
|
|
513
|
+
return_only_new_tokens: bool = True,
|
|
514
|
+
num_beams: int = 1,
|
|
515
|
+
**kwargs,
|
|
516
|
+
):
|
|
517
|
+
if num_beams != 1:
|
|
518
|
+
raise NotImplementedError(
|
|
519
|
+
'当前极简 generate 未实现 beam search(num_beams>1)。')
|
|
520
|
+
device = pixel_values.device if pixel_values is not None else next(
|
|
521
|
+
self.parameters()).device
|
|
522
|
+
encoder_outputs = kwargs.get('encoder_outputs', None)
|
|
523
|
+
if encoder_outputs is None:
|
|
524
|
+
if pixel_values is None:
|
|
525
|
+
raise ValueError(
|
|
526
|
+
'`pixel_values` is required if `encoder_outputs` is not provided.'
|
|
527
|
+
)
|
|
528
|
+
enc = self.vision_model(pixel_values)
|
|
529
|
+
encoder_hidden_states = enc
|
|
530
|
+
else:
|
|
531
|
+
if isinstance(encoder_outputs, (tuple, list)):
|
|
532
|
+
encoder_hidden_states = encoder_outputs[0]
|
|
533
|
+
elif hasattr(encoder_outputs, 'last_hidden_state'):
|
|
534
|
+
encoder_hidden_states = encoder_outputs.last_hidden_state
|
|
535
|
+
elif isinstance(encoder_outputs,
|
|
536
|
+
dict) and 'last_hidden_state' in encoder_outputs:
|
|
537
|
+
encoder_hidden_states = encoder_outputs['last_hidden_state']
|
|
538
|
+
else:
|
|
539
|
+
raise ValueError(
|
|
540
|
+
'`encoder_outputs` 格式不正确,缺少 last_hidden_state。')
|
|
541
|
+
encoder_hidden_states = encoder_hidden_states.to(device)
|
|
542
|
+
batch_size = encoder_hidden_states.size(0)
|
|
543
|
+
|
|
544
|
+
bos_id = bos_token_id
|
|
545
|
+
|
|
546
|
+
if eos_token_id is None:
|
|
547
|
+
eos_token_id = kwargs.get('eos_token_id', None)
|
|
548
|
+
if eos_token_id is None:
|
|
549
|
+
eos_token_id = -1
|
|
550
|
+
if pad_token_id is None:
|
|
551
|
+
pad_token_id = kwargs.get('pad_token_id', None)
|
|
552
|
+
if pad_token_id is None:
|
|
553
|
+
pad_token_id = bos_id
|
|
554
|
+
if decoder_input_ids is None:
|
|
555
|
+
input_ids = torch.full((batch_size, 1),
|
|
556
|
+
bos_id,
|
|
557
|
+
dtype=torch.long,
|
|
558
|
+
device=device)
|
|
559
|
+
else:
|
|
560
|
+
input_ids = decoder_input_ids.to(device)
|
|
561
|
+
self.llm_model.config.use_cache = True
|
|
562
|
+
past_key_values = None
|
|
563
|
+
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
|
564
|
+
|
|
565
|
+
def _top_k_top_p_filtering(logits,
|
|
566
|
+
top_k=0,
|
|
567
|
+
top_p=1.0,
|
|
568
|
+
min_tokens_to_keep=1):
|
|
569
|
+
top_k = min(max(top_k, 0), logits.size(-1))
|
|
570
|
+
if top_k > 0:
|
|
571
|
+
kth_vals, _ = torch.topk(logits, top_k)
|
|
572
|
+
min_thresh = kth_vals[..., -1, None]
|
|
573
|
+
logits = torch.where(logits < min_thresh,
|
|
574
|
+
torch.full_like(logits, float('-inf')),
|
|
575
|
+
logits)
|
|
576
|
+
if top_p < 1.0:
|
|
577
|
+
sorted_logits, sorted_indices = torch.sort(logits,
|
|
578
|
+
descending=True)
|
|
579
|
+
probs = torch.softmax(sorted_logits, dim=-1)
|
|
580
|
+
cumulative_probs = probs.cumsum(dim=-1)
|
|
581
|
+
sorted_mask = cumulative_probs > top_p
|
|
582
|
+
if min_tokens_to_keep > 0:
|
|
583
|
+
sorted_mask[..., :min_tokens_to_keep] = 0
|
|
584
|
+
sorted_logits = sorted_logits.masked_fill(
|
|
585
|
+
sorted_mask, float('-inf'))
|
|
586
|
+
logits = torch.zeros_like(logits).scatter(dim=-1,
|
|
587
|
+
index=sorted_indices,
|
|
588
|
+
src=sorted_logits)
|
|
589
|
+
return logits
|
|
590
|
+
|
|
591
|
+
for _ in range(max_new_tokens):
|
|
592
|
+
dec_in = input_ids[:, -1:]
|
|
593
|
+
dec_out = self.llm_model(
|
|
594
|
+
input_ids=dec_in,
|
|
595
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
596
|
+
past_key_values=past_key_values,
|
|
597
|
+
use_cache=True,
|
|
598
|
+
return_dict=True,
|
|
599
|
+
)
|
|
600
|
+
past_key_values = dec_out.past_key_values
|
|
601
|
+
hidden = dec_out.last_hidden_state
|
|
602
|
+
logits = self.lm_head(hidden[:, -1, :])
|
|
603
|
+
if do_sample:
|
|
604
|
+
logits = logits / max(temperature, 1e-6)
|
|
605
|
+
logits = _top_k_top_p_filtering(logits,
|
|
606
|
+
top_k=top_k,
|
|
607
|
+
top_p=top_p,
|
|
608
|
+
min_tokens_to_keep=1)
|
|
609
|
+
probs = torch.softmax(logits, dim=-1)
|
|
610
|
+
next_tokens = torch.multinomial(probs,
|
|
611
|
+
num_samples=1).squeeze(-1)
|
|
612
|
+
else:
|
|
613
|
+
next_tokens = torch.argmax(logits, dim=-1)
|
|
614
|
+
next_tokens = torch.where(
|
|
615
|
+
finished, torch.full_like(next_tokens, pad_token_id),
|
|
616
|
+
next_tokens)
|
|
617
|
+
input_ids = torch.cat(
|
|
618
|
+
[input_ids, next_tokens.unsqueeze(-1)], dim=-1)
|
|
619
|
+
if eos_token_id >= 0:
|
|
620
|
+
finished = finished | (next_tokens == eos_token_id)
|
|
621
|
+
if torch.all(finished):
|
|
622
|
+
break
|
|
623
|
+
if return_only_new_tokens:
|
|
624
|
+
if decoder_input_ids is None:
|
|
625
|
+
return input_ids[:, 1:]
|
|
626
|
+
else:
|
|
627
|
+
return input_ids[:, decoder_input_ids.size(1):]
|
|
628
|
+
else:
|
|
629
|
+
return input_ids
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
def build_model_cmer(config):
|
|
633
|
+
backbone_config = config.get('Backbone', {})
|
|
634
|
+
|
|
635
|
+
vision_cfg = backbone_config.get('vision_config', {})
|
|
636
|
+
decoder_cfg = backbone_config.get('decoder_config', {})
|
|
637
|
+
|
|
638
|
+
cmer_config = CMERConfig(vision_config=vision_cfg,
|
|
639
|
+
decoder_config=decoder_cfg)
|
|
640
|
+
|
|
641
|
+
model = CMER(cmer_config)
|
|
642
|
+
|
|
643
|
+
return model
|
|
@@ -132,7 +132,7 @@ class EncoderWithSVTR(nn.Module):
|
|
|
132
132
|
z = self.conv2(z)
|
|
133
133
|
# SVTR global block
|
|
134
134
|
B, C, H, W = z.shape
|
|
135
|
-
z = z.flatten(2).transpose(1, 2)
|
|
135
|
+
z = z.flatten(2).transpose(1, 2).contiguous()
|
|
136
136
|
for blk in self.svtr_block:
|
|
137
137
|
z = blk(z)
|
|
138
138
|
z = self.norm(z)
|
|
@@ -186,10 +186,10 @@ class DANDecoder(nn.Module):
|
|
|
186
186
|
torch.zeros(nB, dtype=torch.int64, device=feature.device) +
|
|
187
187
|
self.bos)
|
|
188
188
|
dec_seq = torch.full((nB, nT),
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
189
|
+
self.ignore_index,
|
|
190
|
+
dtype=torch.int64,
|
|
191
|
+
device=feature.get_device())
|
|
192
|
+
|
|
193
193
|
for i in range(0, nT):
|
|
194
194
|
hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1),
|
|
195
195
|
hidden)
|