ipex-llm 2.2.0b20250107__py3-none-manylinux2010_x86_64.whl → 2.2.0b20250109__py3-none-manylinux2010_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ipex_llm/libs/libbloom_amx.so +0 -0
  2. ipex_llm/libs/libbloom_avx.so +0 -0
  3. ipex_llm/libs/libbloom_avx2.so +0 -0
  4. ipex_llm/libs/libbloom_avx512.so +0 -0
  5. ipex_llm/libs/libbloom_avxvnni.so +0 -0
  6. ipex_llm/libs/libgptneox_amx.so +0 -0
  7. ipex_llm/libs/libgptneox_avx.so +0 -0
  8. ipex_llm/libs/libgptneox_avx2.so +0 -0
  9. ipex_llm/libs/libgptneox_avx512.so +0 -0
  10. ipex_llm/libs/libgptneox_avxvnni.so +0 -0
  11. ipex_llm/libs/libllama_amx.so +0 -0
  12. ipex_llm/libs/libllama_avx.so +0 -0
  13. ipex_llm/libs/libllama_avx2.so +0 -0
  14. ipex_llm/libs/libllama_avx512.so +0 -0
  15. ipex_llm/libs/libllama_avxvnni.so +0 -0
  16. ipex_llm/libs/libstarcoder_amx.so +0 -0
  17. ipex_llm/libs/libstarcoder_avx.so +0 -0
  18. ipex_llm/libs/libstarcoder_avx2.so +0 -0
  19. ipex_llm/libs/libstarcoder_avx512.so +0 -0
  20. ipex_llm/libs/libstarcoder_avxvnni.so +0 -0
  21. ipex_llm/libs/quantize-bloom +0 -0
  22. ipex_llm/libs/quantize-gptneox +0 -0
  23. ipex_llm/libs/quantize-llama +0 -0
  24. ipex_llm/libs/quantize-starcoder +0 -0
  25. ipex_llm/transformers/convert.py +20 -50
  26. ipex_llm/transformers/loader.py +1 -1
  27. ipex_llm/transformers/low_bit_linear.py +10 -25
  28. ipex_llm/transformers/model.py +0 -7
  29. ipex_llm/transformers/models/baichuan.py +7 -36
  30. ipex_llm/transformers/models/bert.py +2 -13
  31. ipex_llm/transformers/models/chatglm2.py +8 -31
  32. ipex_llm/transformers/models/chatglm4.py +9 -4
  33. ipex_llm/transformers/models/chatglm4v.py +2 -1
  34. ipex_llm/transformers/models/common.py +3 -1
  35. ipex_llm/transformers/models/glm.py +4 -2
  36. ipex_llm/transformers/models/internlm.py +6 -3
  37. ipex_llm/transformers/models/llama.py +2 -2
  38. ipex_llm/transformers/models/minicpm.py +3 -2
  39. ipex_llm/transformers/models/minicpm3.py +3 -1
  40. ipex_llm/transformers/models/minicpmv.py +1 -0
  41. ipex_llm/transformers/models/mistral.py +1 -1
  42. ipex_llm/transformers/models/mllama.py +1 -1
  43. ipex_llm/transformers/models/phi3.py +6 -2
  44. ipex_llm/transformers/models/qwen.py +4 -2
  45. ipex_llm/transformers/models/qwen2.py +4 -3
  46. ipex_llm/transformers/models/qwen2_moe.py +4 -2
  47. ipex_llm/transformers/models/qwen2_vl.py +3 -1
  48. ipex_llm/transformers/models/stablelm.py +3 -1
  49. ipex_llm/transformers/models/starcoder2.py +3 -1
  50. ipex_llm/transformers/models/utils.py +10 -19
  51. ipex_llm/transformers/models/yuan.py +2 -1
  52. ipex_llm/transformers/speculative.py +2 -14
  53. ipex_llm/transformers/utils.py +2 -14
  54. ipex_llm/transformers/xpu_ops.py +25 -19
  55. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250109.dist-info}/METADATA +20 -20
  56. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250109.dist-info}/RECORD +62 -63
  57. ipex_llm/transformers/models/gptj.py +0 -441
  58. {ipex_llm-2.2.0b20250107.data → ipex_llm-2.2.0b20250109.data}/scripts/ipex-llm-init +0 -0
  59. {ipex_llm-2.2.0b20250107.data → ipex_llm-2.2.0b20250109.data}/scripts/llm-chat +0 -0
  60. {ipex_llm-2.2.0b20250107.data → ipex_llm-2.2.0b20250109.data}/scripts/llm-cli +0 -0
  61. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250109.dist-info}/WHEEL +0 -0
  62. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250109.dist-info}/entry_points.txt +0 -0
  63. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250109.dist-info}/top_level.txt +0 -0
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -680,18 +680,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
680
680
  optimize_lm_head=optimize_lm_head
