ipex-llm 2.2.0b20250104__py3-none-win_amd64.whl → 2.2.0b20250105.post0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ipex_llm/libs/bloom-api.dll +0 -0
- ipex_llm/libs/bloom.dll +0 -0
- ipex_llm/libs/gptneox-api.dll +0 -0
- ipex_llm/libs/gptneox.dll +0 -0
- ipex_llm/libs/libbloom_avx.dll +0 -0
- ipex_llm/libs/libbloom_vnni.dll +0 -0
- ipex_llm/libs/libgptneox_avx.dll +0 -0
- ipex_llm/libs/libgptneox_vnni.dll +0 -0
- ipex_llm/libs/libllama_avx.dll +0 -0
- ipex_llm/libs/libllama_vnni.dll +0 -0
- ipex_llm/libs/libstarcoder_avx.dll +0 -0
- ipex_llm/libs/libstarcoder_vnni.dll +0 -0
- ipex_llm/libs/llama-api.dll +0 -0
- ipex_llm/libs/llama.dll +0 -0
- ipex_llm/libs/main-bloom.exe +0 -0
- ipex_llm/libs/main-gptneox.exe +0 -0
- ipex_llm/libs/main-llama.exe +0 -0
- ipex_llm/libs/main-starcoder.exe +0 -0
- ipex_llm/libs/pipeline.dll +0 -0
- ipex_llm/libs/quantize-bloom.exe +0 -0
- ipex_llm/libs/quantize-bloom_vnni.exe +0 -0
- ipex_llm/libs/quantize-gptneox.exe +0 -0
- ipex_llm/libs/quantize-gptneox_vnni.exe +0 -0
- ipex_llm/libs/quantize-llama.exe +0 -0
- ipex_llm/libs/quantize-llama_vnni.exe +0 -0
- ipex_llm/libs/quantize-starcoder.exe +0 -0
- ipex_llm/libs/quantize-starcoder_vnni.exe +0 -0
- ipex_llm/libs/starcoder-api.dll +0 -0
- ipex_llm/libs/starcoder.dll +0 -0
- ipex_llm/transformers/convert.py +17 -132
- ipex_llm/transformers/lookup.py +2 -2
- ipex_llm/transformers/low_bit_linear.py +8 -8
- ipex_llm/transformers/models/chatglm2.py +1 -192
- ipex_llm/transformers/models/minicpmv.py +2 -2
- ipex_llm/transformers/models/sd.py +2 -2
- ipex_llm/transformers/models/utils.py +14 -89
- ipex_llm/transformers/npu_model.py +80 -50
- ipex_llm/transformers/npu_models/convert_mp.py +1 -1
- ipex_llm/transformers/npu_models/linear.py +15 -3
- ipex_llm/transformers/npu_models/lm_head.py +1 -90
- ipex_llm/transformers/npu_models/lm_head_linear.py +106 -0
- ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +5 -8
- ipex_llm/transformers/utils.py +5 -20
- {ipex_llm-2.2.0b20250104.dist-info → ipex_llm-2.2.0b20250105.post0.dist-info}/METADATA +40 -19
- {ipex_llm-2.2.0b20250104.dist-info → ipex_llm-2.2.0b20250105.post0.dist-info}/RECORD +51 -53
- ipex_llm/transformers/models/cohere.py +0 -589
- ipex_llm/transformers/models/falcon.py +0 -829
- ipex_llm/transformers/models/mixtral.py +0 -576
- {ipex_llm-2.2.0b20250104.data → ipex_llm-2.2.0b20250105.post0.data}/scripts/ipex-llm-init.bat +0 -0
- {ipex_llm-2.2.0b20250104.data → ipex_llm-2.2.0b20250105.post0.data}/scripts/llm-chat.ps1 +0 -0
- {ipex_llm-2.2.0b20250104.data → ipex_llm-2.2.0b20250105.post0.data}/scripts/llm-cli.ps1 +0 -0
- {ipex_llm-2.2.0b20250104.dist-info → ipex_llm-2.2.0b20250105.post0.dist-info}/WHEEL +0 -0
- {ipex_llm-2.2.0b20250104.dist-info → ipex_llm-2.2.0b20250105.post0.dist-info}/entry_points.txt +0 -0
- {ipex_llm-2.2.0b20250104.dist-info → ipex_llm-2.2.0b20250105.post0.dist-info}/top_level.txt +0 -0
@@ -1,829 +0,0 @@
|
|
1
|
-
#
|
2
|
-
# Copyright 2016 The BigDL Authors.
|
3
|
-
#
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
#
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
#
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
#
|
16
|
-
# Some parts of this file is adapted from
|
17
|
-
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/falcon/modeling_falcon.py
|
18
|
-
# which is licensed under Apache License 2.0:
|
19
|
-
#
|
20
|
-
# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
|
21
|
-
#
|
22
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
23
|
-
# you may not use this file except in compliance with the License.
|
24
|
-
# You may obtain a copy of the License at
|
25
|
-
#
|
26
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
27
|
-
#
|
28
|
-
# Unless required by applicable law or agreed to in writing, software
|
29
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
30
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
31
|
-
# See the License for the specific language governing permissions and
|
32
|
-
# limitations under the License.
|
33
|
-
"""PyTorch Falcon model."""
|
34
|
-
|
35
|
-
import math
|
36
|
-
from typing import Optional, Tuple
|
37
|
-
|
38
|
-
import torch
|
39
|
-
from torch.nn import functional as F
|
40
|
-
from ipex_llm.utils.common import invalidInputError
|
41
|
-
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
42
|
-
import warnings
|
43
|
-
|
44
|
-
import os
|
45
|
-
|
46
|
-
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
47
|
-
|
48
|
-
|
49
|
-
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
50
|
-
def rotate_half(x):
|
51
|
-
"""Rotates half the hidden dims of the input."""
|
52
|
-
x1 = x[..., : x.shape[-1] // 2]
|
53
|
-
x2 = x[..., x.shape[-1] // 2:]
|
54
|
-
return torch.cat((-x2, x1), dim=-1)
|
55
|
-
|
56
|
-
|
57
|
-
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
58
|
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
59
|
-
"""Applies Rotary Position Embedding to the query and key tensors.
|
60
|
-
Args:
|
61
|
-
q (`torch.Tensor`): The query tensor.
|
62
|
-
k (`torch.Tensor`): The key tensor.
|
63
|
-
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
64
|
-
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
65
|
-
position_ids (`torch.Tensor`):
|
66
|
-
The position indices of the tokens corresponding to the query and key tensors. For
|
67
|
-
example, this can be used to pass offsetted position ids when working with a KV-cache.
|
68
|
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
69
|
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze
|
70
|
-
cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the
|
71
|
-
dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids]
|
72
|
-
have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape
|
73
|
-
[batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
74
|
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k.
|
75
|
-
Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim],
|
76
|
-
then set unsqueeze_dim=2.
|
77
|
-
Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary
|
78
|
-
Position Embedding.
|
79
|
-
"""
|
80
|
-
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
81
|
-
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
82
|
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
83
|
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
84
|
-
return q_embed, k_embed
|
85
|
-
|
86
|
-
|
87
|
-
def rw_attention_forward_7b(
|
88
|
-
self,
|
89
|
-
hidden_states: torch.Tensor,
|
90
|
-
alibi: torch.Tensor,
|
91
|
-
attention_mask: torch.Tensor,
|
92
|
-
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]]=None,
|
93
|
-
head_mask: Optional[torch.Tensor]=None,
|
94
|
-
use_cache: bool=False,
|
95
|
-
output_attentions: bool=False,
|
96
|
-
):
|
97
|
-
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
98
|
-
|
99
|
-
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
100
|
-
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
101
|
-
|
102
|
-
batch_size, q_length, _, _ = query_layer.shape
|
103
|
-
|
104
|
-
query_layer = query_layer.transpose(1, 2).reshape(
|
105
|
-
batch_size * self.num_heads,
|
106
|
-
q_length,
|
107
|
-
self.head_dim
|
108
|
-
)
|
109
|
-
key_layer = key_layer.transpose(1, 2).reshape(
|
110
|
-
batch_size * self.num_kv,
|
111
|
-
q_length,
|
112
|
-
self.head_dim,
|
113
|
-
)
|
114
|
-
value_layer = value_layer.transpose(1, 2).reshape(
|
115
|
-
batch_size * self.num_kv,
|
116
|
-
q_length,
|
117
|
-
self.head_dim
|
118
|
-
)
|
119
|
-
|
120
|
-
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
121
|
-
_, seq_len, _ = query_layer.shape
|
122
|
-
if layer_past is not None:
|
123
|
-
_, seq_len_past, _ = layer_past[0].shape
|
124
|
-
|
125
|
-
seq_len = seq_len + seq_len_past
|
126
|
-
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
|
127
|
-
|
128
|
-
_, kv_length, _ = key_layer.shape
|
129
|
-
if layer_past is not None:
|
130
|
-
kv_length += layer_past[0].shape[-2]
|
131
|
-
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
132
|
-
key_layer = key_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
|
133
|
-
value_layer = value_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
|
134
|
-
|
135
|
-
device = hidden_states.device
|
136
|
-
if layer_past is not None:
|
137
|
-
# reuse k, v, self_attention
|
138
|
-
cache_k = layer_past[0].view(batch_size, self.num_kv, -1, self.head_dim)
|
139
|
-
cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
|
140
|
-
if cache_k.stride()[1] < kv_length * cache_k.size(3):
|
141
|
-
# allocate new
|
142
|
-
new_cache_k, new_cache_v = extend_kv_cache(
|
143
|
-
batch_size,
|
144
|
-
self.num_kv,
|
145
|
-
self.head_dim,
|
146
|
-
cache_k.size(2),
|
147
|
-
kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
148
|
-
dtype=cache_k.dtype,
|
149
|
-
device=device
|
150
|
-
)
|
151
|
-
new_cache_k[:] = cache_k
|
152
|
-
new_cache_v[:] = cache_v
|
153
|
-
cache_k = new_cache_k
|
154
|
-
cache_v = new_cache_v
|
155
|
-
|
156
|
-
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
|
157
|
-
|
158
|
-
elif use_cache:
|
159
|
-
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
160
|
-
new_key_states, new_value_states = init_kv_cache(
|
161
|
-
batch_size,
|
162
|
-
self.num_kv,
|
163
|
-
self.head_dim,
|
164
|
-
kv_length,
|
165
|
-
max_cache_length,
|
166
|
-
dtype=key_layer.dtype,
|
167
|
-
device=device
|
168
|
-
)
|
169
|
-
new_key_states[:] = key_layer
|
170
|
-
new_value_states[:] = value_layer
|
171
|
-
key_layer = new_key_states
|
172
|
-
value_layer = new_value_states
|
173
|
-
|
174
|
-
query_layer = query_layer.view(batch_size*self.num_heads, -1, self.head_dim)
|
175
|
-
key_layer = key_layer.view(batch_size*self.num_kv, -1, self.head_dim)
|
176
|
-
value_layer = value_layer.view(batch_size*self.num_kv, -1, self.head_dim)
|
177
|
-
_, kv_length, _ = key_layer.shape
|
178
|
-
if use_cache is True:
|
179
|
-
present = (key_layer, value_layer)
|
180
|
-
else:
|
181
|
-
present = None
|
182
|
-
|
183
|
-
if alibi is None:
|
184
|
-
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
185
|
-
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
186
|
-
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
187
|
-
|
188
|
-
# attn_output = F.scaled_dot_product_attention(
|
189
|
-
# query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
190
|
-
# )
|
191
|
-
if layer_past is not None:
|
192
|
-
L = query_layer_.shape[-2]
|
193
|
-
S = key_layer_.shape[-2]
|
194
|
-
attn_mask = torch.ones(L, S, dtype=torch.bool, device=query_layer_.device)
|
195
|
-
attn_output = F.scaled_dot_product_attention(
|
196
|
-
query_layer_, key_layer_, value_layer_, attn_mask, 0.0, is_causal=False
|
197
|
-
)
|
198
|
-
else:
|
199
|
-
attn_output = F.scaled_dot_product_attention(
|
200
|
-
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
201
|
-
)
|
202
|
-
|
203
|
-
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
204
|
-
x = x.permute(0, 2, 1, 3)
|
205
|
-
attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
|
206
|
-
|
207
|
-
output_tensor = self.dense(attn_output)
|
208
|
-
|
209
|
-
outputs = (output_tensor, present)
|
210
|
-
if output_attentions:
|
211
|
-
invalidInputError(False,
|
212
|
-
f"'output_attentions' are not supported yet")
|
213
|
-
return outputs
|
214
|
-
else:
|
215
|
-
attention_mask_float = (attention_mask * 1.0) \
|
216
|
-
.masked_fill(attention_mask, -1e9).to(torch.bfloat16)
|
217
|
-
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
218
|
-
|
219
|
-
# change view to [batch_size, num_heads, q_length, kv_length]
|
220
|
-
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
221
|
-
|
222
|
-
# cast attention scores to fp32,
|
223
|
-
# compute scaled softmax and cast back to initial dtype
|
224
|
-
# - [batch_size, num_heads, q_length, kv_length]
|
225
|
-
input_dtype = attention_scores.dtype
|
226
|
-
# `float16` has a minimum value of -65504.0,
|
227
|
-
# whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
228
|
-
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
229
|
-
attention_scores = attention_scores.to(torch.float32)
|
230
|
-
# attn_weights = torch. \
|
231
|
-
# masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
232
|
-
attention_probs = F.softmax(
|
233
|
-
(attention_scores + alibi) * self.inv_norm_factor + attention_mask_float,
|
234
|
-
dim=-1,
|
235
|
-
dtype=hidden_states.dtype,
|
236
|
-
)
|
237
|
-
# [batch_size, num_heads, q_length, kv_length]
|
238
|
-
attention_probs = self.attention_dropout(attention_probs)
|
239
|
-
|
240
|
-
if head_mask is not None:
|
241
|
-
attention_probs = attention_probs * head_mask
|
242
|
-
|
243
|
-
# change view [batch_size x num_heads, q_length, kv_length]
|
244
|
-
attention_probs_reshaped = attention_probs.view(
|
245
|
-
batch_size * self.num_heads,
|
246
|
-
q_length,
|
247
|
-
kv_length
|
248
|
-
)
|
249
|
-
|
250
|
-
# matmul: [batch_size * num_heads, q_length, head_dim]
|
251
|
-
context_layer = attention_probs_reshaped @ value_layer
|
252
|
-
|
253
|
-
# change view [batch_size, num_heads, q_length, head_dim]
|
254
|
-
context_layer = self._merge_heads(context_layer)
|
255
|
-
|
256
|
-
output_tensor = self.dense(context_layer)
|
257
|
-
|
258
|
-
outputs = (output_tensor, present)
|
259
|
-
if output_attentions:
|
260
|
-
outputs += (attention_probs,)
|
261
|
-
|
262
|
-
return outputs
|
263
|
-
|
264
|
-
|
265
|
-
def rw_attention_forward_40b(
|
266
|
-
self,
|
267
|
-
hidden_states: torch.Tensor,
|
268
|
-
alibi: torch.Tensor,
|
269
|
-
attention_mask: torch.Tensor,
|
270
|
-
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]]=None,
|
271
|
-
head_mask: Optional[torch.Tensor]=None,
|
272
|
-
use_cache: bool=False,
|
273
|
-
output_attentions: bool=False,
|
274
|
-
):
|
275
|
-
# [batch_size, seq_length, 3 x hidden_size]
|
276
|
-
fused_qkv = self.query_key_value(hidden_states)
|
277
|
-
|
278
|
-
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
279
|
-
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
280
|
-
|
281
|
-
batch_size, q_length, _, _ = query_layer.shape
|
282
|
-
|
283
|
-
query_layer = query_layer.transpose(1, 2).reshape(
|
284
|
-
batch_size * self.num_heads,
|
285
|
-
q_length,
|
286
|
-
self.head_dim
|
287
|
-
)
|
288
|
-
key_layer = key_layer.transpose(1, 2).reshape(
|
289
|
-
batch_size * self.num_heads,
|
290
|
-
q_length,
|
291
|
-
self.head_dim,
|
292
|
-
)
|
293
|
-
value_layer = value_layer.transpose(1, 2).reshape(
|
294
|
-
batch_size * self.num_heads,
|
295
|
-
q_length,
|
296
|
-
self.head_dim
|
297
|
-
)
|
298
|
-
|
299
|
-
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
300
|
-
_, seq_len, _ = query_layer.shape
|
301
|
-
if layer_past is not None:
|
302
|
-
_, seq_len_past, _ = layer_past[0].shape
|
303
|
-
|
304
|
-
seq_len = seq_len + seq_len_past
|
305
|
-
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
|
306
|
-
|
307
|
-
_, kv_length, _ = key_layer.shape
|
308
|
-
if layer_past is not None:
|
309
|
-
kv_length += layer_past[0].shape[-2]
|
310
|
-
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
311
|
-
key_layer = key_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
312
|
-
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
|
313
|
-
|
314
|
-
device = hidden_states.device
|
315
|
-
if layer_past is not None:
|
316
|
-
# reuse k, v, self_attention
|
317
|
-
cache_k = layer_past[0].view(batch_size, self.num_heads, -1, self.head_dim)
|
318
|
-
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
319
|
-
if cache_k.stride()[1] < kv_length * cache_k.size(3):
|
320
|
-
# allocate new
|
321
|
-
new_cache_k, new_cache_v = extend_kv_cache(
|
322
|
-
batch_size,
|
323
|
-
self.num_heads,
|
324
|
-
self.head_dim,
|
325
|
-
cache_k.size(2),
|
326
|
-
kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
327
|
-
dtype=cache_k.dtype,
|
328
|
-
device=device
|
329
|
-
)
|
330
|
-
new_cache_k[:] = cache_k
|
331
|
-
new_cache_v[:] = cache_v
|
332
|
-
cache_k = new_cache_k
|
333
|
-
cache_v = new_cache_v
|
334
|
-
|
335
|
-
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
|
336
|
-
|
337
|
-
elif use_cache:
|
338
|
-
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
339
|
-
new_key_states, new_value_states = init_kv_cache(
|
340
|
-
batch_size,
|
341
|
-
self.num_heads,
|
342
|
-
self.head_dim,
|
343
|
-
kv_length,
|
344
|
-
max_cache_length,
|
345
|
-
dtype=key_layer.dtype,
|
346
|
-
device=device
|
347
|
-
)
|
348
|
-
new_key_states[:] = key_layer
|
349
|
-
new_value_states[:] = value_layer
|
350
|
-
key_layer = new_key_states
|
351
|
-
value_layer = new_value_states
|
352
|
-
|
353
|
-
query_layer = query_layer.view(batch_size*self.num_heads, -1, self.head_dim)
|
354
|
-
key_layer = key_layer.view(batch_size*self.num_heads, -1, self.head_dim)
|
355
|
-
value_layer = value_layer.view(batch_size*self.num_heads, -1, self.head_dim)
|
356
|
-
_, kv_length, _ = key_layer.shape
|
357
|
-
if use_cache is True:
|
358
|
-
present = (key_layer, value_layer)
|
359
|
-
else:
|
360
|
-
present = None
|
361
|
-
|
362
|
-
if alibi is None:
|
363
|
-
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
364
|
-
key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
365
|
-
value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
366
|
-
|
367
|
-
# attn_output = F.scaled_dot_product_attention(
|
368
|
-
# query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
369
|
-
# )
|
370
|
-
if present is not None:
|
371
|
-
L = query_layer_.shape[-2]
|
372
|
-
S = key_layer_.shape[-2]
|
373
|
-
attn_mask = torch.ones(L, S, dtype=torch.bool, device=query_layer_.device)
|
374
|
-
attn_output = F.scaled_dot_product_attention(
|
375
|
-
query_layer_, key_layer_, value_layer_, attn_mask, 0.0, is_causal=False
|
376
|
-
)
|
377
|
-
else:
|
378
|
-
attn_output = F.scaled_dot_product_attention(
|
379
|
-
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
380
|
-
)
|
381
|
-
|
382
|
-
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
383
|
-
x = x.permute(0, 2, 1, 3)
|
384
|
-
attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
|
385
|
-
|
386
|
-
output_tensor = self.dense(attn_output)
|
387
|
-
|
388
|
-
outputs = (output_tensor, present)
|
389
|
-
if output_attentions:
|
390
|
-
invalidInputError(False,
|
391
|
-
f"'output_attentions' are not supported yet")
|
392
|
-
return outputs
|
393
|
-
else:
|
394
|
-
attention_mask_float = (attention_mask * 1.0) \
|
395
|
-
.masked_fill(attention_mask, -1e9).to(torch.bfloat16)
|
396
|
-
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
397
|
-
|
398
|
-
# change view to [batch_size, num_heads, q_length, kv_length]
|
399
|
-
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
400
|
-
|
401
|
-
# cast attention scores to fp32,
|
402
|
-
# compute scaled softmax and cast back to initial dtype
|
403
|
-
# - [batch_size, num_heads, q_length, kv_length]
|
404
|
-
input_dtype = attention_scores.dtype
|
405
|
-
# `float16` has a minimum value of -65504.0,
|
406
|
-
# whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
407
|
-
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
408
|
-
attention_scores = attention_scores.to(torch.float32)
|
409
|
-
# attn_weights = torch \
|
410
|
-
# .masked_fill(
|
411
|
-
# attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
412
|
-
attention_probs = F.softmax(
|
413
|
-
(attention_scores + alibi.view(batch_size, self.num_heads, 1, -1))
|
414
|
-
* self.inv_norm_factor + attention_mask_float,
|
415
|
-
dim=-1,
|
416
|
-
dtype=hidden_states.dtype,
|
417
|
-
)
|
418
|
-
# [batch_size, num_heads, q_length, kv_length]
|
419
|
-
attention_probs = self.attention_dropout(attention_probs)
|
420
|
-
|
421
|
-
if head_mask is not None:
|
422
|
-
attention_probs = attention_probs * head_mask
|
423
|
-
|
424
|
-
# change view [batch_size x num_heads, q_length, kv_length]
|
425
|
-
attention_probs_reshaped = attention_probs.view(
|
426
|
-
batch_size * self.num_heads,
|
427
|
-
q_length,
|
428
|
-
kv_length
|
429
|
-
)
|
430
|
-
|
431
|
-
# matmul: [batch_size * num_heads, q_length, head_dim]
|
432
|
-
context_layer = attention_probs_reshaped @ value_layer
|
433
|
-
|
434
|
-
# change view [batch_size, num_heads, q_length, head_dim]
|
435
|
-
context_layer = self._merge_heads(context_layer)
|
436
|
-
|
437
|
-
output_tensor = self.dense(context_layer)
|
438
|
-
|
439
|
-
outputs = (output_tensor, present)
|
440
|
-
if output_attentions:
|
441
|
-
outputs += (attention_probs,)
|
442
|
-
return outputs
|
443
|
-
|
444
|
-
|
445
|
-
def falcon_attention_forward(
|
446
|
-
self,
|
447
|
-
hidden_states: torch.Tensor,
|
448
|
-
alibi: Optional[torch.Tensor],
|
449
|
-
attention_mask: torch.Tensor,
|
450
|
-
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]]=None,
|
451
|
-
head_mask: Optional[torch.Tensor]=None,
|
452
|
-
use_cache: bool=False,
|
453
|
-
output_attentions: bool=False,
|
454
|
-
):
|
455
|
-
# [batch_size, seq_length, 3 x hidden_size]
|
456
|
-
fused_qkv = self.query_key_value(hidden_states)
|
457
|
-
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
458
|
-
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
459
|
-
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
460
|
-
|
461
|
-
batch_size, query_length, _, _ = query_layer.shape
|
462
|
-
|
463
|
-
query_layer = query_layer.transpose(1, 2).reshape(
|
464
|
-
batch_size * self.num_heads,
|
465
|
-
query_length,
|
466
|
-
self.head_dim
|
467
|
-
)
|
468
|
-
key_layer = key_layer.transpose(1, 2).reshape(
|
469
|
-
batch_size * num_kv_heads,
|
470
|
-
query_length,
|
471
|
-
self.head_dim,
|
472
|
-
)
|
473
|
-
value_layer = value_layer.transpose(1, 2).reshape(
|
474
|
-
batch_size * num_kv_heads,
|
475
|
-
query_length,
|
476
|
-
self.head_dim
|
477
|
-
)
|
478
|
-
|
479
|
-
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
480
|
-
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
481
|
-
|
482
|
-
_, kv_length, _ = key_layer.shape
|
483
|
-
if layer_past is not None:
|
484
|
-
kv_length += layer_past[0].shape[-2]
|
485
|
-
query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
|
486
|
-
key_layer = key_layer.view(batch_size, num_kv_heads, query_length, self.head_dim)
|
487
|
-
value_layer = value_layer.view(batch_size, num_kv_heads, query_length, self.head_dim)
|
488
|
-
device = hidden_states.device
|
489
|
-
if layer_past is not None:
|
490
|
-
# reuse k, v, self_attention
|
491
|
-
cache_k = layer_past[0].view(batch_size, num_kv_heads, -1, self.head_dim)
|
492
|
-
cache_v = layer_past[1].view(batch_size, num_kv_heads, -1, self.head_dim)
|
493
|
-
if cache_k.stride()[1] < kv_length * cache_k.size(3):
|
494
|
-
# allocate new
|
495
|
-
new_cache_k, new_cache_v = extend_kv_cache(
|
496
|
-
batch_size,
|
497
|
-
num_kv_heads,
|
498
|
-
self.head_dim,
|
499
|
-
cache_k.size(2),
|
500
|
-
kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
501
|
-
dtype=cache_k.dtype,
|
502
|
-
device=device
|
503
|
-
)
|
504
|
-
new_cache_k[:] = cache_k
|
505
|
-
new_cache_v[:] = cache_v
|
506
|
-
cache_k = new_cache_k
|
507
|
-
cache_v = new_cache_v
|
508
|
-
|
509
|
-
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
|
510
|
-
|
511
|
-
elif use_cache:
|
512
|
-
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
|
513
|
-
new_key_states, new_value_states = init_kv_cache(
|
514
|
-
batch_size,
|
515
|
-
num_kv_heads,
|
516
|
-
self.head_dim,
|
517
|
-
kv_length,
|
518
|
-
max_cache_length,
|
519
|
-
dtype=key_layer.dtype,
|
520
|
-
device=device
|
521
|
-
)
|
522
|
-
new_key_states[:] = key_layer
|
523
|
-
new_value_states[:] = value_layer
|
524
|
-
key_layer = new_key_states
|
525
|
-
value_layer = new_value_states
|
526
|
-
|
527
|
-
query_layer = query_layer.view(batch_size * self.num_heads, -1, self.head_dim)
|
528
|
-
key_layer = key_layer.view(batch_size * num_kv_heads, -1, self.head_dim)
|
529
|
-
value_layer = value_layer.view(batch_size * num_kv_heads, -1, self.head_dim)
|
530
|
-
_, kv_length, _ = key_layer.shape
|
531
|
-
if use_cache:
|
532
|
-
present = (key_layer, value_layer)
|
533
|
-
else:
|
534
|
-
present = None
|
535
|
-
|
536
|
-
attention_mask_float = (attention_mask * 1.0) \
|
537
|
-
.masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
538
|
-
|
539
|
-
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
540
|
-
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
541
|
-
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
542
|
-
|
543
|
-
if alibi is None:
|
544
|
-
if output_attentions:
|
545
|
-
# F.scaled_dot_product_attention doesn't return the attention weights, so we have
|
546
|
-
# to do it by hand if we want them
|
547
|
-
attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
|
548
|
-
attention_scores /= math.sqrt(self.head_dim)
|
549
|
-
|
550
|
-
attention_scores = F.softmax(
|
551
|
-
attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
|
552
|
-
)
|
553
|
-
attn_output = attention_scores @ value_layer_
|
554
|
-
else:
|
555
|
-
attn_output = F.scaled_dot_product_attention(
|
556
|
-
query_layer_,
|
557
|
-
key_layer_,
|
558
|
-
value_layer_,
|
559
|
-
attention_mask_float,
|
560
|
-
0.0,
|
561
|
-
is_causal=False
|
562
|
-
)
|
563
|
-
attention_scores = None
|
564
|
-
|
565
|
-
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
|
566
|
-
attn_output = attn_output.permute(0, 2, 1, 3)
|
567
|
-
attn_output = attn_output.reshape(
|
568
|
-
batch_size,
|
569
|
-
query_length,
|
570
|
-
self.num_heads * self.head_dim
|
571
|
-
)
|
572
|
-
|
573
|
-
output_tensor = self.dense(attn_output)
|
574
|
-
|
575
|
-
if output_attentions:
|
576
|
-
return output_tensor, present, attention_scores
|
577
|
-
else:
|
578
|
-
return output_tensor, present
|
579
|
-
|
580
|
-
else:
|
581
|
-
matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
|
582
|
-
|
583
|
-
# change view to [batch_size, num_heads, q_length, kv_length]
|
584
|
-
attention_scores = matmul_result.view(
|
585
|
-
batch_size,
|
586
|
-
self.num_heads,
|
587
|
-
query_length,
|
588
|
-
kv_length
|
589
|
-
)
|
590
|
-
|
591
|
-
# cast attention scores to fp32,
|
592
|
-
# compute scaled softmax and cast back to initial dtype
|
593
|
-
# - [batch_size, num_heads, q_length, kv_length]
|
594
|
-
input_dtype = attention_scores.dtype
|
595
|
-
# `float16` has a minimum value of -65504.0,
|
596
|
-
# whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
597
|
-
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
598
|
-
attention_scores = attention_scores.to(torch.float32)
|
599
|
-
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
|
600
|
-
# adding (alibi * self.inv_norm_factor) to attention_mask_float.
|
601
|
-
# I think this would be mathematically
|
602
|
-
# equivalent and more performant, but there might be a numerical difference.
|
603
|
-
# If you're reading this
|
604
|
-
# and you'd like to experiment and maybe file a PR, feel free!
|
605
|
-
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
|
606
|
-
attention_logits *= self.inv_norm_factor
|
607
|
-
attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1,
|
608
|
-
dtype=hidden_states.dtype)
|
609
|
-
# [batch_size, num_heads, q_length, kv_length]
|
610
|
-
attention_probs = self.attention_dropout(attention_probs)
|
611
|
-
|
612
|
-
if head_mask is not None:
|
613
|
-
attention_probs = attention_probs * head_mask
|
614
|
-
|
615
|
-
# change view [batch_size, num_heads, q_length, kv_length]
|
616
|
-
attention_probs_reshaped = attention_probs.view(
|
617
|
-
batch_size,
|
618
|
-
self.num_heads,
|
619
|
-
query_length,
|
620
|
-
kv_length
|
621
|
-
)
|
622
|
-
|
623
|
-
# matmul: [batch_size * num_heads, q_length, head_dim]
|
624
|
-
context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
|
625
|
-
|
626
|
-
# change view [batch_size, q_length, num_heads * head_dim]
|
627
|
-
context_layer = self._merge_heads(context_layer)
|
628
|
-
|
629
|
-
output_tensor = self.dense(context_layer)
|
630
|
-
|
631
|
-
if output_attentions:
|
632
|
-
return output_tensor, present, attention_probs
|
633
|
-
else:
|
634
|
-
return output_tensor, present
|
635
|
-
|
636
|
-
|
637
|
-
def falcon_attention_forward_4_36(
|
638
|
-
self,
|
639
|
-
hidden_states: torch.Tensor,
|
640
|
-
alibi: Optional[torch.Tensor],
|
641
|
-
attention_mask: torch.Tensor,
|
642
|
-
position_ids: Optional[torch.LongTensor]=None,
|
643
|
-
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]]=None,
|
644
|
-
head_mask: Optional[torch.Tensor]=None,
|
645
|
-
use_cache: bool=False,
|
646
|
-
output_attentions: bool=False,
|
647
|
-
**kwargs,
|
648
|
-
):
|
649
|
-
""" based on transformers==4.36.0
|
650
|
-
https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/falcon/modeling_falcon.py
|
651
|
-
"""
|
652
|
-
if "padding_mask" in kwargs:
|
653
|
-
warnings.warn(
|
654
|
-
"Passing `padding_mask` is deprecated and will be removed in v4.37. \
|
655
|
-
Please make sure use `attention_mask` instead.`"
|
656
|
-
)
|
657
|
-
|
658
|
-
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
659
|
-
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
660
|
-
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
661
|
-
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
662
|
-
|
663
|
-
batch_size, query_length, _, _ = query_layer.shape
|
664
|
-
|
665
|
-
query_layer = query_layer.transpose(1, 2).reshape(
|
666
|
-
batch_size, self.num_heads, query_length, self.head_dim)
|
667
|
-
key_layer = key_layer.transpose(1, 2).reshape(
|
668
|
-
batch_size, num_kv_heads, query_length, self.head_dim)
|
669
|
-
value_layer = value_layer.transpose(1, 2).reshape(
|
670
|
-
batch_size, num_kv_heads, query_length, self.head_dim)
|
671
|
-
|
672
|
-
kv_seq_len = key_layer.shape[-2]
|
673
|
-
device = hidden_states.device
|
674
|
-
|
675
|
-
if layer_past is not None:
|
676
|
-
kv_seq_len += layer_past[0].shape[-2]
|
677
|
-
|
678
|
-
if alibi is None:
|
679
|
-
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
|
680
|
-
query_layer, key_layer = apply_rotary_pos_emb(
|
681
|
-
query_layer, key_layer, cos, sin, position_ids)
|
682
|
-
|
683
|
-
if layer_past is not None:
|
684
|
-
# reuse k, v, self_attention
|
685
|
-
cache_k = layer_past[0].view(batch_size, self.num_heads, -1, self.head_dim)
|
686
|
-
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
|
687
|
-
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
688
|
-
# allocate new
|
689
|
-
new_cache_k, new_cache_v = extend_kv_cache(
|
690
|
-
batch_size,
|
691
|
-
self.num_heads,
|
692
|
-
self.head_dim,
|
693
|
-
cache_k.size(2),
|
694
|
-
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
695
|
-
dtype=cache_k.dtype,
|
696
|
-
device=device
|
697
|
-
)
|
698
|
-
new_cache_k[:] = cache_k
|
699
|
-
new_cache_v[:] = cache_v
|
700
|
-
cache_k = new_cache_k
|
701
|
-
cache_v = new_cache_v
|
702
|
-
|
703
|
-
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
|
704
|
-
|
705
|
-
elif use_cache:
|
706
|
-
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
707
|
-
new_key_states, new_value_states = init_kv_cache(
|
708
|
-
batch_size,
|
709
|
-
self.num_heads,
|
710
|
-
self.head_dim,
|
711
|
-
kv_seq_len,
|
712
|
-
max_cache_length,
|
713
|
-
dtype=key_layer.dtype,
|
714
|
-
device=device
|
715
|
-
)
|
716
|
-
new_key_states[:] = key_layer
|
717
|
-
new_value_states[:] = value_layer
|
718
|
-
key_layer = new_key_states
|
719
|
-
value_layer = new_value_states
|
720
|
-
|
721
|
-
query_layer = query_layer.view(batch_size, self.num_heads, -1, self.head_dim)
|
722
|
-
key_layer = key_layer.view(batch_size, self.num_heads, -1, self.head_dim)
|
723
|
-
value_layer = value_layer.view(batch_size, self.num_heads, -1, self.head_dim)
|
724
|
-
|
725
|
-
kv_length = key_layer.shape[-2]
|
726
|
-
if use_cache:
|
727
|
-
present = (key_layer, value_layer)
|
728
|
-
else:
|
729
|
-
present = None
|
730
|
-
|
731
|
-
# SDPA with memory-efficient backend is currently (torch==2.1.2)
|
732
|
-
# bugged with non-contiguous inputs with custom attn_mask,
|
733
|
-
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
734
|
-
if query_layer.device.type == "cuda" and attention_mask is not None:
|
735
|
-
query_layer = query_layer.contiguous()
|
736
|
-
key_layer = key_layer.contiguous()
|
737
|
-
value_layer = value_layer.contiguous()
|
738
|
-
|
739
|
-
if alibi is None:
|
740
|
-
if self._use_sdpa and not output_attentions:
|
741
|
-
attn_output = F.scaled_dot_product_attention(
|
742
|
-
query_layer,
|
743
|
-
key_layer,
|
744
|
-
value_layer,
|
745
|
-
attention_mask,
|
746
|
-
0.0,
|
747
|
-
# The query_length > 1 is necessary to match with
|
748
|
-
# AttentionMaskConverter.to_causal_4d that does not create a causal mask in case
|
749
|
-
# query_length == 1.
|
750
|
-
is_causal=self.is_causal and attention_mask is None and query_length > 1,
|
751
|
-
)
|
752
|
-
attention_scores = None
|
753
|
-
else:
|
754
|
-
attention_scores = query_layer @ key_layer.transpose(-1, -2)
|
755
|
-
attention_scores /= math.sqrt(self.head_dim)
|
756
|
-
|
757
|
-
attention_scores = F.softmax(
|
758
|
-
attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
|
759
|
-
# It is unclear why neither dropout nor head_mask is applied here
|
760
|
-
# (while it is with alibi).
|
761
|
-
attn_output = attention_scores @ value_layer
|
762
|
-
|
763
|
-
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
|
764
|
-
attn_output = attn_output.permute(0, 2, 1, 3)
|
765
|
-
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
|
766
|
-
|
767
|
-
attn_output = self.dense(attn_output)
|
768
|
-
|
769
|
-
if output_attentions:
|
770
|
-
return attn_output, present, attention_scores
|
771
|
-
else:
|
772
|
-
return attn_output, present
|
773
|
-
|
774
|
-
else:
|
775
|
-
if self._use_sdpa and not output_attentions and head_mask is None:
|
776
|
-
attn_output = F.scaled_dot_product_attention(
|
777
|
-
query_layer,
|
778
|
-
key_layer,
|
779
|
-
value_layer,
|
780
|
-
attn_mask=attention_mask,
|
781
|
-
dropout_p=self.attention_dropout.p if self.training else 0.0,
|
782
|
-
is_causal=self.is_causal and attention_mask is None and query_length > 1,
|
783
|
-
)
|
784
|
-
attn_output = attn_output.transpose(1, 2)
|
785
|
-
attn_output = attn_output.reshape(
|
786
|
-
batch_size, query_length, self.num_heads * self.head_dim)
|
787
|
-
|
788
|
-
attn_output = self.dense(attn_output)
|
789
|
-
else:
|
790
|
-
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
791
|
-
|
792
|
-
# change view to [batch_size, num_heads, q_length, kv_length]
|
793
|
-
attention_scores = matmul_result.view(
|
794
|
-
batch_size, self.num_heads, query_length, kv_length)
|
795
|
-
|
796
|
-
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype -
|
797
|
-
# [batch_size, num_heads, q_length, kv_length]
|
798
|
-
input_dtype = attention_scores.dtype
|
799
|
-
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a
|
800
|
-
# minimum value of `-3.4e+38`
|
801
|
-
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
802
|
-
attention_scores = attention_scores.to(torch.float32)
|
803
|
-
|
804
|
-
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
|
805
|
-
attention_logits *= self.inv_norm_factor
|
806
|
-
attention_probs = F.softmax(
|
807
|
-
attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
|
808
|
-
# [batch_size, num_heads, q_length, kv_length]
|
809
|
-
attention_probs = self.attention_dropout(attention_probs)
|
810
|
-
|
811
|
-
if head_mask is not None:
|
812
|
-
attention_probs = attention_probs * head_mask
|
813
|
-
|
814
|
-
# change view [batch_size, num_heads, q_length, kv_length]
|
815
|
-
attention_probs_reshaped = attention_probs.view(
|
816
|
-
batch_size, self.num_heads, query_length, kv_length)
|
817
|
-
|
818
|
-
# matmul: [batch_size * num_heads, q_length, head_dim]
|
819
|
-
attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1)
|
820
|
-
|
821
|
-
# change view [batch_size, q_length, num_heads * head_dim]
|
822
|
-
attn_output = self._merge_heads(attn_output)
|
823
|
-
|
824
|
-
attn_output = self.dense(attn_output)
|
825
|
-
|
826
|
-
if output_attentions:
|
827
|
-
return attn_output, present, attention_probs
|
828
|
-
else:
|
829
|
-
return attn_output, present
|