ipex-llm 2.2.0b20250107__py3-none-manylinux2010_x86_64.whl → 2.2.0b20250108__py3-none-manylinux2010_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ipex_llm/libs/libbloom_amx.so +0 -0
  2. ipex_llm/libs/libbloom_avx.so +0 -0
  3. ipex_llm/libs/libbloom_avx2.so +0 -0
  4. ipex_llm/libs/libbloom_avx512.so +0 -0
  5. ipex_llm/libs/libbloom_avxvnni.so +0 -0
  6. ipex_llm/libs/libgptneox_amx.so +0 -0
  7. ipex_llm/libs/libgptneox_avx.so +0 -0
  8. ipex_llm/libs/libgptneox_avx2.so +0 -0
  9. ipex_llm/libs/libgptneox_avx512.so +0 -0
  10. ipex_llm/libs/libgptneox_avxvnni.so +0 -0
  11. ipex_llm/libs/libllama_amx.so +0 -0
  12. ipex_llm/libs/libllama_avx.so +0 -0
  13. ipex_llm/libs/libllama_avx2.so +0 -0
  14. ipex_llm/libs/libllama_avx512.so +0 -0
  15. ipex_llm/libs/libllama_avxvnni.so +0 -0
  16. ipex_llm/libs/libstarcoder_amx.so +0 -0
  17. ipex_llm/libs/libstarcoder_avx.so +0 -0
  18. ipex_llm/libs/libstarcoder_avx2.so +0 -0
  19. ipex_llm/libs/libstarcoder_avx512.so +0 -0
  20. ipex_llm/libs/libstarcoder_avxvnni.so +0 -0
  21. ipex_llm/libs/quantize-bloom +0 -0
  22. ipex_llm/libs/quantize-gptneox +0 -0
  23. ipex_llm/libs/quantize-llama +0 -0
  24. ipex_llm/libs/quantize-starcoder +0 -0
  25. ipex_llm/transformers/convert.py +15 -37
  26. ipex_llm/transformers/loader.py +1 -1
  27. ipex_llm/transformers/low_bit_linear.py +10 -25
  28. ipex_llm/transformers/model.py +0 -7
  29. ipex_llm/transformers/models/chatglm4v.py +1 -0
  30. ipex_llm/transformers/models/glm.py +3 -1
  31. ipex_llm/transformers/models/llama.py +1 -1
  32. ipex_llm/transformers/models/minicpm.py +2 -1
  33. ipex_llm/transformers/models/minicpmv.py +1 -0
  34. ipex_llm/transformers/models/utils.py +3 -16
  35. ipex_llm/transformers/speculative.py +2 -14
  36. ipex_llm/transformers/utils.py +2 -14
  37. ipex_llm/transformers/xpu_ops.py +25 -19
  38. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250108.dist-info}/METADATA +20 -20
  39. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250108.dist-info}/RECORD +45 -46
  40. ipex_llm/transformers/models/gptj.py +0 -441
  41. {ipex_llm-2.2.0b20250107.data → ipex_llm-2.2.0b20250108.data}/scripts/ipex-llm-init +0 -0
  42. {ipex_llm-2.2.0b20250107.data → ipex_llm-2.2.0b20250108.data}/scripts/llm-chat +0 -0
  43. {ipex_llm-2.2.0b20250107.data → ipex_llm-2.2.0b20250108.data}/scripts/llm-cli +0 -0
  44. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250108.dist-info}/WHEEL +0 -0
  45. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250108.dist-info}/entry_points.txt +0 -0
  46. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250108.dist-info}/top_level.txt +0 -0
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -680,18 +680,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
680
680
  optimize_lm_head=optimize_lm_head
681
681
  )
682
682
  device = module.weight.data.device
683
- from ipex_llm.transformers.utils import get_ipex_version
684
- if get_ipex_version() < "2.1.10+xpu":
685
- new_linear._parameters['weight'] = nn.Parameter(module.weight)
686
- else:
687
- # only from 2.1, ipex provides matmul_bias_out
688
- # so we need to transpose weight
689
- new_weight = module.weight.transpose(0, 1).contiguous()
690
- new_linear._parameters['weight'] = nn.Parameter(new_weight)
691
- new_linear.weight_type = 2
683
+ new_linear._parameters['weight'] = nn.Parameter(module.weight)
692
684
  if module.bias is not None:
693
- new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
694
- .to(device)
685
+ new_linear._parameters['bias'] = nn.Parameter(module.bias.data).to(device)
695
686
  elif qtype == ggml_tensor_qtype["bf16"]:
696
687
  module.to(torch.bfloat16)
697
688
  if _USE_VLLM:
@@ -856,18 +847,9 @@ def replace_with_low_bit_linear_for_module(model, qtype, module_name=None,
856
847
  mp_group=mp_group,
857
848
  )
858
849
  device = module.weight.data.device
859
- from ipex_llm.transformers.utils import get_ipex_version
860
- if get_ipex_version() < "2.1.10+xpu":
861
- new_linear._parameters['weight'] = nn.Parameter(module.weight)
862
- else:
863
- # only from 2.1, ipex provides matmul_bias_out
864
- # so we need to transpose weight
865
- new_weight = module.weight.transpose(0, 1).contiguous()
866
- new_linear._parameters['weight'] = nn.Parameter(new_weight)
867
- new_linear.weight_type = 2
850
+ new_linear._parameters['weight'] = nn.Parameter(module.weight)
868
851
  if module.bias is not None:
869
- new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
870
- .to(device)
852
+ new_linear._parameters['bias'] = nn.Parameter(module.bias.data).to(device)
871
853
  elif qtype == ggml_tensor_qtype["bf16"]:
872
854
  module.to(torch.bfloat16)
873
855
  new_linear = BF16Linear(
@@ -1429,6 +1411,7 @@ def _optimize_post(model):
1429
1411
  convert_forward(model, module.GlmRMSNorm, rms_norm_forward)
1430
1412
  convert_forward(model, module.GlmMLP, mlp_silu_forward)
1431
1413
  convert_forward(model, module.GlmAttention, glm_attention_forward)
1414
+ convert_forward(model, module.GlmSdpaAttention, glm_attention_forward)
1432
1415
  glm_model_forward = glm_model_forward_wrapper(module.GlmModel.forward)
1433
1416
  convert_forward(model, module.GlmModel, glm_model_forward)
1434
1417
 
@@ -1437,10 +1420,12 @@ def _optimize_post(model):
1437
1420
  vision_module_name = model.model.vision.__class__.__module__
1438
1421
  vision_module = importlib.import_module(vision_module_name)
1439
1422
  from transformers.models.siglip.modeling_siglip import SiglipAttention
1423
+ from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
1440
1424
  from ipex_llm.transformers.models.chatglm4v import vision_model_forward
1441
1425
  from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
1442
1426
  convert_forward(model, vision_module.VisionModel, vision_model_forward)
1443
1427
  convert_forward(model, SiglipAttention, siglip_attention_forward)
1428
+ convert_forward(model, SiglipSdpaAttention, siglip_attention_forward)
1444
1429
 
1445
1430
  elif "mpt" in model.config.model_type:
1446
1431
  if model.config.architectures is not None:
@@ -1452,21 +1437,6 @@ def _optimize_post(model):
1452
1437
  module.MultiheadAttention,
1453
1438
  mpt_multihead_attention_forward
1454
1439
  )
1455
- elif "gptj" in model.config.model_type:
1456
- # dolly-v1-6b
1457
- modeling_module_name = model.__class__.__module__
1458
- module = importlib.import_module(modeling_module_name)
1459
- from ipex_llm.transformers.models.gptj import gptj_attention_forward, gptj_model_forward,\
1460
- gptj_block_forward
1461
- convert_forward(model,
1462
- module.GPTJAttention,
1463
- gptj_attention_forward)
1464
- convert_forward(model,
1465
- module.GPTJModel,
1466
- gptj_model_forward)
1467
- convert_forward(model,
1468
- module.GPTJBlock,
1469
- gptj_block_forward)
1470
1440
  elif "bloom" in model.config.model_type:
1471
1441
  modeling_module_name = model.__class__.__module__
1472
1442
  module = importlib.import_module(modeling_module_name)
@@ -1691,8 +1661,10 @@ def _optimize_post(model):
1691
1661
  convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward)
1692
1662
  model.visual.get_dtype = MethodType(qwen2_vision_get_dtype, model.visual)
1693
1663
  convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward)
1664
+ convert_forward(model, module.VisionSdpaAttention, qwen2_vision_attention_forward)
1694
1665
  convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
1695
1666
  convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)
1667
+ convert_forward(model, module.Qwen2VLSdpaAttention, qwen2_vl_attention_forward)
1696
1668
  elif model.config.model_type == "aquila":
1697
1669
  modeling_module_name = model.__class__.__module__
1698
1670
  module = importlib.import_module(modeling_module_name)
@@ -1838,6 +1810,7 @@ def _optimize_post(model):
1838
1810
  from ipex_llm.transformers.models.starcoder2 import attention_forward
1839
1811
  from ipex_llm.transformers.models.starcoder2 import model_forward
1840
1812
  convert_forward(model, module.Starcoder2Attention, attention_forward)
1813
+ convert_forward(model, module.Starcoder2SdpaAttention, attention_forward)
1841
1814
  convert_forward(model, module.Starcoder2Model, model_forward)
1842
1815
  elif model.config.model_type == "phi":
1843
1816
  # for phi-2
@@ -1853,6 +1826,7 @@ def _optimize_post(model):
1853
1826
  module = importlib.import_module(modeling_module_name)
1854
1827
  from ipex_llm.transformers.models.phi3 import attention_forward
1855
1828
  convert_forward(model, module.Phi3Attention, attention_forward)
1829
+ convert_forward(model, module.Phi3SdpaAttention, attention_forward)
1856
1830
  from ipex_llm.transformers.models.phi3 import mlp_forward
1857
1831
  convert_forward(model, module.Phi3MLP, mlp_forward)
1858
1832
  from ipex_llm.transformers.models.common import rms_norm_forward
@@ -1896,6 +1870,8 @@ def _optimize_post(model):
1896
1870
  module.StableLmAttention,
1897
1871
  stablelm_attention_forward
1898
1872
  )
1873
+ if hasattr(module, "StableLmSdpaAttention"):
1874
+ convert_forward(model, module.StableLmSdpaAttention, stablelm_attention_forward)
1899
1875
  convert_forward(model,
1900
1876
  module.StableLmMLP,
1901
1877
  mlp_silu_forward)
@@ -1910,6 +1886,7 @@ def _optimize_post(model):
1910
1886
  from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper
1911
1887
  from ipex_llm.transformers.models.minicpm import minicpm_decoder_layer_forward
1912
1888
  convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
1889
+ convert_forward(model, module.MiniCPMSdpaAttention, minicpm_attention_forward)
1913
1890
  convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
1914
1891
  convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
1915
1892
  convert_forward(model, module.MiniCPMDecoderLayer, minicpm_decoder_layer_forward)
@@ -1925,6 +1902,7 @@ def _optimize_post(model):
1925
1902
  convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
1926
1903
  convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
1927
1904
  convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward)
1905
+ convert_forward(model, module.MiniCPMSdpaAttention, minicpm3_attention_forward)
1928
1906
  minicpm3_model_forward = minicpm3_model_forward_wrapper(module.MiniCPM3Model.forward)
1929
1907
  convert_forward(model, module.MiniCPM3Model, minicpm3_model_forward)
1930
1908
  elif model.config.model_type == "minicpmv":
@@ -22,7 +22,7 @@ import time
22
22
  from datetime import date
23
23
  import argparse