681
681
  )
682
682
  device = module.weight.data.device
683
- from ipex_llm.transformers.utils import get_ipex_version
684
- if get_ipex_version() < "2.1.10+xpu":
685
- new_linear._parameters['weight'] = nn.Parameter(module.weight)
686
- else:
687
- # only from 2.1, ipex provides matmul_bias_out
688
- # so we need to transpose weight
689
- new_weight = module.weight.transpose(0, 1).contiguous()
690
- new_linear._parameters['weight'] = nn.Parameter(new_weight)
691
- new_linear.weight_type = 2
683
+ new_linear._parameters['weight'] = nn.Parameter(module.weight)
692
684
  if module.bias is not None:
693
- new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
694
- .to(device)
685
+ new_linear._parameters['bias'] = nn.Parameter(module.bias.data).to(device)
695
686
  elif qtype == ggml_tensor_qtype["bf16"]:
696
687
  module.to(torch.bfloat16)
697
688
  if _USE_VLLM:
@@ -856,18 +847,9 @@ def replace_with_low_bit_linear_for_module(model, qtype, module_name=None,
856
847
  mp_group=mp_group,
857
848
  )
858
849
  device = module.weight.data.device
859
- from ipex_llm.transformers.utils import get_ipex_version
860
- if get_ipex_version() < "2.1.10+xpu":
861
- new_linear._parameters['weight'] = nn.Parameter(module.weight)
862
- else:
863
- # only from 2.1, ipex provides matmul_bias_out
864
- # so we need to transpose weight
865
- new_weight = module.weight.transpose(0, 1).contiguous()
866
- new_linear._parameters['weight'] = nn.Parameter(new_weight)
867
- new_linear.weight_type = 2
850
+ new_linear._parameters['weight'] = nn.Parameter(module.weight)
868
851
  if module.bias is not None:
869
- new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
870
- .to(device)
852
+ new_linear._parameters['bias'] = nn.Parameter(module.bias.data).to(device)
871
853
  elif qtype == ggml_tensor_qtype["bf16"]:
872
854
  module.to(torch.bfloat16)
873
855
  new_linear = BF16Linear(
@@ -1343,7 +1325,6 @@ def _optimize_post(model):
1343
1325
  modeling_module_name = model.__class__.__module__
1344
1326
  module = importlib.import_module(modeling_module_name)
1345
1327
  from ipex_llm.transformers.models.chatglm2 import chatglm2_attention_forward
1346
- from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
1347
1328
  from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
1348
1329
  from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward
1349
1330
  from ipex_llm.transformers.models.chatglm2 import mlp_forward
@@ -1356,9 +1337,7 @@ def _optimize_post(model):
1356
1337
  convert_forward(model,
1357
1338
  module.ChatGLMModel,
1358
1339
  chatglm2_model_forward)
1359
- convert_forward(model,
1360
- module.RMSNorm,
1361
- chatglm_rms_norm_forward)
1340
+ convert_forward(model, module.RMSNorm, rms_norm_forward)
1362
1341
  convert_forward(model, module.MLP, mlp_forward)
1363
1342
  # for codegeex-nano
1364
1343
  if hasattr(model.config, "rope_ratio"):
@@ -1376,8 +1355,7 @@ def _optimize_post(model):
1376
1355
  # glm4 family
1377
1356
  modeling_module_name = model.__class__.__module__
1378
1357
  module = importlib.import_module(modeling_module_name)
1379
- from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
1380
- convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward)
1358
+ convert_forward(model, module.RMSNorm, rms_norm_forward)
1381
1359
 
