mineru 2.5.4__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. mineru/backend/pipeline/model_init.py +25 -3
  2. mineru/backend/pipeline/model_json_to_middle_json.py +2 -2
  3. mineru/backend/pipeline/model_list.py +0 -1
  4. mineru/backend/utils.py +24 -0
  5. mineru/backend/vlm/model_output_to_middle_json.py +2 -2
  6. mineru/backend/vlm/{custom_logits_processors.py → utils.py} +36 -2
  7. mineru/backend/vlm/vlm_analyze.py +43 -50
  8. mineru/backend/vlm/vlm_magic_model.py +155 -1
  9. mineru/cli/common.py +25 -22
  10. mineru/cli/fast_api.py +2 -8
  11. mineru/cli/gradio_app.py +96 -9
  12. mineru/cli/models_download.py +1 -0
  13. mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py +152 -0
  14. mineru/model/mfr/pp_formulanet_plus_m/processors.py +657 -0
  15. mineru/model/mfr/unimernet/unimernet_hf/modeling_unimernet.py +1 -326
  16. mineru/model/mfr/utils.py +338 -0
  17. mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py +103 -16
  18. mineru/model/table/rec/unet_table/main.py +1 -1
  19. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/imaug/operators.py +5 -5
  20. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/__init__.py +2 -1
  21. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_lcnetv3.py +7 -7
  22. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_pphgnetv2.py +2 -2
  23. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/__init__.py +2 -0
  24. mineru/model/utils/pytorchocr/modeling/heads/rec_ppformulanet_head.py +1383 -0
  25. mineru/model/utils/pytorchocr/modeling/heads/rec_unimernet_head.py +2631 -0
  26. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/rec_postprocess.py +25 -28
  27. mineru/model/utils/pytorchocr/utils/__init__.py +0 -0
  28. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/arch_config.yaml +130 -0
  29. mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_arabic_dict.txt +747 -0
  30. mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_cyrillic_dict.txt +850 -0
  31. mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_devanagari_dict.txt +568 -0
  32. mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_ta_dict.txt +513 -0
  33. mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_te_dict.txt +540 -0
  34. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/models_config.yml +15 -15
  35. mineru/model/utils/pytorchocr/utils/resources/pp_formulanet_arch_config.yaml +24 -0
  36. mineru/model/utils/tools/infer/__init__.py +1 -0
  37. mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_det.py +6 -3
  38. mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_rec.py +16 -25
  39. mineru/model/vlm_vllm_model/server.py +4 -1
  40. mineru/resources/header.html +2 -2
  41. mineru/utils/enum_class.py +1 -0
  42. mineru/utils/llm_aided.py +4 -2
  43. mineru/utils/ocr_utils.py +16 -0
  44. mineru/utils/table_merge.py +102 -13
  45. mineru/version.py +1 -1
  46. {mineru-2.5.4.dist-info → mineru-2.6.0.dist-info}/METADATA +32 -8
  47. mineru-2.6.0.dist-info/RECORD +195 -0
  48. mineru-2.5.4.dist-info/RECORD +0 -181
  49. /mineru/model/{ocr/paddleocr2pytorch/pytorchocr → mfr/pp_formulanet_plus_m}/__init__.py +0 -0
  50. /mineru/model/{ocr/paddleocr2pytorch/tools/infer → utils}/__init__.py +0 -0
  51. /mineru/model/{ocr/paddleocr2pytorch/pytorchocr/modeling → utils/pytorchocr}/__init__.py +0 -0
  52. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/base_ocr_v20.py +0 -0
  53. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/__init__.py +0 -0
  54. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/imaug/__init__.py +0 -0
  55. /mineru/model/{ocr/paddleocr2pytorch/pytorchocr/utils → utils/pytorchocr/modeling}/__init__.py +0 -0
  56. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/architectures/__init__.py +0 -0
  57. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/architectures/base_model.py +0 -0
  58. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/det_mobilenet_v3.py +0 -0
  59. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_donut_swin.py +0 -0
  60. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_hgnet.py +0 -0
  61. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +0 -0
  62. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_mv1_enhance.py +0 -0
  63. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_svtrnet.py +0 -0
  64. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/common.py +0 -0
  65. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/cls_head.py +0 -0
  66. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/det_db_head.py +0 -0
  67. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/rec_ctc_head.py +0 -0
  68. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/rec_multi_head.py +0 -0
  69. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/__init__.py +0 -0
  70. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/db_fpn.py +0 -0
  71. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/intracl.py +0 -0
  72. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/rnn.py +0 -0
  73. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/__init__.py +0 -0
  74. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/cls_postprocess.py +0 -0
  75. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/db_postprocess.py +0 -0
  76. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/arabic_dict.txt +0 -0
  77. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +0 -0
  78. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/cyrillic_dict.txt +0 -0
  79. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/devanagari_dict.txt +0 -0
  80. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/en_dict.txt +0 -0
  81. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/japan_dict.txt +0 -0
  82. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ka_dict.txt +0 -0
  83. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/korean_dict.txt +0 -0
  84. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/latin_dict.txt +0 -0
  85. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +0 -0
  86. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv4_doc_dict.txt +0 -0
  87. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_dict.txt +0 -0
  88. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_el_dict.txt +0 -0
  89. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_en_dict.txt +0 -0
  90. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_eslav_dict.txt +0 -0
  91. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_korean_dict.txt +0 -0
  92. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_latin_dict.txt +0 -0
  93. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_th_dict.txt +0 -0
  94. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ta_dict.txt +0 -0
  95. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/te_dict.txt +0 -0
  96. /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/__init__.py +0 -0
  97. /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_cls.py +0 -0
  98. /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_system.py +0 -0
  99. /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/pytorchocr_utility.py +0 -0
  100. {mineru-2.5.4.dist-info → mineru-2.6.0.dist-info}/WHEEL +0 -0
  101. {mineru-2.5.4.dist-info → mineru-2.6.0.dist-info}/entry_points.txt +0 -0
  102. {mineru-2.5.4.dist-info → mineru-2.6.0.dist-info}/licenses/LICENSE.md +0 -0
  103. {mineru-2.5.4.dist-info → mineru-2.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2631 @@
1
+ import copy
2
+ import math
3
+ import re
4
+ import numpy as np
5
+ import inspect
6
+ import warnings
7
+ from collections import OrderedDict
8
+ from typing import Optional, Tuple, Union, List, Dict, Any
9
+ from dataclasses import dataclass, fields, is_dataclass
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch import Tensor
14
+ import torch.nn.functional as F
15
+ from torch.nn import CrossEntropyLoss
16
+
17
+ from mineru.utils.config_reader import get_device
18
+
19
+
20
+ class ModelOutput(OrderedDict):
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+
25
+ def __post_init__(self):
26
+ class_fields = fields(self)
27
+
28
+ if not len(class_fields):
29
+ raise ValueError(f"{self.__class__.__name__} has no fields.")
30
+ if not all(field.default is None for field in class_fields[1:]):
31
+ raise ValueError(
32
+ f"{self.__class__.__name__} should not have more than one required field."
33
+ )
34
+
35
+ first_field = getattr(self, class_fields[0].name)
36
+ other_fields_are_none = all(
37
+ getattr(self, field.name) is None for field in class_fields[1:]
38
+ )
39
+ if other_fields_are_none:
40
+ if isinstance(first_field, dict):
41
+ iterator = first_field.items()
42
+ first_field_iterator = True
43
+ else:
44
+ try:
45
+ iterator = iter(first_field)
46
+ first_field_iterator = True
47
+ except TypeError:
48
+ first_field_iterator = False
49
+
50
+ if first_field_iterator:
51
+ for idx, element in enumerate(iterator):
52
+ if (
53
+ not isinstance(element, (list, tuple))
54
+ or not len(element) == 2
55
+ or not isinstance(element[0], str)
56
+ ):
57
+ if idx == 0:
58
+ self[class_fields[0].name] = first_field
59
+ else:
60
+ raise ValueError(
61
+ f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
62
+ )
63
+ break
64
+ setattr(self, element[0], element[1])
65
+ if element[1] is not None:
66
+ self[element[0]] = element[1]
67
+ elif first_field is not None:
68
+ self[class_fields[0].name] = first_field
69
+ else:
70
+ for field in class_fields:
71
+ v = getattr(self, field.name)
72
+ if v is not None:
73
+ self[field.name] = v
74
+
75
+ def __delitem__(self, *args, **kwargs):
76
+ raise Exception(
77
+ f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
78
+ )
79
+
80
+ def setdefault(self, *args, **kwargs):
81
+ raise Exception(
82
+ f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
83
+ )
84
+
85
+ def pop(self, *args, **kwargs):
86
+ raise Exception(
87
+ f"You cannot use ``pop`` on a {self.__class__.__name__} instance."
88
+ )
89
+
90
+ def update(self, *args, **kwargs):
91
+ raise Exception(
92
+ f"You cannot use ``update`` on a {self.__class__.__name__} instance."
93
+ )
94
+
95
+ def __getitem__(self, k):
96
+ if isinstance(k, str):
97
+ inner_dict = dict(self.items())
98
+ return inner_dict[k]
99
+ else:
100
+ return self.to_tuple()[k]
101
+
102
+ def __setattr__(self, name, value):
103
+ if name in self.keys() and value is not None:
104
+ super().__setitem__(name, value)
105
+ super().__setattr__(name, value)
106
+
107
+ def __setitem__(self, key, value):
108
+ super().__setitem__(key, value)
109
+ super().__setattr__(key, value)
110
+
111
+ def __reduce__(self):
112
+ if not is_dataclass(self):
113
+ return super().__reduce__()
114
+ callable, _args, *remaining = super().__reduce__()
115
+ args = tuple(getattr(self, field.name) for field in fields(self))
116
+ return callable, args, *remaining
117
+
118
+ def to_tuple(self):
119
+ return tuple(self[k] for k in self.keys())
120
+
121
+
122
+ @dataclass
123
+ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
124
+ last_hidden_state = None
125
+ past_key_values = None
126
+ hidden_states = None
127
+ attentions = None
128
+ cross_attentions = None
129
+
130
+ def __init__(self, *args, **kwargs):
131
+ super().__init__(*args, **kwargs)
132
+
133
+
134
+ @dataclass
135
+ class Seq2SeqLMOutput(ModelOutput):
136
+ loss = None
137
+ logits = None
138
+ past_key_values = None
139
+ decoder_hidden_states = None
140
+ decoder_attentions = None
141
+ cross_attentions = None
142
+ encoder_last_hidden_state = None
143
+ encoder_hidden_states = None
144
+ encoder_attentions = None
145
+
146
+ def __init__(self, *args, **kwargs):
147
+ super().__init__(*args, **kwargs)
148
+
149
+
150
+ class MBartConfig(object):
151
+ model_type = "mbart"
152
+ keys_to_ignore_at_inference = ["past_key_values"]
153
+ attribute_map = {
154
+ "num_attention_heads": "encoder_attention_heads",
155
+ "hidden_size": "d_model",
156
+ }
157
+
158
+ def __init__(
159
+ self,
160
+ vocab_size=50265,
161
+ max_position_embeddings=1024,
162
+ encoder_layers=12,
163
+ encoder_ffn_dim=4096,
164
+ encoder_attention_heads=16,
165
+ decoder_layers=12,
166
+ decoder_ffn_dim=4096,
167
+ decoder_attention_heads=16,
168
+ encoder_layerdrop=0.0,
169
+ decoder_layerdrop=0.0,
170
+ use_cache=True,
171
+ is_encoder_decoder=True,
172
+ activation_function="gelu",
173
+ d_model=1024,
174
+ dropout=0.1,
175
+ output_hidden_states=False,
176
+ use_return_dict=True,
177
+ attention_dropout=0.0,
178
+ activation_dropout=0.0,
179
+ init_std=0.02,
180
+ classifier_dropout=0.0,
181
+ scale_embedding=False,
182
+ pad_token_id=1,
183
+ bos_token_id=0,
184
+ eos_token_id=2,
185
+ forced_eos_token_id=2,
186
+ _attn_implementation="eager",
187
+ hidden_size=1024,
188
+ use_parallel=False,
189
+ parallel_step=2,
190
+ is_export=False,
191
+ **kwargs,
192
+ ):
193
+ self.vocab_size = vocab_size
194
+ self.hidden_size = hidden_size
195
+ self.max_position_embeddings = max_position_embeddings
196
+ self.d_model = d_model
197
+ self.encoder_ffn_dim = encoder_ffn_dim
198
+ self.encoder_layers = encoder_layers
199
+ self.encoder_attention_heads = encoder_attention_heads
200
+ self.decoder_ffn_dim = decoder_ffn_dim
201
+ self.decoder_layers = decoder_layers
202
+ self.decoder_attention_heads = decoder_attention_heads
203
+ self.dropout = dropout
204
+ self.output_hidden_states = output_hidden_states
205
+ self.use_return_dict = use_return_dict
206
+ self.attention_dropout = attention_dropout
207
+ self.activation_dropout = activation_dropout
208
+ self.activation_function = activation_function
209
+ self.init_std = init_std
210
+ self.encoder_layerdrop = encoder_layerdrop
211
+ self.decoder_layerdrop = decoder_layerdrop
212
+ self.classifier_dropout = classifier_dropout
213
+ self.use_cache = use_cache
214
+ self.num_hidden_layers = encoder_layers
215
+ self.scale_embedding = (
216
+ scale_embedding # scale factor will be sqrt(d_model) if True
217
+ )
218
+ self.pad_token_id = pad_token_id
219
+ self.bos_token_id = bos_token_id
220
+ self.eos_token_id = eos_token_id
221
+ self.is_encoder_decoder = is_encoder_decoder
222
+ self.forced_eos_token_id = forced_eos_token_id
223
+ self._attn_implementation = _attn_implementation
224
+ self.use_parallel = use_parallel
225
+ self.parallel_step = parallel_step
226
+ self.is_export = is_export
227
+ super().__init__()
228
+
229
+
230
+ @dataclass
231
+ class AttentionMaskConverter:
232
+ """
233
+ A utility class for converting attention masks used in transformer models.
234
+
235
+ This class handles the conversion of attention masks based on whether the
236
+ attention mechanism is causal (i.e., preventing information flow from future
237
+ tokens to past tokens) and whether a sliding window approach is used.
238
+
239
+ Attributes:
240
+ is_causal (bool): Indicates if the attention mechanism is causal.
241
+ sliding_window (Optional[int]): Specifies the size of the sliding window
242
+ for local attention, if applicable.
243
+
244
+ Args:
245
+ is_causal (bool): Determines if the attention mask should enforce causality.
246
+ sliding_window (Optional[int], optional): The size of the sliding window
247
+ for local attention. Default is None.
248
+ """
249
+
250
+ is_causal: bool
251
+ sliding_window: int
252
+
253
+ def __init__(self, is_causal: bool, sliding_window=None):
254
+ self.is_causal = is_causal
255
+ self.sliding_window = sliding_window
256
+
257
+ if self.sliding_window is not None and self.sliding_window <= 0:
258
+ raise ValueError(
259
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
260
+ )
261
+
262
+ @staticmethod
263
+ def _make_causal_mask(
264
+ input_ids_shape,
265
+ dtype,
266
+ past_key_values_length=0,
267
+ sliding_window=None,
268
+ is_export=False,
269
+ ):
270
+ bsz, tgt_len = input_ids_shape
271
+ if is_export:
272
+ mask = torch.full(
273
+ [tgt_len, tgt_len], fill_value=torch.finfo(dtype).min, dtype=torch.float64
274
+ )
275
+ else:
276
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
277
+ mask_cond = torch.arange(mask.shape[-1])
278
+ mask = mask.masked_fill_(
279
+ mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
280
+ )
281
+ return mask[None, None, :, :].expand(
282
+ [bsz, 1, tgt_len, tgt_len + past_key_values_length]
283
+ )
284
+
285
+ def to_4d_export(
286
+ self,
287
+ attention_mask_2d,
288
+ query_length,
289
+ dtype,
290
+ key_value_length,
291
+ is_export=False,
292
+ ):
293
+ input_shape = (attention_mask_2d.shape[0], query_length)
294
+ expanded_attn_mask = self._expand_mask(
295
+ attention_mask_2d, dtype, tgt_len=input_shape[-1]
296
+ )
297
+ expanded_4d_mask = expanded_attn_mask
298
+
299
+ return expanded_4d_mask
300
+
301
+ def to_4d(
302
+ self,
303
+ attention_mask_2d,
304
+ query_length,
305
+ dtype,
306
+ key_value_length,
307
+ is_export=False,
308
+ ):
309
+
310
+ input_shape = (attention_mask_2d.shape[0], query_length)
311
+ causal_4d_mask = None
312
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
313
+ if key_value_length is None:
314
+ raise ValueError(
315
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
316
+ )
317
+
318
+ past_key_values_length = key_value_length - query_length
319
+
320
+ causal_4d_mask = self._make_causal_mask(
321
+ input_shape,
322
+ dtype,
323
+ past_key_values_length=past_key_values_length,
324
+ sliding_window=self.sliding_window,
325
+ is_export=is_export,
326
+ )
327
+ elif self.sliding_window is not None:
328
+ raise NotImplementedError(
329
+ "Sliding window is currently only implemented for causal masking"
330
+ )
331
+
332
+ expanded_attn_mask = self._expand_mask(
333
+ attention_mask_2d, dtype, tgt_len=input_shape[-1]
334
+ )
335
+
336
+ if causal_4d_mask is not None:
337
+ if is_export:
338
+ expanded_attn_mask = causal_4d_mask
339
+ return expanded_attn_mask
340
+ else:
341
+ expanded_attn_mask = causal_4d_mask.masked_fill_(
342
+ expanded_attn_mask.to(torch.bool), torch.finfo(dtype).min
343
+ )
344
+
345
+ expanded_4d_mask = expanded_attn_mask
346
+
347
+ return expanded_4d_mask
348
+
349
+ def _expand_mask(self, mask, dtype, tgt_len=None):
350
+ bsz, src_len = mask.shape
351
+ tgt_len = tgt_len if tgt_len is not None else src_len
352
+ expanded_mask = (
353
+ mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).to(dtype)
354
+ )
355
+ inverted_mask = 1.0 - expanded_mask
356
+ return inverted_mask.masked_fill_(
357
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
358
+ )
359
+
360
+
361
+ def _prepare_4d_attention_mask(mask, dtype, tgt_len=None):
362
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
363
+
364
+
365
+ def _prepare_4d_causal_attention_mask_export(
366
+ attention_mask,
367
+ input_shape,
368
+ inputs_embeds,
369
+ past_key_values_length,
370
+ sliding_window=None,
371
+ is_export=False,
372
+ ):
373
+ attn_mask_converter = AttentionMaskConverter(
374
+ is_causal=True, sliding_window=sliding_window
375
+ )
376
+ key_value_length = input_shape[-1] + past_key_values_length
377
+
378
+ shape = attention_mask.shape
379
+ len_shape = len(shape)
380
+
381
+ attention_mask = attn_mask_converter.to_4d_export(
382
+ attention_mask,
383
+ input_shape[-1],
384
+ key_value_length=key_value_length,
385
+ dtype=inputs_embeds.dtype,
386
+ is_export=is_export,
387
+ )
388
+ return attention_mask
389
+
390
+
391
+ def _prepare_4d_causal_attention_mask(
392
+ attention_mask,
393
+ input_shape,
394
+ inputs_embeds,
395
+ past_key_values_length,
396
+ sliding_window=None,
397
+ is_export=False,
398
+ ):
399
+ attn_mask_converter = AttentionMaskConverter(
400
+ is_causal=True, sliding_window=sliding_window
401
+ )
402
+ key_value_length = input_shape[-1] + past_key_values_length
403
+
404
+ shape = attention_mask.shape
405
+ len_shape = len(shape)
406
+ if (attention_mask is not None) and (len_shape == 2):
407
+ attention_mask = attn_mask_converter.to_4d(
408
+ attention_mask,
409
+ input_shape[-1],
410
+ key_value_length=key_value_length,
411
+ dtype=inputs_embeds.dtype,
412
+ is_export=is_export,
413
+ )
414
+
415
+ return attention_mask
416
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
417
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
418
+ if tuple(attention_mask.shape) != expected_shape:
419
+ raise ValueError(
420
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
421
+ )
422
+ else:
423
+ inverted_mask = 1.0 - attention_mask
424
+ attention_mask = inverted_mask.masked_fill_(
425
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
426
+ )
427
+ else:
428
+ attention_mask = attn_mask_converter.to_causal_4d(
429
+ input_shape[0],
430
+ input_shape[-1],
431
+ key_value_length,
432
+ dtype=inputs_embeds.dtype,
433
+ )
434
+
435
+ return attention_mask
436
+
437
+
438
+ class MBartLearnedPositionalEmbedding(nn.Embedding):
439
+ """
440
+ This module learns positional embeddings up to a fixed maximum size.
441
+ """
442
+
443
+ def __init__(self, num_embeddings, embedding_dim):
444
+ self.offset = 2
445
+ super().__init__(num_embeddings + self.offset, embedding_dim)
446
+ self.device = torch.device(get_device())
447
+
448
+ def forward(self, input_ids, past_key_values_length=0):
449
+ """`input_ids' shape is expected to be [bsz x seqlen]."""
450
+ bsz, seq_len = input_ids.shape[:2]
451
+ positions = torch.arange(
452
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.int64
453
+ ).expand([bsz, -1]).to(self.device)
454
+ return nn.Embedding.forward(self, positions + self.offset)
455
+
456
+
457
+ class MBartPreTrainedModel(nn.Module):
458
+ base_model_prefix = "model"
459
+ supports_gradient_checkpointing = True
460
+ _no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
461
+ _supports_flash_attn_2 = True
462
+
463
+ def __init__(self, config):
464
+ super().__init__()
465
+ self.config = config
466
+
467
+ def _initialize_weights(self, module):
468
+ """
469
+ Initialize the weights if they are not already initialized.
470
+ """
471
+ if getattr(module, "_is_hf_initialized", False):
472
+ return
473
+ self._init_weights(module)
474
+
475
+ def post_init(self):
476
+ self.apply(self._initialize_weights)
477
+
478
+ def _init_weights(self, module):
479
+ std = self.config.init_std
480
+ if isinstance(module, nn.Linear):
481
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
482
+ if module.bias is not None:
483
+ torch.nn.init.constant_(module.bias, val=0.0)
484
+ elif isinstance(module, nn.Embedding):
485
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
486
+ if module.padding_idx is not None:
487
+ torch.nn.init.constant_(module.weight[module.padding_idx], val=0.0)
488
+
489
+ @property
490
+ def dummy_inputs(self):
491
+ pad_token = self.config.pad_token_id
492
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]])
493
+ dummy_inputs = {
494
+ "attention_mask": input_ids.ne(pad_token),
495
+ "input_ids": input_ids,
496
+ }
497
+ return dummy_inputs
498
+
499
+
500
+ class MBartAttention(nn.Module):
501
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
502
+
503
+ def __init__(
504
+ self,
505
+ embed_dim,
506
+ num_heads,
507
+ dropout: float = 0.0,
508
+ is_decoder: bool = False,
509
+ bias: bool = True,
510
+ is_causal: bool = False,
511
+ config=None,
512
+ ):
513
+ super().__init__()
514
+ self.embed_dim = embed_dim
515
+ self.num_heads = num_heads
516
+ self.dropout = dropout
517
+ self.head_dim = embed_dim // num_heads
518
+ self.config = config
519
+
520
+ if (self.head_dim * num_heads) != self.embed_dim:
521
+ raise ValueError(
522
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
523
+ f" and `num_heads`: {num_heads})."
524
+ )
525
+ self.scaling = self.head_dim ** -0.5
526
+ self.is_decoder = is_decoder
527
+ self.is_causal = is_causal
528
+
529
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
530
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
531
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
532
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
533
+
534
+ def _shape(self, tensor, seq_len, bsz):
535
+ return tensor.reshape([bsz, seq_len, self.num_heads, self.head_dim]).permute(
536
+ 0, 2, 1, 3
537
+ )
538
+
539
+ def forward(
540
+ self,
541
+ hidden_states,
542
+ key_value_states=None,
543
+ past_key_value=None,
544
+ attention_mask=None,
545
+ layer_head_mask=None,
546
+ output_attentions=False,
547
+ ):
548
+
549
+ is_cross_attention = key_value_states is not None
550
+
551
+ bsz, tgt_len, _ = hidden_states.shape
552
+ query_states = self.q_proj(hidden_states) * self.scaling
553
+ if (
554
+ is_cross_attention
555
+ and past_key_value is not None
556
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
557
+ ):
558
+ key_states = past_key_value[0]
559
+ value_states = past_key_value[1]
560
+ elif is_cross_attention:
561
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
562
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
563
+ elif past_key_value is not None:
564
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
565
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
566
+ key_states = torch.concat([past_key_value[0], key_states], dim=2)
567
+ value_states = torch.concat([past_key_value[1], value_states], dim=2)
568
+ else:
569
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
570
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
571
+
572
+ if self.is_decoder:
573
+ past_key_value = (key_states, value_states)
574
+
575
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
576
+ query_states = self._shape(query_states, tgt_len, bsz).reshape(proj_shape)
577
+ key_states = key_states.reshape(proj_shape)
578
+ value_states = value_states.reshape(proj_shape)
579
+
580
+ src_len = key_states.shape[1]
581
+ attn_weights = torch.bmm(query_states, key_states.permute([0, 2, 1]))
582
+
583
+ if attention_mask is not None:
584
+ attn_weights = (
585
+ attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
586
+ + attention_mask
587
+ )
588
+ attn_weights = attn_weights.reshape(
589
+ [bsz * self.num_heads, tgt_len, src_len]
590
+ )
591
+
592
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
593
+ if layer_head_mask is not None:
594
+ if tuple(layer_head_mask.shape) != (self.num_heads,):
595
+ raise ValueError(
596
+ f"Head mask for a single layer should be of shape {(self.num_heads,)}, but is"
597
+ f" {layer_head_mask.shape}"
598
+ )
599
+ attn_weights = layer_head_mask.reshape(
600
+ [1, -1, 1, 1]
601
+ ) * attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
602
+ attn_weights = attn_weights.reshape(
603
+ [bsz * self.num_heads, tgt_len, src_len]
604
+ )
605
+
606
+ if output_attentions:
607
+ attn_weights_reshaped = attn_weights.reshape(
608
+ [bsz, self.num_heads, tgt_len, src_len]
609
+ )
610
+ attn_weights = attn_weights_reshaped.reshape(
611
+ [bsz * self.num_heads, tgt_len, src_len]
612
+ )
613
+ else:
614
+ attn_weights_reshaped = None
615
+ attn_probs = nn.functional.dropout(
616
+ attn_weights, p=self.dropout, training=self.training
617
+ )
618
+ attn_output = torch.bmm(attn_probs, value_states)
619
+
620
+ attn_output = attn_output.reshape([bsz, self.num_heads, tgt_len, self.head_dim])
621
+ attn_output = attn_output.permute([0, 2, 1, 3])
622
+
623
+ attn_output = attn_output.reshape([bsz, tgt_len, self.embed_dim])
624
+ attn_output = self.out_proj(attn_output)
625
+ return attn_output, attn_weights_reshaped, past_key_value
626
+
627
+
628
+ MBART_ATTENTION_CLASSES = {
629
+ "eager": MBartAttention,
630
+ }
631
+
632
+
633
+ class MBartDecoderLayer(nn.Module):
634
+ def __init__(self, config):
635
+ super().__init__()
636
+ self.embed_dim = config.d_model
637
+ self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
638
+ embed_dim=self.embed_dim,
639
+ num_heads=config.decoder_attention_heads,
640
+ dropout=config.attention_dropout,
641
+ is_decoder=True,
642
+ is_causal=True,
643
+ config=config,
644
+ )
645
+ self.is_export = config.is_export
646
+ self.dropout = config.dropout
647
+ self.activation_fn = F.gelu
648
+ self.activation_dropout = config.activation_dropout
649
+
650
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
651
+ self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
652
+ self.embed_dim,
653
+ config.decoder_attention_heads,
654
+ dropout=config.attention_dropout,
655
+ is_decoder=True,
656
+ config=config,
657
+ )
658
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
659
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
660
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
661
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
662
+ self.device = torch.device(get_device())
663
+
664
+ def forward(
665
+ self,
666
+ hidden_states,
667
+ attention_mask=None,
668
+ encoder_hidden_states=None,
669
+ encoder_attention_mask=None,
670
+ layer_head_mask=None,
671
+ cross_attn_layer_head_mask=None,
672
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
673
+ output_attentions: Optional[bool] = False,
674
+ use_cache: Optional[bool] = True,
675
+ ) -> torch.Tensor:
676
+
677
+ residual = hidden_states
678
+ hidden_states = self.self_attn_layer_norm(hidden_states)
679
+
680
+ self_attn_past_key_value = None
681
+ if past_key_value is not None:
682
+ self_attn_past_key_value = tuple(
683
+ t.to(self.device) if isinstance(t, torch.Tensor) else t for t in past_key_value[:2]
684
+ )
685
+
686
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
687
+ hidden_states=hidden_states,
688
+ past_key_value=self_attn_past_key_value,
689
+ attention_mask=attention_mask,
690
+ layer_head_mask=layer_head_mask,
691
+ output_attentions=output_attentions,
692
+ )
693
+ hidden_states = nn.functional.dropout(
694
+ hidden_states, p=self.dropout, training=self.training
695
+ )
696
+ hidden_states = residual + hidden_states
697
+
698
+ cross_attn_present_key_value = None
699
+ cross_attn_weights = None
700
+ if encoder_hidden_states is not None:
701
+ residual = hidden_states
702
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
703
+ cross_attn_past_key_value = (
704
+ past_key_value[-2:] if past_key_value is not None else None
705
+ )
706
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = (
707
+ self.encoder_attn(
708
+ hidden_states=hidden_states,
709
+ key_value_states=encoder_hidden_states,
710
+ attention_mask=encoder_attention_mask,
711
+ layer_head_mask=cross_attn_layer_head_mask,
712
+ past_key_value=cross_attn_past_key_value,
713
+ output_attentions=output_attentions,
714
+ )
715
+ )
716
+ hidden_states = nn.functional.dropout(
717
+ hidden_states, p=self.dropout, training=self.training
718
+ )
719
+ hidden_states = residual + hidden_states
720
+
721
+ present_key_value = present_key_value + cross_attn_present_key_value
722
+
723
+ residual = hidden_states
724
+ hidden_states = self.final_layer_norm(hidden_states)
725
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
726
+ hidden_states = nn.functional.dropout(
727
+ hidden_states, p=self.activation_dropout, training=self.training
728
+ )
729
+ hidden_states = self.fc2(hidden_states)
730
+ hidden_states = nn.functional.dropout(
731
+ hidden_states, p=self.dropout, training=self.training
732
+ )
733
+ hidden_states = residual + hidden_states
734
+ outputs = (hidden_states,)
735
+
736
+ if output_attentions:
737
+ outputs += (self_attn_weights, cross_attn_weights)
738
+
739
+ if self.is_export:
740
+ outputs += (present_key_value,)
741
+ else:
742
+ if use_cache:
743
+ outputs += (present_key_value,)
744
+ return outputs
745
+
746
+
747
+ class MBartForCausalLM(MBartPreTrainedModel):
748
+ _tied_weights_keys = ["lm_head.weight"]
749
+
750
+ def __init__(self, config):
751
+ config = copy.deepcopy(config)
752
+ config.is_decoder = True
753
+ config.is_encoder_decoder = False
754
+ super().__init__(config)
755
+ self.model = MBartDecoderWrapper(config)
756
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
757
+
758
+ self.post_init()
759
+
760
+ def get_input_embeddings(self):
761
+ return self.model.decoder.embed_tokens
762
+
763
+ def set_input_embeddings(self, value):
764
+ self.model.decoder.embed_tokens = value
765
+
766
+ def get_output_embeddings(self):
767
+ return self.lm_head
768
+
769
+ def set_output_embeddings(self, new_embeddings):
770
+ self.lm_head = new_embeddings
771
+
772
+ def set_decoder(self, decoder):
773
+ self.model.decoder = decoder
774
+
775
+ def get_decoder(self):
776
+ return self.model.decoder
777
+
778
+ def forward(
779
+ self,
780
+ input_ids=None,
781
+ attention_mask=None,
782
+ encoder_hidden_states=None,
783
+ encoder_attention_mask=None,
784
+ head_mask=None,
785
+ cross_attn_head_mask=None,
786
+ past_key_values=None,
787
+ inputs_embeds=None,
788
+ labels=None,
789
+ use_cache=None,
790
+ output_attentions=None,
791
+ output_hidden_states=None,
792
+ return_dict=None,
793
+ ):
794
+
795
+ output_attentions = (
796
+ output_attentions
797
+ if output_attentions is not None
798
+ else self.config.output_attentions
799
+ )
800
+ output_hidden_states = (
801
+ output_hidden_states
802
+ if output_hidden_states is not None
803
+ else self.config.output_hidden_states
804
+ )
805
+ return_dict = (
806
+ return_dict if return_dict is not None else self.config.use_return_dict
807
+ )
808
+
809
+ outputs = self.model.decoder(
810
+ input_ids=input_ids,
811
+ attention_mask=attention_mask,
812
+ encoder_hidden_states=encoder_hidden_states,
813
+ encoder_attention_mask=encoder_attention_mask,
814
+ head_mask=head_mask,
815
+ cross_attn_head_mask=cross_attn_head_mask,
816
+ past_key_values=past_key_values,
817
+ inputs_embeds=inputs_embeds,
818
+ use_cache=use_cache,
819
+ output_attentions=output_attentions,
820
+ output_hidden_states=output_hidden_states,
821
+ return_dict=return_dict,
822
+ )
823
+
824
+ logits = self.lm_head(outputs[0])
825
+
826
+ loss = None
827
+ if labels is not None:
828
+ labels = labels
829
+ loss_fct = CrossEntropyLoss()
830
+ loss = loss_fct(
831
+ logits.reshape([-1, self.config.vocab_size]), labels.reshape([-1])
832
+ )
833
+
834
+ if not return_dict:
835
+ output = (logits,) + outputs[1:]
836
+ return (loss,) + output if loss is not None else output
837
+
838
+ return CausalLMOutputWithCrossAttentions(
839
+ loss=loss,
840
+ logits=logits,
841
+ past_key_values=outputs.past_key_values,
842
+ hidden_states=outputs.hidden_states,
843
+ attentions=outputs.attentions,
844
+ cross_attentions=outputs.cross_attentions,
845
+ )
846
+
847
+ def prepare_inputs_for_generation(
848
+ self,
849
+ input_ids,
850
+ past_key_values=None,
851
+ attention_mask=None,
852
+ use_cache=None,
853
+ **kwargs,
854
+ ):
855
+ if attention_mask is None:
856
+ attention_mask = input_ids.new_ones(input_ids.shape)
857
+
858
+ if past_key_values:
859
+ past_length = past_key_values[0][0].shape[2]
860
+
861
+ if input_ids.shape[1] > past_length:
862
+ remove_prefix_length = past_length
863
+ else:
864
+ remove_prefix_length = input_ids.shape[1] - 1
865
+
866
+ input_ids = input_ids[:, remove_prefix_length:]
867
+ return {
868
+ "input_ids": input_ids,
869
+ "attention_mask": attention_mask,
870
+ "past_key_values": past_key_values,
871
+ "use_cache": use_cache,
872
+ }
873
+
874
+ @staticmethod
875
+ def _reorder_cache(past_key_values, beam_idx):
876
+ reordered_past = ()
877
+ for layer_past in past_key_values:
878
+ reordered_past += (
879
+ tuple(
880
+ past_state.index_select(0, beam_idx) for past_state in layer_past
881
+ ),
882
+ )
883
+ return reordered_past
884
+
885
+
886
+ class myLayerNorm(nn.LayerNorm):
887
+ """
888
+ Custom implementation of Layer Normalization, with additional options.
889
+
890
+ This class extends the standard LayerNorm to include optional features,
891
+ such as drop block regularization, which might be used for improving
892
+ model generalization.
893
+
894
+ Args:
895
+ num_channels (int): The number of features or channels in the input.
896
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-5.
897
+ affine (bool, optional): If True, this module has learnable affine parameters (gamma and beta). Default is True.
898
+ drop_block (optional): Additional regularization technique that might be applied. Default is None.
899
+
900
+ """
901
+
902
+ def __init__(
903
+ self,
904
+ num_channels,
905
+ eps=1e-5,
906
+ affine=True,
907
+ drop_block=None,
908
+ ):
909
+ super(nn.LayerNorm, self).__init__()
910
+ self._epsilon = eps
911
+ self.num_channels = num_channels
912
+ if affine:
913
+ self.weight = torch.nn.Parameter(torch.randn([num_channels]) * 0.01)
914
+ self.bias = torch.nn.Parameter(torch.randn([num_channels]) * 0.01)
915
+ torch.nn.init.ones_(self.weight)
916
+ torch.nn.init.zeros_(self.bias)
917
+
918
+ def forward(self, x):
919
+ x = F.layer_norm(
920
+ x,
921
+ [self.num_channels],
922
+ weight=self.weight,
923
+ bias=self.bias,
924
+ eps=self._epsilon,
925
+ )
926
+ return x
927
+
928
+
929
+ class MBartDecoder(MBartPreTrainedModel):
930
+ """
931
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]
932
+
933
+ Args:
934
+ config
935
+ embed_tokens (nn.Embedding): output embedding
936
+ """
937
+
938
+ def __init__(self, config, embed_tokens=None):
939
+ super().__init__(config)
940
+ self.dropout = config.dropout
941
+ self.layerdrop = config.decoder_layerdrop
942
+ self.padding_idx = config.pad_token_id
943
+ self.max_target_positions = config.max_position_embeddings
944
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
945
+
946
+ self.embed_tokens = nn.Embedding(
947
+ config.vocab_size, config.d_model, self.padding_idx
948
+ )
949
+
950
+ if embed_tokens is not None:
951
+ self.embed_tokens.weight = embed_tokens.weight
952
+
953
+ self.embed_positions = MBartLearnedPositionalEmbedding(
954
+ config.max_position_embeddings,
955
+ config.d_model,
956
+ )
957
+ self.layers = nn.ModuleList(
958
+ [MBartDecoderLayer(config) for _ in range(config.decoder_layers)]
959
+ )
960
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
961
+ self.layernorm_embedding = myLayerNorm(config.d_model, affine=True)
962
+ self.layer_norm = nn.LayerNorm(config.d_model)
963
+
964
+ self.gradient_checkpointing = False
965
+ # Initialize weights and apply final processing
966
+ self.post_init()
967
+ self.is_export = config.is_export
968
+
969
+ def get_input_embeddings(self):
970
+ return self.embed_tokens
971
+
972
+ def set_input_embeddings(self, value):
973
+ self.embed_tokens = value
974
+
975
+ def forward(
976
+ self,
977
+ input_ids=None,
978
+ attention_mask=None,
979
+ encoder_hidden_states=None,
980
+ encoder_attention_mask=None,
981
+ head_mask=None,
982
+ cross_attn_head_mask=None,
983
+ past_key_values=None,
984
+ inputs_embeds=None,
985
+ use_cache=None,
986
+ output_attentions=None,
987
+ output_hidden_states=None,
988
+ return_dict=None,
989
+ ):
990
+
991
+ output_attentions = (
992
+ output_attentions
993
+ if output_attentions is not None
994
+ else self.config.output_attentions
995
+ )
996
+ output_hidden_states = (
997
+ output_hidden_states
998
+ if output_hidden_states is not None
999
+ else self.config.output_hidden_states
1000
+ )
1001
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1002
+ return_dict = (
1003
+ return_dict if return_dict is not None else self.config.use_return_dict
1004
+ )
1005
+
1006
+ if input_ids is not None and inputs_embeds is not None:
1007
+ raise ValueError(
1008
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
1009
+ )
1010
+ elif input_ids is not None:
1011
+ input = input_ids
1012
+ input_shape = input.shape
1013
+ input_ids = input_ids.reshape([-1, input_shape[-1]])
1014
+ elif inputs_embeds is not None:
1015
+ input_shape = inputs_embeds.shape[:-1]
1016
+ input = inputs_embeds[:, :, -1]
1017
+ else:
1018
+ raise ValueError(
1019
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1020
+ )
1021
+
1022
+ past_key_values_length = (
1023
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
1024
+ )
1025
+
1026
+ if inputs_embeds is None:
1027
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1028
+
1029
+ if self._use_flash_attention_2:
1030
+ attention_mask = (
1031
+ attention_mask
1032
+ if (attention_mask is not None and 0 in attention_mask)
1033
+ else None
1034
+ )
1035
+ else:
1036
+ attention_mask = _prepare_4d_causal_attention_mask(
1037
+ attention_mask,
1038
+ input_shape,
1039
+ inputs_embeds,
1040
+ past_key_values_length,
1041
+ is_export=self.is_export,
1042
+ )
1043
+
1044
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1045
+ if self._use_flash_attention_2:
1046
+ encoder_attention_mask = (
1047
+ encoder_attention_mask if 0 in encoder_attention_mask else None
1048
+ )
1049
+ else:
1050
+ encoder_attention_mask = _prepare_4d_attention_mask(
1051
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1052
+ )
1053
+
1054
+ # embed positions
1055
+ positions = self.embed_positions(input, past_key_values_length)
1056
+
1057
+ hidden_states = inputs_embeds + positions
1058
+ hidden_states = self.layernorm_embedding(hidden_states)
1059
+
1060
+ hidden_states = nn.functional.dropout(
1061
+ hidden_states, p=self.dropout, training=self.training
1062
+ )
1063
+
1064
+ if self.gradient_checkpointing and self.training:
1065
+ if use_cache:
1066
+ print(
1067
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
1068
+ )
1069
+ use_cache = False
1070
+
1071
+ all_hidden_states = () if output_hidden_states else None
1072
+ all_self_attns = () if output_attentions else None
1073
+ all_cross_attentions = (
1074
+ () if (output_attentions and encoder_hidden_states is not None) else None
1075
+ )
1076
+ next_decoder_cache = () if use_cache else None
1077
+
1078
+ for attn_mask, mask_name in zip(
1079
+ [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
1080
+ ):
1081
+ if attn_mask is not None:
1082
+ if attn_mask.shape[0] != len(self.layers):
1083
+ raise ValueError(
1084
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
1085
+ f" {attn_mask.shape[0]}."
1086
+ )
1087
+
1088
+ for idx, decoder_layer in enumerate(self.layers):
1089
+ if output_hidden_states:
1090
+ all_hidden_states += (hidden_states,)
1091
+ if self.training:
1092
+ dropout_probability = torch.rand([])
1093
+ if dropout_probability < self.layerdrop:
1094
+ continue
1095
+
1096
+ past_key_value = (
1097
+ past_key_values[idx] if past_key_values is not None else None
1098
+ )
1099
+
1100
+ if self.gradient_checkpointing and self.training:
1101
+ layer_outputs = self._gradient_checkpointing_func(
1102
+ decoder_layer.__call__,
1103
+ hidden_states,
1104
+ attention_mask,
1105
+ encoder_hidden_states,
1106
+ encoder_attention_mask,
1107
+ head_mask[idx] if head_mask is not None else None,
1108
+ (
1109
+ cross_attn_head_mask[idx]
1110
+ if cross_attn_head_mask is not None
1111
+ else None
1112
+ ),
1113
+ None,
1114
+ output_attentions,
1115
+ use_cache,
1116
+ )
1117
+ else:
1118
+ layer_outputs = decoder_layer(
1119
+ hidden_states,
1120
+ attention_mask=attention_mask,
1121
+ encoder_hidden_states=encoder_hidden_states,
1122
+ encoder_attention_mask=encoder_attention_mask,
1123
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1124
+ cross_attn_layer_head_mask=(
1125
+ cross_attn_head_mask[idx]
1126
+ if cross_attn_head_mask is not None
1127
+ else None
1128
+ ),
1129
+ past_key_value=past_key_value,
1130
+ output_attentions=output_attentions,
1131
+ use_cache=use_cache,
1132
+ )
1133
+ hidden_states = layer_outputs[0]
1134
+
1135
+ if use_cache:
1136
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1137
+
1138
+ if output_attentions:
1139
+ all_self_attns += (layer_outputs[1],)
1140
+
1141
+ if encoder_hidden_states is not None:
1142
+ all_cross_attentions += (layer_outputs[2],)
1143
+
1144
+ hidden_states = self.layer_norm(hidden_states)
1145
+
1146
+ if output_hidden_states:
1147
+ all_hidden_states += (hidden_states,)
1148
+
1149
+ next_cache = next_decoder_cache if use_cache else None
1150
+ if not return_dict:
1151
+ return tuple(
1152
+ v
1153
+ for v in [
1154
+ hidden_states,
1155
+ next_cache,
1156
+ all_hidden_states,
1157
+ all_self_attns,
1158
+ all_cross_attentions,
1159
+ ]
1160
+ if v is not None
1161
+ )
1162
+ return BaseModelOutputWithPastAndCrossAttentions(
1163
+ last_hidden_state=hidden_states,
1164
+ past_key_values=next_cache,
1165
+ hidden_states=all_hidden_states,
1166
+ attentions=all_self_attns,
1167
+ cross_attentions=all_cross_attentions,
1168
+ )
1169
+
1170
+
1171
+ class MBartDecoderWrapper(MBartPreTrainedModel):
1172
+ """
1173
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
1174
+ used in combination with the [`EncoderDecoderModel`] framework.
1175
+ """
1176
+
1177
+ def __init__(self, config):
1178
+ super().__init__(config)
1179
+ self.decoder = MBartDecoder(config)
1180
+
1181
+ def forward(self, *args, **kwargs):
1182
+ return self.decoder(*args, **kwargs)
1183
+
1184
+
1185
+ def _in_projection(
1186
+ q: torch.Tensor,
1187
+ k: torch.Tensor,
1188
+ v: torch.Tensor,
1189
+ w_q: torch.Tensor,
1190
+ w_k: torch.Tensor,
1191
+ w_v: torch.Tensor,
1192
+ b_q: Optional[torch.Tensor] = None,
1193
+ b_k: Optional[torch.Tensor] = None,
1194
+ b_v: Optional[torch.Tensor] = None,
1195
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1196
+ Eq, Ek, Ev = q.shape[-1], k.shape[-1], v.shape[-1]
1197
+ assert w_q.shape == (
1198
+ Eq,
1199
+ Eq,
1200
+ ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
1201
+ assert w_k.shape == (
1202
+ Eq,
1203
+ Ek,
1204
+ ), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
1205
+ assert w_v.shape == (
1206
+ Eq,
1207
+ Ev,
1208
+ ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
1209
+ assert b_q is None or b_q.shape == (
1210
+ Eq,
1211
+ ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
1212
+ assert b_k is None or b_k.shape == (
1213
+ Eq,
1214
+ ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
1215
+ assert b_v is None or b_v.shape == (
1216
+ Eq,
1217
+ ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
1218
+ return linear(q, w_q.T, b_q), linear(k, w_k.T, b_k), linear(v, w_v.T, b_v)
1219
+
1220
+
1221
+ def _scaled_dot_product_attention(
1222
+ q: torch.Tensor,
1223
+ k: torch.Tensor,
1224
+ v: torch.Tensor,
1225
+ attn_mask: Optional[torch.Tensor] = None,
1226
+ dropout_p: float = 0.0,
1227
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1228
+ B, Nt, E = q.shape
1229
+ q = q / math.sqrt(E)
1230
+ attn = torch.bmm(q, k.permute([0, 2, 1]))
1231
+ if attn_mask is not None:
1232
+ attn += attn_mask
1233
+ attn = F.softmax(attn, dim=-1)
1234
+ if dropout_p > 0.0:
1235
+ attn = F.dropout(attn, p=dropout_p)
1236
+ output = torch.bmm(attn, v)
1237
+ return output, attn
1238
+
1239
+
1240
+ def linear(x, w, b, is_transpose):
1241
+ if is_transpose:
1242
+ w = w.T
1243
+ if b is not None:
1244
+ return torch.matmul(x, w) + b
1245
+ else:
1246
+ return torch.matmul(x, w)
1247
+
1248
+
1249
+ def _in_projection_packed(
1250
+ q: Tensor,
1251
+ k: Tensor,
1252
+ v: Tensor,
1253
+ w: Tensor,
1254
+ b: Optional[Tensor] = None,
1255
+ is_export=False,
1256
+ ) -> List[Tensor]:
1257
+ E = q.shape[-1]
1258
+ if k is v:
1259
+ if q is k:
1260
+ proj = linear(q, w, b, is_transpose=True)
1261
+ if is_export:
1262
+ B, D, L = proj.shape
1263
+ proj = proj.reshape([B, D, 3, E])
1264
+ proj = (
1265
+ proj.unsqueeze(0)
1266
+ .permute([3, 1, 2, 0, 4])
1267
+ .squeeze(-2)
1268
+ .contiguous()
1269
+ )
1270
+ else:
1271
+ proj = (
1272
+ proj.unflatten(-1, (3, E))
1273
+ .unsqueeze(0)
1274
+ .permute([3, 1, 2, 0, 4])
1275
+ .squeeze(-2)
1276
+ .contiguous()
1277
+ )
1278
+ return proj[0], proj[1], proj[2]
1279
+ else:
1280
+ w_q, w_k, w_v = w.chunk(3)
1281
+ if b is None:
1282
+ b_q = b_k = b_v = None
1283
+ else:
1284
+ b_q, b_k, b_v = b.chunk(3)
1285
+ return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
1286
+
1287
+
1288
+ def multi_head_attention_forward(
1289
+ query: torch.Tensor,
1290
+ key: torch.Tensor,
1291
+ value: torch.Tensor,
1292
+ embed_dim_to_check: int,
1293
+ num_heads: int,
1294
+ in_proj_weight: torch.Tensor,
1295
+ in_proj_bias: Optional[torch.Tensor],
1296
+ bias_k: Optional[torch.Tensor],
1297
+ bias_v: Optional[torch.Tensor],
1298
+ add_zero_attn: bool,
1299
+ dropout_p: float,
1300
+ out_proj_weight: torch.Tensor,
1301
+ out_proj_bias: Optional[torch.Tensor],
1302
+ training: bool = True,
1303
+ key_padding_mask: Optional[torch.Tensor] = None,
1304
+ need_weights: bool = True,
1305
+ attn_mask: Optional[torch.Tensor] = None,
1306
+ use_separate_proj_weight: bool = False,
1307
+ q_proj_weight: Optional[torch.Tensor] = None,
1308
+ k_proj_weight: Optional[torch.Tensor] = None,
1309
+ v_proj_weight: Optional[torch.Tensor] = None,
1310
+ static_k: Optional[torch.Tensor] = None,
1311
+ static_v: Optional[torch.Tensor] = None,
1312
+ is_export=False,
1313
+ ):
1314
+ tgt_len, bsz, embed_dim = query.shape
1315
+ src_len, _, _ = key.shape
1316
+
1317
+ if isinstance(embed_dim, torch.Tensor):
1318
+ head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
1319
+ else:
1320
+ head_dim = embed_dim // num_heads
1321
+ q, k, v = _in_projection_packed(
1322
+ query, key, value, in_proj_weight, in_proj_bias, is_export
1323
+ )
1324
+
1325
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
1326
+ warnings.warn(
1327
+ "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
1328
+ )
1329
+ key_padding_mask = key_padding_mask.to(torch.bool)
1330
+
1331
+ if bias_k is not None and bias_v is not None: # False
1332
+ assert static_k is None, "bias cannot be added to static key."
1333
+ assert static_v is None, "bias cannot be added to static value."
1334
+ k = torch.concat([k, bias_k.repeat(1, bsz, 1)])
1335
+ v = torch.concat([v, bias_v.repeat(1, bsz, 1)])
1336
+ else:
1337
+ assert bias_k is None
1338
+ assert bias_v is None
1339
+
1340
+ q = q.reshape([tgt_len, bsz * num_heads, head_dim]).permute([1, 0, 2])
1341
+ if static_k is None: # True
1342
+ k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).permute([1, 0, 2])
1343
+ else:
1344
+ assert (
1345
+ static_k.shape[0] == bsz * num_heads
1346
+ ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.shape[0]}"
1347
+ assert (
1348
+ static_k.shape[2] == head_dim
1349
+ ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.shape[2]}"
1350
+ k = static_k
1351
+ if static_v is None: # True
1352
+ v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
1353
+ else:
1354
+ assert (
1355
+ static_v.shape[0] == bsz * num_heads
1356
+ ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.shape[0]}"
1357
+ assert (
1358
+ static_v.shape[2] == head_dim
1359
+ ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.shape[2]}"
1360
+ v = static_v
1361
+
1362
+ src_len = k.shape[1]
1363
+
1364
+ if not training:
1365
+ dropout_p = 0.0
1366
+
1367
+ attn_output, attn_output_weights = _scaled_dot_product_attention(
1368
+ q, k, v, attn_mask, dropout_p
1369
+ )
1370
+
1371
+ attn_output = attn_output.permute([1, 0, 2]).reshape([tgt_len, bsz, embed_dim])
1372
+ attn_output = linear(
1373
+ attn_output, out_proj_weight, out_proj_bias, is_transpose=False
1374
+ )
1375
+
1376
+ if need_weights:
1377
+ attn_output_weights = attn_output_weights.reshape(
1378
+ [bsz, num_heads, tgt_len, src_len]
1379
+ )
1380
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
1381
+ else:
1382
+ return attn_output, None
1383
+
1384
+
1385
+ class MyMultiheadAttention(nn.Module):
1386
+ """
1387
+ Custom implementation of a multi-head attention layer.
1388
+
1389
+ Attributes:
1390
+ __constants__ (list): List of constant attributes.
1391
+ bias_k (Optional[paddle.Tensor]): Optional tensor for key bias.
1392
+ bias_v (Optional[paddle.Tensor]): Optional tensor for value bias.
1393
+
1394
+ Args:
1395
+ embed_dim (int): Total dimension of the model. This is the size of the input feature vectors.
1396
+ num_heads (int): Number of parallel attention heads. The input dimension must be divisible by the number of heads.
1397
+ dropout (float, optional): Dropout probability on the attention weights. Default is 0.0.
1398
+ bias (bool, optional): If True, adds a learnable bias to the output. Default is True.
1399
+ add_bias_kv (bool, optional): If True, adds bias to the key and value sequences. Default is False.
1400
+ add_zero_attn (bool, optional): If True, adds a zero attention head. Default is False.
1401
+ kdim (int, optional): Total number of features for keys. If None, defaults to embed_dim.
1402
+ vdim (int, optional): Total number of features for values. If None, defaults to embed_dim.
1403
+ batch_first (bool, optional): If True, the input and output tensors are provided as (batch, seq, feature). Default is False.
1404
+ device (optional): The device on which the layer's parameters should be initialized. Default is None.
1405
+ dtype (optional): The data type for the parameters. Default is None.
1406
+ is_export (bool, optional): If True, the layer is set up for export, potentially changing behavior for compatibility. Default is False.
1407
+ """
1408
+
1409
+ __constants__ = ["batch_first"]
1410
+ bias_k: Optional[torch.Tensor]
1411
+ bias_v: Optional[torch.Tensor]
1412
+
1413
+ def __init__(
1414
+ self,
1415
+ embed_dim,
1416
+ num_heads,
1417
+ dropout=0.0,
1418
+ bias=True,
1419
+ add_bias_kv=False,
1420
+ add_zero_attn=False,
1421
+ kdim=None,
1422
+ vdim=None,
1423
+ batch_first=False,
1424
+ device=None,
1425
+ dtype=None,
1426
+ is_export=False,
1427
+ ) -> None:
1428
+ super(MyMultiheadAttention, self).__init__()
1429
+ self.embed_dim = embed_dim
1430
+ self.kdim = kdim if kdim is not None else embed_dim
1431
+ self.vdim = vdim if vdim is not None else embed_dim
1432
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
1433
+
1434
+ self.num_heads = num_heads
1435
+ self.dropout = dropout
1436
+ self.batch_first = batch_first
1437
+ self.head_dim = embed_dim // num_heads
1438
+ self.is_export = is_export
1439
+ assert (
1440
+ self.head_dim * num_heads == self.embed_dim
1441
+ ), "embed_dim must be divisible by num_heads"
1442
+
1443
+ if self._qkv_same_embed_dim is False:
1444
+ pass
1445
+ else:
1446
+ if dtype is None:
1447
+ dtype = torch.float32
1448
+ self.in_proj_weight = torch.nn.Parameter(torch.randn(3 * embed_dim, embed_dim) * 0.01)
1449
+ self.q_proj_weight = None
1450
+ self.k_proj_weight = None
1451
+ self.v_proj_weight = None
1452
+
1453
+ if bias:
1454
+ self.in_proj_bias = torch.nn.Parameter(torch.randn(3 * embed_dim, ) * 0.01)
1455
+ torch.nn.init.zeros_(self.in_proj_bias)
1456
+ else:
1457
+ self.in_proj_bias = None
1458
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
1459
+
1460
+ if add_bias_kv:
1461
+ pass
1462
+ else:
1463
+ self.bias_k = self.bias_v = None
1464
+
1465
+ self.add_zero_attn = add_zero_attn
1466
+
1467
+ self._reset_parameters()
1468
+
1469
+ def _reset_parameters(self):
1470
+
1471
+ if self._qkv_same_embed_dim:
1472
+ torch.nn.init.xavier_normal_(self.in_proj_weight)
1473
+ else:
1474
+ torch.nn.init.xavier_normal_(self.q_proj_weight)
1475
+ torch.nn.init.xavier_normal_(self.k_proj_weight)
1476
+ torch.nn.init.xavier_normal_(self.v_proj_weight)
1477
+
1478
+ if self.in_proj_bias is not None:
1479
+ torch.nn.init.zeros_(self.in_proj_bias)
1480
+ torch.nn.init.zeros_(self.out_proj.bias)
1481
+ if self.bias_k is not None:
1482
+ torch.nn.init.xavier_normal_(self.bias_k)
1483
+ if self.bias_v is not None:
1484
+ torch.nn.init.xavier_normal_(self.bias_v)
1485
+
1486
+ def forward(
1487
+ self,
1488
+ query: torch.Tensor,
1489
+ key: torch.Tensor,
1490
+ value: torch.Tensor,
1491
+ key_padding_mask: Optional[torch.Tensor] = None,
1492
+ need_weights: bool = True,
1493
+ attn_mask: Optional[torch.Tensor] = None,
1494
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1495
+
1496
+ attn_output, attn_output_weights = multi_head_attention_forward(
1497
+ query,
1498
+ key,
1499
+ value,
1500
+ self.embed_dim,
1501
+ self.num_heads,
1502
+ self.in_proj_weight,
1503
+ self.in_proj_bias,
1504
+ self.bias_k,
1505
+ self.bias_v,
1506
+ self.add_zero_attn,
1507
+ self.dropout,
1508
+ self.out_proj.weight,
1509
+ self.out_proj.bias,
1510
+ training=self.training,
1511
+ key_padding_mask=key_padding_mask,
1512
+ need_weights=need_weights,
1513
+ attn_mask=attn_mask,
1514
+ is_export=self.is_export,
1515
+ )
1516
+
1517
+ return attn_output, attn_output_weights
1518
+
1519
+
1520
+ class LogitsProcessorList(list):
1521
+ """
1522
+ A list of logits processors that can be applied sequentially.
1523
+
1524
+ Methods:
1525
+ __call__(input_ids, scores, **kwargs): Apply all processors to the given inputs.
1526
+ """
1527
+
1528
+ def __call__(self, input_ids, scores, **kwargs):
1529
+ for processor in self:
1530
+ function_args = inspect.signature(processor.__call__).parameters
1531
+ if len(function_args) > 2:
1532
+ if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
1533
+ raise ValueError(
1534
+ f"Make sure that all the required parameters: {list(function_args.keys())} for "
1535
+ f"{processor.__class__} are passed to the logits processor."
1536
+ )
1537
+ scores = processor(input_ids, scores, **kwargs)
1538
+ else:
1539
+ scores = processor(input_ids, scores)
1540
+ return scores
1541
+
1542
+
1543
+ class ForcedEOSTokenLogitsProcessor(object):
1544
+ """
1545
+ A processor that forces the generation of an end-of-sequence (EOS) token
1546
+ at a specified position in the sequence.
1547
+
1548
+ This is typically used in language generation tasks to ensure that the
1549
+ generated sequence ends properly when it reaches a certain length.
1550
+
1551
+ Args:
1552
+ max_length (int): The maximum length of the sequence. Forces EOS when this length is reached.
1553
+ eos_token_id (Union[int, List[int]]): The ID(s) of the EOS token(s) to be forced in the sequence.
1554
+ """
1555
+
1556
+ def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
1557
+ self.max_length = max_length
1558
+ if isinstance(eos_token_id, int):
1559
+ eos_token_id = [eos_token_id]
1560
+ self.eos_token_id = eos_token_id
1561
+
1562
+ def __call__(self, input_ids, scores):
1563
+ cur_len = input_ids.shape[-1]
1564
+ scores_processed = scores
1565
+ if cur_len == self.max_length - 1:
1566
+ scores_processed = torch.full_like(scores, -math.inf)
1567
+ scores_processed[:, self.eos_token_id] = 0
1568
+ return scores_processed
1569
+
1570
+
1571
+ @dataclass
1572
+ class CausalLMOutputWithCrossAttentions(ModelOutput):
1573
+ loss = None
1574
+ logits = None
1575
+ past_key_values = None
1576
+ hidden_states = None
1577
+ attentions = None
1578
+ cross_attentions = None
1579
+
1580
+ def __init__(self, *args, **kwargs):
1581
+ super().__init__(*args, **kwargs)
1582
+
1583
+
1584
+ @dataclass
1585
+ class CausalLMOutputWithCrossAttentionsAndCounting(ModelOutput):
1586
+ """
1587
+ Base class for causal language model (or autoregressive) outputs.
1588
+ """
1589
+
1590
+ logits = None
1591
+ counting = None
1592
+ past_key_values = None
1593
+ hidden_states = None
1594
+ attentions = None
1595
+ cross_attentions = None
1596
+
1597
+ def __init__(self, *args, **kwargs):
1598
+ super().__init__(*args, **kwargs)
1599
+
1600
+
1601
+ class CustomMBartDecoder(MBartDecoder):
1602
+ """
1603
+ A custom MBartDecoder that includes additional processing layers.
1604
+
1605
+ This class extends the MBartDecoder by adding a customizable neural network
1606
+ component called `counting_context_weight`, which applies a series of linear
1607
+ transformations followed by ReLU activations. This can be used to modify or
1608
+ enhance the decoder's behavior for specific tasks.
1609
+
1610
+ Args:
1611
+ config: The configuration object containing model parameters.
1612
+ """
1613
+
1614
+ def __init__(self, config):
1615
+ super().__init__(config)
1616
+ hidden_size = config.d_model
1617
+ self.is_export = config.is_export
1618
+ self.counting_context_weight = nn.Sequential(
1619
+ nn.Linear(config.vocab_size, hidden_size),
1620
+ nn.ReLU(),
1621
+ nn.Linear(hidden_size, hidden_size),
1622
+ nn.ReLU(),
1623
+ nn.Linear(hidden_size, config.d_model),
1624
+ )
1625
+
1626
+ def forward(
1627
+ self,
1628
+ input_ids=None,
1629
+ attention_mask=None,
1630
+ count_pred=None,
1631
+ encoder_hidden_states=None,
1632
+ encoder_attention_mask=None,
1633
+ head_mask=None,
1634
+ cross_attn_head_mask=None,
1635
+ past_key_values=None,
1636
+ inputs_embeds=None,
1637
+ use_cache=None,
1638
+ output_attentions=None,
1639
+ output_hidden_states=None,
1640
+ return_dict=None,
1641
+ ):
1642
+ self.is_export = False if self.training else True
1643
+ output_attentions = (
1644
+ output_attentions
1645
+ if output_attentions is not None
1646
+ else self.config.output_attentions
1647
+ )
1648
+ output_hidden_states = (
1649
+ output_hidden_states
1650
+ if output_hidden_states is not None
1651
+ else self.config.output_hidden_states
1652
+ )
1653
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1654
+ return_dict = (
1655
+ return_dict if return_dict is not None else self.config.use_return_dict
1656
+ )
1657
+
1658
+ if input_ids is not None and inputs_embeds is not None:
1659
+ raise ValueError(
1660
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
1661
+ )
1662
+ elif input_ids is not None:
1663
+ input = input_ids
1664
+ input_shape = input.shape
1665
+ input_ids = input_ids.reshape([-1, input_shape[-1]])
1666
+ elif inputs_embeds is not None:
1667
+ input_shape = inputs_embeds.shape[:-1]
1668
+ input = inputs_embeds[:, :, -1]
1669
+ else:
1670
+ raise ValueError(
1671
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1672
+ )
1673
+
1674
+ past_key_values_length = (
1675
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
1676
+ )
1677
+
1678
+ if inputs_embeds is None:
1679
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1680
+
1681
+ if self._use_flash_attention_2:
1682
+ attention_mask = (
1683
+ attention_mask
1684
+ if (attention_mask is not None and 0 in attention_mask)
1685
+ else None
1686
+ )
1687
+ else:
1688
+ if self.is_export:
1689
+ attention_mask = _prepare_4d_causal_attention_mask_export(
1690
+ attention_mask,
1691
+ input_shape,
1692
+ inputs_embeds,
1693
+ past_key_values_length,
1694
+ is_export=self.is_export,
1695
+ ).to(torch.float32)
1696
+ else:
1697
+ attention_mask = _prepare_4d_causal_attention_mask(
1698
+ attention_mask,
1699
+ input_shape,
1700
+ inputs_embeds,
1701
+ past_key_values_length,
1702
+ is_export=self.is_export,
1703
+ )
1704
+
1705
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1706
+ if self._use_flash_attention_2:
1707
+ encoder_attention_mask = (
1708
+ encoder_attention_mask if 0 in encoder_attention_mask else None
1709
+ )
1710
+ else:
1711
+ encoder_attention_mask = _prepare_4d_attention_mask(
1712
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1713
+ )
1714
+
1715
+ # embed positions
1716
+ positions = self.embed_positions(input, past_key_values_length)
1717
+
1718
+ hidden_states = inputs_embeds + positions
1719
+
1720
+ # TODO: add counting context weight to hidden_states
1721
+ if count_pred is not None:
1722
+ count_context_weight = self.counting_context_weight(count_pred)
1723
+ hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
1724
+
1725
+ hidden_states = self.layernorm_embedding(hidden_states)
1726
+ hidden_states = nn.functional.dropout(
1727
+ hidden_states, p=self.dropout, training=self.training
1728
+ )
1729
+
1730
+ if self.gradient_checkpointing and self.training:
1731
+ if use_cache:
1732
+ print(
1733
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
1734
+ )
1735
+ use_cache = False
1736
+
1737
+ # decoder layers
1738
+ all_hidden_states = () if output_hidden_states else None
1739
+ all_self_attns = () if output_attentions else None
1740
+ all_cross_attentions = (
1741
+ () if (output_attentions and encoder_hidden_states is not None) else None
1742
+ )
1743
+ next_decoder_cache = () if use_cache else None
1744
+
1745
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1746
+ for attn_mask, mask_name in zip(
1747
+ [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
1748
+ ):
1749
+ if attn_mask is not None:
1750
+ if attn_mask.size()[0] != len(self.layers):
1751
+ raise ValueError(
1752
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
1753
+ f" {attn_mask.size()[0]}."
1754
+ )
1755
+
1756
+ for idx, decoder_layer in enumerate(self.layers):
1757
+ if output_hidden_states:
1758
+ all_hidden_states += (hidden_states,)
1759
+ if self.training:
1760
+ dropout_probability = torch.rand()
1761
+ if dropout_probability < self.layerdrop:
1762
+ continue
1763
+
1764
+ past_key_value = (
1765
+ past_key_values[idx] if past_key_values is not None else None
1766
+ )
1767
+
1768
+ if self.gradient_checkpointing and self.training:
1769
+ layer_outputs = self._gradient_checkpointing_func(
1770
+ decoder_layer.__call__,
1771
+ hidden_states,
1772
+ attention_mask,
1773
+ encoder_hidden_states,
1774
+ encoder_attention_mask,
1775
+ head_mask[idx] if head_mask is not None else None,
1776
+ (
1777
+ cross_attn_head_mask[idx]
1778
+ if cross_attn_head_mask is not None
1779
+ else None
1780
+ ),
1781
+ None,
1782
+ output_attentions,
1783
+ use_cache,
1784
+ )
1785
+ else:
1786
+ layer_outputs = decoder_layer(
1787
+ hidden_states,
1788
+ attention_mask=attention_mask,
1789
+ encoder_hidden_states=encoder_hidden_states,
1790
+ encoder_attention_mask=encoder_attention_mask,
1791
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1792
+ cross_attn_layer_head_mask=(
1793
+ cross_attn_head_mask[idx]
1794
+ if cross_attn_head_mask is not None
1795
+ else None
1796
+ ),
1797
+ past_key_value=past_key_value,
1798
+ output_attentions=output_attentions,
1799
+ use_cache=use_cache,
1800
+ )
1801
+ hidden_states = layer_outputs[0]
1802
+ if self.is_export:
1803
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1804
+ else:
1805
+ if use_cache:
1806
+ next_decoder_cache += (
1807
+ layer_outputs[3 if output_attentions else 1],
1808
+ )
1809
+
1810
+ if output_attentions:
1811
+ all_self_attns += (layer_outputs[1],)
1812
+
1813
+ if encoder_hidden_states is not None:
1814
+ all_cross_attentions += (layer_outputs[2],)
1815
+
1816
+ hidden_states = self.layer_norm(hidden_states)
1817
+
1818
+ if output_hidden_states:
1819
+ all_hidden_states += (hidden_states,)
1820
+ if self.is_export:
1821
+ next_cache = next_decoder_cache
1822
+ else:
1823
+ next_cache = next_decoder_cache if use_cache else None
1824
+ if not self.is_export:
1825
+ if not return_dict:
1826
+ return tuple(
1827
+ v
1828
+ for v in [
1829
+ hidden_states,
1830
+ next_cache,
1831
+ all_hidden_states,
1832
+ all_self_attns,
1833
+ all_cross_attentions,
1834
+ ]
1835
+ if v is not None
1836
+ )
1837
+ return BaseModelOutputWithPastAndCrossAttentions(
1838
+ last_hidden_state=hidden_states,
1839
+ past_key_values=next_cache,
1840
+ hidden_states=all_hidden_states,
1841
+ attentions=all_self_attns,
1842
+ cross_attentions=all_cross_attentions,
1843
+ )
1844
+
1845
+
1846
+ class SelfAttentionBlock(nn.Module):
1847
+ """
1848
+ A self-attention block that implements multi-head self-attention
1849
+ followed by a feed-forward network, typically used in transformer architectures.
1850
+
1851
+ Args:
1852
+ embed_size (int): The size of the embedding vector.
1853
+ num_heads (int): The number of attention heads.
1854
+ is_export (bool): Flag indicating whether to configure the layer for export.
1855
+ """
1856
+
1857
+ def __init__(self, embed_size, num_heads, is_export):
1858
+ super(SelfAttentionBlock, self).__init__()
1859
+ self.self_attention = MyMultiheadAttention(
1860
+ embed_dim=embed_size, num_heads=num_heads, is_export=is_export
1861
+ )
1862
+ self.norm = nn.LayerNorm(embed_size)
1863
+
1864
+ def forward(self, x):
1865
+ attn_output, _ = self.self_attention(x, x, x)
1866
+ x = self.norm(attn_output + x)
1867
+ return x
1868
+
1869
+
1870
+ class SeqCountingDecoder(nn.Module):
1871
+ """
1872
+ A custom sequence counting decoder that incorporates multi-head attention layers
1873
+ and feed-forward networks to process sequences, potentially for latex code counting .
1874
+
1875
+ Args:
1876
+ in_features (int): The number of input features.
1877
+ out_features (int): The number of output features.
1878
+ num_heads (int): The number of attention heads. Defaults to 8.
1879
+ num_layers (int): The number of attention layers. Defaults to 4.
1880
+ is_export (bool): Flag indicating whether to configure the layer for export.
1881
+ """
1882
+
1883
+ def __init__(
1884
+ self, in_features, out_features, num_heads=8, num_layers=4, is_export=False
1885
+ ):
1886
+ super(SeqCountingDecoder, self).__init__()
1887
+
1888
+ self.attention_blocks = nn.ModuleList(
1889
+ [
1890
+ SelfAttentionBlock(
1891
+ embed_size=in_features, num_heads=num_heads, is_export=is_export
1892
+ )
1893
+ for i in range(num_layers)
1894
+ ]
1895
+ )
1896
+ self.fc1 = nn.Linear(in_features, in_features // 2)
1897
+ self.relu = nn.ReLU()
1898
+ self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
1899
+ self.fc2 = nn.Linear(in_features // 2, out_features)
1900
+
1901
+ def forward(self, x):
1902
+ for block in self.attention_blocks:
1903
+ x = block(x)
1904
+ x = self.fc1(x)
1905
+ x = self.relu(x)
1906
+ x = x.transpose([0, 2, 1])
1907
+ x = self.global_avg_pool(x)
1908
+ x = x.squeeze(-1)
1909
+ x = self.fc2(x)
1910
+ return x
1911
+
1912
+
1913
+ class CustomMBartForCausalLM(MBartForCausalLM):
1914
+ """
1915
+ Custom MBart model for causal language modeling with a custom decoder.
1916
+
1917
+ This class extends the MBartForCausalLM by replacing its decoder with a
1918
+ custom decoder, allowing for additional flexibility and features in the
1919
+ decoding process.
1920
+
1921
+ Args:
1922
+ config: The configuration object containing model parameters.
1923
+ length_aware (bool): A flag to enable or configure length-aware mechanisms.
1924
+ """
1925
+
1926
+ def __init__(self, config, length_aware=True):
1927
+ super().__init__(config)
1928
+ self.model.decoder = CustomMBartDecoder(config)
1929
+ self.counting_decoder = SeqCountingDecoder(
1930
+ config.d_model, config.vocab_size, is_export=config.is_export
1931
+ )
1932
+ self.length_aware = length_aware
1933
+
1934
+ def forward(
1935
+ self,
1936
+ input_ids=None,
1937
+ attention_mask=None,
1938
+ encoder_hidden_states=None,
1939
+ encoder_attention_mask=None,
1940
+ head_mask=None,
1941
+ cross_attn_head_mask=None,
1942
+ past_key_values=None,
1943
+ inputs_embeds=None,
1944
+ labels=None,
1945
+ use_cache=None,
1946
+ output_attentions=None,
1947
+ output_hidden_states=None,
1948
+ return_dict=None,
1949
+ count_gt=None,
1950
+ ):
1951
+ output_attentions = (
1952
+ output_attentions
1953
+ if output_attentions is not None
1954
+ else self.config.output_attentions
1955
+ )
1956
+ output_hidden_states = (
1957
+ output_hidden_states
1958
+ if output_hidden_states is not None
1959
+ else self.config.output_hidden_states
1960
+ )
1961
+ return_dict = (
1962
+ return_dict if return_dict is not None else self.config.use_return_dict
1963
+ )
1964
+
1965
+ if self.length_aware:
1966
+ count_pred = self.counting_decoder(encoder_hidden_states)
1967
+ else:
1968
+ count_pred = None
1969
+
1970
+ outputs = self.model.decoder(
1971
+ input_ids=input_ids,
1972
+ attention_mask=attention_mask,
1973
+ count_pred=count_pred,
1974
+ encoder_hidden_states=encoder_hidden_states,
1975
+ encoder_attention_mask=encoder_attention_mask,
1976
+ head_mask=head_mask,
1977
+ cross_attn_head_mask=cross_attn_head_mask,
1978
+ past_key_values=past_key_values,
1979
+ inputs_embeds=inputs_embeds,
1980
+ use_cache=use_cache,
1981
+ output_attentions=output_attentions,
1982
+ output_hidden_states=output_hidden_states,
1983
+ return_dict=return_dict,
1984
+ )
1985
+ logits = self.lm_head(outputs[0])
1986
+
1987
+ return CausalLMOutputWithCrossAttentionsAndCounting(
1988
+ logits=logits,
1989
+ counting=count_pred,
1990
+ past_key_values=outputs.past_key_values,
1991
+ hidden_states=outputs.hidden_states,
1992
+ attentions=outputs.attentions,
1993
+ cross_attentions=outputs.cross_attentions,
1994
+ )
1995
+
1996
+
1997
+ class UniMERNetHead(nn.Module):
1998
+ """Implementation of UniMERNetHead decoder.
1999
+
2000
+ Args:
2001
+ max_new_tokens (int): Maximum number of new tokens to generate.
2002
+ decoder_start_token_id (int): ID of the token that starts the decoding.
2003
+ temperature (float): Sampling temperature for generation.
2004
+ do_sample (bool): Whether to use sampling; if False, uses greedy decoding.
2005
+ top_p (float): Top-p (nucleus) sampling parameter.
2006
+ in_channels (int): Number of input channels/features.
2007
+ encoder_hidden_size (int): Hidden size of the encoder.
2008
+ decoder_hidden_size (int): Hidden size of the decoder.
2009
+ decoder_ffn_dim (int): Dimension of the decoder's feed-forward network.
2010
+ decoder_layers (int): Number of layers in the decoder.
2011
+ is_export (bool): Flag indicating if the model is being prepared for export.
2012
+ length_aware (bool): Flag to enable length-aware mechanisms.
2013
+ """
2014
+
2015
+ def __init__(
2016
+ self,
2017
+ max_new_tokens=1536,
2018
+ decoder_start_token_id=0,
2019
+ temperature=0.2,
2020
+ do_sample=False,
2021
+ top_p=0.95,
2022
+ in_channels=1024,
2023
+ encoder_hidden_size=1024,
2024
+ decoder_hidden_size=1024,
2025
+ decoder_ffn_dim=4096,
2026
+ decoder_layers=8,
2027
+ is_export=False,
2028
+ length_aware=True,
2029
+ ):
2030
+ super().__init__()
2031
+ mbart_config_dict = {
2032
+ "activation_dropout": 0.0,
2033
+ "activation_function": "gelu",
2034
+ "add_cross_attention": True,
2035
+ "add_final_layer_norm": True,
2036
+ "attention_dropout": 0.0,
2037
+ "bos_token_id": 0,
2038
+ "classifier_dropout": 0.0,
2039
+ "d_model": decoder_hidden_size,
2040
+ "decoder_attention_heads": 16,
2041
+ "decoder_ffn_dim": decoder_ffn_dim,
2042
+ "decoder_layerdrop": 0.0,
2043
+ "decoder_layers": decoder_layers,
2044
+ "dropout": 0.1,
2045
+ "encoder_attention_heads": 16,
2046
+ "encoder_ffn_dim": 4096,
2047
+ "encoder_layerdrop": 0.0,
2048
+ "encoder_layers": 12,
2049
+ "eos_token_id": 2,
2050
+ "forced_eos_token_id": 2,
2051
+ "init_std": 0.02,
2052
+ "is_decoder": True,
2053
+ "is_encoder_decoder": False,
2054
+ "output_hidden_states": False,
2055
+ "max_position_embeddings": max_new_tokens,
2056
+ "model_type": "mbart",
2057
+ "num_hidden_layers": 12,
2058
+ "pad_token_id": 1,
2059
+ "scale_embedding": True,
2060
+ "tie_word_embeddings": False,
2061
+ "transformers_version": "4.40.0",
2062
+ "use_cache": True,
2063
+ "use_return_dict": True,
2064
+ "vocab_size": 50000,
2065
+ "_attn_implementation": "eager",
2066
+ "hidden_size": decoder_hidden_size,
2067
+ "is_export": is_export,
2068
+ }
2069
+
2070
+ self.max_new_tokens = max_new_tokens
2071
+ self.decoder_start_token_id = decoder_start_token_id
2072
+ self.temperature = temperature
2073
+ self.do_sample = do_sample
2074
+ self.top_p = top_p
2075
+ self.max_seq_len = max_new_tokens
2076
+ self.config_decoder = MBartConfig(**mbart_config_dict)
2077
+ self.encoder_hidden_size = encoder_hidden_size
2078
+ self.is_export = self.config_decoder.is_export
2079
+ self.decoder = CustomMBartForCausalLM(
2080
+ self.config_decoder, length_aware=length_aware
2081
+ )
2082
+ if self.config_decoder.hidden_size != self.encoder_hidden_size:
2083
+ self.enc_to_dec_proj = nn.Linear(
2084
+ self.encoder_hidden_size, self.config_decoder.hidden_size
2085
+ )
2086
+ generation_config = {
2087
+ "max_length": 1537,
2088
+ "forced_eos_token_id": 2,
2089
+ }
2090
+ self.eos_token_id = generation_config["forced_eos_token_id"]
2091
+ self.pad_token_id = self.config_decoder.pad_token_id
2092
+ self.logits_processor = LogitsProcessorList()
2093
+ self.logits_processor.append(
2094
+ ForcedEOSTokenLogitsProcessor(
2095
+ generation_config["max_length"],
2096
+ generation_config["forced_eos_token_id"],
2097
+ )
2098
+ )
2099
+
2100
+ def _get_decoder_start_token_id(
2101
+ self, decoder_start_token_id=None, bos_token_id=None
2102
+ ) -> int:
2103
+ decoder_start_token_id = (
2104
+ decoder_start_token_id
2105
+ if decoder_start_token_id is not None
2106
+ else self.generation_config.decoder_start_token_id
2107
+ )
2108
+ bos_token_id = (
2109
+ bos_token_id
2110
+ if bos_token_id is not None
2111
+ else self.generation_config.bos_token_id
2112
+ )
2113
+ if decoder_start_token_id is not None:
2114
+ return decoder_start_token_id
2115
+ elif bos_token_id is not None:
2116
+ return bos_token_id
2117
+ raise ValueError(
2118
+ "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
2119
+ )
2120
+
2121
+ def _prepare_decoder_input_ids_for_generation(
2122
+ self,
2123
+ batch_size,
2124
+ model_kwargs,
2125
+ decoder_start_token_id=None,
2126
+ bos_token_id=None,
2127
+ ):
2128
+ if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
2129
+ decoder_input_ids = model_kwargs.pop("decoder_input_ids")
2130
+ elif "input_ids" in model_kwargs:
2131
+ decoder_input_ids = model_kwargs.pop("input_ids")
2132
+ else:
2133
+ decoder_input_ids = None
2134
+
2135
+ decoder_start_token_id = self._get_decoder_start_token_id(
2136
+ decoder_start_token_id, bos_token_id
2137
+ )
2138
+
2139
+ if isinstance(decoder_start_token_id, list):
2140
+ if len(decoder_start_token_id) != batch_size:
2141
+ raise ValueError(
2142
+ f"`decoder_start_token_id` expected to have length {batch_size} but got {len(decoder_start_token_id)}"
2143
+ )
2144
+ decoder_input_ids_start = torch.LongTensor(decoder_start_token_id)
2145
+ decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
2146
+ else:
2147
+ decoder_input_ids_start = (
2148
+ torch.ones(
2149
+ (batch_size, 1),
2150
+ dtype=torch.int64,
2151
+ )
2152
+ * decoder_start_token_id
2153
+ )
2154
+
2155
+ if decoder_input_ids is None:
2156
+ decoder_input_ids = decoder_input_ids_start
2157
+ elif (
2158
+ self.config.model_type == "vision-encoder-decoder"
2159
+ and "donut" in self.name_or_path.lower()
2160
+ ):
2161
+ pass
2162
+ elif self.config.model_type in ["whisper"]:
2163
+ pass
2164
+ elif (
2165
+ isinstance(decoder_start_token_id, int)
2166
+ and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
2167
+ ) or (
2168
+ isinstance(decoder_start_token_id, torch.Tensor)
2169
+ and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
2170
+ ):
2171
+ decoder_input_ids = torch.concat(
2172
+ [decoder_input_ids_start, decoder_input_ids], dim=-1
2173
+ )
2174
+ if "decoder_attention_mask" in model_kwargs:
2175
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
2176
+ decoder_attention_mask = torch.cat(
2177
+ (
2178
+ torch.ones_like(decoder_attention_mask)[:, :1],
2179
+ decoder_attention_mask,
2180
+ ),
2181
+ dim=-1,
2182
+ )
2183
+ model_kwargs["decoder_attention_mask"] = decoder_attention_mask
2184
+
2185
+ return decoder_input_ids, model_kwargs
2186
+
2187
+ def prepare_inputs_for_generation_mbart(
2188
+ self,
2189
+ input_ids,
2190
+ past_key_values=None,
2191
+ attention_mask=None,
2192
+ use_cache=None,
2193
+ **kwargs,
2194
+ ):
2195
+
2196
+ if attention_mask is None:
2197
+ attention_mask = torch.ones(input_ids.shape)
2198
+
2199
+ if past_key_values:
2200
+ past_length = past_key_values[0][0].shape[2]
2201
+
2202
+ if input_ids.shape[1] > past_length:
2203
+ remove_prefix_length = past_length
2204
+ else:
2205
+ remove_prefix_length = input_ids.shape[1] - 1
2206
+
2207
+ input_ids = input_ids[:, remove_prefix_length:]
2208
+ return {
2209
+ "input_ids": input_ids,
2210
+ "attention_mask": attention_mask,
2211
+ "past_key_values": past_key_values,
2212
+ "use_cache": use_cache,
2213
+ }
2214
+
2215
+ def prepare_inputs_for_generation(
2216
+ self,
2217
+ input_ids,
2218
+ past_key_values=None,
2219
+ attention_mask=None,
2220
+ use_cache=None,
2221
+ encoder_outputs=None,
2222
+ **kwargs,
2223
+ ):
2224
+ decoder_inputs = self.prepare_inputs_for_generation_mbart(
2225
+ input_ids, past_key_values=past_key_values
2226
+ )
2227
+ decoder_attention_mask = (
2228
+ decoder_inputs["attention_mask"]
2229
+ if "attention_mask" in decoder_inputs
2230
+ else None
2231
+ )
2232
+ input_dict = {
2233
+ "attention_mask": attention_mask,
2234
+ "decoder_attention_mask": decoder_attention_mask,
2235
+ "decoder_input_ids": decoder_inputs["input_ids"],
2236
+ "encoder_outputs": encoder_outputs,
2237
+ "past_key_values": decoder_inputs["past_key_values"],
2238
+ "use_cache": use_cache,
2239
+ }
2240
+ return input_dict
2241
+
2242
+ def prepare_inputs_for_generation_export(
2243
+ self,
2244
+ past_key_values=None,
2245
+ attention_mask=None,
2246
+ use_cache=None,
2247
+ encoder_outputs=None,
2248
+ **kwargs,
2249
+ ):
2250
+
2251
+ input_dict = {
2252
+ "decoder_attention_mask": None,
2253
+ "use_cache": use_cache,
2254
+ }
2255
+ return input_dict
2256
+
2257
+ def _extract_past_from_model_output(
2258
+ self, outputs: ModelOutput, standardize_cache_format: bool = False
2259
+ ):
2260
+ past_key_values = None
2261
+ if "past_key_values" in outputs:
2262
+ past_key_values = outputs.past_key_values
2263
+ elif "mems" in outputs:
2264
+ past_key_values = outputs.mems
2265
+ elif "past_buckets_states" in outputs:
2266
+ past_key_values = outputs.past_buckets_states
2267
+
2268
+ return past_key_values
2269
+
2270
+ def _update_model_kwargs_for_generation(
2271
+ self,
2272
+ outputs: ModelOutput,
2273
+ model_kwargs: Dict[str, Any],
2274
+ is_encoder_decoder: bool = False,
2275
+ standardize_cache_format: bool = False,
2276
+ ) -> Dict[str, Any]:
2277
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
2278
+ outputs, standardize_cache_format=standardize_cache_format
2279
+ )
2280
+ if getattr(outputs, "state", None) is not None:
2281
+ model_kwargs["state"] = outputs.state
2282
+
2283
+ if "token_type_ids" in model_kwargs:
2284
+ token_type_ids = model_kwargs["token_type_ids"]
2285
+ model_kwargs["token_type_ids"] = torch.concat(
2286
+ [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
2287
+ )
2288
+
2289
+ if not is_encoder_decoder:
2290
+ if "attention_mask" in model_kwargs:
2291
+ attention_mask = model_kwargs["attention_mask"]
2292
+ model_kwargs["attention_mask"] = torch.concat(
2293
+ [
2294
+ attention_mask,
2295
+ attention_mask.new_ones((attention_mask.shape[0], 1)),
2296
+ ],
2297
+ dim=-1,
2298
+ )
2299
+ else:
2300
+ if "decoder_attention_mask" in model_kwargs:
2301
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
2302
+ model_kwargs["decoder_attention_mask"] = torch.concat(
2303
+ [
2304
+ decoder_attention_mask,
2305
+ decoder_attention_mask.new_ones(
2306
+ (decoder_attention_mask.shape[0], 1)
2307
+ ),
2308
+ ],
2309
+ dim=-1,
2310
+ )
2311
+
2312
+ if (
2313
+ "cache_position" in model_kwargs
2314
+ and model_kwargs["cache_position"] is not None
2315
+ ):
2316
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
2317
+
2318
+ return model_kwargs
2319
+
2320
+ def stopping_criteria(self, input_ids):
2321
+ if self.is_export:
2322
+ return input_ids[:, -1] == torch.Tensor([self.eos_token_id])
2323
+ is_done = torch.isin(input_ids[:, -1], torch.Tensor([self.eos_token_id]))
2324
+ return is_done
2325
+
2326
+ def generate_single_iter(
2327
+ self,
2328
+ decoder_input_ids=None,
2329
+ decoder_attention_mask=None,
2330
+ encoder_outputs=None,
2331
+ past_key_values=None,
2332
+ decoder_inputs_embeds=None,
2333
+ labels=None,
2334
+ use_cache=None,
2335
+ output_attentions=None,
2336
+ output_hidden_states=None,
2337
+ return_dict=None,
2338
+ **kwargs,
2339
+ ):
2340
+ encoder_hidden_states = encoder_outputs[0]
2341
+ if self.config_decoder.hidden_size != self.encoder_hidden_size:
2342
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
2343
+ kwargs_decoder = {}
2344
+
2345
+ decoder_outputs = self.decoder(
2346
+ input_ids=decoder_input_ids,
2347
+ attention_mask=decoder_attention_mask,
2348
+ encoder_hidden_states=encoder_hidden_states,
2349
+ encoder_attention_mask=None,
2350
+ inputs_embeds=None,
2351
+ output_attentions=False,
2352
+ output_hidden_states=output_hidden_states,
2353
+ use_cache=use_cache,
2354
+ past_key_values=past_key_values,
2355
+ return_dict=return_dict,
2356
+ **kwargs_decoder,
2357
+ )
2358
+
2359
+ return Seq2SeqLMOutput(
2360
+ loss=None,
2361
+ logits=decoder_outputs.logits,
2362
+ past_key_values=decoder_outputs.past_key_values,
2363
+ decoder_hidden_states=decoder_outputs.hidden_states,
2364
+ decoder_attentions=decoder_outputs.attentions,
2365
+ cross_attentions=decoder_outputs.cross_attentions,
2366
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
2367
+ encoder_hidden_states=encoder_outputs.hidden_states,
2368
+ encoder_attentions=encoder_outputs.attentions,
2369
+ )
2370
+
2371
+ @torch.no_grad()
2372
+ def generate(
2373
+ self,
2374
+ model_kwargs,
2375
+ ):
2376
+ """
2377
+ Generate sequences using the UniMERNetHead for inference tasks.
2378
+
2379
+ Args:
2380
+ model_kwargs (dict): A dictionary of model configurations and inputs, which typically include:
2381
+ - encoder_outputs: Outputs from the encoder.
2382
+ - use_cache: Boolean flag to indicate if caching should be used.
2383
+ - output_attentions: Boolean flag for outputting attention scores.
2384
+ - output_hidden_states: Boolean flag for outputting hidden states.
2385
+
2386
+ Returns:
2387
+ A tensor containing the generated sequences.
2388
+ """
2389
+ batch_size = model_kwargs["encoder_outputs"]["last_hidden_state"].shape[0]
2390
+ generation_config = {
2391
+ "decoder_start_token_id": 0,
2392
+ "bos_token_id": 0,
2393
+ }
2394
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
2395
+ batch_size=batch_size,
2396
+ model_kwargs=model_kwargs,
2397
+ decoder_start_token_id=generation_config["decoder_start_token_id"],
2398
+ bos_token_id=generation_config["bos_token_id"],
2399
+ )
2400
+ model_kwargs["key use_cache"] = True
2401
+ batch_size, cur_len = input_ids.shape
2402
+
2403
+ if "inputs_embeds" in model_kwargs:
2404
+ cur_len = model_kwargs["inputs_embeds"].shape[1]
2405
+ model_kwargs["cache_position"] = torch.arange(cur_len)
2406
+ pad_token_id = self.pad_token_id
2407
+ eos_token_id = [self.eos_token_id]
2408
+ eos_token = self.eos_token_id
2409
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.int64)
2410
+ for idx in range(self.max_seq_len):
2411
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2412
+ outputs = self.generate_single_iter(
2413
+ **model_inputs,
2414
+ return_dict=True,
2415
+ output_attentions=False,
2416
+ output_hidden_states=False,
2417
+ )
2418
+ next_token_logits = outputs.logits[:, -1, :]
2419
+
2420
+ next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
2421
+ next_tokens = torch.argmax(next_tokens_scores, dim=-1)
2422
+ if eos_token_id is not None:
2423
+ if pad_token_id is None:
2424
+ raise ValueError(
2425
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
2426
+ )
2427
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
2428
+ 1 - unfinished_sequences
2429
+ )
2430
+ input_ids = torch.concat([input_ids, next_tokens[:, None]], dim=-1)
2431
+ model_kwargs = self._update_model_kwargs_for_generation(
2432
+ outputs,
2433
+ model_kwargs,
2434
+ is_encoder_decoder=self.config_decoder.is_encoder_decoder,
2435
+ )
2436
+ unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
2437
+ input_ids
2438
+ ).to(torch.int64)
2439
+
2440
+ if (
2441
+ eos_token is not None
2442
+ and (
2443
+ torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
2444
+ >= 1
2445
+ ).all()
2446
+ ):
2447
+ break
2448
+
2449
+ return input_ids
2450
+
2451
+ @torch.no_grad()
2452
+ def generate_export(
2453
+ self,
2454
+ encoder_outputs,
2455
+ model_kwargs,
2456
+ ):
2457
+ batch_size = encoder_outputs["last_hidden_state"].shape[0]
2458
+ generation_config = {
2459
+ "decoder_start_token_id": 0,
2460
+ "bos_token_id": 0,
2461
+ }
2462
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
2463
+ batch_size=batch_size,
2464
+ model_kwargs=model_kwargs,
2465
+ decoder_start_token_id=generation_config["decoder_start_token_id"],
2466
+ bos_token_id=generation_config["bos_token_id"],
2467
+ )
2468
+ input_ids = input_ids.reshape([-1, 1])
2469
+ decoder_input_ids = input_ids
2470
+ model_kwargs["key use_cache"] = True
2471
+ batch_size, cur_len = input_ids.shape
2472
+
2473
+ if "inputs_embeds" in model_kwargs:
2474
+ cur_len = model_kwargs["inputs_embeds"].shape[1]
2475
+ cache_position = torch.arange(cur_len)
2476
+ pad_token_id = self.pad_token_id
2477
+ eos_token_id = [self.eos_token_id]
2478
+ eos_token = self.eos_token_id
2479
+ unfinished_sequences = torch.ones([batch_size], dtype=torch.int64)
2480
+ i_idx = torch.full([], 0)
2481
+ past_key_values = []
2482
+ for i in range(8):
2483
+ init_arr = torch.zeros([batch_size, 16, 0, 64])
2484
+ cache = (init_arr, init_arr, init_arr, init_arr)
2485
+ past_key_values.append(cache)
2486
+ idx = 0
2487
+ while i_idx < torch.Tensor(self.max_seq_len):
2488
+
2489
+ model_inputs = self.prepare_inputs_for_generation_export(
2490
+ past_key_values=past_key_values, **model_kwargs
2491
+ )
2492
+ decoder_attention_mask = model_inputs["decoder_attention_mask"]
2493
+ decoder_attention_mask = torch.ones(input_ids.shape)
2494
+
2495
+ outputs = self.generate_single_iter(
2496
+ decoder_input_ids=decoder_input_ids,
2497
+ decoder_attention_mask=decoder_attention_mask,
2498
+ encoder_outputs=encoder_outputs,
2499
+ past_key_values=past_key_values,
2500
+ return_dict=True,
2501
+ output_attentions=False,
2502
+ output_hidden_states=False,
2503
+ )
2504
+
2505
+ next_token_logits = outputs.logits[:, -1, :]
2506
+
2507
+ next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
2508
+ next_tokens = torch.argmax(next_tokens_scores, dim=-1)
2509
+ if eos_token_id is not None:
2510
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
2511
+ 1 - unfinished_sequences
2512
+ )
2513
+ input_ids = torch.concat([input_ids, next_tokens.unsqueeze(1)], dim=-1)
2514
+ past_length = past_key_values[0][0].shape[2]
2515
+ decoder_input_ids = next_tokens.unsqueeze(1)
2516
+ past_key_values = outputs.past_key_values
2517
+ cache_position = cache_position[-1:] + 1
2518
+ unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
2519
+ input_ids
2520
+ ).to(torch.int64)
2521
+ if (
2522
+ eos_token is not None
2523
+ and (
2524
+ torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
2525
+ >= 1
2526
+ ).all()
2527
+ ):
2528
+ break
2529
+
2530
+ i_idx += 1
2531
+ return input_ids
2532
+
2533
+ def forwad_train(
2534
+ self,
2535
+ encoder_outputs,
2536
+ decoder_input_ids,
2537
+ decoder_attention_mask,
2538
+ past_key_values=None,
2539
+ decoder_inputs_embeds=None,
2540
+ labels=None,
2541
+ use_cache=None,
2542
+ output_attentions=None,
2543
+ output_hidden_states=None,
2544
+ return_dict=None,
2545
+ **kwargs,
2546
+ ):
2547
+ """
2548
+ Training for the UniMERNetHead.
2549
+
2550
+ Args:
2551
+ encoder_outputs: Outputs from the encoder, used as input to the decoder.
2552
+ decoder_input_ids: Input IDs for the decoder.
2553
+ decoder_attention_mask: Attention mask for the decoder inputs.
2554
+ past_key_values: Cached key/values for faster decoding.
2555
+ decoder_inputs_embeds: Optional embeddings for the decoder inputs.
2556
+ labels: Target labels for calculating loss.
2557
+ use_cache: Whether to use cache during decoding.
2558
+ output_attentions: Whether to return attention scores.
2559
+ output_hidden_states: Whether to return hidden states.
2560
+ return_dict: Whether to return a dictionary of outputs.
2561
+ **kwargs: Additional keyword arguments.
2562
+
2563
+ Returns:
2564
+ logits: The raw, unnormalized predictions from the model.
2565
+ count_pred: Optional prediction related to sequence length or other counts.
2566
+ masked_labels: The labels used during training, possibly masked.
2567
+ """
2568
+ labels = decoder_input_ids * 1
2569
+ labels = labels.masked_fill_(labels == self.pad_token_id, -100)
2570
+ input_decoder_input_ids = decoder_input_ids[:, :-1]
2571
+ input_decoder_attention_mask = decoder_attention_mask[:, :-1]
2572
+ encoder_hidden_states = encoder_outputs[0]
2573
+ if self.config_decoder.hidden_size != self.encoder_hidden_size:
2574
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
2575
+ kwargs_decoder = {}
2576
+ decoder_outputs = self.decoder(
2577
+ input_ids=input_decoder_input_ids,
2578
+ attention_mask=input_decoder_attention_mask,
2579
+ encoder_hidden_states=encoder_hidden_states,
2580
+ encoder_attention_mask=None,
2581
+ inputs_embeds=None,
2582
+ output_attentions=False,
2583
+ output_hidden_states=output_hidden_states,
2584
+ use_cache=use_cache,
2585
+ past_key_values=past_key_values,
2586
+ return_dict=return_dict,
2587
+ **kwargs_decoder,
2588
+ )
2589
+
2590
+ logits = decoder_outputs.logits
2591
+ count_pred = decoder_outputs.counting
2592
+ return logits, count_pred, labels
2593
+
2594
+ def forward(self, inputs, targets=None):
2595
+ """
2596
+ Forward pass for the UniMERNetHead, handling both training and inference.
2597
+
2598
+ Args:
2599
+ inputs: The input data, which can vary based on training or inference.
2600
+ targets: The target labels, used only during training.
2601
+
2602
+ Returns:
2603
+ During inference: Returns predicted latex code.
2604
+ During training: Returns logits, predicted counts, and masked labels.
2605
+ """
2606
+ self.is_export = False if self.training else True
2607
+ if not self.training:
2608
+ encoder_outputs = inputs
2609
+ if self.is_export:
2610
+ model_kwargs = {
2611
+ "output_attentions": False,
2612
+ "output_hidden_states": False,
2613
+ "use_cache": True,
2614
+ }
2615
+ word_pred = self.generate_export(encoder_outputs, model_kwargs)
2616
+ else:
2617
+ model_kwargs = {
2618
+ "output_attentions": False,
2619
+ "output_hidden_states": False,
2620
+ "use_cache": True,
2621
+ "encoder_outputs": encoder_outputs,
2622
+ }
2623
+ word_pred = self.generate(model_kwargs)
2624
+
2625
+ return word_pred
2626
+
2627
+ encoder_outputs, tgt_seq, mask = inputs
2628
+ logits, count_pred, masked_labels = self.forwad_train(
2629
+ encoder_outputs, tgt_seq, mask
2630
+ )
2631
+ return logits, count_pred, masked_labels