mineru 2.5.3__py3-none-any.whl → 2.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mineru/backend/pipeline/model_init.py +25 -3
- mineru/backend/pipeline/model_json_to_middle_json.py +2 -2
- mineru/backend/pipeline/model_list.py +0 -1
- mineru/backend/utils.py +24 -0
- mineru/backend/vlm/model_output_to_middle_json.py +2 -2
- mineru/backend/vlm/{custom_logits_processors.py → utils.py} +36 -2
- mineru/backend/vlm/vlm_analyze.py +43 -50
- mineru/backend/vlm/vlm_magic_model.py +155 -1
- mineru/cli/common.py +26 -23
- mineru/cli/fast_api.py +2 -8
- mineru/cli/gradio_app.py +104 -13
- mineru/cli/models_download.py +1 -0
- mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py +152 -0
- mineru/model/mfr/pp_formulanet_plus_m/processors.py +657 -0
- mineru/model/mfr/unimernet/unimernet_hf/modeling_unimernet.py +1 -326
- mineru/model/mfr/utils.py +338 -0
- mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py +103 -16
- mineru/model/table/rec/unet_table/main.py +1 -1
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/imaug/operators.py +5 -5
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/__init__.py +2 -1
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_lcnetv3.py +7 -7
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_pphgnetv2.py +2 -2
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/__init__.py +2 -0
- mineru/model/utils/pytorchocr/modeling/heads/rec_ppformulanet_head.py +1383 -0
- mineru/model/utils/pytorchocr/modeling/heads/rec_unimernet_head.py +2631 -0
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/rec_postprocess.py +25 -28
- mineru/model/utils/pytorchocr/utils/__init__.py +0 -0
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/arch_config.yaml +130 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_arabic_dict.txt +747 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_cyrillic_dict.txt +850 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_devanagari_dict.txt +568 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_ta_dict.txt +513 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_te_dict.txt +540 -0
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/models_config.yml +15 -15
- mineru/model/utils/pytorchocr/utils/resources/pp_formulanet_arch_config.yaml +24 -0
- mineru/model/utils/tools/infer/__init__.py +1 -0
- mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_det.py +6 -3
- mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_rec.py +16 -25
- mineru/model/vlm_vllm_model/server.py +4 -1
- mineru/resources/header.html +2 -2
- mineru/utils/enum_class.py +1 -0
- mineru/utils/guess_suffix_or_lang.py +9 -1
- mineru/utils/llm_aided.py +4 -2
- mineru/utils/ocr_utils.py +16 -0
- mineru/utils/table_merge.py +102 -13
- mineru/version.py +1 -1
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/METADATA +33 -6
- mineru-2.6.0.dist-info/RECORD +195 -0
- mineru-2.5.3.dist-info/RECORD +0 -181
- /mineru/model/{ocr/paddleocr2pytorch/pytorchocr → mfr/pp_formulanet_plus_m}/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch/tools/infer → utils}/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch/pytorchocr/modeling → utils/pytorchocr}/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/base_ocr_v20.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/imaug/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch/pytorchocr/utils → utils/pytorchocr/modeling}/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/architectures/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/architectures/base_model.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/det_mobilenet_v3.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_donut_swin.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_hgnet.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_mv1_enhance.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_svtrnet.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/common.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/cls_head.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/det_db_head.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/rec_ctc_head.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/rec_multi_head.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/db_fpn.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/intracl.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/rnn.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/cls_postprocess.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/db_postprocess.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/arabic_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/cyrillic_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/devanagari_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/en_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/japan_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ka_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/korean_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/latin_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv4_doc_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_el_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_en_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_eslav_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_korean_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_latin_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_th_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ta_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/te_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_cls.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_system.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/pytorchocr_utility.py +0 -0
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/WHEEL +0 -0
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/entry_points.txt +0 -0
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/licenses/LICENSE.md +0 -0
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1383 @@
|
|
|
1
|
+
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import math
|
|
16
|
+
import re
|
|
17
|
+
import numpy as np
|
|
18
|
+
import inspect
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
from typing import Optional, Tuple, Union, List, Dict, Any
|
|
22
|
+
from dataclasses import dataclass, fields, is_dataclass
|
|
23
|
+
|
|
24
|
+
from sympy import totient
|
|
25
|
+
|
|
26
|
+
from mineru.utils.config_reader import get_device
|
|
27
|
+
from .rec_unimernet_head import (
|
|
28
|
+
MBartForCausalLM,
|
|
29
|
+
MBartDecoder,
|
|
30
|
+
MBartConfig,
|
|
31
|
+
ModelOutput,
|
|
32
|
+
BaseModelOutputWithPastAndCrossAttentions,
|
|
33
|
+
Seq2SeqLMOutput,
|
|
34
|
+
CausalLMOutputWithCrossAttentions,
|
|
35
|
+
LogitsProcessorList,
|
|
36
|
+
ForcedEOSTokenLogitsProcessor,
|
|
37
|
+
UniMERNetHead,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class AttentionMaskConverter:
|
|
43
|
+
"""
|
|
44
|
+
A class to convert attention masks based on specific configurations.
|
|
45
|
+
|
|
46
|
+
This class is designed to handle the conversion of attention masks with options for causal masking
|
|
47
|
+
and sliding window attention, which are commonly used in transformer models.
|
|
48
|
+
|
|
49
|
+
Attributes:
|
|
50
|
+
is_causal (bool): Flag indicating whether the attention mask should enforce causal masking,
|
|
51
|
+
which ensures each position can only attend to previous positions.
|
|
52
|
+
sliding_window (int, optional): Size of the sliding window for local attention. If set,
|
|
53
|
+
attention is restricted to a local window of this size.
|
|
54
|
+
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
is_causal: bool
|
|
58
|
+
sliding_window: int
|
|
59
|
+
|
|
60
|
+
def __init__(self, is_causal: bool, sliding_window=None):
|
|
61
|
+
self.is_causal = is_causal
|
|
62
|
+
self.sliding_window = sliding_window
|
|
63
|
+
|
|
64
|
+
if self.sliding_window is not None and self.sliding_window <= 0:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def _make_causal_mask(
|
|
71
|
+
input_ids_shape,
|
|
72
|
+
dtype,
|
|
73
|
+
past_key_values_length=0,
|
|
74
|
+
sliding_window=None,
|
|
75
|
+
is_export=False,
|
|
76
|
+
):
|
|
77
|
+
"""
|
|
78
|
+
Make causal mask used for bi-directional self-attention.
|
|
79
|
+
"""
|
|
80
|
+
bsz, tgt_len = input_ids_shape
|
|
81
|
+
if is_export:
|
|
82
|
+
mask = torch.full(
|
|
83
|
+
(tgt_len, tgt_len), torch.finfo(dtype).min, dtype=torch.float64
|
|
84
|
+
)
|
|
85
|
+
mask_cond = torch.arange(mask.shape[-1])
|
|
86
|
+
mask.masked_fill_(
|
|
87
|
+
mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
|
|
91
|
+
mask_cond = torch.arange(mask.shape[-1])
|
|
92
|
+
mask.masked_fill_(
|
|
93
|
+
mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
|
|
94
|
+
)
|
|
95
|
+
mask = mask.to(dtype)
|
|
96
|
+
|
|
97
|
+
if past_key_values_length > 0:
|
|
98
|
+
mask = torch.concat(
|
|
99
|
+
[torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask],
|
|
100
|
+
dim=-1,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# add lower triangular sliding window mask if necessary
|
|
104
|
+
if sliding_window is not None:
|
|
105
|
+
diagonal = past_key_values_length - sliding_window - 1
|
|
106
|
+
|
|
107
|
+
context_mask = torch.tril(
|
|
108
|
+
torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal
|
|
109
|
+
)
|
|
110
|
+
mask.masked_fill_(context_mask, torch.finfo(dtype).min)
|
|
111
|
+
|
|
112
|
+
return mask[None, None, :, :].expand(
|
|
113
|
+
[bsz, 1, tgt_len, tgt_len + past_key_values_length]
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
@staticmethod
|
|
117
|
+
def _make_causal_mask_parallel(
|
|
118
|
+
input_ids_shape,
|
|
119
|
+
dtype,
|
|
120
|
+
past_key_values_length=0,
|
|
121
|
+
sliding_window=None,
|
|
122
|
+
parallel_step=1,
|
|
123
|
+
is_export=False,
|
|
124
|
+
):
|
|
125
|
+
"""
|
|
126
|
+
Make causal mask used for bi-directional self-attention.
|
|
127
|
+
"""
|
|
128
|
+
bsz, tgt_len = input_ids_shape
|
|
129
|
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
|
|
130
|
+
mask_cond = torch.arange(mask.shape[-1])
|
|
131
|
+
mask_cond_parallel = torch.arange(mask.shape[-1])
|
|
132
|
+
|
|
133
|
+
mask_parallel = torch.arange(0, tgt_len, step=parallel_step).reshape([1, -1])
|
|
134
|
+
mask_parallel = torch.repeat_interleave(mask_parallel, parallel_step, 1)[
|
|
135
|
+
:, :tgt_len
|
|
136
|
+
]
|
|
137
|
+
mask.masked_fill_(
|
|
138
|
+
mask_cond < (mask_parallel + parallel_step).reshape([mask.shape[-1], 1]), 0
|
|
139
|
+
)
|
|
140
|
+
mask = mask.to(dtype)
|
|
141
|
+
|
|
142
|
+
if past_key_values_length > 0:
|
|
143
|
+
mask = torch.concat(
|
|
144
|
+
[torch.zeros([tgt_len, past_key_values_length], dtype=dtype), mask],
|
|
145
|
+
dim=-1,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# add lower triangular sliding window mask if necessary
|
|
149
|
+
if sliding_window is not None:
|
|
150
|
+
diagonal = past_key_values_length - sliding_window - 1
|
|
151
|
+
|
|
152
|
+
context_mask = torch.tril(
|
|
153
|
+
torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal
|
|
154
|
+
)
|
|
155
|
+
mask.masked_fill_(context_mask, torch.finfo(dtype).min)
|
|
156
|
+
|
|
157
|
+
return mask[None, None, :, :].expand(
|
|
158
|
+
[bsz, 1, tgt_len, tgt_len + past_key_values_length]
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def to_4d(
|
|
162
|
+
self,
|
|
163
|
+
attention_mask_2d,
|
|
164
|
+
query_length,
|
|
165
|
+
dtype,
|
|
166
|
+
key_value_length,
|
|
167
|
+
use_parallel=False,
|
|
168
|
+
parallel_step=3,
|
|
169
|
+
is_export=False,
|
|
170
|
+
):
|
|
171
|
+
"""
|
|
172
|
+
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
|
|
173
|
+
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
|
|
174
|
+
causal, a causal mask will be added.
|
|
175
|
+
"""
|
|
176
|
+
input_shape = (attention_mask_2d.shape[0], query_length)
|
|
177
|
+
|
|
178
|
+
causal_4d_mask = None
|
|
179
|
+
if use_parallel:
|
|
180
|
+
step = parallel_step
|
|
181
|
+
else:
|
|
182
|
+
step = 1
|
|
183
|
+
if (
|
|
184
|
+
input_shape[-1] > step or self.sliding_window is not None
|
|
185
|
+
) and self.is_causal:
|
|
186
|
+
|
|
187
|
+
if key_value_length is None:
|
|
188
|
+
raise ValueError(
|
|
189
|
+
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
past_key_values_length = key_value_length - query_length
|
|
193
|
+
|
|
194
|
+
if use_parallel:
|
|
195
|
+
causal_4d_mask = self._make_causal_mask_parallel(
|
|
196
|
+
input_shape,
|
|
197
|
+
dtype,
|
|
198
|
+
past_key_values_length=past_key_values_length,
|
|
199
|
+
sliding_window=self.sliding_window,
|
|
200
|
+
parallel_step=parallel_step,
|
|
201
|
+
is_export=is_export,
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
causal_4d_mask = self._make_causal_mask(
|
|
205
|
+
input_shape,
|
|
206
|
+
dtype,
|
|
207
|
+
past_key_values_length=past_key_values_length,
|
|
208
|
+
sliding_window=self.sliding_window,
|
|
209
|
+
is_export=is_export,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
elif self.sliding_window is not None:
|
|
213
|
+
raise NotImplementedError(
|
|
214
|
+
"Sliding window is currently only implemented for causal masking"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
expanded_attn_mask = self._expand_mask(
|
|
218
|
+
attention_mask_2d, dtype, tgt_len=input_shape[-1]
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if causal_4d_mask is not None:
|
|
222
|
+
expanded_attn_mask = causal_4d_mask.masked_fill_(
|
|
223
|
+
expanded_attn_mask.to(torch.bool), torch.finfo(dtype).min
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
expanded_4d_mask = expanded_attn_mask
|
|
227
|
+
return expanded_4d_mask
|
|
228
|
+
|
|
229
|
+
def to_4d_export(
|
|
230
|
+
self,
|
|
231
|
+
attention_mask_2d,
|
|
232
|
+
query_length,
|
|
233
|
+
dtype,
|
|
234
|
+
key_value_length,
|
|
235
|
+
use_parallel=False,
|
|
236
|
+
parallel_step=3,
|
|
237
|
+
is_export=False,
|
|
238
|
+
):
|
|
239
|
+
input_shape = (attention_mask_2d.shape[0], query_length)
|
|
240
|
+
|
|
241
|
+
expanded_attn_mask = self._expand_mask_export(
|
|
242
|
+
attention_mask_2d, dtype, tgt_len=input_shape[-1]
|
|
243
|
+
)
|
|
244
|
+
expanded_4d_mask = expanded_attn_mask
|
|
245
|
+
|
|
246
|
+
return expanded_4d_mask
|
|
247
|
+
|
|
248
|
+
def _expand_mask(self, mask, dtype, tgt_len=None):
|
|
249
|
+
"""
|
|
250
|
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
251
|
+
"""
|
|
252
|
+
bsz, src_len = mask.shape
|
|
253
|
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
254
|
+
expanded_mask = (
|
|
255
|
+
mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).to(dtype)
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
inverted_mask = 1.0 - expanded_mask
|
|
259
|
+
|
|
260
|
+
return inverted_mask.masked_fill_(
|
|
261
|
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
def _expand_mask_export(self, mask, dtype, tgt_len=None):
|
|
265
|
+
"""
|
|
266
|
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
267
|
+
"""
|
|
268
|
+
bsz, src_len = mask.shape
|
|
269
|
+
expanded_mask = (
|
|
270
|
+
mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).to(dtype)
|
|
271
|
+
)
|
|
272
|
+
inverted_mask = 1.0 - expanded_mask
|
|
273
|
+
return inverted_mask.masked_fill_(
|
|
274
|
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _prepare_4d_attention_mask(mask, dtype, tgt_len=None):
|
|
279
|
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _prepare_4d_causal_attention_mask(
|
|
283
|
+
attention_mask,
|
|
284
|
+
input_shape,
|
|
285
|
+
inputs_embeds,
|
|
286
|
+
past_key_values_length,
|
|
287
|
+
sliding_window=None,
|
|
288
|
+
use_parallel=False,
|
|
289
|
+
parallel_step=3,
|
|
290
|
+
is_export=False,
|
|
291
|
+
):
|
|
292
|
+
"""
|
|
293
|
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
|
294
|
+
`(batch_size, key_value_length)`
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
attention_mask (`paddle.Tensor` or `None`):
|
|
298
|
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
|
299
|
+
input_shape (`tuple(int)` or `list(int)` or `paddle.Size`):
|
|
300
|
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
|
301
|
+
inputs_embeds (`paddle.Tensor`):
|
|
302
|
+
The embedded inputs as a paddle Tensor.
|
|
303
|
+
past_key_values_length (`int`):
|
|
304
|
+
The length of the key value cache.
|
|
305
|
+
sliding_window (`int`, *optional*):
|
|
306
|
+
If the model uses windowed attention, a sliding window should be passed.
|
|
307
|
+
"""
|
|
308
|
+
attn_mask_converter = AttentionMaskConverter(
|
|
309
|
+
is_causal=True, sliding_window=sliding_window
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
key_value_length = input_shape[-1] + past_key_values_length
|
|
313
|
+
|
|
314
|
+
# 4d mask is passed through the layers
|
|
315
|
+
if attention_mask is not None and len(attention_mask.shape) == 2:
|
|
316
|
+
attention_mask = attn_mask_converter.to_4d(
|
|
317
|
+
attention_mask,
|
|
318
|
+
input_shape[-1],
|
|
319
|
+
key_value_length=key_value_length,
|
|
320
|
+
dtype=inputs_embeds.dtype,
|
|
321
|
+
use_parallel=use_parallel,
|
|
322
|
+
parallel_step=parallel_step,
|
|
323
|
+
is_export=is_export,
|
|
324
|
+
)
|
|
325
|
+
elif attention_mask is not None and len(attention_mask.shape) == 4:
|
|
326
|
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
|
327
|
+
if tuple(attention_mask.shape) != expected_shape:
|
|
328
|
+
raise ValueError(
|
|
329
|
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
|
330
|
+
)
|
|
331
|
+
else:
|
|
332
|
+
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
|
333
|
+
inverted_mask = 1.0 - attention_mask
|
|
334
|
+
attention_mask = inverted_mask.masked_fill_(
|
|
335
|
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
|
336
|
+
)
|
|
337
|
+
else:
|
|
338
|
+
attention_mask = attn_mask_converter.to_causal_4d(
|
|
339
|
+
input_shape[0],
|
|
340
|
+
input_shape[-1],
|
|
341
|
+
key_value_length,
|
|
342
|
+
dtype=inputs_embeds.dtype,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
return attention_mask
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def _prepare_4d_causal_attention_mask_export(
|
|
349
|
+
attention_mask,
|
|
350
|
+
input_shape,
|
|
351
|
+
inputs_embeds,
|
|
352
|
+
past_key_values_length,
|
|
353
|
+
sliding_window=None,
|
|
354
|
+
use_parallel=False,
|
|
355
|
+
parallel_step=3,
|
|
356
|
+
is_export=False,
|
|
357
|
+
):
|
|
358
|
+
"""
|
|
359
|
+
Prepare a 4D causal attention mask for export.
|
|
360
|
+
|
|
361
|
+
This function prepares a 4-dimensional causal attention mask, which is used to ensure that each position in the
|
|
362
|
+
sequence can only attend to previous positions. It is specifically designed to handle scenarios where the model
|
|
363
|
+
is being exported, potentially with additional options like sliding window or parallel processing.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
attention_mask: The initial attention mask, typically used to avoid attending to padding tokens.
|
|
367
|
+
input_shape: Shape of the input tensor, usually in the form (batch_size, sequence_length).
|
|
368
|
+
inputs_embeds: Embeddings of the input sequence, used to derive certain dimensions if needed.
|
|
369
|
+
past_key_values_length: Length of past key values, used in contexts like transformer decoders with caching.
|
|
370
|
+
sliding_window: Optional parameter. If provided, specifies the size of a sliding window for local attention.
|
|
371
|
+
use_parallel: Flag indicating whether to use parallel processing for attention computation.
|
|
372
|
+
parallel_step: Number of steps to use in parallel processing, relevant if `use_parallel` is True.
|
|
373
|
+
is_export: Flag indicating whether the attention mask is being prepared for model export.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
A 4D causal attention mask suitable for use in transformer models, ensuring correct causal masking.
|
|
377
|
+
"""
|
|
378
|
+
attn_mask_converter = AttentionMaskConverter(
|
|
379
|
+
is_causal=True, sliding_window=sliding_window
|
|
380
|
+
)
|
|
381
|
+
key_value_length = input_shape[-1] + past_key_values_length
|
|
382
|
+
|
|
383
|
+
shape = attention_mask.shape
|
|
384
|
+
len_shape = len(shape)
|
|
385
|
+
|
|
386
|
+
attention_mask = attn_mask_converter.to_4d_export(
|
|
387
|
+
attention_mask,
|
|
388
|
+
input_shape[-1],
|
|
389
|
+
key_value_length=key_value_length,
|
|
390
|
+
dtype=inputs_embeds.dtype,
|
|
391
|
+
use_parallel=use_parallel,
|
|
392
|
+
parallel_step=parallel_step,
|
|
393
|
+
is_export=is_export,
|
|
394
|
+
)
|
|
395
|
+
return attention_mask
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class CustomMBartDecoder(MBartDecoder):
|
|
399
|
+
def __init__(self, config):
|
|
400
|
+
super().__init__(config)
|
|
401
|
+
hidden_size = config.d_model
|
|
402
|
+
self.is_export = config.is_export
|
|
403
|
+
self.config_decoder = config
|
|
404
|
+
|
|
405
|
+
def forward(
|
|
406
|
+
self,
|
|
407
|
+
input_ids=None,
|
|
408
|
+
attention_mask=None,
|
|
409
|
+
encoder_hidden_states=None,
|
|
410
|
+
encoder_attention_mask=None,
|
|
411
|
+
head_mask=None,
|
|
412
|
+
cross_attn_head_mask=None,
|
|
413
|
+
past_key_values=None,
|
|
414
|
+
inputs_embeds=None,
|
|
415
|
+
use_cache=None,
|
|
416
|
+
output_attentions=None,
|
|
417
|
+
output_hidden_states=None,
|
|
418
|
+
return_dict=None,
|
|
419
|
+
):
|
|
420
|
+
self.is_export = False if self.training else True
|
|
421
|
+
|
|
422
|
+
output_attentions = (
|
|
423
|
+
output_attentions
|
|
424
|
+
if output_attentions is not None
|
|
425
|
+
else self.config.output_attentions
|
|
426
|
+
)
|
|
427
|
+
output_hidden_states = (
|
|
428
|
+
output_hidden_states
|
|
429
|
+
if output_hidden_states is not None
|
|
430
|
+
else self.config.output_hidden_states
|
|
431
|
+
)
|
|
432
|
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
433
|
+
return_dict = (
|
|
434
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
# retrieve input_ids and inputs_embeds
|
|
438
|
+
if input_ids is not None and inputs_embeds is not None:
|
|
439
|
+
raise ValueError(
|
|
440
|
+
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
|
441
|
+
)
|
|
442
|
+
elif input_ids is not None:
|
|
443
|
+
input = input_ids
|
|
444
|
+
input_shape = input.shape
|
|
445
|
+
input_ids = input_ids.reshape([-1, input_shape[-1]])
|
|
446
|
+
elif inputs_embeds is not None:
|
|
447
|
+
input_shape = inputs_embeds.shape[:-1]
|
|
448
|
+
input = inputs_embeds[:, :, -1]
|
|
449
|
+
else:
|
|
450
|
+
raise ValueError(
|
|
451
|
+
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
# past_key_values_length
|
|
455
|
+
past_key_values_length = (
|
|
456
|
+
past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
if inputs_embeds is None:
|
|
460
|
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
|
461
|
+
|
|
462
|
+
if self._use_flash_attention_2:
|
|
463
|
+
# 2d mask is passed through the layers
|
|
464
|
+
attention_mask = (
|
|
465
|
+
attention_mask
|
|
466
|
+
if (attention_mask is not None and 0 in attention_mask)
|
|
467
|
+
else None
|
|
468
|
+
)
|
|
469
|
+
else:
|
|
470
|
+
# 4d mask is passed through the layers
|
|
471
|
+
if self.is_export:
|
|
472
|
+
attention_mask = _prepare_4d_causal_attention_mask_export(
|
|
473
|
+
attention_mask,
|
|
474
|
+
input_shape,
|
|
475
|
+
inputs_embeds,
|
|
476
|
+
past_key_values_length,
|
|
477
|
+
use_parallel=self.config_decoder.use_parallel,
|
|
478
|
+
parallel_step=self.config_decoder.parallel_step,
|
|
479
|
+
is_export=self.is_export,
|
|
480
|
+
)
|
|
481
|
+
else:
|
|
482
|
+
attention_mask = _prepare_4d_causal_attention_mask(
|
|
483
|
+
attention_mask,
|
|
484
|
+
input_shape,
|
|
485
|
+
inputs_embeds,
|
|
486
|
+
past_key_values_length,
|
|
487
|
+
use_parallel=self.config_decoder.use_parallel,
|
|
488
|
+
parallel_step=self.config_decoder.parallel_step,
|
|
489
|
+
is_export=self.is_export,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# expand encoder attention mask
|
|
493
|
+
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
|
494
|
+
if self._use_flash_attention_2:
|
|
495
|
+
encoder_attention_mask = (
|
|
496
|
+
encoder_attention_mask if 0 in encoder_attention_mask else None
|
|
497
|
+
)
|
|
498
|
+
else:
|
|
499
|
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
500
|
+
encoder_attention_mask = _prepare_4d_attention_mask(
|
|
501
|
+
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
# embed positions
|
|
505
|
+
positions = self.embed_positions(input, past_key_values_length)
|
|
506
|
+
|
|
507
|
+
hidden_states = inputs_embeds + positions
|
|
508
|
+
|
|
509
|
+
hidden_states = self.layernorm_embedding(hidden_states)
|
|
510
|
+
hidden_states = nn.functional.dropout(
|
|
511
|
+
hidden_states, p=self.dropout, training=self.training
|
|
512
|
+
)
|
|
513
|
+
if self.gradient_checkpointing and self.training:
|
|
514
|
+
if use_cache:
|
|
515
|
+
print(
|
|
516
|
+
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
|
|
517
|
+
)
|
|
518
|
+
use_cache = False
|
|
519
|
+
|
|
520
|
+
# decoder layers
|
|
521
|
+
all_hidden_states = () if output_hidden_states else None
|
|
522
|
+
all_self_attns = () if output_attentions else None
|
|
523
|
+
all_cross_attentions = (
|
|
524
|
+
() if (output_attentions and encoder_hidden_states is not None) else None
|
|
525
|
+
)
|
|
526
|
+
next_decoder_cache = () if use_cache else None
|
|
527
|
+
|
|
528
|
+
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
|
529
|
+
for attn_mask, mask_name in zip(
|
|
530
|
+
[head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
|
|
531
|
+
):
|
|
532
|
+
if attn_mask is not None:
|
|
533
|
+
if attn_mask.size()[0] != len(self.layers):
|
|
534
|
+
raise ValueError(
|
|
535
|
+
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
|
536
|
+
f" {attn_mask.size()[0]}."
|
|
537
|
+
)
|
|
538
|
+
for idx, decoder_layer in enumerate(self.layers):
|
|
539
|
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
|
540
|
+
if output_hidden_states:
|
|
541
|
+
all_hidden_states += (hidden_states,)
|
|
542
|
+
if self.training:
|
|
543
|
+
dropout_probability = torch.rand([])
|
|
544
|
+
if dropout_probability < self.layerdrop:
|
|
545
|
+
continue
|
|
546
|
+
|
|
547
|
+
past_key_value = (
|
|
548
|
+
past_key_values[idx] if past_key_values is not None else None
|
|
549
|
+
)
|
|
550
|
+
if self.gradient_checkpointing and self.training:
|
|
551
|
+
layer_outputs = self._gradient_checkpointing_func(
|
|
552
|
+
decoder_layer.__call__,
|
|
553
|
+
hidden_states,
|
|
554
|
+
attention_mask,
|
|
555
|
+
encoder_hidden_states,
|
|
556
|
+
encoder_attention_mask,
|
|
557
|
+
head_mask[idx] if head_mask is not None else None,
|
|
558
|
+
(
|
|
559
|
+
cross_attn_head_mask[idx]
|
|
560
|
+
if cross_attn_head_mask is not None
|
|
561
|
+
else None
|
|
562
|
+
),
|
|
563
|
+
None,
|
|
564
|
+
output_attentions,
|
|
565
|
+
use_cache,
|
|
566
|
+
)
|
|
567
|
+
else:
|
|
568
|
+
layer_outputs = decoder_layer(
|
|
569
|
+
hidden_states,
|
|
570
|
+
attention_mask=attention_mask,
|
|
571
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
572
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
573
|
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
|
574
|
+
cross_attn_layer_head_mask=(
|
|
575
|
+
cross_attn_head_mask[idx]
|
|
576
|
+
if cross_attn_head_mask is not None
|
|
577
|
+
else None
|
|
578
|
+
),
|
|
579
|
+
past_key_value=past_key_value,
|
|
580
|
+
output_attentions=output_attentions,
|
|
581
|
+
use_cache=use_cache,
|
|
582
|
+
)
|
|
583
|
+
hidden_states = layer_outputs[0]
|
|
584
|
+
|
|
585
|
+
if self.is_export:
|
|
586
|
+
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
|
587
|
+
else:
|
|
588
|
+
if use_cache:
|
|
589
|
+
next_decoder_cache += (
|
|
590
|
+
layer_outputs[3 if output_attentions else 1],
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
if output_attentions:
|
|
594
|
+
all_self_attns += (layer_outputs[1],)
|
|
595
|
+
|
|
596
|
+
if encoder_hidden_states is not None:
|
|
597
|
+
all_cross_attentions += (layer_outputs[2],)
|
|
598
|
+
|
|
599
|
+
hidden_states = self.layer_norm(hidden_states)
|
|
600
|
+
|
|
601
|
+
# add hidden states from the last decoder layer
|
|
602
|
+
if output_hidden_states:
|
|
603
|
+
all_hidden_states += (hidden_states,)
|
|
604
|
+
|
|
605
|
+
if self.is_export:
|
|
606
|
+
next_cache = next_decoder_cache
|
|
607
|
+
else:
|
|
608
|
+
next_cache = next_decoder_cache if use_cache else None
|
|
609
|
+
if not return_dict:
|
|
610
|
+
return tuple(
|
|
611
|
+
v
|
|
612
|
+
for v in [
|
|
613
|
+
hidden_states,
|
|
614
|
+
next_cache,
|
|
615
|
+
all_hidden_states,
|
|
616
|
+
all_self_attns,
|
|
617
|
+
all_cross_attentions,
|
|
618
|
+
]
|
|
619
|
+
if v is not None
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
return BaseModelOutputWithPastAndCrossAttentions(
|
|
623
|
+
last_hidden_state=hidden_states,
|
|
624
|
+
past_key_values=next_cache,
|
|
625
|
+
hidden_states=all_hidden_states,
|
|
626
|
+
attentions=all_self_attns,
|
|
627
|
+
cross_attentions=all_cross_attentions,
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
class CustomMBartForCausalLM(MBartForCausalLM):
|
|
632
|
+
def __init__(self, config):
|
|
633
|
+
super().__init__(config)
|
|
634
|
+
# Modify the decoder within MBartDecoderWrapper
|
|
635
|
+
self.model.decoder = CustomMBartDecoder(config)
|
|
636
|
+
|
|
637
|
+
def forward(
|
|
638
|
+
self,
|
|
639
|
+
input_ids=None,
|
|
640
|
+
attention_mask=None,
|
|
641
|
+
encoder_hidden_states=None,
|
|
642
|
+
encoder_attention_mask=None,
|
|
643
|
+
head_mask=None,
|
|
644
|
+
cross_attn_head_mask=None,
|
|
645
|
+
past_key_values=None,
|
|
646
|
+
inputs_embeds=None,
|
|
647
|
+
labels=None,
|
|
648
|
+
use_cache=None,
|
|
649
|
+
output_attentions=None,
|
|
650
|
+
output_hidden_states=None,
|
|
651
|
+
return_dict=None,
|
|
652
|
+
):
|
|
653
|
+
output_attentions = (
|
|
654
|
+
output_attentions
|
|
655
|
+
if output_attentions is not None
|
|
656
|
+
else self.config.output_attentions
|
|
657
|
+
)
|
|
658
|
+
output_hidden_states = (
|
|
659
|
+
output_hidden_states
|
|
660
|
+
if output_hidden_states is not None
|
|
661
|
+
else self.config.output_hidden_states
|
|
662
|
+
)
|
|
663
|
+
return_dict = (
|
|
664
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
668
|
+
outputs = self.model.decoder(
|
|
669
|
+
input_ids=input_ids,
|
|
670
|
+
attention_mask=attention_mask,
|
|
671
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
672
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
673
|
+
head_mask=head_mask,
|
|
674
|
+
cross_attn_head_mask=cross_attn_head_mask,
|
|
675
|
+
past_key_values=past_key_values,
|
|
676
|
+
inputs_embeds=inputs_embeds,
|
|
677
|
+
use_cache=use_cache,
|
|
678
|
+
output_attentions=output_attentions,
|
|
679
|
+
output_hidden_states=output_hidden_states,
|
|
680
|
+
return_dict=return_dict,
|
|
681
|
+
)
|
|
682
|
+
logits = self.lm_head(outputs[0])
|
|
683
|
+
|
|
684
|
+
return CausalLMOutputWithCrossAttentions(
|
|
685
|
+
logits=logits,
|
|
686
|
+
past_key_values=outputs.past_key_values,
|
|
687
|
+
hidden_states=outputs.hidden_states,
|
|
688
|
+
attentions=outputs.attentions,
|
|
689
|
+
cross_attentions=outputs.cross_attentions,
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
class PPFormulaNet_Head(UniMERNetHead):
|
|
694
|
+
"""
|
|
695
|
+
PPFormulaNet_Head
|
|
696
|
+
Args:
|
|
697
|
+
max_new_tokens (int): Maximum number of new tokens to generate. Default is 1536.
|
|
698
|
+
decoder_start_token_id (int): Start token ID for the decoder. Default is 0.
|
|
699
|
+
temperature (float): Temperature parameter for controlling randomness in sampling. Default is 0.2.
|
|
700
|
+
do_sample (bool): Flag to determine whether to use sampling for generation. Default is False.
|
|
701
|
+
top_p (float): Top-p (nucleus) sampling parameter for controlling diversity. Default is 0.95.
|
|
702
|
+
in_channels (int): Number of input channels for the model. Default is 1024.
|
|
703
|
+
decoder_layers (int): Number of layers in the decoder. Default is 8.
|
|
704
|
+
encoder_hidden_size (int): Size of the hidden layer in the encoder. Default is 1024.
|
|
705
|
+
decoder_ffn_dim (int): Dimension of the feed-forward network in the decoder. Default is 4096.
|
|
706
|
+
decoder_hidden_size (int): Size of the hidden layer in the decoder. Default is 1024.
|
|
707
|
+
is_export (bool): Flag indicating whether the model is to be exported. Default is False.
|
|
708
|
+
length_aware (bool): Flag to determine if the model should be aware of input sequence length. Default is True.
|
|
709
|
+
use_parallel (bool): Flag to enable or disable parallel processing. Default is False.
|
|
710
|
+
parallel_step (int): Number of steps to use in parallel processing. Default is 3.
|
|
711
|
+
"""
|
|
712
|
+
|
|
713
|
+
def __init__(
|
|
714
|
+
self,
|
|
715
|
+
max_new_tokens=1536,
|
|
716
|
+
decoder_start_token_id=0,
|
|
717
|
+
temperature=0.2,
|
|
718
|
+
do_sample=False,
|
|
719
|
+
top_p=0.95,
|
|
720
|
+
in_channels=1024,
|
|
721
|
+
decoder_layers=8,
|
|
722
|
+
encoder_hidden_size=1024,
|
|
723
|
+
decoder_ffn_dim=4096,
|
|
724
|
+
decoder_hidden_size=1024,
|
|
725
|
+
is_export=False,
|
|
726
|
+
length_aware=True,
|
|
727
|
+
use_parallel=False,
|
|
728
|
+
parallel_step=3,
|
|
729
|
+
):
|
|
730
|
+
|
|
731
|
+
super().__init__()
|
|
732
|
+
|
|
733
|
+
mbart_config_dict = {
|
|
734
|
+
"activation_dropout": 0.0,
|
|
735
|
+
"activation_function": "gelu",
|
|
736
|
+
"add_cross_attention": True,
|
|
737
|
+
"add_final_layer_norm": True,
|
|
738
|
+
"attention_dropout": 0.0,
|
|
739
|
+
"bos_token_id": 0,
|
|
740
|
+
"classifier_dropout": 0.0,
|
|
741
|
+
"d_model": decoder_hidden_size,
|
|
742
|
+
"decoder_attention_heads": 16,
|
|
743
|
+
"decoder_ffn_dim": decoder_ffn_dim,
|
|
744
|
+
"decoder_layerdrop": 0.0,
|
|
745
|
+
"decoder_layers": decoder_layers,
|
|
746
|
+
"dropout": 0.1,
|
|
747
|
+
"encoder_attention_heads": 16,
|
|
748
|
+
"encoder_ffn_dim": 4096,
|
|
749
|
+
"encoder_layerdrop": 0.0,
|
|
750
|
+
"encoder_layers": 12,
|
|
751
|
+
"eos_token_id": 2,
|
|
752
|
+
"forced_eos_token_id": 2,
|
|
753
|
+
"init_std": 0.02,
|
|
754
|
+
"is_decoder": True,
|
|
755
|
+
"is_encoder_decoder": False,
|
|
756
|
+
"output_hidden_states": False,
|
|
757
|
+
"max_position_embeddings": (
|
|
758
|
+
max_new_tokens + parallel_step if use_parallel else max_new_tokens
|
|
759
|
+
),
|
|
760
|
+
"model_type": "mbart",
|
|
761
|
+
"num_hidden_layers": 12,
|
|
762
|
+
"pad_token_id": 1,
|
|
763
|
+
"scale_embedding": True,
|
|
764
|
+
"tie_word_embeddings": False,
|
|
765
|
+
"transformers_version": "4.40.0",
|
|
766
|
+
"use_cache": True,
|
|
767
|
+
"use_return_dict": True,
|
|
768
|
+
"vocab_size": 50000,
|
|
769
|
+
"_attn_implementation": "eager",
|
|
770
|
+
"hidden_size": decoder_hidden_size,
|
|
771
|
+
"use_parallel": use_parallel,
|
|
772
|
+
"parallel_step": int(parallel_step),
|
|
773
|
+
"is_export": is_export,
|
|
774
|
+
}
|
|
775
|
+
self.decoder_start_token_id = decoder_start_token_id
|
|
776
|
+
self.temperature = temperature
|
|
777
|
+
self.do_sample = do_sample
|
|
778
|
+
self.top_p = top_p
|
|
779
|
+
self.is_export = is_export
|
|
780
|
+
self.max_seq_len = max_new_tokens
|
|
781
|
+
self.config_decoder = MBartConfig(**mbart_config_dict)
|
|
782
|
+
self.encoder_hidden_size = encoder_hidden_size
|
|
783
|
+
self.decoder = CustomMBartForCausalLM(self.config_decoder)
|
|
784
|
+
if self.config_decoder.hidden_size != self.encoder_hidden_size:
|
|
785
|
+
self.enc_to_dec_proj = nn.Linear(
|
|
786
|
+
self.encoder_hidden_size, self.config_decoder.hidden_size
|
|
787
|
+
)
|
|
788
|
+
generation_config = {
|
|
789
|
+
"max_length": 1537,
|
|
790
|
+
"forced_eos_token_id": 2,
|
|
791
|
+
}
|
|
792
|
+
self.eos_token_id = generation_config["forced_eos_token_id"]
|
|
793
|
+
self.pad_token_id = self.config_decoder.pad_token_id
|
|
794
|
+
self.logits_processor = LogitsProcessorList()
|
|
795
|
+
self.logits_processor.append(
|
|
796
|
+
ForcedEOSTokenLogitsProcessor(
|
|
797
|
+
generation_config["max_length"],
|
|
798
|
+
generation_config["forced_eos_token_id"],
|
|
799
|
+
)
|
|
800
|
+
)
|
|
801
|
+
self.device = torch.device(get_device())
|
|
802
|
+
|
|
803
|
+
def prepare_inputs_for_generation(
|
|
804
|
+
self,
|
|
805
|
+
input_ids,
|
|
806
|
+
past_key_values=None,
|
|
807
|
+
attention_mask=None,
|
|
808
|
+
use_cache=None,
|
|
809
|
+
encoder_outputs=None,
|
|
810
|
+
**kwargs,
|
|
811
|
+
):
|
|
812
|
+
decoder_inputs = self.prepare_inputs_for_generation_mbart(
|
|
813
|
+
input_ids, past_key_values=past_key_values
|
|
814
|
+
)
|
|
815
|
+
decoder_attention_mask = (
|
|
816
|
+
decoder_inputs["attention_mask"]
|
|
817
|
+
if "attention_mask" in decoder_inputs
|
|
818
|
+
else None
|
|
819
|
+
)
|
|
820
|
+
input_dict = {
|
|
821
|
+
"attention_mask": attention_mask,
|
|
822
|
+
"decoder_attention_mask": decoder_attention_mask,
|
|
823
|
+
"decoder_input_ids": decoder_inputs["input_ids"],
|
|
824
|
+
"past_key_values": decoder_inputs["past_key_values"],
|
|
825
|
+
"use_cache": use_cache,
|
|
826
|
+
}
|
|
827
|
+
return input_dict
|
|
828
|
+
|
|
829
|
+
def _extract_past_from_model_output(
|
|
830
|
+
self, outputs: ModelOutput, standardize_cache_format: bool = False
|
|
831
|
+
):
|
|
832
|
+
past_key_values = None
|
|
833
|
+
if "past_key_values" in outputs:
|
|
834
|
+
past_key_values = outputs.past_key_values
|
|
835
|
+
elif "mems" in outputs:
|
|
836
|
+
past_key_values = outputs.mems
|
|
837
|
+
elif "past_buckets_states" in outputs:
|
|
838
|
+
past_key_values = outputs.past_buckets_states
|
|
839
|
+
return past_key_values
|
|
840
|
+
|
|
841
|
+
def _update_model_kwargs_for_generation(
|
|
842
|
+
self,
|
|
843
|
+
outputs: ModelOutput,
|
|
844
|
+
model_kwargs: Dict[str, Any],
|
|
845
|
+
is_encoder_decoder: bool = False,
|
|
846
|
+
standardize_cache_format: bool = False,
|
|
847
|
+
) -> Dict[str, Any]:
|
|
848
|
+
# update past_key_values
|
|
849
|
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
|
850
|
+
outputs, standardize_cache_format=standardize_cache_format
|
|
851
|
+
)
|
|
852
|
+
if getattr(outputs, "state", None) is not None:
|
|
853
|
+
model_kwargs["state"] = outputs.state
|
|
854
|
+
|
|
855
|
+
# update token_type_ids with last value
|
|
856
|
+
if "token_type_ids" in model_kwargs:
|
|
857
|
+
token_type_ids = model_kwargs["token_type_ids"]
|
|
858
|
+
model_kwargs["token_type_ids"] = torch.concat(
|
|
859
|
+
[token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
|
|
860
|
+
)
|
|
861
|
+
|
|
862
|
+
if not is_encoder_decoder:
|
|
863
|
+
# update attention mask
|
|
864
|
+
if "attention_mask" in model_kwargs:
|
|
865
|
+
attention_mask = model_kwargs["attention_mask"]
|
|
866
|
+
model_kwargs["attention_mask"] = torch.concat(
|
|
867
|
+
[
|
|
868
|
+
attention_mask,
|
|
869
|
+
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
|
870
|
+
],
|
|
871
|
+
dim=-1,
|
|
872
|
+
)
|
|
873
|
+
else:
|
|
874
|
+
# update decoder attention mask
|
|
875
|
+
if "decoder_attention_mask" in model_kwargs:
|
|
876
|
+
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
|
877
|
+
model_kwargs["decoder_attention_mask"] = torch.concat(
|
|
878
|
+
[
|
|
879
|
+
decoder_attention_mask,
|
|
880
|
+
decoder_attention_mask.new_ones(
|
|
881
|
+
(decoder_attention_mask.shape[0], 1)
|
|
882
|
+
),
|
|
883
|
+
],
|
|
884
|
+
dim=-1,
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
if (
|
|
888
|
+
"cache_position" in model_kwargs
|
|
889
|
+
and model_kwargs["cache_position"] is not None
|
|
890
|
+
):
|
|
891
|
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
|
892
|
+
return model_kwargs
|
|
893
|
+
|
|
894
|
+
def stopping_criteria(self, input_ids):
|
|
895
|
+
if self.is_export:
|
|
896
|
+
return input_ids[:, -1].cpu() == torch.Tensor([self.eos_token_id])
|
|
897
|
+
is_done = torch.isin(input_ids[:, -1].cpu(), torch.Tensor([self.eos_token_id]))
|
|
898
|
+
return is_done
|
|
899
|
+
|
|
900
|
+
def stopping_criteria_parallel(self, input_ids):
|
|
901
|
+
parallel_step = self.config_decoder.parallel_step
|
|
902
|
+
|
|
903
|
+
if self.is_export:
|
|
904
|
+
is_done_list = []
|
|
905
|
+
for i in range(parallel_step, 0, -1):
|
|
906
|
+
cur_is_done = input_ids[:, -i] == torch.Tensor([self.eos_token_id])
|
|
907
|
+
is_done_list.append(cur_is_done)
|
|
908
|
+
is_done_list = torch.Tensor(is_done_list).permute([1, 0])
|
|
909
|
+
return is_done_list
|
|
910
|
+
else:
|
|
911
|
+
is_done = torch.isin(
|
|
912
|
+
input_ids[:, -parallel_step:],
|
|
913
|
+
torch.Tensor([self.eos_token_id]).reshape([1, 1]),
|
|
914
|
+
)
|
|
915
|
+
return torch.Tensor(is_done)
|
|
916
|
+
|
|
917
|
+
def generate_single_iter(
|
|
918
|
+
self,
|
|
919
|
+
decoder_input_ids=None,
|
|
920
|
+
decoder_attention_mask=None,
|
|
921
|
+
encoder_outputs=None,
|
|
922
|
+
past_key_values=None,
|
|
923
|
+
decoder_inputs_embeds=None,
|
|
924
|
+
labels=None,
|
|
925
|
+
use_cache=None,
|
|
926
|
+
output_attentions=None,
|
|
927
|
+
output_hidden_states=None,
|
|
928
|
+
return_dict=None,
|
|
929
|
+
**kwargs,
|
|
930
|
+
):
|
|
931
|
+
|
|
932
|
+
encoder_hidden_states = encoder_outputs[0]
|
|
933
|
+
if self.config_decoder.hidden_size != self.encoder_hidden_size:
|
|
934
|
+
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
|
935
|
+
kwargs_decoder = {}
|
|
936
|
+
decoder_outputs = self.decoder(
|
|
937
|
+
input_ids=decoder_input_ids,
|
|
938
|
+
attention_mask=decoder_attention_mask,
|
|
939
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
940
|
+
encoder_attention_mask=None,
|
|
941
|
+
inputs_embeds=None,
|
|
942
|
+
output_attentions=False,
|
|
943
|
+
output_hidden_states=output_hidden_states,
|
|
944
|
+
use_cache=use_cache,
|
|
945
|
+
past_key_values=past_key_values,
|
|
946
|
+
return_dict=return_dict,
|
|
947
|
+
**kwargs_decoder,
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
return Seq2SeqLMOutput(
|
|
951
|
+
loss=None,
|
|
952
|
+
logits=decoder_outputs.logits,
|
|
953
|
+
past_key_values=decoder_outputs.past_key_values,
|
|
954
|
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
955
|
+
decoder_attentions=decoder_outputs.attentions,
|
|
956
|
+
cross_attentions=decoder_outputs.cross_attentions,
|
|
957
|
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
|
958
|
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
959
|
+
encoder_attentions=encoder_outputs.attentions,
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
def _prepare_decoder_input_ids_for_generation(
|
|
963
|
+
self,
|
|
964
|
+
batch_size,
|
|
965
|
+
model_kwargs,
|
|
966
|
+
decoder_start_token_id=None,
|
|
967
|
+
bos_token_id=None,
|
|
968
|
+
):
|
|
969
|
+
|
|
970
|
+
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
|
|
971
|
+
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
|
|
972
|
+
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
|
973
|
+
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
|
|
974
|
+
elif "input_ids" in model_kwargs:
|
|
975
|
+
decoder_input_ids = model_kwargs.pop("input_ids")
|
|
976
|
+
else:
|
|
977
|
+
decoder_input_ids = None
|
|
978
|
+
|
|
979
|
+
# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
|
|
980
|
+
decoder_start_token_id = self._get_decoder_start_token_id(
|
|
981
|
+
decoder_start_token_id, bos_token_id
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
if isinstance(decoder_start_token_id, list):
|
|
985
|
+
if len(decoder_start_token_id) != batch_size:
|
|
986
|
+
raise ValueError(
|
|
987
|
+
f"`decoder_start_token_id` expected to have length {batch_size} but got {len(decoder_start_token_id)}"
|
|
988
|
+
)
|
|
989
|
+
decoder_input_ids_start = torch.Tensor(
|
|
990
|
+
decoder_start_token_id
|
|
991
|
+
).to(torch.int64)
|
|
992
|
+
decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
|
|
993
|
+
else:
|
|
994
|
+
use_parallel = self.config_decoder.use_parallel
|
|
995
|
+
parallel_step = self.config_decoder.parallel_step
|
|
996
|
+
|
|
997
|
+
if use_parallel:
|
|
998
|
+
decoder_input_ids_start = (
|
|
999
|
+
torch.ones(
|
|
1000
|
+
(batch_size, parallel_step),
|
|
1001
|
+
dtype=torch.int64,
|
|
1002
|
+
device=self.device,
|
|
1003
|
+
)
|
|
1004
|
+
* decoder_start_token_id
|
|
1005
|
+
)
|
|
1006
|
+
else:
|
|
1007
|
+
decoder_input_ids_start = (
|
|
1008
|
+
torch.ones(
|
|
1009
|
+
(batch_size, 1),
|
|
1010
|
+
dtype=torch.int64,
|
|
1011
|
+
device=self.device,
|
|
1012
|
+
)
|
|
1013
|
+
* decoder_start_token_id
|
|
1014
|
+
)
|
|
1015
|
+
# no user input -> use decoder_start_token_id as decoder_input_ids
|
|
1016
|
+
if decoder_input_ids is None:
|
|
1017
|
+
decoder_input_ids = decoder_input_ids_start
|
|
1018
|
+
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
|
|
1019
|
+
elif (
|
|
1020
|
+
self.config.model_type == "vision-encoder-decoder"
|
|
1021
|
+
and "donut" in self.name_or_path.lower()
|
|
1022
|
+
):
|
|
1023
|
+
pass
|
|
1024
|
+
elif self.config.model_type in ["whisper"]:
|
|
1025
|
+
pass
|
|
1026
|
+
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
|
|
1027
|
+
# decoder_attention_mask if provided)
|
|
1028
|
+
elif (
|
|
1029
|
+
isinstance(decoder_start_token_id, int)
|
|
1030
|
+
and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
|
|
1031
|
+
) or (
|
|
1032
|
+
isinstance(decoder_start_token_id, torch.Tensor)
|
|
1033
|
+
and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
|
|
1034
|
+
):
|
|
1035
|
+
decoder_input_ids = torch.concat(
|
|
1036
|
+
[decoder_input_ids_start, decoder_input_ids], dim=-1
|
|
1037
|
+
)
|
|
1038
|
+
if "decoder_attention_mask" in model_kwargs:
|
|
1039
|
+
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
|
1040
|
+
decoder_attention_mask = torch.cat(
|
|
1041
|
+
(
|
|
1042
|
+
torch.ones_like(decoder_attention_mask)[:, :1],
|
|
1043
|
+
decoder_attention_mask,
|
|
1044
|
+
),
|
|
1045
|
+
dim=-1,
|
|
1046
|
+
)
|
|
1047
|
+
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
|
|
1048
|
+
|
|
1049
|
+
return decoder_input_ids, model_kwargs
|
|
1050
|
+
|
|
1051
|
+
@torch.no_grad()
|
|
1052
|
+
def generate_export(
|
|
1053
|
+
self,
|
|
1054
|
+
encoder_outputs,
|
|
1055
|
+
model_kwargs,
|
|
1056
|
+
):
|
|
1057
|
+
use_parallel = self.config_decoder.use_parallel
|
|
1058
|
+
parallel_step = self.config_decoder.parallel_step
|
|
1059
|
+
batch_size = encoder_outputs["last_hidden_state"].shape[0]
|
|
1060
|
+
generation_config = {
|
|
1061
|
+
"decoder_start_token_id": 0,
|
|
1062
|
+
"bos_token_id": 0,
|
|
1063
|
+
}
|
|
1064
|
+
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
|
1065
|
+
batch_size=batch_size,
|
|
1066
|
+
model_kwargs=model_kwargs,
|
|
1067
|
+
decoder_start_token_id=generation_config["decoder_start_token_id"],
|
|
1068
|
+
bos_token_id=generation_config["bos_token_id"],
|
|
1069
|
+
)
|
|
1070
|
+
if not use_parallel:
|
|
1071
|
+
input_ids = input_ids.reshape([-1, 1])
|
|
1072
|
+
decoder_input_ids = input_ids
|
|
1073
|
+
model_kwargs["key use_cache"] = True
|
|
1074
|
+
batch_size, cur_len = input_ids.shape
|
|
1075
|
+
|
|
1076
|
+
if "inputs_embeds" in model_kwargs:
|
|
1077
|
+
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
|
1078
|
+
|
|
1079
|
+
cache_position = torch.arange(cur_len)
|
|
1080
|
+
pad_token_id = self.pad_token_id
|
|
1081
|
+
eos_token_id = [self.eos_token_id]
|
|
1082
|
+
eos_token = self.eos_token_id
|
|
1083
|
+
if use_parallel:
|
|
1084
|
+
unfinished_sequences = torch.ones(
|
|
1085
|
+
[batch_size, parallel_step], dtype=torch.int64, device=self.device
|
|
1086
|
+
)
|
|
1087
|
+
parallel_length = math.ceil(self.max_seq_len // parallel_step)
|
|
1088
|
+
else:
|
|
1089
|
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.int64, device=self.device)
|
|
1090
|
+
parallel_length = self.max_seq_len
|
|
1091
|
+
|
|
1092
|
+
i_idx = 0
|
|
1093
|
+
past_key_values = []
|
|
1094
|
+
decoder_attention_heads = self.config_decoder.decoder_attention_heads
|
|
1095
|
+
decoder_attention_heads_dim = int(
|
|
1096
|
+
self.config_decoder.d_model / decoder_attention_heads
|
|
1097
|
+
)
|
|
1098
|
+
for i in range(self.config_decoder.decoder_layers):
|
|
1099
|
+
init_arr = torch.zeros(
|
|
1100
|
+
[batch_size, decoder_attention_heads, 0, decoder_attention_heads_dim]
|
|
1101
|
+
)
|
|
1102
|
+
cache = (init_arr, init_arr, init_arr, init_arr)
|
|
1103
|
+
past_key_values.append(cache)
|
|
1104
|
+
|
|
1105
|
+
while i_idx < parallel_length:
|
|
1106
|
+
|
|
1107
|
+
model_inputs = self.prepare_inputs_for_generation_export(
|
|
1108
|
+
past_key_values=past_key_values, **model_kwargs
|
|
1109
|
+
)
|
|
1110
|
+
decoder_attention_mask = torch.ones(input_ids.shape, device=self.device)
|
|
1111
|
+
|
|
1112
|
+
outputs = self.generate_single_iter(
|
|
1113
|
+
decoder_input_ids=decoder_input_ids,
|
|
1114
|
+
decoder_attention_mask=decoder_attention_mask,
|
|
1115
|
+
encoder_outputs=encoder_outputs,
|
|
1116
|
+
past_key_values=past_key_values,
|
|
1117
|
+
return_dict=True,
|
|
1118
|
+
output_attentions=False,
|
|
1119
|
+
output_hidden_states=False,
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
if use_parallel:
|
|
1123
|
+
next_token_logits = outputs.logits[:, -parallel_step:, :]
|
|
1124
|
+
else:
|
|
1125
|
+
next_token_logits = outputs.logits[:, -1, :]
|
|
1126
|
+
next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
|
|
1127
|
+
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
|
|
1128
|
+
|
|
1129
|
+
if eos_token_id is not None:
|
|
1130
|
+
# False
|
|
1131
|
+
if pad_token_id is None:
|
|
1132
|
+
raise ValueError(
|
|
1133
|
+
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
|
1134
|
+
)
|
|
1135
|
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
|
1136
|
+
1 - unfinished_sequences
|
|
1137
|
+
)
|
|
1138
|
+
if use_parallel:
|
|
1139
|
+
input_ids = torch.concat([input_ids, next_tokens], dim=-1)
|
|
1140
|
+
decoder_input_ids = next_tokens
|
|
1141
|
+
else:
|
|
1142
|
+
input_ids = torch.concat(
|
|
1143
|
+
[input_ids, next_tokens.unsqueeze(1)], dim=-1
|
|
1144
|
+
)
|
|
1145
|
+
decoder_input_ids = next_tokens.unsqueeze(1)
|
|
1146
|
+
|
|
1147
|
+
past_length = past_key_values[0][0].shape[2]
|
|
1148
|
+
|
|
1149
|
+
past_key_values = outputs.past_key_values
|
|
1150
|
+
cache_position = cache_position[-1:] + 1
|
|
1151
|
+
if use_parallel:
|
|
1152
|
+
unfinished_sequences = (
|
|
1153
|
+
unfinished_sequences
|
|
1154
|
+
& ~self.stopping_criteria_parallel(input_ids).to(torch.int64).to(self.device)
|
|
1155
|
+
)
|
|
1156
|
+
else:
|
|
1157
|
+
unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
|
|
1158
|
+
input_ids
|
|
1159
|
+
).to(torch.int64).to(self.device)
|
|
1160
|
+
|
|
1161
|
+
if (
|
|
1162
|
+
eos_token is not None
|
|
1163
|
+
and (
|
|
1164
|
+
torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
|
|
1165
|
+
>= 1
|
|
1166
|
+
).all()
|
|
1167
|
+
):
|
|
1168
|
+
break
|
|
1169
|
+
i_idx += 1
|
|
1170
|
+
# break
|
|
1171
|
+
|
|
1172
|
+
return input_ids
|
|
1173
|
+
|
|
1174
|
+
@torch.no_grad()
|
|
1175
|
+
def generate(
|
|
1176
|
+
self,
|
|
1177
|
+
encoder_outputs,
|
|
1178
|
+
model_kwargs,
|
|
1179
|
+
):
|
|
1180
|
+
"""
|
|
1181
|
+
Generate sequences from the model without computing gradients.
|
|
1182
|
+
|
|
1183
|
+
This method is used to generate sequences from the model based on the given encoder outputs.
|
|
1184
|
+
It does not compute gradients, making it suitable for inference.
|
|
1185
|
+
|
|
1186
|
+
Args:
|
|
1187
|
+
encoder_outputs: The outputs from the encoder, typically including hidden states necessary for generation.
|
|
1188
|
+
model_kwargs: Additional keyword arguments that may include parameters such as maximum length,
|
|
1189
|
+
temperature, top-k/top-p sampling parameters, and other generation-specific settings.
|
|
1190
|
+
|
|
1191
|
+
Returns:
|
|
1192
|
+
Generated sequences based on the encoder outputs and specified generation parameters.
|
|
1193
|
+
"""
|
|
1194
|
+
use_parallel = self.config_decoder.use_parallel
|
|
1195
|
+
parallel_step = self.config_decoder.parallel_step
|
|
1196
|
+
batch_size = encoder_outputs["last_hidden_state"].shape[0]
|
|
1197
|
+
generation_config = {
|
|
1198
|
+
"decoder_start_token_id": 0,
|
|
1199
|
+
"bos_token_id": 0,
|
|
1200
|
+
}
|
|
1201
|
+
|
|
1202
|
+
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
|
1203
|
+
batch_size=batch_size,
|
|
1204
|
+
model_kwargs=model_kwargs,
|
|
1205
|
+
decoder_start_token_id=generation_config["decoder_start_token_id"],
|
|
1206
|
+
bos_token_id=generation_config["bos_token_id"],
|
|
1207
|
+
)
|
|
1208
|
+
|
|
1209
|
+
decoder_input_ids = input_ids
|
|
1210
|
+
model_kwargs["key use_cache"] = True
|
|
1211
|
+
batch_size, cur_len = input_ids.shape
|
|
1212
|
+
|
|
1213
|
+
if "inputs_embeds" in model_kwargs:
|
|
1214
|
+
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
|
1215
|
+
model_kwargs["cache_position"] = torch.arange(cur_len)
|
|
1216
|
+
pad_token_id = self.pad_token_id
|
|
1217
|
+
eos_token_id = [self.eos_token_id]
|
|
1218
|
+
eos_token = self.eos_token_id
|
|
1219
|
+
if use_parallel:
|
|
1220
|
+
unfinished_sequences = torch.ones(
|
|
1221
|
+
[batch_size, parallel_step], dtype=torch.int64
|
|
1222
|
+
)
|
|
1223
|
+
parallel_length = math.ceil(self.max_seq_len // parallel_step)
|
|
1224
|
+
else:
|
|
1225
|
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.int64)
|
|
1226
|
+
parallel_length = self.max_seq_len
|
|
1227
|
+
past_key_values = []
|
|
1228
|
+
|
|
1229
|
+
for idx in range(parallel_length):
|
|
1230
|
+
|
|
1231
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
1232
|
+
outputs = self.generate_single_iter(
|
|
1233
|
+
**model_inputs,
|
|
1234
|
+
encoder_outputs=encoder_outputs,
|
|
1235
|
+
return_dict=True,
|
|
1236
|
+
output_attentions=False,
|
|
1237
|
+
output_hidden_states=False,
|
|
1238
|
+
)
|
|
1239
|
+
|
|
1240
|
+
if use_parallel:
|
|
1241
|
+
next_token_logits = outputs.logits[:, :, :]
|
|
1242
|
+
else:
|
|
1243
|
+
next_token_logits = outputs.logits[:, -1, :]
|
|
1244
|
+
|
|
1245
|
+
next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
|
|
1246
|
+
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
|
|
1247
|
+
if eos_token_id is not None:
|
|
1248
|
+
# False
|
|
1249
|
+
if pad_token_id is None:
|
|
1250
|
+
raise ValueError(
|
|
1251
|
+
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
|
1252
|
+
)
|
|
1253
|
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
|
1254
|
+
1 - unfinished_sequences
|
|
1255
|
+
)
|
|
1256
|
+
if use_parallel:
|
|
1257
|
+
input_ids = torch.concat([input_ids, next_tokens], dim=-1)
|
|
1258
|
+
else:
|
|
1259
|
+
input_ids = torch.concat([input_ids, next_tokens[:, None]], dim=-1)
|
|
1260
|
+
|
|
1261
|
+
model_kwargs = self._update_model_kwargs_for_generation(
|
|
1262
|
+
outputs,
|
|
1263
|
+
model_kwargs,
|
|
1264
|
+
is_encoder_decoder=self.config_decoder.is_encoder_decoder,
|
|
1265
|
+
)
|
|
1266
|
+
if use_parallel:
|
|
1267
|
+
unfinished_sequences = (
|
|
1268
|
+
unfinished_sequences
|
|
1269
|
+
& ~self.stopping_criteria_parallel(input_ids).to(torch.int64)
|
|
1270
|
+
)
|
|
1271
|
+
else:
|
|
1272
|
+
unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
|
|
1273
|
+
input_ids
|
|
1274
|
+
).to(torch.int64)
|
|
1275
|
+
|
|
1276
|
+
if (
|
|
1277
|
+
eos_token is not None
|
|
1278
|
+
and (
|
|
1279
|
+
torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
|
|
1280
|
+
>= 1
|
|
1281
|
+
).all()
|
|
1282
|
+
):
|
|
1283
|
+
break
|
|
1284
|
+
return input_ids
|
|
1285
|
+
|
|
1286
|
+
def forwad_train(
|
|
1287
|
+
self,
|
|
1288
|
+
encoder_outputs,
|
|
1289
|
+
decoder_input_ids,
|
|
1290
|
+
decoder_attention_mask,
|
|
1291
|
+
past_key_values=None,
|
|
1292
|
+
decoder_inputs_embeds=None,
|
|
1293
|
+
labels=None,
|
|
1294
|
+
use_cache=None,
|
|
1295
|
+
output_attentions=None,
|
|
1296
|
+
output_hidden_states=None,
|
|
1297
|
+
return_dict=None,
|
|
1298
|
+
**kwargs,
|
|
1299
|
+
):
|
|
1300
|
+
"""
|
|
1301
|
+
Forward pass for training the model.
|
|
1302
|
+
|
|
1303
|
+
Args:
|
|
1304
|
+
encoder_outputs: The outputs from the encoder, typically including hidden states.
|
|
1305
|
+
decoder_input_ids: Input IDs for the decoder.
|
|
1306
|
+
decoder_attention_mask: Attention mask for the decoder inputs to avoid attending to padding tokens.
|
|
1307
|
+
past_key_values: Previously computed key and value states for the decoder, used for fast generation.
|
|
1308
|
+
decoder_inputs_embeds: Optional embeddings for decoder inputs, used instead of decoder_input_ids if provided.
|
|
1309
|
+
labels: Labels for computing the training loss.
|
|
1310
|
+
use_cache: Whether to use a cache of past key values for faster generation.
|
|
1311
|
+
output_attentions: Whether to output attention weights.
|
|
1312
|
+
output_hidden_states: Whether to output hidden states of all layers.
|
|
1313
|
+
return_dict: Whether to return the output as a dictionary.
|
|
1314
|
+
**kwargs: Additional keyword arguments.
|
|
1315
|
+
|
|
1316
|
+
Returns:
|
|
1317
|
+
Depending on the `return_dict` flag, returns either a dictionary of model outputs or a tuple.
|
|
1318
|
+
"""
|
|
1319
|
+
if self.config_decoder.use_parallel:
|
|
1320
|
+
batch = decoder_input_ids.shape[0]
|
|
1321
|
+
add_sos_token = self.config_decoder.parallel_step - 1
|
|
1322
|
+
start_token = torch.zeros([batch, add_sos_token]).to(torch.int64)
|
|
1323
|
+
start_mask = torch.ones([batch, add_sos_token]).to(torch.int64)
|
|
1324
|
+
decoder_input_ids = torch.concat([start_token, decoder_input_ids], dim=1)
|
|
1325
|
+
decoder_attention_mask = torch.concat(
|
|
1326
|
+
[start_mask, decoder_attention_mask], dim=1
|
|
1327
|
+
)
|
|
1328
|
+
|
|
1329
|
+
labels = decoder_input_ids * 1
|
|
1330
|
+
labels = labels.masked_fill_(labels == self.pad_token_id, -100)
|
|
1331
|
+
if self.config_decoder.use_parallel:
|
|
1332
|
+
input_decoder_input_ids = decoder_input_ids[
|
|
1333
|
+
:, : -self.config_decoder.parallel_step
|
|
1334
|
+
]
|
|
1335
|
+
input_decoder_attention_mask = decoder_attention_mask[
|
|
1336
|
+
:, : -self.config_decoder.parallel_step
|
|
1337
|
+
]
|
|
1338
|
+
else:
|
|
1339
|
+
input_decoder_input_ids = decoder_input_ids[:, :-1]
|
|
1340
|
+
input_decoder_attention_mask = decoder_attention_mask[:, :-1]
|
|
1341
|
+
|
|
1342
|
+
encoder_hidden_states = encoder_outputs[0]
|
|
1343
|
+
kwargs_decoder = {}
|
|
1344
|
+
if self.config_decoder.hidden_size != self.encoder_hidden_size:
|
|
1345
|
+
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
|
1346
|
+
|
|
1347
|
+
decoder_outputs = self.decoder(
|
|
1348
|
+
input_ids=input_decoder_input_ids,
|
|
1349
|
+
attention_mask=input_decoder_attention_mask,
|
|
1350
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
1351
|
+
encoder_attention_mask=None,
|
|
1352
|
+
inputs_embeds=None,
|
|
1353
|
+
output_attentions=False,
|
|
1354
|
+
output_hidden_states=output_hidden_states,
|
|
1355
|
+
use_cache=use_cache,
|
|
1356
|
+
past_key_values=past_key_values,
|
|
1357
|
+
return_dict=return_dict,
|
|
1358
|
+
**kwargs_decoder,
|
|
1359
|
+
)
|
|
1360
|
+
|
|
1361
|
+
logits = decoder_outputs.logits
|
|
1362
|
+
return logits, labels
|
|
1363
|
+
|
|
1364
|
+
# forward for export
|
|
1365
|
+
def forward(self, inputs, targets=None):
|
|
1366
|
+
self.is_export = False if self.training else True
|
|
1367
|
+
if not self.training:
|
|
1368
|
+
encoder_outputs = inputs
|
|
1369
|
+
model_kwargs = {
|
|
1370
|
+
"output_attentions": False,
|
|
1371
|
+
"output_hidden_states": False,
|
|
1372
|
+
"use_cache": True,
|
|
1373
|
+
}
|
|
1374
|
+
if self.is_export:
|
|
1375
|
+
word_pred = self.generate_export(encoder_outputs, model_kwargs)
|
|
1376
|
+
else:
|
|
1377
|
+
word_pred = self.generate(encoder_outputs, model_kwargs)
|
|
1378
|
+
|
|
1379
|
+
return word_pred
|
|
1380
|
+
encoder_outputs, tgt_seq, mask = inputs
|
|
1381
|
+
logits, masked_labels = self.forwad_train(encoder_outputs, tgt_seq, mask)
|
|
1382
|
+
|
|
1383
|
+
return logits, masked_labels
|