1382
1360
  if hasattr(model.transformer, "vision"):
1383
1361
  # glm4 vision family
@@ -1429,6 +1407,7 @@ def _optimize_post(model):
1429
1407
  convert_forward(model, module.GlmRMSNorm, rms_norm_forward)
1430
1408
  convert_forward(model, module.GlmMLP, mlp_silu_forward)
1431
1409
  convert_forward(model, module.GlmAttention, glm_attention_forward)
1410
+ convert_forward(model, module.GlmSdpaAttention, glm_attention_forward)
1432
1411
  glm_model_forward = glm_model_forward_wrapper(module.GlmModel.forward)
1433
1412
  convert_forward(model, module.GlmModel, glm_model_forward)
1434
1413
 
@@ -1437,10 +1416,12 @@ def _optimize_post(model):
1437
1416
  vision_module_name = model.model.vision.__class__.__module__
1438
1417
  vision_module = importlib.import_module(vision_module_name)
1439
1418
  from transformers.models.siglip.modeling_siglip import SiglipAttention
1419
+ from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
1440
1420
  from ipex_llm.transformers.models.chatglm4v import vision_model_forward
1441
1421
  from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
1442
1422
  convert_forward(model, vision_module.VisionModel, vision_model_forward)
1443
1423
  convert_forward(model, SiglipAttention, siglip_attention_forward)
1424
+ convert_forward(model, SiglipSdpaAttention, siglip_attention_forward)
1444
1425
 
1445
1426
  elif "mpt" in model.config.model_type:
1446
1427
  if model.config.architectures is not None:
@@ -1452,21 +1433,6 @@ def _optimize_post(model):
1452
1433
  module.MultiheadAttention,
1453
1434
  mpt_multihead_attention_forward
1454
1435
  )
1455
- elif "gptj" in model.config.model_type:
1456
- # dolly-v1-6b
1457
- modeling_module_name = model.__class__.__module__
1458
- module = importlib.import_module(modeling_module_name)
1459
- from ipex_llm.transformers.models.gptj import gptj_attention_forward, gptj_model_forward,\
1460
- gptj_block_forward
1461
- convert_forward(model,
1462
- module.GPTJAttention,
1463
- gptj_attention_forward)
1464
- convert_forward(model,
1465
- module.GPTJModel,
1466
- gptj_model_forward)
1467
- convert_forward(model,
1468
- module.GPTJBlock,
1469
- gptj_block_forward)
1470
1436
  elif "bloom" in model.config.model_type:
1471
1437
  modeling_module_name = model.__class__.__module__
1472
1438
  module = importlib.import_module(modeling_module_name)
@@ -1478,8 +1444,8 @@ def _optimize_post(model):
1478
1444
  elif model.config.model_type == "baichuan":
1479
1445
  modeling_module_name = model.__class__.__module__
1480
1446
  module = importlib.import_module(modeling_module_name)
1481
- from ipex_llm.transformers.models.baichuan import baichuan_mlp_forward
1482
- convert_forward(model, module.MLP, baichuan_mlp_forward)
1447
+ convert_forward(model, module.RMSNorm, rms_norm_forward)
1448
+ convert_forward(model, module.MLP, mlp_silu_forward)
1483
1449
 
1484
1450
  if model.config.hidden_size in [4096, 2048]:
1485
1451
  # baichuan-7B and baichuan2-7B
@@ -1488,7 +1454,6 @@ def _optimize_post(model):
1488
1454
  for i in range(len(model.model.layers)):
1489
1455
  setattr(model.model.layers[i].self_attn, "layer_idx", i)
1490
1456
  convert_forward(model, module.Attention, baichuan_attention_forward_7b)
1491
- convert_forward(model, module.RMSNorm, rms_norm_forward)
1492
1457
  if model.config.vocab_size == 125696:
1493
1458
  # baichuan2-7B
1494
1459
  convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
@@ -1498,9 +1463,7 @@ def _optimize_post(model):
1498
1463
  elif model.config.hidden_size == 5120:
1499
1464
  # baichuan-13B and baichuan2-13B
1500
1465
  from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b
1501
- from ipex_llm.transformers.models.baichuan import baichuan_13b_rms_norm_forward
1502
1466
  convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b)
1503
- convert_forward(model, module.RMSNorm, baichuan_13b_rms_norm_forward)
1504
1467
 
1505
1468
  if model.config.vocab_size == 125696:
1506
1469
  # baichaun2-13B
@@ -1595,7 +1558,6 @@ def _optimize_post(model):
1595
1558
  from ipex_llm.transformers.models.qwen import qwen_attention_forward
1596
1559
  from ipex_llm.transformers.models.qwen import qwen_attention_forward_registered
1597
1560
  from ipex_llm.transformers.models.qwen import qwen_mlp_forward
1598
- from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
1599
1561
  from ipex_llm.transformers.models.qwen import qwen_model_forward
1600
1562
  if model.config.max_position_embeddings == 8192 \
1601
1563
  and model.config.hidden_size == 4096:
@@ -1610,7 +1572,7 @@ def _optimize_post(model):
1610
1572
  )
1611
1573
  convert_forward(model,
1612
1574
  module.RMSNorm,
1613
- chatglm_rms_norm_forward)
1575
+ rms_norm_forward)
1614
1576
  convert_forward(model,
1615
1577
  module.QWenMLP,
1616
1578
  qwen_mlp_forward)
@@ -1691,8 +1653,10 @@ def _optimize_post(model):
1691
1653
  convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward)
1692
1654
  model.visual.get_dtype = MethodType(qwen2_vision_get_dtype, model.visual)
1693
1655
  convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward)
1656
+ convert_forward(model, module.VisionSdpaAttention, qwen2_vision_attention_forward)
1694
1657
  convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
1695
1658
  convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)
1659
+ convert_forward(model, module.Qwen2VLSdpaAttention, qwen2_vl_attention_forward)
1696
1660
  elif model.config.model_type == "aquila":
1697
1661
  modeling_module_name = model.__class__.__module__
1698
1662
  module = importlib.import_module(modeling_module_name)
@@ -1838,6 +1802,7 @@ def _optimize_post(model):
1838
1802
  from ipex_llm.transformers.models.starcoder2 import attention_forward
1839
1803
  from ipex_llm.transformers.models.starcoder2 import model_forward
1840
1804
  convert_forward(model, module.Starcoder2Attention, attention_forward)
1805
+ convert_forward(model, module.Starcoder2SdpaAttention, attention_forward)
1841
1806
  convert_forward(model, module.Starcoder2Model, model_forward)
1842
1807
  elif model.config.model_type == "phi":
1843
1808
  # for phi-2
@@ -1853,6 +1818,7 @@ def _optimize_post(model):
1853
1818
  module = importlib.import_module(modeling_module_name)
1854
1819
  from ipex_llm.transformers.models.phi3 import attention_forward
1855
1820
  convert_forward(model, module.Phi3Attention, attention_forward)
1821
+ convert_forward(model, module.Phi3SdpaAttention, attention_forward)
1856
1822
  from ipex_llm.transformers.models.phi3 import mlp_forward
1857
1823
  convert_forward(model, module.Phi3MLP, mlp_forward)
1858
1824
  from ipex_llm.transformers.models.common import rms_norm_forward
@@ -1896,6 +1862,8 @@ def _optimize_post(model):
1896
1862
  module.StableLmAttention,
1897
1863
  stablelm_attention_forward
1898
1864
  )
1865
+ if hasattr(module, "StableLmSdpaAttention"):
1866
+ convert_forward(model, module.StableLmSdpaAttention, stablelm_attention_forward)
1899
1867
  convert_forward(model,
1900
1868
  module.StableLmMLP,
1901
1869
  mlp_silu_forward)
@@ -1910,6 +1878,7 @@ def _optimize_post(model):
1910
1878
  from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper
1911
1879
  from ipex_llm.transformers.models.minicpm import minicpm_decoder_layer_forward
1912
1880
  convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
1881
+ convert_forward(model, module.MiniCPMSdpaAttention, minicpm_attention_forward)
1913
1882
  convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
1914
1883
  convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
1915
1884
  convert_forward(model, module.MiniCPMDecoderLayer, minicpm_decoder_layer_forward)
@@ -1925,6 +1894,7 @@ def _optimize_post(model):
1925
1894
  convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
1926
1895
  convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
1927
1896
  convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward)
1897
+ convert_forward(model, module.MiniCPMSdpaAttention, minicpm3_attention_forward)
1928
1898
  minicpm3_model_forward = minicpm3_model_forward_wrapper(module.MiniCPM3Model.forward)
1929
1899
  convert_forward(model, module.MiniCPM3Model, minicpm3_model_forward)
1930
1900
  elif model.config.model_type == "minicpmv":
@@ -22,7 +22,7 @@ import time
22
22
  from datetime import date
23
23
  import argparse
24
24
  from ipex_llm.utils.common import invalidInputError
25
- from transformers import AutoTokenizer, GPTJForCausalLM, LlamaTokenizer
25
+ from transformers import AutoTokenizer, LlamaTokenizer
26
26
 
27
27
  LLAMA_IDS = ['llama', 'vicuna', 'merged-baize']
28
28
 
@@ -51,8 +51,7 @@ from torch import Tensor, device, dtype, nn
51
51
  from operator import mul
52
52
  from functools import reduce
53
53
  from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
54
- from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_name, \
55
- get_ipex_version
54
+ from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_name
56
55
  from ipex_llm.transformers.convert import is_deepspeed_available, get_use_vllm
57
56
 
58
57
  T = TypeVar("T", bound="torch.nn.Module")
@@ -286,7 +285,7 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
286
285
  or (
287
286
  qtype in [SYM_INT8, FP4, FP6, Q4_K, Q6_K]
288
287
  and batch_size <= 48
289
- and device_name in ["arc", "pvc", "mtl", "lnl", "arl"]
288
+ and device_name in ["arc", "pvc", "mtl", "arl"]
290
289
  and x.shape[1] % 256 == 0
291
290
  and output_len % 32 == 0
292
291
  )
@@ -759,9 +758,9 @@ class FP16Linear(nn.Linear):
759
758
  self.weight_length = self.out_len * self.in_len
760
759
  self.qtype = ggml_tensor_qtype["fp16"]
761
760
  self.mp_group = mp_group
762
- # weigh_type = 1 means original weight
763
- # weigh_type = 2 means weight has been transposed
764
- # weigh_type = 3 means weight has been transposed by esimd method
761
+ # weight_type = 1 means original weight
762
+ # weight_type = 2 means weight has been transposed
763
+ # weight_type = 3 means weight has been transposed by esimd method
765
764
  self.weight_type = 1
766
765
  self.optimize_lm_head = optimize_lm_head
767
766
  self.disable_fp16_opt = False
@@ -775,28 +774,14 @@ class FP16Linear(nn.Linear):
775
774
 
776
775
  x = x.to(torch.float16)
777
776
  if self.bias is not None and self.bias.dtype != x.dtype:
778
- self.bias.data = self.bias.data.to(x.dtype)
777
+ self.bias.data = self.bias.data.to(x.dtype)
779
778
  if self.weight is not None and self.weight.dtype != x.dtype:
780
779
  self.weight.data = self.weight.data.to(x.dtype)
781
780
 
782
781
  if not self.use_esimd_kernel(x):
783
- if (
784
- get_ipex_version() < "2.1.10+xpu"
785
- or get_xpu_device_name(x.device) not in ["arc", "pvc"]
786
- or self.disable_fp16_opt
787
- ):
788
- if self.weight_type == 2:
789
- self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
790
- requires_grad=False)
791
- self.weight_type = 1
792
- result = F.linear(x, self.weight, self.bias)
793
- else:
794
- if self.weight_type == 1:
795
- self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
796
- requires_grad=False)
797
- self.weight_type = 2
798
- result = torch.ops.torch_ipex.matmul_bias_out(x.contiguous(),
799
- self.weight, self.bias)
782
+ invalidInputError(self.weight_type == 1, "weight_type should be 1")
783
+ result = F.linear(x, self.weight, self.bias)
784
+
800
785
  if self.mp_group is not None:
801
786
  if get_use_vllm():
802
787
  result = self.mp_group.all_reduce(result)