24
24
  from ipex_llm.utils.common import invalidInputError
25
- from transformers import AutoTokenizer, GPTJForCausalLM, LlamaTokenizer
25
+ from transformers import AutoTokenizer, LlamaTokenizer
26
26
 
27
27
  LLAMA_IDS = ['llama', 'vicuna', 'merged-baize']
28
28
 
@@ -51,8 +51,7 @@ from torch import Tensor, device, dtype, nn
51
51
  from operator import mul
52
52
  from functools import reduce
53
53
  from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
54
- from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_name, \
55
- get_ipex_version
54
+ from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_name
56
55
  from ipex_llm.transformers.convert import is_deepspeed_available, get_use_vllm
57
56
 
58
57
  T = TypeVar("T", bound="torch.nn.Module")
@@ -286,7 +285,7 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
286
285
  or (
287
286
  qtype in [SYM_INT8, FP4, FP6, Q4_K, Q6_K]
288
287
  and batch_size <= 48
289
- and device_name in ["arc", "pvc", "mtl", "lnl", "arl"]
288
+ and device_name in ["arc", "pvc", "mtl", "arl"]
290
289
  and x.shape[1] % 256 == 0
291
290
  and output_len % 32 == 0
292
291
  )
@@ -759,9 +758,9 @@ class FP16Linear(nn.Linear):
759
758
  self.weight_length = self.out_len * self.in_len
760
759
  self.qtype = ggml_tensor_qtype["fp16"]
761
760
  self.mp_group = mp_group
762
- # weigh_type = 1 means original weight
763
- # weigh_type = 2 means weight has been transposed
764
- # weigh_type = 3 means weight has been transposed by esimd method
761
+ # weight_type = 1 means original weight
762
+ # weight_type = 2 means weight has been transposed
763
+ # weight_type = 3 means weight has been transposed by esimd method
765
764
  self.weight_type = 1
766
765
  self.optimize_lm_head = optimize_lm_head
767
766
  self.disable_fp16_opt = False
@@ -775,28 +774,14 @@ class FP16Linear(nn.Linear):
775
774
 
776
775
  x = x.to(torch.float16)
777
776
  if self.bias is not None and self.bias.dtype != x.dtype:
778
- self.bias.data = self.bias.data.to(x.dtype)
777
+ self.bias.data = self.bias.data.to(x.dtype)
779
778
  if self.weight is not None and self.weight.dtype != x.dtype:
780
779
  self.weight.data = self.weight.data.to(x.dtype)
781
780
 
782
781
  if not self.use_esimd_kernel(x):
783
- if (
784
- get_ipex_version() < "2.1.10+xpu"
785
- or get_xpu_device_name(x.device) not in ["arc", "pvc"]
786
- or self.disable_fp16_opt
787
- ):
788
- if self.weight_type == 2:
789
- self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
790
- requires_grad=False)
791
- self.weight_type = 1
792
- result = F.linear(x, self.weight, self.bias)
793
- else:
794
- if self.weight_type == 1:
795
- self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
796
- requires_grad=False)
797
- self.weight_type = 2
798
- result = torch.ops.torch_ipex.matmul_bias_out(x.contiguous(),
799
- self.weight, self.bias)
782
+ invalidInputError(self.weight_type == 1, "weight_type should be 1")
783
+ result = F.linear(x, self.weight, self.bias)
784
+
800
785
  if self.mp_group is not None:
801
786
  if get_use_vllm():
802
787
  result = self.mp_group.all_reduce(result)
@@ -852,7 +837,7 @@ class FP16Linear(nn.Linear):
852
837
  if self.disable_fp16_opt:
853
838
  return False
854
839
  # esimd kernel can only be used for Arc and Flex
855
- if gpu_type not in ["arc", "flex"]:
840
+ if gpu_type not in ["arc"]:
856
841
  return False
857
842
  # now esimd kernel can only be used for specific cases (llama2-7b shape)
858
843
  if self.in_len == 11008 and self.out_features == 4096:
@@ -103,12 +103,6 @@ def save_low_bit(self, *args, **kwargs):
103
103
  self.to(origin_device)
104
104
 
105
105
 
106
- def _load_pre():
107
- from transformers import GPTJModel
108
- from ipex_llm.transformers.models.gptj import gptj_model_new_init
109
- GPTJModel.__init__ = gptj_model_new_init
110
-
111
-
112
106
  class _BaseAutoModelClass:
113
107
  HF_MODEL = None
114
108
 
@@ -495,7 +489,6 @@ class _BaseAutoModelClass:
495
489
  else:
496
490
  if quant_config is not None:
497
491
  kwargs["quantization_config"] = quant_config
498
- _load_pre()
499
492
  try:
500
493
  # To handle the input CUDA setting (such as 'device_map={"":0}'), ignore it
501
494
  kwargs.pop('device_map', None)
@@ -301,6 +301,7 @@ def patch_embedding_forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L,
301
301
 
302
302
  def merge_qkv(module: torch.nn.Module):
303
303
  merge_qkv_base(module, "SiglipAttention")
304
+ merge_qkv_base(module, "SiglipSdpaAttention")
304
305
 
305
306
 
306
307
  def vision_model_forward(self: torch.nn.Module, image: torch.Tensor):
@@ -37,6 +37,7 @@ import torch
37
37
 
38
38
  from typing import Optional, Tuple
39
39
  from transformers.cache_utils import Cache
40
+ from transformers.models.glm.modeling_glm import GlmAttention
40
41
  from transformers.models.glm.modeling_glm import apply_rotary_pos_emb
41
42
  from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
42
43
  from ipex_llm.transformers.models.common import merge_qkv_base
@@ -46,8 +47,9 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache
46
47
 
47
48
 
48
49
  def merge_qkv(module: torch.nn.Module):
49
- merge_qkv_base(module, "GlmAttention")
50
+ merge_qkv_base(module, GlmAttention)
50
51
  merge_qkv_base(module, "SiglipAttention")
52
+ merge_qkv_base(module, "SiglipSdpaAttention")
51
53
 
52
54
 
53
55
  def split_mlp(module: torch.nn.Module):
@@ -116,7 +116,7 @@ def llama_model_forward(
116
116
 
117
117
 
118
118
  def merge_qkv(module: torch.nn.Module):
119
- return merge_qkv_base(module, LlamaAttention)
119
+ merge_qkv_base(module, LlamaAttention)
120
120
 
121
121
 
122
122
  def llama_attention_forward(
@@ -51,7 +51,8 @@ from transformers.cache_utils import Cache
51
51
 
52
52
 
53
53
  def merge_qkv(module: torch.nn.Module):
54
- return merge_qkv_base(module, "MiniCPMAttention")
54
+ merge_qkv_base(module, "MiniCPMAttention")
55
+ merge_qkv_base(module, "MiniCPMSdpaAttention")
55
56
 
56
57
 
57
58
  def apply_residual_scale(module: torch.nn.Module):
@@ -36,6 +36,7 @@ from transformers.generation.logits_process import RepetitionPenaltyLogitsProces
36
36
  # MiniCPM-V-2_5 and MiniCPM-V-2_6
37
37
  def merge_qkv(module: torch.nn.Module):
38
38
  merge_qkv_base(module, "SiglipAttention")
39
+ merge_qkv_base(module, "SiglipSdpaAttention")
39
40
  merge_qkv_base(module, "Idefics2VisionAttention")
40
41
 
41
42
 
@@ -19,7 +19,7 @@ import torch
19
19
  import warnings
20
20
  from ipex_llm.utils.common import invalidInputError
21
21
  from ipex_llm.ggml.quantize import ggml_tensor_qtype
22
- from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_name
22
+ from ipex_llm.transformers.utils import get_xpu_device_name
23
23
  from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4,\
24
24
  FP6, ASYM_INT4
25
25
 
@@ -168,7 +168,7 @@ def should_use_fuse_rope(hidden_states, position_ids, training):
168
168
 
169
169
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
170
170
  if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
171
- "mixtral", "qwen2", "yuan", "stablelm", "qwen2_moe"]:
171
+ "qwen2", "yuan", "stablelm", "qwen2_moe"]:
172
172
  # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
173
173
  cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
174
174
  sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
@@ -183,7 +183,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
183
183
  q_embed = (q * cos) + (rotate_half(q) * sin)
184
184
  k_embed = (k * cos) + (rotate_half(k) * sin)
185
185
  return q_embed, k_embed
186
- elif model_family in ["gptj", "chatglm"]:
186
+ elif model_family in ["chatglm"]:
187
187
  q_embed = (q * cos) + (rotate_every_two(q) * sin)
188
188
  k_embed = (k * cos) + (rotate_every_two(k) * sin)
189
189
  return q_embed, k_embed
@@ -192,19 +192,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
192
192
  f"{model_family} is not supported.")
193
193
 
194
194
 
195
- def apply_ipex_rotate_every_two(q, k, cos, sin):
196
- # ipex's apply_rotary_embedding_two_qk can change the origin storage,
197
- # so q/k will get the result directly.
198
- from ipex_llm.transformers.utils import get_ipex_version
199
- if get_ipex_version() >= "2.1.10+xpu":
200
- torch.ops.torch_ipex.apply_rotary_embedding_two_qk(
201
- q, k, sin, cos, q, k
202
- )
203
- else:
204
- torch.ops.torch_ipex.apply_rotary_embedding(q, sin, cos, q)
205
- torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
206
-
207
-
208
195
  def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
209
196
  # to determinate if is enough kv cache room in transformers==4.36
210
197
  # seq_len for current seq len
@@ -432,8 +432,7 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l
432
432
  from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
433
433
  extend_kv_cache
434
434
  enough_kv_room = True
435
- if model_type not in ["chatglm", "qwen", "baichuan", "llama", "mistral",
436
- "gptj", "opt"]:
435
+ if model_type not in ["chatglm", "qwen", "baichuan", "llama", "mistral", "opt"]:
437
436
  return past_key_values, False
438
437
  cache_k = past_key_values[0][0]
439
438
  if model_type == "chatglm":
@@ -527,7 +526,7 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
527
526
  v[:-(new_cache_size), :, :, :])
528
527
  for k, v in past_key_values
529
528
  ]
530
- elif self.config.model_type in ["baichuan", "gptj"]:
529
+ elif self.config.model_type in ["baichuan"]:
531
530
  past_key_values = [
532
531
  (k[:, :, :-(new_cache_size), :],
533
532
  v[:, :, :-(new_cache_size), :])
@@ -796,13 +795,6 @@ def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_
796
795
  device=verify_input_ids.device)
797
796
  position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
798
797
  forward_args["position_ids"] = position_ids
799
- elif self.config.model_type == "gptj":
800
- past_length = past_key_values[0][0].size(2)
801
- input_len = verify_input_ids.shape[1]
802
- position_ids = torch.arange(past_length, input_len + past_length,
803
- dtype=torch.long, device=verify_input_ids.device)
804
- position_ids = position_ids.unsqueeze(0).view(-1, input_len)
805
- forward_args["position_ids"] = position_ids
806
798
 
807
799
  return self(**forward_args)
808
800
 
