ipex-llm 2.2.0b20250106__py3-none-win_amd64.whl → 2.2.0b20250106.post1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. ipex_llm/libs/bloom-api.dll +0 -0
  2. ipex_llm/libs/bloom.dll +0 -0
  3. ipex_llm/libs/gptneox-api.dll +0 -0
  4. ipex_llm/libs/gptneox.dll +0 -0
  5. ipex_llm/libs/libbloom_avx.dll +0 -0
  6. ipex_llm/libs/libbloom_vnni.dll +0 -0
  7. ipex_llm/libs/libgptneox_avx.dll +0 -0
  8. ipex_llm/libs/libgptneox_vnni.dll +0 -0
  9. ipex_llm/libs/libllama_avx.dll +0 -0
  10. ipex_llm/libs/libllama_vnni.dll +0 -0
  11. ipex_llm/libs/libstarcoder_avx.dll +0 -0
  12. ipex_llm/libs/libstarcoder_vnni.dll +0 -0
  13. ipex_llm/libs/llama-api.dll +0 -0
  14. ipex_llm/libs/llama.dll +0 -0
  15. ipex_llm/libs/main-bloom.exe +0 -0
  16. ipex_llm/libs/main-gptneox.exe +0 -0
  17. ipex_llm/libs/main-llama.exe +0 -0
  18. ipex_llm/libs/main-starcoder.exe +0 -0
  19. ipex_llm/libs/pipeline.dll +0 -0
  20. ipex_llm/libs/quantize-bloom.exe +0 -0
  21. ipex_llm/libs/quantize-bloom_vnni.exe +0 -0
  22. ipex_llm/libs/quantize-gptneox.exe +0 -0
  23. ipex_llm/libs/quantize-gptneox_vnni.exe +0 -0
  24. ipex_llm/libs/quantize-llama.exe +0 -0
  25. ipex_llm/libs/quantize-llama_vnni.exe +0 -0
  26. ipex_llm/libs/quantize-starcoder.exe +0 -0
  27. ipex_llm/libs/quantize-starcoder_vnni.exe +0 -0
  28. ipex_llm/libs/starcoder-api.dll +0 -0
  29. ipex_llm/libs/starcoder.dll +0 -0
  30. ipex_llm/transformers/convert.py +19 -158
  31. ipex_llm/transformers/loader.py +1 -1
  32. ipex_llm/transformers/lookup.py +2 -2
  33. ipex_llm/transformers/low_bit_linear.py +15 -29
  34. ipex_llm/transformers/model.py +0 -7
  35. ipex_llm/transformers/models/chatglm2.py +1 -192
  36. ipex_llm/transformers/models/minicpmv.py +2 -2
  37. ipex_llm/transformers/models/sd.py +2 -2
  38. ipex_llm/transformers/models/utils.py +16 -104
  39. ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +5 -8
  40. ipex_llm/transformers/speculative.py +2 -14
  41. ipex_llm/transformers/utils.py +7 -20
  42. {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250106.post1.dist-info}/METADATA +40 -19
  43. {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250106.post1.dist-info}/RECORD +49 -53
  44. ipex_llm/transformers/models/cohere.py +0 -589
  45. ipex_llm/transformers/models/falcon.py +0 -829
  46. ipex_llm/transformers/models/gptj.py +0 -441
  47. ipex_llm/transformers/models/mixtral.py +0 -576
  48. {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250106.post1.data}/scripts/ipex-llm-init.bat +0 -0
  49. {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250106.post1.data}/scripts/llm-chat.ps1 +0 -0
  50. {ipex_llm-2.2.0b20250106.data → ipex_llm-2.2.0b20250106.post1.data}/scripts/llm-cli.ps1 +0 -0
  51. {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250106.post1.dist-info}/WHEEL +0 -0
  52. {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250106.post1.dist-info}/entry_points.txt +0 -0
  53. {ipex_llm-2.2.0b20250106.dist-info → ipex_llm-2.2.0b20250106.post1.dist-info}/top_level.txt +0 -0