@@ -852,7 +837,7 @@ class FP16Linear(nn.Linear):
852
837
  if self.disable_fp16_opt:
853
838
  return False
854
839
  # esimd kernel can only be used for Arc and Flex
855
- if gpu_type not in ["arc", "flex"]:
840
+ if gpu_type not in ["arc"]:
856
841
  return False
857
842
  # now esimd kernel can only be used for specific cases (llama2-7b shape)
858
843
  if self.in_len == 11008 and self.out_features == 4096:
@@ -103,12 +103,6 @@ def save_low_bit(self, *args, **kwargs):
103
103
  self.to(origin_device)
104
104
 
105
105
 
106
- def _load_pre():
107
- from transformers import GPTJModel
108
- from ipex_llm.transformers.models.gptj import gptj_model_new_init
109
- GPTJModel.__init__ = gptj_model_new_init
110
-
111
-
112
106
  class _BaseAutoModelClass:
113
107
  HF_MODEL = None
114
108
 
@@ -495,7 +489,6 @@ class _BaseAutoModelClass:
495
489
  else:
496
490
  if quant_config is not None:
497
491
  kwargs["quantization_config"] = quant_config
498
- _load_pre()
499
492
  try:
500
493
  # To handle the input CUDA setting (such as 'device_map={"":0}'), ignore it
501
494
  kwargs.pop('device_map', None)
@@ -47,38 +47,6 @@ def pre_compute_inv_freq(module: torch.nn.Module):
47
47
  module.register_buffer("inv_freq", inv_freq, persistent=False)
48
48
 
49
49
 
50
- def baichuan_13b_rms_norm_forward(self, hidden_states):
51
- if hidden_states.device.type == "xpu" and not (self.training or hidden_states.requires_grad):
52
- import xe_addons
53
- x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
54
- output = xe_addons.rms_norm(self.weight, x_2d, self.epsilon)
55
- return output.reshape(hidden_states.shape)
56
-
57
- input_dtype = hidden_states.dtype
58
- hidden_states = hidden_states.to(torch.float32)
59
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
60
- hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
61
- return self.weight * hidden_states.to(input_dtype)
62
-
63
-
64
- def baichuan_mlp_forward(
65
- self,
66
- x: torch.Tensor,
67
- ) -> torch.Tensor:
68
- x_2d = x.view(-1, x.shape[-1])
69
- qtype = getattr(self.gate_proj, "qtype", None)
70
- if mlp_fusion_check(x_2d, qtype, self.training):
71
- import xe_linear
72
- if not x_2d.is_contiguous():
73
- x_2d = x_2d.contiguous()
74
- return self.down_proj(xe_linear.mlp_forward_xpu(
75
- x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
76
- x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
77
- SILU, qtype
78
- ))
79
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
80
-
81
-
82
50
  def baichuan_model_7b_forward(
83
51
  self,
84
52
  input_ids: torch.LongTensor = None,
@@ -105,7 +73,9 @@ def baichuan_model_7b_forward(
105
73
  if use_cache:
106
74
  inputs = input_ids if input_ids is not None else inputs_embeds
107
75
  use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
108
- use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
76
+ use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
77
+ self.config.num_attention_heads,
78
+ self.config.num_attention_heads)
109
79
  if use_compress_kv and not isinstance(past_key_values,
110
80
  DynamicCompressCache):
111
81
  if use_quantize_kv:
@@ -278,8 +248,6 @@ def baichuan_attention_forward_7b(
278
248
  key_states = key_states.to(hidden_states.dtype)
279
249
 
280
250
  # IPEX-LLM OPT: kv cache and quantize kv
281
- use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
282
-
283
251
  # [CompressKV]
284
252
  if use_compresskv:
285
253
  enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
@@ -290,6 +258,8 @@ def baichuan_attention_forward_7b(
290
258
  query_states, attention_mask, 1,
291
259
  self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
292
260
  else:
261
+ use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states,
262
+ self.num_heads, self.num_heads)
293
263
  key_states, value_states = update_past_key_value(
294
264
  past_key_value, key_states, value_states,
295
265
  kv_seq_len, use_quantize_kv, device
@@ -340,7 +310,8 @@ def baichuan_attention_forward_13b(
340
310
  kv_seq_len += past_key_value[0].shape[2]
341
311
 
342
312
  # IPEX-LLM OPT: kv cache and quantize kv
343
- use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
313
+ use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states,
314
+ self.num_heads, self.num_heads)
344
315
  key_states, value_states = update_past_key_value(
345
316
  past_key_value, key_states, value_states,
346
317
  kv_seq_len, use_quantize_kv, device
@@ -36,24 +36,13 @@ import math
36
36
  import torch
37
37
  from typing import Optional, Tuple
38
38
  from transformers.models.bert.modeling_bert import BertSelfAttention, BertEncoder
39
+ from ipex_llm.transformers.models.common import merge_linear
39
40
  from ipex_llm.utils.common import invalidInputError
40
41
 
41
42
 
42
43
  def merge_qkv(module: torch.nn.Module):
43
44
  if isinstance(module, BertSelfAttention):
44
- q_w = module.query.weight.data
45
- k_w = module.key.weight.data
46
- v_w = module.value.weight.data
47
- q_b = module.query.bias.data
48
- k_b = module.key.bias.data
49
- v_b = module.value.bias.data
50
- new_w = torch.cat([q_w, k_w, v_w], dim=0)
51
- new_b = torch.cat([q_b, k_b, v_b], dim=-1)
52
- qkv = torch.nn.Linear(0, 0, bias=True)
53
- qkv.weight = torch.nn.Parameter(new_w, requires_grad=False)
54
- qkv.bias = torch.nn.Parameter(new_b, requires_grad=False)
55
- qkv.in_features = module.query.in_features
56
- qkv.out_features = module.query.out_features * 3
45
+ qkv = merge_linear([module.query, module.key, module.value])
57
46
  module.qkv = qkv
58
47
  del module.query
59
48
  del module.key
@@ -33,34 +33,6 @@ from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cac
33
33
  KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
34
34
 
35
35
 
36
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
37
- """
38
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states
39
- go from (batch, num_key_value_heads, seqlen, head_dim) to
40
- (batch, num_attention_heads, seqlen, head_dim)
41
- """
42
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
43
- if n_rep == 1:
44
- return hidden_states
45
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
46
- n_rep, slen, head_dim)
47
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
48
-
49
-
50
- def chatglm_rms_norm_forward(self, hidden_states):
51
- if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
52
- import xe_addons
53
- x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
54
- output = xe_addons.rms_norm(self.weight, x_2d, self.eps)
55
- return output.reshape(hidden_states.shape)
56
-
57
- input_dtype = hidden_states.dtype
58
- hidden_states = hidden_states.to(torch.float32)
59
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
60
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
61
- return self.weight * hidden_states.to(input_dtype)
62
-
63
-
64
36
  def chatglm2_model_forward(
65
37
  self,
66
38
  input_ids,
@@ -91,8 +63,13 @@ def chatglm2_model_forward(
91
63
 
92
64
  if use_cache:
93
65
  use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
66
+ n_heads = self.config.num_attention_heads
67
+ if self.config.multi_query_attention:
68
+ n_kv_heads = self.config.multi_query_group_num
69
+ else:
70
+ n_kv_heads = n_heads
94
71
  use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
95
- input_ids)
72
+ input_ids, n_heads, n_kv_heads)
96
73
  if use_compress_kv and not isinstance(past_key_values,
97
74
  DynamicCompressCache):
98
75
  if use_quantize_kv:
@@ -285,8 +262,6 @@ def chatglm2_attention_forward(
285
262
  key_states[..., :rot_dim] = k_rot[...]
286
263
 
287
264
  # IPEX-LLM OPT: kv cache and quantize kv
288
- use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
289
-
290
265
  # [CompressKV]
291
266
  if use_compresskv:
292
267
  from transformers.configuration_utils import PretrainedConfig
@@ -300,6 +275,8 @@ def chatglm2_attention_forward(
300
275
  self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
301
276
  )
302
277
  else:
278
+ use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states,
279
+ n_head, n_kv_head)
303
280
  key_states, value_states = update_past_key_value(
304
281
  past_key_value, key_states, value_states,
305
282
  kv_seq_len, use_quantize_kv, hidden_states.device
@@ -55,8 +55,13 @@ def chatglm4_model_forward(
55
55
  if use_cache:
56
56
  inputs = input_ids if input_ids is not None else inputs_embeds
57
57
  use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
58
- use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
59
- inputs)
58
+ n_heads = self.config.num_attention_heads
59
+ if self.config.multi_query_attention:
60
+ n_kv_heads = self.config.multi_query_group_num
61
+ else:
62
+ n_kv_heads = n_heads
63
+ use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj, inputs,
64
+ n_heads, n_kv_heads)
60
65
  if use_compress_kv and not isinstance(past_key_values,
61
66
  DynamicCompressCache):
62
67
  if use_quantize_kv:
@@ -211,8 +216,6 @@ def chatglm4_attention_forward(
211
216
  key_states[..., :rot_dim] = k_rot[...]
212
217
 
213
218
  # IPEX-LLM OPT: kv cache and quantize kv
214
- use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
215
-
216
219
  # [CompressKV]
217
220
  if use_compresskv:
218
221
  from transformers.configuration_utils import PretrainedConfig
@@ -226,6 +229,8 @@ def chatglm4_attention_forward(
226
229
  self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
227
230
  )
228
231
  else:
232
+ use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states,
233
+ n_head, n_kv_head)
229
234
  key_states, value_states = update_past_key_value(
230
235
  past_key_value, key_states, value_states,
231
236
  kv_seq_len, use_quantize_kv, hidden_states.device
@@ -230,7 +230,7 @@ def chatglm4v_attention_forward(
230
230
  key_states[..., :rot_dim] = k_rot[...]
231
231
 
232
232
  # IPEX-LLM OPT: kv cache and quantize kv
233
- use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
233
+ use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states, n_head, n_kv_head)
234
234
  key_states, value_states = update_past_key_value(
235
235
  past_key_value, key_states, value_states,
236
236
  kv_seq_len, use_quantize_kv, hidden_states.device
@@ -301,6 +301,7 @@ def patch_embedding_forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L,
301
301
 
302
302
  def merge_qkv(module: torch.nn.Module):
303
303
  merge_qkv_base(module, "SiglipAttention")
304
+ merge_qkv_base(module, "SiglipSdpaAttention")
304
305
 
305
306
 
306
307
  def vision_model_forward(self: torch.nn.Module, image: torch.Tensor):
@@ -157,8 +157,10 @@ def rms_norm_forward(self, hidden_states: torch.Tensor):
157
157
  weight = self.weight
158
158
  if hasattr(self, "variance_epsilon"):
159
159
  eps = self.variance_epsilon
160
- else:
160
+ elif hasattr(self, "epsilon"):
161
161
  eps = self.epsilon
162
+ else:
163
+ eps = self.eps
162
164
 
163
165
  if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
164
166
  import xe_addons
@@ -37,6 +37,7 @@ import torch
37
37
 
38
38
  from typing import Optional, Tuple
39
39
  from transformers.cache_utils import Cache
40
+ from transformers.models.glm.modeling_glm import GlmAttention
40
41
  from transformers.models.glm.modeling_glm import apply_rotary_pos_emb
41
42
  from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
42
43
  from ipex_llm.transformers.models.common import merge_qkv_base
@@ -46,8 +47,9 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache
46
47
 
47
48
 
48
49
  def merge_qkv(module: torch.nn.Module):
49
- merge_qkv_base(module, "GlmAttention")
50
+ merge_qkv_base(module, GlmAttention)
50
51
  merge_qkv_base(module, "SiglipAttention")
52
+ merge_qkv_base(module, "SiglipSdpaAttention")
51
53
 
52
54
 
53
55
  def split_mlp(module: torch.nn.Module):
@@ -145,7 +147,7 @@ def glm_model_forward_wrapper(origin_forward):
145
147
  use_cache = use_cache if use_cache is not None else self.config.use_cache
146
148
  use_cache = use_cache or inputs.device.type == 'xpu'
147
149
  use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
148
- self.config.num_attention_heads //
150
+ self.config.num_attention_heads,
149
151
  self.config.num_key_value_heads)
150
152
 
151
153
  if use_cache: