ipex-llm 2.2.0b20250108__py3-none-win_amd64.whl → 2.2.0b20250110__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ipex_llm/libs/bloom-api.dll +0 -0
  2. ipex_llm/libs/bloom.dll +0 -0
  3. ipex_llm/libs/gptneox-api.dll +0 -0
  4. ipex_llm/libs/gptneox.dll +0 -0
  5. ipex_llm/libs/libbloom_avx.dll +0 -0
  6. ipex_llm/libs/libbloom_vnni.dll +0 -0
  7. ipex_llm/libs/libgptneox_avx.dll +0 -0
  8. ipex_llm/libs/libgptneox_vnni.dll +0 -0
  9. ipex_llm/libs/libllama_avx.dll +0 -0
  10. ipex_llm/libs/libllama_vnni.dll +0 -0
  11. ipex_llm/libs/libstarcoder_avx.dll +0 -0
  12. ipex_llm/libs/libstarcoder_vnni.dll +0 -0
  13. ipex_llm/libs/llama-api.dll +0 -0
  14. ipex_llm/libs/llama.dll +0 -0
  15. ipex_llm/libs/main-bloom.exe +0 -0
  16. ipex_llm/libs/main-gptneox.exe +0 -0
  17. ipex_llm/libs/main-llama.exe +0 -0
  18. ipex_llm/libs/main-starcoder.exe +0 -0
  19. ipex_llm/libs/pipeline.dll +0 -0
  20. ipex_llm/libs/quantize-bloom.exe +0 -0
  21. ipex_llm/libs/quantize-bloom_vnni.exe +0 -0
  22. ipex_llm/libs/quantize-gptneox.exe +0 -0
  23. ipex_llm/libs/quantize-gptneox_vnni.exe +0 -0
  24. ipex_llm/libs/quantize-llama.exe +0 -0
  25. ipex_llm/libs/quantize-llama_vnni.exe +0 -0
  26. ipex_llm/libs/quantize-starcoder.exe +0 -0
  27. ipex_llm/libs/quantize-starcoder_vnni.exe +0 -0
  28. ipex_llm/libs/starcoder-api.dll +0 -0
  29. ipex_llm/libs/starcoder.dll +0 -0
  30. ipex_llm/transformers/convert.py +9 -22
  31. ipex_llm/transformers/convert_ipex.py +8 -1
  32. ipex_llm/transformers/low_bit_linear.py +5 -5
  33. ipex_llm/transformers/models/baichuan.py +8 -38
  34. ipex_llm/transformers/models/bert.py +2 -13
  35. ipex_llm/transformers/models/chatglm2.py +8 -31
  36. ipex_llm/transformers/models/chatglm4.py +9 -4
  37. ipex_llm/transformers/models/chatglm4v.py +1 -1
  38. ipex_llm/transformers/models/common.py +3 -1
  39. ipex_llm/transformers/models/glm.py +1 -1
  40. ipex_llm/transformers/models/internlm.py +6 -18
  41. ipex_llm/transformers/models/llama.py +1 -1
  42. ipex_llm/transformers/models/minicpm.py +1 -1
  43. ipex_llm/transformers/models/minicpm3.py +3 -1
  44. ipex_llm/transformers/models/mistral.py +1 -1
  45. ipex_llm/transformers/models/mllama.py +1 -1
  46. ipex_llm/transformers/models/phi3.py +8 -21
  47. ipex_llm/transformers/models/qwen.py +4 -2
  48. ipex_llm/transformers/models/qwen2.py +25 -309
  49. ipex_llm/transformers/models/qwen2_moe.py +4 -2
  50. ipex_llm/transformers/models/qwen2_vl.py +3 -1
  51. ipex_llm/transformers/models/stablelm.py +3 -1
  52. ipex_llm/transformers/models/starcoder2.py +3 -1
  53. ipex_llm/transformers/models/utils.py +7 -23
  54. ipex_llm/transformers/models/yuan.py +2 -1
  55. ipex_llm/transformers/npu_model.py +7 -3
  56. {ipex_llm-2.2.0b20250108.dist-info → ipex_llm-2.2.0b20250110.dist-info}/METADATA +20 -20
  57. {ipex_llm-2.2.0b20250108.dist-info → ipex_llm-2.2.0b20250110.dist-info}/RECORD +63 -63
  58. {ipex_llm-2.2.0b20250108.data → ipex_llm-2.2.0b20250110.data}/scripts/ipex-llm-init.bat +0 -0
  59. {ipex_llm-2.2.0b20250108.data → ipex_llm-2.2.0b20250110.data}/scripts/llm-chat.ps1 +0 -0
  60. {ipex_llm-2.2.0b20250108.data → ipex_llm-2.2.0b20250110.data}/scripts/llm-cli.ps1 +0 -0
  61. {ipex_llm-2.2.0b20250108.dist-info → ipex_llm-2.2.0b20250110.dist-info}/WHEEL +0 -0
  62. {ipex_llm-2.2.0b20250108.dist-info → ipex_llm-2.2.0b20250110.dist-info}/entry_points.txt +0 -0
  63. {ipex_llm-2.2.0b20250108.dist-info → ipex_llm-2.2.0b20250110.dist-info}/top_level.txt +0 -0