@@ -269,7 +269,7 @@ def chatglm2_attention_forward(
269
269
  # IPEX-LLM OPT: fuse rope
270
270
  inv_freq, position_ids = rotary_pos_emb
271
271
  rot_dim = inv_freq.size(-1) * 2
272
- if should_use_fuse_rope(hidden_states, rotary_pos_emb[1], self.training):
272
+ if should_use_fuse_rope(hidden_states, position_ids, self.training):
273
273
  import xe_addons
274
274
  xe_addons.rotary_two_inplaced(inv_freq, position_ids,
275
275
  query_states[..., :rot_dim], key_states[..., :rot_dim])
@@ -321,197 +321,6 @@ def chatglm2_attention_forward(
321
321
  return output, past_key_value
322
322
 
323
323
 
324
- @torch.jit.script
325
- def apply_rotary_pos_emb_original(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
326
- # x: [sq, b, np, hn]
327
- sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
328
- rot_dim = rope_cache.shape[-2] * 2
329
- x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
330
- # truncate to support variable sizes
331
- rope_cache = rope_cache[:sq]
332
- xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
333
- rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
334
- x_out2 = torch.stack(
335
- [
336
- xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
337
- xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
338
- ],
339
- -1,
340
- )
341
- x_out2 = x_out2.flatten(3)
342
- return torch.cat((x_out2, x_pass), dim=-1)
343
-
344
-
345
- def codegeex_model_forward(
346
- self,
347
- input_ids,
348
- position_ids: Optional[torch.Tensor]=None,
349
- attention_mask: Optional[torch.BoolTensor]=None,
350
- full_attention_mask: Optional[torch.BoolTensor]=None,
351
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None,
352
- inputs_embeds: Optional[torch.Tensor]=None,
353
- use_cache: Optional[bool]=None,
354
- output_hidden_states: Optional[bool]=None,
355
- return_dict: Optional[bool]=None,
356
- ):
357
- output_hidden_states = (
358
- output_hidden_states if output_hidden_states is not None
359
- else self.config.output_hidden_states
360
- )
361
- use_cache = use_cache if use_cache is not None else self.config.use_cache
362
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
363
-
364
- if inputs_embeds is None:
365
- batch_size, seq_length = input_ids.shape
366
- inputs_embeds = self.embedding(input_ids)
367
- else:
368
- inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
369
- seq_length, batch_size, _ = inputs_embeds.shape
370
- input_ids = torch.empty((batch_size, seq_length),
371
- dtype=inputs_embeds.dtype, device=inputs_embeds.device)
372
-
373
- if full_attention_mask is None:
374
- if (attention_mask is not None and not attention_mask.all()) or (
375
- past_key_values and seq_length != 1):
376
- full_attention_mask = self.get_masks(input_ids,
377
- past_key_values,
378
- padding_mask=attention_mask)
379
-
380
- # ipex-llm changes begin
381
- # 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids`
382
- # 2. generate `causal_mask` and replace `full_attention_mask` with it
383
- if position_ids is None:
384
- if past_key_values is None:
385
- position_ids = torch.arange(seq_length, dtype=torch.int64, device=inputs_embeds.device)
386
- else:
387
- if isinstance(past_key_values, DynamicCompressCache):
388
- kv_length = past_key_values.get_seq_length()
389
- else:
390
- kv_length = past_key_values[0][0].size(0)
391
- position_ids = torch.arange(kv_length, kv_length + seq_length,
392
- dtype=torch.int64, device=inputs_embeds.device)
393
- position_ids = position_ids.repeat(batch_size, 1)
394
- use_fuse_rope = input_ids.device.type == "xpu" and not self.training
395
-
396
- # Rotary positional embeddings
397
- rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
398
- if position_ids is not None:
399
- rotary_pos_emb = rotary_pos_emb[position_ids]
400
- else:
401
- rotary_pos_emb = rotary_pos_emb[None, :seq_length]
402
- if use_fuse_rope:
403
- # Repeat cos sin here, call only once for each token.
404
- # Chatglm2's rotary embedding is similar to gptj's, is rotate_every_two.
405
- # If put this to attension forward, it will generate too many times.
406
- cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1)
407
- cos = cos.squeeze(-1)
408
- sin = sin.squeeze(-1)
409
- cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
410
- sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
411
- rotary_pos_emb = (cos, sin)
412
- else:
413
- rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
414
-
415
- # `full_attention_mask` is not None only when
416
- # `past_key_values` is not None and `seq_length` > 1
417
- if full_attention_mask is not None:
418
- causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
419
- dtype=inputs_embeds.dtype, device=inputs_embeds.device)
420
- mask_value = torch.finfo(inputs_embeds.dtype).min
421
- causal_mask.masked_fill_(full_attention_mask, mask_value)
422
- elif self.training or (inputs_embeds.device.type != "xpu" and past_key_values is None):
423
- full_attention_mask = self.get_masks(input_ids,
424
- past_key_values,
425
- padding_mask=attention_mask)
426
- causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
427
- dtype=inputs_embeds.dtype, device=inputs_embeds.device)
428
- mask_value = torch.finfo(inputs_embeds.dtype).min
429
- causal_mask.masked_fill_(full_attention_mask, mask_value)
430
- else:
431
- causal_mask = None
432
-
433
- # Run encoder.
434
- hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
435
- inputs_embeds, causal_mask,
436
- rotary_pos_emb=rotary_pos_emb,
437
- kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
438
- )
439
- # ipex-llm changes end
440
-
441
- if not return_dict:
442
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions]
443
- if v is not None)
444
-
445
- return BaseModelOutputWithPast(
446
- last_hidden_state=hidden_states,
447
- past_key_values=presents,
448
- hidden_states=all_hidden_states,
449
- attentions=all_self_attentions,
450
- )
451
-
452
-
453
- def codegeex_attention_forward(
454
- self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
455
- ):
456
- q_len, bsz, _ = hidden_states.size()
457
- n_head = self.num_attention_heads_per_partition
458
- n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head
459
- head_dim = self.hidden_size_per_attention_head
460
-
461
- past_key_value = None if kv_cache is None else (kv_cache[0].permute(1, 2, 0, 3),
462
- kv_cache[1].permute(1, 2, 0, 3))
463
- qkv = self.query_key_value(hidden_states)
464
- qkv = qkv.view(q_len, bsz, n_head + 2 * n_kv_head, head_dim)
465
- # [seq_len, bsz, n_head, head_dim] -> [bsz, n_head, seq_len, head_dim]
466
- qkv = qkv.permute(1, 2, 0, 3)
467
- query_layer, key_layer, value_layer = qkv.split([n_head,
468
- n_kv_head,
469
- n_kv_head], dim=1)
470
- kv_seq_len = key_layer.shape[2]
471
- if past_key_value is not None:
472
- kv_seq_len += past_key_value[0].shape[2]
473
-
474
- # apply relative positional encoding (rotary embedding)
475
- if len(rotary_pos_emb) == 2 and isinstance(rotary_pos_emb, tuple):
476
- cos, sin = rotary_pos_emb
477
- rot_dim = cos.shape[-1]
478
- query_layer = query_layer.transpose(1, 2)
479
- key_layer = key_layer.transpose(1, 2)
480
- query_layer_cur = query_layer[..., :rot_dim]
481
- key_layer_cur = key_layer[..., :rot_dim]
482
- # ipex_llm's apply_rotary_embedding can change the origin storage,
483
- # so query_layer will get the result directly.
484
- torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur)
485
- torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur)
486
- query_layer = query_layer.transpose(1, 2)
487
- key_layer = key_layer.transpose(1, 2)
488
- else:
489
- query_layer = apply_rotary_pos_emb_original(query_layer, rotary_pos_emb)
490
- key_layer = apply_rotary_pos_emb_original(key_layer, rotary_pos_emb)
491
-
492
- key_layer, value_layer = update_past_key_value(
493
- past_key_value, key_layer, value_layer,
494
- kv_seq_len, False, hidden_states.device
495
- )
496
- # past_key_value: [bsz, n_kv_head, seq_len, head_dim] -> [seq_len, bsz, n_kv_head, head_dim]
497
- past_key_value = (key_layer.permute(2, 0, 1, 3),
498
- value_layer.permute(2, 0, 1, 3)) if use_cache else None
499
-
500
- # =================
501
- # Output. [sq, b, h]
502
- # =================
503
- context_layer = scaled_dot_product_attention(
504
- query_layer, key_layer, value_layer,
505
- attention_mask, q_len == kv_seq_len
506
- )
507
-
508
- context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(q_len,
509
- bsz,
510
- n_head * head_dim)
511
- output = self.dense(context_layer)
512
-
513
- return output, past_key_value
514
-
515
324
  import torch.nn.functional as F
