ipex-llm 2.2.0b20250101__py3-none-manylinux2010_x86_64.whl → 2.2.0b20250103__py3-none-manylinux2010_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ipex_llm/optimize.py CHANGED
@@ -254,7 +254,9 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
254
254
  torch_dtype=torch_dtype,
255
255
  optimize_model=optimize_llm,
256
256
  modules_to_not_convert=modules_to_not_convert,
257
- cpu_embedding=cpu_embedding)
257
+ cpu_embedding=cpu_embedding,
258
+ disable_optimize_pre=kwargs.pop("disable_optimize_pre",
259
+ False))
258
260
  # add save_low_bit to pretrained model dynamically
259
261
  import types
260
262
  model._bigdl_config = dict()
@@ -1081,7 +1081,8 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
1081
1081
  torch_dtype="auto",
1082
1082
  imatrix_data=None,
1083
1083
  embedding_qtype=None,
1084
- mixed_precision=False):
1084
+ mixed_precision=False,
1085
+ disable_optimize_pre=False):
1085
1086
  if qtype in ggml_tensor_qtype.values():
1086
1087
  index = list(ggml_tensor_qtype.values()).index(qtype)
1087
1088
  logger.info(f"Converting the current model to "
@@ -1104,7 +1105,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
1104
1105
  model = _optimize_ipex(model, qtype)
1105
1106
  return model
1106
1107
 
1107
- if optimize_model:
1108
+ if optimize_model and not disable_optimize_pre:
1108
1109
  model = _optimize_pre(model, qtype)
1109
1110
 
1110
1111
  act_order = False
@@ -1983,16 +1984,9 @@ def _optimize_post(model):
1983
1984
  modeling_module_name = model.__class__.__module__
1984
1985
  module = importlib.import_module(modeling_module_name)
1985
1986
  from ipex_llm.transformers.models.yuan import yuan_attention_forward
1986
- # from ipex_llm.transformers.models.yuan import yuan_mlp_forward
1987
- convert_forward(model,
1988
- module.YuanAttention,
1989
- yuan_attention_forward
1990
- )
1991
- # disable able mlp_forward for quantize_kv on mtl.
1992
- # convert_forward(model,
1993
- # module.YuanMLP,
1994
- # yuan_mlp_forward
1995
- # )
1987
+ convert_forward(model, module.YuanAttention, yuan_attention_forward)
1988
+ # from ipex_llm.transformers.models.common import mlp_silu_forward
1989
+ # convert_forward(model, module.YuanMLP, mlp_silu_forward)
1996
1990
  elif model.config.model_type == 'bert' and (
1997
1991
  not model.config.is_decoder and
1998
1992
  model.config.position_embedding_type == "absolute"
@@ -764,6 +764,7 @@ class FP16Linear(nn.Linear):
764
764
  # weigh_type = 3 means weight has been transposed by esimd method
765
765
  self.weight_type = 1
766
766
  self.optimize_lm_head = optimize_lm_head
767
+ self.disable_fp16_opt = False
767
768
 
768
769
  def forward(self, x: torch.Tensor):
769
770
  # only work for GPU
@@ -779,8 +780,11 @@ class FP16Linear(nn.Linear):
779
780
  self.weight.data = self.weight.data.to(x.dtype)
780
781
 
781
782
  if not self.use_esimd_kernel(x):
782
- if get_ipex_version() < "2.1.10+xpu" \
783
- or get_xpu_device_type(x) not in ["arc", "flex", "pvc"]:
783
+ if (
784
+ get_ipex_version() < "2.1.10+xpu"
785
+ or get_xpu_device_type(x) not in ["arc", "flex", "pvc"]
786
+ or self.disable_fp16_opt
787
+ ):
784
788
  if self.weight_type == 2:
785
789
  self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
786
790
  requires_grad=False)
@@ -845,6 +849,8 @@ class FP16Linear(nn.Linear):
845
849
 
846
850
  def use_esimd_kernel(self, x):
847
851
  gpu_type = get_xpu_device_type(x)
852
+ if self.disable_fp16_opt:
853
+ return False
848
854
  # esimd kernel can only be used for Arc and Flex
849
855
  if gpu_type not in ["arc", "flex"]:
850
856
  return False
@@ -445,6 +445,7 @@ class _BaseAutoModelClass:
445
445
  mixed_precision = kwargs.pop("mixed_precision", False)
446
446
  if embedding_qtype is not None:
447
447
  embedding_qtype = ggml_tensor_qtype[embedding_qtype]
448
+ disable_optimize_pre = kwargs.pop("disable_optimize_pre", False)
448
449
  _args = copy.deepcopy(args)
449
450
  _kwargs = copy.deepcopy(kwargs)
450
451
  awq_config = None
@@ -513,7 +514,8 @@ class _BaseAutoModelClass:
513
514
  torch_dtype=kwargs.get("torch_dtype", 'auto'),
514
515
  imatrix_data=imatrix_data,
515
516
  embedding_qtype=embedding_qtype,
516
- mixed_precision=mixed_precision)
517
+ mixed_precision=mixed_precision,
518
+ disable_optimize_pre=disable_optimize_pre)
517
519
 
518
520
  if disk_embedding:
519
521
  from ipex_llm.transformers.embedding import DiskEmbedding
@@ -29,7 +29,7 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp
29
29
  should_use_compresskv
30
30
  from ipex_llm.transformers.models.utils import update_past_key_value
31
31
  from ipex_llm.transformers.models.utils import should_use_fuse_rope
32
- from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
32
+ from ipex_llm.transformers.models.utils import use_sdp
33
33
  from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU
34
34
  from ipex_llm.transformers.models.utils import mlp_fusion_check
35
35
  from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
@@ -301,16 +301,10 @@ def baichuan_attention_forward_7b(
301
301
 
302
302
  # IPEX-LLM OPT: sdp
303
303
  attn_weights = None
304
- if use_flash_attention(query_states, key_states, attention_mask):
305
- attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
306
- key_states.to(dtype=torch.float16),
307
- value_states.to(dtype=torch.float16),
308
- is_causal=True).to(hidden_states.dtype)
309
- else:
310
- attn_output = scaled_dot_product_attention(
311
- query_states, key_states, value_states,
312
- attention_mask, q_len == kv_seq_len
313
- )
304
+ attn_output = scaled_dot_product_attention(
305
+ query_states, key_states, value_states,
306
+ attention_mask, q_len == kv_seq_len
307
+ )
314
308
 
315
309
  attn_output = attn_output.transpose(1, 2).contiguous()
316
310
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@@ -23,7 +23,7 @@ import torch.utils.checkpoint
23
23
  import torch.nn.functional as F
24
24
  from typing import Optional, Tuple
25
25
  from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
26
- from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
26
+ from ipex_llm.transformers.models.utils import use_sdp
27
27
 
28
28
 
29
29
  def rotate_half(x):
@@ -41,7 +41,7 @@ def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
41
41
 
42
42
 
43
43
  def glm_sdpa(query, key, value, attention_mask=None, is_causal=False):