Binary file
ipex_llm/libs/bloom.dll CHANGED
Binary file
Binary file
ipex_llm/libs/gptneox.dll CHANGED
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
ipex_llm/libs/llama.dll CHANGED
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -1325,7 +1325,6 @@ def _optimize_post(model):
1325
1325
  modeling_module_name = model.__class__.__module__
1326
1326
  module = importlib.import_module(modeling_module_name)
1327
1327
  from ipex_llm.transformers.models.chatglm2 import chatglm2_attention_forward
1328
- from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
1329
1328
  from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
1330
1329
  from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward
1331
1330
  from ipex_llm.transformers.models.chatglm2 import mlp_forward
@@ -1338,9 +1337,7 @@ def _optimize_post(model):
1338
1337
  convert_forward(model,
1339
1338
  module.ChatGLMModel,
1340
1339
  chatglm2_model_forward)
1341
- convert_forward(model,
1342
- module.RMSNorm,
1343
- chatglm_rms_norm_forward)
1340
+ convert_forward(model, module.RMSNorm, rms_norm_forward)
1344
1341
  convert_forward(model, module.MLP, mlp_forward)
1345
1342
  # for codegeex-nano
1346
1343
  if hasattr(model.config, "rope_ratio"):
@@ -1358,8 +1355,7 @@ def _optimize_post(model):
1358
1355
  # glm4 family
1359
1356
  modeling_module_name = model.__class__.__module__
1360
1357
  module = importlib.import_module(modeling_module_name)
1361
- from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
1362
- convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward)
1358
+ convert_forward(model, module.RMSNorm, rms_norm_forward)
1363
1359
 
1364
1360
  if hasattr(model.transformer, "vision"):
1365
1361
  # glm4 vision family
@@ -1448,8 +1444,8 @@ def _optimize_post(model):
1448
1444
  elif model.config.model_type == "baichuan":
1449
1445
  modeling_module_name = model.__class__.__module__
1450
1446
  module = importlib.import_module(modeling_module_name)
1451
- from ipex_llm.transformers.models.baichuan import baichuan_mlp_forward
1452
- convert_forward(model, module.MLP, baichuan_mlp_forward)
1447
+ convert_forward(model, module.RMSNorm, rms_norm_forward)
1448
+ convert_forward(model, module.MLP, mlp_silu_forward)
1453
1449
 
1454
1450
  if model.config.hidden_size in [4096, 2048]:
1455
1451
  # baichuan-7B and baichuan2-7B
@@ -1458,7 +1454,6 @@ def _optimize_post(model):
1458
1454
  for i in range(len(model.model.layers)):
