xinference 1.10.1__py3-none-any.whl → 1.11.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +462 -3
- xinference/client/restful/async_restful_client.py +158 -5
- xinference/client/restful/restful_client.py +131 -0
- xinference/core/supervisor.py +12 -0
- xinference/model/audio/model_spec.json +20 -20
- xinference/model/image/model_spec.json +159 -159
- xinference/model/llm/__init__.py +2 -2
- xinference/model/llm/llm_family.json +843 -180
- xinference/model/llm/mlx/distributed_models/core.py +41 -0
- xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
- xinference/model/llm/sglang/core.py +20 -6
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
- xinference/model/llm/transformers/chatglm.py +3 -0
- xinference/model/llm/transformers/core.py +93 -16
- xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
- xinference/model/llm/transformers/utils.py +3 -0
- xinference/model/llm/utils.py +37 -24
- xinference/model/llm/vllm/core.py +128 -69
- xinference/model/utils.py +74 -31
- xinference/thirdparty/audiotools/core/audio_signal.py +6 -6
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
- xinference/thirdparty/melo/text/chinese_mix.py +2 -2
- xinference/types.py +9 -0
- xinference/ui/web/ui/build/asset-manifest.json +3 -3
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/js/{main.d192c4f3.js → main.e4d9a9e1.js} +3 -3
- xinference/ui/web/ui/build/static/js/{main.d192c4f3.js.map → main.e4d9a9e1.js.map} +1 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e6770a05771952175c9fbf48fce283c9bb1bc8b5763e39edc36d099d1fe16b4a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
- {xinference-1.10.1.dist-info → xinference-1.11.0.post1.dist-info}/METADATA +8 -5
- {xinference-1.10.1.dist-info → xinference-1.11.0.post1.dist-info}/RECORD +37 -36
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +0 -1
- /xinference/ui/web/ui/build/static/js/{main.d192c4f3.js.LICENSE.txt → main.e4d9a9e1.js.LICENSE.txt} +0 -0
- {xinference-1.10.1.dist-info → xinference-1.11.0.post1.dist-info}/WHEEL +0 -0
- {xinference-1.10.1.dist-info → xinference-1.11.0.post1.dist-info}/entry_points.txt +0 -0
- {xinference-1.10.1.dist-info → xinference-1.11.0.post1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.10.1.dist-info → xinference-1.11.0.post1.dist-info}/top_level.txt +0 -0
|
@@ -162,3 +162,44 @@ class DistributedModelMixin:
|
|
|
162
162
|
self.layers = self.layers[: self.end_idx]
|
|
163
163
|
self.layers[: self.start_idx] = [None] * self.start_idx
|
|
164
164
|
self.num_layers = len(self.layers) - self.start_idx
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class SafeKVCache:
|
|
168
|
+
"""
|
|
169
|
+
A safe wrapper around mlx_lm's KVCache that handles None keys gracefully.
|
|
170
|
+
This is needed because mlx_lm's generate function accesses cache.state
|
|
171
|
+
before the cache is properly initialized.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
def __init__(self):
|
|
175
|
+
from mlx_lm.models.cache import KVCache
|
|
176
|
+
|
|
177
|
+
self._cache = KVCache()
|
|
178
|
+
|
|
179
|
+
@property
|
|
180
|
+
def state(self):
|
|
181
|
+
# Safe access to state property
|
|
182
|
+
if self._cache.keys is None:
|
|
183
|
+
return None, None
|
|
184
|
+
if self._cache.offset == self._cache.keys.shape[2]:
|
|
185
|
+
return self._cache.keys, self._cache.values
|
|
186
|
+
else:
|
|
187
|
+
return (
|
|
188
|
+
self._cache.keys[..., : self._cache.offset, :],
|
|
189
|
+
self._cache.values[..., : self._cache.offset, :],
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
@state.setter
|
|
193
|
+
def state(self, v):
|
|
194
|
+
# Safe setter for state property
|
|
195
|
+
if v is None or v[0] is None:
|
|
196
|
+
self._cache.keys = None
|
|
197
|
+
self._cache.values = None
|
|
198
|
+
self._cache.offset = 0
|
|
199
|
+
else:
|
|
200
|
+
self._cache.keys, self._cache.values = v
|
|
201
|
+
self._cache.offset = self._cache.keys.shape[2]
|
|
202
|
+
|
|
203
|
+
def __getattr__(self, name):
|
|
204
|
+
# Delegate all other attributes and methods to the underlying cache
|
|
205
|
+
return getattr(self._cache, name)
|
|
@@ -46,11 +46,10 @@ class Qwen2Model(_Qwen2Model, DistributedModelMixin):
|
|
|
46
46
|
|
|
47
47
|
pipeline_rank = self.rank
|
|
48
48
|
pipeline_size = self.world_size
|
|
49
|
-
if mask is None:
|
|
50
|
-
mask = create_attention_mask(h, cache)
|
|
51
49
|
|
|
52
50
|
if cache is None:
|
|
53
51
|
cache = [None] * self.num_layers
|
|
52
|
+
mask = create_attention_mask(h, cache[0])
|
|
54
53
|
|
|
55
54
|
# Receive from the previous process in the pipeline
|
|
56
55
|
|
|
@@ -362,9 +362,16 @@ class SGLANGModel(LLM):
|
|
|
362
362
|
def _convert_state_to_completion_chunk(
|
|
363
363
|
request_id: str, model: str, output_text: str, meta_info: Dict
|
|
364
364
|
) -> CompletionChunk:
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
365
|
+
finish_reason_raw = meta_info.get("finish_reason", None)
|
|
366
|
+
finish_reason: Optional[str] = None
|
|
367
|
+
if isinstance(finish_reason_raw, dict) and "type" in finish_reason_raw:
|
|
368
|
+
finish_reason = (
|
|
369
|
+
str(finish_reason_raw["type"])
|
|
370
|
+
if finish_reason_raw["type"] is not None
|
|
371
|
+
else None
|
|
372
|
+
)
|
|
373
|
+
elif isinstance(finish_reason_raw, str):
|
|
374
|
+
finish_reason = finish_reason_raw
|
|
368
375
|
choices: List[CompletionChoice] = [
|
|
369
376
|
CompletionChoice(
|
|
370
377
|
text=output_text,
|
|
@@ -392,9 +399,16 @@ class SGLANGModel(LLM):
|
|
|
392
399
|
def _convert_state_to_completion(
|
|
393
400
|
request_id: str, model: str, output_text: str, meta_info: Dict
|
|
394
401
|
) -> Completion:
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
402
|
+
finish_reason_raw = meta_info.get("finish_reason", None)
|
|
403
|
+
finish_reason: Optional[str] = None
|
|
404
|
+
if isinstance(finish_reason_raw, dict) and "type" in finish_reason_raw:
|
|
405
|
+
finish_reason = (
|
|
406
|
+
str(finish_reason_raw["type"])
|
|
407
|
+
if finish_reason_raw["type"] is not None
|
|
408
|
+
else None
|
|
409
|
+
)
|
|
410
|
+
elif isinstance(finish_reason_raw, str):
|
|
411
|
+
finish_reason = finish_reason_raw
|
|
398
412
|
choices = [
|
|
399
413
|
CompletionChoice(
|
|
400
414
|
text=output_text,
|
|
@@ -59,10 +59,28 @@ class QwenToolParser(ToolParser):
|
|
|
59
59
|
Returns:
|
|
60
60
|
str: Extracted JSON string or original string if no match found.
|
|
61
61
|
"""
|
|
62
|
+
# First try to find complete tool calls
|
|
62
63
|
function_calls = self.tool_call_complete_regex.findall(function_call_str)
|
|
63
|
-
if len(function_calls)
|
|
64
|
-
return
|
|
65
|
-
|
|
64
|
+
if len(function_calls) > 0:
|
|
65
|
+
return function_calls[-1]
|
|
66
|
+
|
|
67
|
+
# If no complete tool calls found, try to extract from incomplete tool calls
|
|
68
|
+
# Handle cases like <tool_call><tool_call>_city
|
|
69
|
+
if self.tool_call_start_token in function_call_str:
|
|
70
|
+
# Extract content between the last tool_call start token and end of string
|
|
71
|
+
last_start = function_call_str.rfind(self.tool_call_start_token)
|
|
72
|
+
potential_json = function_call_str[
|
|
73
|
+
last_start + len(self.tool_call_start_token) :
|
|
74
|
+
]
|
|
75
|
+
# Remove any trailing tool_call end tokens
|
|
76
|
+
if self.tool_call_end_token in potential_json:
|
|
77
|
+
potential_json = potential_json.split(self.tool_call_end_token)[0]
|
|
78
|
+
# Clean up any extra whitespace
|
|
79
|
+
potential_json = potential_json.strip()
|
|
80
|
+
if potential_json:
|
|
81
|
+
return potential_json
|
|
82
|
+
|
|
83
|
+
return function_call_str
|
|
66
84
|
|
|
67
85
|
def _parse_json_function_call_stream(
|
|
68
86
|
self,
|
|
@@ -229,7 +247,14 @@ class QwenToolParser(ToolParser):
|
|
|
229
247
|
try:
|
|
230
248
|
parsed_json = self._parse_json_function_call(function_call)
|
|
231
249
|
res = json.loads(parsed_json, strict=False)
|
|
232
|
-
|
|
250
|
+
# Validate that we have the required fields
|
|
251
|
+
if "name" in res and "arguments" in res:
|
|
252
|
+
results.append((None, res["name"], res["arguments"]))
|
|
253
|
+
else:
|
|
254
|
+
logger.warning(
|
|
255
|
+
"Invalid tool call format, missing required fields: %s", res
|
|
256
|
+
)
|
|
257
|
+
results.append((function_call, None, None))
|
|
233
258
|
except Exception as e:
|
|
234
259
|
logger.error(
|
|
235
260
|
"Can't parse single qwen tool call output: %s. Error: %s",
|
|
@@ -472,6 +472,9 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
472
472
|
r.prompt = self._process_messages(
|
|
473
473
|
r.prompt, tools=tools, tool_choice=tool_choice
|
|
474
474
|
)
|
|
475
|
+
assert isinstance(
|
|
476
|
+
r.prompt, list
|
|
477
|
+
), "r.prompt must be a list after processing"
|
|
475
478
|
r.full_prompt = self.get_full_context(
|
|
476
479
|
r.prompt,
|
|
477
480
|
self.model_family.chat_template, # type: ignore
|
|
@@ -48,6 +48,7 @@ from ..utils import (
|
|
|
48
48
|
)
|
|
49
49
|
from .utils import (
|
|
50
50
|
_get_pad_param,
|
|
51
|
+
convert_to_cache_cls,
|
|
51
52
|
get_context_length,
|
|
52
53
|
get_max_src_len,
|
|
53
54
|
pad_prefill_tokens,
|
|
@@ -573,6 +574,7 @@ class PytorchModel(LLM):
|
|
|
573
574
|
]
|
|
574
575
|
)
|
|
575
576
|
data.append(x)
|
|
577
|
+
|
|
576
578
|
return torch.stack(data).to(self._device)
|
|
577
579
|
|
|
578
580
|
def build_prefill_position_ids(
|
|
@@ -713,30 +715,105 @@ class PytorchModel(LLM):
|
|
|
713
715
|
from torch.nn.functional import pad
|
|
714
716
|
from transformers import DynamicCache
|
|
715
717
|
|
|
718
|
+
# Handle case where past_cache is None
|
|
719
|
+
if past_cache is None:
|
|
720
|
+
return new_cache
|
|
721
|
+
|
|
722
|
+
# Convert both caches to DynamicCache if not already
|
|
723
|
+
if not isinstance(past_cache, DynamicCache):
|
|
724
|
+
past_cache = convert_to_cache_cls(past_cache)
|
|
725
|
+
if not isinstance(new_cache, DynamicCache):
|
|
726
|
+
new_cache = convert_to_cache_cls(new_cache)
|
|
727
|
+
|
|
716
728
|
_, seq_len_idx = self.get_batch_size_and_seq_len_indexes_from_kv()
|
|
717
|
-
|
|
718
|
-
|
|
729
|
+
|
|
730
|
+
# Handle empty caches
|
|
731
|
+
if len(past_cache) == 0:
|
|
732
|
+
return new_cache
|
|
733
|
+
if len(new_cache) == 0:
|
|
734
|
+
return past_cache
|
|
735
|
+
|
|
736
|
+
# Get first layer seq_len safely
|
|
737
|
+
past_first = past_cache[0] if len(past_cache) > 0 else (None, None)
|
|
738
|
+
new_first = new_cache[0] if len(new_cache) > 0 else (None, None)
|
|
739
|
+
|
|
740
|
+
if past_first[0] is None or past_first[1] is None:
|
|
741
|
+
return new_cache
|
|
742
|
+
if new_first[0] is None or new_first[1] is None:
|
|
743
|
+
return past_cache
|
|
744
|
+
|
|
745
|
+
past_seq_len = past_first[0].shape[seq_len_idx]
|
|
746
|
+
new_seq_len = new_first[0].shape[seq_len_idx]
|
|
747
|
+
|
|
748
|
+
# Pad the shorter cache
|
|
719
749
|
if past_seq_len != new_seq_len:
|
|
720
|
-
|
|
721
|
-
|
|
750
|
+
if past_seq_len > new_seq_len:
|
|
751
|
+
padding_target = new_cache
|
|
752
|
+
padding_len = past_seq_len - new_seq_len
|
|
753
|
+
else:
|
|
754
|
+
padding_target = past_cache
|
|
755
|
+
padding_len = new_seq_len - past_seq_len
|
|
756
|
+
|
|
722
757
|
pad_param = _get_pad_param(seq_len_idx, padding_len)
|
|
723
758
|
for idx in range(len(padding_target)):
|
|
724
759
|
k = padding_target.key_cache[idx]
|
|
725
760
|
v = padding_target.value_cache[idx]
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
padding_target.value_cache[idx] = _v
|
|
761
|
+
if k is not None and v is not None:
|
|
762
|
+
padding_target.key_cache[idx] = pad(k, pad_param)
|
|
763
|
+
padding_target.value_cache[idx] = pad(v, pad_param)
|
|
730
764
|
|
|
765
|
+
# Merge caches
|
|
731
766
|
ret_kv = DynamicCache()
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
767
|
+
max_layers = max(len(past_cache), len(new_cache))
|
|
768
|
+
|
|
769
|
+
for idx in range(max_layers):
|
|
770
|
+
past_k = past_cache.key_cache[idx] if idx < len(past_cache) else None
|
|
771
|
+
past_v = past_cache.value_cache[idx] if idx < len(past_cache) else None
|
|
772
|
+
new_k = new_cache.key_cache[idx] if idx < len(new_cache) else None
|
|
773
|
+
new_v = new_cache.value_cache[idx] if idx < len(new_cache) else None
|
|
774
|
+
|
|
775
|
+
if past_k is not None and new_k is not None:
|
|
776
|
+
# Both layers exist - validate tensor dimensions before concatenation
|
|
777
|
+
if past_k.dim() != new_k.dim():
|
|
778
|
+
logger.error(
|
|
779
|
+
f"KV cache tensor dimension mismatch at layer {idx}: "
|
|
780
|
+
f"past_k.dim()={past_k.dim()}, new_k.dim()={new_k.dim()}"
|
|
781
|
+
)
|
|
782
|
+
# Use the cache with higher batch size
|
|
783
|
+
if past_k.shape[0] >= new_k.shape[0]:
|
|
784
|
+
ret_kv.update(past_k, past_v, idx)
|
|
785
|
+
else:
|
|
786
|
+
ret_kv.update(new_k, new_v, idx)
|
|
787
|
+
continue
|
|
788
|
+
|
|
789
|
+
if past_k.shape[1:] == new_k.shape[1:]:
|
|
790
|
+
# Shapes are compatible, concatenate along batch dimension
|
|
791
|
+
ret_kv.update(
|
|
792
|
+
torch.cat((new_k, past_k), 0).contiguous(),
|
|
793
|
+
torch.cat((new_v, past_v), 0).contiguous(),
|
|
794
|
+
idx,
|
|
795
|
+
)
|
|
796
|
+
else:
|
|
797
|
+
# Detailed logging for shape mismatch
|
|
798
|
+
logger.warning(
|
|
799
|
+
f"KV cache shape mismatch at layer {idx}: "
|
|
800
|
+
f"past_k.shape={past_k.shape}, new_k.shape={new_k.shape}. "
|
|
801
|
+
f"This may be due to inconsistent batch sizes in continuous batching."
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
# Choose the cache with larger batch size to preserve more data
|
|
805
|
+
if past_k.shape[0] >= new_k.shape[0]:
|
|
806
|
+
ret_kv.update(past_k, past_v, idx)
|
|
807
|
+
else:
|
|
808
|
+
ret_kv.update(new_k, new_v, idx)
|
|
809
|
+
elif past_k is not None:
|
|
810
|
+
ret_kv.update(past_k, past_v, idx)
|
|
811
|
+
elif new_k is not None:
|
|
812
|
+
ret_kv.update(new_k, new_v, idx)
|
|
813
|
+
else:
|
|
814
|
+
# both None, fill with None
|
|
815
|
+
ret_kv.update(None, None, idx)
|
|
816
|
+
|
|
740
817
|
return ret_kv
|
|
741
818
|
|
|
742
819
|
def prepare_batch_inference(self, req_list: List[InferenceRequest]):
|
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
16
|
+
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from PIL import Image
|
|
20
|
+
|
|
21
|
+
from .....core.model import register_batching_multimodal_models
|
|
22
|
+
from .....model.utils import select_device
|
|
23
|
+
from .....types import PytorchModelConfig
|
|
24
|
+
from ....scheduler.request import InferenceRequest
|
|
25
|
+
from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
|
|
26
|
+
from ...utils import _decode_image, parse_messages
|
|
27
|
+
from ..core import register_non_default_model
|
|
28
|
+
from .core import PytorchMultiModalModel
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@register_batching_multimodal_models("MiniCPM-V-4.5")
|
|
34
|
+
@register_transformer
|
|
35
|
+
@register_non_default_model("MiniCPM-V-4.5")
|
|
36
|
+
class MiniCPMV45Model(PytorchMultiModalModel):
|
|
37
|
+
@classmethod
|
|
38
|
+
def match_json(
|
|
39
|
+
cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
|
|
40
|
+
) -> bool:
|
|
41
|
+
family = model_family.model_family or model_family.model_name
|
|
42
|
+
if "MiniCPM-V-4.5".lower() in family.lower():
|
|
43
|
+
return True
|
|
44
|
+
return False
|
|
45
|
+
|
|
46
|
+
def _sanitize_model_config(
|
|
47
|
+
self, pytorch_model_config: Optional[PytorchModelConfig]
|
|
48
|
+
) -> PytorchModelConfig:
|
|
49
|
+
pytorch_model_config = super()._sanitize_model_config(pytorch_model_config)
|
|
50
|
+
assert pytorch_model_config is not None
|
|
51
|
+
# Configure pixel parameters for MiniCPM-V-4.5
|
|
52
|
+
pytorch_model_config.setdefault("min_pixels", 256 * 28 * 28)
|
|
53
|
+
pytorch_model_config.setdefault("max_pixels", 1280 * 28 * 28)
|
|
54
|
+
return pytorch_model_config
|
|
55
|
+
|
|
56
|
+
def decide_device(self):
|
|
57
|
+
device = self._pytorch_model_config.get("device", "auto")
|
|
58
|
+
self._device = select_device(device)
|
|
59
|
+
self._device = (
|
|
60
|
+
"auto"
|
|
61
|
+
if self._device == "cuda" and self.quantization is None
|
|
62
|
+
else self._device
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def load_processor(self):
|
|
66
|
+
from transformers import AutoProcessor, AutoTokenizer
|
|
67
|
+
|
|
68
|
+
min_pixels = self._pytorch_model_config.get("min_pixels")
|
|
69
|
+
max_pixels = self._pytorch_model_config.get("max_pixels")
|
|
70
|
+
self._processor = AutoProcessor.from_pretrained(
|
|
71
|
+
self.model_path,
|
|
72
|
+
trust_remote_code=True,
|
|
73
|
+
min_pixels=min_pixels,
|
|
74
|
+
max_pixels=max_pixels,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
78
|
+
self.model_path, trust_remote_code=True
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def load_multimodal_model(self):
|
|
82
|
+
from transformers import AutoModel
|
|
83
|
+
from transformers.generation import GenerationConfig
|
|
84
|
+
|
|
85
|
+
if "int4" in self.model_path:
|
|
86
|
+
model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
|
|
87
|
+
else:
|
|
88
|
+
kwargs = self.apply_bnb_quantization()
|
|
89
|
+
model = AutoModel.from_pretrained(
|
|
90
|
+
self.model_path,
|
|
91
|
+
trust_remote_code=True,
|
|
92
|
+
torch_dtype=torch.float16,
|
|
93
|
+
device_map=self._device,
|
|
94
|
+
**kwargs,
|
|
95
|
+
)
|
|
96
|
+
self._model = model.eval()
|
|
97
|
+
# Specify hyperparameters for generation
|
|
98
|
+
self._model.generation_config = GenerationConfig.from_pretrained(
|
|
99
|
+
self.model_path,
|
|
100
|
+
trust_remote_code=True,
|
|
101
|
+
)
|
|
102
|
+
self._device = self._model.device
|
|
103
|
+
|
|
104
|
+
def _message_content_to_chat(self, content):
|
|
105
|
+
MAX_NUM_FRAMES = 64
|
|
106
|
+
|
|
107
|
+
def encode_video(video_path):
|
|
108
|
+
from decord import VideoReader, cpu
|
|
109
|
+
|
|
110
|
+
def uniform_sample(l, n):
|
|
111
|
+
gap = len(l) / n
|
|
112
|
+
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
|
113
|
+
return [l[i] for i in idxs]
|
|
114
|
+
|
|
115
|
+
vr = VideoReader(video_path, ctx=cpu(0))
|
|
116
|
+
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
|
117
|
+
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
|
118
|
+
if len(frame_idx) > MAX_NUM_FRAMES:
|
|
119
|
+
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
|
120
|
+
frames = vr.get_batch(frame_idx).asnumpy()
|
|
121
|
+
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
|
122
|
+
logger.info(
|
|
123
|
+
f"Num frames: {len(frames)} when decoding video for {self.model_uid}"
|
|
124
|
+
)
|
|
125
|
+
return frames
|
|
126
|
+
|
|
127
|
+
def _load_video(_url):
|
|
128
|
+
frames = None
|
|
129
|
+
if _url.startswith("data:"):
|
|
130
|
+
raise RuntimeError("Only video url format is supported")
|
|
131
|
+
else:
|
|
132
|
+
frames = encode_video(_url)
|
|
133
|
+
return frames
|
|
134
|
+
|
|
135
|
+
if not isinstance(content, str):
|
|
136
|
+
texts = []
|
|
137
|
+
image_urls = []
|
|
138
|
+
video_urls = []
|
|
139
|
+
for c in content:
|
|
140
|
+
c_type = c.get("type")
|
|
141
|
+
if c_type == "text":
|
|
142
|
+
texts.append(c["text"])
|
|
143
|
+
elif c_type == "image_url":
|
|
144
|
+
image_urls.append(c["image_url"]["url"])
|
|
145
|
+
elif c_type == "video_url":
|
|
146
|
+
video_urls.append(c["video_url"]["url"])
|
|
147
|
+
image_futures = []
|
|
148
|
+
with ThreadPoolExecutor() as executor:
|
|
149
|
+
for image_url in image_urls:
|
|
150
|
+
fut = executor.submit(_decode_image, image_url)
|
|
151
|
+
image_futures.append(fut)
|
|
152
|
+
images = [fut.result() for fut in image_futures]
|
|
153
|
+
frames = []
|
|
154
|
+
if len(video_urls) > 1:
|
|
155
|
+
raise RuntimeError("Only one video per message is supported")
|
|
156
|
+
for v in video_urls:
|
|
157
|
+
frames = _load_video(v)
|
|
158
|
+
text = " ".join(texts)
|
|
159
|
+
return text, images, frames
|
|
160
|
+
return content, [], []
|
|
161
|
+
|
|
162
|
+
def _convert_to_specific_style(self, messages: List[Dict]) -> Tuple:
|
|
163
|
+
video_existed = False
|
|
164
|
+
prompt, _, chat_history = parse_messages(messages)
|
|
165
|
+
|
|
166
|
+
content, images_chat, video_frames = self._message_content_to_chat(prompt)
|
|
167
|
+
if len(video_frames) > 0:
|
|
168
|
+
video_existed = True
|
|
169
|
+
images_chat = video_frames
|
|
170
|
+
|
|
171
|
+
msgs = []
|
|
172
|
+
query_to_response: List[Dict] = []
|
|
173
|
+
for h in chat_history or []:
|
|
174
|
+
images_history = []
|
|
175
|
+
role = h["role"]
|
|
176
|
+
content_h, images_tmp, video_frames_h = self._message_content_to_chat(
|
|
177
|
+
h["content"]
|
|
178
|
+
)
|
|
179
|
+
if images_tmp != []:
|
|
180
|
+
images_history = images_tmp
|
|
181
|
+
if len(video_frames_h) > 0:
|
|
182
|
+
video_existed = True
|
|
183
|
+
images_history = video_frames_h
|
|
184
|
+
if len(query_to_response) == 0 and role == "user":
|
|
185
|
+
query_to_response.append(
|
|
186
|
+
{"role": "user", "content": images_history + [content_h]}
|
|
187
|
+
)
|
|
188
|
+
if len(query_to_response) == 1 and role == "assistant":
|
|
189
|
+
query_to_response.append(
|
|
190
|
+
{"role": "assistant", "content": images_history + [content_h]}
|
|
191
|
+
)
|
|
192
|
+
if len(query_to_response) == 2:
|
|
193
|
+
msgs.extend(query_to_response)
|
|
194
|
+
query_to_response = []
|
|
195
|
+
msgs.append({"role": "user", "content": images_chat + [content]})
|
|
196
|
+
return msgs, video_existed
|
|
197
|
+
|
|
198
|
+
def build_inputs_from_messages(
|
|
199
|
+
self,
|
|
200
|
+
messages: List[Dict],
|
|
201
|
+
generate_config: Dict,
|
|
202
|
+
):
|
|
203
|
+
msgs, video_existed = self._convert_to_specific_style(messages)
|
|
204
|
+
# Set decode params for video
|
|
205
|
+
params = {}
|
|
206
|
+
if video_existed:
|
|
207
|
+
params = {"use_image_id": False, "max_slice_nums": 1}
|
|
208
|
+
return dict(msgs=msgs, image=None, **params)
|
|
209
|
+
|
|
210
|
+
def build_generate_kwargs(
|
|
211
|
+
self,
|
|
212
|
+
generate_config: Dict,
|
|
213
|
+
) -> Dict[str, Any]:
|
|
214
|
+
return dict(**generate_config)
|
|
215
|
+
|
|
216
|
+
def build_streaming_iter(
|
|
217
|
+
self,
|
|
218
|
+
messages: List[Dict],
|
|
219
|
+
generate_config: Dict,
|
|
220
|
+
) -> Tuple[Iterator, int]:
|
|
221
|
+
inputs = self.build_inputs_from_messages(messages, generate_config)
|
|
222
|
+
config = self.build_generate_kwargs(generate_config)
|
|
223
|
+
chat_iter = self._model.chat(
|
|
224
|
+
**inputs, **config, tokenizer=self._tokenizer, sampling=True
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return chat_iter, -1
|
|
228
|
+
|
|
229
|
+
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
230
|
+
"""
|
|
231
|
+
Refer to MiniCPM-V-4.5 documentation for generation parameters
|
|
232
|
+
"""
|
|
233
|
+
raw_config = req.inference_kwargs.get("raw_params", {})
|
|
234
|
+
temperature = raw_config.get("temperature", None)
|
|
235
|
+
if temperature is None:
|
|
236
|
+
raw_config["temperature"] = 0.7
|
|
237
|
+
top_p = raw_config.get("top_p", None)
|
|
238
|
+
if top_p is None:
|
|
239
|
+
raw_config["top_p"] = 0.8
|
|
240
|
+
top_k = raw_config.get("top_k", None)
|
|
241
|
+
if top_k is None:
|
|
242
|
+
raw_config["top_k"] = 100
|
|
243
|
+
repetition_penalty = raw_config.get("repetition_penalty", None)
|
|
244
|
+
if repetition_penalty is None:
|
|
245
|
+
raw_config["repetition_penalty"] = 1.05
|
|
246
|
+
return raw_config
|
|
247
|
+
|
|
248
|
+
def _handle_input_ids_and_images(self, msgs: List[Dict]) -> Dict:
|
|
249
|
+
"""
|
|
250
|
+
Handle input IDs and images for MiniCPM-V-4.5
|
|
251
|
+
Based on MiniCPM-V-2.6 implementation with adaptations for 4.5
|
|
252
|
+
"""
|
|
253
|
+
from copy import deepcopy
|
|
254
|
+
|
|
255
|
+
copy_msgs = deepcopy(msgs)
|
|
256
|
+
|
|
257
|
+
images = []
|
|
258
|
+
for i, msg in enumerate(copy_msgs):
|
|
259
|
+
role = msg["role"]
|
|
260
|
+
content = msg["content"]
|
|
261
|
+
assert role in ["user", "assistant"]
|
|
262
|
+
if i == 0:
|
|
263
|
+
assert role == "user", "The role of first msg should be user"
|
|
264
|
+
if isinstance(content, str):
|
|
265
|
+
content = [content]
|
|
266
|
+
cur_msgs = []
|
|
267
|
+
for c in content:
|
|
268
|
+
if isinstance(c, Image.Image):
|
|
269
|
+
images.append(c)
|
|
270
|
+
cur_msgs.append("(<image>./</image>)")
|
|
271
|
+
elif isinstance(c, str):
|
|
272
|
+
cur_msgs.append(c)
|
|
273
|
+
msg["content"] = "\n".join(cur_msgs)
|
|
274
|
+
|
|
275
|
+
return {
|
|
276
|
+
"prompt": self._processor.tokenizer.apply_chat_template(
|
|
277
|
+
copy_msgs, tokenize=False, add_generation_prompt=True
|
|
278
|
+
),
|
|
279
|
+
"input_image": images,
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
def _get_full_prompt(self, messages: List[Dict], tools, generate_config: dict): # type: ignore
|
|
283
|
+
msgs, video_existed = self._convert_to_specific_style(messages)
|
|
284
|
+
if video_existed:
|
|
285
|
+
raise RuntimeError(
|
|
286
|
+
f"Continuous batching does not support video inputs for this model: {self.model_uid}"
|
|
287
|
+
)
|
|
288
|
+
return self._handle_input_ids_and_images(msgs)
|
|
289
|
+
|
|
290
|
+
def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
|
|
291
|
+
prompts_lists = [x["prompt"] for x in prompts]
|
|
292
|
+
input_images_lists = [x["input_image"] for x in prompts]
|
|
293
|
+
inputs = self._processor(
|
|
294
|
+
prompts_lists,
|
|
295
|
+
input_images_lists,
|
|
296
|
+
max_slice_nums=None,
|
|
297
|
+
use_image_id=None,
|
|
298
|
+
return_tensors="pt",
|
|
299
|
+
max_length=8192,
|
|
300
|
+
).to(self._model.device)
|
|
301
|
+
inputs.pop("image_sizes")
|
|
302
|
+
|
|
303
|
+
masked_input_ids = inputs["input_ids"] * inputs["attention_mask"]
|
|
304
|
+
for i in range(masked_input_ids.shape[0]):
|
|
305
|
+
non_zero_values = masked_input_ids[i][masked_input_ids[i] != 0].tolist()
|
|
306
|
+
req_list[i].prompt_tokens = non_zero_values
|
|
307
|
+
req_list[i].extra_kwargs["attention_mask_seq_len"] = len(non_zero_values)
|
|
308
|
+
req_list[i].padding_len = masked_input_ids.shape[1] - len(non_zero_values)
|
|
309
|
+
|
|
310
|
+
model_inputs = {
|
|
311
|
+
"input_ids": inputs["input_ids"],
|
|
312
|
+
"image_bound": inputs["image_bound"],
|
|
313
|
+
"pixel_values": inputs["pixel_values"],
|
|
314
|
+
"tgt_sizes": inputs["tgt_sizes"],
|
|
315
|
+
}
|
|
316
|
+
model_inputs["inputs_embeds"], _ = self._model.get_vllm_embedding(model_inputs)
|
|
317
|
+
|
|
318
|
+
return {
|
|
319
|
+
"inputs_embeds": model_inputs["inputs_embeds"],
|
|
320
|
+
"attention_mask": inputs["attention_mask"],
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
def build_decode_position_ids(
|
|
324
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
325
|
+
):
|
|
326
|
+
return None
|
|
327
|
+
|
|
328
|
+
def batch_inference(self, req_list: List[InferenceRequest]):
|
|
329
|
+
"""
|
|
330
|
+
This method is rewritten
|
|
331
|
+
because the specific inference process is performed by `self._model.llm`,
|
|
332
|
+
not `self._model` itself
|
|
333
|
+
"""
|
|
334
|
+
from ..utils import batch_inference_one_step
|
|
335
|
+
|
|
336
|
+
self.prepare_batch_inference(req_list)
|
|
337
|
+
batch_inference_one_step(
|
|
338
|
+
self, req_list, self.model_uid, self._model.llm, self._tokenizer
|
|
339
|
+
)
|
|
340
|
+
self.handle_batch_inference_results(req_list)
|
|
@@ -281,7 +281,10 @@ def _batch_inference_one_step_internal(
|
|
|
281
281
|
r.append_new_token(token)
|
|
282
282
|
|
|
283
283
|
if decode_reqs:
|
|
284
|
+
# Ensure all decode requests have the same kv_cache reference
|
|
285
|
+
# This prevents batch size mismatches during merging
|
|
284
286
|
decode_kv = decode_reqs[0].kv_cache
|
|
287
|
+
|
|
285
288
|
# prefill and decode kv cache need to be merged at `batch_size` and `seq_len` dimensions.
|
|
286
289
|
merged_kv_cache = xinf_model_obj.merge_kv_cache(decode_kv, past_key_values)
|
|
287
290
|
for r in valid_req_list:
|