ipex-llm 2.2.0b20250106__py3-none-win_amd64.whl → 2.2.0b20250106.post1__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ipex_llm/libs/bloom-api.dll +0 -0
- ipex_llm/libs/bloom.dll +0 -0
- ipex_llm/libs/gptneox-api.dll +0 -0
- ipex_llm/libs/gptneox.dll +0 -0
- ipex_llm/libs/libbloom_avx.dll +0 -0
- ipex_llm/libs/libbloom_vnni.dll +0 -0
- ipex_llm/libs/libgptneox_avx.dll +0 -0
- ipex_llm/libs/libgptneox_vnni.dll +0 -0
- ipex_llm/libs/libllama_avx.dll +0 -0
- ipex_llm/libs/libllama_vnni.dll +0 -0
- ipex_llm/libs/libstarcoder_avx.dll +0 -0
- ipex_llm/libs/libstarcoder_vnni.dll +0 -0
- ipex_llm/libs/llama-api.dll +0 -0
- ipex_llm/libs/llama.dll +0 -0
- ipex_llm/libs/main-bloom.exe +0 -0
- ipex_llm/libs/main-gptneox.exe +0 -0
- ipex_llm/libs/main-llama.exe +0 -0
- ipex_llm/libs/main-starcoder.exe +0 -0
- ipex_llm/libs/pipeline.dll +0 -0
- ipex_llm/libs/quantize-bloom.exe +0 -0
- ipex_llm/libs/quantize-bloom_vnni.exe +0 -0
- ipex_llm/libs/quantize-gptneox.exe +0 -0
- ipex_llm/libs/quantize-gptneox_vnni.exe +0 -0
- ipex_llm/libs/quantize-llama.exe +0 -0
- ipex_llm/libs/quantize-llama_vnni.exe +0 -0
- ipex_llm/libs/quantize-starcoder.exe +0 -0
- ipex_llm/libs/quantize-starcoder_vnni.exe +0 -0
- ipex_llm/libs/starcoder-api.dll +0 -0
- ipex_llm/libs/starcoder.dll +0 -0
- ipex_llm/transformers/convert.py +19 -158
- ipex_llm/transformers/loader.py +1 -1
- ipex_llm/transformers/lookup.py +2 -2
- ipex_llm/transformers/low_bit_linear.py +15 -29
- ipex_llm/transformers/model.py +0 -7
- ipex_llm/transformers/models/chatglm2.py +1 -192
- ipex_llm/transformers/models/minicpmv.py +2 -2
- ipex_llm/transformers/models/sd.py +2 -2
- ipex_llm/transformers/models/utils.py +16 -104
- ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +5 -8
- ipex_llm/transformers/speculative.py +2 -14
- ipex_llm/transformers/utils.py +7 -20
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250106.post1.dist-info}/METADATA +40 -19
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250106.post1.dist-info}/RECORD +49 -53
- ipex_llm/transformers/models/cohere.py +0 -589
- ipex_llm/transformers/models/falcon.py +0 -829
- ipex_llm/transformers/models/gptj.py +0 -441
- ipex_llm/transformers/models/mixtral.py +0 -576
- {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250106.post1.data}/scripts/ipex-llm-init.bat +0 -0
- {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250106.post1.data}/scripts/llm-chat.ps1 +0 -0
- {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250106.post1.data}/scripts/llm-cli.ps1 +0 -0
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250106.post1.dist-info}/WHEEL +0 -0
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250106.post1.dist-info}/entry_points.txt +0 -0
- {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250106.post1.dist-info}/top_level.txt +0 -0
@@ -269,7 +269,7 @@ def chatglm2_attention_forward(
|
|
269
269
|
# IPEX-LLM OPT: fuse rope
|
270
270
|
inv_freq, position_ids = rotary_pos_emb
|
271
271
|
rot_dim = inv_freq.size(-1) * 2
|
272
|
-
if should_use_fuse_rope(hidden_states,
|
272
|
+
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
273
273
|
import xe_addons
|
274
274
|
xe_addons.rotary_two_inplaced(inv_freq, position_ids,
|
275
275
|
query_states[..., :rot_dim], key_states[..., :rot_dim])
|
@@ -321,197 +321,6 @@ def chatglm2_attention_forward(
|
|
321
321
|
return output, past_key_value
|
322
322
|
|
323
323
|
|
324
|
-
@torch.jit.script
|
325
|
-
def apply_rotary_pos_emb_original(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
326
|
-
# x: [sq, b, np, hn]
|
327
|
-
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
328
|
-
rot_dim = rope_cache.shape[-2] * 2
|
329
|
-
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
330
|
-
# truncate to support variable sizes
|
331
|
-
rope_cache = rope_cache[:sq]
|
332
|
-
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
|
333
|
-
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
|
334
|
-
x_out2 = torch.stack(
|
335
|
-
[
|
336
|
-
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
337
|
-
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
338
|
-
],
|
339
|
-
-1,
|
340
|
-
)
|
341
|
-
x_out2 = x_out2.flatten(3)
|
342
|
-
return torch.cat((x_out2, x_pass), dim=-1)
|
343
|
-
|
344
|
-
|
345
|
-
def codegeex_model_forward(
|
346
|
-
self,
|
347
|
-
input_ids,
|
348
|
-
position_ids: Optional[torch.Tensor]=None,
|
349
|
-
attention_mask: Optional[torch.BoolTensor]=None,
|
350
|
-
full_attention_mask: Optional[torch.BoolTensor]=None,
|
351
|
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None,
|
352
|
-
inputs_embeds: Optional[torch.Tensor]=None,
|
353
|
-
use_cache: Optional[bool]=None,
|
354
|
-
output_hidden_states: Optional[bool]=None,
|
355
|
-
return_dict: Optional[bool]=None,
|
356
|
-
):
|
357
|
-
output_hidden_states = (
|
358
|
-
output_hidden_states if output_hidden_states is not None
|
359
|
-
else self.config.output_hidden_states
|
360
|
-
)
|
361
|
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
362
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
363
|
-
|
364
|
-
if inputs_embeds is None:
|
365
|
-
batch_size, seq_length = input_ids.shape
|
366
|
-
inputs_embeds = self.embedding(input_ids)
|
367
|
-
else:
|
368
|
-
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
|
369
|
-
seq_length, batch_size, _ = inputs_embeds.shape
|
370
|
-
input_ids = torch.empty((batch_size, seq_length),
|
371
|
-
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
372
|
-
|
373
|
-
if full_attention_mask is None:
|
374
|
-
if (attention_mask is not None and not attention_mask.all()) or (
|
375
|
-
past_key_values and seq_length != 1):
|
376
|
-
full_attention_mask = self.get_masks(input_ids,
|
377
|
-
past_key_values,
|
378
|
-
padding_mask=attention_mask)
|
379
|
-
|
380
|
-
# ipex-llm changes begin
|
381
|
-
# 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids`
|
382
|
-
# 2. generate `causal_mask` and replace `full_attention_mask` with it
|
383
|
-
if position_ids is None:
|
384
|
-
if past_key_values is None:
|
385
|
-
position_ids = torch.arange(seq_length, dtype=torch.int64, device=inputs_embeds.device)
|
386
|
-
else:
|
387
|
-
if isinstance(past_key_values, DynamicCompressCache):
|
388
|
-
kv_length = past_key_values.get_seq_length()
|
389
|
-
else:
|
390
|
-
kv_length = past_key_values[0][0].size(0)
|
391
|
-
position_ids = torch.arange(kv_length, kv_length + seq_length,
|
392
|
-
dtype=torch.int64, device=inputs_embeds.device)
|
393
|
-
position_ids = position_ids.repeat(batch_size, 1)
|
394
|
-
use_fuse_rope = input_ids.device.type == "xpu" and not self.training
|
395
|
-
|
396
|
-
# Rotary positional embeddings
|
397
|
-
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
398
|
-
if position_ids is not None:
|
399
|
-
rotary_pos_emb = rotary_pos_emb[position_ids]
|
400
|
-
else:
|
401
|
-
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
402
|
-
if use_fuse_rope:
|
403
|
-
# Repeat cos sin here, call only once for each token.
|
404
|
-
# Chatglm2's rotary embedding is similar to gptj's, is rotate_every_two.
|
405
|
-
# If put this to attension forward, it will generate too many times.
|
406
|
-
cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1)
|
407
|
-
cos = cos.squeeze(-1)
|
408
|
-
sin = sin.squeeze(-1)
|
409
|
-
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
|
410
|
-
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
|
411
|
-
rotary_pos_emb = (cos, sin)
|
412
|
-
else:
|
413
|
-
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
414
|
-
|
415
|
-
# `full_attention_mask` is not None only when
|
416
|
-
# `past_key_values` is not None and `seq_length` > 1
|
417
|
-
if full_attention_mask is not None:
|
418
|
-
causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
|
419
|
-
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
420
|
-
mask_value = torch.finfo(inputs_embeds.dtype).min
|
421
|
-
causal_mask.masked_fill_(full_attention_mask, mask_value)
|
422
|
-
elif self.training or (inputs_embeds.device.type != "xpu" and past_key_values is None):
|
423
|
-
full_attention_mask = self.get_masks(input_ids,
|
424
|
-
past_key_values,
|
425
|
-
padding_mask=attention_mask)
|
426
|
-
causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
|
427
|
-
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
428
|
-
mask_value = torch.finfo(inputs_embeds.dtype).min
|
429
|
-
causal_mask.masked_fill_(full_attention_mask, mask_value)
|
430
|
-
else:
|
431
|
-
causal_mask = None
|
432
|
-
|
433
|
-
# Run encoder.
|
434
|
-
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
435
|
-
inputs_embeds, causal_mask,
|
436
|
-
rotary_pos_emb=rotary_pos_emb,
|
437
|
-
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
438
|
-
)
|
439
|
-
# ipex-llm changes end
|
440
|
-
|
441
|
-
if not return_dict:
|
442
|
-
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions]
|
443
|
-
if v is not None)
|
444
|
-
|
445
|
-
return BaseModelOutputWithPast(
|
446
|
-
last_hidden_state=hidden_states,
|
447
|
-
past_key_values=presents,
|
448
|
-
hidden_states=all_hidden_states,
|
449
|
-
attentions=all_self_attentions,
|
450
|
-
)
|
451
|
-
|
452
|
-
|
453
|
-
def codegeex_attention_forward(
|
454
|
-
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
455
|
-
):
|
456
|
-
q_len, bsz, _ = hidden_states.size()
|
457
|
-
n_head = self.num_attention_heads_per_partition
|
458
|
-
n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head
|
459
|
-
head_dim = self.hidden_size_per_attention_head
|
460
|
-
|
461
|
-
past_key_value = None if kv_cache is None else (kv_cache[0].permute(1, 2, 0, 3),
|
462
|
-
kv_cache[1].permute(1, 2, 0, 3))
|
463
|
-
qkv = self.query_key_value(hidden_states)
|
464
|
-
qkv = qkv.view(q_len, bsz, n_head + 2 * n_kv_head, head_dim)
|
465
|
-
# [seq_len, bsz, n_head, head_dim] -> [bsz, n_head, seq_len, head_dim]
|
466
|
-
qkv = qkv.permute(1, 2, 0, 3)
|
467
|
-
query_layer, key_layer, value_layer = qkv.split([n_head,
|
468
|
-
n_kv_head,
|
469
|
-
n_kv_head], dim=1)
|
470
|
-
kv_seq_len = key_layer.shape[2]
|
471
|
-
if past_key_value is not None:
|
472
|
-
kv_seq_len += past_key_value[0].shape[2]
|
473
|
-
|
474
|
-
# apply relative positional encoding (rotary embedding)
|
475
|
-
if len(rotary_pos_emb) == 2 and isinstance(rotary_pos_emb, tuple):
|
476
|
-
cos, sin = rotary_pos_emb
|
477
|
-
rot_dim = cos.shape[-1]
|
478
|
-
query_layer = query_layer.transpose(1, 2)
|
479
|
-
key_layer = key_layer.transpose(1, 2)
|
480
|
-
query_layer_cur = query_layer[..., :rot_dim]
|
481
|
-
key_layer_cur = key_layer[..., :rot_dim]
|
482
|
-
# ipex_llm's apply_rotary_embedding can change the origin storage,
|
483
|
-
# so query_layer will get the result directly.
|
484
|
-
torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur)
|
485
|
-
torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur)
|
486
|
-
query_layer = query_layer.transpose(1, 2)
|
487
|
-
key_layer = key_layer.transpose(1, 2)
|
488
|
-
else:
|
489
|
-
query_layer = apply_rotary_pos_emb_original(query_layer, rotary_pos_emb)
|
490
|
-
key_layer = apply_rotary_pos_emb_original(key_layer, rotary_pos_emb)
|
491
|
-
|
492
|
-
key_layer, value_layer = update_past_key_value(
|
493
|
-
past_key_value, key_layer, value_layer,
|
494
|
-
kv_seq_len, False, hidden_states.device
|
495
|
-
)
|
496
|
-
# past_key_value: [bsz, n_kv_head, seq_len, head_dim] -> [seq_len, bsz, n_kv_head, head_dim]
|
497
|
-
past_key_value = (key_layer.permute(2, 0, 1, 3),
|
498
|
-
value_layer.permute(2, 0, 1, 3)) if use_cache else None
|
499
|
-
|
500
|
-
# =================
|
501
|
-
# Output. [sq, b, h]
|
502
|
-
# =================
|
503
|
-
context_layer = scaled_dot_product_attention(
|
504
|
-
query_layer, key_layer, value_layer,
|
505
|
-
attention_mask, q_len == kv_seq_len
|
506
|
-
)
|
507
|
-
|
508
|
-
context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(q_len,
|
509
|
-
bsz,
|
510
|
-
n_head * head_dim)
|
511
|
-
output = self.dense(context_layer)
|
512
|
-
|
513
|
-
return output, past_key_value
|
514
|
-
|
515
324
|
import torch.nn.functional as F
|
516
325
|
|
517
326
|
|
@@ -53,10 +53,10 @@ def siglip_attention_forward(
|
|
53
53
|
qkv = qkv.transpose(1, 2)
|
54
54
|
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
55
55
|
|
56
|
-
from ipex_llm.transformers.utils import
|
56
|
+
from ipex_llm.transformers.utils import get_xpu_device_name
|
57
57
|
if (
|
58
58
|
self.head_dim == 72
|
59
|
-
and
|
59
|
+
and get_xpu_device_name(query_states.device) == "arc" and
|
60
60
|
query_states.dtype in [torch.float, torch.half]
|
61
61
|
):
|
62
62
|
n_heads, kv_length = query_states.size(1), key_states.size(2)
|
@@ -36,7 +36,7 @@ import math
|
|
36
36
|
import torch
|
37
37
|
from typing import Optional
|
38
38
|
|
39
|
-
from ipex_llm.transformers.utils import
|
39
|
+
from ipex_llm.transformers.utils import get_xpu_device_name
|
40
40
|
from ipex_llm.transformers.models.common import padding_qkv_hd
|
41
41
|
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
42
42
|
from diffusers.models.attention_processor import Attention
|
@@ -144,7 +144,7 @@ class AttnProcessor2_0:
|
|
144
144
|
|
145
145
|
def upcast_vae(self):
|
146
146
|
# workaround overflow and ipex's bugs
|
147
|
-
if
|
147
|
+
if get_xpu_device_name(self.vae.post_quant_conv.weight.device) == "arc":
|
148
148
|
self.vae.to(torch.bfloat16)
|
149
149
|
else:
|
150
150
|
self.vae.decoder.up_blocks.to(torch.bfloat16)
|
@@ -19,7 +19,7 @@ import torch
|
|
19
19
|
import warnings
|
20
20
|
from ipex_llm.utils.common import invalidInputError
|
21
21
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
22
|
-
from ipex_llm.transformers.utils import get_ipex_version,
|
22
|
+
from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_name
|
23
23
|
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4,\
|
24
24
|
FP6, ASYM_INT4
|
25
25
|
|
@@ -85,16 +85,14 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: in
|
|
85
85
|
return os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] == "1"
|
86
86
|
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
|
87
87
|
return os.environ["IPEX_LLM_LOW_MEM"] == "1"
|
88
|
+
elif linear.qtype in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
|
89
|
+
return False
|
88
90
|
else:
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
def kv_cache_device_check(x: torch.Tensor, kv_group: int) -> bool:
|
95
|
-
return (get_xpu_device_type(x) in ["mtl", "lnl"] and kv_group <= 1) or \
|
96
|
-
((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and
|
97
|
-
1 < x.size(0) and x.size(0) <= 8)
|
91
|
+
device_name = get_xpu_device_name(x.device)
|
92
|
+
return (
|
93
|
+
device_name in ["mtl", "lnl", "arl"] and kv_group == 1
|
94
|
+
or device_name in ["arc", "bmg"] and x.size(0) > 1
|
95
|
+
)
|
98
96
|
|
99
97
|
|
100
98
|
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device):
|
@@ -170,7 +168,7 @@ def should_use_fuse_rope(hidden_states, position_ids, training):
|
|
170
168
|
|
171
169
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
|
172
170
|
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
|
173
|
-
"
|
171
|
+
"qwen2", "yuan", "stablelm", "qwen2_moe"]:
|
174
172
|
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
175
173
|
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
176
174
|
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
@@ -185,7 +183,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
|
|
185
183
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
186
184
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
187
185
|
return q_embed, k_embed
|
188
|
-
elif model_family in ["
|
186
|
+
elif model_family in ["chatglm"]:
|
189
187
|
q_embed = (q * cos) + (rotate_every_two(q) * sin)
|
190
188
|
k_embed = (k * cos) + (rotate_every_two(k) * sin)
|
191
189
|
return q_embed, k_embed
|
@@ -194,19 +192,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
|
|
194
192
|
f"{model_family} is not supported.")
|
195
193
|
|
196
194
|
|
197
|
-
def apply_ipex_rotate_every_two(q, k, cos, sin):
|
198
|
-
# ipex's apply_rotary_embedding_two_qk can change the origin storage,
|
199
|
-
# so q/k will get the result directly.
|
200
|
-
from ipex_llm.transformers.utils import get_ipex_version
|
201
|
-
if get_ipex_version() >= "2.1.10+xpu":
|
202
|
-
torch.ops.torch_ipex.apply_rotary_embedding_two_qk(
|
203
|
-
q, k, sin, cos, q, k
|
204
|
-
)
|
205
|
-
else:
|
206
|
-
torch.ops.torch_ipex.apply_rotary_embedding(q, sin, cos, q)
|
207
|
-
torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
|
208
|
-
|
209
|
-
|
210
195
|
def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
|
211
196
|
# to determinate if is enough kv cache room in transformers==4.36
|
212
197
|
# seq_len for current seq len
|
@@ -226,57 +211,6 @@ def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1):
|
|
226
211
|
(past_key_value[0].size(2) + seq_len) * past_key_value[0].size(3)
|
227
212
|
|
228
213
|
|
229
|
-
def use_flash_attention(query, key, attention_mask=None):
|
230
|
-
# here we support query's shape is always [batch_size, head_num, q_len, head_dim],
|
231
|
-
# key's shape is always [batch_size, head_num, k_len, head_dim]
|
232
|
-
invalidInputError(query.dim() == 4,
|
233
|
-
"Here query input of use_flash_attention should be [batch_size, "
|
234
|
-
"head_num, q_len, head_dim]")
|
235
|
-
invalidInputError(key.dim() == 4,
|
236
|
-
"Here key input of use_flash_attention should be [batch_size, "
|
237
|
-
"head_num, k_len, head_dim]")
|
238
|
-
bsz, _, q_len, _ = query.size()
|
239
|
-
k_len = key.size()[2]
|
240
|
-
# check whether ipex flash attention can be used
|
241
|
-
if q_len != k_len:
|
242
|
-
# now only use flash attention for first token
|
243
|
-
# as it seems have no performance benifit for rest token now
|
244
|
-
return False
|
245
|
-
if query.device.type != "xpu":
|
246
|
-
# ipex flash attention only support for xpu
|
247
|
-
return False
|
248
|
-
ipex_version = get_ipex_version()
|
249
|
-
if ipex_version <= "2.0.110+xpu":
|
250
|
-
# ipex flash attention is supported from ipex 2.1
|
251
|
-
return False
|
252
|
-
if not torch.xpu.has_xetla():
|
253
|
-
# ipex flash attention is only supported for xetla
|
254
|
-
# may update this later
|
255
|
-
return False
|
256
|
-
elif get_xpu_device_type(query) != "pvc":
|
257
|
-
return False
|
258
|
-
if query.dtype not in [torch.float32, torch.float16]:
|
259
|
-
# only use flash attention for fp32/fp16 input
|
260
|
-
return False
|
261
|
-
if bsz > 1:
|
262
|
-
# as flash attention doesn't support attn_mask in ipex 2.1,
|
263
|
-
# so it will cause output error for padded batch input
|
264
|
-
if attention_mask is None:
|
265
|
-
return True
|
266
|
-
else:
|
267
|
-
# TODO: below logic may change for different model
|
268
|
-
# attention mask shape : [bsz, 1, q_len, k_len]
|
269
|
-
if attention_mask[0].squeeze()[0, 0].item() != 0:
|
270
|
-
# first batch contains padding
|
271
|
-
# otherwise we suppose it should be a upper triangular matrix
|
272
|
-
# at the same time, the diagonal is also 0
|
273
|
-
return False
|
274
|
-
elif not attention_mask.equal(attention_mask[0].repeat(bsz, 1, 1, 1)):
|
275
|
-
# check whether mask of every batch is the same
|
276
|
-
return False
|
277
|
-
return True
|
278
|
-
|
279
|
-
|
280
214
|
def use_sdp(q_len, kv_len, head_dim, query_states):
|
281
215
|
return (
|
282
216
|
query_states.device.type == "xpu"
|
@@ -315,38 +249,16 @@ def mlp_fusion_check(x, qtype, training):
|
|
315
249
|
if training or x.requires_grad:
|
316
250
|
return False
|
317
251
|
if qtype == FP6:
|
318
|
-
device =
|
319
|
-
if device in ["mtl", "lnl"]:
|
252
|
+
device = get_xpu_device_name(x.device)
|
253
|
+
if device in ["mtl", "lnl", "arl"]:
|
320
254
|
return False
|
321
255
|
return True
|
322
256
|
|
323
257
|
|
324
|
-
def use_decoding_fast_path(proj,
|
325
|
-
use_fuse_rope,
|
326
|
-
enough_kv_room,
|
327
|
-
bs,
|
328
|
-
qtype_check=decoding_fast_path_qtype_check):
|
329
|
-
if proj is None:
|
330
|
-
return False
|
331
|
-
device = get_xpu_device_type(proj.weight)
|
332
|
-
if not qtype_check(proj):
|
333
|
-
return False
|
334
|
-
if not use_fuse_rope:
|
335
|
-
return False
|
336
|
-
if not enough_kv_room:
|
337
|
-
return False
|
338
|
-
if bs != 1:
|
339
|
-
return False
|
340
|
-
|
341
|
-
if device in ["uhd"]:
|
342
|
-
return False
|
343
|
-
return True
|
344
|
-
|
345
|
-
|
346
258
|
def use_xmx(x: torch.Tensor, qtype: int):
|
347
|
-
device =
|
259
|
+
device = get_xpu_device_name(x.device)
|
348
260
|
return (
|
349
|
-
device in ["arc", "
|
261
|
+
device in ["arc", "pvc"]
|
350
262
|
and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5]
|
351
263
|
and (
|
352
264
|
(device == "pvc" and 1 < x.size(0) <= 16)
|
@@ -370,7 +282,7 @@ def fp16_fusion_check(proj, x, training):
|
|
370
282
|
return False
|
371
283
|
if x.requires_grad:
|
372
284
|
return False
|
373
|
-
device_type =
|
285
|
+
device_type = get_xpu_device_name(x.device)
|
374
286
|
if device_type != "pvc":
|
375
287
|
return False
|
376
288
|
return True
|
@@ -439,7 +351,7 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int):
|
|
439
351
|
else:
|
440
352
|
if use_compress_kv is None:
|
441
353
|
return (
|
442
|
-
|
354
|
+
get_xpu_device_name(x.device) in ["mtl", "lnl", "arl"]
|
443
355
|
and prompt_len >= 1800
|
444
356
|
and prompt_len <= 4500
|
445
357
|
)
|
@@ -473,10 +473,6 @@ def convert_llm_for_deploy(model: torch.nn.Module,
|
|
473
473
|
"n_splits_linear": n_splits_linear,
|
474
474
|
"n_splits_down_proj": n_splits_down_proj,
|
475
475
|
"lm_head_low_bit": lm_head_low_bit}
|
476
|
-
model.config.update(update_dict)
|
477
|
-
model.config.save_pretrained(save_directory)
|
478
|
-
if model.can_generate():
|
479
|
-
model.generation_config.save_pretrained(save_directory)
|
480
476
|
|
481
477
|
from .qwen import convert_qwen_layer, convert_fused_qwen_layer
|
482
478
|
from .qwen import convert_lm_head_and_embedding
|
@@ -537,8 +533,6 @@ def convert_llm_for_deploy(model: torch.nn.Module,
|
|
537
533
|
"n_splits_linear": n_splits_linear,
|
538
534
|
"n_splits_down_proj": n_splits_down_proj,
|
539
535
|
"lm_head_low_bit": lm_head_low_bit}
|
540
|
-
model.config.update(update_dict)
|
541
|
-
model.config.save_pretrained(save_directory)
|
542
536
|
|
543
537
|
from .llama import convert_llama_layer, convert_fused_llama_layer
|
544
538
|
from .llama import convert_lm_head_and_embedding
|
@@ -577,8 +571,6 @@ def convert_llm_for_deploy(model: torch.nn.Module,
|
|
577
571
|
"n_splits_linear": n_splits_linear,
|
578
572
|
"n_splits_down_proj": n_splits_down_proj,
|
579
573
|
"lm_head_low_bit": lm_head_low_bit}
|
580
|
-
model.config.update(update_dict)
|
581
|
-
model.config.save_pretrained(save_directory)
|
582
574
|
|
583
575
|
from .minicpm import convert_minicpm_layer, convert_fused_minicpm_layer
|
584
576
|
from .minicpm import convert_lm_head_and_embedding
|
@@ -595,3 +587,8 @@ def convert_llm_for_deploy(model: torch.nn.Module,
|
|
595
587
|
save_directory, weight_dir,
|
596
588
|
convert_model=True,
|
597
589
|
max_prompt_len=max_prompt_len)
|
590
|
+
|
591
|
+
model.config.update(update_dict)
|
592
|
+
model.config.save_pretrained(save_directory)
|
593
|
+
if model.can_generate():
|
594
|
+
model.generation_config.save_pretrained(save_directory)
|
@@ -432,8 +432,7 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l
|
|
432
432
|
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
433
433
|
extend_kv_cache
|
434
434
|
enough_kv_room = True
|
435
|
-
if model_type not in ["chatglm", "qwen", "baichuan", "llama", "mistral",
|
436
|
-
"gptj", "opt"]:
|
435
|
+
if model_type not in ["chatglm", "qwen", "baichuan", "llama", "mistral", "opt"]:
|
437
436
|
return past_key_values, False
|
438
437
|
cache_k = past_key_values[0][0]
|
439
438
|
if model_type == "chatglm":
|
@@ -527,7 +526,7 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
|
|
527
526
|
v[:-(new_cache_size), :, :, :])
|
528
527
|
for k, v in past_key_values
|
529
528
|
]
|
530
|
-
elif self.config.model_type in ["baichuan"
|
529
|
+
elif self.config.model_type in ["baichuan"]:
|
531
530
|
past_key_values = [
|
532
531
|
(k[:, :, :-(new_cache_size), :],
|
533
532
|
v[:, :, :-(new_cache_size), :])
|
@@ -796,13 +795,6 @@ def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_
|
|
796
795
|
device=verify_input_ids.device)
|
797
796
|
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
|
798
797
|
forward_args["position_ids"] = position_ids
|
799
|
-
elif self.config.model_type == "gptj":
|
800
|
-
past_length = past_key_values[0][0].size(2)
|
801
|
-
input_len = verify_input_ids.shape[1]
|
802
|
-
position_ids = torch.arange(past_length, input_len + past_length,
|
803
|
-
dtype=torch.long, device=verify_input_ids.device)
|
804
|
-
position_ids = position_ids.unsqueeze(0).view(-1, input_len)
|
805
|
-
forward_args["position_ids"] = position_ids
|
806
798
|
|
807
799
|
return self(**forward_args)
|
808
800
|
|
@@ -971,10 +963,6 @@ def speculative_generate(self,
|
|
971
963
|
past_key_value_len = past_key_values[0][0].shape[0]
|
972
964
|
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
|
973
965
|
forward_args["position_ids"] = position_ids
|
974
|
-
elif self.config.model_type == "gptj":
|
975
|
-
past_length = draft_past_key_values[0][0].size(2)
|
976
|
-
position_ids = torch.Tensor([[past_length]]).long().to(self.device)
|
977
|
-
forward_args["position_ids"] = position_ids
|
978
966
|
|
979
967
|
if _enable_ipex:
|
980
968
|
if any(keyword in self.config.model_type
|
ipex_llm/transformers/utils.py
CHANGED
@@ -168,27 +168,14 @@ def get_ipex_version():
|
|
168
168
|
return _ipex_version
|
169
169
|
|
170
170
|
|
171
|
-
def
|
172
|
-
if
|
173
|
-
return
|
174
|
-
name = torch.xpu.get_device_name(x.device.index)
|
175
|
-
if name.startswith("Intel(R) Arc(TM) A"):
|
176
|
-
return "arc"
|
177
|
-
elif name.startswith("Intel(R) Graphics [0xe20b]"):
|
178
|
-
return "bmg"
|
179
|
-
elif name.startswith("Intel(R) Arc(TM)"):
|
180
|
-
if 'V' in name:
|
181
|
-
return "lnl"
|
182
|
-
else:
|
183
|
-
return "mtl"
|
184
|
-
elif name.startswith("Intel(R) Data Center GPU Flex"):
|
185
|
-
return "flex"
|
186
|
-
elif name.startswith("Intel(R) Data Center GPU Max"):
|
187
|
-
return "pvc"
|
188
|
-
elif name.startswith("Intel(R) UHD"):
|
189
|
-
return "uhd"
|
171
|
+
def get_xpu_device_name(device: torch.device):
|
172
|
+
if device.type != "xpu":
|
173
|
+
return device.type
|
190
174
|
else:
|
191
|
-
|
175
|
+
# possiable device name:
|
176
|
+
# ["arc", "pvc", "mtl", "lnl", "bmg", "arl", "legacy", "unknown"]
|
177
|
+
import xe_linear
|
178
|
+
return xe_linear.get_xpu_device_name(device)
|
192
179
|
|
193
180
|
|
194
181
|
def load_imatrix_data(imatrix_file):
|