516
325
 
517
326
 
@@ -53,10 +53,10 @@ def siglip_attention_forward(
53
53
  qkv = qkv.transpose(1, 2)
54
54
  query_states, key_states, value_states = qkv.chunk(3, dim=1)
55
55
 
56
- from ipex_llm.transformers.utils import get_xpu_device_type
56
+ from ipex_llm.transformers.utils import get_xpu_device_name
57
57
  if (
58
58
  self.head_dim == 72
59
- and get_xpu_device_type(query_states) in ["arc", "flex"] and
59
+ and get_xpu_device_name(query_states.device) == "arc" and
60
60
  query_states.dtype in [torch.float, torch.half]
61
61
  ):
62
62
  n_heads, kv_length = query_states.size(1), key_states.size(2)
@@ -36,7 +36,7 @@ import math
36
36
  import torch
37
37
  from typing import Optional
38
38
 
39
- from ipex_llm.transformers.utils import get_xpu_device_type
39
+ from ipex_llm.transformers.utils import get_xpu_device_name
40
40
  from ipex_llm.transformers.models.common import padding_qkv_hd
41
41
  from ipex_llm.transformers.models.common import scaled_dot_product_attention
42
42
  from diffusers.models.attention_processor import Attention
@@ -144,7 +144,7 @@ class AttnProcessor2_0:
144
144
 
145
145
  def upcast_vae(self):
146
146
  # workaround overflow and ipex's bugs
147
- if get_xpu_device_type(self.vae.post_quant_conv.weight) in ["arc", "flex", "pvc"]:
147
+ if get_xpu_device_name(self.vae.post_quant_conv.weight.device) == "arc":
148
148
  self.vae.to(torch.bfloat16)
149
149
  else:
150
150
  self.vae.decoder.up_blocks.to(torch.bfloat16)
@@ -19,7 +19,7 @@ import torch
19
19
  import warnings
20
20
  from ipex_llm.utils.common import invalidInputError
21
21
  from ipex_llm.ggml.quantize import ggml_tensor_qtype
22
- from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type
22
+ from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_name
23
23
  from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4,\
24
24
  FP6, ASYM_INT4
25
25
 
@@ -85,16 +85,14 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: in
85
85
  return os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] == "1"
86
86
  elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
87
87
  return os.environ["IPEX_LLM_LOW_MEM"] == "1"
88
+ elif linear.qtype in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
89
+ return False
88
90
  else:
89
- return x.device.type == 'xpu' and kv_cache_device_check(x, kv_group) \
90
- and hasattr(linear, "qtype") and \
91
- linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
92
-
93
-
94
- def kv_cache_device_check(x: torch.Tensor, kv_group: int) -> bool:
95
- return (get_xpu_device_type(x) in ["mtl", "lnl"] and kv_group <= 1) or \
96
- ((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and
97
- 1 < x.size(0) and x.size(0) <= 8)
91
+ device_name = get_xpu_device_name(x.device)
92
+ return (
93
+ device_name in ["mtl", "lnl", "arl"] and kv_group == 1
94
+ or device_name in ["arc", "bmg"] and x.size(0) > 1
95
+ )
98
96
 
99
97
 
100
98
  def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device):
@@ -170,7 +168,7 @@ def should_use_fuse_rope(hidden_states, position_ids, training):
170
168
 
171
169
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
172
170
  if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
173
- "mixtral", "qwen2", "yuan", "stablelm", "qwen2_moe"]:
171
+ "qwen2", "yuan", "stablelm", "qwen2_moe"]:
174
172
  # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
175
173
  cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
176
174
  sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
@@ -185,7 +183,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
185
183
  q_embed = (q * cos) + (rotate_half(q) * sin)
186
184
  k_embed = (k * cos) + (rotate_half(k) * sin)
187
185
  return q_embed, k_embed
188
- elif model_family in ["gptj", "chatglm"]:
186
+ elif model_family in ["chatglm"]:
189
187
  q_embed = (q * cos) + (rotate_every_two(q) * sin)
190
188
  k_embed = (k * cos) + (rotate_every_two(k) * sin)
191
189
  return q_embed, k_embed
@@ -194,19 +192,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
194
192
  f"{model_family} is not supported.")
195
193
 
196
194
 
197
- def apply_ipex_rotate_every_two(q, k, cos, sin):
198
- # ipex's apply_rotary_embedding_two_qk can change the origin storage,
199
- # so q/k will get the result directly.
200
- from ipex_llm.transformers.utils import get_ipex_version
201
- if get_ipex_version() >= "2.1.10+xpu":
202
- torch.ops.torch_ipex.apply_rotary_embedding_two_qk(
203
- q, k, sin, cos, q, k
204
- )
205
- else:
206
- torch.ops.torch_ipex.apply_rotary_embedding(q, sin, cos, q)
207
- torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
208
-
209
-
210
195
  def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
211
196
  # to determinate if is enough kv cache room in transformers==4.36
212
197
  # seq_len for current seq len
@@ -226,57 +211,6 @@ def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1):
226
211
  (past_key_value[0].size(2) + seq_len) * past_key_value[0].size(3)
227
212
 
228
213
 
229
- def use_flash_attention(query, key, attention_mask=None):
230
- # here we support query's shape is always [batch_size, head_num, q_len, head_dim],
231
- # key's shape is always [batch_size, head_num, k_len, head_dim]
232
- invalidInputError(query.dim() == 4,
233
- "Here query input of use_flash_attention should be [batch_size, "
234
- "head_num, q_len, head_dim]")
235
- invalidInputError(key.dim() == 4,
236
- "Here key input of use_flash_attention should be [batch_size, "
237
- "head_num, k_len, head_dim]")
238
- bsz, _, q_len, _ = query.size()
239
- k_len = key.size()[2]
240
- # check whether ipex flash attention can be used
241
- if q_len != k_len:
242
- # now only use flash attention for first token
243
- # as it seems have no performance benifit for rest token now
244
- return False
245
- if query.device.type != "xpu":
246
- # ipex flash attention only support for xpu
247
- return False
248
- ipex_version = get_ipex_version()
249
- if ipex_version <= "2.0.110+xpu":
250
- # ipex flash attention is supported from ipex 2.1
251
- return False
252
- if not torch.xpu.has_xetla():
253
- # ipex flash attention is only supported for xetla
254
- # may update this later
255
- return False
256
- elif get_xpu_device_type(query) != "pvc":
257
- return False
258
- if query.dtype not in [torch.float32, torch.float16]:
259
- # only use flash attention for fp32/fp16 input
260
- return False
261
- if bsz > 1:
262
- # as flash attention doesn't support attn_mask in ipex 2.1,
263
- # so it will cause output error for padded batch input
264
- if attention_mask is None:
265
- return True
266
- else:
267
- # TODO: below logic may change for different model
268
- # attention mask shape : [bsz, 1, q_len, k_len]
269
- if attention_mask[0].squeeze()[0, 0].item() != 0:
270
- # first batch contains padding
271
- # otherwise we suppose it should be a upper triangular matrix
272
- # at the same time, the diagonal is also 0
273
- return False
274
- elif not attention_mask.equal(attention_mask[0].repeat(bsz, 1, 1, 1)):
275
- # check whether mask of every batch is the same
276
- return False
277
- return True
278
-
279
-
280
214
  def use_sdp(q_len, kv_len, head_dim, query_states):