44
- if use_flash_attention(query, key, attention_mask) or query.device.type == 'cpu':
44
+ if query.device.type == 'cpu':
45
45
  context_layer = F.scaled_dot_product_attention(query.to(key.dtype),
46
46
  key,
47
47
  value,
@@ -33,7 +33,6 @@ from ipex_llm.transformers.models.utils import update_past_key_value, should_use
33
33
  from ipex_llm.transformers.models.utils import use_quantize_kv_cache
34
34
  from ipex_llm.transformers.models.utils import rotate_half, SILU
35
35
  from ipex_llm.transformers.models.utils import mlp_fusion_check
36
- from ipex_llm.transformers.models.utils import use_flash_attention
37
36
  from ipex_llm.utils.common import invalidInputError
38
37
  from transformers.modeling_outputs import BaseModelOutputWithPast
39
38
 
@@ -116,33 +115,28 @@ def qwen_attention_forward(
116
115
  past_key_value = (key_states.transpose(1, 2),
117
116
  value_states.transpose(1, 2)) if use_cache else None
118
117
 
119
- # IPEX-LLM OPT: sdp
118
+ # IPEX-LLM OPT: sdpa
120
119
  attn_weights = None
121
- if use_flash_attention(query_states, key_states, attention_mask):
122
- attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
123
- key_states.to(dtype=torch.float16),
124
- value_states.to(dtype=torch.float16),
125
- is_causal=True).to(hidden_states.dtype)
120
+
121
+ if q_len > 1 and q_len != kv_seq_len:
122
+ causal_mask = torch.tril(
123
+ torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query_states.device)
124
+ ).view(1, 1, kv_seq_len, kv_seq_len)
125
+ causal_mask = causal_mask[
126
+ :, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len
127
+ ]
128
+ attention_mask = torch.zeros(causal_mask.shape, dtype=query_states.dtype,
129
+ device=query_states.device)
130
+ attention_mask.masked_fill_(causal_mask.logical_not(),
131
+ torch.finfo(attention_mask.dtype).min)
132
+ attention_mask = attention_mask.expand([bsz, -1, -1, -1])
126
133
  else:
127
- if q_len > 1 and q_len != kv_seq_len:
128
- causal_mask = torch.tril(
129
- torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query_states.device)
130
- ).view(1, 1, kv_seq_len, kv_seq_len)
131
- causal_mask = causal_mask[
132
- :, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len
133
- ]
134
- attention_mask = torch.zeros(causal_mask.shape, dtype=query_states.dtype,
135
- device=query_states.device)
136
- attention_mask.masked_fill_(causal_mask.logical_not(),
137
- torch.finfo(attention_mask.dtype).min)
138
- attention_mask = attention_mask.expand([bsz, -1, -1, -1])
139
- else:
140
- attention_mask = None
134
+ attention_mask = None
141
135
 
142
- attn_output = scaled_dot_product_attention(
143
- query_states, key_states, value_states,
144
- attention_mask, q_len == kv_seq_len
145
- )
136
+ attn_output = scaled_dot_product_attention(
137
+ query_states, key_states, value_states,
138
+ attention_mask, q_len == kv_seq_len
139
+ )
146
140
 
147
141
  attn_output = attn_output.transpose(1, 2).contiguous()
148
142
  attn_output = attn_output.view(bsz, q_len, self.hidden_size)
@@ -219,31 +213,25 @@ def qwen_attention_forward_registered(
219
213
  past_key_value = (key_states.transpose(1, 2),
220
214
  value_states.transpose(1, 2)) if use_cache else None
221
215
 
222
- # IPEX-LLM OPT: sdp
216
+ # IPEX-LLM OPT: sdpa
223
217
  attn_weights = None
224
218
 
225
- if use_flash_attention(query_states, key_states, attention_mask):
226
- attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
227
- key_states.to(dtype=torch.float16),
228
- value_states.to(dtype=torch.float16),
229
- is_causal=True).to(hidden_states.dtype)
219
+ if q_len > 1 and q_len != kv_seq_len:
220
+ causal_mask = registered_causal_mask[
221
+ :, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len
222
+ ]
223
+ attention_mask = torch.zeros(causal_mask.shape, dtype=query_states.dtype,
224
+ device=query_states.device)
225
+ attention_mask.masked_fill_(causal_mask.logical_not(),
226
+ torch.finfo(attention_mask.dtype).min)
227
+ attention_mask = attention_mask.expand([bsz, -1, -1, -1])
230
228
  else:
231
- if q_len > 1 and q_len != kv_seq_len:
232
- causal_mask = registered_causal_mask[
233
- :, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len
234
- ]
235
- attention_mask = torch.zeros(causal_mask.shape, dtype=query_states.dtype,
236
- device=query_states.device)
237
- attention_mask.masked_fill_(causal_mask.logical_not(),
238
- torch.finfo(attention_mask.dtype).min)
239
- attention_mask = attention_mask.expand([bsz, -1, -1, -1])
240
- else:
241
- attention_mask = None
229
+ attention_mask = None
242
230
 
243
- attn_output = scaled_dot_product_attention(
244
- query_states, key_states, value_states,
245
- attention_mask, q_len == kv_seq_len
246
- )
231
+ attn_output = scaled_dot_product_attention(
232
+ query_states, key_states, value_states,
233
+ attention_mask, q_len == kv_seq_len
234
+ )
247
235
 
248
236
  attn_output = attn_output.transpose(1, 2).contiguous()
249
237
  attn_output = attn_output.view(bsz, q_len, self.hidden_size)
@@ -38,12 +38,10 @@
38
38
  #
39
39
 
40
40
  import os
41
- import math
42
41
  from typing import Optional, Tuple, Union, List
43
42
 
44
43
  import torch
45
44
  from torch.nn import CrossEntropyLoss
46
- from torch.nn.functional import scaled_dot_product_attention as sdpa
47
45
 
48
46
  from ipex_llm.transformers.models.common import merge_qkv_base
49
47
  from ipex_llm.transformers.models.common import scaled_dot_product_attention
@@ -51,13 +49,12 @@ from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
51
49
  from ipex_llm.transformers.models.utils import should_use_fuse_rope
52
50
  from ipex_llm.transformers.models.utils import use_quantize_kv_cache, \
53
51
  should_use_compresskv, is_enough_kv_cache_room_4_36
54
- from ipex_llm.transformers.models.utils import use_flash_attention
55
52
  from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, \
56
53
  DynamicCompressCache, DynamicCompressFp8Cache
57
54
  from ipex_llm.utils.common import invalidInputError
58
55
 
59
56
  from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
60
- from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv
57
+ from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb
61
58
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
62
59
  from transformers.cache_utils import Cache
63
60
  from transformers import logging
@@ -580,21 +577,10 @@ def qwen2_attention_forward(
580
577
  self.layer_idx, None)
581
578
 
582
579
  attn_weights = None
583
- if use_flash_attention(query_states, key_states, attention_mask):
584
- if attention_mask is not None:
585
- attention_mask = attention_mask[:, :, :, :kv_seq_len]
586
- # repeat k/v heads if n_kv_heads < n_heads
587
- key_states = repeat_kv(key_states, self.num_key_value_groups)
588
- value_states = repeat_kv(value_states, self.num_key_value_groups)
589
- attn_output = sdpa(query_states.to(device, dtype=torch.float16),
590
- key_states.to(device, dtype=torch.float16),
591
- value_states.to(device, dtype=torch.float16),
592
- is_causal=True).to(hidden_states.dtype)
593
- else:
594
- attn_output = scaled_dot_product_attention(
595
- query_states, key_states, value_states,
596
- attention_mask, q_len == kv_seq_len
597
- )
580
+ attn_output = scaled_dot_product_attention(
581
+ query_states, key_states, value_states,
582
+ attention_mask, q_len == kv_seq_len
583
+ )
598
584
 
599
585
  attn_output = attn_output.transpose(1, 2).contiguous()
600
586
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@@ -20,17 +20,15 @@
20
20
  # https://huggingface.co/IEITYuan/Yuan2-2B-hf/blob/7ab7b3c18eb8e5232ce2a3f720d4e6f4b53a2806/README.md#%E5%A3%B0%E6%98%8E%E4%B8%8E%E5%8D%8F%E8%AE%AEterms-and-conditions
21
21
  #
22
22
 
23
- import math
24
23
  from typing import Optional, Tuple
25
24
 
26
25
  import torch
27
26
 
28
27
  from ipex_llm.utils.common import invalidInputError
29
28
  from ipex_llm.transformers.models.common import scaled_dot_product_attention
30
- from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
31
- mlp_fusion_check, fp16_fusion_check
29
+ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
32
30
  from ipex_llm.transformers.models.utils import use_quantize_kv_cache
33
- from ipex_llm.transformers.models.utils import SILU, update_past_key_value
31
+ from ipex_llm.transformers.models.utils import update_past_key_value
34
32
  from ipex_llm.transformers.models.utils import should_use_fuse_rope
35
33
 
36
34
 
@@ -98,52 +96,6 @@ def yuan_localized_filtering_forward(
98
96
  return lf_output
99
97
 
100
98
 
101
- def yuan_mlp_forward(
102
- self,
103
- x: torch.Tensor,
104
- residual=None
105
- ) -> torch.Tensor:
106
- x_2d = x.view(-1, x.shape[-1])
107
- bsz, hidden_size = x_2d.shape
108
- qtype = getattr(self.up_proj, "qtype", None)
109
- if mlp_fusion_check(x_2d, qtype, self.training):
110
- import xe_linear
111
- if not x_2d.is_contiguous():
112
- x_2d = x_2d.contiguous()
113
- out = self.down_proj(xe_linear.mlp_forward_xpu(
114
- x_2d, self.up_proj.weight.data, self.gate_proj.weight.data,
115
- x_2d.shape[0], x_2d.shape[1], self.up_proj.out_len,
116
- SILU, qtype
117
- ))
118
- if residual is not None:
119
- return out + residual
120
- else:
121
- return out
122
- elif fp16_fusion_check(self.up_proj, x, self.training) and \
123
- hidden_size == 4096 and bsz == 1:
124
- hidden_states1 = torch.ops.torch_ipex.mm_silu(x, self.up_proj.weight)
125
- hidden_states = torch.ops.torch_ipex.mm_resmul(
126
- x, self.gate_proj.weight, hidden_states1
127
- )
128
- if residual is None:
129
- hidden_states = torch.matmul(hidden_states, self.down_proj.weight)
130
- else:
131
- attn_output = torch.addmm(
132
- residual.flatten(0, -2),
133
- hidden_states.flatten(0, -2),
134
- self.down_proj.weight,
135
- beta=1,
136
- )
137
- hidden_states = attn_output.view(x.shape)
138
- return hidden_states
139
- else:
140
- out = self.down_proj(self.act_fn(self.up_proj(x)) * self.gate_proj(x))
141
- if residual is not None:
142
- return out + residual
143
- else:
144
- return out
145
-
146
-
147
99
  def yuan_attention_forward(
148
100
  self,
149
101
  hidden_states: torch.Tensor,
@@ -301,8 +301,7 @@ class _BaseAutoModelClass:
301
301
  model.share_memory()
302
302
 
303
303
  if not pipeline:
304
- if (not hasattr(model, 'llm') and
305
- model.config.model_type in ["qwen2", "llama", "minicpm"]):
304
+ if model.config.model_type in ["qwen2", "llama", "minicpm"]:
306
305
  from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
307
306
  optimize_llm_single_process(
308
307
  llm,
@@ -312,7 +311,8 @@ class _BaseAutoModelClass:
312
311
  group_size=quantization_group_size,
313
312
  qtype=qtype,
314
313
  save_directory=save_directory,
315
- fuse_layers=fuse_layers
314
+ fuse_layers=fuse_layers,
315
+ has_llm=hasattr(model, "llm")
316
316
  )
317
317
  else:
318
318
  optimize_llm(
@@ -449,7 +449,8 @@ def optimize_llm_single_process(
449
449
  group_size: int,
450
450
  qtype: str,
451
451
  save_directory: str,
452
- fuse_layers: int=None
452
+ fuse_layers: int=None,
453
+ has_llm: bool=False
453
454
  ):
454
455
  from ipex_llm.transformers.npu_pipeline_model.convert_pipeline import convert_llm
455
456
  from .npu_llm_cpp import load_model_from_file
@@ -468,8 +469,13 @@ def optimize_llm_single_process(
468
469
  model.kv_len = kv_len
469
470
  model.model_ptr = model_ptr
470
471
  model.save_directory = save_directory
471
- model.vocab_size = model.config.vocab_size
472
+ if model.config.vocab_size == 151666:
473
+ # for MiniCPM-V 2.6, 152064 is vocab_size of Qwen2-7B
474
+ model.vocab_size = 152064
475
+ else:
476
+ model.vocab_size = model.config.vocab_size
472
477
  model.logits_buffer = torch.empty(1, 1, model.vocab_size, dtype=torch.float32)
478
+ model.max_prompt_len = max_prompt_len
473
479
  except:
474
480
  invalidInputError(False,
475
481
  "False to InitLLMPipeline.")
@@ -478,9 +484,10 @@ def optimize_llm_single_process(
478
484
  general_convert(model, PreTrainedModel, prepare_input_ids, "prepare_inputs_for_generation")
479
485
  general_convert(model, PreTrainedModel, causal_lm_forward)
480
486
  # patch generate function
481
- import types
482
- model.original_generate = model.generate
483
- model.generate = types.MethodType(generate, model)
487
+ if not has_llm:
488
+ import types
489
+ model.original_generate = model.generate
490
+ model.generate = types.MethodType(generate, model)
484
491
  return model
485
492
 
486
493
 
@@ -491,9 +498,10 @@ def prepare_input_ids(
491
498
  else: # prefill, reset the model here
492
499
  from .npu_llm_cpp import reset
493
500
  reset(self.model_ptr)
494
- model_inputs = {
495
- "input_ids": input_ids
496
- }
501
+ if inputs_embeds is not None and past_key_values is None:
502
+ model_inputs = {"inputs_embeds": inputs_embeds}
503
+ else:
504
+ model_inputs = {"input_ids": input_ids}
497
505
  return model_inputs
498
506
 
499
507
 
@@ -511,17 +519,31 @@ def causal_lm_forward(
511
519
  return_dict: Optional[bool] = None,
512
520
  ) -> Union[Tuple, CausalLMOutputWithPast]:
513
521
  from .npu_llm_cpp import run_prefill_with_logits, run_decode_with_logits
514
- if isinstance(input_ids[0], torch.Tensor):
515
- input_list = input_ids[0].flatten().tolist()
516
- else:
517
- input_list = input_ids[0]
518
- input_length = len(input_list)
519
- if input_length > 1:
520
- logits = run_prefill_with_logits(self.model_ptr, input_list,
521
- self.logits_buffer, self.vocab_size)
522
+ if input_ids is not None:
523
+ if isinstance(input_ids[0], torch.Tensor):
524
+ input_list = input_ids[0].flatten().tolist()
525
+ else:
526
+ input_list = input_ids[0]
527
+ input_length = len(input_list)
528
+ if input_length > 1:
529
+ logits = run_prefill_with_logits(self.model_ptr, input_list,
530
+ self.logits_buffer, self.vocab_size)
531
+ else:
532
+ logits = run_decode_with_logits(self.model_ptr, input_list[0],
533
+ self.logits_buffer, self.vocab_size)
534
+ elif inputs_embeds is not None:
535
+ seq_len = inputs_embeds.shape[1]
536
+ pad_len = self.max_prompt_len - seq_len
537
+ inputs_embeds = torch.nn.functional.pad(inputs_embeds.to(torch.float16),
538
+ (0, 0, 0, pad_len), value=0.0)
539
+ logits = run_prefill_with_logits(self.model_ptr, None, self.logits_buffer,
540
+ self.vocab_size, inputs_embeds, seq_len)
522
541
  else:
523
- logits = run_decode_with_logits(self.model_ptr, input_list[0],
524
- self.logits_buffer, self.vocab_size)
542
+ invalidInputError(False, "Please specify either input_ids or inputs_embeds.")
543
+
544
+ if self.config.vocab_size == 151666:
545
+ # for MiniCPM-V 2.6
546
+ logits = logits[:, :, :151666]
525
547
 
526
548
  return CausalLMOutputWithPast(
527
549
  loss=None,
@@ -48,8 +48,8 @@ _lib = ctypes.cdll.LoadLibrary(_lib_path)
48
48
  _lib.load_model_from_file.argtypes = [ctypes.c_char_p]
49
49
  _lib.load_model_from_file.restype = ctypes.c_void_p
50
50
 
51
- _lib.run_prefill.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int), ctypes.c_int,
52
- ctypes.c_float]
51
+ _lib.run_prefill.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int,
52
+ ctypes.c_float, ctypes.c_bool]
53
53
  _lib.run_prefill.restype = ctypes.POINTER(ctypes.c_float)
54
54
 
55
55
  _lib.run_decode.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_float]
@@ -61,8 +61,10 @@ _lib.llm_sample_token.restype = ctypes.c_int
61
61
  _lib.reset.argtypes = [ctypes.c_void_p]
62
62
  _lib.reset.restype = None
63
63
 
64
- _lib.run_prefill_with_logits.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int),
65
- ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.c_int]
64
+ _lib.run_prefill_with_logits.argtypes = [ctypes.c_void_p, ctypes.c_void_p,
65
+ ctypes.c_int, ctypes.POINTER(ctypes.c_float),
66
+ ctypes.c_int, ctypes.c_bool]
67
+
66
68
  _lib.run_prefill_with_logits.restype = None
67
69
 
68
70
  _lib.run_decode_with_logits.argtypes = [ctypes.c_void_p, ctypes.c_int,
@@ -77,7 +79,7 @@ def load_model_from_file(model_dir: str):
77
79
  def run_prefill(model_ptr, input_ids, vocab_size, repetition_penalty=1.0):
78
80
  input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
79
81
  input_len = len(input_ids)
80
- plogits = _lib.run_prefill(model_ptr, input_ptr, input_len, repetition_penalty)
82
+ plogits = _lib.run_prefill(model_ptr, input_ptr, input_len, repetition_penalty, False)
81
83
  new_token = _lib.llm_sample_token(plogits, True, vocab_size)
82
84
  return new_token
83
85
 
@@ -88,12 +90,19 @@ def run_decode(model_ptr, input_id, vocab_size, repetition_penalty=1.0):
88
90
  return new_token
89
91
 
90
92
 
91
- def run_prefill_with_logits(model_ptr, input_ids, logits, vocab_size):
92
- input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
93
- input_len = len(input_ids)
93
+ def run_prefill_with_logits(model_ptr, input_ids, logits, vocab_size,
94
+ inputs_embeds=None, seq_len=None):
95
+ if input_ids is not None:
96
+ input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
97
+ input_len = len(input_ids)
98
+ else:
99
+ input_ptr = inputs_embeds.contiguous().data.data_ptr()
100
+ input_ptr = ctypes.cast(input_ptr, ctypes.c_void_p)
101
+ input_len = seq_len
94
102
  logits_ptr = logits.data.data_ptr()
95
103
  logits_ptr = ctypes.cast(logits_ptr, ctypes.POINTER(ctypes.c_float))
96
- _lib.run_prefill_with_logits(model_ptr, input_ptr, input_len, logits_ptr, vocab_size)
104
+ _lib.run_prefill_with_logits(model_ptr, input_ptr, input_len, logits_ptr,
105
+ vocab_size, (input_ids is None))
97
106
  return logits
98
107
 
99
108
 
@@ -34,6 +34,10 @@ def convert_lm_head_and_embedding(model, temp_dir, weight_dir,
34
34
  lm_head_n_splits = 1
35
35
  asym = getattr(model.config, "asym", False)
36
36
 
37
+ if vocab_size == 151666:
38
+ # for MiniCPM-V 2.6 lm_head on NPU
39
+ vocab_size = 152064
40
+
37
41
  if not isinstance(lm_head, SlicedLMHead):
38
42
  asym = lm_head.qtype == "asym_int4_rtn"
39
43
  if asym: