ipex-llm 2.2.0b20250106__py3-none-manylinux2010_x86_64.whl → 2.2.0b20250107__py3-none-manylinux2010_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. ipex_llm/libs/libbloom_amx.so +0 -0
  2. ipex_llm/libs/libbloom_avx.so +0 -0
  3. ipex_llm/libs/libbloom_avx2.so +0 -0
  4. ipex_llm/libs/libbloom_avx512.so +0 -0
  5. ipex_llm/libs/libbloom_avxvnni.so +0 -0
  6. ipex_llm/libs/libgptneox_amx.so +0 -0
  7. ipex_llm/libs/libgptneox_avx.so +0 -0
  8. ipex_llm/libs/libgptneox_avx2.so +0 -0
  9. ipex_llm/libs/libgptneox_avx512.so +0 -0
  10. ipex_llm/libs/libgptneox_avxvnni.so +0 -0
  11. ipex_llm/libs/libllama_amx.so +0 -0
  12. ipex_llm/libs/libllama_avx.so +0 -0
  13. ipex_llm/libs/libllama_avx2.so +0 -0
  14. ipex_llm/libs/libllama_avx512.so +0 -0
  15. ipex_llm/libs/libllama_avxvnni.so +0 -0
  16. ipex_llm/libs/libstarcoder_amx.so +0 -0
  17. ipex_llm/libs/libstarcoder_avx.so +0 -0
  18. ipex_llm/libs/libstarcoder_avx2.so +0 -0
  19. ipex_llm/libs/libstarcoder_avx512.so +0 -0
  20. ipex_llm/libs/libstarcoder_avxvnni.so +0 -0
  21. ipex_llm/libs/quantize-bloom +0 -0
  22. ipex_llm/libs/quantize-gptneox +0 -0
  23. ipex_llm/libs/quantize-llama +0 -0
  24. ipex_llm/libs/quantize-starcoder +0 -0
  25. ipex_llm/transformers/convert.py +17 -132
  26. ipex_llm/transformers/lookup.py +2 -2
  27. ipex_llm/transformers/low_bit_linear.py +8 -8
  28. ipex_llm/transformers/models/chatglm2.py +1 -192
  29. ipex_llm/transformers/models/minicpmv.py +2 -2
  30. ipex_llm/transformers/models/sd.py +2 -2
  31. ipex_llm/transformers/models/utils.py +14 -89
  32. ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +5 -8
  33. ipex_llm/transformers/utils.py +5 -20
  34. {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/METADATA +40 -19
  35. {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/RECORD +41 -44
  36. ipex_llm/transformers/models/cohere.py +0 -589
  37. ipex_llm/transformers/models/falcon.py +0 -829
  38. ipex_llm/transformers/models/mixtral.py +0 -576
  39. {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250107.data}/scripts/ipex-llm-init +0 -0
  40. {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250107.data}/scripts/llm-chat +0 -0
  41. {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250107.data}/scripts/llm-cli +0 -0
  42. {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/WHEEL +0 -0
  43. {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/entry_points.txt +0 -0
  44. {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/top_level.txt +0 -0
@@ -1,589 +0,0 @@
1
- #
2
- # Copyright 2016 The BigDL Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- # Some parts of this file is adapted from
17
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py
18
-
19
- # coding=utf-8
20
- # Copyright 2024 Cohere team. All rights reserved.
21
- #
22
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
23
- # and OPT implementations in this library. It has been modified from its
24
- # original forms to accommodate minor architectural differences compared
25
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
26
- #
27
- # Licensed under the Apache License, Version 2.0 (the "License");
28
- # you may not use this file except in compliance with the License.
29
- # You may obtain a copy of the License at
30
- #
31
- # http://www.apache.org/licenses/LICENSE-2.0
32
- #
33
- # Unless required by applicable law or agreed to in writing, software
34
- # distributed under the License is distributed on an "AS IS" BASIS,
35
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36
- # See the License for the specific language governing permissions and
37
- # limitations under the License.
38
-
39
- # This file is based on the LLama model definition file in transformers
40
-
41
- """PyTorch Cohere model."""
42
- import math
43
- import torch
44
- import torch.nn.functional as F
45
- import torch.nn as nn
46
- import torch.utils.checkpoint
47
- from typing import Optional, Tuple, List
48
- from ipex_llm.transformers.models.utils import repeat_kv
49
- from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
50
- from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
51
- from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
52
- from ipex_llm.transformers.models.utils import use_decoding_fast_path
53
- from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
54
- from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
55
- from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
56
- from ipex_llm.transformers.kv import DynamicFp8Cache
57
- from ipex_llm.transformers.models.utils import should_use_fuse_rope
58
- from transformers.modeling_outputs import BaseModelOutputWithPast
59
- from ipex_llm.utils.common import invalidInputError
60
- try:
61
- from transformers.cache_utils import Cache, DynamicCache
62
- except ImportError:
63
- Cache = Tuple[torch.Tensor]
64
-
65
- KV_CACHE_ALLOC_BLOCK_LENGTH = 256
66
-
67
-
68
- def cohere_model_forward(
69
- self,
70
- input_ids: torch.LongTensor = None,
71
- attention_mask: Optional[torch.Tensor] = None,
72
- position_ids: Optional[torch.LongTensor] = None,
73
- past_key_values: Optional[List[torch.FloatTensor]] = None,
74
- inputs_embeds: Optional[torch.FloatTensor] = None,
75
- use_cache: Optional[bool] = None,
76
- output_attentions: Optional[bool] = None,
77
- output_hidden_states: Optional[bool] = None,
78
- return_dict: Optional[bool] = None,
79
- cache_position: Optional[torch.LongTensor] = None,
80
- ):
81
- use_cache = use_cache if use_cache is not None \
82
- else self.config.use_cache
83
- if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
84
- if not isinstance(past_key_values, DynamicFp8Cache):
85
- past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
86
- output_attentions = output_attentions if output_attentions is not None \
87
- else self.config.output_attentions
88
- output_hidden_states = (
89
- output_hidden_states if output_hidden_states is not None
90
- else self.config.output_hidden_states
91
- )
92
- use_cache = use_cache if use_cache is not None else self.config.use_cache
93
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
94
-
95
- if input_ids is not None and inputs_embeds is not None:
96
- invalidInputError(False,
97
- "You cannot specify both input_ids and inputs_embeds at the same time")
98
-
99
- if self.gradient_checkpointing and self.training and use_cache:
100
- invalidInputError(False,
101
- "`use_cache=True` is incompatible "
102
- "with gradient checkpointing. Setting `use_cache=False`.")
103
- use_cache = False
104
-
105
- if inputs_embeds is None:
106
- inputs_embeds = self.embed_tokens(input_ids)
107
-
108
- past_seen_tokens = 0
109
- if use_cache: # kept for BC (cache positions)
110
- if not isinstance(past_key_values, Cache):
111
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
112
- past_seen_tokens = past_key_values.get_seq_length()
113
-
114
- if cache_position is None:
115
- if isinstance(past_key_values, Cache):
116
- invalidInputError(False, "cache_position is a required argument when using Cache.")
117
- cache_position = torch.arange(
118
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
119
- )
120
-
121
- if position_ids is None:
122
- position_ids = cache_position.unsqueeze(0)
123
-
124
- causal_mask = self._update_causal_mask(attention_mask,
125
- inputs_embeds, cache_position, past_seen_tokens)
126
-
127
- # embed positions
128
- hidden_states = inputs_embeds
129
-
130
- # decoder layers
131
- all_hidden_states = () if output_hidden_states else None
132
- all_self_attns = () if output_attentions else None
133
- next_decoder_cache = None
134
-
135
- for decoder_layer in self.layers:
136
- if output_hidden_states:
137
- all_hidden_states += (hidden_states,)
138
-
139
- if self.gradient_checkpointing and self.training:
140
- layer_outputs = self._gradient_checkpointing_func(
141
- decoder_layer.__call__,
142
- hidden_states,
143
- causal_mask,
144
- position_ids,
145
- past_key_values,
146
- output_attentions,
147
- use_cache,
148
- cache_position,
149
- )
150
- else:
151
- # ipex-llm changes
152
- curr_device = decoder_layer.input_layernorm.weight.device
153
- if causal_mask is not None:
154
- causal_mask = causal_mask.to(curr_device)
155
- if position_ids is not None:
156
- position_ids = position_ids.to(curr_device)
157
- # ipex-llm changes end
158
- layer_outputs = decoder_layer(
159
- hidden_states,
160
- attention_mask=causal_mask,
161
- position_ids=position_ids,
162
- past_key_value=past_key_values,
163
- output_attentions=output_attentions,
164
- use_cache=use_cache,
165
- cache_position=cache_position,
166
- )
167
-
168
- hidden_states = layer_outputs[0]
169
-
170
- if use_cache:
171
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
172
-
173
- if output_attentions:
174
- all_self_attns += (layer_outputs[1],)
175
-
176
- hidden_states = self.norm(hidden_states)
177
-
178
- # add hidden states from the last decoder layer
179
- if output_hidden_states:
180
- all_hidden_states += (hidden_states,)
181
-
182
- next_cache = next_decoder_cache if use_cache else None
183
- if not return_dict:
184
- return tuple(v for v in [hidden_states, next_cache,
185
- all_hidden_states, all_self_attns] if v is not None)
186
- return BaseModelOutputWithPast(
187
- last_hidden_state=hidden_states,
188
- past_key_values=next_cache,
189
- hidden_states=all_hidden_states,
190
- attentions=all_self_attns,
191
- )
192
-
193
-
194
- def cohere_model_forward_4_41(
195
- self,
196
- input_ids: torch.LongTensor = None,
197
- attention_mask: Optional[torch.Tensor] = None,
198
- position_ids: Optional[torch.LongTensor] = None,
199
- past_key_values: Optional[List[torch.FloatTensor]] = None,
200
- inputs_embeds: Optional[torch.FloatTensor] = None,
201
- use_cache: Optional[bool] = None,
202
- output_attentions: Optional[bool] = None,
203
- output_hidden_states: Optional[bool] = None,
204
- return_dict: Optional[bool] = None,
205
- cache_position: Optional[torch.LongTensor] = None,
206
- ):
207
- use_cache = use_cache if use_cache is not None \
208
- else self.config.use_cache
209
- if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
210
- if not isinstance(past_key_values, DynamicFp8Cache):
211
- past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
212
- output_attentions = output_attentions if output_attentions is not None \
213
- else self.config.output_attentions
214
- output_hidden_states = (
215
- output_hidden_states if output_hidden_states is not None
216
- else self.config.output_hidden_states
217
- )
218
- use_cache = use_cache if use_cache is not None else self.config.use_cache
219
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
220
-
221
- if input_ids is not None and inputs_embeds is not None:
222
- invalidInputError(False,
223
- "You cannot specify both input_ids and inputs_embeds at the same time")
224
-
225
- if self.gradient_checkpointing and self.training and use_cache:
226
- invalidInputError(False,
227
- "`use_cache=True` is incompatible "
228
- "with gradient checkpointing. Setting `use_cache=False`.")
229
- use_cache = False
230
-
231
- if inputs_embeds is None:
232
- inputs_embeds = self.embed_tokens(input_ids)
233
-
234
- past_seen_tokens = 0
235
- return_legacy_cache = False
236
- # kept for BC (non `Cache` `past_key_values` inputs)
237
- if use_cache and not isinstance(past_key_values, Cache):
238
- return_legacy_cache = True
239
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
240
-
241
- if cache_position is None:
242
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
243
- cache_position = torch.arange(
244
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
245
- )
246
-
247
- if position_ids is None:
248
- position_ids = cache_position.unsqueeze(0)
249
-
250
- causal_mask = self._update_causal_mask(
251
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
252
- )
253
-
254
- # embed positions
255
- hidden_states = inputs_embeds
256
-
257
- # decoder layers
258
- all_hidden_states = () if output_hidden_states else None
259
- all_self_attns = () if output_attentions else None
260
- next_decoder_cache = None
261
-
262
- for decoder_layer in self.layers:
263
- if output_hidden_states:
264
- all_hidden_states += (hidden_states,)
265
-
266
- if self.gradient_checkpointing and self.training:
267
- layer_outputs = self._gradient_checkpointing_func(
268
- decoder_layer.__call__,
269
- hidden_states,
270
- causal_mask,
271
- position_ids,
272
- past_key_values,
273
- output_attentions,
274
- use_cache,
275
- cache_position,
276
- )
277
- else:
278
- # ipex-llm changes
279
- curr_device = decoder_layer.input_layernorm.weight.device
280
- if causal_mask is not None:
281
- causal_mask = causal_mask.to(curr_device)
282
- if position_ids is not None:
283
- position_ids = position_ids.to(curr_device)
284
- # ipex-llm changes end
285
- layer_outputs = decoder_layer(
286
- hidden_states,
287
- attention_mask=causal_mask,
288
- position_ids=position_ids,
289
- past_key_value=past_key_values,
290
- output_attentions=output_attentions,
291
- use_cache=use_cache,
292
- cache_position=cache_position,
293
- )
294
-
295
- hidden_states = layer_outputs[0]
296
-
297
- if use_cache:
298
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
299
-
300
- if output_attentions:
301
- all_self_attns += (layer_outputs[1],)
302
-
303
- hidden_states = self.norm(hidden_states)
304
-
305
- # add hidden states from the last decoder layer
306
- if output_hidden_states:
307
- all_hidden_states += (hidden_states,)
308
-
309
- next_cache = next_decoder_cache if use_cache else None
310
- if return_legacy_cache:
311
- next_cache = next_cache.to_legacy_cache()
312
- if not return_dict:
313
- return tuple(v for v in [hidden_states, next_cache,
314
- all_hidden_states, all_self_attns] if v is not None)
315
- return BaseModelOutputWithPast(
316
- last_hidden_state=hidden_states,
317
- past_key_values=next_cache,
318
- hidden_states=all_hidden_states,
319
- attentions=all_self_attns,
320
- )
321
-
322
-
323
- def cohere_attention_forward(
324
- self,
325
- hidden_states: torch.Tensor,
326
- attention_mask: Optional[torch.Tensor] = None,
327
- position_ids: Optional[torch.LongTensor] = None,
328
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
329
- output_attentions: bool = False,
330
- use_cache: bool = False,
331
- cache_position: Optional[torch.LongTensor] = None,
332
- **kwargs,
333
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
334
- if use_quantize_kv_cache(self.q_proj, hidden_states):
335
- forward_function = cohere_attention_forward_quantized
336
- else:
337
- forward_function = cohere_attention_forward_origin
338
- return forward_function(
339
- self=self,
340
- hidden_states=hidden_states,
341
- attention_mask=attention_mask,
342
- position_ids=position_ids,
343
- past_key_value=past_key_value,
344
- output_attentions=output_attentions,
345
- use_cache=use_cache,
346
- cache_position=cache_position,
347
- **kwargs,
348
- )
349
-
350
-
351
- def cohere_attention_forward_quantized(
352
- self,
353
- hidden_states: torch.Tensor,
354
- attention_mask: Optional[torch.Tensor] = None,
355
- position_ids: Optional[torch.LongTensor] = None,
356
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
357
- output_attentions: bool = False,
358
- use_cache: bool = False,
359
- cache_position: Optional[torch.LongTensor] = None,
360
- **kwargs,
361
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
362
- bsz, q_len, _ = hidden_states.size()
363
-
364
- query_states = self.q_proj(hidden_states)
365
- key_states = self.k_proj(hidden_states)
366
- value_states = self.v_proj(hidden_states)
367
-
368
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
369
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
370
- if self.use_qk_norm:
371
- query_states = self.q_norm(query_states)
372
- key_states = self.k_norm(key_states)
373
-
374
- query_states = query_states.transpose(1, 2)
375
- key_states = key_states.transpose(1, 2)
376
- value_states = value_states.view(bsz, q_len,
377
- self.num_key_value_heads, self.head_dim).transpose(1, 2)
378
-
379
- past_key_value = getattr(self, "past_key_value", past_key_value)
380
- kv_seq_len = key_states.shape[-2]
381
- if past_key_value is not None:
382
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
383
- cos, sin = self.rotary_emb(value_states, position_ids)
384
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
385
-
386
- if past_key_value is not None:
387
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
388
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
389
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx,
390
- cache_kwargs, new_layout=True)
391
- if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
392
- and not hidden_states.requires_grad:
393
- import xe_addons
394
- attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask)
395
- attn_weights = None
396
- else:
397
- key_states, value_states = restore_fp8_kv_cache(key_states,
398
- value_states, query_states.dtype)
399
- key_states = repeat_kv(key_states, self.num_key_value_groups)
400
- value_states = repeat_kv(value_states, self.num_key_value_groups)
401
-
402
- attn_weights = torch.matmul(query_states,
403
- key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
404
-
405
- if attention_mask is not None: # no matter the length, we just slice it
406
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
407
- attn_weights = attn_weights + causal_mask
408
-
409
- # upcast attention to fp32
410
- attn_weights = nn.functional.softmax(attn_weights,
411
- dim=-1, dtype=torch.float32).to(query_states.dtype)
412
- attn_weights = nn.functional.dropout(attn_weights,
413
- p=self.attention_dropout, training=self.training)
414
- attn_output = torch.matmul(attn_weights, value_states)
415
-
416
- invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
417
- "`attn_output` should be of size "
418
- f"{(bsz, self.num_heads, q_len, self.head_dim)},"
419
- f" but is {attn_output.size()}")
420
-
421
- attn_output = attn_output.transpose(1, 2).contiguous()
422
-
423
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
424
-
425
- attn_output = self.o_proj(attn_output)
426
-
427
- if not output_attentions:
428
- attn_weights = None
429
-
430
- return attn_output, attn_weights, past_key_value
431
-
432
-
433
- def cohere_attention_forward_origin(
434
- self,
435
- hidden_states: torch.Tensor,
436
- attention_mask: Optional[torch.Tensor] = None,
437
- position_ids: Optional[torch.LongTensor] = None,
438
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
439
- output_attentions: bool = False,
440
- use_cache: bool = False,
441
- cache_position: Optional[torch.LongTensor] = None,
442
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
443
- bsz, q_len, _ = hidden_states.size()
444
- device = hidden_states.device
445
- use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
446
- enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
447
- decoding_fast_path = use_decoding_fast_path(self.q_proj,
448
- use_fuse_rope,
449
- enough_kv_room,
450
- bsz * q_len)
451
- if decoding_fast_path:
452
- hidden_states = hidden_states.view(1, -1)
453
- cache_k = past_key_value.key_cache[self.layer_idx]
454
- cache_v = past_key_value.value_cache[self.layer_idx]
455
- kv_seq_len = cache_k.shape[-2]
456
- import xe_linear
457
- query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states,
458
- self.q_proj.weight,
459
- self.k_proj.weight,
460
- self.v_proj.weight,
461
- position_ids,
462
- cache_k, cache_v,
463
- self.q_proj.weight.qtype,
464
- self.v_proj.weight.qtype,
465
- kv_seq_len,
466
- self.head_dim,
467
- self.rotary_emb.base,)
468
- kv_seq_len += 1
469
- # update past_key_value's seem_tokens and kv caches.
470
- if self.layer_idx == 0:
471
- past_key_value._seen_tokens = kv_seq_len
472
- past_key_value.key_cache[self.layer_idx] = key_states
473
- past_key_value.value_cache[self.layer_idx] = value_states
474
- else:
475
- query_states = self.q_proj(hidden_states)
476
- key_states = self.k_proj(hidden_states)
477
- value_states = self.v_proj(hidden_states)
478
-
479
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
480
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
481
- if self.use_qk_norm:
482
- query_states = self.q_norm(query_states)
483
- key_states = self.k_norm(key_states)
484
-
485
- query_states = query_states.transpose(1, 2)
486
- key_states = key_states.transpose(1, 2)
487
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
488
- self.head_dim).transpose(1, 2)
489
-
490
- past_key_value = getattr(self, "past_key_value", past_key_value)
491
- kv_seq_len = key_states.shape[-2]
492
- if past_key_value is not None:
493
- if self.layer_idx is None:
494
- invalidInputError(
495
- False,
496
- "The cache structure has changed since version v4.36. "
497
- f"If you are using {self.__class__.__name__} "
498
- "for auto-regressive decoding with k/v caching, "
499
- "please make sure to initialize the attention class with a layer index."
500
- )
501
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
502
- cos, sin = self.rotary_emb(value_states, position_ids)
503
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
504
-
505
- if past_key_value is not None:
506
- if self.layer_idx == 0:
507
- past_key_value._seen_tokens += key_states.shape[-2]
508
-
509
- if len(past_key_value.key_cache) <= self.layer_idx:
510
- past_key_value.key_cache.append(key_states)
511
- past_key_value.value_cache.append(value_states)
512
- else:
513
- cache_k = past_key_value.key_cache[self.layer_idx]
514
- cache_v = past_key_value.value_cache[self.layer_idx]
515
-
516
- if not enough_kv_room:
517
- # allocate new
518
- new_c_k, new_c_v = extend_kv_cache(bsz,
519
- self.num_key_value_heads, # Support GQA
520
- self.head_dim,
521
- cache_k.size(2),
522
- kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
523
- dtype=cache_k.dtype,
524
- device=device)
525
-
526
- new_c_k[:] = cache_k
527
- new_c_v[:] = cache_v
528
- cache_k = new_c_k
529
- cache_v = new_c_v
530
-
531
- key_states, value_states = append_kv_cache(cache_k,
532
- cache_v,
533
- key_states,
534
- value_states)
535
-
536
- # update past_key_value
537
- past_key_value.key_cache[self.layer_idx] = key_states
538
- past_key_value.value_cache[self.layer_idx] = value_states
539
-
540
- key_states = repeat_kv(key_states, self.num_key_value_groups)
541
- value_states = repeat_kv(value_states, self.num_key_value_groups)
542
-
543
- if not self.training and not hidden_states.requires_grad and \
544
- use_flash_attention(query_states, key_states, attention_mask):
545
- attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16),
546
- key_states.to(device, dtype=torch.float16),
547
- value_states.to(device, dtype=torch.float16),
548
- is_causal=True)
549
- attn_weights = None
550
- elif not self.training and not hidden_states.requires_grad and \
551
- use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
552
- import xe_addons
553
- if attention_mask is not None:
554
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
555
- else:
556
- causal_mask = None
557
- attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask)
558
- attn_output = attn_output.view(query_states.shape)
559
- attn_weights = None
560
- else:
561
- attn_weights = torch.matmul(query_states,
562
- key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
563
-
564
- if attention_mask is not None: # no matter the length, we just slice it
565
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
566
- attn_weights = attn_weights + causal_mask
567
-
568
- # upcast attention to fp32
569
- attn_weights = nn.functional.softmax(attn_weights,
570
- dim=-1, dtype=torch.float32).to(query_states.dtype)
571
- attn_weights = nn.functional.dropout(attn_weights,
572
- p=self.attention_dropout, training=self.training)
573
- attn_output = torch.matmul(attn_weights, value_states)
574
-
575
- invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
576
- "`attn_output` should be of size "
577
- f"{(bsz, self.num_heads, q_len, self.head_dim)},"
578
- f" but is {attn_output.size()}")
579
-
580
- attn_output = attn_output.transpose(1, 2).contiguous()
581
-
582
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
583
-
584
- attn_output = self.o_proj(attn_output)
585
-
586
- if not output_attentions:
587
- attn_weights = None
588
-
589
- return attn_output.to(hidden_states.dtype), attn_weights, past_key_value