281
215
  return (
282
216
  query_states.device.type == "xpu"
@@ -315,38 +249,16 @@ def mlp_fusion_check(x, qtype, training):
315
249
  if training or x.requires_grad:
316
250
  return False
317
251
  if qtype == FP6:
318
- device = get_xpu_device_type(x)
319
- if device in ["mtl", "lnl"]:
252
+ device = get_xpu_device_name(x.device)
253
+ if device in ["mtl", "lnl", "arl"]:
320
254
  return False
321
255
  return True
322
256
 
323
257
 
324
- def use_decoding_fast_path(proj,
325
- use_fuse_rope,
326
- enough_kv_room,
327
- bs,
328
- qtype_check=decoding_fast_path_qtype_check):
329
- if proj is None:
330
- return False
331
- device = get_xpu_device_type(proj.weight)
332
- if not qtype_check(proj):
333
- return False
334
- if not use_fuse_rope:
335
- return False
336
- if not enough_kv_room:
337
- return False
338
- if bs != 1:
339
- return False
340
-
341
- if device in ["uhd"]:
342
- return False
343
- return True
344
-
345
-
346
258
  def use_xmx(x: torch.Tensor, qtype: int):
347
- device = get_xpu_device_type(x)
259
+ device = get_xpu_device_name(x.device)
348
260
  return (
349
- device in ["arc", "flex", "pvc"]
261
+ device in ["arc", "pvc"]
350
262
  and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5]
351
263
  and (
352
264
  (device == "pvc" and 1 < x.size(0) <= 16)
@@ -370,7 +282,7 @@ def fp16_fusion_check(proj, x, training):
370
282
  return False
371
283
  if x.requires_grad:
372
284
  return False
373
- device_type = get_xpu_device_type(x)
285
+ device_type = get_xpu_device_name(x.device)
374
286
  if device_type != "pvc":
375
287
  return False
376
288
  return True
@@ -439,7 +351,7 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int):
439
351
  else:
440
352
  if use_compress_kv is None:
441
353
  return (
442
- get_xpu_device_type(x) in ["mtl", "lnl"]
354
+ get_xpu_device_name(x.device) in ["mtl", "lnl", "arl"]
443
355
  and prompt_len >= 1800
444
356
  and prompt_len <= 4500
445
357
  )
@@ -473,10 +473,6 @@ def convert_llm_for_deploy(model: torch.nn.Module,
473
473
  "n_splits_linear": n_splits_linear,
474
474
  "n_splits_down_proj": n_splits_down_proj,
475
475
  "lm_head_low_bit": lm_head_low_bit}
476
- model.config.update(update_dict)
477
- model.config.save_pretrained(save_directory)
478
- if model.can_generate():
479
- model.generation_config.save_pretrained(save_directory)
480
476
 
481
477
  from .qwen import convert_qwen_layer, convert_fused_qwen_layer
482
478
  from .qwen import convert_lm_head_and_embedding
@@ -537,8 +533,6 @@ def convert_llm_for_deploy(model: torch.nn.Module,
537
533
  "n_splits_linear": n_splits_linear,
538
534
  "n_splits_down_proj": n_splits_down_proj,
539
535
  "lm_head_low_bit": lm_head_low_bit}
540
- model.config.update(update_dict)
541
- model.config.save_pretrained(save_directory)
542
536
 
543
537
  from .llama import convert_llama_layer, convert_fused_llama_layer
544
538
  from .llama import convert_lm_head_and_embedding
@@ -577,8 +571,6 @@ def convert_llm_for_deploy(model: torch.nn.Module,
577
571
  "n_splits_linear": n_splits_linear,
578
572
  "n_splits_down_proj": n_splits_down_proj,
579
573
  "lm_head_low_bit": lm_head_low_bit}
580
- model.config.update(update_dict)
581
- model.config.save_pretrained(save_directory)
582
574
 
583
575
  from .minicpm import convert_minicpm_layer, convert_fused_minicpm_layer
584
576
  from .minicpm import convert_lm_head_and_embedding
@@ -595,3 +587,8 @@ def convert_llm_for_deploy(model: torch.nn.Module,
595
587
  save_directory, weight_dir,
596
588
  convert_model=True,
597
589
  max_prompt_len=max_prompt_len)
590
+
591
+ model.config.update(update_dict)
592
+ model.config.save_pretrained(save_directory)
593
+ if model.can_generate():
594
+ model.generation_config.save_pretrained(save_directory)
@@ -432,8 +432,7 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l
432
432
  from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
433
433
  extend_kv_cache
434
434
  enough_kv_room = True
435
- if model_type not in ["chatglm", "qwen", "baichuan", "llama", "mistral",
436
- "gptj", "opt"]:
435
+ if model_type not in ["chatglm", "qwen", "baichuan", "llama", "mistral", "opt"]:
437
436
  return past_key_values, False
438
437
  cache_k = past_key_values[0][0]
439
438
  if model_type == "chatglm":
@@ -527,7 +526,7 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
527
526
  v[:-(new_cache_size), :, :, :])
