mineru 2.5.3__py3-none-any.whl → 2.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mineru/backend/pipeline/model_init.py +25 -3
- mineru/backend/pipeline/model_json_to_middle_json.py +2 -2
- mineru/backend/pipeline/model_list.py +0 -1
- mineru/backend/utils.py +24 -0
- mineru/backend/vlm/model_output_to_middle_json.py +2 -2
- mineru/backend/vlm/{custom_logits_processors.py → utils.py} +36 -2
- mineru/backend/vlm/vlm_analyze.py +43 -50
- mineru/backend/vlm/vlm_magic_model.py +155 -1
- mineru/cli/common.py +26 -23
- mineru/cli/fast_api.py +2 -8
- mineru/cli/gradio_app.py +104 -13
- mineru/cli/models_download.py +1 -0
- mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py +152 -0
- mineru/model/mfr/pp_formulanet_plus_m/processors.py +657 -0
- mineru/model/mfr/unimernet/unimernet_hf/modeling_unimernet.py +1 -326
- mineru/model/mfr/utils.py +338 -0
- mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py +103 -16
- mineru/model/table/rec/unet_table/main.py +1 -1
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/imaug/operators.py +5 -5
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/__init__.py +2 -1
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_lcnetv3.py +7 -7
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_pphgnetv2.py +2 -2
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/__init__.py +2 -0
- mineru/model/utils/pytorchocr/modeling/heads/rec_ppformulanet_head.py +1383 -0
- mineru/model/utils/pytorchocr/modeling/heads/rec_unimernet_head.py +2631 -0
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/rec_postprocess.py +25 -28
- mineru/model/utils/pytorchocr/utils/__init__.py +0 -0
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/arch_config.yaml +130 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_arabic_dict.txt +747 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_cyrillic_dict.txt +850 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_devanagari_dict.txt +568 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_ta_dict.txt +513 -0
- mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_te_dict.txt +540 -0
- mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/models_config.yml +15 -15
- mineru/model/utils/pytorchocr/utils/resources/pp_formulanet_arch_config.yaml +24 -0
- mineru/model/utils/tools/infer/__init__.py +1 -0
- mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_det.py +6 -3
- mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_rec.py +16 -25
- mineru/model/vlm_vllm_model/server.py +4 -1
- mineru/resources/header.html +2 -2
- mineru/utils/enum_class.py +1 -0
- mineru/utils/guess_suffix_or_lang.py +9 -1
- mineru/utils/llm_aided.py +4 -2
- mineru/utils/ocr_utils.py +16 -0
- mineru/utils/table_merge.py +102 -13
- mineru/version.py +1 -1
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/METADATA +33 -6
- mineru-2.6.0.dist-info/RECORD +195 -0
- mineru-2.5.3.dist-info/RECORD +0 -181
- /mineru/model/{ocr/paddleocr2pytorch/pytorchocr → mfr/pp_formulanet_plus_m}/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch/tools/infer → utils}/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch/pytorchocr/modeling → utils/pytorchocr}/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/base_ocr_v20.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/imaug/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch/pytorchocr/utils → utils/pytorchocr/modeling}/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/architectures/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/architectures/base_model.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/det_mobilenet_v3.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_donut_swin.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_hgnet.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_mv1_enhance.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_svtrnet.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/common.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/cls_head.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/det_db_head.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/rec_ctc_head.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/rec_multi_head.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/db_fpn.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/intracl.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/rnn.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/cls_postprocess.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/db_postprocess.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/arabic_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/cyrillic_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/devanagari_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/en_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/japan_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ka_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/korean_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/latin_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv4_doc_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_el_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_en_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_eslav_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_korean_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_latin_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_th_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ta_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/te_dict.txt +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/__init__.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_cls.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_system.py +0 -0
- /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/pytorchocr_utility.py +0 -0
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/WHEEL +0 -0
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/entry_points.txt +0 -0
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/licenses/LICENSE.md +0 -0
- {mineru-2.5.3.dist-info → mineru-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,2631 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import math
|
|
3
|
+
import re
|
|
4
|
+
import numpy as np
|
|
5
|
+
import inspect
|
|
6
|
+
import warnings
|
|
7
|
+
from collections import OrderedDict
|
|
8
|
+
from typing import Optional, Tuple, Union, List, Dict, Any
|
|
9
|
+
from dataclasses import dataclass, fields, is_dataclass
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
from torch import Tensor
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
from torch.nn import CrossEntropyLoss
|
|
16
|
+
|
|
17
|
+
from mineru.utils.config_reader import get_device
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ModelOutput(OrderedDict):
|
|
21
|
+
|
|
22
|
+
def __init__(self, *args, **kwargs):
|
|
23
|
+
super().__init__(*args, **kwargs)
|
|
24
|
+
|
|
25
|
+
def __post_init__(self):
|
|
26
|
+
class_fields = fields(self)
|
|
27
|
+
|
|
28
|
+
if not len(class_fields):
|
|
29
|
+
raise ValueError(f"{self.__class__.__name__} has no fields.")
|
|
30
|
+
if not all(field.default is None for field in class_fields[1:]):
|
|
31
|
+
raise ValueError(
|
|
32
|
+
f"{self.__class__.__name__} should not have more than one required field."
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
first_field = getattr(self, class_fields[0].name)
|
|
36
|
+
other_fields_are_none = all(
|
|
37
|
+
getattr(self, field.name) is None for field in class_fields[1:]
|
|
38
|
+
)
|
|
39
|
+
if other_fields_are_none:
|
|
40
|
+
if isinstance(first_field, dict):
|
|
41
|
+
iterator = first_field.items()
|
|
42
|
+
first_field_iterator = True
|
|
43
|
+
else:
|
|
44
|
+
try:
|
|
45
|
+
iterator = iter(first_field)
|
|
46
|
+
first_field_iterator = True
|
|
47
|
+
except TypeError:
|
|
48
|
+
first_field_iterator = False
|
|
49
|
+
|
|
50
|
+
if first_field_iterator:
|
|
51
|
+
for idx, element in enumerate(iterator):
|
|
52
|
+
if (
|
|
53
|
+
not isinstance(element, (list, tuple))
|
|
54
|
+
or not len(element) == 2
|
|
55
|
+
or not isinstance(element[0], str)
|
|
56
|
+
):
|
|
57
|
+
if idx == 0:
|
|
58
|
+
self[class_fields[0].name] = first_field
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
|
|
62
|
+
)
|
|
63
|
+
break
|
|
64
|
+
setattr(self, element[0], element[1])
|
|
65
|
+
if element[1] is not None:
|
|
66
|
+
self[element[0]] = element[1]
|
|
67
|
+
elif first_field is not None:
|
|
68
|
+
self[class_fields[0].name] = first_field
|
|
69
|
+
else:
|
|
70
|
+
for field in class_fields:
|
|
71
|
+
v = getattr(self, field.name)
|
|
72
|
+
if v is not None:
|
|
73
|
+
self[field.name] = v
|
|
74
|
+
|
|
75
|
+
def __delitem__(self, *args, **kwargs):
|
|
76
|
+
raise Exception(
|
|
77
|
+
f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def setdefault(self, *args, **kwargs):
|
|
81
|
+
raise Exception(
|
|
82
|
+
f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def pop(self, *args, **kwargs):
|
|
86
|
+
raise Exception(
|
|
87
|
+
f"You cannot use ``pop`` on a {self.__class__.__name__} instance."
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def update(self, *args, **kwargs):
|
|
91
|
+
raise Exception(
|
|
92
|
+
f"You cannot use ``update`` on a {self.__class__.__name__} instance."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def __getitem__(self, k):
|
|
96
|
+
if isinstance(k, str):
|
|
97
|
+
inner_dict = dict(self.items())
|
|
98
|
+
return inner_dict[k]
|
|
99
|
+
else:
|
|
100
|
+
return self.to_tuple()[k]
|
|
101
|
+
|
|
102
|
+
def __setattr__(self, name, value):
|
|
103
|
+
if name in self.keys() and value is not None:
|
|
104
|
+
super().__setitem__(name, value)
|
|
105
|
+
super().__setattr__(name, value)
|
|
106
|
+
|
|
107
|
+
def __setitem__(self, key, value):
|
|
108
|
+
super().__setitem__(key, value)
|
|
109
|
+
super().__setattr__(key, value)
|
|
110
|
+
|
|
111
|
+
def __reduce__(self):
|
|
112
|
+
if not is_dataclass(self):
|
|
113
|
+
return super().__reduce__()
|
|
114
|
+
callable, _args, *remaining = super().__reduce__()
|
|
115
|
+
args = tuple(getattr(self, field.name) for field in fields(self))
|
|
116
|
+
return callable, args, *remaining
|
|
117
|
+
|
|
118
|
+
def to_tuple(self):
|
|
119
|
+
return tuple(self[k] for k in self.keys())
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dataclass
|
|
123
|
+
class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
|
|
124
|
+
last_hidden_state = None
|
|
125
|
+
past_key_values = None
|
|
126
|
+
hidden_states = None
|
|
127
|
+
attentions = None
|
|
128
|
+
cross_attentions = None
|
|
129
|
+
|
|
130
|
+
def __init__(self, *args, **kwargs):
|
|
131
|
+
super().__init__(*args, **kwargs)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@dataclass
|
|
135
|
+
class Seq2SeqLMOutput(ModelOutput):
|
|
136
|
+
loss = None
|
|
137
|
+
logits = None
|
|
138
|
+
past_key_values = None
|
|
139
|
+
decoder_hidden_states = None
|
|
140
|
+
decoder_attentions = None
|
|
141
|
+
cross_attentions = None
|
|
142
|
+
encoder_last_hidden_state = None
|
|
143
|
+
encoder_hidden_states = None
|
|
144
|
+
encoder_attentions = None
|
|
145
|
+
|
|
146
|
+
def __init__(self, *args, **kwargs):
|
|
147
|
+
super().__init__(*args, **kwargs)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class MBartConfig(object):
|
|
151
|
+
model_type = "mbart"
|
|
152
|
+
keys_to_ignore_at_inference = ["past_key_values"]
|
|
153
|
+
attribute_map = {
|
|
154
|
+
"num_attention_heads": "encoder_attention_heads",
|
|
155
|
+
"hidden_size": "d_model",
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
def __init__(
|
|
159
|
+
self,
|
|
160
|
+
vocab_size=50265,
|
|
161
|
+
max_position_embeddings=1024,
|
|
162
|
+
encoder_layers=12,
|
|
163
|
+
encoder_ffn_dim=4096,
|
|
164
|
+
encoder_attention_heads=16,
|
|
165
|
+
decoder_layers=12,
|
|
166
|
+
decoder_ffn_dim=4096,
|
|
167
|
+
decoder_attention_heads=16,
|
|
168
|
+
encoder_layerdrop=0.0,
|
|
169
|
+
decoder_layerdrop=0.0,
|
|
170
|
+
use_cache=True,
|
|
171
|
+
is_encoder_decoder=True,
|
|
172
|
+
activation_function="gelu",
|
|
173
|
+
d_model=1024,
|
|
174
|
+
dropout=0.1,
|
|
175
|
+
output_hidden_states=False,
|
|
176
|
+
use_return_dict=True,
|
|
177
|
+
attention_dropout=0.0,
|
|
178
|
+
activation_dropout=0.0,
|
|
179
|
+
init_std=0.02,
|
|
180
|
+
classifier_dropout=0.0,
|
|
181
|
+
scale_embedding=False,
|
|
182
|
+
pad_token_id=1,
|
|
183
|
+
bos_token_id=0,
|
|
184
|
+
eos_token_id=2,
|
|
185
|
+
forced_eos_token_id=2,
|
|
186
|
+
_attn_implementation="eager",
|
|
187
|
+
hidden_size=1024,
|
|
188
|
+
use_parallel=False,
|
|
189
|
+
parallel_step=2,
|
|
190
|
+
is_export=False,
|
|
191
|
+
**kwargs,
|
|
192
|
+
):
|
|
193
|
+
self.vocab_size = vocab_size
|
|
194
|
+
self.hidden_size = hidden_size
|
|
195
|
+
self.max_position_embeddings = max_position_embeddings
|
|
196
|
+
self.d_model = d_model
|
|
197
|
+
self.encoder_ffn_dim = encoder_ffn_dim
|
|
198
|
+
self.encoder_layers = encoder_layers
|
|
199
|
+
self.encoder_attention_heads = encoder_attention_heads
|
|
200
|
+
self.decoder_ffn_dim = decoder_ffn_dim
|
|
201
|
+
self.decoder_layers = decoder_layers
|
|
202
|
+
self.decoder_attention_heads = decoder_attention_heads
|
|
203
|
+
self.dropout = dropout
|
|
204
|
+
self.output_hidden_states = output_hidden_states
|
|
205
|
+
self.use_return_dict = use_return_dict
|
|
206
|
+
self.attention_dropout = attention_dropout
|
|
207
|
+
self.activation_dropout = activation_dropout
|
|
208
|
+
self.activation_function = activation_function
|
|
209
|
+
self.init_std = init_std
|
|
210
|
+
self.encoder_layerdrop = encoder_layerdrop
|
|
211
|
+
self.decoder_layerdrop = decoder_layerdrop
|
|
212
|
+
self.classifier_dropout = classifier_dropout
|
|
213
|
+
self.use_cache = use_cache
|
|
214
|
+
self.num_hidden_layers = encoder_layers
|
|
215
|
+
self.scale_embedding = (
|
|
216
|
+
scale_embedding # scale factor will be sqrt(d_model) if True
|
|
217
|
+
)
|
|
218
|
+
self.pad_token_id = pad_token_id
|
|
219
|
+
self.bos_token_id = bos_token_id
|
|
220
|
+
self.eos_token_id = eos_token_id
|
|
221
|
+
self.is_encoder_decoder = is_encoder_decoder
|
|
222
|
+
self.forced_eos_token_id = forced_eos_token_id
|
|
223
|
+
self._attn_implementation = _attn_implementation
|
|
224
|
+
self.use_parallel = use_parallel
|
|
225
|
+
self.parallel_step = parallel_step
|
|
226
|
+
self.is_export = is_export
|
|
227
|
+
super().__init__()
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@dataclass
|
|
231
|
+
class AttentionMaskConverter:
|
|
232
|
+
"""
|
|
233
|
+
A utility class for converting attention masks used in transformer models.
|
|
234
|
+
|
|
235
|
+
This class handles the conversion of attention masks based on whether the
|
|
236
|
+
attention mechanism is causal (i.e., preventing information flow from future
|
|
237
|
+
tokens to past tokens) and whether a sliding window approach is used.
|
|
238
|
+
|
|
239
|
+
Attributes:
|
|
240
|
+
is_causal (bool): Indicates if the attention mechanism is causal.
|
|
241
|
+
sliding_window (Optional[int]): Specifies the size of the sliding window
|
|
242
|
+
for local attention, if applicable.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
is_causal (bool): Determines if the attention mask should enforce causality.
|
|
246
|
+
sliding_window (Optional[int], optional): The size of the sliding window
|
|
247
|
+
for local attention. Default is None.
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
is_causal: bool
|
|
251
|
+
sliding_window: int
|
|
252
|
+
|
|
253
|
+
def __init__(self, is_causal: bool, sliding_window=None):
|
|
254
|
+
self.is_causal = is_causal
|
|
255
|
+
self.sliding_window = sliding_window
|
|
256
|
+
|
|
257
|
+
if self.sliding_window is not None and self.sliding_window <= 0:
|
|
258
|
+
raise ValueError(
|
|
259
|
+
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
@staticmethod
|
|
263
|
+
def _make_causal_mask(
|
|
264
|
+
input_ids_shape,
|
|
265
|
+
dtype,
|
|
266
|
+
past_key_values_length=0,
|
|
267
|
+
sliding_window=None,
|
|
268
|
+
is_export=False,
|
|
269
|
+
):
|
|
270
|
+
bsz, tgt_len = input_ids_shape
|
|
271
|
+
if is_export:
|
|
272
|
+
mask = torch.full(
|
|
273
|
+
[tgt_len, tgt_len], fill_value=torch.finfo(dtype).min, dtype=torch.float64
|
|
274
|
+
)
|
|
275
|
+
else:
|
|
276
|
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
|
|
277
|
+
mask_cond = torch.arange(mask.shape[-1])
|
|
278
|
+
mask = mask.masked_fill_(
|
|
279
|
+
mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
|
|
280
|
+
)
|
|
281
|
+
return mask[None, None, :, :].expand(
|
|
282
|
+
[bsz, 1, tgt_len, tgt_len + past_key_values_length]
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
def to_4d_export(
|
|
286
|
+
self,
|
|
287
|
+
attention_mask_2d,
|
|
288
|
+
query_length,
|
|
289
|
+
dtype,
|
|
290
|
+
key_value_length,
|
|
291
|
+
is_export=False,
|
|
292
|
+
):
|
|
293
|
+
input_shape = (attention_mask_2d.shape[0], query_length)
|
|
294
|
+
expanded_attn_mask = self._expand_mask(
|
|
295
|
+
attention_mask_2d, dtype, tgt_len=input_shape[-1]
|
|
296
|
+
)
|
|
297
|
+
expanded_4d_mask = expanded_attn_mask
|
|
298
|
+
|
|
299
|
+
return expanded_4d_mask
|
|
300
|
+
|
|
301
|
+
def to_4d(
|
|
302
|
+
self,
|
|
303
|
+
attention_mask_2d,
|
|
304
|
+
query_length,
|
|
305
|
+
dtype,
|
|
306
|
+
key_value_length,
|
|
307
|
+
is_export=False,
|
|
308
|
+
):
|
|
309
|
+
|
|
310
|
+
input_shape = (attention_mask_2d.shape[0], query_length)
|
|
311
|
+
causal_4d_mask = None
|
|
312
|
+
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
|
|
313
|
+
if key_value_length is None:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
past_key_values_length = key_value_length - query_length
|
|
319
|
+
|
|
320
|
+
causal_4d_mask = self._make_causal_mask(
|
|
321
|
+
input_shape,
|
|
322
|
+
dtype,
|
|
323
|
+
past_key_values_length=past_key_values_length,
|
|
324
|
+
sliding_window=self.sliding_window,
|
|
325
|
+
is_export=is_export,
|
|
326
|
+
)
|
|
327
|
+
elif self.sliding_window is not None:
|
|
328
|
+
raise NotImplementedError(
|
|
329
|
+
"Sliding window is currently only implemented for causal masking"
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
expanded_attn_mask = self._expand_mask(
|
|
333
|
+
attention_mask_2d, dtype, tgt_len=input_shape[-1]
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
if causal_4d_mask is not None:
|
|
337
|
+
if is_export:
|
|
338
|
+
expanded_attn_mask = causal_4d_mask
|
|
339
|
+
return expanded_attn_mask
|
|
340
|
+
else:
|
|
341
|
+
expanded_attn_mask = causal_4d_mask.masked_fill_(
|
|
342
|
+
expanded_attn_mask.to(torch.bool), torch.finfo(dtype).min
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
expanded_4d_mask = expanded_attn_mask
|
|
346
|
+
|
|
347
|
+
return expanded_4d_mask
|
|
348
|
+
|
|
349
|
+
def _expand_mask(self, mask, dtype, tgt_len=None):
|
|
350
|
+
bsz, src_len = mask.shape
|
|
351
|
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
352
|
+
expanded_mask = (
|
|
353
|
+
mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).to(dtype)
|
|
354
|
+
)
|
|
355
|
+
inverted_mask = 1.0 - expanded_mask
|
|
356
|
+
return inverted_mask.masked_fill_(
|
|
357
|
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def _prepare_4d_attention_mask(mask, dtype, tgt_len=None):
|
|
362
|
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def _prepare_4d_causal_attention_mask_export(
|
|
366
|
+
attention_mask,
|
|
367
|
+
input_shape,
|
|
368
|
+
inputs_embeds,
|
|
369
|
+
past_key_values_length,
|
|
370
|
+
sliding_window=None,
|
|
371
|
+
is_export=False,
|
|
372
|
+
):
|
|
373
|
+
attn_mask_converter = AttentionMaskConverter(
|
|
374
|
+
is_causal=True, sliding_window=sliding_window
|
|
375
|
+
)
|
|
376
|
+
key_value_length = input_shape[-1] + past_key_values_length
|
|
377
|
+
|
|
378
|
+
shape = attention_mask.shape
|
|
379
|
+
len_shape = len(shape)
|
|
380
|
+
|
|
381
|
+
attention_mask = attn_mask_converter.to_4d_export(
|
|
382
|
+
attention_mask,
|
|
383
|
+
input_shape[-1],
|
|
384
|
+
key_value_length=key_value_length,
|
|
385
|
+
dtype=inputs_embeds.dtype,
|
|
386
|
+
is_export=is_export,
|
|
387
|
+
)
|
|
388
|
+
return attention_mask
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def _prepare_4d_causal_attention_mask(
|
|
392
|
+
attention_mask,
|
|
393
|
+
input_shape,
|
|
394
|
+
inputs_embeds,
|
|
395
|
+
past_key_values_length,
|
|
396
|
+
sliding_window=None,
|
|
397
|
+
is_export=False,
|
|
398
|
+
):
|
|
399
|
+
attn_mask_converter = AttentionMaskConverter(
|
|
400
|
+
is_causal=True, sliding_window=sliding_window
|
|
401
|
+
)
|
|
402
|
+
key_value_length = input_shape[-1] + past_key_values_length
|
|
403
|
+
|
|
404
|
+
shape = attention_mask.shape
|
|
405
|
+
len_shape = len(shape)
|
|
406
|
+
if (attention_mask is not None) and (len_shape == 2):
|
|
407
|
+
attention_mask = attn_mask_converter.to_4d(
|
|
408
|
+
attention_mask,
|
|
409
|
+
input_shape[-1],
|
|
410
|
+
key_value_length=key_value_length,
|
|
411
|
+
dtype=inputs_embeds.dtype,
|
|
412
|
+
is_export=is_export,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
return attention_mask
|
|
416
|
+
elif attention_mask is not None and len(attention_mask.shape) == 4:
|
|
417
|
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
|
418
|
+
if tuple(attention_mask.shape) != expected_shape:
|
|
419
|
+
raise ValueError(
|
|
420
|
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
|
421
|
+
)
|
|
422
|
+
else:
|
|
423
|
+
inverted_mask = 1.0 - attention_mask
|
|
424
|
+
attention_mask = inverted_mask.masked_fill_(
|
|
425
|
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
|
426
|
+
)
|
|
427
|
+
else:
|
|
428
|
+
attention_mask = attn_mask_converter.to_causal_4d(
|
|
429
|
+
input_shape[0],
|
|
430
|
+
input_shape[-1],
|
|
431
|
+
key_value_length,
|
|
432
|
+
dtype=inputs_embeds.dtype,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
return attention_mask
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
class MBartLearnedPositionalEmbedding(nn.Embedding):
|
|
439
|
+
"""
|
|
440
|
+
This module learns positional embeddings up to a fixed maximum size.
|
|
441
|
+
"""
|
|
442
|
+
|
|
443
|
+
def __init__(self, num_embeddings, embedding_dim):
|
|
444
|
+
self.offset = 2
|
|
445
|
+
super().__init__(num_embeddings + self.offset, embedding_dim)
|
|
446
|
+
self.device = torch.device(get_device())
|
|
447
|
+
|
|
448
|
+
def forward(self, input_ids, past_key_values_length=0):
|
|
449
|
+
"""`input_ids' shape is expected to be [bsz x seqlen]."""
|
|
450
|
+
bsz, seq_len = input_ids.shape[:2]
|
|
451
|
+
positions = torch.arange(
|
|
452
|
+
past_key_values_length, past_key_values_length + seq_len, dtype=torch.int64
|
|
453
|
+
).expand([bsz, -1]).to(self.device)
|
|
454
|
+
return nn.Embedding.forward(self, positions + self.offset)
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
class MBartPreTrainedModel(nn.Module):
|
|
458
|
+
base_model_prefix = "model"
|
|
459
|
+
supports_gradient_checkpointing = True
|
|
460
|
+
_no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
|
|
461
|
+
_supports_flash_attn_2 = True
|
|
462
|
+
|
|
463
|
+
def __init__(self, config):
|
|
464
|
+
super().__init__()
|
|
465
|
+
self.config = config
|
|
466
|
+
|
|
467
|
+
def _initialize_weights(self, module):
|
|
468
|
+
"""
|
|
469
|
+
Initialize the weights if they are not already initialized.
|
|
470
|
+
"""
|
|
471
|
+
if getattr(module, "_is_hf_initialized", False):
|
|
472
|
+
return
|
|
473
|
+
self._init_weights(module)
|
|
474
|
+
|
|
475
|
+
def post_init(self):
|
|
476
|
+
self.apply(self._initialize_weights)
|
|
477
|
+
|
|
478
|
+
def _init_weights(self, module):
|
|
479
|
+
std = self.config.init_std
|
|
480
|
+
if isinstance(module, nn.Linear):
|
|
481
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
|
482
|
+
if module.bias is not None:
|
|
483
|
+
torch.nn.init.constant_(module.bias, val=0.0)
|
|
484
|
+
elif isinstance(module, nn.Embedding):
|
|
485
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
|
486
|
+
if module.padding_idx is not None:
|
|
487
|
+
torch.nn.init.constant_(module.weight[module.padding_idx], val=0.0)
|
|
488
|
+
|
|
489
|
+
@property
|
|
490
|
+
def dummy_inputs(self):
|
|
491
|
+
pad_token = self.config.pad_token_id
|
|
492
|
+
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]])
|
|
493
|
+
dummy_inputs = {
|
|
494
|
+
"attention_mask": input_ids.ne(pad_token),
|
|
495
|
+
"input_ids": input_ids,
|
|
496
|
+
}
|
|
497
|
+
return dummy_inputs
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
class MBartAttention(nn.Module):
|
|
501
|
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
502
|
+
|
|
503
|
+
def __init__(
|
|
504
|
+
self,
|
|
505
|
+
embed_dim,
|
|
506
|
+
num_heads,
|
|
507
|
+
dropout: float = 0.0,
|
|
508
|
+
is_decoder: bool = False,
|
|
509
|
+
bias: bool = True,
|
|
510
|
+
is_causal: bool = False,
|
|
511
|
+
config=None,
|
|
512
|
+
):
|
|
513
|
+
super().__init__()
|
|
514
|
+
self.embed_dim = embed_dim
|
|
515
|
+
self.num_heads = num_heads
|
|
516
|
+
self.dropout = dropout
|
|
517
|
+
self.head_dim = embed_dim // num_heads
|
|
518
|
+
self.config = config
|
|
519
|
+
|
|
520
|
+
if (self.head_dim * num_heads) != self.embed_dim:
|
|
521
|
+
raise ValueError(
|
|
522
|
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
|
523
|
+
f" and `num_heads`: {num_heads})."
|
|
524
|
+
)
|
|
525
|
+
self.scaling = self.head_dim ** -0.5
|
|
526
|
+
self.is_decoder = is_decoder
|
|
527
|
+
self.is_causal = is_causal
|
|
528
|
+
|
|
529
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
530
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
531
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
532
|
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
533
|
+
|
|
534
|
+
def _shape(self, tensor, seq_len, bsz):
|
|
535
|
+
return tensor.reshape([bsz, seq_len, self.num_heads, self.head_dim]).permute(
|
|
536
|
+
0, 2, 1, 3
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
def forward(
|
|
540
|
+
self,
|
|
541
|
+
hidden_states,
|
|
542
|
+
key_value_states=None,
|
|
543
|
+
past_key_value=None,
|
|
544
|
+
attention_mask=None,
|
|
545
|
+
layer_head_mask=None,
|
|
546
|
+
output_attentions=False,
|
|
547
|
+
):
|
|
548
|
+
|
|
549
|
+
is_cross_attention = key_value_states is not None
|
|
550
|
+
|
|
551
|
+
bsz, tgt_len, _ = hidden_states.shape
|
|
552
|
+
query_states = self.q_proj(hidden_states) * self.scaling
|
|
553
|
+
if (
|
|
554
|
+
is_cross_attention
|
|
555
|
+
and past_key_value is not None
|
|
556
|
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
|
557
|
+
):
|
|
558
|
+
key_states = past_key_value[0]
|
|
559
|
+
value_states = past_key_value[1]
|
|
560
|
+
elif is_cross_attention:
|
|
561
|
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
|
562
|
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
|
563
|
+
elif past_key_value is not None:
|
|
564
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
565
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
566
|
+
key_states = torch.concat([past_key_value[0], key_states], dim=2)
|
|
567
|
+
value_states = torch.concat([past_key_value[1], value_states], dim=2)
|
|
568
|
+
else:
|
|
569
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
570
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
571
|
+
|
|
572
|
+
if self.is_decoder:
|
|
573
|
+
past_key_value = (key_states, value_states)
|
|
574
|
+
|
|
575
|
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
|
576
|
+
query_states = self._shape(query_states, tgt_len, bsz).reshape(proj_shape)
|
|
577
|
+
key_states = key_states.reshape(proj_shape)
|
|
578
|
+
value_states = value_states.reshape(proj_shape)
|
|
579
|
+
|
|
580
|
+
src_len = key_states.shape[1]
|
|
581
|
+
attn_weights = torch.bmm(query_states, key_states.permute([0, 2, 1]))
|
|
582
|
+
|
|
583
|
+
if attention_mask is not None:
|
|
584
|
+
attn_weights = (
|
|
585
|
+
attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
|
|
586
|
+
+ attention_mask
|
|
587
|
+
)
|
|
588
|
+
attn_weights = attn_weights.reshape(
|
|
589
|
+
[bsz * self.num_heads, tgt_len, src_len]
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
593
|
+
if layer_head_mask is not None:
|
|
594
|
+
if tuple(layer_head_mask.shape) != (self.num_heads,):
|
|
595
|
+
raise ValueError(
|
|
596
|
+
f"Head mask for a single layer should be of shape {(self.num_heads,)}, but is"
|
|
597
|
+
f" {layer_head_mask.shape}"
|
|
598
|
+
)
|
|
599
|
+
attn_weights = layer_head_mask.reshape(
|
|
600
|
+
[1, -1, 1, 1]
|
|
601
|
+
) * attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
|
|
602
|
+
attn_weights = attn_weights.reshape(
|
|
603
|
+
[bsz * self.num_heads, tgt_len, src_len]
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
if output_attentions:
|
|
607
|
+
attn_weights_reshaped = attn_weights.reshape(
|
|
608
|
+
[bsz, self.num_heads, tgt_len, src_len]
|
|
609
|
+
)
|
|
610
|
+
attn_weights = attn_weights_reshaped.reshape(
|
|
611
|
+
[bsz * self.num_heads, tgt_len, src_len]
|
|
612
|
+
)
|
|
613
|
+
else:
|
|
614
|
+
attn_weights_reshaped = None
|
|
615
|
+
attn_probs = nn.functional.dropout(
|
|
616
|
+
attn_weights, p=self.dropout, training=self.training
|
|
617
|
+
)
|
|
618
|
+
attn_output = torch.bmm(attn_probs, value_states)
|
|
619
|
+
|
|
620
|
+
attn_output = attn_output.reshape([bsz, self.num_heads, tgt_len, self.head_dim])
|
|
621
|
+
attn_output = attn_output.permute([0, 2, 1, 3])
|
|
622
|
+
|
|
623
|
+
attn_output = attn_output.reshape([bsz, tgt_len, self.embed_dim])
|
|
624
|
+
attn_output = self.out_proj(attn_output)
|
|
625
|
+
return attn_output, attn_weights_reshaped, past_key_value
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
MBART_ATTENTION_CLASSES = {
|
|
629
|
+
"eager": MBartAttention,
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
class MBartDecoderLayer(nn.Module):
|
|
634
|
+
def __init__(self, config):
|
|
635
|
+
super().__init__()
|
|
636
|
+
self.embed_dim = config.d_model
|
|
637
|
+
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
|
|
638
|
+
embed_dim=self.embed_dim,
|
|
639
|
+
num_heads=config.decoder_attention_heads,
|
|
640
|
+
dropout=config.attention_dropout,
|
|
641
|
+
is_decoder=True,
|
|
642
|
+
is_causal=True,
|
|
643
|
+
config=config,
|
|
644
|
+
)
|
|
645
|
+
self.is_export = config.is_export
|
|
646
|
+
self.dropout = config.dropout
|
|
647
|
+
self.activation_fn = F.gelu
|
|
648
|
+
self.activation_dropout = config.activation_dropout
|
|
649
|
+
|
|
650
|
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
651
|
+
self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
|
|
652
|
+
self.embed_dim,
|
|
653
|
+
config.decoder_attention_heads,
|
|
654
|
+
dropout=config.attention_dropout,
|
|
655
|
+
is_decoder=True,
|
|
656
|
+
config=config,
|
|
657
|
+
)
|
|
658
|
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
659
|
+
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
|
660
|
+
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
|
661
|
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
662
|
+
self.device = torch.device(get_device())
|
|
663
|
+
|
|
664
|
+
def forward(
|
|
665
|
+
self,
|
|
666
|
+
hidden_states,
|
|
667
|
+
attention_mask=None,
|
|
668
|
+
encoder_hidden_states=None,
|
|
669
|
+
encoder_attention_mask=None,
|
|
670
|
+
layer_head_mask=None,
|
|
671
|
+
cross_attn_layer_head_mask=None,
|
|
672
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
673
|
+
output_attentions: Optional[bool] = False,
|
|
674
|
+
use_cache: Optional[bool] = True,
|
|
675
|
+
) -> torch.Tensor:
|
|
676
|
+
|
|
677
|
+
residual = hidden_states
|
|
678
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
679
|
+
|
|
680
|
+
self_attn_past_key_value = None
|
|
681
|
+
if past_key_value is not None:
|
|
682
|
+
self_attn_past_key_value = tuple(
|
|
683
|
+
t.to(self.device) if isinstance(t, torch.Tensor) else t for t in past_key_value[:2]
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
687
|
+
hidden_states=hidden_states,
|
|
688
|
+
past_key_value=self_attn_past_key_value,
|
|
689
|
+
attention_mask=attention_mask,
|
|
690
|
+
layer_head_mask=layer_head_mask,
|
|
691
|
+
output_attentions=output_attentions,
|
|
692
|
+
)
|
|
693
|
+
hidden_states = nn.functional.dropout(
|
|
694
|
+
hidden_states, p=self.dropout, training=self.training
|
|
695
|
+
)
|
|
696
|
+
hidden_states = residual + hidden_states
|
|
697
|
+
|
|
698
|
+
cross_attn_present_key_value = None
|
|
699
|
+
cross_attn_weights = None
|
|
700
|
+
if encoder_hidden_states is not None:
|
|
701
|
+
residual = hidden_states
|
|
702
|
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
703
|
+
cross_attn_past_key_value = (
|
|
704
|
+
past_key_value[-2:] if past_key_value is not None else None
|
|
705
|
+
)
|
|
706
|
+
hidden_states, cross_attn_weights, cross_attn_present_key_value = (
|
|
707
|
+
self.encoder_attn(
|
|
708
|
+
hidden_states=hidden_states,
|
|
709
|
+
key_value_states=encoder_hidden_states,
|
|
710
|
+
attention_mask=encoder_attention_mask,
|
|
711
|
+
layer_head_mask=cross_attn_layer_head_mask,
|
|
712
|
+
past_key_value=cross_attn_past_key_value,
|
|
713
|
+
output_attentions=output_attentions,
|
|
714
|
+
)
|
|
715
|
+
)
|
|
716
|
+
hidden_states = nn.functional.dropout(
|
|
717
|
+
hidden_states, p=self.dropout, training=self.training
|
|
718
|
+
)
|
|
719
|
+
hidden_states = residual + hidden_states
|
|
720
|
+
|
|
721
|
+
present_key_value = present_key_value + cross_attn_present_key_value
|
|
722
|
+
|
|
723
|
+
residual = hidden_states
|
|
724
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
|
725
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
726
|
+
hidden_states = nn.functional.dropout(
|
|
727
|
+
hidden_states, p=self.activation_dropout, training=self.training
|
|
728
|
+
)
|
|
729
|
+
hidden_states = self.fc2(hidden_states)
|
|
730
|
+
hidden_states = nn.functional.dropout(
|
|
731
|
+
hidden_states, p=self.dropout, training=self.training
|
|
732
|
+
)
|
|
733
|
+
hidden_states = residual + hidden_states
|
|
734
|
+
outputs = (hidden_states,)
|
|
735
|
+
|
|
736
|
+
if output_attentions:
|
|
737
|
+
outputs += (self_attn_weights, cross_attn_weights)
|
|
738
|
+
|
|
739
|
+
if self.is_export:
|
|
740
|
+
outputs += (present_key_value,)
|
|
741
|
+
else:
|
|
742
|
+
if use_cache:
|
|
743
|
+
outputs += (present_key_value,)
|
|
744
|
+
return outputs
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
class MBartForCausalLM(MBartPreTrainedModel):
|
|
748
|
+
_tied_weights_keys = ["lm_head.weight"]
|
|
749
|
+
|
|
750
|
+
def __init__(self, config):
|
|
751
|
+
config = copy.deepcopy(config)
|
|
752
|
+
config.is_decoder = True
|
|
753
|
+
config.is_encoder_decoder = False
|
|
754
|
+
super().__init__(config)
|
|
755
|
+
self.model = MBartDecoderWrapper(config)
|
|
756
|
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
757
|
+
|
|
758
|
+
self.post_init()
|
|
759
|
+
|
|
760
|
+
def get_input_embeddings(self):
|
|
761
|
+
return self.model.decoder.embed_tokens
|
|
762
|
+
|
|
763
|
+
def set_input_embeddings(self, value):
|
|
764
|
+
self.model.decoder.embed_tokens = value
|
|
765
|
+
|
|
766
|
+
def get_output_embeddings(self):
|
|
767
|
+
return self.lm_head
|
|
768
|
+
|
|
769
|
+
def set_output_embeddings(self, new_embeddings):
|
|
770
|
+
self.lm_head = new_embeddings
|
|
771
|
+
|
|
772
|
+
def set_decoder(self, decoder):
|
|
773
|
+
self.model.decoder = decoder
|
|
774
|
+
|
|
775
|
+
def get_decoder(self):
|
|
776
|
+
return self.model.decoder
|
|
777
|
+
|
|
778
|
+
def forward(
|
|
779
|
+
self,
|
|
780
|
+
input_ids=None,
|
|
781
|
+
attention_mask=None,
|
|
782
|
+
encoder_hidden_states=None,
|
|
783
|
+
encoder_attention_mask=None,
|
|
784
|
+
head_mask=None,
|
|
785
|
+
cross_attn_head_mask=None,
|
|
786
|
+
past_key_values=None,
|
|
787
|
+
inputs_embeds=None,
|
|
788
|
+
labels=None,
|
|
789
|
+
use_cache=None,
|
|
790
|
+
output_attentions=None,
|
|
791
|
+
output_hidden_states=None,
|
|
792
|
+
return_dict=None,
|
|
793
|
+
):
|
|
794
|
+
|
|
795
|
+
output_attentions = (
|
|
796
|
+
output_attentions
|
|
797
|
+
if output_attentions is not None
|
|
798
|
+
else self.config.output_attentions
|
|
799
|
+
)
|
|
800
|
+
output_hidden_states = (
|
|
801
|
+
output_hidden_states
|
|
802
|
+
if output_hidden_states is not None
|
|
803
|
+
else self.config.output_hidden_states
|
|
804
|
+
)
|
|
805
|
+
return_dict = (
|
|
806
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
outputs = self.model.decoder(
|
|
810
|
+
input_ids=input_ids,
|
|
811
|
+
attention_mask=attention_mask,
|
|
812
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
813
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
814
|
+
head_mask=head_mask,
|
|
815
|
+
cross_attn_head_mask=cross_attn_head_mask,
|
|
816
|
+
past_key_values=past_key_values,
|
|
817
|
+
inputs_embeds=inputs_embeds,
|
|
818
|
+
use_cache=use_cache,
|
|
819
|
+
output_attentions=output_attentions,
|
|
820
|
+
output_hidden_states=output_hidden_states,
|
|
821
|
+
return_dict=return_dict,
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
logits = self.lm_head(outputs[0])
|
|
825
|
+
|
|
826
|
+
loss = None
|
|
827
|
+
if labels is not None:
|
|
828
|
+
labels = labels
|
|
829
|
+
loss_fct = CrossEntropyLoss()
|
|
830
|
+
loss = loss_fct(
|
|
831
|
+
logits.reshape([-1, self.config.vocab_size]), labels.reshape([-1])
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
if not return_dict:
|
|
835
|
+
output = (logits,) + outputs[1:]
|
|
836
|
+
return (loss,) + output if loss is not None else output
|
|
837
|
+
|
|
838
|
+
return CausalLMOutputWithCrossAttentions(
|
|
839
|
+
loss=loss,
|
|
840
|
+
logits=logits,
|
|
841
|
+
past_key_values=outputs.past_key_values,
|
|
842
|
+
hidden_states=outputs.hidden_states,
|
|
843
|
+
attentions=outputs.attentions,
|
|
844
|
+
cross_attentions=outputs.cross_attentions,
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
def prepare_inputs_for_generation(
|
|
848
|
+
self,
|
|
849
|
+
input_ids,
|
|
850
|
+
past_key_values=None,
|
|
851
|
+
attention_mask=None,
|
|
852
|
+
use_cache=None,
|
|
853
|
+
**kwargs,
|
|
854
|
+
):
|
|
855
|
+
if attention_mask is None:
|
|
856
|
+
attention_mask = input_ids.new_ones(input_ids.shape)
|
|
857
|
+
|
|
858
|
+
if past_key_values:
|
|
859
|
+
past_length = past_key_values[0][0].shape[2]
|
|
860
|
+
|
|
861
|
+
if input_ids.shape[1] > past_length:
|
|
862
|
+
remove_prefix_length = past_length
|
|
863
|
+
else:
|
|
864
|
+
remove_prefix_length = input_ids.shape[1] - 1
|
|
865
|
+
|
|
866
|
+
input_ids = input_ids[:, remove_prefix_length:]
|
|
867
|
+
return {
|
|
868
|
+
"input_ids": input_ids,
|
|
869
|
+
"attention_mask": attention_mask,
|
|
870
|
+
"past_key_values": past_key_values,
|
|
871
|
+
"use_cache": use_cache,
|
|
872
|
+
}
|
|
873
|
+
|
|
874
|
+
@staticmethod
|
|
875
|
+
def _reorder_cache(past_key_values, beam_idx):
|
|
876
|
+
reordered_past = ()
|
|
877
|
+
for layer_past in past_key_values:
|
|
878
|
+
reordered_past += (
|
|
879
|
+
tuple(
|
|
880
|
+
past_state.index_select(0, beam_idx) for past_state in layer_past
|
|
881
|
+
),
|
|
882
|
+
)
|
|
883
|
+
return reordered_past
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
class myLayerNorm(nn.LayerNorm):
|
|
887
|
+
"""
|
|
888
|
+
Custom implementation of Layer Normalization, with additional options.
|
|
889
|
+
|
|
890
|
+
This class extends the standard LayerNorm to include optional features,
|
|
891
|
+
such as drop block regularization, which might be used for improving
|
|
892
|
+
model generalization.
|
|
893
|
+
|
|
894
|
+
Args:
|
|
895
|
+
num_channels (int): The number of features or channels in the input.
|
|
896
|
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-5.
|
|
897
|
+
affine (bool, optional): If True, this module has learnable affine parameters (gamma and beta). Default is True.
|
|
898
|
+
drop_block (optional): Additional regularization technique that might be applied. Default is None.
|
|
899
|
+
|
|
900
|
+
"""
|
|
901
|
+
|
|
902
|
+
def __init__(
|
|
903
|
+
self,
|
|
904
|
+
num_channels,
|
|
905
|
+
eps=1e-5,
|
|
906
|
+
affine=True,
|
|
907
|
+
drop_block=None,
|
|
908
|
+
):
|
|
909
|
+
super(nn.LayerNorm, self).__init__()
|
|
910
|
+
self._epsilon = eps
|
|
911
|
+
self.num_channels = num_channels
|
|
912
|
+
if affine:
|
|
913
|
+
self.weight = torch.nn.Parameter(torch.randn([num_channels]) * 0.01)
|
|
914
|
+
self.bias = torch.nn.Parameter(torch.randn([num_channels]) * 0.01)
|
|
915
|
+
torch.nn.init.ones_(self.weight)
|
|
916
|
+
torch.nn.init.zeros_(self.bias)
|
|
917
|
+
|
|
918
|
+
def forward(self, x):
|
|
919
|
+
x = F.layer_norm(
|
|
920
|
+
x,
|
|
921
|
+
[self.num_channels],
|
|
922
|
+
weight=self.weight,
|
|
923
|
+
bias=self.bias,
|
|
924
|
+
eps=self._epsilon,
|
|
925
|
+
)
|
|
926
|
+
return x
|
|
927
|
+
|
|
928
|
+
|
|
929
|
+
class MBartDecoder(MBartPreTrainedModel):
|
|
930
|
+
"""
|
|
931
|
+
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]
|
|
932
|
+
|
|
933
|
+
Args:
|
|
934
|
+
config
|
|
935
|
+
embed_tokens (nn.Embedding): output embedding
|
|
936
|
+
"""
|
|
937
|
+
|
|
938
|
+
def __init__(self, config, embed_tokens=None):
|
|
939
|
+
super().__init__(config)
|
|
940
|
+
self.dropout = config.dropout
|
|
941
|
+
self.layerdrop = config.decoder_layerdrop
|
|
942
|
+
self.padding_idx = config.pad_token_id
|
|
943
|
+
self.max_target_positions = config.max_position_embeddings
|
|
944
|
+
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
|
945
|
+
|
|
946
|
+
self.embed_tokens = nn.Embedding(
|
|
947
|
+
config.vocab_size, config.d_model, self.padding_idx
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
if embed_tokens is not None:
|
|
951
|
+
self.embed_tokens.weight = embed_tokens.weight
|
|
952
|
+
|
|
953
|
+
self.embed_positions = MBartLearnedPositionalEmbedding(
|
|
954
|
+
config.max_position_embeddings,
|
|
955
|
+
config.d_model,
|
|
956
|
+
)
|
|
957
|
+
self.layers = nn.ModuleList(
|
|
958
|
+
[MBartDecoderLayer(config) for _ in range(config.decoder_layers)]
|
|
959
|
+
)
|
|
960
|
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
|
961
|
+
self.layernorm_embedding = myLayerNorm(config.d_model, affine=True)
|
|
962
|
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
|
963
|
+
|
|
964
|
+
self.gradient_checkpointing = False
|
|
965
|
+
# Initialize weights and apply final processing
|
|
966
|
+
self.post_init()
|
|
967
|
+
self.is_export = config.is_export
|
|
968
|
+
|
|
969
|
+
def get_input_embeddings(self):
|
|
970
|
+
return self.embed_tokens
|
|
971
|
+
|
|
972
|
+
def set_input_embeddings(self, value):
|
|
973
|
+
self.embed_tokens = value
|
|
974
|
+
|
|
975
|
+
def forward(
|
|
976
|
+
self,
|
|
977
|
+
input_ids=None,
|
|
978
|
+
attention_mask=None,
|
|
979
|
+
encoder_hidden_states=None,
|
|
980
|
+
encoder_attention_mask=None,
|
|
981
|
+
head_mask=None,
|
|
982
|
+
cross_attn_head_mask=None,
|
|
983
|
+
past_key_values=None,
|
|
984
|
+
inputs_embeds=None,
|
|
985
|
+
use_cache=None,
|
|
986
|
+
output_attentions=None,
|
|
987
|
+
output_hidden_states=None,
|
|
988
|
+
return_dict=None,
|
|
989
|
+
):
|
|
990
|
+
|
|
991
|
+
output_attentions = (
|
|
992
|
+
output_attentions
|
|
993
|
+
if output_attentions is not None
|
|
994
|
+
else self.config.output_attentions
|
|
995
|
+
)
|
|
996
|
+
output_hidden_states = (
|
|
997
|
+
output_hidden_states
|
|
998
|
+
if output_hidden_states is not None
|
|
999
|
+
else self.config.output_hidden_states
|
|
1000
|
+
)
|
|
1001
|
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
1002
|
+
return_dict = (
|
|
1003
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
1004
|
+
)
|
|
1005
|
+
|
|
1006
|
+
if input_ids is not None and inputs_embeds is not None:
|
|
1007
|
+
raise ValueError(
|
|
1008
|
+
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
|
1009
|
+
)
|
|
1010
|
+
elif input_ids is not None:
|
|
1011
|
+
input = input_ids
|
|
1012
|
+
input_shape = input.shape
|
|
1013
|
+
input_ids = input_ids.reshape([-1, input_shape[-1]])
|
|
1014
|
+
elif inputs_embeds is not None:
|
|
1015
|
+
input_shape = inputs_embeds.shape[:-1]
|
|
1016
|
+
input = inputs_embeds[:, :, -1]
|
|
1017
|
+
else:
|
|
1018
|
+
raise ValueError(
|
|
1019
|
+
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
past_key_values_length = (
|
|
1023
|
+
past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
1024
|
+
)
|
|
1025
|
+
|
|
1026
|
+
if inputs_embeds is None:
|
|
1027
|
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
|
1028
|
+
|
|
1029
|
+
if self._use_flash_attention_2:
|
|
1030
|
+
attention_mask = (
|
|
1031
|
+
attention_mask
|
|
1032
|
+
if (attention_mask is not None and 0 in attention_mask)
|
|
1033
|
+
else None
|
|
1034
|
+
)
|
|
1035
|
+
else:
|
|
1036
|
+
attention_mask = _prepare_4d_causal_attention_mask(
|
|
1037
|
+
attention_mask,
|
|
1038
|
+
input_shape,
|
|
1039
|
+
inputs_embeds,
|
|
1040
|
+
past_key_values_length,
|
|
1041
|
+
is_export=self.is_export,
|
|
1042
|
+
)
|
|
1043
|
+
|
|
1044
|
+
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
|
1045
|
+
if self._use_flash_attention_2:
|
|
1046
|
+
encoder_attention_mask = (
|
|
1047
|
+
encoder_attention_mask if 0 in encoder_attention_mask else None
|
|
1048
|
+
)
|
|
1049
|
+
else:
|
|
1050
|
+
encoder_attention_mask = _prepare_4d_attention_mask(
|
|
1051
|
+
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
1052
|
+
)
|
|
1053
|
+
|
|
1054
|
+
# embed positions
|
|
1055
|
+
positions = self.embed_positions(input, past_key_values_length)
|
|
1056
|
+
|
|
1057
|
+
hidden_states = inputs_embeds + positions
|
|
1058
|
+
hidden_states = self.layernorm_embedding(hidden_states)
|
|
1059
|
+
|
|
1060
|
+
hidden_states = nn.functional.dropout(
|
|
1061
|
+
hidden_states, p=self.dropout, training=self.training
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
if self.gradient_checkpointing and self.training:
|
|
1065
|
+
if use_cache:
|
|
1066
|
+
print(
|
|
1067
|
+
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
|
|
1068
|
+
)
|
|
1069
|
+
use_cache = False
|
|
1070
|
+
|
|
1071
|
+
all_hidden_states = () if output_hidden_states else None
|
|
1072
|
+
all_self_attns = () if output_attentions else None
|
|
1073
|
+
all_cross_attentions = (
|
|
1074
|
+
() if (output_attentions and encoder_hidden_states is not None) else None
|
|
1075
|
+
)
|
|
1076
|
+
next_decoder_cache = () if use_cache else None
|
|
1077
|
+
|
|
1078
|
+
for attn_mask, mask_name in zip(
|
|
1079
|
+
[head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
|
|
1080
|
+
):
|
|
1081
|
+
if attn_mask is not None:
|
|
1082
|
+
if attn_mask.shape[0] != len(self.layers):
|
|
1083
|
+
raise ValueError(
|
|
1084
|
+
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
|
1085
|
+
f" {attn_mask.shape[0]}."
|
|
1086
|
+
)
|
|
1087
|
+
|
|
1088
|
+
for idx, decoder_layer in enumerate(self.layers):
|
|
1089
|
+
if output_hidden_states:
|
|
1090
|
+
all_hidden_states += (hidden_states,)
|
|
1091
|
+
if self.training:
|
|
1092
|
+
dropout_probability = torch.rand([])
|
|
1093
|
+
if dropout_probability < self.layerdrop:
|
|
1094
|
+
continue
|
|
1095
|
+
|
|
1096
|
+
past_key_value = (
|
|
1097
|
+
past_key_values[idx] if past_key_values is not None else None
|
|
1098
|
+
)
|
|
1099
|
+
|
|
1100
|
+
if self.gradient_checkpointing and self.training:
|
|
1101
|
+
layer_outputs = self._gradient_checkpointing_func(
|
|
1102
|
+
decoder_layer.__call__,
|
|
1103
|
+
hidden_states,
|
|
1104
|
+
attention_mask,
|
|
1105
|
+
encoder_hidden_states,
|
|
1106
|
+
encoder_attention_mask,
|
|
1107
|
+
head_mask[idx] if head_mask is not None else None,
|
|
1108
|
+
(
|
|
1109
|
+
cross_attn_head_mask[idx]
|
|
1110
|
+
if cross_attn_head_mask is not None
|
|
1111
|
+
else None
|
|
1112
|
+
),
|
|
1113
|
+
None,
|
|
1114
|
+
output_attentions,
|
|
1115
|
+
use_cache,
|
|
1116
|
+
)
|
|
1117
|
+
else:
|
|
1118
|
+
layer_outputs = decoder_layer(
|
|
1119
|
+
hidden_states,
|
|
1120
|
+
attention_mask=attention_mask,
|
|
1121
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
1122
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
1123
|
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
|
1124
|
+
cross_attn_layer_head_mask=(
|
|
1125
|
+
cross_attn_head_mask[idx]
|
|
1126
|
+
if cross_attn_head_mask is not None
|
|
1127
|
+
else None
|
|
1128
|
+
),
|
|
1129
|
+
past_key_value=past_key_value,
|
|
1130
|
+
output_attentions=output_attentions,
|
|
1131
|
+
use_cache=use_cache,
|
|
1132
|
+
)
|
|
1133
|
+
hidden_states = layer_outputs[0]
|
|
1134
|
+
|
|
1135
|
+
if use_cache:
|
|
1136
|
+
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
|
1137
|
+
|
|
1138
|
+
if output_attentions:
|
|
1139
|
+
all_self_attns += (layer_outputs[1],)
|
|
1140
|
+
|
|
1141
|
+
if encoder_hidden_states is not None:
|
|
1142
|
+
all_cross_attentions += (layer_outputs[2],)
|
|
1143
|
+
|
|
1144
|
+
hidden_states = self.layer_norm(hidden_states)
|
|
1145
|
+
|
|
1146
|
+
if output_hidden_states:
|
|
1147
|
+
all_hidden_states += (hidden_states,)
|
|
1148
|
+
|
|
1149
|
+
next_cache = next_decoder_cache if use_cache else None
|
|
1150
|
+
if not return_dict:
|
|
1151
|
+
return tuple(
|
|
1152
|
+
v
|
|
1153
|
+
for v in [
|
|
1154
|
+
hidden_states,
|
|
1155
|
+
next_cache,
|
|
1156
|
+
all_hidden_states,
|
|
1157
|
+
all_self_attns,
|
|
1158
|
+
all_cross_attentions,
|
|
1159
|
+
]
|
|
1160
|
+
if v is not None
|
|
1161
|
+
)
|
|
1162
|
+
return BaseModelOutputWithPastAndCrossAttentions(
|
|
1163
|
+
last_hidden_state=hidden_states,
|
|
1164
|
+
past_key_values=next_cache,
|
|
1165
|
+
hidden_states=all_hidden_states,
|
|
1166
|
+
attentions=all_self_attns,
|
|
1167
|
+
cross_attentions=all_cross_attentions,
|
|
1168
|
+
)
|
|
1169
|
+
|
|
1170
|
+
|
|
1171
|
+
class MBartDecoderWrapper(MBartPreTrainedModel):
|
|
1172
|
+
"""
|
|
1173
|
+
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
|
1174
|
+
used in combination with the [`EncoderDecoderModel`] framework.
|
|
1175
|
+
"""
|
|
1176
|
+
|
|
1177
|
+
def __init__(self, config):
|
|
1178
|
+
super().__init__(config)
|
|
1179
|
+
self.decoder = MBartDecoder(config)
|
|
1180
|
+
|
|
1181
|
+
def forward(self, *args, **kwargs):
|
|
1182
|
+
return self.decoder(*args, **kwargs)
|
|
1183
|
+
|
|
1184
|
+
|
|
1185
|
+
def _in_projection(
|
|
1186
|
+
q: torch.Tensor,
|
|
1187
|
+
k: torch.Tensor,
|
|
1188
|
+
v: torch.Tensor,
|
|
1189
|
+
w_q: torch.Tensor,
|
|
1190
|
+
w_k: torch.Tensor,
|
|
1191
|
+
w_v: torch.Tensor,
|
|
1192
|
+
b_q: Optional[torch.Tensor] = None,
|
|
1193
|
+
b_k: Optional[torch.Tensor] = None,
|
|
1194
|
+
b_v: Optional[torch.Tensor] = None,
|
|
1195
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1196
|
+
Eq, Ek, Ev = q.shape[-1], k.shape[-1], v.shape[-1]
|
|
1197
|
+
assert w_q.shape == (
|
|
1198
|
+
Eq,
|
|
1199
|
+
Eq,
|
|
1200
|
+
), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
|
|
1201
|
+
assert w_k.shape == (
|
|
1202
|
+
Eq,
|
|
1203
|
+
Ek,
|
|
1204
|
+
), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
|
|
1205
|
+
assert w_v.shape == (
|
|
1206
|
+
Eq,
|
|
1207
|
+
Ev,
|
|
1208
|
+
), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
|
|
1209
|
+
assert b_q is None or b_q.shape == (
|
|
1210
|
+
Eq,
|
|
1211
|
+
), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
|
|
1212
|
+
assert b_k is None or b_k.shape == (
|
|
1213
|
+
Eq,
|
|
1214
|
+
), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
|
|
1215
|
+
assert b_v is None or b_v.shape == (
|
|
1216
|
+
Eq,
|
|
1217
|
+
), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
|
|
1218
|
+
return linear(q, w_q.T, b_q), linear(k, w_k.T, b_k), linear(v, w_v.T, b_v)
|
|
1219
|
+
|
|
1220
|
+
|
|
1221
|
+
def _scaled_dot_product_attention(
|
|
1222
|
+
q: torch.Tensor,
|
|
1223
|
+
k: torch.Tensor,
|
|
1224
|
+
v: torch.Tensor,
|
|
1225
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
1226
|
+
dropout_p: float = 0.0,
|
|
1227
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1228
|
+
B, Nt, E = q.shape
|
|
1229
|
+
q = q / math.sqrt(E)
|
|
1230
|
+
attn = torch.bmm(q, k.permute([0, 2, 1]))
|
|
1231
|
+
if attn_mask is not None:
|
|
1232
|
+
attn += attn_mask
|
|
1233
|
+
attn = F.softmax(attn, dim=-1)
|
|
1234
|
+
if dropout_p > 0.0:
|
|
1235
|
+
attn = F.dropout(attn, p=dropout_p)
|
|
1236
|
+
output = torch.bmm(attn, v)
|
|
1237
|
+
return output, attn
|
|
1238
|
+
|
|
1239
|
+
|
|
1240
|
+
def linear(x, w, b, is_transpose):
|
|
1241
|
+
if is_transpose:
|
|
1242
|
+
w = w.T
|
|
1243
|
+
if b is not None:
|
|
1244
|
+
return torch.matmul(x, w) + b
|
|
1245
|
+
else:
|
|
1246
|
+
return torch.matmul(x, w)
|
|
1247
|
+
|
|
1248
|
+
|
|
1249
|
+
def _in_projection_packed(
|
|
1250
|
+
q: Tensor,
|
|
1251
|
+
k: Tensor,
|
|
1252
|
+
v: Tensor,
|
|
1253
|
+
w: Tensor,
|
|
1254
|
+
b: Optional[Tensor] = None,
|
|
1255
|
+
is_export=False,
|
|
1256
|
+
) -> List[Tensor]:
|
|
1257
|
+
E = q.shape[-1]
|
|
1258
|
+
if k is v:
|
|
1259
|
+
if q is k:
|
|
1260
|
+
proj = linear(q, w, b, is_transpose=True)
|
|
1261
|
+
if is_export:
|
|
1262
|
+
B, D, L = proj.shape
|
|
1263
|
+
proj = proj.reshape([B, D, 3, E])
|
|
1264
|
+
proj = (
|
|
1265
|
+
proj.unsqueeze(0)
|
|
1266
|
+
.permute([3, 1, 2, 0, 4])
|
|
1267
|
+
.squeeze(-2)
|
|
1268
|
+
.contiguous()
|
|
1269
|
+
)
|
|
1270
|
+
else:
|
|
1271
|
+
proj = (
|
|
1272
|
+
proj.unflatten(-1, (3, E))
|
|
1273
|
+
.unsqueeze(0)
|
|
1274
|
+
.permute([3, 1, 2, 0, 4])
|
|
1275
|
+
.squeeze(-2)
|
|
1276
|
+
.contiguous()
|
|
1277
|
+
)
|
|
1278
|
+
return proj[0], proj[1], proj[2]
|
|
1279
|
+
else:
|
|
1280
|
+
w_q, w_k, w_v = w.chunk(3)
|
|
1281
|
+
if b is None:
|
|
1282
|
+
b_q = b_k = b_v = None
|
|
1283
|
+
else:
|
|
1284
|
+
b_q, b_k, b_v = b.chunk(3)
|
|
1285
|
+
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
|
1286
|
+
|
|
1287
|
+
|
|
1288
|
+
def multi_head_attention_forward(
|
|
1289
|
+
query: torch.Tensor,
|
|
1290
|
+
key: torch.Tensor,
|
|
1291
|
+
value: torch.Tensor,
|
|
1292
|
+
embed_dim_to_check: int,
|
|
1293
|
+
num_heads: int,
|
|
1294
|
+
in_proj_weight: torch.Tensor,
|
|
1295
|
+
in_proj_bias: Optional[torch.Tensor],
|
|
1296
|
+
bias_k: Optional[torch.Tensor],
|
|
1297
|
+
bias_v: Optional[torch.Tensor],
|
|
1298
|
+
add_zero_attn: bool,
|
|
1299
|
+
dropout_p: float,
|
|
1300
|
+
out_proj_weight: torch.Tensor,
|
|
1301
|
+
out_proj_bias: Optional[torch.Tensor],
|
|
1302
|
+
training: bool = True,
|
|
1303
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
1304
|
+
need_weights: bool = True,
|
|
1305
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
1306
|
+
use_separate_proj_weight: bool = False,
|
|
1307
|
+
q_proj_weight: Optional[torch.Tensor] = None,
|
|
1308
|
+
k_proj_weight: Optional[torch.Tensor] = None,
|
|
1309
|
+
v_proj_weight: Optional[torch.Tensor] = None,
|
|
1310
|
+
static_k: Optional[torch.Tensor] = None,
|
|
1311
|
+
static_v: Optional[torch.Tensor] = None,
|
|
1312
|
+
is_export=False,
|
|
1313
|
+
):
|
|
1314
|
+
tgt_len, bsz, embed_dim = query.shape
|
|
1315
|
+
src_len, _, _ = key.shape
|
|
1316
|
+
|
|
1317
|
+
if isinstance(embed_dim, torch.Tensor):
|
|
1318
|
+
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
|
|
1319
|
+
else:
|
|
1320
|
+
head_dim = embed_dim // num_heads
|
|
1321
|
+
q, k, v = _in_projection_packed(
|
|
1322
|
+
query, key, value, in_proj_weight, in_proj_bias, is_export
|
|
1323
|
+
)
|
|
1324
|
+
|
|
1325
|
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
|
1326
|
+
warnings.warn(
|
|
1327
|
+
"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
|
|
1328
|
+
)
|
|
1329
|
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
|
1330
|
+
|
|
1331
|
+
if bias_k is not None and bias_v is not None: # False
|
|
1332
|
+
assert static_k is None, "bias cannot be added to static key."
|
|
1333
|
+
assert static_v is None, "bias cannot be added to static value."
|
|
1334
|
+
k = torch.concat([k, bias_k.repeat(1, bsz, 1)])
|
|
1335
|
+
v = torch.concat([v, bias_v.repeat(1, bsz, 1)])
|
|
1336
|
+
else:
|
|
1337
|
+
assert bias_k is None
|
|
1338
|
+
assert bias_v is None
|
|
1339
|
+
|
|
1340
|
+
q = q.reshape([tgt_len, bsz * num_heads, head_dim]).permute([1, 0, 2])
|
|
1341
|
+
if static_k is None: # True
|
|
1342
|
+
k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).permute([1, 0, 2])
|
|
1343
|
+
else:
|
|
1344
|
+
assert (
|
|
1345
|
+
static_k.shape[0] == bsz * num_heads
|
|
1346
|
+
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.shape[0]}"
|
|
1347
|
+
assert (
|
|
1348
|
+
static_k.shape[2] == head_dim
|
|
1349
|
+
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.shape[2]}"
|
|
1350
|
+
k = static_k
|
|
1351
|
+
if static_v is None: # True
|
|
1352
|
+
v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
|
|
1353
|
+
else:
|
|
1354
|
+
assert (
|
|
1355
|
+
static_v.shape[0] == bsz * num_heads
|
|
1356
|
+
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.shape[0]}"
|
|
1357
|
+
assert (
|
|
1358
|
+
static_v.shape[2] == head_dim
|
|
1359
|
+
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.shape[2]}"
|
|
1360
|
+
v = static_v
|
|
1361
|
+
|
|
1362
|
+
src_len = k.shape[1]
|
|
1363
|
+
|
|
1364
|
+
if not training:
|
|
1365
|
+
dropout_p = 0.0
|
|
1366
|
+
|
|
1367
|
+
attn_output, attn_output_weights = _scaled_dot_product_attention(
|
|
1368
|
+
q, k, v, attn_mask, dropout_p
|
|
1369
|
+
)
|
|
1370
|
+
|
|
1371
|
+
attn_output = attn_output.permute([1, 0, 2]).reshape([tgt_len, bsz, embed_dim])
|
|
1372
|
+
attn_output = linear(
|
|
1373
|
+
attn_output, out_proj_weight, out_proj_bias, is_transpose=False
|
|
1374
|
+
)
|
|
1375
|
+
|
|
1376
|
+
if need_weights:
|
|
1377
|
+
attn_output_weights = attn_output_weights.reshape(
|
|
1378
|
+
[bsz, num_heads, tgt_len, src_len]
|
|
1379
|
+
)
|
|
1380
|
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
|
1381
|
+
else:
|
|
1382
|
+
return attn_output, None
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
class MyMultiheadAttention(nn.Module):
|
|
1386
|
+
"""
|
|
1387
|
+
Custom implementation of a multi-head attention layer.
|
|
1388
|
+
|
|
1389
|
+
Attributes:
|
|
1390
|
+
__constants__ (list): List of constant attributes.
|
|
1391
|
+
bias_k (Optional[paddle.Tensor]): Optional tensor for key bias.
|
|
1392
|
+
bias_v (Optional[paddle.Tensor]): Optional tensor for value bias.
|
|
1393
|
+
|
|
1394
|
+
Args:
|
|
1395
|
+
embed_dim (int): Total dimension of the model. This is the size of the input feature vectors.
|
|
1396
|
+
num_heads (int): Number of parallel attention heads. The input dimension must be divisible by the number of heads.
|
|
1397
|
+
dropout (float, optional): Dropout probability on the attention weights. Default is 0.0.
|
|
1398
|
+
bias (bool, optional): If True, adds a learnable bias to the output. Default is True.
|
|
1399
|
+
add_bias_kv (bool, optional): If True, adds bias to the key and value sequences. Default is False.
|
|
1400
|
+
add_zero_attn (bool, optional): If True, adds a zero attention head. Default is False.
|
|
1401
|
+
kdim (int, optional): Total number of features for keys. If None, defaults to embed_dim.
|
|
1402
|
+
vdim (int, optional): Total number of features for values. If None, defaults to embed_dim.
|
|
1403
|
+
batch_first (bool, optional): If True, the input and output tensors are provided as (batch, seq, feature). Default is False.
|
|
1404
|
+
device (optional): The device on which the layer's parameters should be initialized. Default is None.
|
|
1405
|
+
dtype (optional): The data type for the parameters. Default is None.
|
|
1406
|
+
is_export (bool, optional): If True, the layer is set up for export, potentially changing behavior for compatibility. Default is False.
|
|
1407
|
+
"""
|
|
1408
|
+
|
|
1409
|
+
__constants__ = ["batch_first"]
|
|
1410
|
+
bias_k: Optional[torch.Tensor]
|
|
1411
|
+
bias_v: Optional[torch.Tensor]
|
|
1412
|
+
|
|
1413
|
+
def __init__(
|
|
1414
|
+
self,
|
|
1415
|
+
embed_dim,
|
|
1416
|
+
num_heads,
|
|
1417
|
+
dropout=0.0,
|
|
1418
|
+
bias=True,
|
|
1419
|
+
add_bias_kv=False,
|
|
1420
|
+
add_zero_attn=False,
|
|
1421
|
+
kdim=None,
|
|
1422
|
+
vdim=None,
|
|
1423
|
+
batch_first=False,
|
|
1424
|
+
device=None,
|
|
1425
|
+
dtype=None,
|
|
1426
|
+
is_export=False,
|
|
1427
|
+
) -> None:
|
|
1428
|
+
super(MyMultiheadAttention, self).__init__()
|
|
1429
|
+
self.embed_dim = embed_dim
|
|
1430
|
+
self.kdim = kdim if kdim is not None else embed_dim
|
|
1431
|
+
self.vdim = vdim if vdim is not None else embed_dim
|
|
1432
|
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
|
1433
|
+
|
|
1434
|
+
self.num_heads = num_heads
|
|
1435
|
+
self.dropout = dropout
|
|
1436
|
+
self.batch_first = batch_first
|
|
1437
|
+
self.head_dim = embed_dim // num_heads
|
|
1438
|
+
self.is_export = is_export
|
|
1439
|
+
assert (
|
|
1440
|
+
self.head_dim * num_heads == self.embed_dim
|
|
1441
|
+
), "embed_dim must be divisible by num_heads"
|
|
1442
|
+
|
|
1443
|
+
if self._qkv_same_embed_dim is False:
|
|
1444
|
+
pass
|
|
1445
|
+
else:
|
|
1446
|
+
if dtype is None:
|
|
1447
|
+
dtype = torch.float32
|
|
1448
|
+
self.in_proj_weight = torch.nn.Parameter(torch.randn(3 * embed_dim, embed_dim) * 0.01)
|
|
1449
|
+
self.q_proj_weight = None
|
|
1450
|
+
self.k_proj_weight = None
|
|
1451
|
+
self.v_proj_weight = None
|
|
1452
|
+
|
|
1453
|
+
if bias:
|
|
1454
|
+
self.in_proj_bias = torch.nn.Parameter(torch.randn(3 * embed_dim, ) * 0.01)
|
|
1455
|
+
torch.nn.init.zeros_(self.in_proj_bias)
|
|
1456
|
+
else:
|
|
1457
|
+
self.in_proj_bias = None
|
|
1458
|
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
1459
|
+
|
|
1460
|
+
if add_bias_kv:
|
|
1461
|
+
pass
|
|
1462
|
+
else:
|
|
1463
|
+
self.bias_k = self.bias_v = None
|
|
1464
|
+
|
|
1465
|
+
self.add_zero_attn = add_zero_attn
|
|
1466
|
+
|
|
1467
|
+
self._reset_parameters()
|
|
1468
|
+
|
|
1469
|
+
def _reset_parameters(self):
|
|
1470
|
+
|
|
1471
|
+
if self._qkv_same_embed_dim:
|
|
1472
|
+
torch.nn.init.xavier_normal_(self.in_proj_weight)
|
|
1473
|
+
else:
|
|
1474
|
+
torch.nn.init.xavier_normal_(self.q_proj_weight)
|
|
1475
|
+
torch.nn.init.xavier_normal_(self.k_proj_weight)
|
|
1476
|
+
torch.nn.init.xavier_normal_(self.v_proj_weight)
|
|
1477
|
+
|
|
1478
|
+
if self.in_proj_bias is not None:
|
|
1479
|
+
torch.nn.init.zeros_(self.in_proj_bias)
|
|
1480
|
+
torch.nn.init.zeros_(self.out_proj.bias)
|
|
1481
|
+
if self.bias_k is not None:
|
|
1482
|
+
torch.nn.init.xavier_normal_(self.bias_k)
|
|
1483
|
+
if self.bias_v is not None:
|
|
1484
|
+
torch.nn.init.xavier_normal_(self.bias_v)
|
|
1485
|
+
|
|
1486
|
+
def forward(
|
|
1487
|
+
self,
|
|
1488
|
+
query: torch.Tensor,
|
|
1489
|
+
key: torch.Tensor,
|
|
1490
|
+
value: torch.Tensor,
|
|
1491
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
1492
|
+
need_weights: bool = True,
|
|
1493
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
1494
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
1495
|
+
|
|
1496
|
+
attn_output, attn_output_weights = multi_head_attention_forward(
|
|
1497
|
+
query,
|
|
1498
|
+
key,
|
|
1499
|
+
value,
|
|
1500
|
+
self.embed_dim,
|
|
1501
|
+
self.num_heads,
|
|
1502
|
+
self.in_proj_weight,
|
|
1503
|
+
self.in_proj_bias,
|
|
1504
|
+
self.bias_k,
|
|
1505
|
+
self.bias_v,
|
|
1506
|
+
self.add_zero_attn,
|
|
1507
|
+
self.dropout,
|
|
1508
|
+
self.out_proj.weight,
|
|
1509
|
+
self.out_proj.bias,
|
|
1510
|
+
training=self.training,
|
|
1511
|
+
key_padding_mask=key_padding_mask,
|
|
1512
|
+
need_weights=need_weights,
|
|
1513
|
+
attn_mask=attn_mask,
|
|
1514
|
+
is_export=self.is_export,
|
|
1515
|
+
)
|
|
1516
|
+
|
|
1517
|
+
return attn_output, attn_output_weights
|
|
1518
|
+
|
|
1519
|
+
|
|
1520
|
+
class LogitsProcessorList(list):
|
|
1521
|
+
"""
|
|
1522
|
+
A list of logits processors that can be applied sequentially.
|
|
1523
|
+
|
|
1524
|
+
Methods:
|
|
1525
|
+
__call__(input_ids, scores, **kwargs): Apply all processors to the given inputs.
|
|
1526
|
+
"""
|
|
1527
|
+
|
|
1528
|
+
def __call__(self, input_ids, scores, **kwargs):
|
|
1529
|
+
for processor in self:
|
|
1530
|
+
function_args = inspect.signature(processor.__call__).parameters
|
|
1531
|
+
if len(function_args) > 2:
|
|
1532
|
+
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
|
|
1533
|
+
raise ValueError(
|
|
1534
|
+
f"Make sure that all the required parameters: {list(function_args.keys())} for "
|
|
1535
|
+
f"{processor.__class__} are passed to the logits processor."
|
|
1536
|
+
)
|
|
1537
|
+
scores = processor(input_ids, scores, **kwargs)
|
|
1538
|
+
else:
|
|
1539
|
+
scores = processor(input_ids, scores)
|
|
1540
|
+
return scores
|
|
1541
|
+
|
|
1542
|
+
|
|
1543
|
+
class ForcedEOSTokenLogitsProcessor(object):
|
|
1544
|
+
"""
|
|
1545
|
+
A processor that forces the generation of an end-of-sequence (EOS) token
|
|
1546
|
+
at a specified position in the sequence.
|
|
1547
|
+
|
|
1548
|
+
This is typically used in language generation tasks to ensure that the
|
|
1549
|
+
generated sequence ends properly when it reaches a certain length.
|
|
1550
|
+
|
|
1551
|
+
Args:
|
|
1552
|
+
max_length (int): The maximum length of the sequence. Forces EOS when this length is reached.
|
|
1553
|
+
eos_token_id (Union[int, List[int]]): The ID(s) of the EOS token(s) to be forced in the sequence.
|
|
1554
|
+
"""
|
|
1555
|
+
|
|
1556
|
+
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
|
|
1557
|
+
self.max_length = max_length
|
|
1558
|
+
if isinstance(eos_token_id, int):
|
|
1559
|
+
eos_token_id = [eos_token_id]
|
|
1560
|
+
self.eos_token_id = eos_token_id
|
|
1561
|
+
|
|
1562
|
+
def __call__(self, input_ids, scores):
|
|
1563
|
+
cur_len = input_ids.shape[-1]
|
|
1564
|
+
scores_processed = scores
|
|
1565
|
+
if cur_len == self.max_length - 1:
|
|
1566
|
+
scores_processed = torch.full_like(scores, -math.inf)
|
|
1567
|
+
scores_processed[:, self.eos_token_id] = 0
|
|
1568
|
+
return scores_processed
|
|
1569
|
+
|
|
1570
|
+
|
|
1571
|
+
@dataclass
|
|
1572
|
+
class CausalLMOutputWithCrossAttentions(ModelOutput):
|
|
1573
|
+
loss = None
|
|
1574
|
+
logits = None
|
|
1575
|
+
past_key_values = None
|
|
1576
|
+
hidden_states = None
|
|
1577
|
+
attentions = None
|
|
1578
|
+
cross_attentions = None
|
|
1579
|
+
|
|
1580
|
+
def __init__(self, *args, **kwargs):
|
|
1581
|
+
super().__init__(*args, **kwargs)
|
|
1582
|
+
|
|
1583
|
+
|
|
1584
|
+
@dataclass
|
|
1585
|
+
class CausalLMOutputWithCrossAttentionsAndCounting(ModelOutput):
|
|
1586
|
+
"""
|
|
1587
|
+
Base class for causal language model (or autoregressive) outputs.
|
|
1588
|
+
"""
|
|
1589
|
+
|
|
1590
|
+
logits = None
|
|
1591
|
+
counting = None
|
|
1592
|
+
past_key_values = None
|
|
1593
|
+
hidden_states = None
|
|
1594
|
+
attentions = None
|
|
1595
|
+
cross_attentions = None
|
|
1596
|
+
|
|
1597
|
+
def __init__(self, *args, **kwargs):
|
|
1598
|
+
super().__init__(*args, **kwargs)
|
|
1599
|
+
|
|
1600
|
+
|
|
1601
|
+
class CustomMBartDecoder(MBartDecoder):
|
|
1602
|
+
"""
|
|
1603
|
+
A custom MBartDecoder that includes additional processing layers.
|
|
1604
|
+
|
|
1605
|
+
This class extends the MBartDecoder by adding a customizable neural network
|
|
1606
|
+
component called `counting_context_weight`, which applies a series of linear
|
|
1607
|
+
transformations followed by ReLU activations. This can be used to modify or
|
|
1608
|
+
enhance the decoder's behavior for specific tasks.
|
|
1609
|
+
|
|
1610
|
+
Args:
|
|
1611
|
+
config: The configuration object containing model parameters.
|
|
1612
|
+
"""
|
|
1613
|
+
|
|
1614
|
+
def __init__(self, config):
|
|
1615
|
+
super().__init__(config)
|
|
1616
|
+
hidden_size = config.d_model
|
|
1617
|
+
self.is_export = config.is_export
|
|
1618
|
+
self.counting_context_weight = nn.Sequential(
|
|
1619
|
+
nn.Linear(config.vocab_size, hidden_size),
|
|
1620
|
+
nn.ReLU(),
|
|
1621
|
+
nn.Linear(hidden_size, hidden_size),
|
|
1622
|
+
nn.ReLU(),
|
|
1623
|
+
nn.Linear(hidden_size, config.d_model),
|
|
1624
|
+
)
|
|
1625
|
+
|
|
1626
|
+
def forward(
|
|
1627
|
+
self,
|
|
1628
|
+
input_ids=None,
|
|
1629
|
+
attention_mask=None,
|
|
1630
|
+
count_pred=None,
|
|
1631
|
+
encoder_hidden_states=None,
|
|
1632
|
+
encoder_attention_mask=None,
|
|
1633
|
+
head_mask=None,
|
|
1634
|
+
cross_attn_head_mask=None,
|
|
1635
|
+
past_key_values=None,
|
|
1636
|
+
inputs_embeds=None,
|
|
1637
|
+
use_cache=None,
|
|
1638
|
+
output_attentions=None,
|
|
1639
|
+
output_hidden_states=None,
|
|
1640
|
+
return_dict=None,
|
|
1641
|
+
):
|
|
1642
|
+
self.is_export = False if self.training else True
|
|
1643
|
+
output_attentions = (
|
|
1644
|
+
output_attentions
|
|
1645
|
+
if output_attentions is not None
|
|
1646
|
+
else self.config.output_attentions
|
|
1647
|
+
)
|
|
1648
|
+
output_hidden_states = (
|
|
1649
|
+
output_hidden_states
|
|
1650
|
+
if output_hidden_states is not None
|
|
1651
|
+
else self.config.output_hidden_states
|
|
1652
|
+
)
|
|
1653
|
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
1654
|
+
return_dict = (
|
|
1655
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
1656
|
+
)
|
|
1657
|
+
|
|
1658
|
+
if input_ids is not None and inputs_embeds is not None:
|
|
1659
|
+
raise ValueError(
|
|
1660
|
+
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
|
1661
|
+
)
|
|
1662
|
+
elif input_ids is not None:
|
|
1663
|
+
input = input_ids
|
|
1664
|
+
input_shape = input.shape
|
|
1665
|
+
input_ids = input_ids.reshape([-1, input_shape[-1]])
|
|
1666
|
+
elif inputs_embeds is not None:
|
|
1667
|
+
input_shape = inputs_embeds.shape[:-1]
|
|
1668
|
+
input = inputs_embeds[:, :, -1]
|
|
1669
|
+
else:
|
|
1670
|
+
raise ValueError(
|
|
1671
|
+
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
|
1672
|
+
)
|
|
1673
|
+
|
|
1674
|
+
past_key_values_length = (
|
|
1675
|
+
past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
1676
|
+
)
|
|
1677
|
+
|
|
1678
|
+
if inputs_embeds is None:
|
|
1679
|
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
|
1680
|
+
|
|
1681
|
+
if self._use_flash_attention_2:
|
|
1682
|
+
attention_mask = (
|
|
1683
|
+
attention_mask
|
|
1684
|
+
if (attention_mask is not None and 0 in attention_mask)
|
|
1685
|
+
else None
|
|
1686
|
+
)
|
|
1687
|
+
else:
|
|
1688
|
+
if self.is_export:
|
|
1689
|
+
attention_mask = _prepare_4d_causal_attention_mask_export(
|
|
1690
|
+
attention_mask,
|
|
1691
|
+
input_shape,
|
|
1692
|
+
inputs_embeds,
|
|
1693
|
+
past_key_values_length,
|
|
1694
|
+
is_export=self.is_export,
|
|
1695
|
+
).to(torch.float32)
|
|
1696
|
+
else:
|
|
1697
|
+
attention_mask = _prepare_4d_causal_attention_mask(
|
|
1698
|
+
attention_mask,
|
|
1699
|
+
input_shape,
|
|
1700
|
+
inputs_embeds,
|
|
1701
|
+
past_key_values_length,
|
|
1702
|
+
is_export=self.is_export,
|
|
1703
|
+
)
|
|
1704
|
+
|
|
1705
|
+
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
|
1706
|
+
if self._use_flash_attention_2:
|
|
1707
|
+
encoder_attention_mask = (
|
|
1708
|
+
encoder_attention_mask if 0 in encoder_attention_mask else None
|
|
1709
|
+
)
|
|
1710
|
+
else:
|
|
1711
|
+
encoder_attention_mask = _prepare_4d_attention_mask(
|
|
1712
|
+
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
1713
|
+
)
|
|
1714
|
+
|
|
1715
|
+
# embed positions
|
|
1716
|
+
positions = self.embed_positions(input, past_key_values_length)
|
|
1717
|
+
|
|
1718
|
+
hidden_states = inputs_embeds + positions
|
|
1719
|
+
|
|
1720
|
+
# TODO: add counting context weight to hidden_states
|
|
1721
|
+
if count_pred is not None:
|
|
1722
|
+
count_context_weight = self.counting_context_weight(count_pred)
|
|
1723
|
+
hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
|
|
1724
|
+
|
|
1725
|
+
hidden_states = self.layernorm_embedding(hidden_states)
|
|
1726
|
+
hidden_states = nn.functional.dropout(
|
|
1727
|
+
hidden_states, p=self.dropout, training=self.training
|
|
1728
|
+
)
|
|
1729
|
+
|
|
1730
|
+
if self.gradient_checkpointing and self.training:
|
|
1731
|
+
if use_cache:
|
|
1732
|
+
print(
|
|
1733
|
+
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
|
|
1734
|
+
)
|
|
1735
|
+
use_cache = False
|
|
1736
|
+
|
|
1737
|
+
# decoder layers
|
|
1738
|
+
all_hidden_states = () if output_hidden_states else None
|
|
1739
|
+
all_self_attns = () if output_attentions else None
|
|
1740
|
+
all_cross_attentions = (
|
|
1741
|
+
() if (output_attentions and encoder_hidden_states is not None) else None
|
|
1742
|
+
)
|
|
1743
|
+
next_decoder_cache = () if use_cache else None
|
|
1744
|
+
|
|
1745
|
+
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
|
1746
|
+
for attn_mask, mask_name in zip(
|
|
1747
|
+
[head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
|
|
1748
|
+
):
|
|
1749
|
+
if attn_mask is not None:
|
|
1750
|
+
if attn_mask.size()[0] != len(self.layers):
|
|
1751
|
+
raise ValueError(
|
|
1752
|
+
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
|
1753
|
+
f" {attn_mask.size()[0]}."
|
|
1754
|
+
)
|
|
1755
|
+
|
|
1756
|
+
for idx, decoder_layer in enumerate(self.layers):
|
|
1757
|
+
if output_hidden_states:
|
|
1758
|
+
all_hidden_states += (hidden_states,)
|
|
1759
|
+
if self.training:
|
|
1760
|
+
dropout_probability = torch.rand()
|
|
1761
|
+
if dropout_probability < self.layerdrop:
|
|
1762
|
+
continue
|
|
1763
|
+
|
|
1764
|
+
past_key_value = (
|
|
1765
|
+
past_key_values[idx] if past_key_values is not None else None
|
|
1766
|
+
)
|
|
1767
|
+
|
|
1768
|
+
if self.gradient_checkpointing and self.training:
|
|
1769
|
+
layer_outputs = self._gradient_checkpointing_func(
|
|
1770
|
+
decoder_layer.__call__,
|
|
1771
|
+
hidden_states,
|
|
1772
|
+
attention_mask,
|
|
1773
|
+
encoder_hidden_states,
|
|
1774
|
+
encoder_attention_mask,
|
|
1775
|
+
head_mask[idx] if head_mask is not None else None,
|
|
1776
|
+
(
|
|
1777
|
+
cross_attn_head_mask[idx]
|
|
1778
|
+
if cross_attn_head_mask is not None
|
|
1779
|
+
else None
|
|
1780
|
+
),
|
|
1781
|
+
None,
|
|
1782
|
+
output_attentions,
|
|
1783
|
+
use_cache,
|
|
1784
|
+
)
|
|
1785
|
+
else:
|
|
1786
|
+
layer_outputs = decoder_layer(
|
|
1787
|
+
hidden_states,
|
|
1788
|
+
attention_mask=attention_mask,
|
|
1789
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
1790
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
1791
|
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
|
1792
|
+
cross_attn_layer_head_mask=(
|
|
1793
|
+
cross_attn_head_mask[idx]
|
|
1794
|
+
if cross_attn_head_mask is not None
|
|
1795
|
+
else None
|
|
1796
|
+
),
|
|
1797
|
+
past_key_value=past_key_value,
|
|
1798
|
+
output_attentions=output_attentions,
|
|
1799
|
+
use_cache=use_cache,
|
|
1800
|
+
)
|
|
1801
|
+
hidden_states = layer_outputs[0]
|
|
1802
|
+
if self.is_export:
|
|
1803
|
+
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
|
1804
|
+
else:
|
|
1805
|
+
if use_cache:
|
|
1806
|
+
next_decoder_cache += (
|
|
1807
|
+
layer_outputs[3 if output_attentions else 1],
|
|
1808
|
+
)
|
|
1809
|
+
|
|
1810
|
+
if output_attentions:
|
|
1811
|
+
all_self_attns += (layer_outputs[1],)
|
|
1812
|
+
|
|
1813
|
+
if encoder_hidden_states is not None:
|
|
1814
|
+
all_cross_attentions += (layer_outputs[2],)
|
|
1815
|
+
|
|
1816
|
+
hidden_states = self.layer_norm(hidden_states)
|
|
1817
|
+
|
|
1818
|
+
if output_hidden_states:
|
|
1819
|
+
all_hidden_states += (hidden_states,)
|
|
1820
|
+
if self.is_export:
|
|
1821
|
+
next_cache = next_decoder_cache
|
|
1822
|
+
else:
|
|
1823
|
+
next_cache = next_decoder_cache if use_cache else None
|
|
1824
|
+
if not self.is_export:
|
|
1825
|
+
if not return_dict:
|
|
1826
|
+
return tuple(
|
|
1827
|
+
v
|
|
1828
|
+
for v in [
|
|
1829
|
+
hidden_states,
|
|
1830
|
+
next_cache,
|
|
1831
|
+
all_hidden_states,
|
|
1832
|
+
all_self_attns,
|
|
1833
|
+
all_cross_attentions,
|
|
1834
|
+
]
|
|
1835
|
+
if v is not None
|
|
1836
|
+
)
|
|
1837
|
+
return BaseModelOutputWithPastAndCrossAttentions(
|
|
1838
|
+
last_hidden_state=hidden_states,
|
|
1839
|
+
past_key_values=next_cache,
|
|
1840
|
+
hidden_states=all_hidden_states,
|
|
1841
|
+
attentions=all_self_attns,
|
|
1842
|
+
cross_attentions=all_cross_attentions,
|
|
1843
|
+
)
|
|
1844
|
+
|
|
1845
|
+
|
|
1846
|
+
class SelfAttentionBlock(nn.Module):
|
|
1847
|
+
"""
|
|
1848
|
+
A self-attention block that implements multi-head self-attention
|
|
1849
|
+
followed by a feed-forward network, typically used in transformer architectures.
|
|
1850
|
+
|
|
1851
|
+
Args:
|
|
1852
|
+
embed_size (int): The size of the embedding vector.
|
|
1853
|
+
num_heads (int): The number of attention heads.
|
|
1854
|
+
is_export (bool): Flag indicating whether to configure the layer for export.
|
|
1855
|
+
"""
|
|
1856
|
+
|
|
1857
|
+
def __init__(self, embed_size, num_heads, is_export):
|
|
1858
|
+
super(SelfAttentionBlock, self).__init__()
|
|
1859
|
+
self.self_attention = MyMultiheadAttention(
|
|
1860
|
+
embed_dim=embed_size, num_heads=num_heads, is_export=is_export
|
|
1861
|
+
)
|
|
1862
|
+
self.norm = nn.LayerNorm(embed_size)
|
|
1863
|
+
|
|
1864
|
+
def forward(self, x):
|
|
1865
|
+
attn_output, _ = self.self_attention(x, x, x)
|
|
1866
|
+
x = self.norm(attn_output + x)
|
|
1867
|
+
return x
|
|
1868
|
+
|
|
1869
|
+
|
|
1870
|
+
class SeqCountingDecoder(nn.Module):
|
|
1871
|
+
"""
|
|
1872
|
+
A custom sequence counting decoder that incorporates multi-head attention layers
|
|
1873
|
+
and feed-forward networks to process sequences, potentially for latex code counting .
|
|
1874
|
+
|
|
1875
|
+
Args:
|
|
1876
|
+
in_features (int): The number of input features.
|
|
1877
|
+
out_features (int): The number of output features.
|
|
1878
|
+
num_heads (int): The number of attention heads. Defaults to 8.
|
|
1879
|
+
num_layers (int): The number of attention layers. Defaults to 4.
|
|
1880
|
+
is_export (bool): Flag indicating whether to configure the layer for export.
|
|
1881
|
+
"""
|
|
1882
|
+
|
|
1883
|
+
def __init__(
|
|
1884
|
+
self, in_features, out_features, num_heads=8, num_layers=4, is_export=False
|
|
1885
|
+
):
|
|
1886
|
+
super(SeqCountingDecoder, self).__init__()
|
|
1887
|
+
|
|
1888
|
+
self.attention_blocks = nn.ModuleList(
|
|
1889
|
+
[
|
|
1890
|
+
SelfAttentionBlock(
|
|
1891
|
+
embed_size=in_features, num_heads=num_heads, is_export=is_export
|
|
1892
|
+
)
|
|
1893
|
+
for i in range(num_layers)
|
|
1894
|
+
]
|
|
1895
|
+
)
|
|
1896
|
+
self.fc1 = nn.Linear(in_features, in_features // 2)
|
|
1897
|
+
self.relu = nn.ReLU()
|
|
1898
|
+
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
|
|
1899
|
+
self.fc2 = nn.Linear(in_features // 2, out_features)
|
|
1900
|
+
|
|
1901
|
+
def forward(self, x):
|
|
1902
|
+
for block in self.attention_blocks:
|
|
1903
|
+
x = block(x)
|
|
1904
|
+
x = self.fc1(x)
|
|
1905
|
+
x = self.relu(x)
|
|
1906
|
+
x = x.transpose([0, 2, 1])
|
|
1907
|
+
x = self.global_avg_pool(x)
|
|
1908
|
+
x = x.squeeze(-1)
|
|
1909
|
+
x = self.fc2(x)
|
|
1910
|
+
return x
|
|
1911
|
+
|
|
1912
|
+
|
|
1913
|
+
class CustomMBartForCausalLM(MBartForCausalLM):
|
|
1914
|
+
"""
|
|
1915
|
+
Custom MBart model for causal language modeling with a custom decoder.
|
|
1916
|
+
|
|
1917
|
+
This class extends the MBartForCausalLM by replacing its decoder with a
|
|
1918
|
+
custom decoder, allowing for additional flexibility and features in the
|
|
1919
|
+
decoding process.
|
|
1920
|
+
|
|
1921
|
+
Args:
|
|
1922
|
+
config: The configuration object containing model parameters.
|
|
1923
|
+
length_aware (bool): A flag to enable or configure length-aware mechanisms.
|
|
1924
|
+
"""
|
|
1925
|
+
|
|
1926
|
+
def __init__(self, config, length_aware=True):
|
|
1927
|
+
super().__init__(config)
|
|
1928
|
+
self.model.decoder = CustomMBartDecoder(config)
|
|
1929
|
+
self.counting_decoder = SeqCountingDecoder(
|
|
1930
|
+
config.d_model, config.vocab_size, is_export=config.is_export
|
|
1931
|
+
)
|
|
1932
|
+
self.length_aware = length_aware
|
|
1933
|
+
|
|
1934
|
+
def forward(
|
|
1935
|
+
self,
|
|
1936
|
+
input_ids=None,
|
|
1937
|
+
attention_mask=None,
|
|
1938
|
+
encoder_hidden_states=None,
|
|
1939
|
+
encoder_attention_mask=None,
|
|
1940
|
+
head_mask=None,
|
|
1941
|
+
cross_attn_head_mask=None,
|
|
1942
|
+
past_key_values=None,
|
|
1943
|
+
inputs_embeds=None,
|
|
1944
|
+
labels=None,
|
|
1945
|
+
use_cache=None,
|
|
1946
|
+
output_attentions=None,
|
|
1947
|
+
output_hidden_states=None,
|
|
1948
|
+
return_dict=None,
|
|
1949
|
+
count_gt=None,
|
|
1950
|
+
):
|
|
1951
|
+
output_attentions = (
|
|
1952
|
+
output_attentions
|
|
1953
|
+
if output_attentions is not None
|
|
1954
|
+
else self.config.output_attentions
|
|
1955
|
+
)
|
|
1956
|
+
output_hidden_states = (
|
|
1957
|
+
output_hidden_states
|
|
1958
|
+
if output_hidden_states is not None
|
|
1959
|
+
else self.config.output_hidden_states
|
|
1960
|
+
)
|
|
1961
|
+
return_dict = (
|
|
1962
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
1963
|
+
)
|
|
1964
|
+
|
|
1965
|
+
if self.length_aware:
|
|
1966
|
+
count_pred = self.counting_decoder(encoder_hidden_states)
|
|
1967
|
+
else:
|
|
1968
|
+
count_pred = None
|
|
1969
|
+
|
|
1970
|
+
outputs = self.model.decoder(
|
|
1971
|
+
input_ids=input_ids,
|
|
1972
|
+
attention_mask=attention_mask,
|
|
1973
|
+
count_pred=count_pred,
|
|
1974
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
1975
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
1976
|
+
head_mask=head_mask,
|
|
1977
|
+
cross_attn_head_mask=cross_attn_head_mask,
|
|
1978
|
+
past_key_values=past_key_values,
|
|
1979
|
+
inputs_embeds=inputs_embeds,
|
|
1980
|
+
use_cache=use_cache,
|
|
1981
|
+
output_attentions=output_attentions,
|
|
1982
|
+
output_hidden_states=output_hidden_states,
|
|
1983
|
+
return_dict=return_dict,
|
|
1984
|
+
)
|
|
1985
|
+
logits = self.lm_head(outputs[0])
|
|
1986
|
+
|
|
1987
|
+
return CausalLMOutputWithCrossAttentionsAndCounting(
|
|
1988
|
+
logits=logits,
|
|
1989
|
+
counting=count_pred,
|
|
1990
|
+
past_key_values=outputs.past_key_values,
|
|
1991
|
+
hidden_states=outputs.hidden_states,
|
|
1992
|
+
attentions=outputs.attentions,
|
|
1993
|
+
cross_attentions=outputs.cross_attentions,
|
|
1994
|
+
)
|
|
1995
|
+
|
|
1996
|
+
|
|
1997
|
+
class UniMERNetHead(nn.Module):
|
|
1998
|
+
"""Implementation of UniMERNetHead decoder.
|
|
1999
|
+
|
|
2000
|
+
Args:
|
|
2001
|
+
max_new_tokens (int): Maximum number of new tokens to generate.
|
|
2002
|
+
decoder_start_token_id (int): ID of the token that starts the decoding.
|
|
2003
|
+
temperature (float): Sampling temperature for generation.
|
|
2004
|
+
do_sample (bool): Whether to use sampling; if False, uses greedy decoding.
|
|
2005
|
+
top_p (float): Top-p (nucleus) sampling parameter.
|
|
2006
|
+
in_channels (int): Number of input channels/features.
|
|
2007
|
+
encoder_hidden_size (int): Hidden size of the encoder.
|
|
2008
|
+
decoder_hidden_size (int): Hidden size of the decoder.
|
|
2009
|
+
decoder_ffn_dim (int): Dimension of the decoder's feed-forward network.
|
|
2010
|
+
decoder_layers (int): Number of layers in the decoder.
|
|
2011
|
+
is_export (bool): Flag indicating if the model is being prepared for export.
|
|
2012
|
+
length_aware (bool): Flag to enable length-aware mechanisms.
|
|
2013
|
+
"""
|
|
2014
|
+
|
|
2015
|
+
def __init__(
|
|
2016
|
+
self,
|
|
2017
|
+
max_new_tokens=1536,
|
|
2018
|
+
decoder_start_token_id=0,
|
|
2019
|
+
temperature=0.2,
|
|
2020
|
+
do_sample=False,
|
|
2021
|
+
top_p=0.95,
|
|
2022
|
+
in_channels=1024,
|
|
2023
|
+
encoder_hidden_size=1024,
|
|
2024
|
+
decoder_hidden_size=1024,
|
|
2025
|
+
decoder_ffn_dim=4096,
|
|
2026
|
+
decoder_layers=8,
|
|
2027
|
+
is_export=False,
|
|
2028
|
+
length_aware=True,
|
|
2029
|
+
):
|
|
2030
|
+
super().__init__()
|
|
2031
|
+
mbart_config_dict = {
|
|
2032
|
+
"activation_dropout": 0.0,
|
|
2033
|
+
"activation_function": "gelu",
|
|
2034
|
+
"add_cross_attention": True,
|
|
2035
|
+
"add_final_layer_norm": True,
|
|
2036
|
+
"attention_dropout": 0.0,
|
|
2037
|
+
"bos_token_id": 0,
|
|
2038
|
+
"classifier_dropout": 0.0,
|
|
2039
|
+
"d_model": decoder_hidden_size,
|
|
2040
|
+
"decoder_attention_heads": 16,
|
|
2041
|
+
"decoder_ffn_dim": decoder_ffn_dim,
|
|
2042
|
+
"decoder_layerdrop": 0.0,
|
|
2043
|
+
"decoder_layers": decoder_layers,
|
|
2044
|
+
"dropout": 0.1,
|
|
2045
|
+
"encoder_attention_heads": 16,
|
|
2046
|
+
"encoder_ffn_dim": 4096,
|
|
2047
|
+
"encoder_layerdrop": 0.0,
|
|
2048
|
+
"encoder_layers": 12,
|
|
2049
|
+
"eos_token_id": 2,
|
|
2050
|
+
"forced_eos_token_id": 2,
|
|
2051
|
+
"init_std": 0.02,
|
|
2052
|
+
"is_decoder": True,
|
|
2053
|
+
"is_encoder_decoder": False,
|
|
2054
|
+
"output_hidden_states": False,
|
|
2055
|
+
"max_position_embeddings": max_new_tokens,
|
|
2056
|
+
"model_type": "mbart",
|
|
2057
|
+
"num_hidden_layers": 12,
|
|
2058
|
+
"pad_token_id": 1,
|
|
2059
|
+
"scale_embedding": True,
|
|
2060
|
+
"tie_word_embeddings": False,
|
|
2061
|
+
"transformers_version": "4.40.0",
|
|
2062
|
+
"use_cache": True,
|
|
2063
|
+
"use_return_dict": True,
|
|
2064
|
+
"vocab_size": 50000,
|
|
2065
|
+
"_attn_implementation": "eager",
|
|
2066
|
+
"hidden_size": decoder_hidden_size,
|
|
2067
|
+
"is_export": is_export,
|
|
2068
|
+
}
|
|
2069
|
+
|
|
2070
|
+
self.max_new_tokens = max_new_tokens
|
|
2071
|
+
self.decoder_start_token_id = decoder_start_token_id
|
|
2072
|
+
self.temperature = temperature
|
|
2073
|
+
self.do_sample = do_sample
|
|
2074
|
+
self.top_p = top_p
|
|
2075
|
+
self.max_seq_len = max_new_tokens
|
|
2076
|
+
self.config_decoder = MBartConfig(**mbart_config_dict)
|
|
2077
|
+
self.encoder_hidden_size = encoder_hidden_size
|
|
2078
|
+
self.is_export = self.config_decoder.is_export
|
|
2079
|
+
self.decoder = CustomMBartForCausalLM(
|
|
2080
|
+
self.config_decoder, length_aware=length_aware
|
|
2081
|
+
)
|
|
2082
|
+
if self.config_decoder.hidden_size != self.encoder_hidden_size:
|
|
2083
|
+
self.enc_to_dec_proj = nn.Linear(
|
|
2084
|
+
self.encoder_hidden_size, self.config_decoder.hidden_size
|
|
2085
|
+
)
|
|
2086
|
+
generation_config = {
|
|
2087
|
+
"max_length": 1537,
|
|
2088
|
+
"forced_eos_token_id": 2,
|
|
2089
|
+
}
|
|
2090
|
+
self.eos_token_id = generation_config["forced_eos_token_id"]
|
|
2091
|
+
self.pad_token_id = self.config_decoder.pad_token_id
|
|
2092
|
+
self.logits_processor = LogitsProcessorList()
|
|
2093
|
+
self.logits_processor.append(
|
|
2094
|
+
ForcedEOSTokenLogitsProcessor(
|
|
2095
|
+
generation_config["max_length"],
|
|
2096
|
+
generation_config["forced_eos_token_id"],
|
|
2097
|
+
)
|
|
2098
|
+
)
|
|
2099
|
+
|
|
2100
|
+
def _get_decoder_start_token_id(
|
|
2101
|
+
self, decoder_start_token_id=None, bos_token_id=None
|
|
2102
|
+
) -> int:
|
|
2103
|
+
decoder_start_token_id = (
|
|
2104
|
+
decoder_start_token_id
|
|
2105
|
+
if decoder_start_token_id is not None
|
|
2106
|
+
else self.generation_config.decoder_start_token_id
|
|
2107
|
+
)
|
|
2108
|
+
bos_token_id = (
|
|
2109
|
+
bos_token_id
|
|
2110
|
+
if bos_token_id is not None
|
|
2111
|
+
else self.generation_config.bos_token_id
|
|
2112
|
+
)
|
|
2113
|
+
if decoder_start_token_id is not None:
|
|
2114
|
+
return decoder_start_token_id
|
|
2115
|
+
elif bos_token_id is not None:
|
|
2116
|
+
return bos_token_id
|
|
2117
|
+
raise ValueError(
|
|
2118
|
+
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
|
2119
|
+
)
|
|
2120
|
+
|
|
2121
|
+
def _prepare_decoder_input_ids_for_generation(
|
|
2122
|
+
self,
|
|
2123
|
+
batch_size,
|
|
2124
|
+
model_kwargs,
|
|
2125
|
+
decoder_start_token_id=None,
|
|
2126
|
+
bos_token_id=None,
|
|
2127
|
+
):
|
|
2128
|
+
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
|
2129
|
+
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
|
|
2130
|
+
elif "input_ids" in model_kwargs:
|
|
2131
|
+
decoder_input_ids = model_kwargs.pop("input_ids")
|
|
2132
|
+
else:
|
|
2133
|
+
decoder_input_ids = None
|
|
2134
|
+
|
|
2135
|
+
decoder_start_token_id = self._get_decoder_start_token_id(
|
|
2136
|
+
decoder_start_token_id, bos_token_id
|
|
2137
|
+
)
|
|
2138
|
+
|
|
2139
|
+
if isinstance(decoder_start_token_id, list):
|
|
2140
|
+
if len(decoder_start_token_id) != batch_size:
|
|
2141
|
+
raise ValueError(
|
|
2142
|
+
f"`decoder_start_token_id` expected to have length {batch_size} but got {len(decoder_start_token_id)}"
|
|
2143
|
+
)
|
|
2144
|
+
decoder_input_ids_start = torch.LongTensor(decoder_start_token_id)
|
|
2145
|
+
decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
|
|
2146
|
+
else:
|
|
2147
|
+
decoder_input_ids_start = (
|
|
2148
|
+
torch.ones(
|
|
2149
|
+
(batch_size, 1),
|
|
2150
|
+
dtype=torch.int64,
|
|
2151
|
+
)
|
|
2152
|
+
* decoder_start_token_id
|
|
2153
|
+
)
|
|
2154
|
+
|
|
2155
|
+
if decoder_input_ids is None:
|
|
2156
|
+
decoder_input_ids = decoder_input_ids_start
|
|
2157
|
+
elif (
|
|
2158
|
+
self.config.model_type == "vision-encoder-decoder"
|
|
2159
|
+
and "donut" in self.name_or_path.lower()
|
|
2160
|
+
):
|
|
2161
|
+
pass
|
|
2162
|
+
elif self.config.model_type in ["whisper"]:
|
|
2163
|
+
pass
|
|
2164
|
+
elif (
|
|
2165
|
+
isinstance(decoder_start_token_id, int)
|
|
2166
|
+
and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
|
|
2167
|
+
) or (
|
|
2168
|
+
isinstance(decoder_start_token_id, torch.Tensor)
|
|
2169
|
+
and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
|
|
2170
|
+
):
|
|
2171
|
+
decoder_input_ids = torch.concat(
|
|
2172
|
+
[decoder_input_ids_start, decoder_input_ids], dim=-1
|
|
2173
|
+
)
|
|
2174
|
+
if "decoder_attention_mask" in model_kwargs:
|
|
2175
|
+
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
|
2176
|
+
decoder_attention_mask = torch.cat(
|
|
2177
|
+
(
|
|
2178
|
+
torch.ones_like(decoder_attention_mask)[:, :1],
|
|
2179
|
+
decoder_attention_mask,
|
|
2180
|
+
),
|
|
2181
|
+
dim=-1,
|
|
2182
|
+
)
|
|
2183
|
+
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
|
|
2184
|
+
|
|
2185
|
+
return decoder_input_ids, model_kwargs
|
|
2186
|
+
|
|
2187
|
+
def prepare_inputs_for_generation_mbart(
|
|
2188
|
+
self,
|
|
2189
|
+
input_ids,
|
|
2190
|
+
past_key_values=None,
|
|
2191
|
+
attention_mask=None,
|
|
2192
|
+
use_cache=None,
|
|
2193
|
+
**kwargs,
|
|
2194
|
+
):
|
|
2195
|
+
|
|
2196
|
+
if attention_mask is None:
|
|
2197
|
+
attention_mask = torch.ones(input_ids.shape)
|
|
2198
|
+
|
|
2199
|
+
if past_key_values:
|
|
2200
|
+
past_length = past_key_values[0][0].shape[2]
|
|
2201
|
+
|
|
2202
|
+
if input_ids.shape[1] > past_length:
|
|
2203
|
+
remove_prefix_length = past_length
|
|
2204
|
+
else:
|
|
2205
|
+
remove_prefix_length = input_ids.shape[1] - 1
|
|
2206
|
+
|
|
2207
|
+
input_ids = input_ids[:, remove_prefix_length:]
|
|
2208
|
+
return {
|
|
2209
|
+
"input_ids": input_ids,
|
|
2210
|
+
"attention_mask": attention_mask,
|
|
2211
|
+
"past_key_values": past_key_values,
|
|
2212
|
+
"use_cache": use_cache,
|
|
2213
|
+
}
|
|
2214
|
+
|
|
2215
|
+
def prepare_inputs_for_generation(
|
|
2216
|
+
self,
|
|
2217
|
+
input_ids,
|
|
2218
|
+
past_key_values=None,
|
|
2219
|
+
attention_mask=None,
|
|
2220
|
+
use_cache=None,
|
|
2221
|
+
encoder_outputs=None,
|
|
2222
|
+
**kwargs,
|
|
2223
|
+
):
|
|
2224
|
+
decoder_inputs = self.prepare_inputs_for_generation_mbart(
|
|
2225
|
+
input_ids, past_key_values=past_key_values
|
|
2226
|
+
)
|
|
2227
|
+
decoder_attention_mask = (
|
|
2228
|
+
decoder_inputs["attention_mask"]
|
|
2229
|
+
if "attention_mask" in decoder_inputs
|
|
2230
|
+
else None
|
|
2231
|
+
)
|
|
2232
|
+
input_dict = {
|
|
2233
|
+
"attention_mask": attention_mask,
|
|
2234
|
+
"decoder_attention_mask": decoder_attention_mask,
|
|
2235
|
+
"decoder_input_ids": decoder_inputs["input_ids"],
|
|
2236
|
+
"encoder_outputs": encoder_outputs,
|
|
2237
|
+
"past_key_values": decoder_inputs["past_key_values"],
|
|
2238
|
+
"use_cache": use_cache,
|
|
2239
|
+
}
|
|
2240
|
+
return input_dict
|
|
2241
|
+
|
|
2242
|
+
def prepare_inputs_for_generation_export(
|
|
2243
|
+
self,
|
|
2244
|
+
past_key_values=None,
|
|
2245
|
+
attention_mask=None,
|
|
2246
|
+
use_cache=None,
|
|
2247
|
+
encoder_outputs=None,
|
|
2248
|
+
**kwargs,
|
|
2249
|
+
):
|
|
2250
|
+
|
|
2251
|
+
input_dict = {
|
|
2252
|
+
"decoder_attention_mask": None,
|
|
2253
|
+
"use_cache": use_cache,
|
|
2254
|
+
}
|
|
2255
|
+
return input_dict
|
|
2256
|
+
|
|
2257
|
+
def _extract_past_from_model_output(
|
|
2258
|
+
self, outputs: ModelOutput, standardize_cache_format: bool = False
|
|
2259
|
+
):
|
|
2260
|
+
past_key_values = None
|
|
2261
|
+
if "past_key_values" in outputs:
|
|
2262
|
+
past_key_values = outputs.past_key_values
|
|
2263
|
+
elif "mems" in outputs:
|
|
2264
|
+
past_key_values = outputs.mems
|
|
2265
|
+
elif "past_buckets_states" in outputs:
|
|
2266
|
+
past_key_values = outputs.past_buckets_states
|
|
2267
|
+
|
|
2268
|
+
return past_key_values
|
|
2269
|
+
|
|
2270
|
+
def _update_model_kwargs_for_generation(
|
|
2271
|
+
self,
|
|
2272
|
+
outputs: ModelOutput,
|
|
2273
|
+
model_kwargs: Dict[str, Any],
|
|
2274
|
+
is_encoder_decoder: bool = False,
|
|
2275
|
+
standardize_cache_format: bool = False,
|
|
2276
|
+
) -> Dict[str, Any]:
|
|
2277
|
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
|
2278
|
+
outputs, standardize_cache_format=standardize_cache_format
|
|
2279
|
+
)
|
|
2280
|
+
if getattr(outputs, "state", None) is not None:
|
|
2281
|
+
model_kwargs["state"] = outputs.state
|
|
2282
|
+
|
|
2283
|
+
if "token_type_ids" in model_kwargs:
|
|
2284
|
+
token_type_ids = model_kwargs["token_type_ids"]
|
|
2285
|
+
model_kwargs["token_type_ids"] = torch.concat(
|
|
2286
|
+
[token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
|
|
2287
|
+
)
|
|
2288
|
+
|
|
2289
|
+
if not is_encoder_decoder:
|
|
2290
|
+
if "attention_mask" in model_kwargs:
|
|
2291
|
+
attention_mask = model_kwargs["attention_mask"]
|
|
2292
|
+
model_kwargs["attention_mask"] = torch.concat(
|
|
2293
|
+
[
|
|
2294
|
+
attention_mask,
|
|
2295
|
+
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
|
2296
|
+
],
|
|
2297
|
+
dim=-1,
|
|
2298
|
+
)
|
|
2299
|
+
else:
|
|
2300
|
+
if "decoder_attention_mask" in model_kwargs:
|
|
2301
|
+
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
|
2302
|
+
model_kwargs["decoder_attention_mask"] = torch.concat(
|
|
2303
|
+
[
|
|
2304
|
+
decoder_attention_mask,
|
|
2305
|
+
decoder_attention_mask.new_ones(
|
|
2306
|
+
(decoder_attention_mask.shape[0], 1)
|
|
2307
|
+
),
|
|
2308
|
+
],
|
|
2309
|
+
dim=-1,
|
|
2310
|
+
)
|
|
2311
|
+
|
|
2312
|
+
if (
|
|
2313
|
+
"cache_position" in model_kwargs
|
|
2314
|
+
and model_kwargs["cache_position"] is not None
|
|
2315
|
+
):
|
|
2316
|
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
|
2317
|
+
|
|
2318
|
+
return model_kwargs
|
|
2319
|
+
|
|
2320
|
+
def stopping_criteria(self, input_ids):
|
|
2321
|
+
if self.is_export:
|
|
2322
|
+
return input_ids[:, -1] == torch.Tensor([self.eos_token_id])
|
|
2323
|
+
is_done = torch.isin(input_ids[:, -1], torch.Tensor([self.eos_token_id]))
|
|
2324
|
+
return is_done
|
|
2325
|
+
|
|
2326
|
+
def generate_single_iter(
|
|
2327
|
+
self,
|
|
2328
|
+
decoder_input_ids=None,
|
|
2329
|
+
decoder_attention_mask=None,
|
|
2330
|
+
encoder_outputs=None,
|
|
2331
|
+
past_key_values=None,
|
|
2332
|
+
decoder_inputs_embeds=None,
|
|
2333
|
+
labels=None,
|
|
2334
|
+
use_cache=None,
|
|
2335
|
+
output_attentions=None,
|
|
2336
|
+
output_hidden_states=None,
|
|
2337
|
+
return_dict=None,
|
|
2338
|
+
**kwargs,
|
|
2339
|
+
):
|
|
2340
|
+
encoder_hidden_states = encoder_outputs[0]
|
|
2341
|
+
if self.config_decoder.hidden_size != self.encoder_hidden_size:
|
|
2342
|
+
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
|
2343
|
+
kwargs_decoder = {}
|
|
2344
|
+
|
|
2345
|
+
decoder_outputs = self.decoder(
|
|
2346
|
+
input_ids=decoder_input_ids,
|
|
2347
|
+
attention_mask=decoder_attention_mask,
|
|
2348
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
2349
|
+
encoder_attention_mask=None,
|
|
2350
|
+
inputs_embeds=None,
|
|
2351
|
+
output_attentions=False,
|
|
2352
|
+
output_hidden_states=output_hidden_states,
|
|
2353
|
+
use_cache=use_cache,
|
|
2354
|
+
past_key_values=past_key_values,
|
|
2355
|
+
return_dict=return_dict,
|
|
2356
|
+
**kwargs_decoder,
|
|
2357
|
+
)
|
|
2358
|
+
|
|
2359
|
+
return Seq2SeqLMOutput(
|
|
2360
|
+
loss=None,
|
|
2361
|
+
logits=decoder_outputs.logits,
|
|
2362
|
+
past_key_values=decoder_outputs.past_key_values,
|
|
2363
|
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
2364
|
+
decoder_attentions=decoder_outputs.attentions,
|
|
2365
|
+
cross_attentions=decoder_outputs.cross_attentions,
|
|
2366
|
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
|
2367
|
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
2368
|
+
encoder_attentions=encoder_outputs.attentions,
|
|
2369
|
+
)
|
|
2370
|
+
|
|
2371
|
+
@torch.no_grad()
|
|
2372
|
+
def generate(
|
|
2373
|
+
self,
|
|
2374
|
+
model_kwargs,
|
|
2375
|
+
):
|
|
2376
|
+
"""
|
|
2377
|
+
Generate sequences using the UniMERNetHead for inference tasks.
|
|
2378
|
+
|
|
2379
|
+
Args:
|
|
2380
|
+
model_kwargs (dict): A dictionary of model configurations and inputs, which typically include:
|
|
2381
|
+
- encoder_outputs: Outputs from the encoder.
|
|
2382
|
+
- use_cache: Boolean flag to indicate if caching should be used.
|
|
2383
|
+
- output_attentions: Boolean flag for outputting attention scores.
|
|
2384
|
+
- output_hidden_states: Boolean flag for outputting hidden states.
|
|
2385
|
+
|
|
2386
|
+
Returns:
|
|
2387
|
+
A tensor containing the generated sequences.
|
|
2388
|
+
"""
|
|
2389
|
+
batch_size = model_kwargs["encoder_outputs"]["last_hidden_state"].shape[0]
|
|
2390
|
+
generation_config = {
|
|
2391
|
+
"decoder_start_token_id": 0,
|
|
2392
|
+
"bos_token_id": 0,
|
|
2393
|
+
}
|
|
2394
|
+
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
|
2395
|
+
batch_size=batch_size,
|
|
2396
|
+
model_kwargs=model_kwargs,
|
|
2397
|
+
decoder_start_token_id=generation_config["decoder_start_token_id"],
|
|
2398
|
+
bos_token_id=generation_config["bos_token_id"],
|
|
2399
|
+
)
|
|
2400
|
+
model_kwargs["key use_cache"] = True
|
|
2401
|
+
batch_size, cur_len = input_ids.shape
|
|
2402
|
+
|
|
2403
|
+
if "inputs_embeds" in model_kwargs:
|
|
2404
|
+
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
|
2405
|
+
model_kwargs["cache_position"] = torch.arange(cur_len)
|
|
2406
|
+
pad_token_id = self.pad_token_id
|
|
2407
|
+
eos_token_id = [self.eos_token_id]
|
|
2408
|
+
eos_token = self.eos_token_id
|
|
2409
|
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.int64)
|
|
2410
|
+
for idx in range(self.max_seq_len):
|
|
2411
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
2412
|
+
outputs = self.generate_single_iter(
|
|
2413
|
+
**model_inputs,
|
|
2414
|
+
return_dict=True,
|
|
2415
|
+
output_attentions=False,
|
|
2416
|
+
output_hidden_states=False,
|
|
2417
|
+
)
|
|
2418
|
+
next_token_logits = outputs.logits[:, -1, :]
|
|
2419
|
+
|
|
2420
|
+
next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
|
|
2421
|
+
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
|
|
2422
|
+
if eos_token_id is not None:
|
|
2423
|
+
if pad_token_id is None:
|
|
2424
|
+
raise ValueError(
|
|
2425
|
+
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
|
2426
|
+
)
|
|
2427
|
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
|
2428
|
+
1 - unfinished_sequences
|
|
2429
|
+
)
|
|
2430
|
+
input_ids = torch.concat([input_ids, next_tokens[:, None]], dim=-1)
|
|
2431
|
+
model_kwargs = self._update_model_kwargs_for_generation(
|
|
2432
|
+
outputs,
|
|
2433
|
+
model_kwargs,
|
|
2434
|
+
is_encoder_decoder=self.config_decoder.is_encoder_decoder,
|
|
2435
|
+
)
|
|
2436
|
+
unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
|
|
2437
|
+
input_ids
|
|
2438
|
+
).to(torch.int64)
|
|
2439
|
+
|
|
2440
|
+
if (
|
|
2441
|
+
eos_token is not None
|
|
2442
|
+
and (
|
|
2443
|
+
torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
|
|
2444
|
+
>= 1
|
|
2445
|
+
).all()
|
|
2446
|
+
):
|
|
2447
|
+
break
|
|
2448
|
+
|
|
2449
|
+
return input_ids
|
|
2450
|
+
|
|
2451
|
+
@torch.no_grad()
|
|
2452
|
+
def generate_export(
|
|
2453
|
+
self,
|
|
2454
|
+
encoder_outputs,
|
|
2455
|
+
model_kwargs,
|
|
2456
|
+
):
|
|
2457
|
+
batch_size = encoder_outputs["last_hidden_state"].shape[0]
|
|
2458
|
+
generation_config = {
|
|
2459
|
+
"decoder_start_token_id": 0,
|
|
2460
|
+
"bos_token_id": 0,
|
|
2461
|
+
}
|
|
2462
|
+
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
|
2463
|
+
batch_size=batch_size,
|
|
2464
|
+
model_kwargs=model_kwargs,
|
|
2465
|
+
decoder_start_token_id=generation_config["decoder_start_token_id"],
|
|
2466
|
+
bos_token_id=generation_config["bos_token_id"],
|
|
2467
|
+
)
|
|
2468
|
+
input_ids = input_ids.reshape([-1, 1])
|
|
2469
|
+
decoder_input_ids = input_ids
|
|
2470
|
+
model_kwargs["key use_cache"] = True
|
|
2471
|
+
batch_size, cur_len = input_ids.shape
|
|
2472
|
+
|
|
2473
|
+
if "inputs_embeds" in model_kwargs:
|
|
2474
|
+
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
|
2475
|
+
cache_position = torch.arange(cur_len)
|
|
2476
|
+
pad_token_id = self.pad_token_id
|
|
2477
|
+
eos_token_id = [self.eos_token_id]
|
|
2478
|
+
eos_token = self.eos_token_id
|
|
2479
|
+
unfinished_sequences = torch.ones([batch_size], dtype=torch.int64)
|
|
2480
|
+
i_idx = torch.full([], 0)
|
|
2481
|
+
past_key_values = []
|
|
2482
|
+
for i in range(8):
|
|
2483
|
+
init_arr = torch.zeros([batch_size, 16, 0, 64])
|
|
2484
|
+
cache = (init_arr, init_arr, init_arr, init_arr)
|
|
2485
|
+
past_key_values.append(cache)
|
|
2486
|
+
idx = 0
|
|
2487
|
+
while i_idx < torch.Tensor(self.max_seq_len):
|
|
2488
|
+
|
|
2489
|
+
model_inputs = self.prepare_inputs_for_generation_export(
|
|
2490
|
+
past_key_values=past_key_values, **model_kwargs
|
|
2491
|
+
)
|
|
2492
|
+
decoder_attention_mask = model_inputs["decoder_attention_mask"]
|
|
2493
|
+
decoder_attention_mask = torch.ones(input_ids.shape)
|
|
2494
|
+
|
|
2495
|
+
outputs = self.generate_single_iter(
|
|
2496
|
+
decoder_input_ids=decoder_input_ids,
|
|
2497
|
+
decoder_attention_mask=decoder_attention_mask,
|
|
2498
|
+
encoder_outputs=encoder_outputs,
|
|
2499
|
+
past_key_values=past_key_values,
|
|
2500
|
+
return_dict=True,
|
|
2501
|
+
output_attentions=False,
|
|
2502
|
+
output_hidden_states=False,
|
|
2503
|
+
)
|
|
2504
|
+
|
|
2505
|
+
next_token_logits = outputs.logits[:, -1, :]
|
|
2506
|
+
|
|
2507
|
+
next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
|
|
2508
|
+
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
|
|
2509
|
+
if eos_token_id is not None:
|
|
2510
|
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
|
2511
|
+
1 - unfinished_sequences
|
|
2512
|
+
)
|
|
2513
|
+
input_ids = torch.concat([input_ids, next_tokens.unsqueeze(1)], dim=-1)
|
|
2514
|
+
past_length = past_key_values[0][0].shape[2]
|
|
2515
|
+
decoder_input_ids = next_tokens.unsqueeze(1)
|
|
2516
|
+
past_key_values = outputs.past_key_values
|
|
2517
|
+
cache_position = cache_position[-1:] + 1
|
|
2518
|
+
unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
|
|
2519
|
+
input_ids
|
|
2520
|
+
).to(torch.int64)
|
|
2521
|
+
if (
|
|
2522
|
+
eos_token is not None
|
|
2523
|
+
and (
|
|
2524
|
+
torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
|
|
2525
|
+
>= 1
|
|
2526
|
+
).all()
|
|
2527
|
+
):
|
|
2528
|
+
break
|
|
2529
|
+
|
|
2530
|
+
i_idx += 1
|
|
2531
|
+
return input_ids
|
|
2532
|
+
|
|
2533
|
+
def forwad_train(
|
|
2534
|
+
self,
|
|
2535
|
+
encoder_outputs,
|
|
2536
|
+
decoder_input_ids,
|
|
2537
|
+
decoder_attention_mask,
|
|
2538
|
+
past_key_values=None,
|
|
2539
|
+
decoder_inputs_embeds=None,
|
|
2540
|
+
labels=None,
|
|
2541
|
+
use_cache=None,
|
|
2542
|
+
output_attentions=None,
|
|
2543
|
+
output_hidden_states=None,
|
|
2544
|
+
return_dict=None,
|
|
2545
|
+
**kwargs,
|
|
2546
|
+
):
|
|
2547
|
+
"""
|
|
2548
|
+
Training for the UniMERNetHead.
|
|
2549
|
+
|
|
2550
|
+
Args:
|
|
2551
|
+
encoder_outputs: Outputs from the encoder, used as input to the decoder.
|
|
2552
|
+
decoder_input_ids: Input IDs for the decoder.
|
|
2553
|
+
decoder_attention_mask: Attention mask for the decoder inputs.
|
|
2554
|
+
past_key_values: Cached key/values for faster decoding.
|
|
2555
|
+
decoder_inputs_embeds: Optional embeddings for the decoder inputs.
|
|
2556
|
+
labels: Target labels for calculating loss.
|
|
2557
|
+
use_cache: Whether to use cache during decoding.
|
|
2558
|
+
output_attentions: Whether to return attention scores.
|
|
2559
|
+
output_hidden_states: Whether to return hidden states.
|
|
2560
|
+
return_dict: Whether to return a dictionary of outputs.
|
|
2561
|
+
**kwargs: Additional keyword arguments.
|
|
2562
|
+
|
|
2563
|
+
Returns:
|
|
2564
|
+
logits: The raw, unnormalized predictions from the model.
|
|
2565
|
+
count_pred: Optional prediction related to sequence length or other counts.
|
|
2566
|
+
masked_labels: The labels used during training, possibly masked.
|
|
2567
|
+
"""
|
|
2568
|
+
labels = decoder_input_ids * 1
|
|
2569
|
+
labels = labels.masked_fill_(labels == self.pad_token_id, -100)
|
|
2570
|
+
input_decoder_input_ids = decoder_input_ids[:, :-1]
|
|
2571
|
+
input_decoder_attention_mask = decoder_attention_mask[:, :-1]
|
|
2572
|
+
encoder_hidden_states = encoder_outputs[0]
|
|
2573
|
+
if self.config_decoder.hidden_size != self.encoder_hidden_size:
|
|
2574
|
+
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
|
2575
|
+
kwargs_decoder = {}
|
|
2576
|
+
decoder_outputs = self.decoder(
|
|
2577
|
+
input_ids=input_decoder_input_ids,
|
|
2578
|
+
attention_mask=input_decoder_attention_mask,
|
|
2579
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
2580
|
+
encoder_attention_mask=None,
|
|
2581
|
+
inputs_embeds=None,
|
|
2582
|
+
output_attentions=False,
|
|
2583
|
+
output_hidden_states=output_hidden_states,
|
|
2584
|
+
use_cache=use_cache,
|
|
2585
|
+
past_key_values=past_key_values,
|
|
2586
|
+
return_dict=return_dict,
|
|
2587
|
+
**kwargs_decoder,
|
|
2588
|
+
)
|
|
2589
|
+
|
|
2590
|
+
logits = decoder_outputs.logits
|
|
2591
|
+
count_pred = decoder_outputs.counting
|
|
2592
|
+
return logits, count_pred, labels
|
|
2593
|
+
|
|
2594
|
+
def forward(self, inputs, targets=None):
|
|
2595
|
+
"""
|
|
2596
|
+
Forward pass for the UniMERNetHead, handling both training and inference.
|
|
2597
|
+
|
|
2598
|
+
Args:
|
|
2599
|
+
inputs: The input data, which can vary based on training or inference.
|
|
2600
|
+
targets: The target labels, used only during training.
|
|
2601
|
+
|
|
2602
|
+
Returns:
|
|
2603
|
+
During inference: Returns predicted latex code.
|
|
2604
|
+
During training: Returns logits, predicted counts, and masked labels.
|
|
2605
|
+
"""
|
|
2606
|
+
self.is_export = False if self.training else True
|
|
2607
|
+
if not self.training:
|
|
2608
|
+
encoder_outputs = inputs
|
|
2609
|
+
if self.is_export:
|
|
2610
|
+
model_kwargs = {
|
|
2611
|
+
"output_attentions": False,
|
|
2612
|
+
"output_hidden_states": False,
|
|
2613
|
+
"use_cache": True,
|
|
2614
|
+
}
|
|
2615
|
+
word_pred = self.generate_export(encoder_outputs, model_kwargs)
|
|
2616
|
+
else:
|
|
2617
|
+
model_kwargs = {
|
|
2618
|
+
"output_attentions": False,
|
|
2619
|
+
"output_hidden_states": False,
|
|
2620
|
+
"use_cache": True,
|
|
2621
|
+
"encoder_outputs": encoder_outputs,
|
|
2622
|
+
}
|
|
2623
|
+
word_pred = self.generate(model_kwargs)
|
|
2624
|
+
|
|
2625
|
+
return word_pred
|
|
2626
|
+
|
|
2627
|
+
encoder_outputs, tgt_seq, mask = inputs
|
|
2628
|
+
logits, count_pred, masked_labels = self.forwad_train(
|
|
2629
|
+
encoder_outputs, tgt_seq, mask
|
|
2630
|
+
)
|
|
2631
|
+
return logits, count_pred, masked_labels
|