xinference 1.10.1__py3-none-any.whl → 1.11.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (39) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +462 -3
  3. xinference/client/restful/async_restful_client.py +158 -5
  4. xinference/client/restful/restful_client.py +131 -0
  5. xinference/core/supervisor.py +12 -0
  6. xinference/model/audio/model_spec.json +20 -20
  7. xinference/model/image/model_spec.json +159 -159
  8. xinference/model/llm/__init__.py +2 -2
  9. xinference/model/llm/llm_family.json +843 -180
  10. xinference/model/llm/mlx/distributed_models/core.py +41 -0
  11. xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
  12. xinference/model/llm/sglang/core.py +20 -6
  13. xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
  14. xinference/model/llm/transformers/chatglm.py +3 -0
  15. xinference/model/llm/transformers/core.py +93 -16
  16. xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
  17. xinference/model/llm/transformers/utils.py +3 -0
  18. xinference/model/llm/utils.py +37 -24
  19. xinference/model/llm/vllm/core.py +128 -69
  20. xinference/model/utils.py +74 -31
  21. xinference/thirdparty/audiotools/core/audio_signal.py +6 -6
  22. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
  23. xinference/thirdparty/melo/text/chinese_mix.py +2 -2
  24. xinference/types.py +9 -0
  25. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  26. xinference/ui/web/ui/build/index.html +1 -1
  27. xinference/ui/web/ui/build/static/js/{main.d192c4f3.js → main.e4d9a9e1.js} +3 -3
  28. xinference/ui/web/ui/build/static/js/{main.d192c4f3.js.map → main.e4d9a9e1.js.map} +1 -1
  29. xinference/ui/web/ui/node_modules/.cache/babel-loader/e6770a05771952175c9fbf48fce283c9bb1bc8b5763e39edc36d099d1fe16b4a.json +1 -0
  30. xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
  31. {xinference-1.10.1.dist-info → xinference-1.11.0.post1.dist-info}/METADATA +8 -5
  32. {xinference-1.10.1.dist-info → xinference-1.11.0.post1.dist-info}/RECORD +37 -36
  33. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +0 -1
  34. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +0 -1
  35. /xinference/ui/web/ui/build/static/js/{main.d192c4f3.js.LICENSE.txt → main.e4d9a9e1.js.LICENSE.txt} +0 -0
  36. {xinference-1.10.1.dist-info → xinference-1.11.0.post1.dist-info}/WHEEL +0 -0
  37. {xinference-1.10.1.dist-info → xinference-1.11.0.post1.dist-info}/entry_points.txt +0 -0
  38. {xinference-1.10.1.dist-info → xinference-1.11.0.post1.dist-info}/licenses/LICENSE +0 -0
  39. {xinference-1.10.1.dist-info → xinference-1.11.0.post1.dist-info}/top_level.txt +0 -0
@@ -162,3 +162,44 @@ class DistributedModelMixin:
162
162
  self.layers = self.layers[: self.end_idx]
163
163
  self.layers[: self.start_idx] = [None] * self.start_idx
164
164
  self.num_layers = len(self.layers) - self.start_idx
165
+
166
+
167
+ class SafeKVCache:
168
+ """
169
+ A safe wrapper around mlx_lm's KVCache that handles None keys gracefully.
170
+ This is needed because mlx_lm's generate function accesses cache.state
171
+ before the cache is properly initialized.
172
+ """
173
+
174
+ def __init__(self):
175
+ from mlx_lm.models.cache import KVCache
176
+
177
+ self._cache = KVCache()
178
+
179
+ @property
180
+ def state(self):
181
+ # Safe access to state property
182
+ if self._cache.keys is None:
183
+ return None, None
184
+ if self._cache.offset == self._cache.keys.shape[2]:
185
+ return self._cache.keys, self._cache.values
186
+ else:
187
+ return (
188
+ self._cache.keys[..., : self._cache.offset, :],
189
+ self._cache.values[..., : self._cache.offset, :],
190
+ )
191
+
192
+ @state.setter
193
+ def state(self, v):
194
+ # Safe setter for state property
195
+ if v is None or v[0] is None:
196
+ self._cache.keys = None
197
+ self._cache.values = None
198
+ self._cache.offset = 0
199
+ else:
200
+ self._cache.keys, self._cache.values = v
201
+ self._cache.offset = self._cache.keys.shape[2]
202
+
203
+ def __getattr__(self, name):
204
+ # Delegate all other attributes and methods to the underlying cache
205
+ return getattr(self._cache, name)
@@ -46,11 +46,10 @@ class Qwen2Model(_Qwen2Model, DistributedModelMixin):
46
46
 
47
47
  pipeline_rank = self.rank
48
48
  pipeline_size = self.world_size
49
- if mask is None:
50
- mask = create_attention_mask(h, cache)
51
49
 
52
50
  if cache is None:
53
51
  cache = [None] * self.num_layers
52
+ mask = create_attention_mask(h, cache[0])
54
53
 
55
54
  # Receive from the previous process in the pipeline
56
55
 
@@ -362,9 +362,16 @@ class SGLANGModel(LLM):
362
362
  def _convert_state_to_completion_chunk(
363
363
  request_id: str, model: str, output_text: str, meta_info: Dict
364
364
  ) -> CompletionChunk:
365
- finish_reason = meta_info.get("finish_reason", None)
366
- if isinstance(finish_reason, dict) and "type" in finish_reason:
367
- finish_reason = finish_reason["type"]
365
+ finish_reason_raw = meta_info.get("finish_reason", None)
366
+ finish_reason: Optional[str] = None
367
+ if isinstance(finish_reason_raw, dict) and "type" in finish_reason_raw:
368
+ finish_reason = (
369
+ str(finish_reason_raw["type"])
370
+ if finish_reason_raw["type"] is not None
371
+ else None
372
+ )
373
+ elif isinstance(finish_reason_raw, str):
374
+ finish_reason = finish_reason_raw
368
375
  choices: List[CompletionChoice] = [
369
376
  CompletionChoice(
370
377
  text=output_text,
@@ -392,9 +399,16 @@ class SGLANGModel(LLM):
392
399
  def _convert_state_to_completion(
393
400
  request_id: str, model: str, output_text: str, meta_info: Dict
394
401
  ) -> Completion:
395
- finish_reason = meta_info.get("finish_reason", None)
396
- if isinstance(finish_reason, dict) and "type" in finish_reason:
397
- finish_reason = finish_reason["type"]
402
+ finish_reason_raw = meta_info.get("finish_reason", None)
403
+ finish_reason: Optional[str] = None
404
+ if isinstance(finish_reason_raw, dict) and "type" in finish_reason_raw:
405
+ finish_reason = (
406
+ str(finish_reason_raw["type"])
407
+ if finish_reason_raw["type"] is not None
408
+ else None
409
+ )
410
+ elif isinstance(finish_reason_raw, str):
411
+ finish_reason = finish_reason_raw
398
412
  choices = [
399
413
  CompletionChoice(
400
414
  text=output_text,
@@ -59,10 +59,28 @@ class QwenToolParser(ToolParser):
59
59
  Returns:
60
60
  str: Extracted JSON string or original string if no match found.
61
61
  """
62
+ # First try to find complete tool calls
62
63
  function_calls = self.tool_call_complete_regex.findall(function_call_str)
63
- if len(function_calls) == 0:
64
- return function_call_str
65
- return function_calls[-1]
64
+ if len(function_calls) > 0:
65
+ return function_calls[-1]
66
+
67
+ # If no complete tool calls found, try to extract from incomplete tool calls
68
+ # Handle cases like <tool_call><tool_call>_city
69
+ if self.tool_call_start_token in function_call_str:
70
+ # Extract content between the last tool_call start token and end of string
71
+ last_start = function_call_str.rfind(self.tool_call_start_token)
72
+ potential_json = function_call_str[
73
+ last_start + len(self.tool_call_start_token) :
74
+ ]
75
+ # Remove any trailing tool_call end tokens
76
+ if self.tool_call_end_token in potential_json:
77
+ potential_json = potential_json.split(self.tool_call_end_token)[0]
78
+ # Clean up any extra whitespace
79
+ potential_json = potential_json.strip()
80
+ if potential_json:
81
+ return potential_json
82
+
83
+ return function_call_str
66
84
 
67
85
  def _parse_json_function_call_stream(
68
86
  self,
@@ -229,7 +247,14 @@ class QwenToolParser(ToolParser):
229
247
  try:
230
248
  parsed_json = self._parse_json_function_call(function_call)
231
249
  res = json.loads(parsed_json, strict=False)
232
- results.append((None, res["name"], res["arguments"]))
250
+ # Validate that we have the required fields
251
+ if "name" in res and "arguments" in res:
252
+ results.append((None, res["name"], res["arguments"]))
253
+ else:
254
+ logger.warning(
255
+ "Invalid tool call format, missing required fields: %s", res
256
+ )
257
+ results.append((function_call, None, None))
233
258
  except Exception as e:
234
259
  logger.error(
235
260
  "Can't parse single qwen tool call output: %s. Error: %s",
@@ -472,6 +472,9 @@ class ChatglmPytorchChatModel(PytorchChatModel):
472
472
  r.prompt = self._process_messages(
473
473
  r.prompt, tools=tools, tool_choice=tool_choice
474
474
  )
475
+ assert isinstance(
476
+ r.prompt, list
477
+ ), "r.prompt must be a list after processing"
475
478
  r.full_prompt = self.get_full_context(
476
479
  r.prompt,
477
480
  self.model_family.chat_template, # type: ignore
@@ -48,6 +48,7 @@ from ..utils import (
48
48
  )
49
49
  from .utils import (
50
50
  _get_pad_param,
51
+ convert_to_cache_cls,
51
52
  get_context_length,
52
53
  get_max_src_len,
53
54
  pad_prefill_tokens,
@@ -573,6 +574,7 @@ class PytorchModel(LLM):
573
574
  ]
574
575
  )
575
576
  data.append(x)
577
+
576
578
  return torch.stack(data).to(self._device)
577
579
 
578
580
  def build_prefill_position_ids(
@@ -713,30 +715,105 @@ class PytorchModel(LLM):
713
715
  from torch.nn.functional import pad
714
716
  from transformers import DynamicCache
715
717
 
718
+ # Handle case where past_cache is None
719
+ if past_cache is None:
720
+ return new_cache
721
+
722
+ # Convert both caches to DynamicCache if not already
723
+ if not isinstance(past_cache, DynamicCache):
724
+ past_cache = convert_to_cache_cls(past_cache)
725
+ if not isinstance(new_cache, DynamicCache):
726
+ new_cache = convert_to_cache_cls(new_cache)
727
+
716
728
  _, seq_len_idx = self.get_batch_size_and_seq_len_indexes_from_kv()
717
- past_seq_len = past_cache[0][0].shape[seq_len_idx]
718
- new_seq_len = new_cache[0][0].shape[seq_len_idx]
729
+
730
+ # Handle empty caches
731
+ if len(past_cache) == 0:
732
+ return new_cache
733
+ if len(new_cache) == 0:
734
+ return past_cache
735
+
736
+ # Get first layer seq_len safely
737
+ past_first = past_cache[0] if len(past_cache) > 0 else (None, None)
738
+ new_first = new_cache[0] if len(new_cache) > 0 else (None, None)
739
+
740
+ if past_first[0] is None or past_first[1] is None:
741
+ return new_cache
742
+ if new_first[0] is None or new_first[1] is None:
743
+ return past_cache
744
+
745
+ past_seq_len = past_first[0].shape[seq_len_idx]
746
+ new_seq_len = new_first[0].shape[seq_len_idx]
747
+
748
+ # Pad the shorter cache
719
749
  if past_seq_len != new_seq_len:
720
- padding_target = new_cache if past_seq_len > new_seq_len else past_cache
721
- padding_len = abs(past_seq_len - new_seq_len)
750
+ if past_seq_len > new_seq_len:
751
+ padding_target = new_cache
752
+ padding_len = past_seq_len - new_seq_len
753
+ else:
754
+ padding_target = past_cache
755
+ padding_len = new_seq_len - past_seq_len
756
+
722
757
  pad_param = _get_pad_param(seq_len_idx, padding_len)
723
758
  for idx in range(len(padding_target)):
724
759
  k = padding_target.key_cache[idx]
725
760
  v = padding_target.value_cache[idx]
726
- _k = pad(k, pad_param)
727
- _v = pad(v, pad_param)
728
- padding_target.key_cache[idx] = _k
729
- padding_target.value_cache[idx] = _v
761
+ if k is not None and v is not None:
762
+ padding_target.key_cache[idx] = pad(k, pad_param)
763
+ padding_target.value_cache[idx] = pad(v, pad_param)
730
764
 
765
+ # Merge caches
731
766
  ret_kv = DynamicCache()
732
- for idx in range(len(past_cache)):
733
- k1, k2 = new_cache.key_cache[idx], past_cache.key_cache[idx]
734
- v1, v2 = new_cache.value_cache[idx], past_cache.value_cache[idx]
735
- ret_kv.update(
736
- torch.cat((k1, k2), 0).contiguous(),
737
- torch.cat((v1, v2), 0).contiguous(),
738
- idx,
739
- )
767
+ max_layers = max(len(past_cache), len(new_cache))
768
+
769
+ for idx in range(max_layers):
770
+ past_k = past_cache.key_cache[idx] if idx < len(past_cache) else None
771
+ past_v = past_cache.value_cache[idx] if idx < len(past_cache) else None
772
+ new_k = new_cache.key_cache[idx] if idx < len(new_cache) else None
773
+ new_v = new_cache.value_cache[idx] if idx < len(new_cache) else None
774
+
775
+ if past_k is not None and new_k is not None:
776
+ # Both layers exist - validate tensor dimensions before concatenation
777
+ if past_k.dim() != new_k.dim():
778
+ logger.error(
779
+ f"KV cache tensor dimension mismatch at layer {idx}: "
780
+ f"past_k.dim()={past_k.dim()}, new_k.dim()={new_k.dim()}"
781
+ )
782
+ # Use the cache with higher batch size
783
+ if past_k.shape[0] >= new_k.shape[0]:
784
+ ret_kv.update(past_k, past_v, idx)
785
+ else:
786
+ ret_kv.update(new_k, new_v, idx)
787
+ continue
788
+
789
+ if past_k.shape[1:] == new_k.shape[1:]:
790
+ # Shapes are compatible, concatenate along batch dimension
791
+ ret_kv.update(
792
+ torch.cat((new_k, past_k), 0).contiguous(),
793
+ torch.cat((new_v, past_v), 0).contiguous(),
794
+ idx,
795
+ )
796
+ else:
797
+ # Detailed logging for shape mismatch
798
+ logger.warning(
799
+ f"KV cache shape mismatch at layer {idx}: "
800
+ f"past_k.shape={past_k.shape}, new_k.shape={new_k.shape}. "
801
+ f"This may be due to inconsistent batch sizes in continuous batching."
802
+ )
803
+
804
+ # Choose the cache with larger batch size to preserve more data
805
+ if past_k.shape[0] >= new_k.shape[0]:
806
+ ret_kv.update(past_k, past_v, idx)
807
+ else:
808
+ ret_kv.update(new_k, new_v, idx)
809
+ elif past_k is not None:
810
+ ret_kv.update(past_k, past_v, idx)
811
+ elif new_k is not None:
812
+ ret_kv.update(new_k, new_v, idx)
813
+ else:
814
+ # both None, fill with None
815
+ ret_kv.update(None, None, idx)
816
+
740
817
  return ret_kv
741
818
 
742
819
  def prepare_batch_inference(self, req_list: List[InferenceRequest]):
@@ -0,0 +1,340 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ from typing import Any, Dict, Iterator, List, Optional, Tuple
17
+
18
+ import torch
19
+ from PIL import Image
20
+
21
+ from .....core.model import register_batching_multimodal_models
22
+ from .....model.utils import select_device
23
+ from .....types import PytorchModelConfig
24
+ from ....scheduler.request import InferenceRequest
25
+ from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
26
+ from ...utils import _decode_image, parse_messages
27
+ from ..core import register_non_default_model
28
+ from .core import PytorchMultiModalModel
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ @register_batching_multimodal_models("MiniCPM-V-4.5")
34
+ @register_transformer
35
+ @register_non_default_model("MiniCPM-V-4.5")
36
+ class MiniCPMV45Model(PytorchMultiModalModel):
37
+ @classmethod
38
+ def match_json(
39
+ cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
40
+ ) -> bool:
41
+ family = model_family.model_family or model_family.model_name
42
+ if "MiniCPM-V-4.5".lower() in family.lower():
43
+ return True
44
+ return False
45
+
46
+ def _sanitize_model_config(
47
+ self, pytorch_model_config: Optional[PytorchModelConfig]
48
+ ) -> PytorchModelConfig:
49
+ pytorch_model_config = super()._sanitize_model_config(pytorch_model_config)
50
+ assert pytorch_model_config is not None
51
+ # Configure pixel parameters for MiniCPM-V-4.5
52
+ pytorch_model_config.setdefault("min_pixels", 256 * 28 * 28)
53
+ pytorch_model_config.setdefault("max_pixels", 1280 * 28 * 28)
54
+ return pytorch_model_config
55
+
56
+ def decide_device(self):
57
+ device = self._pytorch_model_config.get("device", "auto")
58
+ self._device = select_device(device)
59
+ self._device = (
60
+ "auto"
61
+ if self._device == "cuda" and self.quantization is None
62
+ else self._device
63
+ )
64
+
65
+ def load_processor(self):
66
+ from transformers import AutoProcessor, AutoTokenizer
67
+
68
+ min_pixels = self._pytorch_model_config.get("min_pixels")
69
+ max_pixels = self._pytorch_model_config.get("max_pixels")
70
+ self._processor = AutoProcessor.from_pretrained(
71
+ self.model_path,
72
+ trust_remote_code=True,
73
+ min_pixels=min_pixels,
74
+ max_pixels=max_pixels,
75
+ )
76
+
77
+ self._tokenizer = AutoTokenizer.from_pretrained(
78
+ self.model_path, trust_remote_code=True
79
+ )
80
+
81
+ def load_multimodal_model(self):
82
+ from transformers import AutoModel
83
+ from transformers.generation import GenerationConfig
84
+
85
+ if "int4" in self.model_path:
86
+ model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
87
+ else:
88
+ kwargs = self.apply_bnb_quantization()
89
+ model = AutoModel.from_pretrained(
90
+ self.model_path,
91
+ trust_remote_code=True,
92
+ torch_dtype=torch.float16,
93
+ device_map=self._device,
94
+ **kwargs,
95
+ )
96
+ self._model = model.eval()
97
+ # Specify hyperparameters for generation
98
+ self._model.generation_config = GenerationConfig.from_pretrained(
99
+ self.model_path,
100
+ trust_remote_code=True,
101
+ )
102
+ self._device = self._model.device
103
+
104
+ def _message_content_to_chat(self, content):
105
+ MAX_NUM_FRAMES = 64
106
+
107
+ def encode_video(video_path):
108
+ from decord import VideoReader, cpu
109
+
110
+ def uniform_sample(l, n):
111
+ gap = len(l) / n
112
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
113
+ return [l[i] for i in idxs]
114
+
115
+ vr = VideoReader(video_path, ctx=cpu(0))
116
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
117
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
118
+ if len(frame_idx) > MAX_NUM_FRAMES:
119
+ frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
120
+ frames = vr.get_batch(frame_idx).asnumpy()
121
+ frames = [Image.fromarray(v.astype("uint8")) for v in frames]
122
+ logger.info(
123
+ f"Num frames: {len(frames)} when decoding video for {self.model_uid}"
124
+ )
125
+ return frames
126
+
127
+ def _load_video(_url):
128
+ frames = None
129
+ if _url.startswith("data:"):
130
+ raise RuntimeError("Only video url format is supported")
131
+ else:
132
+ frames = encode_video(_url)
133
+ return frames
134
+
135
+ if not isinstance(content, str):
136
+ texts = []
137
+ image_urls = []
138
+ video_urls = []
139
+ for c in content:
140
+ c_type = c.get("type")
141
+ if c_type == "text":
142
+ texts.append(c["text"])
143
+ elif c_type == "image_url":
144
+ image_urls.append(c["image_url"]["url"])
145
+ elif c_type == "video_url":
146
+ video_urls.append(c["video_url"]["url"])
147
+ image_futures = []
148
+ with ThreadPoolExecutor() as executor:
149
+ for image_url in image_urls:
150
+ fut = executor.submit(_decode_image, image_url)
151
+ image_futures.append(fut)
152
+ images = [fut.result() for fut in image_futures]
153
+ frames = []
154
+ if len(video_urls) > 1:
155
+ raise RuntimeError("Only one video per message is supported")
156
+ for v in video_urls:
157
+ frames = _load_video(v)
158
+ text = " ".join(texts)
159
+ return text, images, frames
160
+ return content, [], []
161
+
162
+ def _convert_to_specific_style(self, messages: List[Dict]) -> Tuple:
163
+ video_existed = False
164
+ prompt, _, chat_history = parse_messages(messages)
165
+
166
+ content, images_chat, video_frames = self._message_content_to_chat(prompt)
167
+ if len(video_frames) > 0:
168
+ video_existed = True
169
+ images_chat = video_frames
170
+
171
+ msgs = []
172
+ query_to_response: List[Dict] = []
173
+ for h in chat_history or []:
174
+ images_history = []
175
+ role = h["role"]
176
+ content_h, images_tmp, video_frames_h = self._message_content_to_chat(
177
+ h["content"]
178
+ )
179
+ if images_tmp != []:
180
+ images_history = images_tmp
181
+ if len(video_frames_h) > 0:
182
+ video_existed = True
183
+ images_history = video_frames_h
184
+ if len(query_to_response) == 0 and role == "user":
185
+ query_to_response.append(
186
+ {"role": "user", "content": images_history + [content_h]}
187
+ )
188
+ if len(query_to_response) == 1 and role == "assistant":
189
+ query_to_response.append(
190
+ {"role": "assistant", "content": images_history + [content_h]}
191
+ )
192
+ if len(query_to_response) == 2:
193
+ msgs.extend(query_to_response)
194
+ query_to_response = []
195
+ msgs.append({"role": "user", "content": images_chat + [content]})
196
+ return msgs, video_existed
197
+
198
+ def build_inputs_from_messages(
199
+ self,
200
+ messages: List[Dict],
201
+ generate_config: Dict,
202
+ ):
203
+ msgs, video_existed = self._convert_to_specific_style(messages)
204
+ # Set decode params for video
205
+ params = {}
206
+ if video_existed:
207
+ params = {"use_image_id": False, "max_slice_nums": 1}
208
+ return dict(msgs=msgs, image=None, **params)
209
+
210
+ def build_generate_kwargs(
211
+ self,
212
+ generate_config: Dict,
213
+ ) -> Dict[str, Any]:
214
+ return dict(**generate_config)
215
+
216
+ def build_streaming_iter(
217
+ self,
218
+ messages: List[Dict],
219
+ generate_config: Dict,
220
+ ) -> Tuple[Iterator, int]:
221
+ inputs = self.build_inputs_from_messages(messages, generate_config)
222
+ config = self.build_generate_kwargs(generate_config)
223
+ chat_iter = self._model.chat(
224
+ **inputs, **config, tokenizer=self._tokenizer, sampling=True
225
+ )
226
+
227
+ return chat_iter, -1
228
+
229
+ def prepare_sanitize_generate_config(self, req: InferenceRequest):
230
+ """
231
+ Refer to MiniCPM-V-4.5 documentation for generation parameters
232
+ """
233
+ raw_config = req.inference_kwargs.get("raw_params", {})
234
+ temperature = raw_config.get("temperature", None)
235
+ if temperature is None:
236
+ raw_config["temperature"] = 0.7
237
+ top_p = raw_config.get("top_p", None)
238
+ if top_p is None:
239
+ raw_config["top_p"] = 0.8
240
+ top_k = raw_config.get("top_k", None)
241
+ if top_k is None:
242
+ raw_config["top_k"] = 100
243
+ repetition_penalty = raw_config.get("repetition_penalty", None)
244
+ if repetition_penalty is None:
245
+ raw_config["repetition_penalty"] = 1.05
246
+ return raw_config
247
+
248
+ def _handle_input_ids_and_images(self, msgs: List[Dict]) -> Dict:
249
+ """
250
+ Handle input IDs and images for MiniCPM-V-4.5
251
+ Based on MiniCPM-V-2.6 implementation with adaptations for 4.5
252
+ """
253
+ from copy import deepcopy
254
+
255
+ copy_msgs = deepcopy(msgs)
256
+
257
+ images = []
258
+ for i, msg in enumerate(copy_msgs):
259
+ role = msg["role"]
260
+ content = msg["content"]
261
+ assert role in ["user", "assistant"]
262
+ if i == 0:
263
+ assert role == "user", "The role of first msg should be user"
264
+ if isinstance(content, str):
265
+ content = [content]
266
+ cur_msgs = []
267
+ for c in content:
268
+ if isinstance(c, Image.Image):
269
+ images.append(c)
270
+ cur_msgs.append("(<image>./</image>)")
271
+ elif isinstance(c, str):
272
+ cur_msgs.append(c)
273
+ msg["content"] = "\n".join(cur_msgs)
274
+
275
+ return {
276
+ "prompt": self._processor.tokenizer.apply_chat_template(
277
+ copy_msgs, tokenize=False, add_generation_prompt=True
278
+ ),
279
+ "input_image": images,
280
+ }
281
+
282
+ def _get_full_prompt(self, messages: List[Dict], tools, generate_config: dict): # type: ignore
283
+ msgs, video_existed = self._convert_to_specific_style(messages)
284
+ if video_existed:
285
+ raise RuntimeError(
286
+ f"Continuous batching does not support video inputs for this model: {self.model_uid}"
287
+ )
288
+ return self._handle_input_ids_and_images(msgs)
289
+
290
+ def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
291
+ prompts_lists = [x["prompt"] for x in prompts]
292
+ input_images_lists = [x["input_image"] for x in prompts]
293
+ inputs = self._processor(
294
+ prompts_lists,
295
+ input_images_lists,
296
+ max_slice_nums=None,
297
+ use_image_id=None,
298
+ return_tensors="pt",
299
+ max_length=8192,
300
+ ).to(self._model.device)
301
+ inputs.pop("image_sizes")
302
+
303
+ masked_input_ids = inputs["input_ids"] * inputs["attention_mask"]
304
+ for i in range(masked_input_ids.shape[0]):
305
+ non_zero_values = masked_input_ids[i][masked_input_ids[i] != 0].tolist()
306
+ req_list[i].prompt_tokens = non_zero_values
307
+ req_list[i].extra_kwargs["attention_mask_seq_len"] = len(non_zero_values)
308
+ req_list[i].padding_len = masked_input_ids.shape[1] - len(non_zero_values)
309
+
310
+ model_inputs = {
311
+ "input_ids": inputs["input_ids"],
312
+ "image_bound": inputs["image_bound"],
313
+ "pixel_values": inputs["pixel_values"],
314
+ "tgt_sizes": inputs["tgt_sizes"],
315
+ }
316
+ model_inputs["inputs_embeds"], _ = self._model.get_vllm_embedding(model_inputs)
317
+
318
+ return {
319
+ "inputs_embeds": model_inputs["inputs_embeds"],
320
+ "attention_mask": inputs["attention_mask"],
321
+ }
322
+
323
+ def build_decode_position_ids(
324
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
325
+ ):
326
+ return None
327
+
328
+ def batch_inference(self, req_list: List[InferenceRequest]):
329
+ """
330
+ This method is rewritten
331
+ because the specific inference process is performed by `self._model.llm`,
332
+ not `self._model` itself
333
+ """
334
+ from ..utils import batch_inference_one_step
335
+
336
+ self.prepare_batch_inference(req_list)
337
+ batch_inference_one_step(
338
+ self, req_list, self.model_uid, self._model.llm, self._tokenizer
339
+ )
340
+ self.handle_batch_inference_results(req_list)
@@ -281,7 +281,10 @@ def _batch_inference_one_step_internal(
281
281
  r.append_new_token(token)
282
282
 
283
283
  if decode_reqs:
284
+ # Ensure all decode requests have the same kv_cache reference
285
+ # This prevents batch size mismatches during merging
284
286
  decode_kv = decode_reqs[0].kv_cache
287
+
285
288
  # prefill and decode kv cache need to be merged at `batch_size` and `seq_len` dimensions.
286
289
  merged_kv_cache = xinf_model_obj.merge_kv_cache(decode_kv, past_key_values)
287
290
  for r in valid_req_list: