ipex-llm 2.2.0b20250104__py3-none-win_amd64.whl → 2.2.0b20250105.post0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. ipex_llm/libs/bloom-api.dll +0 -0
  2. ipex_llm/libs/bloom.dll +0 -0
  3. ipex_llm/libs/gptneox-api.dll +0 -0
  4. ipex_llm/libs/gptneox.dll +0 -0
  5. ipex_llm/libs/libbloom_avx.dll +0 -0
  6. ipex_llm/libs/libbloom_vnni.dll +0 -0
  7. ipex_llm/libs/libgptneox_avx.dll +0 -0
  8. ipex_llm/libs/libgptneox_vnni.dll +0 -0
  9. ipex_llm/libs/libllama_avx.dll +0 -0
  10. ipex_llm/libs/libllama_vnni.dll +0 -0
  11. ipex_llm/libs/libstarcoder_avx.dll +0 -0
  12. ipex_llm/libs/libstarcoder_vnni.dll +0 -0
  13. ipex_llm/libs/llama-api.dll +0 -0
  14. ipex_llm/libs/llama.dll +0 -0
  15. ipex_llm/libs/main-bloom.exe +0 -0
  16. ipex_llm/libs/main-gptneox.exe +0 -0
  17. ipex_llm/libs/main-llama.exe +0 -0
  18. ipex_llm/libs/main-starcoder.exe +0 -0
  19. ipex_llm/libs/pipeline.dll +0 -0
  20. ipex_llm/libs/quantize-bloom.exe +0 -0
  21. ipex_llm/libs/quantize-bloom_vnni.exe +0 -0
  22. ipex_llm/libs/quantize-gptneox.exe +0 -0
  23. ipex_llm/libs/quantize-gptneox_vnni.exe +0 -0
  24. ipex_llm/libs/quantize-llama.exe +0 -0
  25. ipex_llm/libs/quantize-llama_vnni.exe +0 -0
  26. ipex_llm/libs/quantize-starcoder.exe +0 -0
  27. ipex_llm/libs/quantize-starcoder_vnni.exe +0 -0
  28. ipex_llm/libs/starcoder-api.dll +0 -0
  29. ipex_llm/libs/starcoder.dll +0 -0
  30. ipex_llm/transformers/convert.py +17 -132
  31. ipex_llm/transformers/lookup.py +2 -2
  32. ipex_llm/transformers/low_bit_linear.py +8 -8
  33. ipex_llm/transformers/models/chatglm2.py +1 -192
  34. ipex_llm/transformers/models/minicpmv.py +2 -2
  35. ipex_llm/transformers/models/sd.py +2 -2
  36. ipex_llm/transformers/models/utils.py +14 -89
  37. ipex_llm/transformers/npu_model.py +80 -50
  38. ipex_llm/transformers/npu_models/convert_mp.py +1 -1
  39. ipex_llm/transformers/npu_models/linear.py +15 -3
  40. ipex_llm/transformers/npu_models/lm_head.py +1 -90
  41. ipex_llm/transformers/npu_models/lm_head_linear.py +106 -0
  42. ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +5 -8
  43. ipex_llm/transformers/utils.py +5 -20
  44. {ipex_llm-2.2.0b20250104.dist-info → ipex_llm-2.2.0b20250105.post0.dist-info}/METADATA +40 -19
  45. {ipex_llm-2.2.0b20250104.dist-info → ipex_llm-2.2.0b20250105.post0.dist-info}/RECORD +51 -53
  46. ipex_llm/transformers/models/cohere.py +0 -589
  47. ipex_llm/transformers/models/falcon.py +0 -829
  48. ipex_llm/transformers/models/mixtral.py +0 -576
  49. {ipex_llm-2.2.0b20250104.data → ipex_llm-2.2.0b20250105.post0.data}/scripts/ipex-llm-init.bat +0 -0
  50. {ipex_llm-2.2.0b20250104.data → ipex_llm-2.2.0b20250105.post0.data}/scripts/llm-chat.ps1 +0 -0
  51. {ipex_llm-2.2.0b20250104.data → ipex_llm-2.2.0b20250105.post0.data}/scripts/llm-cli.ps1 +0 -0
  52. {ipex_llm-2.2.0b20250104.dist-info → ipex_llm-2.2.0b20250105.post0.dist-info}/WHEEL +0 -0
  53. {ipex_llm-2.2.0b20250104.dist-info → ipex_llm-2.2.0b20250105.post0.dist-info}/entry_points.txt +0 -0
  54. {ipex_llm-2.2.0b20250104.dist-info → ipex_llm-2.2.0b20250105.post0.dist-info}/top_level.txt +0 -0
@@ -1,576 +0,0 @@
1
- #
2
- # Copyright 2016 The BigDL Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- # Some parts of this file is adapted from
17
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
18
-
19
- # coding=utf-8
20
- # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
21
- #
22
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
23
- # and OPT implementations in this library. It has been modified from its
24
- # original forms to accommodate minor architectural differences compared
25
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
26
- #
27
- # Licensed under the Apache License, Version 2.0 (the "License");
28
- # you may not use this file except in compliance with the License.
29
- # You may obtain a copy of the License at
30
- #
31
- # http://www.apache.org/licenses/LICENSE-2.0
32
- #
33
- # Unless required by applicable law or agreed to in writing, software
34
- # distributed under the License is distributed on an "AS IS" BASIS,
35
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36
- # See the License for the specific language governing permissions and
37
- # limitations under the License.
38
-
39
- """ PyTorch Mixtral model."""
40
- import math
41
- from typing import Optional, Tuple, Union, List
42
- from transformers.modeling_outputs import MoeModelOutputWithPast
43
- from transformers.cache_utils import Cache, DynamicCache
44
- from transformers.modeling_attn_mask_utils import (
45
- _prepare_4d_causal_attention_mask,
46
- )
47
-
48
- import torch
49
- from torch import nn
50
- import torch.nn.functional as F
51
- from ipex_llm.ggml.quantize import ggml_tensor_qtype
52
- from ipex_llm.utils.common import invalidInputError
53
- from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
54
- from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
55
- from ipex_llm.transformers.models.utils import should_use_fuse_rope
56
- from ipex_llm.transformers.models.utils import use_decoding_fast_path
57
- from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
58
- from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
59
- from ipex_llm.transformers.low_bit_linear import IQ2_XXS
60
-
61
- import os
62
-
63
- KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
64
-
65
-
66
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
67
- """
68
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
69
- The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
70
- to (batch, num_attention_heads, seqlen, head_dim)
71
- """
72
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
73
- if n_rep == 1:
74
- return hidden_states
75
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
76
- n_rep, slen, head_dim)
77
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
78
-
79
-
80
- def mixtral_moeblock_forward(self,
81
- hidden_states: torch.Tensor):
82
- batch_size, sequence_length, hidden_dim = hidden_states.shape
83
- hidden_states = hidden_states.view(-1, hidden_dim)
84
- bs = hidden_states.shape[0]
85
- # router_logits: (batch * sequence_length, n_experts)
86
- router_logits = self.gate(hidden_states)
87
-
88
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
89
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
90
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
91
- # we cast back to the input dtype
92
- routing_weights = routing_weights.to(hidden_states.dtype)
93
-
94
- if bs == 1:
95
- selected_experts = selected_experts[0].cpu().tolist()
96
- for idx in range(self.top_k):
97
- exp_id = selected_experts[idx]
98
- expert_layer = self.experts[exp_id]
99
- weight = routing_weights[:, idx]
100
- if idx == 0:
101
- final_hidden_states = expert_layer(hidden_states, weight)
102
- else:
103
- final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight)
104
- elif bs < 256 and hidden_states.device.type == 'xpu':
105
- final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim),
106
- dtype=hidden_states.dtype, device=hidden_states.device)
107
- import xe_linear
108
- indexes = xe_linear.get_moe_indexes(selected_experts.to(torch.int32).cpu(), 8)
109
- for expert_idx in range(self.num_experts):
110
- expert_layer = self.experts[expert_idx]
111
- idx_list = indexes[0][expert_idx]
112
- top_x_list = indexes[1][expert_idx]
113
- if len(idx_list) == 0:
114
- continue
115
-
116
- top_x = torch.tensor(top_x_list, dtype=torch.long, device=hidden_states.device)
117
- current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
118
- current_hidden_states = expert_layer(current_state,
119
- routing_weights[top_x_list, idx_list, None])
120
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
121
- else:
122
- final_hidden_states = torch.zeros(
123
- (batch_size * sequence_length, hidden_dim),
124
- dtype=hidden_states.dtype,
125
- device=hidden_states.device
126
- )
127
- # One hot encode the selected experts to create an expert mask
128
- # this will be used to easily index which expert is going to be sollicitated
129
- expert_mask = torch.nn.functional.one_hot(selected_experts,
130
- num_classes=self.num_experts).permute(2, 1, 0)
131
-
132
- # Loop over all available experts in the model and perform the computation on each expert
133
- for expert_idx in range(self.num_experts):
134
- expert_layer = self.experts[expert_idx]
135
- idx, top_x = torch.where(expert_mask[expert_idx])
136
-
137
- if top_x.shape[0] == 0:
138
- continue
139
-
140
- # in torch it is faster to index using lists than torch tensors
141
- top_x_list = top_x.tolist()
142
- idx_list = idx.tolist()
143
-
144
- # Index the correct hidden states and compute the expert hidden state for
145
- # the current expert. We need to make sure to multiply the output hidden
146
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
147
- current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
148
- current_hidden_states = expert_layer(current_state,
149
- routing_weights[top_x_list, idx_list, None])
150
-
151
- # However `index_add_` only support torch tensors for indexing so we'll use
152
- # the `top_x` tensor here.
153
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
154
-
155
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
156
- return final_hidden_states, router_logits
157
-
158
-
159
- def mixtral_attention_forward(
160
- self,
161
- hidden_states: torch.Tensor,
162
- attention_mask: Optional[torch.Tensor]=None,
163
- position_ids: Optional[torch.LongTensor]=None,
164
- past_key_value: Optional[Tuple[torch.Tensor]]=None,
165
- output_attentions: bool=False,
166
- use_cache: bool=False,
167
- padding_mask: Optional[torch.Tensor]=None,
168
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
169
- bsz, q_len, _ = hidden_states.size()
170
- device = hidden_states.device
171
- # for flash attention
172
- original_dtype = hidden_states.dtype
173
-
174
- use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
175
- enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
176
- decoding_fast_path = use_decoding_fast_path(self.q_proj,
177
- use_fuse_rope,
178
- enough_kv_room,
179
- bsz * q_len)
180
-
181
- if decoding_fast_path:
182
- hidden_states = hidden_states.view(1, -1)
183
- cache_k = past_key_value.key_cache[self.layer_idx]
184
- cache_v = past_key_value.value_cache[self.layer_idx]
185
- kv_seq_len = cache_k.shape[-2]
186
- import xe_linear
187
- query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states,
188
- self.q_proj.weight,
189
- self.k_proj.weight,
190
- self.v_proj.weight,
191
- position_ids,
192
- cache_k, cache_v,
193
- self.q_proj.weight.qtype,
194
- self.v_proj.weight.qtype,
195
- kv_seq_len,
196
- self.head_dim,
197
- self.rotary_emb.base,)
198
- kv_seq_len += 1
199
- # update past_key_value's seem_tokens and kv caches.
200
- if self.layer_idx == 0:
201
- past_key_value.seen_tokens = kv_seq_len
202
- past_key_value.key_cache[self.layer_idx] = key_states
203
- past_key_value.value_cache[self.layer_idx] = value_states
204
- # diasble it for now as it will cause output change for unknown reason
205
- # elif decoding_fast_path and self.q_proj.qtype == IQ2_XXS:
206
- # # this path self.v_proj use q4_0
207
- # hidden_states = hidden_states.view(1, -1)
208
- # cache_k = past_key_value.key_cache[self.layer_idx]
209
- # cache_v = past_key_value.value_cache[self.layer_idx]
210
- # kv_seq_len = cache_k.shape[-2]
211
- # import xe_linear
212
- # query_states, key_states = xe_linear.forward_qk(hidden_states,
213
- # self.q_proj.weight,
214
- # self.k_proj.weight,
215
- # position_ids,
216
- # cache_k,
217
- # self.q_proj.weight.qtype,
218
- # kv_seq_len,
219
- # self.head_dim,
220
- # 10000)
221
- # kv_seq_len += 1
222
- # # update past_key_value's seem_tokens and kv caches.
223
- # if self.layer_idx == 0:
224
- # past_key_value.seen_tokens = kv_seq_len
225
- # # update value_states
226
- # value_states = self.v_proj(hidden_states)
227
- # value_states = value_states.view(bsz, q_len,
228
- # self.num_key_value_heads, self.head_dim).transpose(1, 2)
229
- # new_size = (cache_v.size(0),
230
- # cache_v.size(1),
231
- # cache_v.size(2) + value_states.size(2),
232
- # cache_v.size(3))
233
- # new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
234
- # new_cache_v[:, :, cache_v.size(2):cache_v.size(2)+value_states.size(2), :] = value_states
235
-
236
- # past_key_value.key_cache[self.layer_idx] = key_states
237
- # past_key_value.value_cache[self.layer_idx] = new_cache_v
238
- else:
239
- query_states = self.q_proj(hidden_states)
240
- key_states = self.k_proj(hidden_states)
241
- value_states = self.v_proj(hidden_states)
242
-
243
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
244
- key_states = key_states.view(bsz, q_len,
245
- self.num_key_value_heads, self.head_dim).transpose(1, 2)
246
- value_states = value_states.view(bsz, q_len,
247
- self.num_key_value_heads, self.head_dim).transpose(1, 2)
248
-
249
- kv_seq_len = key_states.shape[-2]
250
- if past_key_value is not None:
251
- if self.layer_idx is None:
252
- invalidInputError(False,
253
- "The cache structure has changed since version v4.36. "
254
- f"If you are using {self.__class__.__name__} for "
255
- "auto-regressive decodingwith k/v caching, please make sure "
256
- "to initialize the attention class with a layer index.")
257
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
258
-
259
- if use_fuse_rope:
260
- import xe_addons
261
- xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
262
- query_states, key_states)
263
- else:
264
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
265
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
266
- cos, sin, position_ids, "mixtral")
267
-
268
- if past_key_value is not None:
269
- # update the number of seen tokens
270
- if self.layer_idx == 0:
271
- past_key_value.seen_tokens += key_states.shape[-2]
272
-
273
- # reuse k, v, self_attention
274
- # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
275
- if len(past_key_value.key_cache) <= self.layer_idx:
276
- past_key_value.key_cache.append(key_states)
277
- past_key_value.value_cache.append(value_states)
278
- else:
279
- cache_k = past_key_value.key_cache[self.layer_idx]
280
- cache_v = past_key_value.value_cache[self.layer_idx]
281
-
282
- if not enough_kv_room:
283
- # allocate new
284
- new_c_k, new_c_v = extend_kv_cache(bsz,
285
- self.num_key_value_heads, # Support GQA
286
- self.head_dim,
287
- cache_k.size(2),
288
- kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
289
- dtype=cache_k.dtype,
290
- device=device)
291
-
292
- new_c_k[:] = cache_k
293
- new_c_v[:] = cache_v
294
- cache_k = new_c_k
295
- cache_v = new_c_v
296
-
297
- key_states, value_states = append_kv_cache(cache_k,
298
- cache_v,
299
- key_states,
300
- value_states)
301
-
302
- # update past_key_value
303
- past_key_value.key_cache[self.layer_idx] = key_states
304
- past_key_value.value_cache[self.layer_idx] = value_states
305
-
306
- if not self.training and not hidden_states.requires_grad:
307
- fsdp_flag = use_flash_attention(query_states, key_states)
308
- else:
309
- fsdp_flag = False
310
- if fsdp_flag:
311
- attention_dtype = torch.float16 # use fp16 for flash attention
312
- else:
313
- attention_dtype = original_dtype
314
-
315
- # repeat k/v heads if n_kv_heads < n_heads
316
- key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
317
- dtype=attention_dtype)
318
- value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
319
- dtype=attention_dtype)
320
-
321
- if fsdp_flag:
322
- attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
323
- key_states,
324
- value_states,
325
- is_causal=True)
326
- attn_weights = None
327
- elif use_sdp(query_states.shape[2], key_states.shape[2], self.head_dim, query_states):
328
- import xe_addons
329
- attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
330
- attn_output = attn_output.view(query_states.shape)
331
- attn_weights = None
332
- else:
333
- attn_weights = torch.matmul(
334
- query_states.to(key_states.dtype),
335
- key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
336
-
337
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
338
- invalidInputError(
339
- False,
340
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},"
341
- f" but is {attn_weights.size()}"
342
- )
343
-
344
- if attention_mask is not None:
345
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
346
- invalidInputError(
347
- False,
348
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
349
- f" but is {attention_mask.size()}"
350
- )
351
-
352
- attn_weights = attn_weights + attention_mask
353
-
354
- # upcast attention to fp32
355
- attn_weights = nn.functional.\
356
- softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
357
- attn_output = torch.matmul(attn_weights, value_states)
358
-
359
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
360
- invalidInputError(
361
- False,
362
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)},"
363
- f" but is {attn_output.size()}"
364
- )
365
-
366
- attn_output = attn_output.transpose(1, 2).contiguous()
367
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
368
-
369
- attn_output = self.o_proj(attn_output)
370
-
371
- if not output_attentions:
372
- attn_weights = None
373
-
374
- return attn_output, attn_weights, past_key_value
375
-
376
-
377
- def mixtral_mlp_forward(
378
- self,
379
- x: torch.Tensor,
380
- routing_weights
381
- ) -> torch.Tensor:
382
- qtype = getattr(self.w1, "qtype", None)
383
- if mlp_fusion_check(x, qtype, self.training):
384
- import xe_linear
385
- return self.w2(xe_linear.mlp_forward_xpu(
386
- x, self.w1.weight.data, self.w3.weight.data,
387
- x.shape[0], x.shape[1], self.w1.out_len,
388
- SILU, qtype,
389
- )) * routing_weights
390
- else:
391
- current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
392
- current_hidden_states = self.w2(current_hidden_states)
393
- return routing_weights * current_hidden_states
394
-
395
-
396
- def mixtral_model_forward(
397
- self,
398
- input_ids: torch.LongTensor = None,
399
- attention_mask: Optional[torch.Tensor] = None,
400
- position_ids: Optional[torch.LongTensor] = None,
401
- past_key_values: Optional[List[torch.FloatTensor]] = None,
402
- inputs_embeds: Optional[torch.FloatTensor] = None,
403
- use_cache: Optional[bool] = None,
404
- output_attentions: Optional[bool] = None,
405
- output_hidden_states: Optional[bool] = None,
406
- output_router_logits: Optional[bool] = None,
407
- return_dict: Optional[bool] = None,
408
- ) -> Union[Tuple, MoeModelOutputWithPast]:
409
- # to be compatible with transformers>=4.37.0
410
- self._use_flash_attention_2 = self.config._attn_implementation == "flash_attention_2"
411
-
412
- output_attentions = output_attentions if output_attentions is not None \
413
- else self.config.output_attentions
414
- output_router_logits = (
415
- output_router_logits if output_router_logits is not None
416
- else self.config.output_router_logits
417
- )
418
- output_hidden_states = (
419
- output_hidden_states if output_hidden_states is not None else
420
- self.config.output_hidden_states
421
- )
422
- use_cache = use_cache if use_cache is not None else self.config.use_cache
423
-
424
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
425
-
426
- # retrieve input_ids and inputs_embeds
427
- if input_ids is not None and inputs_embeds is not None:
428
- invalidInputError(False, "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") # noqa
429
- elif input_ids is not None:
430
- batch_size, seq_length = input_ids.shape
431
- elif inputs_embeds is not None:
432
- batch_size, seq_length, _ = inputs_embeds.shape
433
- else:
434
- invalidInputError(False, "You have to specify either decoder_input_ids or decoder_inputs_embeds") # noqa
435
-
436
- past_key_values_length = 0
437
-
438
- if use_cache:
439
- use_legacy_cache = not isinstance(past_key_values, Cache)
440
- if use_legacy_cache:
441
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
442
- past_key_values_length = past_key_values.get_usable_length(seq_length)
443
-
444
- if position_ids is None:
445
- device = input_ids.device if input_ids is not None else inputs_embeds.device
446
- position_ids = torch.arange(
447
- past_key_values_length, seq_length + past_key_values_length,
448
- dtype=torch.long, device=device
449
- )
450
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
451
- else:
452
- position_ids = position_ids.view(-1, seq_length).long()
453
-
454
- if inputs_embeds is None:
455
- inputs_embeds = self.embed_tokens(input_ids)
456
-
457
- if attention_mask is not None and self._use_flash_attention_2 and use_cache:
458
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
459
- if is_padding_right:
460
- invalidInputError(
461
- False,
462
- "You are attempting to perform batched generation with padding_side='right'"
463
- " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " # noqa
464
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
465
- )
466
-
467
- if self._use_flash_attention_2:
468
- # 2d mask is passed through the layers
469
- attention_mask = attention_mask \
470
- if (attention_mask is not None and 0 in attention_mask) else None
471
- else:
472
- # 4d mask is passed through the layers
473
- attention_mask = _prepare_4d_causal_attention_mask(
474
- attention_mask,
475
- (batch_size, seq_length),
476
- inputs_embeds,
477
- past_key_values_length,
478
- sliding_window=self.config.sliding_window,
479
- )
480
-
481
- hidden_states = inputs_embeds
482
-
483
- if self.gradient_checkpointing and self.training:
484
- if use_cache:
485
- logger.warning_once(
486
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." # noqa
487
- )
488
- use_cache = False
489
-
490
- # decoder layers
491
- all_hidden_states = () if output_hidden_states else None
492
- all_self_attns = () if output_attentions else None
493
- all_router_logits = () if output_router_logits else None
494
- next_decoder_cache = None
495
-
496
- for decoder_layer in self.layers:
497
- if output_hidden_states:
498
- all_hidden_states += (hidden_states,)
499
-
500
- if self.gradient_checkpointing and self.training:
501
- layer_outputs = self._gradient_checkpointing_func(
502
- decoder_layer.__call__,
503
- hidden_states,
504
- attention_mask,
505
- position_ids,
506
- past_key_values,
507
- output_attentions,
508
- output_router_logits,
509
- use_cache,
510
- )
511
- else:
512
- # bigdl-llm changes:
513
- #
514
- # Avoid moving `attention_mask`` and `position_ids`` to other devices multiple times.
515
- #
516
- # When the model is partitioned on two different devices using
517
- # `accelerate`'s `dispatch``, a hook to move inputs to the correct device is
518
- # added to each layer's `forward``, which will result in moving `attention_mask`
519
- # and `position_ids`, which allocated on device:0, to other devices for each
520
- # decoder layer not in device:0.
521
- #
522
- # To avoid this, we move `attention_mask` and `position_ids` to the device of
523
- # the current layer before the forward call, so that the moving is only done once
524
- # for each devices other than devie:0.
525
- #
526
- curr_device = decoder_layer.input_layernorm.weight.device
527
- if attention_mask is not None:
528
- attention_mask = attention_mask.to(curr_device)
529
- if position_ids is not None:
530
- position_ids = position_ids.to(curr_device)
531
- # bigdl-llm changes end
532
- layer_outputs = decoder_layer(
533
- hidden_states,
534
- attention_mask=attention_mask,
535
- position_ids=position_ids,
536
- past_key_value=past_key_values,
537
- output_attentions=output_attentions,
538
- output_router_logits=output_router_logits,
539
- use_cache=use_cache,
540
- )
541
-
542
- hidden_states = layer_outputs[0]
543
-
544
- if use_cache:
545
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
546
-
547
- if output_attentions:
548
- all_self_attns += (layer_outputs[1],)
549
-
550
- if output_router_logits:
551
- all_router_logits += (layer_outputs[-1],)
552
-
553
- hidden_states = self.norm(hidden_states)
554
-
555
- # add hidden states from the last decoder layer
556
- if output_hidden_states:
557
- all_hidden_states += (hidden_states,)
558
-
559
- next_cache = None
560
- if use_cache:
561
- next_cache = next_decoder_cache.to_legacy_cache() \
562
- if use_legacy_cache else next_decoder_cache
563
-
564
- if not return_dict:
565
- return tuple(
566
- v
567
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] # noqa
568
- if v is not None
569
- )
570
- return MoeModelOutputWithPast(
571
- last_hidden_state=hidden_states,
572
- past_key_values=next_cache,
573
- hidden_states=all_hidden_states,
574
- attentions=all_self_attns,
575
- router_logits=all_router_logits,
576
- )