ipex-llm 2.2.0b20250106__py3-none-win_amd64.whl → 2.2.0b20250107__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ipex_llm/libs/bloom-api.dll +0 -0
- ipex_llm/libs/bloom.dll +0 -0
- ipex_llm/libs/gptneox-api.dll +0 -0
- ipex_llm/libs/gptneox.dll +0 -0
- ipex_llm/libs/libbloom_avx.dll +0 -0
- ipex_llm/libs/libbloom_vnni.dll +0 -0
- ipex_llm/libs/libgptneox_avx.dll +0 -0
- ipex_llm/libs/libgptneox_vnni.dll +0 -0
- ipex_llm/libs/libllama_avx.dll +0 -0
- ipex_llm/libs/libllama_vnni.dll +0 -0
- ipex_llm/libs/libstarcoder_avx.dll +0 -0
- ipex_llm/libs/libstarcoder_vnni.dll +0 -0
- ipex_llm/libs/llama-api.dll +0 -0
- ipex_llm/libs/llama.dll +0 -0
- ipex_llm/libs/main-bloom.exe +0 -0
- ipex_llm/libs/main-gptneox.exe +0 -0
- ipex_llm/libs/main-llama.exe +0 -0
- ipex_llm/libs/main-starcoder.exe +0 -0
- ipex_llm/libs/pipeline.dll +0 -0
- ipex_llm/libs/quantize-bloom.exe +0 -0
- ipex_llm/libs/quantize-bloom_vnni.exe +0 -0
- ipex_llm/libs/quantize-gptneox.exe +0 -0
- ipex_llm/libs/quantize-gptneox_vnni.exe +0 -0
- ipex_llm/libs/quantize-llama.exe +0 -0
- ipex_llm/libs/quantize-llama_vnni.exe +0 -0
- ipex_llm/libs/quantize-starcoder.exe +0 -0
- ipex_llm/libs/quantize-starcoder_vnni.exe +0 -0
- ipex_llm/libs/starcoder-api.dll +0 -0
- ipex_llm/libs/starcoder.dll +0 -0
- ipex_llm/transformers/convert.py +17 -132
- ipex_llm/transformers/lookup.py +2 -2
- ipex_llm/transformers/low_bit_linear.py +8 -8
- ipex_llm/transformers/models/chatglm2.py +1 -192
- ipex_llm/transformers/models/minicpmv.py +2 -2
- ipex_llm/transformers/models/sd.py +2 -2
- ipex_llm/transformers/models/utils.py +14 -89
- ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +5 -8
- ipex_llm/transformers/utils.py +5 -20
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/METADATA +40 -19
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/RECORD +46 -49
- ipex_llm/transformers/models/cohere.py +0 -589
- ipex_llm/transformers/models/falcon.py +0 -829
- ipex_llm/transformers/models/mixtral.py +0 -576
- {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250107.data}/scripts/ipex-llm-init.bat +0 -0
- {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250107.data}/scripts/llm-chat.ps1 +0 -0
- {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250107.data}/scripts/llm-cli.ps1 +0 -0
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/WHEEL +0 -0
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/entry_points.txt +0 -0
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250107.dist-info}/top_level.txt +0 -0
ipex_llm/libs/bloom-api.dll
CHANGED
Binary file
|
ipex_llm/libs/bloom.dll
CHANGED
Binary file
|
ipex_llm/libs/gptneox-api.dll
CHANGED
Binary file
|
ipex_llm/libs/gptneox.dll
CHANGED
Binary file
|
ipex_llm/libs/libbloom_avx.dll
CHANGED
Binary file
|
ipex_llm/libs/libbloom_vnni.dll
CHANGED
Binary file
|
ipex_llm/libs/libgptneox_avx.dll
CHANGED
Binary file
|
Binary file
|
ipex_llm/libs/libllama_avx.dll
CHANGED
Binary file
|
ipex_llm/libs/libllama_vnni.dll
CHANGED
Binary file
|
Binary file
|
Binary file
|
ipex_llm/libs/llama-api.dll
CHANGED
Binary file
|
ipex_llm/libs/llama.dll
CHANGED
Binary file
|
ipex_llm/libs/main-bloom.exe
CHANGED
Binary file
|
ipex_llm/libs/main-gptneox.exe
CHANGED
Binary file
|
ipex_llm/libs/main-llama.exe
CHANGED
Binary file
|
ipex_llm/libs/main-starcoder.exe
CHANGED
Binary file
|
ipex_llm/libs/pipeline.dll
CHANGED
Binary file
|
ipex_llm/libs/quantize-bloom.exe
CHANGED
Binary file
|
Binary file
|
Binary file
|
Binary file
|
ipex_llm/libs/quantize-llama.exe
CHANGED
Binary file
|
Binary file
|
Binary file
|
Binary file
|
ipex_llm/libs/starcoder-api.dll
CHANGED
Binary file
|
ipex_llm/libs/starcoder.dll
CHANGED
Binary file
|
ipex_llm/transformers/convert.py
CHANGED
@@ -1052,7 +1052,8 @@ def _optimize_pre(model, qtype=None):
|
|
1052
1052
|
_optimize_pre(model.llm, qtype=qtype)
|
1053
1053
|
model.llm.config.model_type = "megrezo"
|
1054
1054
|
elif model.config.model_type == "chatglm":
|
1055
|
-
if hasattr(model.config, 'padded_vocab_size') and
|
1055
|
+
if hasattr(model.config, 'padded_vocab_size') and \
|
1056
|
+
model.config.padded_vocab_size in [65024, 64896]:
|
1056
1057
|
# chatglm2 and chatglm3
|
1057
1058
|
from ipex_llm.transformers.models.chatglm2 import split_mlp
|
1058
1059
|
model.apply(split_mlp)
|
@@ -1337,7 +1338,7 @@ def _optimize_post(model):
|
|
1337
1338
|
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]
|
1338
1339
|
):
|
1339
1340
|
if hasattr(model.config, 'padded_vocab_size') and \
|
1340
|
-
model.config.padded_vocab_size
|
1341
|
+
model.config.padded_vocab_size in [65024, 64896]:
|
1341
1342
|
# chatglm2-6b, chatglm2-6b-32k, chatglm3-6b, chatglm3-6b-32k, chatglm3-6b-128k
|
1342
1343
|
modeling_module_name = model.__class__.__module__
|
1343
1344
|
module = importlib.import_module(modeling_module_name)
|
@@ -1359,27 +1360,9 @@ def _optimize_post(model):
|
|
1359
1360
|
module.RMSNorm,
|
1360
1361
|
chatglm_rms_norm_forward)
|
1361
1362
|
convert_forward(model, module.MLP, mlp_forward)
|
1362
|
-
|
1363
|
-
|
1364
|
-
|
1365
|
-
modeling_module_name = model.__class__.__module__
|
1366
|
-
module = importlib.import_module(modeling_module_name)
|
1367
|
-
from ipex_llm.transformers.models.chatglm2 import codegeex_attention_forward
|
1368
|
-
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
1369
|
-
from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
|
1370
|
-
from ipex_llm.transformers.models.chatglm2 import codegeex_model_forward
|
1371
|
-
convert_forward(model,
|
1372
|
-
module.SelfAttention,
|
1373
|
-
codegeex_attention_forward)
|
1374
|
-
convert_forward(model,
|
1375
|
-
module.GLMTransformer,
|
1376
|
-
chatglm2_encoder_forward)
|
1377
|
-
convert_forward(model,
|
1378
|
-
module.ChatGLMModel,
|
1379
|
-
codegeex_model_forward)
|
1380
|
-
convert_forward(model,
|
1381
|
-
module.RMSNorm,
|
1382
|
-
chatglm_rms_norm_forward)
|
1363
|
+
# for codegeex-nano
|
1364
|
+
if hasattr(model.config, "rope_ratio"):
|
1365
|
+
model.transformer.rotary_pos_emb.rope_ratio = model.config.rope_ratio
|
1383
1366
|
elif hasattr(model.config, 'vocab_size') and model.config.vocab_size == 130528:
|
1384
1367
|
# chatglm-6b
|
1385
1368
|
modeling_module_name = model.__class__.__module__
|
@@ -1492,44 +1475,6 @@ def _optimize_post(model):
|
|
1492
1475
|
module.BloomAttention,
|
1493
1476
|
bloom_attention_forward
|
1494
1477
|
)
|
1495
|
-
elif "falcon" in model.config.model_type or "RefinedWeb" in model.config.model_type:
|
1496
|
-
if model.config.architectures is not None:
|
1497
|
-
modeling_module_name = model.__class__.__module__
|
1498
|
-
module = importlib.import_module(modeling_module_name)
|
1499
|
-
if "RWForCausalLM" in model.config.architectures:
|
1500
|
-
if model.config.hidden_size == 4544:
|
1501
|
-
# falcon-7b need to check performance drop after kv cache support.
|
1502
|
-
# from ipex_llm.transformers.models.falcon import rw_attention_forward_7b
|
1503
|
-
# convert_forward(model,
|
1504
|
-
# module.Attention,
|
1505
|
-
# rw_attention_forward_7b
|
1506
|
-
# )
|
1507
|
-
pass
|
1508
|
-
else:
|
1509
|
-
# falcon-40b
|
1510
|
-
from ipex_llm.transformers.models.falcon import rw_attention_forward_40b
|
1511
|
-
convert_forward(model,
|
1512
|
-
module.Attention,
|
1513
|
-
rw_attention_forward_40b
|
1514
|
-
)
|
1515
|
-
elif "FalconForCausalLM" in model.config.architectures:
|
1516
|
-
if model.config.hidden_size != 4544:
|
1517
|
-
# falcon-180b and new falcon-40b
|
1518
|
-
if version.parse(trans_version) >= version.parse("4.36.0"):
|
1519
|
-
# transformers version >= 4.36.0
|
1520
|
-
from ipex_llm.transformers.models.falcon import \
|
1521
|
-
falcon_attention_forward_4_36
|
1522
|
-
|
1523
|
-
convert_forward(model,
|
1524
|
-
module.FalconAttention,
|
1525
|
-
falcon_attention_forward_4_36
|
1526
|
-
)
|
1527
|
-
else:
|
1528
|
-
from ipex_llm.transformers.models.falcon import falcon_attention_forward
|
1529
|
-
convert_forward(model,
|
1530
|
-
module.FalconAttention,
|
1531
|
-
falcon_attention_forward
|
1532
|
-
)
|
1533
1478
|
elif model.config.model_type == "baichuan":
|
1534
1479
|
modeling_module_name = model.__class__.__module__
|
1535
1480
|
module = importlib.import_module(modeling_module_name)
|
@@ -1748,31 +1693,6 @@ def _optimize_post(model):
|
|
1748
1693
|
convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward)
|
1749
1694
|
convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
|
1750
1695
|
convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)
|
1751
|
-
elif model.config.model_type == "cohere":
|
1752
|
-
# for CohereForAI/c4ai-command-r-v01
|
1753
|
-
invalidInputError(version.parse(trans_version) >= version.parse("4.40.0"),
|
1754
|
-
"Please upgrade transformers to 4.40.0 or higher version "
|
1755
|
-
"to run Mixtral models.")
|
1756
|
-
modeling_module_name = model.__class__.__module__
|
1757
|
-
module = importlib.import_module(modeling_module_name)
|
1758
|
-
if version.parse(trans_version) >= version.parse("4.41.0"):
|
1759
|
-
from ipex_llm.transformers.models.cohere import cohere_model_forward_4_41
|
1760
|
-
convert_forward(model,
|
1761
|
-
module.CohereModel,
|
1762
|
-
cohere_model_forward_4_41)
|
1763
|
-
else:
|
1764
|
-
from ipex_llm.transformers.models.cohere import cohere_model_forward
|
1765
|
-
convert_forward(model,
|
1766
|
-
module.CohereModel,
|
1767
|
-
cohere_model_forward)
|
1768
|
-
|
1769
|
-
from ipex_llm.transformers.models.cohere import cohere_attention_forward
|
1770
|
-
convert_forward(model,
|
1771
|
-
module.CohereAttention,
|
1772
|
-
cohere_attention_forward)
|
1773
|
-
convert_forward(model,
|
1774
|
-
module.CohereMLP,
|
1775
|
-
mlp_silu_forward)
|
1776
1696
|
elif model.config.model_type == "aquila":
|
1777
1697
|
modeling_module_name = model.__class__.__module__
|
1778
1698
|
module = importlib.import_module(modeling_module_name)
|
@@ -1784,31 +1704,6 @@ def _optimize_post(model):
|
|
1784
1704
|
convert_forward(model,
|
1785
1705
|
module.AquilaRMSNorm,
|
1786
1706
|
rms_norm_forward)
|
1787
|
-
elif model.config.model_type == "mixtral":
|
1788
|
-
# For mistralai/Mixtral-8x7B-v0.1
|
1789
|
-
invalidInputError(version.parse(trans_version) >= version.parse("4.36.0"),
|
1790
|
-
"Please upgrade transformers to 4.36.0 or higher version "
|
1791
|
-
"to run Mixtral models.")
|
1792
|
-
modeling_module_name = model.__class__.__module__
|
1793
|
-
module = importlib.import_module(modeling_module_name)
|
1794
|
-
from ipex_llm.transformers.models.mixtral import mixtral_moeblock_forward, \
|
1795
|
-
mixtral_attention_forward, mixtral_mlp_forward, mixtral_model_forward
|
1796
|
-
convert_forward(model,
|
1797
|
-
module.MixtralAttention,
|
1798
|
-
mixtral_attention_forward)
|
1799
|
-
convert_forward(model,
|
1800
|
-
module.MixtralRMSNorm,
|
1801
|
-
rms_norm_forward)
|
1802
|
-
convert_forward(model,
|
1803
|
-
module.MixtralSparseMoeBlock,
|
1804
|
-
mixtral_moeblock_forward)
|
1805
|
-
convert_forward(model,
|
1806
|
-
module.MixtralBLockSparseTop2MLP,
|
1807
|
-
mixtral_mlp_forward)
|
1808
|
-
convert_forward(model,
|
1809
|
-
module.MixtralModel,
|
1810
|
-
mixtral_model_forward)
|
1811
|
-
|
1812
1707
|
elif model.config.model_type == "phi-msft" and \
|
1813
1708
|
hasattr(model.config, "num_local_experts"):
|
1814
1709
|
# For phixtral, limit the condition to avoid applying on phi-2 hosted by ModelScope
|
@@ -1823,29 +1718,19 @@ def _optimize_post(model):
|
|
1823
1718
|
module.MLP,
|
1824
1719
|
phixtral_mlp_forward)
|
1825
1720
|
elif model.config.model_type == "mistral":
|
1826
|
-
|
1827
|
-
|
1828
|
-
# For DiscoResearch/mixtral-7b-8expert
|
1829
|
-
invalidInputError(version.parse(trans_version) >= version.parse("4.36.0"),
|
1830
|
-
"Please upgrade transformers to 4.36.0 or higher version "
|
1831
|
-
"to run Mixtral models.")
|
1832
|
-
modeling_module_name = model.__class__.__module__
|
1833
|
-
module = importlib.import_module(modeling_module_name)
|
1834
|
-
convert_forward(model, module.MistralRMSNorm, rms_norm_forward)
|
1835
|
-
else:
|
1836
|
-
modeling_module_name = model.__class__.__module__
|
1837
|
-
module = importlib.import_module(modeling_module_name)
|
1721
|
+
modeling_module_name = model.__class__.__module__
|
1722
|
+
module = importlib.import_module(modeling_module_name)
|
1838
1723
|
|
1839
|
-
|
1840
|
-
|
1841
|
-
|
1842
|
-
|
1724
|
+
from ipex_llm.transformers.models.mistral import mistral_model_forward
|
1725
|
+
from ipex_llm.transformers.models.mistral import mistral_attention_forward
|
1726
|
+
from ipex_llm.transformers.models.common import rms_norm_forward
|
1727
|
+
from ipex_llm.transformers.models.common import mlp_silu_forward
|
1843
1728
|
|
1844
|
-
|
1845
|
-
|
1846
|
-
|
1847
|
-
|
1848
|
-
|
1729
|
+
convert_forward(model, module.MistralModel, mistral_model_forward)
|
1730
|
+
convert_forward(model, module.MistralAttention, mistral_attention_forward)
|
1731
|
+
convert_forward(model, module.MistralSdpaAttention, mistral_attention_forward)
|
1732
|
+
convert_forward(model, module.MistralRMSNorm, rms_norm_forward)
|
1733
|
+
convert_forward(model, module.MistralMLP, mlp_silu_forward)
|
1849
1734
|
elif model.config.model_type == "gemma":
|
1850
1735
|
modeling_module_name = model.__class__.__module__
|
1851
1736
|
module = importlib.import_module(modeling_module_name)
|
ipex_llm/transformers/lookup.py
CHANGED
@@ -33,7 +33,7 @@ from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to
|
|
33
33
|
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks,\
|
34
34
|
_prepare_generate_args_4_45
|
35
35
|
from ipex_llm.utils.common import invalidInputError
|
36
|
-
from ipex_llm.transformers.utils import
|
36
|
+
from ipex_llm.transformers.utils import get_xpu_device_name
|
37
37
|
|
38
38
|
logger = logging.getLogger("ipex_llm.lookup")
|
39
39
|
|
@@ -295,7 +295,7 @@ def lookup_generate(self,
|
|
295
295
|
invalidInputError(input_ids.shape[0] == 1,
|
296
296
|
"Prompt lookup is currently not supported with batch inference.")
|
297
297
|
|
298
|
-
device_name =
|
298
|
+
device_name = get_xpu_device_name(input_ids.device)
|
299
299
|
|
300
300
|
candidates_generator = PromptLookupCandidateGenerator(
|
301
301
|
num_output_tokens=num_output_tokens,
|
@@ -51,7 +51,7 @@ from torch import Tensor, device, dtype, nn
|
|
51
51
|
from operator import mul
|
52
52
|
from functools import reduce
|
53
53
|
from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
|
54
|
-
from ipex_llm.transformers.utils import get_autocast_dtype,
|
54
|
+
from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_name, \
|
55
55
|
get_ipex_version
|
56
56
|
from ipex_llm.transformers.convert import is_deepspeed_available, get_use_vllm
|
57
57
|
|
@@ -266,7 +266,7 @@ def reshape_lm_head_input(x):
|
|
266
266
|
|
267
267
|
|
268
268
|
def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
|
269
|
-
|
269
|
+
device_name = get_xpu_device_name(x.device)
|
270
270
|
batch_size = x.shape[0]
|
271
271
|
hard_condition = (
|
272
272
|
x.dtype in [torch.float, torch.half]
|
@@ -286,7 +286,7 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
|
|
286
286
|
or (
|
287
287
|
qtype in [SYM_INT8, FP4, FP6, Q4_K, Q6_K]
|
288
288
|
and batch_size <= 48
|
289
|
-
and
|
289
|
+
and device_name in ["arc", "pvc", "mtl", "lnl", "arl"]
|
290
290
|
and x.shape[1] % 256 == 0
|
291
291
|
and output_len % 32 == 0
|
292
292
|
)
|
@@ -295,8 +295,8 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
|
|
295
295
|
if hard_condition:
|
296
296
|
return (
|
297
297
|
batch_size > 1
|
298
|
-
or (device in ["arc"
|
299
|
-
or (device in ["arc", "
|
298
|
+
or (device in ["arc"] and qtype in [SYM_INT8, FP4])
|
299
|
+
or (device in ["arc", "mtl"] and qtype in [FP8E4])
|
300
300
|
or (device in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0)
|
301
301
|
or (device in ["bmg"] and qtype in [SYM_INT4, FP8E5])
|
302
302
|
)
|
@@ -603,7 +603,7 @@ class LowBitLinear(nn.Linear):
|
|
603
603
|
# empty cache before and after lm_head at first token when input > 1024
|
604
604
|
# on arc or IPEX_LLM_LOW_MEM is set to 1 at inference time.
|
605
605
|
if self.device is None:
|
606
|
-
self.device =
|
606
|
+
self.device = get_xpu_device_name(self.weight.data.device)
|
607
607
|
self.low_memory_mode = \
|
608
608
|
self.low_memory_mode and \
|
609
609
|
(self.device == "arc" or os.environ.get("IPEX_LLM_LOW_MEM", None) == "1")
|
@@ -782,7 +782,7 @@ class FP16Linear(nn.Linear):
|
|
782
782
|
if not self.use_esimd_kernel(x):
|
783
783
|
if (
|
784
784
|
get_ipex_version() < "2.1.10+xpu"
|
785
|
-
or
|
785
|
+
or get_xpu_device_name(x.device) not in ["arc", "pvc"]
|
786
786
|
or self.disable_fp16_opt
|
787
787
|
):
|
788
788
|
if self.weight_type == 2:
|
@@ -848,7 +848,7 @@ class FP16Linear(nn.Linear):
|
|
848
848
|
return result.to(x.dtype)
|
849
849
|
|
850
850
|
def use_esimd_kernel(self, x):
|
851
|
-
gpu_type =
|
851
|
+
gpu_type = get_xpu_device_name(x.device)
|
852
852
|
if self.disable_fp16_opt:
|
853
853
|
return False
|
854
854
|
# esimd kernel can only be used for Arc and Flex
|
@@ -269,7 +269,7 @@ def chatglm2_attention_forward(
|
|
269
269
|
# IPEX-LLM OPT: fuse rope
|
270
270
|
inv_freq, position_ids = rotary_pos_emb
|
271
271
|
rot_dim = inv_freq.size(-1) * 2
|
272
|
-
if should_use_fuse_rope(hidden_states,
|
272
|
+
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
273
273
|
import xe_addons
|
274
274
|
xe_addons.rotary_two_inplaced(inv_freq, position_ids,
|
275
275
|
query_states[..., :rot_dim], key_states[..., :rot_dim])
|
@@ -321,197 +321,6 @@ def chatglm2_attention_forward(
|
|
321
321
|
return output, past_key_value
|
322
322
|
|
323
323
|
|
324
|
-
@torch.jit.script
|
325
|
-
def apply_rotary_pos_emb_original(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
326
|
-
# x: [sq, b, np, hn]
|
327
|
-
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
328
|
-
rot_dim = rope_cache.shape[-2] * 2
|
329
|
-
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
330
|
-
# truncate to support variable sizes
|
331
|
-
rope_cache = rope_cache[:sq]
|
332
|
-
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
|
333
|
-
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
|
334
|
-
x_out2 = torch.stack(
|
335
|
-
[
|
336
|
-
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
337
|
-
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
338
|
-
],
|
339
|
-
-1,
|
340
|
-
)
|
341
|
-
x_out2 = x_out2.flatten(3)
|
342
|
-
return torch.cat((x_out2, x_pass), dim=-1)
|
343
|
-
|
344
|
-
|
345
|
-
def codegeex_model_forward(
|
346
|
-
self,
|
347
|
-
input_ids,
|
348
|
-
position_ids: Optional[torch.Tensor]=None,
|
349
|
-
attention_mask: Optional[torch.BoolTensor]=None,
|
350
|
-
full_attention_mask: Optional[torch.BoolTensor]=None,
|
351
|
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None,
|
352
|
-
inputs_embeds: Optional[torch.Tensor]=None,
|
353
|
-
use_cache: Optional[bool]=None,
|
354
|
-
output_hidden_states: Optional[bool]=None,
|
355
|
-
return_dict: Optional[bool]=None,
|
356
|
-
):
|
357
|
-
output_hidden_states = (
|
358
|
-
output_hidden_states if output_hidden_states is not None
|
359
|
-
else self.config.output_hidden_states
|
360
|
-
)
|
361
|
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
362
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
363
|
-
|
364
|
-
if inputs_embeds is None:
|
365
|
-
batch_size, seq_length = input_ids.shape
|
366
|
-
inputs_embeds = self.embedding(input_ids)
|
367
|
-
else:
|
368
|
-
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
|
369
|
-
seq_length, batch_size, _ = inputs_embeds.shape
|
370
|
-
input_ids = torch.empty((batch_size, seq_length),
|
371
|
-
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
372
|
-
|
373
|
-
if full_attention_mask is None:
|
374
|
-
if (attention_mask is not None and not attention_mask.all()) or (
|
375
|
-
past_key_values and seq_length != 1):
|
376
|
-
full_attention_mask = self.get_masks(input_ids,
|
377
|
-
past_key_values,
|
378
|
-
padding_mask=attention_mask)
|
379
|
-
|
380
|
-
# ipex-llm changes begin
|
381
|
-
# 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids`
|
382
|
-
# 2. generate `causal_mask` and replace `full_attention_mask` with it
|
383
|
-
if position_ids is None:
|
384
|
-
if past_key_values is None:
|
385
|
-
position_ids = torch.arange(seq_length, dtype=torch.int64, device=inputs_embeds.device)
|
386
|
-
else:
|
387
|
-
if isinstance(past_key_values, DynamicCompressCache):
|
388
|
-
kv_length = past_key_values.get_seq_length()
|
389
|
-
else:
|
390
|
-
kv_length = past_key_values[0][0].size(0)
|
391
|
-
position_ids = torch.arange(kv_length, kv_length + seq_length,
|
392
|
-
dtype=torch.int64, device=inputs_embeds.device)
|
393
|
-
position_ids = position_ids.repeat(batch_size, 1)
|
394
|
-
use_fuse_rope = input_ids.device.type == "xpu" and not self.training
|
395
|
-
|
396
|
-
# Rotary positional embeddings
|
397
|
-
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
398
|
-
if position_ids is not None:
|
399
|
-
rotary_pos_emb = rotary_pos_emb[position_ids]
|
400
|
-
else:
|
401
|
-
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
402
|
-
if use_fuse_rope:
|
403
|
-
# Repeat cos sin here, call only once for each token.
|
404
|
-
# Chatglm2's rotary embedding is similar to gptj's, is rotate_every_two.
|
405
|
-
# If put this to attension forward, it will generate too many times.
|
406
|
-
cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1)
|
407
|
-
cos = cos.squeeze(-1)
|
408
|
-
sin = sin.squeeze(-1)
|
409
|
-
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
|
410
|
-
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
|
411
|
-
rotary_pos_emb = (cos, sin)
|
412
|
-
else:
|
413
|
-
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
414
|
-
|
415
|
-
# `full_attention_mask` is not None only when
|
416
|
-
# `past_key_values` is not None and `seq_length` > 1
|
417
|
-
if full_attention_mask is not None:
|
418
|
-
causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
|
419
|
-
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
420
|
-
mask_value = torch.finfo(inputs_embeds.dtype).min
|
421
|
-
causal_mask.masked_fill_(full_attention_mask, mask_value)
|
422
|
-
elif self.training or (inputs_embeds.device.type != "xpu" and past_key_values is None):
|
423
|
-
full_attention_mask = self.get_masks(input_ids,
|
424
|
-
past_key_values,
|
425
|
-
padding_mask=attention_mask)
|
426
|
-
causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
|
427
|
-
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
428
|
-
mask_value = torch.finfo(inputs_embeds.dtype).min
|
429
|
-
causal_mask.masked_fill_(full_attention_mask, mask_value)
|
430
|
-
else:
|
431
|
-
causal_mask = None
|
432
|
-
|
433
|
-
# Run encoder.
|
434
|
-
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
435
|
-
inputs_embeds, causal_mask,
|
436
|
-
rotary_pos_emb=rotary_pos_emb,
|
437
|
-
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
438
|
-
)
|
439
|
-
# ipex-llm changes end
|
440
|
-
|
441
|
-
if not return_dict:
|
442
|
-
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions]
|
443
|
-
if v is not None)
|
444
|
-
|
445
|
-
return BaseModelOutputWithPast(
|
446
|
-
last_hidden_state=hidden_states,
|
447
|
-
past_key_values=presents,
|
448
|
-
hidden_states=all_hidden_states,
|
449
|
-
attentions=all_self_attentions,
|
450
|
-
)
|
451
|
-
|
452
|
-
|
453
|
-
def codegeex_attention_forward(
|
454
|
-
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
455
|
-
):
|
456
|
-
q_len, bsz, _ = hidden_states.size()
|
457
|
-
n_head = self.num_attention_heads_per_partition
|
458
|
-
n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head
|
459
|
-
head_dim = self.hidden_size_per_attention_head
|
460
|
-
|
461
|
-
past_key_value = None if kv_cache is None else (kv_cache[0].permute(1, 2, 0, 3),
|
462
|
-
kv_cache[1].permute(1, 2, 0, 3))
|
463
|
-
qkv = self.query_key_value(hidden_states)
|
464
|
-
qkv = qkv.view(q_len, bsz, n_head + 2 * n_kv_head, head_dim)
|
465
|
-
# [seq_len, bsz, n_head, head_dim] -> [bsz, n_head, seq_len, head_dim]
|
466
|
-
qkv = qkv.permute(1, 2, 0, 3)
|
467
|
-
query_layer, key_layer, value_layer = qkv.split([n_head,
|
468
|
-
n_kv_head,
|
469
|
-
n_kv_head], dim=1)
|
470
|
-
kv_seq_len = key_layer.shape[2]
|
471
|
-
if past_key_value is not None:
|
472
|
-
kv_seq_len += past_key_value[0].shape[2]
|
473
|
-
|
474
|
-
# apply relative positional encoding (rotary embedding)
|
475
|
-
if len(rotary_pos_emb) == 2 and isinstance(rotary_pos_emb, tuple):
|
476
|
-
cos, sin = rotary_pos_emb
|
477
|
-
rot_dim = cos.shape[-1]
|
478
|
-
query_layer = query_layer.transpose(1, 2)
|
479
|
-
key_layer = key_layer.transpose(1, 2)
|
480
|
-
query_layer_cur = query_layer[..., :rot_dim]
|
481
|
-
key_layer_cur = key_layer[..., :rot_dim]
|
482
|
-
# ipex_llm's apply_rotary_embedding can change the origin storage,
|
483
|
-
# so query_layer will get the result directly.
|
484
|
-
torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur)
|
485
|
-
torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur)
|
486
|
-
query_layer = query_layer.transpose(1, 2)
|
487
|
-
key_layer = key_layer.transpose(1, 2)
|
488
|
-
else:
|
489
|
-
query_layer = apply_rotary_pos_emb_original(query_layer, rotary_pos_emb)
|
490
|
-
key_layer = apply_rotary_pos_emb_original(key_layer, rotary_pos_emb)
|
491
|
-
|
492
|
-
key_layer, value_layer = update_past_key_value(
|
493
|
-
past_key_value, key_layer, value_layer,
|
494
|
-
kv_seq_len, False, hidden_states.device
|
495
|
-
)
|
496
|
-
# past_key_value: [bsz, n_kv_head, seq_len, head_dim] -> [seq_len, bsz, n_kv_head, head_dim]
|
497
|
-
past_key_value = (key_layer.permute(2, 0, 1, 3),
|
498
|
-
value_layer.permute(2, 0, 1, 3)) if use_cache else None
|
499
|
-
|
500
|
-
# =================
|
501
|
-
# Output. [sq, b, h]
|
502
|
-
# =================
|
503
|
-
context_layer = scaled_dot_product_attention(
|
504
|
-
query_layer, key_layer, value_layer,
|
505
|
-
attention_mask, q_len == kv_seq_len
|
506
|
-
)
|
507
|
-
|
508
|
-
context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(q_len,
|
509
|
-
bsz,
|
510
|
-
n_head * head_dim)
|
511
|
-
output = self.dense(context_layer)
|
512
|
-
|
513
|
-
return output, past_key_value
|
514
|
-
|
515
324
|
import torch.nn.functional as F
|
516
325
|
|
517
326
|
|
@@ -53,10 +53,10 @@ def siglip_attention_forward(
|
|
53
53
|
qkv = qkv.transpose(1, 2)
|
54
54
|
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
55
55
|
|
56
|
-
from ipex_llm.transformers.utils import
|
56
|
+
from ipex_llm.transformers.utils import get_xpu_device_name
|
57
57
|
if (
|
58
58
|
self.head_dim == 72
|
59
|
-
and
|
59
|
+
and get_xpu_device_name(query_states.device) == "arc" and
|
60
60
|
query_states.dtype in [torch.float, torch.half]
|
61
61
|
):
|
62
62
|
n_heads, kv_length = query_states.size(1), key_states.size(2)
|
@@ -36,7 +36,7 @@ import math
|
|
36
36
|
import torch
|
37
37
|
from typing import Optional
|
38
38
|
|
39
|
-
from ipex_llm.transformers.utils import
|
39
|
+
from ipex_llm.transformers.utils import get_xpu_device_name
|
40
40
|
from ipex_llm.transformers.models.common import padding_qkv_hd
|
41
41
|
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
42
42
|
from diffusers.models.attention_processor import Attention
|
@@ -144,7 +144,7 @@ class AttnProcessor2_0:
|
|
144
144
|
|
145
145
|
def upcast_vae(self):
|
146
146
|
# workaround overflow and ipex's bugs
|
147
|
-
if
|
147
|
+
if get_xpu_device_name(self.vae.post_quant_conv.weight.device) == "arc":
|
148
148
|
self.vae.to(torch.bfloat16)
|
149
149
|
else:
|
150
150
|
self.vae.decoder.up_blocks.to(torch.bfloat16)
|