528
527
  for k, v in past_key_values
529
528
  ]
530
- elif self.config.model_type in ["baichuan", "gptj"]:
529
+ elif self.config.model_type in ["baichuan"]:
531
530
  past_key_values = [
532
531
  (k[:, :, :-(new_cache_size), :],
533
532
  v[:, :, :-(new_cache_size), :])
@@ -796,13 +795,6 @@ def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_
796
795
  device=verify_input_ids.device)
797
796
  position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
798
797
  forward_args["position_ids"] = position_ids
799
- elif self.config.model_type == "gptj":
800
- past_length = past_key_values[0][0].size(2)
801
- input_len = verify_input_ids.shape[1]
802
- position_ids = torch.arange(past_length, input_len + past_length,
803
- dtype=torch.long, device=verify_input_ids.device)
804
- position_ids = position_ids.unsqueeze(0).view(-1, input_len)
805
- forward_args["position_ids"] = position_ids
806
798
 
807
799
  return self(**forward_args)
808
800
 
@@ -971,10 +963,6 @@ def speculative_generate(self,
971
963
  past_key_value_len = past_key_values[0][0].shape[0]
972
964
  position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
973
965
  forward_args["position_ids"] = position_ids
974
- elif self.config.model_type == "gptj":
975
- past_length = draft_past_key_values[0][0].size(2)
976
- position_ids = torch.Tensor([[past_length]]).long().to(self.device)
977
- forward_args["position_ids"] = position_ids
978
966
 
979
967
  if _enable_ipex:
980
968
  if any(keyword in self.config.model_type
@@ -168,27 +168,14 @@ def get_ipex_version():
168
168
  return _ipex_version
169
169
 
170
170
 
171
- def get_xpu_device_type(x):
172
- if x.device.type != "xpu":
173
- return x.device.type
174
- name = torch.xpu.get_device_name(x.device.index)
175
- if name.startswith("Intel(R) Arc(TM) A"):
176
- return "arc"
177
- elif name.startswith("Intel(R) Graphics [0xe20b]"):
178
- return "bmg"
179
- elif name.startswith("Intel(R) Arc(TM)"):
180
- if 'V' in name:
181
- return "lnl"
182
- else:
183
- return "mtl"
184
- elif name.startswith("Intel(R) Data Center GPU Flex"):
185
- return "flex"
186
- elif name.startswith("Intel(R) Data Center GPU Max"):
187
- return "pvc"
188
- elif name.startswith("Intel(R) UHD"):
189
- return "uhd"
171
+ def get_xpu_device_name(device: torch.device):
172
+ if device.type != "xpu":
173
+ return device.type
190
174
  else:
191
- return "others"
175
+ # possiable device name:
176
+ # ["arc", "pvc", "mtl", "lnl", "bmg", "arl", "legacy", "unknown"]
177
+ import xe_linear
178
+ return xe_linear.get_xpu_device_name(device)
192
179
 
193
180
 
194
181
  def load_imatrix_data(imatrix_file):