ipex-llm 2.2.0b20250106__py3-none-manylinux2010_x86_64.whl → 2.2.0b20250107__py3-none-manylinux2010_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ipex_llm/libs/libbloom_amx.so +0 -0
- ipex_llm/libs/libbloom_avx.so +0 -0
- ipex_llm/libs/libbloom_avx2.so +0 -0
- ipex_llm/libs/libbloom_avx512.so +0 -0
- ipex_llm/libs/libbloom_avxvnni.so +0 -0
- ipex_llm/libs/libgptneox_amx.so +0 -0
- ipex_llm/libs/libgptneox_avx.so +0 -0
- ipex_llm/libs/libgptneox_avx2.so +0 -0
- ipex_llm/libs/libgptneox_avx512.so +0 -0
- ipex_llm/libs/libgptneox_avxvnni.so +0 -0
- ipex_llm/libs/libllama_amx.so +0 -0
- ipex_llm/libs/libllama_avx.so +0 -0
- ipex_llm/libs/libllama_avx2.so +0 -0
- ipex_llm/libs/libllama_avx512.so +0 -0
- ipex_llm/libs/libllama_avxvnni.so +0 -0
- ipex_llm/libs/libstarcoder_amx.so +0 -0
- ipex_llm/libs/libstarcoder_avx.so +0 -0
- ipex_llm/libs/libstarcoder_avx2.so +0 -0
- ipex_llm/libs/libstarcoder_avx512.so +0 -0
- ipex_llm/libs/libstarcoder_avxvnni.so +0 -0
- ipex_llm/libs/quantize-bloom +0 -0
- ipex_llm/libs/quantize-gptneox +0 -0
- ipex_llm/libs/quantize-llama +0 -0
- ipex_llm/libs/quantize-starcoder +0 -0
- ipex_llm/transformers/convert.py +17 -132
- ipex_llm/transformers/lookup.py +2 -2
- ipex_llm/transformers/low_bit_linear.py +8 -8
- ipex_llm/transformers/models/chatglm2.py +1 -192
- ipex_llm/transformers/models/minicpmv.py +2 -2
- ipex_llm/transformers/models/sd.py +2 -2
- ipex_llm/transformers/models/utils.py +14 -89
- ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +5 -8
- ipex_llm/transformers/utils.py +5 -20
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/METADATA +40 -19
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/RECORD +41 -44
- ipex_llm/transformers/models/cohere.py +0 -589
- ipex_llm/transformers/models/falcon.py +0 -829
- ipex_llm/transformers/models/mixtral.py +0 -576
- {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250107.data}/scripts/ipex-llm-init +0 -0
- {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250107.data}/scripts/llm-chat +0 -0
- {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250107.data}/scripts/llm-cli +0 -0
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/WHEEL +0 -0
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/entry_points.txt +0 -0
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/top_level.txt +0 -0
@@ -1,576 +0,0 @@
|
|
1
|
-
#
|
2
|
-
# Copyright 2016 The BigDL Authors.
|
3
|
-
#
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
#
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
#
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
#
|
16
|
-
# Some parts of this file is adapted from
|
17
|
-
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
|
18
|
-
|
19
|
-
# coding=utf-8
|
20
|
-
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
|
21
|
-
#
|
22
|
-
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
23
|
-
# and OPT implementations in this library. It has been modified from its
|
24
|
-
# original forms to accommodate minor architectural differences compared
|
25
|
-
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
26
|
-
#
|
27
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
28
|
-
# you may not use this file except in compliance with the License.
|
29
|
-
# You may obtain a copy of the License at
|
30
|
-
#
|
31
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
32
|
-
#
|
33
|
-
# Unless required by applicable law or agreed to in writing, software
|
34
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
35
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
36
|
-
# See the License for the specific language governing permissions and
|
37
|
-
# limitations under the License.
|
38
|
-
|
39
|
-
""" PyTorch Mixtral model."""
|
40
|
-
import math
|
41
|
-
from typing import Optional, Tuple, Union, List
|
42
|
-
from transformers.modeling_outputs import MoeModelOutputWithPast
|
43
|
-
from transformers.cache_utils import Cache, DynamicCache
|
44
|
-
from transformers.modeling_attn_mask_utils import (
|
45
|
-
_prepare_4d_causal_attention_mask,
|
46
|
-
)
|
47
|
-
|
48
|
-
import torch
|
49
|
-
from torch import nn
|
50
|
-
import torch.nn.functional as F
|
51
|
-
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
52
|
-
from ipex_llm.utils.common import invalidInputError
|
53
|
-
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
54
|
-
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
|
55
|
-
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
56
|
-
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
57
|
-
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
|
58
|
-
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
59
|
-
from ipex_llm.transformers.low_bit_linear import IQ2_XXS
|
60
|
-
|
61
|
-
import os
|
62
|
-
|
63
|
-
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
64
|
-
|
65
|
-
|
66
|
-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
67
|
-
"""
|
68
|
-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
69
|
-
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
|
70
|
-
to (batch, num_attention_heads, seqlen, head_dim)
|
71
|
-
"""
|
72
|
-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
73
|
-
if n_rep == 1:
|
74
|
-
return hidden_states
|
75
|
-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
|
76
|
-
n_rep, slen, head_dim)
|
77
|
-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
78
|
-
|
79
|
-
|
80
|
-
def mixtral_moeblock_forward(self,
|
81
|
-
hidden_states: torch.Tensor):
|
82
|
-
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
83
|
-
hidden_states = hidden_states.view(-1, hidden_dim)
|
84
|
-
bs = hidden_states.shape[0]
|
85
|
-
# router_logits: (batch * sequence_length, n_experts)
|
86
|
-
router_logits = self.gate(hidden_states)
|
87
|
-
|
88
|
-
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
89
|
-
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
90
|
-
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
91
|
-
# we cast back to the input dtype
|
92
|
-
routing_weights = routing_weights.to(hidden_states.dtype)
|
93
|
-
|
94
|
-
if bs == 1:
|
95
|
-
selected_experts = selected_experts[0].cpu().tolist()
|
96
|
-
for idx in range(self.top_k):
|
97
|
-
exp_id = selected_experts[idx]
|
98
|
-
expert_layer = self.experts[exp_id]
|
99
|
-
weight = routing_weights[:, idx]
|
100
|
-
if idx == 0:
|
101
|
-
final_hidden_states = expert_layer(hidden_states, weight)
|
102
|
-
else:
|
103
|
-
final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight)
|
104
|
-
elif bs < 256 and hidden_states.device.type == 'xpu':
|
105
|
-
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim),
|
106
|
-
dtype=hidden_states.dtype, device=hidden_states.device)
|
107
|
-
import xe_linear
|
108
|
-
indexes = xe_linear.get_moe_indexes(selected_experts.to(torch.int32).cpu(), 8)
|
109
|
-
for expert_idx in range(self.num_experts):
|
110
|
-
expert_layer = self.experts[expert_idx]
|
111
|
-
idx_list = indexes[0][expert_idx]
|
112
|
-
top_x_list = indexes[1][expert_idx]
|
113
|
-
if len(idx_list) == 0:
|
114
|
-
continue
|
115
|
-
|
116
|
-
top_x = torch.tensor(top_x_list, dtype=torch.long, device=hidden_states.device)
|
117
|
-
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
118
|
-
current_hidden_states = expert_layer(current_state,
|
119
|
-
routing_weights[top_x_list, idx_list, None])
|
120
|
-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
121
|
-
else:
|
122
|
-
final_hidden_states = torch.zeros(
|
123
|
-
(batch_size * sequence_length, hidden_dim),
|
124
|
-
dtype=hidden_states.dtype,
|
125
|
-
device=hidden_states.device
|
126
|
-
)
|
127
|
-
# One hot encode the selected experts to create an expert mask
|
128
|
-
# this will be used to easily index which expert is going to be sollicitated
|
129
|
-
expert_mask = torch.nn.functional.one_hot(selected_experts,
|
130
|
-
num_classes=self.num_experts).permute(2, 1, 0)
|
131
|
-
|
132
|
-
# Loop over all available experts in the model and perform the computation on each expert
|
133
|
-
for expert_idx in range(self.num_experts):
|
134
|
-
expert_layer = self.experts[expert_idx]
|
135
|
-
idx, top_x = torch.where(expert_mask[expert_idx])
|
136
|
-
|
137
|
-
if top_x.shape[0] == 0:
|
138
|
-
continue
|
139
|
-
|
140
|
-
# in torch it is faster to index using lists than torch tensors
|
141
|
-
top_x_list = top_x.tolist()
|
142
|
-
idx_list = idx.tolist()
|
143
|
-
|
144
|
-
# Index the correct hidden states and compute the expert hidden state for
|
145
|
-
# the current expert. We need to make sure to multiply the output hidden
|
146
|
-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
147
|
-
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
148
|
-
current_hidden_states = expert_layer(current_state,
|
149
|
-
routing_weights[top_x_list, idx_list, None])
|
150
|
-
|
151
|
-
# However `index_add_` only support torch tensors for indexing so we'll use
|
152
|
-
# the `top_x` tensor here.
|
153
|
-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
154
|
-
|
155
|
-
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
156
|
-
return final_hidden_states, router_logits
|
157
|
-
|
158
|
-
|
159
|
-
def mixtral_attention_forward(
|
160
|
-
self,
|
161
|
-
hidden_states: torch.Tensor,
|
162
|
-
attention_mask: Optional[torch.Tensor]=None,
|
163
|
-
position_ids: Optional[torch.LongTensor]=None,
|
164
|
-
past_key_value: Optional[Tuple[torch.Tensor]]=None,
|
165
|
-
output_attentions: bool=False,
|
166
|
-
use_cache: bool=False,
|
167
|
-
padding_mask: Optional[torch.Tensor]=None,
|
168
|
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
169
|
-
bsz, q_len, _ = hidden_states.size()
|
170
|
-
device = hidden_states.device
|
171
|
-
# for flash attention
|
172
|
-
original_dtype = hidden_states.dtype
|
173
|
-
|
174
|
-
use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
|
175
|
-
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
176
|
-
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
177
|
-
use_fuse_rope,
|
178
|
-
enough_kv_room,
|
179
|
-
bsz * q_len)
|
180
|
-
|
181
|
-
if decoding_fast_path:
|
182
|
-
hidden_states = hidden_states.view(1, -1)
|
183
|
-
cache_k = past_key_value.key_cache[self.layer_idx]
|
184
|
-
cache_v = past_key_value.value_cache[self.layer_idx]
|
185
|
-
kv_seq_len = cache_k.shape[-2]
|
186
|
-
import xe_linear
|
187
|
-
query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states,
|
188
|
-
self.q_proj.weight,
|
189
|
-
self.k_proj.weight,
|
190
|
-
self.v_proj.weight,
|
191
|
-
position_ids,
|
192
|
-
cache_k, cache_v,
|
193
|
-
self.q_proj.weight.qtype,
|
194
|
-
self.v_proj.weight.qtype,
|
195
|
-
kv_seq_len,
|
196
|
-
self.head_dim,
|
197
|
-
self.rotary_emb.base,)
|
198
|
-
kv_seq_len += 1
|
199
|
-
# update past_key_value's seem_tokens and kv caches.
|
200
|
-
if self.layer_idx == 0:
|
201
|
-
past_key_value.seen_tokens = kv_seq_len
|
202
|
-
past_key_value.key_cache[self.layer_idx] = key_states
|
203
|
-
past_key_value.value_cache[self.layer_idx] = value_states
|
204
|
-
# diasble it for now as it will cause output change for unknown reason
|
205
|
-
# elif decoding_fast_path and self.q_proj.qtype == IQ2_XXS:
|
206
|
-
# # this path self.v_proj use q4_0
|
207
|
-
# hidden_states = hidden_states.view(1, -1)
|
208
|
-
# cache_k = past_key_value.key_cache[self.layer_idx]
|
209
|
-
# cache_v = past_key_value.value_cache[self.layer_idx]
|
210
|
-
# kv_seq_len = cache_k.shape[-2]
|
211
|
-
# import xe_linear
|
212
|
-
# query_states, key_states = xe_linear.forward_qk(hidden_states,
|
213
|
-
# self.q_proj.weight,
|
214
|
-
# self.k_proj.weight,
|
215
|
-
# position_ids,
|
216
|
-
# cache_k,
|
217
|
-
# self.q_proj.weight.qtype,
|
218
|
-
# kv_seq_len,
|
219
|
-
# self.head_dim,
|
220
|
-
# 10000)
|
221
|
-
# kv_seq_len += 1
|
222
|
-
# # update past_key_value's seem_tokens and kv caches.
|
223
|
-
# if self.layer_idx == 0:
|
224
|
-
# past_key_value.seen_tokens = kv_seq_len
|
225
|
-
# # update value_states
|
226
|
-
# value_states = self.v_proj(hidden_states)
|
227
|
-
# value_states = value_states.view(bsz, q_len,
|
228
|
-
# self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
229
|
-
# new_size = (cache_v.size(0),
|
230
|
-
# cache_v.size(1),
|
231
|
-
# cache_v.size(2) + value_states.size(2),
|
232
|
-
# cache_v.size(3))
|
233
|
-
# new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
|
234
|
-
# new_cache_v[:, :, cache_v.size(2):cache_v.size(2)+value_states.size(2), :] = value_states
|
235
|
-
|
236
|
-
# past_key_value.key_cache[self.layer_idx] = key_states
|
237
|
-
# past_key_value.value_cache[self.layer_idx] = new_cache_v
|
238
|
-
else:
|
239
|
-
query_states = self.q_proj(hidden_states)
|
240
|
-
key_states = self.k_proj(hidden_states)
|
241
|
-
value_states = self.v_proj(hidden_states)
|
242
|
-
|
243
|
-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
244
|
-
key_states = key_states.view(bsz, q_len,
|
245
|
-
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
246
|
-
value_states = value_states.view(bsz, q_len,
|
247
|
-
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
248
|
-
|
249
|
-
kv_seq_len = key_states.shape[-2]
|
250
|
-
if past_key_value is not None:
|
251
|
-
if self.layer_idx is None:
|
252
|
-
invalidInputError(False,
|
253
|
-
"The cache structure has changed since version v4.36. "
|
254
|
-
f"If you are using {self.__class__.__name__} for "
|
255
|
-
"auto-regressive decodingwith k/v caching, please make sure "
|
256
|
-
"to initialize the attention class with a layer index.")
|
257
|
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
258
|
-
|
259
|
-
if use_fuse_rope:
|
260
|
-
import xe_addons
|
261
|
-
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
262
|
-
query_states, key_states)
|
263
|
-
else:
|
264
|
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
265
|
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
266
|
-
cos, sin, position_ids, "mixtral")
|
267
|
-
|
268
|
-
if past_key_value is not None:
|
269
|
-
# update the number of seen tokens
|
270
|
-
if self.layer_idx == 0:
|
271
|
-
past_key_value.seen_tokens += key_states.shape[-2]
|
272
|
-
|
273
|
-
# reuse k, v, self_attention
|
274
|
-
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
|
275
|
-
if len(past_key_value.key_cache) <= self.layer_idx:
|
276
|
-
past_key_value.key_cache.append(key_states)
|
277
|
-
past_key_value.value_cache.append(value_states)
|
278
|
-
else:
|
279
|
-
cache_k = past_key_value.key_cache[self.layer_idx]
|
280
|
-
cache_v = past_key_value.value_cache[self.layer_idx]
|
281
|
-
|
282
|
-
if not enough_kv_room:
|
283
|
-
# allocate new
|
284
|
-
new_c_k, new_c_v = extend_kv_cache(bsz,
|
285
|
-
self.num_key_value_heads, # Support GQA
|
286
|
-
self.head_dim,
|
287
|
-
cache_k.size(2),
|
288
|
-
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
289
|
-
dtype=cache_k.dtype,
|
290
|
-
device=device)
|
291
|
-
|
292
|
-
new_c_k[:] = cache_k
|
293
|
-
new_c_v[:] = cache_v
|
294
|
-
cache_k = new_c_k
|
295
|
-
cache_v = new_c_v
|
296
|
-
|
297
|
-
key_states, value_states = append_kv_cache(cache_k,
|
298
|
-
cache_v,
|
299
|
-
key_states,
|
300
|
-
value_states)
|
301
|
-
|
302
|
-
# update past_key_value
|
303
|
-
past_key_value.key_cache[self.layer_idx] = key_states
|
304
|
-
past_key_value.value_cache[self.layer_idx] = value_states
|
305
|
-
|
306
|
-
if not self.training and not hidden_states.requires_grad:
|
307
|
-
fsdp_flag = use_flash_attention(query_states, key_states)
|
308
|
-
else:
|
309
|
-
fsdp_flag = False
|
310
|
-
if fsdp_flag:
|
311
|
-
attention_dtype = torch.float16 # use fp16 for flash attention
|
312
|
-
else:
|
313
|
-
attention_dtype = original_dtype
|
314
|
-
|
315
|
-
# repeat k/v heads if n_kv_heads < n_heads
|
316
|
-
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
317
|
-
dtype=attention_dtype)
|
318
|
-
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
319
|
-
dtype=attention_dtype)
|
320
|
-
|
321
|
-
if fsdp_flag:
|
322
|
-
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
|
323
|
-
key_states,
|
324
|
-
value_states,
|
325
|
-
is_causal=True)
|
326
|
-
attn_weights = None
|
327
|
-
elif use_sdp(query_states.shape[2], key_states.shape[2], self.head_dim, query_states):
|
328
|
-
import xe_addons
|
329
|
-
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
|
330
|
-
attn_output = attn_output.view(query_states.shape)
|
331
|
-
attn_weights = None
|
332
|
-
else:
|
333
|
-
attn_weights = torch.matmul(
|
334
|
-
query_states.to(key_states.dtype),
|
335
|
-
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
336
|
-
|
337
|
-
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
338
|
-
invalidInputError(
|
339
|
-
False,
|
340
|
-
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},"
|
341
|
-
f" but is {attn_weights.size()}"
|
342
|
-
)
|
343
|
-
|
344
|
-
if attention_mask is not None:
|
345
|
-
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
346
|
-
invalidInputError(
|
347
|
-
False,
|
348
|
-
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
349
|
-
f" but is {attention_mask.size()}"
|
350
|
-
)
|
351
|
-
|
352
|
-
attn_weights = attn_weights + attention_mask
|
353
|
-
|
354
|
-
# upcast attention to fp32
|
355
|
-
attn_weights = nn.functional.\
|
356
|
-
softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
|
357
|
-
attn_output = torch.matmul(attn_weights, value_states)
|
358
|
-
|
359
|
-
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
360
|
-
invalidInputError(
|
361
|
-
False,
|
362
|
-
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)},"
|
363
|
-
f" but is {attn_output.size()}"
|
364
|
-
)
|
365
|
-
|
366
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
367
|
-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
368
|
-
|
369
|
-
attn_output = self.o_proj(attn_output)
|
370
|
-
|
371
|
-
if not output_attentions:
|
372
|
-
attn_weights = None
|
373
|
-
|
374
|
-
return attn_output, attn_weights, past_key_value
|
375
|
-
|
376
|
-
|
377
|
-
def mixtral_mlp_forward(
|
378
|
-
self,
|
379
|
-
x: torch.Tensor,
|
380
|
-
routing_weights
|
381
|
-
) -> torch.Tensor:
|
382
|
-
qtype = getattr(self.w1, "qtype", None)
|
383
|
-
if mlp_fusion_check(x, qtype, self.training):
|
384
|
-
import xe_linear
|
385
|
-
return self.w2(xe_linear.mlp_forward_xpu(
|
386
|
-
x, self.w1.weight.data, self.w3.weight.data,
|
387
|
-
x.shape[0], x.shape[1], self.w1.out_len,
|
388
|
-
SILU, qtype,
|
389
|
-
)) * routing_weights
|
390
|
-
else:
|
391
|
-
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
|
392
|
-
current_hidden_states = self.w2(current_hidden_states)
|
393
|
-
return routing_weights * current_hidden_states
|
394
|
-
|
395
|
-
|
396
|
-
def mixtral_model_forward(
|
397
|
-
self,
|
398
|
-
input_ids: torch.LongTensor = None,
|
399
|
-
attention_mask: Optional[torch.Tensor] = None,
|
400
|
-
position_ids: Optional[torch.LongTensor] = None,
|
401
|
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
402
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
403
|
-
use_cache: Optional[bool] = None,
|
404
|
-
output_attentions: Optional[bool] = None,
|
405
|
-
output_hidden_states: Optional[bool] = None,
|
406
|
-
output_router_logits: Optional[bool] = None,
|
407
|
-
return_dict: Optional[bool] = None,
|
408
|
-
) -> Union[Tuple, MoeModelOutputWithPast]:
|
409
|
-
# to be compatible with transformers>=4.37.0
|
410
|
-
self._use_flash_attention_2 = self.config._attn_implementation == "flash_attention_2"
|
411
|
-
|
412
|
-
output_attentions = output_attentions if output_attentions is not None \
|
413
|
-
else self.config.output_attentions
|
414
|
-
output_router_logits = (
|
415
|
-
output_router_logits if output_router_logits is not None
|
416
|
-
else self.config.output_router_logits
|
417
|
-
)
|
418
|
-
output_hidden_states = (
|
419
|
-
output_hidden_states if output_hidden_states is not None else
|
420
|
-
self.config.output_hidden_states
|
421
|
-
)
|
422
|
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
423
|
-
|
424
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
425
|
-
|
426
|
-
# retrieve input_ids and inputs_embeds
|
427
|
-
if input_ids is not None and inputs_embeds is not None:
|
428
|
-
invalidInputError(False, "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") # noqa
|
429
|
-
elif input_ids is not None:
|
430
|
-
batch_size, seq_length = input_ids.shape
|
431
|
-
elif inputs_embeds is not None:
|
432
|
-
batch_size, seq_length, _ = inputs_embeds.shape
|
433
|
-
else:
|
434
|
-
invalidInputError(False, "You have to specify either decoder_input_ids or decoder_inputs_embeds") # noqa
|
435
|
-
|
436
|
-
past_key_values_length = 0
|
437
|
-
|
438
|
-
if use_cache:
|
439
|
-
use_legacy_cache = not isinstance(past_key_values, Cache)
|
440
|
-
if use_legacy_cache:
|
441
|
-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
442
|
-
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
443
|
-
|
444
|
-
if position_ids is None:
|
445
|
-
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
446
|
-
position_ids = torch.arange(
|
447
|
-
past_key_values_length, seq_length + past_key_values_length,
|
448
|
-
dtype=torch.long, device=device
|
449
|
-
)
|
450
|
-
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
451
|
-
else:
|
452
|
-
position_ids = position_ids.view(-1, seq_length).long()
|
453
|
-
|
454
|
-
if inputs_embeds is None:
|
455
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
456
|
-
|
457
|
-
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
458
|
-
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
459
|
-
if is_padding_right:
|
460
|
-
invalidInputError(
|
461
|
-
False,
|
462
|
-
"You are attempting to perform batched generation with padding_side='right'"
|
463
|
-
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " # noqa
|
464
|
-
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
465
|
-
)
|
466
|
-
|
467
|
-
if self._use_flash_attention_2:
|
468
|
-
# 2d mask is passed through the layers
|
469
|
-
attention_mask = attention_mask \
|
470
|
-
if (attention_mask is not None and 0 in attention_mask) else None
|
471
|
-
else:
|
472
|
-
# 4d mask is passed through the layers
|
473
|
-
attention_mask = _prepare_4d_causal_attention_mask(
|
474
|
-
attention_mask,
|
475
|
-
(batch_size, seq_length),
|
476
|
-
inputs_embeds,
|
477
|
-
past_key_values_length,
|
478
|
-
sliding_window=self.config.sliding_window,
|
479
|
-
)
|
480
|
-
|
481
|
-
hidden_states = inputs_embeds
|
482
|
-
|
483
|
-
if self.gradient_checkpointing and self.training:
|
484
|
-
if use_cache:
|
485
|
-
logger.warning_once(
|
486
|
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." # noqa
|
487
|
-
)
|
488
|
-
use_cache = False
|
489
|
-
|
490
|
-
# decoder layers
|
491
|
-
all_hidden_states = () if output_hidden_states else None
|
492
|
-
all_self_attns = () if output_attentions else None
|
493
|
-
all_router_logits = () if output_router_logits else None
|
494
|
-
next_decoder_cache = None
|
495
|
-
|
496
|
-
for decoder_layer in self.layers:
|
497
|
-
if output_hidden_states:
|
498
|
-
all_hidden_states += (hidden_states,)
|
499
|
-
|
500
|
-
if self.gradient_checkpointing and self.training:
|
501
|
-
layer_outputs = self._gradient_checkpointing_func(
|
502
|
-
decoder_layer.__call__,
|
503
|
-
hidden_states,
|
504
|
-
attention_mask,
|
505
|
-
position_ids,
|
506
|
-
past_key_values,
|
507
|
-
output_attentions,
|
508
|
-
output_router_logits,
|
509
|
-
use_cache,
|
510
|
-
)
|
511
|
-
else:
|
512
|
-
# bigdl-llm changes:
|
513
|
-
#
|
514
|
-
# Avoid moving `attention_mask`` and `position_ids`` to other devices multiple times.
|
515
|
-
#
|
516
|
-
# When the model is partitioned on two different devices using
|
517
|
-
# `accelerate`'s `dispatch``, a hook to move inputs to the correct device is
|
518
|
-
# added to each layer's `forward``, which will result in moving `attention_mask`
|
519
|
-
# and `position_ids`, which allocated on device:0, to other devices for each
|
520
|
-
# decoder layer not in device:0.
|
521
|
-
#
|
522
|
-
# To avoid this, we move `attention_mask` and `position_ids` to the device of
|
523
|
-
# the current layer before the forward call, so that the moving is only done once
|
524
|
-
# for each devices other than devie:0.
|
525
|
-
#
|
526
|
-
curr_device = decoder_layer.input_layernorm.weight.device
|
527
|
-
if attention_mask is not None:
|
528
|
-
attention_mask = attention_mask.to(curr_device)
|
529
|
-
if position_ids is not None:
|
530
|
-
position_ids = position_ids.to(curr_device)
|
531
|
-
# bigdl-llm changes end
|
532
|
-
layer_outputs = decoder_layer(
|
533
|
-
hidden_states,
|
534
|
-
attention_mask=attention_mask,
|
535
|
-
position_ids=position_ids,
|
536
|
-
past_key_value=past_key_values,
|
537
|
-
output_attentions=output_attentions,
|
538
|
-
output_router_logits=output_router_logits,
|
539
|
-
use_cache=use_cache,
|
540
|
-
)
|
541
|
-
|
542
|
-
hidden_states = layer_outputs[0]
|
543
|
-
|
544
|
-
if use_cache:
|
545
|
-
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
546
|
-
|
547
|
-
if output_attentions:
|
548
|
-
all_self_attns += (layer_outputs[1],)
|
549
|
-
|
550
|
-
if output_router_logits:
|
551
|
-
all_router_logits += (layer_outputs[-1],)
|
552
|
-
|
553
|
-
hidden_states = self.norm(hidden_states)
|
554
|
-
|
555
|
-
# add hidden states from the last decoder layer
|
556
|
-
if output_hidden_states:
|
557
|
-
all_hidden_states += (hidden_states,)
|
558
|
-
|
559
|
-
next_cache = None
|
560
|
-
if use_cache:
|
561
|
-
next_cache = next_decoder_cache.to_legacy_cache() \
|
562
|
-
if use_legacy_cache else next_decoder_cache
|
563
|
-
|
564
|
-
if not return_dict:
|
565
|
-
return tuple(
|
566
|
-
v
|
567
|
-
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] # noqa
|
568
|
-
if v is not None
|
569
|
-
)
|
570
|
-
return MoeModelOutputWithPast(
|
571
|
-
last_hidden_state=hidden_states,
|
572
|
-
past_key_values=next_cache,
|
573
|
-
hidden_states=all_hidden_states,
|
574
|
-
attentions=all_self_attns,
|
575
|
-
router_logits=all_router_logits,
|
576
|
-
)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|