1459
1455
  setattr(model.model.layers[i].self_attn, "layer_idx", i)
1460
1456
  convert_forward(model, module.Attention, baichuan_attention_forward_7b)
1461
- convert_forward(model, module.RMSNorm, rms_norm_forward)
1462
1457
  if model.config.vocab_size == 125696:
1463
1458
  # baichuan2-7B
1464
1459
  convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
@@ -1468,9 +1463,7 @@ def _optimize_post(model):
1468
1463
  elif model.config.hidden_size == 5120:
1469
1464
  # baichuan-13B and baichuan2-13B
1470
1465
  from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b
1471
- from ipex_llm.transformers.models.baichuan import baichuan_13b_rms_norm_forward
1472
1466
  convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b)
1473
- convert_forward(model, module.RMSNorm, baichuan_13b_rms_norm_forward)
1474
1467
 
1475
1468
  if model.config.vocab_size == 125696:
1476
1469
  # baichaun2-13B
@@ -1565,7 +1558,6 @@ def _optimize_post(model):
1565
1558
  from ipex_llm.transformers.models.qwen import qwen_attention_forward
1566
1559
  from ipex_llm.transformers.models.qwen import qwen_attention_forward_registered
1567
1560
  from ipex_llm.transformers.models.qwen import qwen_mlp_forward
1568
- from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
1569
1561
  from ipex_llm.transformers.models.qwen import qwen_model_forward
1570
1562
  if model.config.max_position_embeddings == 8192 \
1571
1563
  and model.config.hidden_size == 4096:
@@ -1580,7 +1572,7 @@ def _optimize_post(model):
1580
1572
  )
1581
1573
  convert_forward(model,
1582
1574
  module.RMSNorm,
1583
- chatglm_rms_norm_forward)
1575
+ rms_norm_forward)
1584
1576
  convert_forward(model,
1585
1577
  module.QWenMLP,
1586
1578
  qwen_mlp_forward)
@@ -1598,6 +1590,9 @@ def _optimize_post(model):
1598
1590
  convert_forward(model,
1599
1591
  module.Qwen2ForCausalLM,
1600
1592
  qwen2_causal_lm_forward)
1593
+ convert_forward(model,
1594
+ module.Qwen2Model,
1595
+ qwen2_model_forward)
1601
1596
  convert_forward(model,
1602
1597
  module.Qwen2RMSNorm,
1603
1598
  rms_norm_forward)
@@ -1610,12 +1605,6 @@ def _optimize_post(model):
1610
1605
  convert_forward(model,
1611
1606
  module.Qwen2SdpaAttention,
1612
1607
  qwen2_attention_forward)
1613
- if version.parse(trans_version) >= version.parse("4.42"):
1614
- from ipex_llm.transformers.models.qwen2 import qwen2_model_forward_4_42
1615
- convert_forward(model, module.Qwen2Model, qwen2_model_forward_4_42)
1616
- else:
1617
- from ipex_llm.transformers.models.qwen2 import qwen2_model_forward
1618
- convert_forward(model, module.Qwen2Model, qwen2_model_forward)
1619
1608
  elif model.config.model_type == "qwen2_moe":
1620
1609
  # for Qwen1.5-MOE-A2.7B
1621
1610
  modeling_module_name = model.__class__.__module__
@@ -1827,9 +1816,7 @@ def _optimize_post(model):
1827
1816
  from ipex_llm.transformers.models.phi3 import attention_forward
1828
1817
  convert_forward(model, module.Phi3Attention, attention_forward)
1829
1818
  convert_forward(model, module.Phi3SdpaAttention, attention_forward)
1830
- from ipex_llm.transformers.models.phi3 import mlp_forward
1831
- convert_forward(model, module.Phi3MLP, mlp_forward)
1832
- from ipex_llm.transformers.models.common import rms_norm_forward
1819
+ convert_forward(model, module.Phi3MLP, mlp_silu_forward)
1833
1820
  convert_forward(model, module.Phi3RMSNorm, rms_norm_forward)
1834
1821
  if model.config.model_type == "phi3":
1835
1822
  from ipex_llm.transformers.models.phi3 import phi3_model_forward_wrapper
@@ -52,7 +52,14 @@ import os
52
52
 
53
53
 
54
54
  def _ipex_optimize_rmsnorm(_model, supported_classes, is_tpp=False, is_woq=False):
55
- from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm
55
+ try:
56
+ # old version use name `_IPEXRMSNorm`
57
+ from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion \
58
+ import _IPEXRMSNorm
59
+ except ImportError:
60
+ # new version use name `_IPEXRMSNormCPU`
61
+ from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion \
62
+ import _IPEXRMSNormCPU as _IPEXRMSNorm
56
63
  for supported_class in supported_classes:
57
64
  lowering_class_cpu(
58
65
  _model,
@@ -47,7 +47,7 @@ import os
47
47
  import torch
48
48
  import torch.distributed
49
49
  import torch.nn.functional as F
50
- from torch import Tensor, device, dtype, nn
50
+ from torch import Tensor, dtype, nn
51
51
  from operator import mul
52
52
  from functools import reduce
53
53
  from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
@@ -294,10 +294,10 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
294
294
  if hard_condition:
295
295
  return (
296
296
  batch_size > 1
297
- or (device in ["arc"] and qtype in [SYM_INT8, FP4])
298
- or (device in ["arc", "mtl"] and qtype in [FP8E4])
299
- or (device in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0)
300
- or (device in ["bmg"] and qtype in [SYM_INT4, FP8E5])
297
+ or (device_name in ["arc"] and qtype in [SYM_INT8, FP4])
298
+ or (device_name in ["arc", "mtl"] and qtype in [FP8E4])
299
+ or (device_name in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0)
300
+ or (device_name in ["bmg"] and qtype in [SYM_INT4, FP8E5])
301
301
  )
302
302
  return False
303
303
 
@@ -30,8 +30,7 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp
30
30
  from ipex_llm.transformers.models.utils import update_past_key_value
31
31
  from ipex_llm.transformers.models.utils import should_use_fuse_rope
32
32
  from ipex_llm.transformers.models.utils import use_sdp
33
- from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
34
- from ipex_llm.transformers.models.utils import mlp_fusion_check
33
+ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
35
34
  from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
36
35
  from ipex_llm.transformers.kv import DynamicCompressFp8Cache, DynamicCompressCache
37
36
  import warnings
@@ -47,38 +46,6 @@ def pre_compute_inv_freq(module: torch.nn.Module):
47
46
  module.register_buffer("inv_freq", inv_freq, persistent=False)
48
47
 
49
48
 
50
- def baichuan_13b_rms_norm_forward(self, hidden_states):
51
- if hidden_states.device.type == "xpu" and not (self.training or hidden_states.requires_grad):
52
- import xe_addons
53
- x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
54
- output = xe_addons.rms_norm(self.weight, x_2d, self.epsilon)
55
- return output.reshape(hidden_states.shape)
56
-
57
- input_dtype = hidden_states.dtype
58
- hidden_states = hidden_states.to(torch.float32)
59
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
60
- hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
61
- return self.weight * hidden_states.to(input_dtype)
62
-
63
-
64
- def baichuan_mlp_forward(
65
- self,
66
- x: torch.Tensor,
67
- ) -> torch.Tensor:
68
- x_2d = x.view(-1, x.shape[-1])
69
- qtype = getattr(self.gate_proj, "qtype", None)
70
- if mlp_fusion_check(x_2d, qtype, self.training):
71
- import xe_linear
72
- if not x_2d.is_contiguous():
73
- x_2d = x_2d.contiguous()
74
- return self.down_proj(xe_linear.mlp_forward_xpu(
75
- x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
76
- x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
77
- SILU, qtype
78
- ))
79
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
80
-
81
-
82
49
  def baichuan_model_7b_forward(
83
50
  self,
84
51
  input_ids: torch.LongTensor = None,
@@ -105,7 +72,9 @@ def baichuan_model_7b_forward(
105
72
  if use_cache:
106
73
  inputs = input_ids if input_ids is not None else inputs_embeds
107
74
  use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
108
- use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
75
+ use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
76
+ self.config.num_attention_heads,
77
+ self.config.num_attention_heads)
109
78
  if use_compress_kv and not isinstance(past_key_values,
110
79
  DynamicCompressCache):
111
80
  if use_quantize_kv:
@@ -278,8 +247,6 @@ def baichuan_attention_forward_7b(
278
247
  key_states = key_states.to(hidden_states.dtype)
279
248
 
280
249
  # IPEX-LLM OPT: kv cache and quantize kv
281
- use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
282
-
283
250
  # [CompressKV]
284
251
  if use_compresskv:
285
252
  enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
@@ -290,6 +257,8 @@ def baichuan_attention_forward_7b(
290
257
  query_states, attention_mask, 1,
291
258
  self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
292
259
  else:
260
+ use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states,
261
+ self.num_heads, self.num_heads)
293
262
  key_states, value_states = update_past_key_value(
294
263
  past_key_value, key_states, value_states,
295
264
  kv_seq_len, use_quantize_kv, device
@@ -340,7 +309,8 @@ def baichuan_attention_forward_13b(
340
309
  kv_seq_len += past_key_value[0].shape[2]
341
310
 
342
311
  # IPEX-LLM OPT: kv cache and quantize kv
343
- use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
312
+ use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states,
313
+ self.num_heads, self.num_heads)
344
314
  key_states, value_states = update_past_key_value(
345
315
  past_key_value, key_states, value_states,
346
316
  kv_seq_len, use_quantize_kv, device
@@ -36,24 +36,13 @@ import math
36
36
  import torch
37
37
  from typing import Optional, Tuple
38
38
  from transformers.models.bert.modeling_bert import BertSelfAttention, BertEncoder
39
+ from ipex_llm.transformers.models.common import merge_linear
39
40
  from ipex_llm.utils.common import invalidInputError
40
41
 
41
42
 
42
43
  def merge_qkv(module: torch.nn.Module):
43
44
  if isinstance(module, BertSelfAttention):
44
- q_w = module.query.weight.data
45
- k_w = module.key.weight.data
46
- v_w = module.value.weight.data
47
- q_b = module.query.bias.data
48
- k_b = module.key.bias.data
49
- v_b = module.value.bias.data
50
- new_w = torch.cat([q_w, k_w, v_w], dim=0)
51
- new_b = torch.cat([q_b, k_b, v_b], dim=-1)
52
- qkv = torch.nn.Linear(0, 0, bias=True)
53
- qkv.weight = torch.nn.Parameter(new_w, requires_grad=False)
54
- qkv.bias = torch.nn.Parameter(new_b, requires_grad=False)
55
- qkv.in_features = module.query.in_features
56
- qkv.out_features = module.query.out_features * 3
45
+ qkv = merge_linear([module.query, module.key, module.value])
57
46
  module.qkv = qkv
58
47
  del module.query
59
48
  del module.key
@@ -33,34 +33,6 @@ from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cac
33
33
  KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
34
34
 
35
35
 
36
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
37
- """
38
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states
39
- go from (batch, num_key_value_heads, seqlen, head_dim) to
40
- (batch, num_attention_heads, seqlen, head_dim)
41
- """
42
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
43
- if n_rep == 1:
44
- return hidden_states
45
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
46
- n_rep, slen, head_dim)
47
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
48
-
49
-
50
- def chatglm_rms_norm_forward(self, hidden_states):
51
- if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
52
- import xe_addons
53
- x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
54
- output = xe_addons.rms_norm(self.weight, x_2d, self.eps)
55
- return output.reshape(hidden_states.shape)
56
-
57
- input_dtype = hidden_states.dtype
58
- hidden_states = hidden_states.to(torch.float32)
59
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
60
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
61
- return self.weight * hidden_states.to(input_dtype)
62
-
63
-
64
36
  def chatglm2_model_forward(
65
37
  self,
66
38
  input_ids,
@@ -91,8 +63,13 @@ def chatglm2_model_forward(
91
63
 
92
64
  if use_cache:
93
65
  use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
66
+ n_heads = self.config.num_attention_heads
67
+ if self.config.multi_query_attention:
68
+ n_kv_heads = self.config.multi_query_group_num
69
+ else:
70
+ n_kv_heads = n_heads
94
71
  use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
95
- input_ids)
72
+ input_ids, n_heads, n_kv_heads)
96
73
  if use_compress_kv and not isinstance(past_key_values,
97
74
  DynamicCompressCache):
98
75
  if use_quantize_kv:
@@ -285,8 +262,6 @@ def chatglm2_attention_forward(
285
262
  key_states[..., :rot_dim] = k_rot[...]
286
263
 
287
264
  # IPEX-LLM OPT: kv cache and quantize kv
288
- use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
289
-
290
265
  # [CompressKV]
291
266
  if use_compresskv:
292
267
  from transformers.configuration_utils import PretrainedConfig
@@ -300,6 +275,8 @@ def chatglm2_attention_forward(
300
275
  self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
301
276
  )
302
277
  else:
278
+ use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states,
279
+ n_head, n_kv_head)
303
280
  key_states, value_states = update_past_key_value(
304
281
  past_key_value, key_states, value_states,
305
282
  kv_seq_len, use_quantize_kv, hidden_states.device
@@ -55,8 +55,13 @@ def chatglm4_model_forward(
55
55
  if use_cache:
56
56
  inputs = input_ids if input_ids is not None else inputs_embeds
57
57
  use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
58
- use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
59
- inputs)
58
+ n_heads = self.config.num_attention_heads
59
+ if self.config.multi_query_attention:
60
+ n_kv_heads = self.config.multi_query_group_num
61
+ else:
62
+ n_kv_heads = n_heads
63
+ use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj, inputs,
64
+ n_heads, n_kv_heads)
60
65
  if use_compress_kv and not isinstance(past_key_values,
61
66
  DynamicCompressCache):
62
67
  if use_quantize_kv:
@@ -211,8 +216,6 @@ def chatglm4_attention_forward(
211
216
  key_states[..., :rot_dim] = k_rot[...]
212
217
 
213
218
  # IPEX-LLM OPT: kv cache and quantize kv
214
- use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
215
-
216
219
  # [CompressKV]
217
220
  if use_compresskv:
218
221
  from transformers.configuration_utils import PretrainedConfig
@@ -226,6 +229,8 @@ def chatglm4_attention_forward(
226
229
  self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
227
230
  )
228
231
  else:
232
+ use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states,
233
+ n_head, n_kv_head)
229
234
  key_states, value_states = update_past_key_value(
230
235
  past_key_value, key_states, value_states,
231
236
  kv_seq_len, use_quantize_kv, hidden_states.device
@@ -230,7 +230,7 @@ def chatglm4v_attention_forward(
230
230
  key_states[..., :rot_dim] = k_rot[...]
231
231
 
232
232
  # IPEX-LLM OPT: kv cache and quantize kv
233
- use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
233
+ use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states, n_head, n_kv_head)
234
234
  key_states, value_states = update_past_key_value(
235
235
  past_key_value, key_states, value_states,
236
236
  kv_seq_len, use_quantize_kv, hidden_states.device
@@ -157,8 +157,10 @@ def rms_norm_forward(self, hidden_states: torch.Tensor):
157
157
  weight = self.weight
158
158
  if hasattr(self, "variance_epsilon"):
159
159
  eps = self.variance_epsilon
160
- else:
160
+ elif hasattr(self, "epsilon"):
161
161
  eps = self.epsilon
162
+ else:
163
+ eps = self.eps
162
164
 
163
165
  if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
164
166
  import xe_addons
@@ -147,7 +147,7 @@ def glm_model_forward_wrapper(origin_forward):
147
147
  use_cache = use_cache if use_cache is not None else self.config.use_cache
148
148
  use_cache = use_cache or inputs.device.type == 'xpu'
149
149
  use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
150
- self.config.num_attention_heads //
150
+ self.config.num_attention_heads,
151
151
  self.config.num_key_value_heads)
152
152
 
153
153
  if use_cache:
@@ -87,7 +87,8 @@ def internlm_attention_forward(
87
87
  )
88
88
 
89
89
  # IPEX-LLM OPT: kv cache and quantzie kv cache
90
- use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states)
90
+ use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states,
91
+ self.num_heads, self.num_heads)
91
92
  key_states, value_states = update_past_key_value(
92
93
  past_key_value, key_states, value_states,
93
94
  kv_seq_len, use_quantize_kv, hidden_states.device
@@ -112,21 +113,6 @@ def internlm_attention_forward(
112
113
  return attn_output, attn_weights, past_key_value
113
114
 
114
115
 
115
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
116
- """
117
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
118
- The hidden states go from (batch,
119
- num_key_value_heads, seqlen, head_dim) to
120
- (batch, num_attention_heads, seqlen, head_dim)
121
- """
122
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
123
- if n_rep == 1:
124
- return hidden_states
125
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
126
- n_rep, slen, head_dim)
127
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
128
-
129
-
130
116
  def internlm2_attention_forward(
131
117
  self,
132
118
  hidden_states: torch.Tensor,
@@ -171,7 +157,8 @@ def internlm2_attention_forward(
171
157
  )
172
158
 
173
159
  # IPEX-LLM OPT: kv cache and quantzie kv cache
174
- use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
160
+ use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states,
161
+ self.num_heads, self.num_key_value_heads)
175
162
  key_states, value_states = update_past_key_value(
176
163
  past_key_value, key_states, value_states,
177
164
  kv_seq_len, use_quantize_kv, hidden_states.device
@@ -346,7 +333,8 @@ def internlm_xcomposser2_attention_forward(
346
333
  query_states, key_states, cos, sin, position_ids, "internlm")
347
334
 
348
335
  # IPEX-LLM OPT: kv cache and quantzie kv cache
349
- use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
336
+ use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states,
337
+ self.num_heads, self.num_key_value_heads)
350
338
  key_states, value_states = update_past_key_value(
351
339
  past_key_value, key_states, value_states,
352
340
  kv_seq_len, use_quantize_kv, device
@@ -72,7 +72,7 @@ def llama_model_forward(
72
72
  use_cache = True if inputs.device.type == "xpu" else use_cache
73
73
  use_quantize_kv = use_quantize_kv_cache(
74
74
  self.layers[0].mlp.down_proj, inputs,
75
- self.config.num_attention_heads // self.config.num_key_value_heads
75
+ self.config.num_attention_heads, self.config.num_key_value_heads
76
76
  )
77
77
  use_compresskv = should_use_compresskv(inputs, inputs.shape[1]) or \
78
78
  isinstance(past_key_values, DynamicCompressCache)
@@ -159,7 +159,7 @@ def minicpm_model_forward_wrapper(origin_forward):
159
159
  # IPEX-LLM OPT: kv cache and quantize kv cache
160
160
  inputs = input_ids if input_ids is not None else inputs_embeds
161
161
  use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
162
- self.config.num_attention_heads //
162
+ self.config.num_attention_heads,
163
163
  self.config.num_key_value_heads)
164
164
  use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
165
165
  isinstance(past_key_values, DynamicCompressCache)
@@ -66,7 +66,9 @@ def minicpm3_model_forward_wrapper(origin_forward):
66
66
  inputs = input_ids if input_ids is not None else inputs_embeds
67
67
  use_cache = use_cache if use_cache is not None else self.config.use_cache
68
68
  use_cache = True if inputs.device.type == "xpu" else use_cache
69
- use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
69
+ num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
70
+ use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
71
+ num_heads, num_kv_heads)
70
72
  if use_cache:
71
73
  if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
72
74
  past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
@@ -71,7 +71,7 @@ def mistral_model_forward(
71
71
  use_cache = use_cache if use_cache is not None else self.config.use_cache
72
72
  use_cache = use_cache or inputs.device.type == 'xpu'
73
73
  use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
74
- self.config.num_attention_heads //
74
+ self.config.num_attention_heads,
75
75
  self.config.num_key_value_heads)
76
76
  use_compress_kv = should_use_compresskv(inputs, inputs.size(1)) or \
77
77
  isinstance(past_key_values, DynamicCompressCache)
@@ -113,7 +113,7 @@ def mllama_text_model_forward(
113
113
  use_cache = True if inputs.device.type == "xpu" else use_cache
114
114
  use_quantize_kv = use_quantize_kv_cache(
115
115
  self.layers[0].mlp.down_proj, inputs,
116
- self.config.num_attention_heads // self.config.num_key_value_heads
116
+ self.config.num_attention_heads, self.config.num_key_value_heads
117
117
  )
118
118
  if use_cache:
119
119
  if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
@@ -39,7 +39,6 @@ import warnings
39
39
  from ipex_llm.transformers.models.common import attention_softmax
40
40
  from ipex_llm.transformers.models.common import scaled_dot_product_attention
41
41
  from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
42
- from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
43
42
  from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
44
43
  from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
45
44
  from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
@@ -213,24 +212,8 @@ def split_mlp(module: torch.nn.Module):
213
212
 
214
213
  del module.gate_up_proj
215
214
 
216
-
217
- def mlp_forward(
218
- self,
219
- hidden_states: torch.FloatTensor
220
- ) -> torch.FloatTensor:
221
- x_2d = hidden_states.view(-1, hidden_states.shape[-1])
222
- qtype = getattr(self.gate_proj, "qtype", None)
223
- if mlp_fusion_check(x_2d, qtype, self.training):
224
- x_2d = x_2d.contiguous()
225
- import xe_linear
226
- return self.down_proj(xe_linear.mlp_forward_xpu(
227
- x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
228
- x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_features,
229
- SILU, qtype
230
- ))
231
- return self.down_proj(
232
- self.activation_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)
233
- )
215
+ # rename activation function
216
+ module.act_fn = module.activation_fn
234
217
 
235
218
 
236
219
  def phi3_model_forward_wrapper(origin_model_forward):
@@ -249,7 +232,9 @@ def phi3_model_forward_wrapper(origin_model_forward):
249
232
  # IPEX-LLM OPT: kv cache and quantize kv cache and sdp
250
233
  use_cache = use_cache if use_cache is not None else self.config.use_cache
251
234
  inputs = input_ids if input_ids is not None else inputs_embeds
252
- use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
235
+ num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
236
+ use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
237
+ num_heads, num_kv_heads)
253
238
  use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
254
239
  isinstance(past_key_values, DynamicCompressCache)
255
240
  if use_cache:
@@ -305,7 +290,9 @@ def phi3v_model_forward_wrapper(origin_model_forward):
305
290
  ):
306
291
  # IPEX-LLM OPT: kv cache and quantize kv cache and sdp
307
292
  use_cache = use_cache if use_cache is not None else self.config.use_cache
308
- use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids)
293
+ num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
294
+ use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids,
295
+ num_heads, num_kv_heads)
309
296
  if use_cache:
310
297
  if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
311
298
  past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)