mineru 2.5.4__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. mineru/backend/pipeline/model_init.py +25 -3
  2. mineru/backend/pipeline/model_json_to_middle_json.py +2 -2
  3. mineru/backend/pipeline/model_list.py +0 -1
  4. mineru/backend/utils.py +24 -0
  5. mineru/backend/vlm/model_output_to_middle_json.py +2 -2
  6. mineru/backend/vlm/{custom_logits_processors.py → utils.py} +36 -2
  7. mineru/backend/vlm/vlm_analyze.py +43 -50
  8. mineru/backend/vlm/vlm_magic_model.py +155 -1
  9. mineru/cli/common.py +25 -22
  10. mineru/cli/fast_api.py +2 -8
  11. mineru/cli/gradio_app.py +96 -9
  12. mineru/cli/models_download.py +1 -0
  13. mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py +152 -0
  14. mineru/model/mfr/pp_formulanet_plus_m/processors.py +657 -0
  15. mineru/model/mfr/unimernet/unimernet_hf/modeling_unimernet.py +1 -326
  16. mineru/model/mfr/utils.py +338 -0
  17. mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py +103 -16
  18. mineru/model/table/rec/unet_table/main.py +1 -1
  19. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/imaug/operators.py +5 -5
  20. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/__init__.py +2 -1
  21. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_lcnetv3.py +7 -7
  22. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_pphgnetv2.py +2 -2
  23. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/__init__.py +2 -0
  24. mineru/model/utils/pytorchocr/modeling/heads/rec_ppformulanet_head.py +1383 -0
  25. mineru/model/utils/pytorchocr/modeling/heads/rec_unimernet_head.py +2631 -0
  26. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/rec_postprocess.py +25 -28
  27. mineru/model/utils/pytorchocr/utils/__init__.py +0 -0
  28. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/arch_config.yaml +130 -0
  29. mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_arabic_dict.txt +747 -0
  30. mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_cyrillic_dict.txt +850 -0
  31. mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_devanagari_dict.txt +568 -0
  32. mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_ta_dict.txt +513 -0
  33. mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_te_dict.txt +540 -0
  34. mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/models_config.yml +15 -15
  35. mineru/model/utils/pytorchocr/utils/resources/pp_formulanet_arch_config.yaml +24 -0
  36. mineru/model/utils/tools/infer/__init__.py +1 -0
  37. mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_det.py +6 -3
  38. mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_rec.py +16 -25
  39. mineru/model/vlm_vllm_model/server.py +4 -1
  40. mineru/resources/header.html +2 -2
  41. mineru/utils/enum_class.py +1 -0
  42. mineru/utils/llm_aided.py +4 -2
  43. mineru/utils/ocr_utils.py +16 -0
  44. mineru/utils/table_merge.py +102 -13
  45. mineru/version.py +1 -1
  46. {mineru-2.5.4.dist-info → mineru-2.6.0.dist-info}/METADATA +32 -8
  47. mineru-2.6.0.dist-info/RECORD +195 -0
  48. mineru-2.5.4.dist-info/RECORD +0 -181
  49. /mineru/model/{ocr/paddleocr2pytorch/pytorchocr → mfr/pp_formulanet_plus_m}/__init__.py +0 -0
  50. /mineru/model/{ocr/paddleocr2pytorch/tools/infer → utils}/__init__.py +0 -0
  51. /mineru/model/{ocr/paddleocr2pytorch/pytorchocr/modeling → utils/pytorchocr}/__init__.py +0 -0
  52. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/base_ocr_v20.py +0 -0
  53. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/__init__.py +0 -0
  54. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/data/imaug/__init__.py +0 -0
  55. /mineru/model/{ocr/paddleocr2pytorch/pytorchocr/utils → utils/pytorchocr/modeling}/__init__.py +0 -0
  56. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/architectures/__init__.py +0 -0
  57. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/architectures/base_model.py +0 -0
  58. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/det_mobilenet_v3.py +0 -0
  59. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_donut_swin.py +0 -0
  60. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_hgnet.py +0 -0
  61. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +0 -0
  62. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_mv1_enhance.py +0 -0
  63. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/backbones/rec_svtrnet.py +0 -0
  64. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/common.py +0 -0
  65. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/cls_head.py +0 -0
  66. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/det_db_head.py +0 -0
  67. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/rec_ctc_head.py +0 -0
  68. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/heads/rec_multi_head.py +0 -0
  69. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/__init__.py +0 -0
  70. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/db_fpn.py +0 -0
  71. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/intracl.py +0 -0
  72. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/modeling/necks/rnn.py +0 -0
  73. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/__init__.py +0 -0
  74. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/cls_postprocess.py +0 -0
  75. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/postprocess/db_postprocess.py +0 -0
  76. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/arabic_dict.txt +0 -0
  77. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +0 -0
  78. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/cyrillic_dict.txt +0 -0
  79. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/devanagari_dict.txt +0 -0
  80. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/en_dict.txt +0 -0
  81. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/japan_dict.txt +0 -0
  82. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ka_dict.txt +0 -0
  83. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/korean_dict.txt +0 -0
  84. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/latin_dict.txt +0 -0
  85. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +0 -0
  86. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv4_doc_dict.txt +0 -0
  87. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_dict.txt +0 -0
  88. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_el_dict.txt +0 -0
  89. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_en_dict.txt +0 -0
  90. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_eslav_dict.txt +0 -0
  91. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_korean_dict.txt +0 -0
  92. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_latin_dict.txt +0 -0
  93. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ppocrv5_th_dict.txt +0 -0
  94. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/ta_dict.txt +0 -0
  95. /mineru/model/{ocr/paddleocr2pytorch → utils}/pytorchocr/utils/resources/dict/te_dict.txt +0 -0
  96. /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/__init__.py +0 -0
  97. /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_cls.py +0 -0
  98. /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/predict_system.py +0 -0
  99. /mineru/model/{ocr/paddleocr2pytorch → utils}/tools/infer/pytorchocr_utility.py +0 -0
  100. {mineru-2.5.4.dist-info → mineru-2.6.0.dist-info}/WHEEL +0 -0
  101. {mineru-2.5.4.dist-info → mineru-2.6.0.dist-info}/entry_points.txt +0 -0
  102. {mineru-2.5.4.dist-info → mineru-2.6.0.dist-info}/licenses/LICENSE.md +0 -0
  103. {mineru-2.5.4.dist-info → mineru-2.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1383 @@
1
+ # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import re
17
+ import numpy as np
18
+ import inspect
19
+ import torch
20
+ import torch.nn as nn
21
+ from typing import Optional, Tuple, Union, List, Dict, Any
22
+ from dataclasses import dataclass, fields, is_dataclass
23
+
24
+ from sympy import totient
25
+
26
+ from mineru.utils.config_reader import get_device
27
+ from .rec_unimernet_head import (
28
+ MBartForCausalLM,
29
+ MBartDecoder,
30
+ MBartConfig,
31
+ ModelOutput,
32
+ BaseModelOutputWithPastAndCrossAttentions,
33
+ Seq2SeqLMOutput,
34
+ CausalLMOutputWithCrossAttentions,
35
+ LogitsProcessorList,
36
+ ForcedEOSTokenLogitsProcessor,
37
+ UniMERNetHead,
38
+ )
39
+
40
+
41
+ @dataclass
42
+ class AttentionMaskConverter:
43
+ """
44
+ A class to convert attention masks based on specific configurations.
45
+
46
+ This class is designed to handle the conversion of attention masks with options for causal masking
47
+ and sliding window attention, which are commonly used in transformer models.
48
+
49
+ Attributes:
50
+ is_causal (bool): Flag indicating whether the attention mask should enforce causal masking,
51
+ which ensures each position can only attend to previous positions.
52
+ sliding_window (int, optional): Size of the sliding window for local attention. If set,
53
+ attention is restricted to a local window of this size.
54
+
55
+ """
56
+
57
+ is_causal: bool
58
+ sliding_window: int
59
+
60
+ def __init__(self, is_causal: bool, sliding_window=None):
61
+ self.is_causal = is_causal
62
+ self.sliding_window = sliding_window
63
+
64
+ if self.sliding_window is not None and self.sliding_window <= 0:
65
+ raise ValueError(
66
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
67
+ )
68
+
69
+ @staticmethod
70
+ def _make_causal_mask(
71
+ input_ids_shape,
72
+ dtype,
73
+ past_key_values_length=0,
74
+ sliding_window=None,
75
+ is_export=False,
76
+ ):
77
+ """
78
+ Make causal mask used for bi-directional self-attention.
79
+ """
80
+ bsz, tgt_len = input_ids_shape
81
+ if is_export:
82
+ mask = torch.full(
83
+ (tgt_len, tgt_len), torch.finfo(dtype).min, dtype=torch.float64
84
+ )
85
+ mask_cond = torch.arange(mask.shape[-1])
86
+ mask.masked_fill_(
87
+ mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
88
+ )
89
+ else:
90
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
91
+ mask_cond = torch.arange(mask.shape[-1])
92
+ mask.masked_fill_(
93
+ mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
94
+ )
95
+ mask = mask.to(dtype)
96
+
97
+ if past_key_values_length > 0:
98
+ mask = torch.concat(
99
+ [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask],
100
+ dim=-1,
101
+ )
102
+
103
+ # add lower triangular sliding window mask if necessary
104
+ if sliding_window is not None:
105
+ diagonal = past_key_values_length - sliding_window - 1
106
+
107
+ context_mask = torch.tril(
108
+ torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal
109
+ )
110
+ mask.masked_fill_(context_mask, torch.finfo(dtype).min)
111
+
112
+ return mask[None, None, :, :].expand(
113
+ [bsz, 1, tgt_len, tgt_len + past_key_values_length]
114
+ )
115
+
116
+ @staticmethod
117
+ def _make_causal_mask_parallel(
118
+ input_ids_shape,
119
+ dtype,
120
+ past_key_values_length=0,
121
+ sliding_window=None,
122
+ parallel_step=1,
123
+ is_export=False,
124
+ ):
125
+ """
126
+ Make causal mask used for bi-directional self-attention.
127
+ """
128
+ bsz, tgt_len = input_ids_shape
129
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
130
+ mask_cond = torch.arange(mask.shape[-1])
131
+ mask_cond_parallel = torch.arange(mask.shape[-1])
132
+
133
+ mask_parallel = torch.arange(0, tgt_len, step=parallel_step).reshape([1, -1])
134
+ mask_parallel = torch.repeat_interleave(mask_parallel, parallel_step, 1)[
135
+ :, :tgt_len
136
+ ]
137
+ mask.masked_fill_(
138
+ mask_cond < (mask_parallel + parallel_step).reshape([mask.shape[-1], 1]), 0
139
+ )
140
+ mask = mask.to(dtype)
141
+
142
+ if past_key_values_length > 0:
143
+ mask = torch.concat(
144
+ [torch.zeros([tgt_len, past_key_values_length], dtype=dtype), mask],
145
+ dim=-1,
146
+ )
147
+
148
+ # add lower triangular sliding window mask if necessary
149
+ if sliding_window is not None:
150
+ diagonal = past_key_values_length - sliding_window - 1
151
+
152
+ context_mask = torch.tril(
153
+ torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal
154
+ )
155
+ mask.masked_fill_(context_mask, torch.finfo(dtype).min)
156
+
157
+ return mask[None, None, :, :].expand(
158
+ [bsz, 1, tgt_len, tgt_len + past_key_values_length]
159
+ )
160
+
161
+ def to_4d(
162
+ self,
163
+ attention_mask_2d,
164
+ query_length,
165
+ dtype,
166
+ key_value_length,
167
+ use_parallel=False,
168
+ parallel_step=3,
169
+ is_export=False,
170
+ ):
171
+ """
172
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
173
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
174
+ causal, a causal mask will be added.
175
+ """
176
+ input_shape = (attention_mask_2d.shape[0], query_length)
177
+
178
+ causal_4d_mask = None
179
+ if use_parallel:
180
+ step = parallel_step
181
+ else:
182
+ step = 1
183
+ if (
184
+ input_shape[-1] > step or self.sliding_window is not None
185
+ ) and self.is_causal:
186
+
187
+ if key_value_length is None:
188
+ raise ValueError(
189
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
190
+ )
191
+
192
+ past_key_values_length = key_value_length - query_length
193
+
194
+ if use_parallel:
195
+ causal_4d_mask = self._make_causal_mask_parallel(
196
+ input_shape,
197
+ dtype,
198
+ past_key_values_length=past_key_values_length,
199
+ sliding_window=self.sliding_window,
200
+ parallel_step=parallel_step,
201
+ is_export=is_export,
202
+ )
203
+ else:
204
+ causal_4d_mask = self._make_causal_mask(
205
+ input_shape,
206
+ dtype,
207
+ past_key_values_length=past_key_values_length,
208
+ sliding_window=self.sliding_window,
209
+ is_export=is_export,
210
+ )
211
+
212
+ elif self.sliding_window is not None:
213
+ raise NotImplementedError(
214
+ "Sliding window is currently only implemented for causal masking"
215
+ )
216
+
217
+ expanded_attn_mask = self._expand_mask(
218
+ attention_mask_2d, dtype, tgt_len=input_shape[-1]
219
+ )
220
+
221
+ if causal_4d_mask is not None:
222
+ expanded_attn_mask = causal_4d_mask.masked_fill_(
223
+ expanded_attn_mask.to(torch.bool), torch.finfo(dtype).min
224
+ )
225
+
226
+ expanded_4d_mask = expanded_attn_mask
227
+ return expanded_4d_mask
228
+
229
+ def to_4d_export(
230
+ self,
231
+ attention_mask_2d,
232
+ query_length,
233
+ dtype,
234
+ key_value_length,
235
+ use_parallel=False,
236
+ parallel_step=3,
237
+ is_export=False,
238
+ ):
239
+ input_shape = (attention_mask_2d.shape[0], query_length)
240
+
241
+ expanded_attn_mask = self._expand_mask_export(
242
+ attention_mask_2d, dtype, tgt_len=input_shape[-1]
243
+ )
244
+ expanded_4d_mask = expanded_attn_mask
245
+
246
+ return expanded_4d_mask
247
+
248
+ def _expand_mask(self, mask, dtype, tgt_len=None):
249
+ """
250
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
251
+ """
252
+ bsz, src_len = mask.shape
253
+ tgt_len = tgt_len if tgt_len is not None else src_len
254
+ expanded_mask = (
255
+ mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).to(dtype)
256
+ )
257
+
258
+ inverted_mask = 1.0 - expanded_mask
259
+
260
+ return inverted_mask.masked_fill_(
261
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
262
+ )
263
+
264
+ def _expand_mask_export(self, mask, dtype, tgt_len=None):
265
+ """
266
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
267
+ """
268
+ bsz, src_len = mask.shape
269
+ expanded_mask = (
270
+ mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).to(dtype)
271
+ )
272
+ inverted_mask = 1.0 - expanded_mask
273
+ return inverted_mask.masked_fill_(
274
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
275
+ )
276
+
277
+
278
+ def _prepare_4d_attention_mask(mask, dtype, tgt_len=None):
279
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
280
+
281
+
282
+ def _prepare_4d_causal_attention_mask(
283
+ attention_mask,
284
+ input_shape,
285
+ inputs_embeds,
286
+ past_key_values_length,
287
+ sliding_window=None,
288
+ use_parallel=False,
289
+ parallel_step=3,
290
+ is_export=False,
291
+ ):
292
+ """
293
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
294
+ `(batch_size, key_value_length)`
295
+
296
+ Args:
297
+ attention_mask (`paddle.Tensor` or `None`):
298
+ A 2D attention mask of shape `(batch_size, key_value_length)`
299
+ input_shape (`tuple(int)` or `list(int)` or `paddle.Size`):
300
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
301
+ inputs_embeds (`paddle.Tensor`):
302
+ The embedded inputs as a paddle Tensor.
303
+ past_key_values_length (`int`):
304
+ The length of the key value cache.
305
+ sliding_window (`int`, *optional*):
306
+ If the model uses windowed attention, a sliding window should be passed.
307
+ """
308
+ attn_mask_converter = AttentionMaskConverter(
309
+ is_causal=True, sliding_window=sliding_window
310
+ )
311
+
312
+ key_value_length = input_shape[-1] + past_key_values_length
313
+
314
+ # 4d mask is passed through the layers
315
+ if attention_mask is not None and len(attention_mask.shape) == 2:
316
+ attention_mask = attn_mask_converter.to_4d(
317
+ attention_mask,
318
+ input_shape[-1],
319
+ key_value_length=key_value_length,
320
+ dtype=inputs_embeds.dtype,
321
+ use_parallel=use_parallel,
322
+ parallel_step=parallel_step,
323
+ is_export=is_export,
324
+ )
325
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
326
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
327
+ if tuple(attention_mask.shape) != expected_shape:
328
+ raise ValueError(
329
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
330
+ )
331
+ else:
332
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
333
+ inverted_mask = 1.0 - attention_mask
334
+ attention_mask = inverted_mask.masked_fill_(
335
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
336
+ )
337
+ else:
338
+ attention_mask = attn_mask_converter.to_causal_4d(
339
+ input_shape[0],
340
+ input_shape[-1],
341
+ key_value_length,
342
+ dtype=inputs_embeds.dtype,
343
+ )
344
+
345
+ return attention_mask
346
+
347
+
348
+ def _prepare_4d_causal_attention_mask_export(
349
+ attention_mask,
350
+ input_shape,
351
+ inputs_embeds,
352
+ past_key_values_length,
353
+ sliding_window=None,
354
+ use_parallel=False,
355
+ parallel_step=3,
356
+ is_export=False,
357
+ ):
358
+ """
359
+ Prepare a 4D causal attention mask for export.
360
+
361
+ This function prepares a 4-dimensional causal attention mask, which is used to ensure that each position in the
362
+ sequence can only attend to previous positions. It is specifically designed to handle scenarios where the model
363
+ is being exported, potentially with additional options like sliding window or parallel processing.
364
+
365
+ Args:
366
+ attention_mask: The initial attention mask, typically used to avoid attending to padding tokens.
367
+ input_shape: Shape of the input tensor, usually in the form (batch_size, sequence_length).
368
+ inputs_embeds: Embeddings of the input sequence, used to derive certain dimensions if needed.
369
+ past_key_values_length: Length of past key values, used in contexts like transformer decoders with caching.
370
+ sliding_window: Optional parameter. If provided, specifies the size of a sliding window for local attention.
371
+ use_parallel: Flag indicating whether to use parallel processing for attention computation.
372
+ parallel_step: Number of steps to use in parallel processing, relevant if `use_parallel` is True.
373
+ is_export: Flag indicating whether the attention mask is being prepared for model export.
374
+
375
+ Returns:
376
+ A 4D causal attention mask suitable for use in transformer models, ensuring correct causal masking.
377
+ """
378
+ attn_mask_converter = AttentionMaskConverter(
379
+ is_causal=True, sliding_window=sliding_window
380
+ )
381
+ key_value_length = input_shape[-1] + past_key_values_length
382
+
383
+ shape = attention_mask.shape
384
+ len_shape = len(shape)
385
+
386
+ attention_mask = attn_mask_converter.to_4d_export(
387
+ attention_mask,
388
+ input_shape[-1],
389
+ key_value_length=key_value_length,
390
+ dtype=inputs_embeds.dtype,
391
+ use_parallel=use_parallel,
392
+ parallel_step=parallel_step,
393
+ is_export=is_export,
394
+ )
395
+ return attention_mask
396
+
397
+
398
+ class CustomMBartDecoder(MBartDecoder):
399
+ def __init__(self, config):
400
+ super().__init__(config)
401
+ hidden_size = config.d_model
402
+ self.is_export = config.is_export
403
+ self.config_decoder = config
404
+
405
+ def forward(
406
+ self,
407
+ input_ids=None,
408
+ attention_mask=None,
409
+ encoder_hidden_states=None,
410
+ encoder_attention_mask=None,
411
+ head_mask=None,
412
+ cross_attn_head_mask=None,
413
+ past_key_values=None,
414
+ inputs_embeds=None,
415
+ use_cache=None,
416
+ output_attentions=None,
417
+ output_hidden_states=None,
418
+ return_dict=None,
419
+ ):
420
+ self.is_export = False if self.training else True
421
+
422
+ output_attentions = (
423
+ output_attentions
424
+ if output_attentions is not None
425
+ else self.config.output_attentions
426
+ )
427
+ output_hidden_states = (
428
+ output_hidden_states
429
+ if output_hidden_states is not None
430
+ else self.config.output_hidden_states
431
+ )
432
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
433
+ return_dict = (
434
+ return_dict if return_dict is not None else self.config.use_return_dict
435
+ )
436
+
437
+ # retrieve input_ids and inputs_embeds
438
+ if input_ids is not None and inputs_embeds is not None:
439
+ raise ValueError(
440
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
441
+ )
442
+ elif input_ids is not None:
443
+ input = input_ids
444
+ input_shape = input.shape
445
+ input_ids = input_ids.reshape([-1, input_shape[-1]])
446
+ elif inputs_embeds is not None:
447
+ input_shape = inputs_embeds.shape[:-1]
448
+ input = inputs_embeds[:, :, -1]
449
+ else:
450
+ raise ValueError(
451
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
452
+ )
453
+
454
+ # past_key_values_length
455
+ past_key_values_length = (
456
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
457
+ )
458
+
459
+ if inputs_embeds is None:
460
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
461
+
462
+ if self._use_flash_attention_2:
463
+ # 2d mask is passed through the layers
464
+ attention_mask = (
465
+ attention_mask
466
+ if (attention_mask is not None and 0 in attention_mask)
467
+ else None
468
+ )
469
+ else:
470
+ # 4d mask is passed through the layers
471
+ if self.is_export:
472
+ attention_mask = _prepare_4d_causal_attention_mask_export(
473
+ attention_mask,
474
+ input_shape,
475
+ inputs_embeds,
476
+ past_key_values_length,
477
+ use_parallel=self.config_decoder.use_parallel,
478
+ parallel_step=self.config_decoder.parallel_step,
479
+ is_export=self.is_export,
480
+ )
481
+ else:
482
+ attention_mask = _prepare_4d_causal_attention_mask(
483
+ attention_mask,
484
+ input_shape,
485
+ inputs_embeds,
486
+ past_key_values_length,
487
+ use_parallel=self.config_decoder.use_parallel,
488
+ parallel_step=self.config_decoder.parallel_step,
489
+ is_export=self.is_export,
490
+ )
491
+
492
+ # expand encoder attention mask
493
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
494
+ if self._use_flash_attention_2:
495
+ encoder_attention_mask = (
496
+ encoder_attention_mask if 0 in encoder_attention_mask else None
497
+ )
498
+ else:
499
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
500
+ encoder_attention_mask = _prepare_4d_attention_mask(
501
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
502
+ )
503
+
504
+ # embed positions
505
+ positions = self.embed_positions(input, past_key_values_length)
506
+
507
+ hidden_states = inputs_embeds + positions
508
+
509
+ hidden_states = self.layernorm_embedding(hidden_states)
510
+ hidden_states = nn.functional.dropout(
511
+ hidden_states, p=self.dropout, training=self.training
512
+ )
513
+ if self.gradient_checkpointing and self.training:
514
+ if use_cache:
515
+ print(
516
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
517
+ )
518
+ use_cache = False
519
+
520
+ # decoder layers
521
+ all_hidden_states = () if output_hidden_states else None
522
+ all_self_attns = () if output_attentions else None
523
+ all_cross_attentions = (
524
+ () if (output_attentions and encoder_hidden_states is not None) else None
525
+ )
526
+ next_decoder_cache = () if use_cache else None
527
+
528
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
529
+ for attn_mask, mask_name in zip(
530
+ [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
531
+ ):
532
+ if attn_mask is not None:
533
+ if attn_mask.size()[0] != len(self.layers):
534
+ raise ValueError(
535
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
536
+ f" {attn_mask.size()[0]}."
537
+ )
538
+ for idx, decoder_layer in enumerate(self.layers):
539
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
540
+ if output_hidden_states:
541
+ all_hidden_states += (hidden_states,)
542
+ if self.training:
543
+ dropout_probability = torch.rand([])
544
+ if dropout_probability < self.layerdrop:
545
+ continue
546
+
547
+ past_key_value = (
548
+ past_key_values[idx] if past_key_values is not None else None
549
+ )
550
+ if self.gradient_checkpointing and self.training:
551
+ layer_outputs = self._gradient_checkpointing_func(
552
+ decoder_layer.__call__,
553
+ hidden_states,
554
+ attention_mask,
555
+ encoder_hidden_states,
556
+ encoder_attention_mask,
557
+ head_mask[idx] if head_mask is not None else None,
558
+ (
559
+ cross_attn_head_mask[idx]
560
+ if cross_attn_head_mask is not None
561
+ else None
562
+ ),
563
+ None,
564
+ output_attentions,
565
+ use_cache,
566
+ )
567
+ else:
568
+ layer_outputs = decoder_layer(
569
+ hidden_states,
570
+ attention_mask=attention_mask,
571
+ encoder_hidden_states=encoder_hidden_states,
572
+ encoder_attention_mask=encoder_attention_mask,
573
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
574
+ cross_attn_layer_head_mask=(
575
+ cross_attn_head_mask[idx]
576
+ if cross_attn_head_mask is not None
577
+ else None
578
+ ),
579
+ past_key_value=past_key_value,
580
+ output_attentions=output_attentions,
581
+ use_cache=use_cache,
582
+ )
583
+ hidden_states = layer_outputs[0]
584
+
585
+ if self.is_export:
586
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
587
+ else:
588
+ if use_cache:
589
+ next_decoder_cache += (
590
+ layer_outputs[3 if output_attentions else 1],
591
+ )
592
+
593
+ if output_attentions:
594
+ all_self_attns += (layer_outputs[1],)
595
+
596
+ if encoder_hidden_states is not None:
597
+ all_cross_attentions += (layer_outputs[2],)
598
+
599
+ hidden_states = self.layer_norm(hidden_states)
600
+
601
+ # add hidden states from the last decoder layer
602
+ if output_hidden_states:
603
+ all_hidden_states += (hidden_states,)
604
+
605
+ if self.is_export:
606
+ next_cache = next_decoder_cache
607
+ else:
608
+ next_cache = next_decoder_cache if use_cache else None
609
+ if not return_dict:
610
+ return tuple(
611
+ v
612
+ for v in [
613
+ hidden_states,
614
+ next_cache,
615
+ all_hidden_states,
616
+ all_self_attns,
617
+ all_cross_attentions,
618
+ ]
619
+ if v is not None
620
+ )
621
+
622
+ return BaseModelOutputWithPastAndCrossAttentions(
623
+ last_hidden_state=hidden_states,
624
+ past_key_values=next_cache,
625
+ hidden_states=all_hidden_states,
626
+ attentions=all_self_attns,
627
+ cross_attentions=all_cross_attentions,
628
+ )
629
+
630
+
631
+ class CustomMBartForCausalLM(MBartForCausalLM):
632
+ def __init__(self, config):
633
+ super().__init__(config)
634
+ # Modify the decoder within MBartDecoderWrapper
635
+ self.model.decoder = CustomMBartDecoder(config)
636
+
637
+ def forward(
638
+ self,
639
+ input_ids=None,
640
+ attention_mask=None,
641
+ encoder_hidden_states=None,
642
+ encoder_attention_mask=None,
643
+ head_mask=None,
644
+ cross_attn_head_mask=None,
645
+ past_key_values=None,
646
+ inputs_embeds=None,
647
+ labels=None,
648
+ use_cache=None,
649
+ output_attentions=None,
650
+ output_hidden_states=None,
651
+ return_dict=None,
652
+ ):
653
+ output_attentions = (
654
+ output_attentions
655
+ if output_attentions is not None
656
+ else self.config.output_attentions
657
+ )
658
+ output_hidden_states = (
659
+ output_hidden_states
660
+ if output_hidden_states is not None
661
+ else self.config.output_hidden_states
662
+ )
663
+ return_dict = (
664
+ return_dict if return_dict is not None else self.config.use_return_dict
665
+ )
666
+
667
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
668
+ outputs = self.model.decoder(
669
+ input_ids=input_ids,
670
+ attention_mask=attention_mask,
671
+ encoder_hidden_states=encoder_hidden_states,
672
+ encoder_attention_mask=encoder_attention_mask,
673
+ head_mask=head_mask,
674
+ cross_attn_head_mask=cross_attn_head_mask,
675
+ past_key_values=past_key_values,
676
+ inputs_embeds=inputs_embeds,
677
+ use_cache=use_cache,
678
+ output_attentions=output_attentions,
679
+ output_hidden_states=output_hidden_states,
680
+ return_dict=return_dict,
681
+ )
682
+ logits = self.lm_head(outputs[0])
683
+
684
+ return CausalLMOutputWithCrossAttentions(
685
+ logits=logits,
686
+ past_key_values=outputs.past_key_values,
687
+ hidden_states=outputs.hidden_states,
688
+ attentions=outputs.attentions,
689
+ cross_attentions=outputs.cross_attentions,
690
+ )
691
+
692
+
693
+ class PPFormulaNet_Head(UniMERNetHead):
694
+ """
695
+ PPFormulaNet_Head
696
+ Args:
697
+ max_new_tokens (int): Maximum number of new tokens to generate. Default is 1536.
698
+ decoder_start_token_id (int): Start token ID for the decoder. Default is 0.
699
+ temperature (float): Temperature parameter for controlling randomness in sampling. Default is 0.2.
700
+ do_sample (bool): Flag to determine whether to use sampling for generation. Default is False.
701
+ top_p (float): Top-p (nucleus) sampling parameter for controlling diversity. Default is 0.95.
702
+ in_channels (int): Number of input channels for the model. Default is 1024.
703
+ decoder_layers (int): Number of layers in the decoder. Default is 8.
704
+ encoder_hidden_size (int): Size of the hidden layer in the encoder. Default is 1024.
705
+ decoder_ffn_dim (int): Dimension of the feed-forward network in the decoder. Default is 4096.
706
+ decoder_hidden_size (int): Size of the hidden layer in the decoder. Default is 1024.
707
+ is_export (bool): Flag indicating whether the model is to be exported. Default is False.
708
+ length_aware (bool): Flag to determine if the model should be aware of input sequence length. Default is True.
709
+ use_parallel (bool): Flag to enable or disable parallel processing. Default is False.
710
+ parallel_step (int): Number of steps to use in parallel processing. Default is 3.
711
+ """
712
+
713
+ def __init__(
714
+ self,
715
+ max_new_tokens=1536,
716
+ decoder_start_token_id=0,
717
+ temperature=0.2,
718
+ do_sample=False,
719
+ top_p=0.95,
720
+ in_channels=1024,
721
+ decoder_layers=8,
722
+ encoder_hidden_size=1024,
723
+ decoder_ffn_dim=4096,
724
+ decoder_hidden_size=1024,
725
+ is_export=False,
726
+ length_aware=True,
727
+ use_parallel=False,
728
+ parallel_step=3,
729
+ ):
730
+
731
+ super().__init__()
732
+
733
+ mbart_config_dict = {
734
+ "activation_dropout": 0.0,
735
+ "activation_function": "gelu",
736
+ "add_cross_attention": True,
737
+ "add_final_layer_norm": True,
738
+ "attention_dropout": 0.0,
739
+ "bos_token_id": 0,
740
+ "classifier_dropout": 0.0,
741
+ "d_model": decoder_hidden_size,
742
+ "decoder_attention_heads": 16,
743
+ "decoder_ffn_dim": decoder_ffn_dim,
744
+ "decoder_layerdrop": 0.0,
745
+ "decoder_layers": decoder_layers,
746
+ "dropout": 0.1,
747
+ "encoder_attention_heads": 16,
748
+ "encoder_ffn_dim": 4096,
749
+ "encoder_layerdrop": 0.0,
750
+ "encoder_layers": 12,
751
+ "eos_token_id": 2,
752
+ "forced_eos_token_id": 2,
753
+ "init_std": 0.02,
754
+ "is_decoder": True,
755
+ "is_encoder_decoder": False,
756
+ "output_hidden_states": False,
757
+ "max_position_embeddings": (
758
+ max_new_tokens + parallel_step if use_parallel else max_new_tokens
759
+ ),
760
+ "model_type": "mbart",
761
+ "num_hidden_layers": 12,
762
+ "pad_token_id": 1,
763
+ "scale_embedding": True,
764
+ "tie_word_embeddings": False,
765
+ "transformers_version": "4.40.0",
766
+ "use_cache": True,
767
+ "use_return_dict": True,
768
+ "vocab_size": 50000,
769
+ "_attn_implementation": "eager",
770
+ "hidden_size": decoder_hidden_size,
771
+ "use_parallel": use_parallel,
772
+ "parallel_step": int(parallel_step),
773
+ "is_export": is_export,
774
+ }
775
+ self.decoder_start_token_id = decoder_start_token_id
776
+ self.temperature = temperature
777
+ self.do_sample = do_sample
778
+ self.top_p = top_p
779
+ self.is_export = is_export
780
+ self.max_seq_len = max_new_tokens
781
+ self.config_decoder = MBartConfig(**mbart_config_dict)
782
+ self.encoder_hidden_size = encoder_hidden_size
783
+ self.decoder = CustomMBartForCausalLM(self.config_decoder)
784
+ if self.config_decoder.hidden_size != self.encoder_hidden_size:
785
+ self.enc_to_dec_proj = nn.Linear(
786
+ self.encoder_hidden_size, self.config_decoder.hidden_size
787
+ )
788
+ generation_config = {
789
+ "max_length": 1537,
790
+ "forced_eos_token_id": 2,
791
+ }
792
+ self.eos_token_id = generation_config["forced_eos_token_id"]
793
+ self.pad_token_id = self.config_decoder.pad_token_id
794
+ self.logits_processor = LogitsProcessorList()
795
+ self.logits_processor.append(
796
+ ForcedEOSTokenLogitsProcessor(
797
+ generation_config["max_length"],
798
+ generation_config["forced_eos_token_id"],
799
+ )
800
+ )
801
+ self.device = torch.device(get_device())
802
+
803
+ def prepare_inputs_for_generation(
804
+ self,
805
+ input_ids,
806
+ past_key_values=None,
807
+ attention_mask=None,
808
+ use_cache=None,
809
+ encoder_outputs=None,
810
+ **kwargs,
811
+ ):
812
+ decoder_inputs = self.prepare_inputs_for_generation_mbart(
813
+ input_ids, past_key_values=past_key_values
814
+ )
815
+ decoder_attention_mask = (
816
+ decoder_inputs["attention_mask"]
817
+ if "attention_mask" in decoder_inputs
818
+ else None
819
+ )
820
+ input_dict = {
821
+ "attention_mask": attention_mask,
822
+ "decoder_attention_mask": decoder_attention_mask,
823
+ "decoder_input_ids": decoder_inputs["input_ids"],
824
+ "past_key_values": decoder_inputs["past_key_values"],
825
+ "use_cache": use_cache,
826
+ }
827
+ return input_dict
828
+
829
+ def _extract_past_from_model_output(
830
+ self, outputs: ModelOutput, standardize_cache_format: bool = False
831
+ ):
832
+ past_key_values = None
833
+ if "past_key_values" in outputs:
834
+ past_key_values = outputs.past_key_values
835
+ elif "mems" in outputs:
836
+ past_key_values = outputs.mems
837
+ elif "past_buckets_states" in outputs:
838
+ past_key_values = outputs.past_buckets_states
839
+ return past_key_values
840
+
841
+ def _update_model_kwargs_for_generation(
842
+ self,
843
+ outputs: ModelOutput,
844
+ model_kwargs: Dict[str, Any],
845
+ is_encoder_decoder: bool = False,
846
+ standardize_cache_format: bool = False,
847
+ ) -> Dict[str, Any]:
848
+ # update past_key_values
849
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
850
+ outputs, standardize_cache_format=standardize_cache_format
851
+ )
852
+ if getattr(outputs, "state", None) is not None:
853
+ model_kwargs["state"] = outputs.state
854
+
855
+ # update token_type_ids with last value
856
+ if "token_type_ids" in model_kwargs:
857
+ token_type_ids = model_kwargs["token_type_ids"]
858
+ model_kwargs["token_type_ids"] = torch.concat(
859
+ [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
860
+ )
861
+
862
+ if not is_encoder_decoder:
863
+ # update attention mask
864
+ if "attention_mask" in model_kwargs:
865
+ attention_mask = model_kwargs["attention_mask"]
866
+ model_kwargs["attention_mask"] = torch.concat(
867
+ [
868
+ attention_mask,
869
+ attention_mask.new_ones((attention_mask.shape[0], 1)),
870
+ ],
871
+ dim=-1,
872
+ )
873
+ else:
874
+ # update decoder attention mask
875
+ if "decoder_attention_mask" in model_kwargs:
876
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
877
+ model_kwargs["decoder_attention_mask"] = torch.concat(
878
+ [
879
+ decoder_attention_mask,
880
+ decoder_attention_mask.new_ones(
881
+ (decoder_attention_mask.shape[0], 1)
882
+ ),
883
+ ],
884
+ dim=-1,
885
+ )
886
+
887
+ if (
888
+ "cache_position" in model_kwargs
889
+ and model_kwargs["cache_position"] is not None
890
+ ):
891
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
892
+ return model_kwargs
893
+
894
+ def stopping_criteria(self, input_ids):
895
+ if self.is_export:
896
+ return input_ids[:, -1].cpu() == torch.Tensor([self.eos_token_id])
897
+ is_done = torch.isin(input_ids[:, -1].cpu(), torch.Tensor([self.eos_token_id]))
898
+ return is_done
899
+
900
+ def stopping_criteria_parallel(self, input_ids):
901
+ parallel_step = self.config_decoder.parallel_step
902
+
903
+ if self.is_export:
904
+ is_done_list = []
905
+ for i in range(parallel_step, 0, -1):
906
+ cur_is_done = input_ids[:, -i] == torch.Tensor([self.eos_token_id])
907
+ is_done_list.append(cur_is_done)
908
+ is_done_list = torch.Tensor(is_done_list).permute([1, 0])
909
+ return is_done_list
910
+ else:
911
+ is_done = torch.isin(
912
+ input_ids[:, -parallel_step:],
913
+ torch.Tensor([self.eos_token_id]).reshape([1, 1]),
914
+ )
915
+ return torch.Tensor(is_done)
916
+
917
+ def generate_single_iter(
918
+ self,
919
+ decoder_input_ids=None,
920
+ decoder_attention_mask=None,
921
+ encoder_outputs=None,
922
+ past_key_values=None,
923
+ decoder_inputs_embeds=None,
924
+ labels=None,
925
+ use_cache=None,
926
+ output_attentions=None,
927
+ output_hidden_states=None,
928
+ return_dict=None,
929
+ **kwargs,
930
+ ):
931
+
932
+ encoder_hidden_states = encoder_outputs[0]
933
+ if self.config_decoder.hidden_size != self.encoder_hidden_size:
934
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
935
+ kwargs_decoder = {}
936
+ decoder_outputs = self.decoder(
937
+ input_ids=decoder_input_ids,
938
+ attention_mask=decoder_attention_mask,
939
+ encoder_hidden_states=encoder_hidden_states,
940
+ encoder_attention_mask=None,
941
+ inputs_embeds=None,
942
+ output_attentions=False,
943
+ output_hidden_states=output_hidden_states,
944
+ use_cache=use_cache,
945
+ past_key_values=past_key_values,
946
+ return_dict=return_dict,
947
+ **kwargs_decoder,
948
+ )
949
+
950
+ return Seq2SeqLMOutput(
951
+ loss=None,
952
+ logits=decoder_outputs.logits,
953
+ past_key_values=decoder_outputs.past_key_values,
954
+ decoder_hidden_states=decoder_outputs.hidden_states,
955
+ decoder_attentions=decoder_outputs.attentions,
956
+ cross_attentions=decoder_outputs.cross_attentions,
957
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
958
+ encoder_hidden_states=encoder_outputs.hidden_states,
959
+ encoder_attentions=encoder_outputs.attentions,
960
+ )
961
+
962
+ def _prepare_decoder_input_ids_for_generation(
963
+ self,
964
+ batch_size,
965
+ model_kwargs,
966
+ decoder_start_token_id=None,
967
+ bos_token_id=None,
968
+ ):
969
+
970
+ # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
971
+ # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
972
+ if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
973
+ decoder_input_ids = model_kwargs.pop("decoder_input_ids")
974
+ elif "input_ids" in model_kwargs:
975
+ decoder_input_ids = model_kwargs.pop("input_ids")
976
+ else:
977
+ decoder_input_ids = None
978
+
979
+ # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
980
+ decoder_start_token_id = self._get_decoder_start_token_id(
981
+ decoder_start_token_id, bos_token_id
982
+ )
983
+
984
+ if isinstance(decoder_start_token_id, list):
985
+ if len(decoder_start_token_id) != batch_size:
986
+ raise ValueError(
987
+ f"`decoder_start_token_id` expected to have length {batch_size} but got {len(decoder_start_token_id)}"
988
+ )
989
+ decoder_input_ids_start = torch.Tensor(
990
+ decoder_start_token_id
991
+ ).to(torch.int64)
992
+ decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
993
+ else:
994
+ use_parallel = self.config_decoder.use_parallel
995
+ parallel_step = self.config_decoder.parallel_step
996
+
997
+ if use_parallel:
998
+ decoder_input_ids_start = (
999
+ torch.ones(
1000
+ (batch_size, parallel_step),
1001
+ dtype=torch.int64,
1002
+ device=self.device,
1003
+ )
1004
+ * decoder_start_token_id
1005
+ )
1006
+ else:
1007
+ decoder_input_ids_start = (
1008
+ torch.ones(
1009
+ (batch_size, 1),
1010
+ dtype=torch.int64,
1011
+ device=self.device,
1012
+ )
1013
+ * decoder_start_token_id
1014
+ )
1015
+ # no user input -> use decoder_start_token_id as decoder_input_ids
1016
+ if decoder_input_ids is None:
1017
+ decoder_input_ids = decoder_input_ids_start
1018
+ # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
1019
+ elif (
1020
+ self.config.model_type == "vision-encoder-decoder"
1021
+ and "donut" in self.name_or_path.lower()
1022
+ ):
1023
+ pass
1024
+ elif self.config.model_type in ["whisper"]:
1025
+ pass
1026
+ # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
1027
+ # decoder_attention_mask if provided)
1028
+ elif (
1029
+ isinstance(decoder_start_token_id, int)
1030
+ and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
1031
+ ) or (
1032
+ isinstance(decoder_start_token_id, torch.Tensor)
1033
+ and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
1034
+ ):
1035
+ decoder_input_ids = torch.concat(
1036
+ [decoder_input_ids_start, decoder_input_ids], dim=-1
1037
+ )
1038
+ if "decoder_attention_mask" in model_kwargs:
1039
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
1040
+ decoder_attention_mask = torch.cat(
1041
+ (
1042
+ torch.ones_like(decoder_attention_mask)[:, :1],
1043
+ decoder_attention_mask,
1044
+ ),
1045
+ dim=-1,
1046
+ )
1047
+ model_kwargs["decoder_attention_mask"] = decoder_attention_mask
1048
+
1049
+ return decoder_input_ids, model_kwargs
1050
+
1051
+ @torch.no_grad()
1052
+ def generate_export(
1053
+ self,
1054
+ encoder_outputs,
1055
+ model_kwargs,
1056
+ ):
1057
+ use_parallel = self.config_decoder.use_parallel
1058
+ parallel_step = self.config_decoder.parallel_step
1059
+ batch_size = encoder_outputs["last_hidden_state"].shape[0]
1060
+ generation_config = {
1061
+ "decoder_start_token_id": 0,
1062
+ "bos_token_id": 0,
1063
+ }
1064
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
1065
+ batch_size=batch_size,
1066
+ model_kwargs=model_kwargs,
1067
+ decoder_start_token_id=generation_config["decoder_start_token_id"],
1068
+ bos_token_id=generation_config["bos_token_id"],
1069
+ )
1070
+ if not use_parallel:
1071
+ input_ids = input_ids.reshape([-1, 1])
1072
+ decoder_input_ids = input_ids
1073
+ model_kwargs["key use_cache"] = True
1074
+ batch_size, cur_len = input_ids.shape
1075
+
1076
+ if "inputs_embeds" in model_kwargs:
1077
+ cur_len = model_kwargs["inputs_embeds"].shape[1]
1078
+
1079
+ cache_position = torch.arange(cur_len)
1080
+ pad_token_id = self.pad_token_id
1081
+ eos_token_id = [self.eos_token_id]
1082
+ eos_token = self.eos_token_id
1083
+ if use_parallel:
1084
+ unfinished_sequences = torch.ones(
1085
+ [batch_size, parallel_step], dtype=torch.int64, device=self.device
1086
+ )
1087
+ parallel_length = math.ceil(self.max_seq_len // parallel_step)
1088
+ else:
1089
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.int64, device=self.device)
1090
+ parallel_length = self.max_seq_len
1091
+
1092
+ i_idx = 0
1093
+ past_key_values = []
1094
+ decoder_attention_heads = self.config_decoder.decoder_attention_heads
1095
+ decoder_attention_heads_dim = int(
1096
+ self.config_decoder.d_model / decoder_attention_heads
1097
+ )
1098
+ for i in range(self.config_decoder.decoder_layers):
1099
+ init_arr = torch.zeros(
1100
+ [batch_size, decoder_attention_heads, 0, decoder_attention_heads_dim]
1101
+ )
1102
+ cache = (init_arr, init_arr, init_arr, init_arr)
1103
+ past_key_values.append(cache)
1104
+
1105
+ while i_idx < parallel_length:
1106
+
1107
+ model_inputs = self.prepare_inputs_for_generation_export(
1108
+ past_key_values=past_key_values, **model_kwargs
1109
+ )
1110
+ decoder_attention_mask = torch.ones(input_ids.shape, device=self.device)
1111
+
1112
+ outputs = self.generate_single_iter(
1113
+ decoder_input_ids=decoder_input_ids,
1114
+ decoder_attention_mask=decoder_attention_mask,
1115
+ encoder_outputs=encoder_outputs,
1116
+ past_key_values=past_key_values,
1117
+ return_dict=True,
1118
+ output_attentions=False,
1119
+ output_hidden_states=False,
1120
+ )
1121
+
1122
+ if use_parallel:
1123
+ next_token_logits = outputs.logits[:, -parallel_step:, :]
1124
+ else:
1125
+ next_token_logits = outputs.logits[:, -1, :]
1126
+ next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
1127
+ next_tokens = torch.argmax(next_tokens_scores, dim=-1)
1128
+
1129
+ if eos_token_id is not None:
1130
+ # False
1131
+ if pad_token_id is None:
1132
+ raise ValueError(
1133
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
1134
+ )
1135
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1136
+ 1 - unfinished_sequences
1137
+ )
1138
+ if use_parallel:
1139
+ input_ids = torch.concat([input_ids, next_tokens], dim=-1)
1140
+ decoder_input_ids = next_tokens
1141
+ else:
1142
+ input_ids = torch.concat(
1143
+ [input_ids, next_tokens.unsqueeze(1)], dim=-1
1144
+ )
1145
+ decoder_input_ids = next_tokens.unsqueeze(1)
1146
+
1147
+ past_length = past_key_values[0][0].shape[2]
1148
+
1149
+ past_key_values = outputs.past_key_values
1150
+ cache_position = cache_position[-1:] + 1
1151
+ if use_parallel:
1152
+ unfinished_sequences = (
1153
+ unfinished_sequences
1154
+ & ~self.stopping_criteria_parallel(input_ids).to(torch.int64).to(self.device)
1155
+ )
1156
+ else:
1157
+ unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
1158
+ input_ids
1159
+ ).to(torch.int64).to(self.device)
1160
+
1161
+ if (
1162
+ eos_token is not None
1163
+ and (
1164
+ torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
1165
+ >= 1
1166
+ ).all()
1167
+ ):
1168
+ break
1169
+ i_idx += 1
1170
+ # break
1171
+
1172
+ return input_ids
1173
+
1174
+ @torch.no_grad()
1175
+ def generate(
1176
+ self,
1177
+ encoder_outputs,
1178
+ model_kwargs,
1179
+ ):
1180
+ """
1181
+ Generate sequences from the model without computing gradients.
1182
+
1183
+ This method is used to generate sequences from the model based on the given encoder outputs.
1184
+ It does not compute gradients, making it suitable for inference.
1185
+
1186
+ Args:
1187
+ encoder_outputs: The outputs from the encoder, typically including hidden states necessary for generation.
1188
+ model_kwargs: Additional keyword arguments that may include parameters such as maximum length,
1189
+ temperature, top-k/top-p sampling parameters, and other generation-specific settings.
1190
+
1191
+ Returns:
1192
+ Generated sequences based on the encoder outputs and specified generation parameters.
1193
+ """
1194
+ use_parallel = self.config_decoder.use_parallel
1195
+ parallel_step = self.config_decoder.parallel_step
1196
+ batch_size = encoder_outputs["last_hidden_state"].shape[0]
1197
+ generation_config = {
1198
+ "decoder_start_token_id": 0,
1199
+ "bos_token_id": 0,
1200
+ }
1201
+
1202
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
1203
+ batch_size=batch_size,
1204
+ model_kwargs=model_kwargs,
1205
+ decoder_start_token_id=generation_config["decoder_start_token_id"],
1206
+ bos_token_id=generation_config["bos_token_id"],
1207
+ )
1208
+
1209
+ decoder_input_ids = input_ids
1210
+ model_kwargs["key use_cache"] = True
1211
+ batch_size, cur_len = input_ids.shape
1212
+
1213
+ if "inputs_embeds" in model_kwargs:
1214
+ cur_len = model_kwargs["inputs_embeds"].shape[1]
1215
+ model_kwargs["cache_position"] = torch.arange(cur_len)
1216
+ pad_token_id = self.pad_token_id
1217
+ eos_token_id = [self.eos_token_id]
1218
+ eos_token = self.eos_token_id
1219
+ if use_parallel:
1220
+ unfinished_sequences = torch.ones(
1221
+ [batch_size, parallel_step], dtype=torch.int64
1222
+ )
1223
+ parallel_length = math.ceil(self.max_seq_len // parallel_step)
1224
+ else:
1225
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.int64)
1226
+ parallel_length = self.max_seq_len
1227
+ past_key_values = []
1228
+
1229
+ for idx in range(parallel_length):
1230
+
1231
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1232
+ outputs = self.generate_single_iter(
1233
+ **model_inputs,
1234
+ encoder_outputs=encoder_outputs,
1235
+ return_dict=True,
1236
+ output_attentions=False,
1237
+ output_hidden_states=False,
1238
+ )
1239
+
1240
+ if use_parallel:
1241
+ next_token_logits = outputs.logits[:, :, :]
1242
+ else:
1243
+ next_token_logits = outputs.logits[:, -1, :]
1244
+
1245
+ next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
1246
+ next_tokens = torch.argmax(next_tokens_scores, dim=-1)
1247
+ if eos_token_id is not None:
1248
+ # False
1249
+ if pad_token_id is None:
1250
+ raise ValueError(
1251
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
1252
+ )
1253
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1254
+ 1 - unfinished_sequences
1255
+ )
1256
+ if use_parallel:
1257
+ input_ids = torch.concat([input_ids, next_tokens], dim=-1)
1258
+ else:
1259
+ input_ids = torch.concat([input_ids, next_tokens[:, None]], dim=-1)
1260
+
1261
+ model_kwargs = self._update_model_kwargs_for_generation(
1262
+ outputs,
1263
+ model_kwargs,
1264
+ is_encoder_decoder=self.config_decoder.is_encoder_decoder,
1265
+ )
1266
+ if use_parallel:
1267
+ unfinished_sequences = (
1268
+ unfinished_sequences
1269
+ & ~self.stopping_criteria_parallel(input_ids).to(torch.int64)
1270
+ )
1271
+ else:
1272
+ unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
1273
+ input_ids
1274
+ ).to(torch.int64)
1275
+
1276
+ if (
1277
+ eos_token is not None
1278
+ and (
1279
+ torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
1280
+ >= 1
1281
+ ).all()
1282
+ ):
1283
+ break
1284
+ return input_ids
1285
+
1286
+ def forwad_train(
1287
+ self,
1288
+ encoder_outputs,
1289
+ decoder_input_ids,
1290
+ decoder_attention_mask,
1291
+ past_key_values=None,
1292
+ decoder_inputs_embeds=None,
1293
+ labels=None,
1294
+ use_cache=None,
1295
+ output_attentions=None,
1296
+ output_hidden_states=None,
1297
+ return_dict=None,
1298
+ **kwargs,
1299
+ ):
1300
+ """
1301
+ Forward pass for training the model.
1302
+
1303
+ Args:
1304
+ encoder_outputs: The outputs from the encoder, typically including hidden states.
1305
+ decoder_input_ids: Input IDs for the decoder.
1306
+ decoder_attention_mask: Attention mask for the decoder inputs to avoid attending to padding tokens.
1307
+ past_key_values: Previously computed key and value states for the decoder, used for fast generation.
1308
+ decoder_inputs_embeds: Optional embeddings for decoder inputs, used instead of decoder_input_ids if provided.
1309
+ labels: Labels for computing the training loss.
1310
+ use_cache: Whether to use a cache of past key values for faster generation.
1311
+ output_attentions: Whether to output attention weights.
1312
+ output_hidden_states: Whether to output hidden states of all layers.
1313
+ return_dict: Whether to return the output as a dictionary.
1314
+ **kwargs: Additional keyword arguments.
1315
+
1316
+ Returns:
1317
+ Depending on the `return_dict` flag, returns either a dictionary of model outputs or a tuple.
1318
+ """
1319
+ if self.config_decoder.use_parallel:
1320
+ batch = decoder_input_ids.shape[0]
1321
+ add_sos_token = self.config_decoder.parallel_step - 1
1322
+ start_token = torch.zeros([batch, add_sos_token]).to(torch.int64)
1323
+ start_mask = torch.ones([batch, add_sos_token]).to(torch.int64)
1324
+ decoder_input_ids = torch.concat([start_token, decoder_input_ids], dim=1)
1325
+ decoder_attention_mask = torch.concat(
1326
+ [start_mask, decoder_attention_mask], dim=1
1327
+ )
1328
+
1329
+ labels = decoder_input_ids * 1
1330
+ labels = labels.masked_fill_(labels == self.pad_token_id, -100)
1331
+ if self.config_decoder.use_parallel:
1332
+ input_decoder_input_ids = decoder_input_ids[
1333
+ :, : -self.config_decoder.parallel_step
1334
+ ]
1335
+ input_decoder_attention_mask = decoder_attention_mask[
1336
+ :, : -self.config_decoder.parallel_step
1337
+ ]
1338
+ else:
1339
+ input_decoder_input_ids = decoder_input_ids[:, :-1]
1340
+ input_decoder_attention_mask = decoder_attention_mask[:, :-1]
1341
+
1342
+ encoder_hidden_states = encoder_outputs[0]
1343
+ kwargs_decoder = {}
1344
+ if self.config_decoder.hidden_size != self.encoder_hidden_size:
1345
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
1346
+
1347
+ decoder_outputs = self.decoder(
1348
+ input_ids=input_decoder_input_ids,
1349
+ attention_mask=input_decoder_attention_mask,
1350
+ encoder_hidden_states=encoder_hidden_states,
1351
+ encoder_attention_mask=None,
1352
+ inputs_embeds=None,
1353
+ output_attentions=False,
1354
+ output_hidden_states=output_hidden_states,
1355
+ use_cache=use_cache,
1356
+ past_key_values=past_key_values,
1357
+ return_dict=return_dict,
1358
+ **kwargs_decoder,
1359
+ )
1360
+
1361
+ logits = decoder_outputs.logits
1362
+ return logits, labels
1363
+
1364
+ # forward for export
1365
+ def forward(self, inputs, targets=None):
1366
+ self.is_export = False if self.training else True
1367
+ if not self.training:
1368
+ encoder_outputs = inputs
1369
+ model_kwargs = {
1370
+ "output_attentions": False,
1371
+ "output_hidden_states": False,
1372
+ "use_cache": True,
1373
+ }
1374
+ if self.is_export:
1375
+ word_pred = self.generate_export(encoder_outputs, model_kwargs)
1376
+ else:
1377
+ word_pred = self.generate(encoder_outputs, model_kwargs)
1378
+
1379
+ return word_pred
1380
+ encoder_outputs, tgt_seq, mask = inputs
1381
+ logits, masked_labels = self.forwad_train(encoder_outputs, tgt_seq, mask)
1382
+
1383
+ return logits, masked_labels