@@ -971,10 +963,6 @@ def speculative_generate(self,
971
963
  past_key_value_len = past_key_values[0][0].shape[0]
972
964
  position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
973
965
  forward_args["position_ids"] = position_ids
974
- elif self.config.model_type == "gptj":
975
- past_length = draft_past_key_values[0][0].size(2)
976
- position_ids = torch.Tensor([[past_length]]).long().to(self.device)
977
- forward_args["position_ids"] = position_ids
978
966
 
979
967
  if _enable_ipex:
980
968
  if any(keyword in self.config.model_type
@@ -154,24 +154,12 @@ def get_autocast_dtype(x):
154
154
  f"Device {x.device} is not supported.")
155
155
 
156
156
 
157
- _ipex_version = None
158
-
159
-
160
- def get_ipex_version():
161
-
162
- global _ipex_version
163
- if _ipex_version is not None:
164
- return _ipex_version
165
-
166
- import intel_extension_for_pytorch as ipex
167
- _ipex_version = ipex.__version__
168
- return _ipex_version
169
-
170
-
171
157
  def get_xpu_device_name(device: torch.device):
172
158
  if device.type != "xpu":
173
159
  return device.type
174
160
  else:
161
+ # possiable device name:
162
+ # ["arc", "pvc", "mtl", "lnl", "bmg", "arl", "legacy", "unknown"]
175
163
  import xe_linear
176
164
  return xe_linear.get_xpu_device_name(device)
177
165
 
@@ -20,9 +20,9 @@ import xe_batch
20
20
  import xe_addons
21
21
 
22
22
 
23
- @torch.library.register_fake("ipex_llm::forward_new")
24
- def _(x, weight, qtype, input_size):
25
- return torch.empty_like(x)
23
+ # @torch.library.register_fake("ipex_llm::forward_new")
24
+ # def _(x, weight, qtype, input_size):
25
+ # return ???
26
26
 
27
27
 
28
28
  # @torch.library.register_fake("ipex_llm::dequant")
@@ -32,32 +32,38 @@ def _(x, weight, qtype, input_size):
32
32
 
33
33
  @torch.library.register_fake("ipex_llm::mlp_forward_xpu")
34
34
  def _(x, weight1, weight2, batch_size, state_size, output_size, act_type, qtype):
35
- return torch.empty_like(x)
35
+ return torch.empty([batch_size, output_size],
36
+ dtype=x.dtype, device=x.device)
36
37
 
37
38
 
38
- # @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4")
39
- # def _(time_decay, time_first, key, value, num_state, den_state, max_state)
40
- # return ???
39
+ @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4")
40
+ def _(time_decay, time_first, key, value, num_state, den_state, max_state):
41
+ return torch.empty_like(key)
41
42
 
42
43
 
43
- # @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5")
44
- # def _(time_decay, time_first, receptance, key, value, state)
45
- # return ???
44
+ @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5")
45
+ def _(time_decay, time_first, receptance, key, value, state):
46
+ bsz, n_heads, seq_len, head_dim = key.shape
47
+ return torch.empty([bsz, seq_len, n_heads, head_dim],
48
+ dtype=key.dtype, device=key.device)
46
49
 
47
50
 
48
- # @torch.library.register_fake("ipex_llm::rwkv_time_shift")
49
- # def _(hidden, shifted, mix):
50
- # return ???
51
+ @torch.library.register_fake("ipex_llm::rwkv_time_shift")
52
+ def _(hidden, shifted, mix):
53
+ bsz, seq_len, hidden_size = hidden.shape
54
+ return torch.empty([mix.size(0), bsz, seq_len, hidden_size],
55
+ dtype=hidden.dtype, device=hidden.device)
51
56
 
52
57
 
53
- # @torch.library.register_fake("ipex_llm::dequantize_rows")
54
- # def _(x, weight, qtype, state_size, output_size):
55
- # return ???
58
+ @torch.library.register_fake("ipex_llm::dequantize_rows")
59
+ def _(x, weight, qtype, state_size, output_size):
60
+ return torch.empty([x.size(0), x.size(1), state_size],
61
+ dtype=torch.float, device=weight.device)
56
62
 
57
63
 
58
- @torch.library.register_fake("ipex_llm::batch_forward")
59
- def _(x, weight, qtype):
60
- return torch.empty_like(x)
64
+ # @torch.library.register_fake("ipex_llm::batch_forward")
65
+ # def _(x, weight, qtype):
66
+ # return ???
61
67
 
62
68
 
63
69
  @torch.library.register_fake("ipex_llm::sdp")