xinference 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

@@ -538,7 +538,10 @@ def _generate_model_file_names(
538
538
  )
539
539
  need_merge = False
540
540
 
541
- if llm_spec.quantization_parts is None:
541
+ if (
542
+ llm_spec.quantization_parts is None
543
+ or quantization not in llm_spec.quantization_parts
544
+ ):
542
545
  file_names.append(final_file_name)
543
546
  elif quantization is not None and quantization in llm_spec.quantization_parts:
544
547
  parts = llm_spec.quantization_parts[quantization]
@@ -4769,10 +4769,11 @@
4769
4769
  "model_format":"mlx",
4770
4770
  "model_size_in_billions":2,
4771
4771
  "quantizations":[
4772
+ "4bit",
4772
4773
  "8bit"
4773
4774
  ],
4774
4775
  "model_hub": "modelscope",
4775
- "model_id":"okwinds/Qwen2-VL-2B-Instruct-MLX-8bit",
4776
+ "model_id":"mlx-community/Qwen2-VL-2B-Instruct-{quantization}",
4776
4777
  "model_revision":"master"
4777
4778
  },
4778
4779
  {
@@ -4825,6 +4826,97 @@
4825
4826
  "<|endoftext|>"
4826
4827
  ]
4827
4828
  },
4829
+ {
4830
+ "version":1,
4831
+ "context_length":128000,
4832
+ "model_name":"qwen2.5-vl-instruct",
4833
+ "model_lang":[
4834
+ "en",
4835
+ "zh"
4836
+ ],
4837
+ "model_ability":[
4838
+ "chat",
4839
+ "vision"
4840
+ ],
4841
+ "model_description":"Qwen2.5-VL: Qwen2.5-VL is the latest version of the vision language models in the Qwen model familities.",
4842
+ "model_specs":[
4843
+ {
4844
+ "model_format":"pytorch",
4845
+ "model_size_in_billions":3,
4846
+ "quantizations":[
4847
+ "none"
4848
+ ],
4849
+ "model_hub": "modelscope",
4850
+ "model_id":"qwen/Qwen2.5-VL-3B-Instruct"
4851
+ },
4852
+ {
4853
+ "model_format":"pytorch",
4854
+ "model_size_in_billions":7,
4855
+ "quantizations":[
4856
+ "none"
4857
+ ],
4858
+ "model_hub": "modelscope",
4859
+ "model_id":"qwen/Qwen2.5-VL-7B-Instruct"
4860
+ },
4861
+ {
4862
+ "model_format":"pytorch",
4863
+ "model_size_in_billions":72,
4864
+ "quantizations":[
4865
+ "none"
4866
+ ],
4867
+ "model_hub": "modelscope",
4868
+ "model_id":"qwen/Qwen2.5-VL-72B-Instruct"
4869
+ },
4870
+ {
4871
+ "model_format":"mlx",
4872
+ "model_size_in_billions":3,
4873
+ "quantizations":[
4874
+ "3bit",
4875
+ "4bit",
4876
+ "6bit",
4877
+ "8bit",
4878
+ "bf16"
4879
+ ],
4880
+ "model_hub": "modelscope",
4881
+ "model_id":"mlx-community/Qwen2.5-VL-3B-Instruct-{quantization}"
4882
+ },
4883
+ {
4884
+ "model_format":"mlx",
4885
+ "model_size_in_billions":7,
4886
+ "quantizations":[
4887
+ "3bit",
4888
+ "4bit",
4889
+ "6bit",
4890
+ "8bit",
4891
+ "bf16"
4892
+ ],
4893
+ "model_hub": "modelscope",
4894
+ "model_id":"mlx-community/Qwen2.5-VL-7B-Instruct-{quantization}"
4895
+ },
4896
+ {
4897
+ "model_format":"mlx",
4898
+ "model_size_in_billions":72,
4899
+ "quantizations":[
4900
+ "3bit",
4901
+ "4bit",
4902
+ "6bit",
4903
+ "8bit",
4904
+ "bf16"
4905
+ ],
4906
+ "model_hub": "modelscope",
4907
+ "model_id":"mlx-community/Qwen2.5-VL-72B-Instruct-{quantization}"
4908
+ }
4909
+ ],
4910
+ "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}",
4911
+ "stop_token_ids": [
4912
+ 151645,
4913
+ 151643
4914
+ ],
4915
+ "stop": [
4916
+ "<|im_end|>",
4917
+ "<|endoftext|>"
4918
+ ]
4919
+ },
4828
4920
  {
4829
4921
  "version": 1,
4830
4922
  "context_length": 32768,
@@ -5558,7 +5650,7 @@
5558
5650
  "q8_0"
5559
5651
  ],
5560
5652
  "model_id": "qwen/Qwen2.5-7B-Instruct-GGUF",
5561
- "model_file_name_template": "qwen2_5-7b-instruct-{quantization}.gguf",
5653
+ "model_file_name_template": "qwen2.5-7b-instruct-{quantization}.gguf",
5562
5654
  "model_hub": "modelscope",
5563
5655
  "model_file_name_split_template": "qwen2.5-7b-instruct-{quantization}-{part}.gguf",
5564
5656
  "quantization_parts": {
@@ -6473,6 +6565,19 @@
6473
6565
  "model_file_name_template": "DeepSeek-R1-Distill-Qwen-1.5B-{quantization}.gguf",
6474
6566
  "model_hub": "modelscope"
6475
6567
  },
6568
+ {
6569
+ "model_format": "mlx",
6570
+ "model_size_in_billions": "1_5",
6571
+ "quantizations": [
6572
+ "3bit",
6573
+ "4bit",
6574
+ "6bit",
6575
+ "8bit",
6576
+ "bf16"
6577
+ ],
6578
+ "model_id": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-{quantization}",
6579
+ "model_hub": "modelscope"
6580
+ },
6476
6581
  {
6477
6582
  "model_format": "pytorch",
6478
6583
  "model_size_in_billions": 7,
@@ -6621,6 +6726,125 @@
6621
6726
  "<|end▁of▁sentence|>"
6622
6727
  ]
6623
6728
  },
6729
+ {
6730
+ "version": 1,
6731
+ "context_length": 131072,
6732
+ "model_name": "deepseek-r1-distill-llama",
6733
+ "model_lang": [
6734
+ "en",
6735
+ "zh"
6736
+ ],
6737
+ "model_ability": [
6738
+ "chat"
6739
+ ],
6740
+ "model_description": "deepseek-r1-distill-llama is distilled from DeepSeek-R1 based on Llama",
6741
+ "model_specs": [
6742
+ {
6743
+ "model_format": "pytorch",
6744
+ "model_size_in_billions": 8,
6745
+ "quantizations": [
6746
+ "4-bit",
6747
+ "8-bit",
6748
+ "none"
6749
+ ],
6750
+ "model_id": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
6751
+ "model_hub": "modelscope"
6752
+ },
6753
+ {
6754
+ "model_format": "ggufv2",
6755
+ "model_size_in_billions": 8,
6756
+ "quantizations": [
6757
+ "Q2_K",
6758
+ "Q2_K_L",
6759
+ "Q3_K_M",
6760
+ "Q4_K_M",
6761
+ "Q5_K_M",
6762
+ "Q6_K",
6763
+ "Q8_0",
6764
+ "F16"
6765
+ ],
6766
+ "model_id": "unsloth/DeepSeek-R1-Distill-Llama-8B-GGUF",
6767
+ "model_file_name_template": "DeepSeek-R1-Distill-Llama-8B-{quantization}.gguf",
6768
+ "model_hub": "modelscope"
6769
+ },
6770
+ {
6771
+ "model_format": "mlx",
6772
+ "model_size_in_billions": 8,
6773
+ "quantizations": [
6774
+ "3bit",
6775
+ "4bit",
6776
+ "6bit",
6777
+ "8bit",
6778
+ "bf16"
6779
+ ],
6780
+ "model_id": "okwinds/DeepSeek-R1-Distill-Llama-8B-MLX-{quantization}",
6781
+ "model_hub": "modelscope"
6782
+ },
6783
+ {
6784
+ "model_format": "pytorch",
6785
+ "model_size_in_billions": 70,
6786
+ "quantizations": [
6787
+ "4-bit",
6788
+ "8-bit",
6789
+ "none"
6790
+ ],
6791
+ "model_id": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
6792
+ "model_hub": "modelscope"
6793
+ },
6794
+ {
6795
+ "model_format": "ggufv2",
6796
+ "model_size_in_billions": 70,
6797
+ "quantizations": [
6798
+ "Q2_K",
6799
+ "Q2_K_L",
6800
+ "Q3_K_M",
6801
+ "Q4_K_M",
6802
+ "Q5_K_M",
6803
+ "Q6_K",
6804
+ "Q8_0",
6805
+ "F16"
6806
+ ],
6807
+ "quantization_parts": {
6808
+ "Q6_K": [
6809
+ "00001-of-00002",
6810
+ "00002-of-00002"
6811
+ ],
6812
+ "Q8_0": [
6813
+ "00001-of-00002",
6814
+ "00002-of-00002"
6815
+ ],
6816
+ "F16": [
6817
+ "00001-of-00003",
6818
+ "00002-of-00003",
6819
+ "00003-of-00003"
6820
+ ]
6821
+ },
6822
+ "model_id": "unsloth/DeepSeek-R1-Distill-Llama-70B-GGUF",
6823
+ "model_file_name_template": "DeepSeek-R1-Distill-Qwen-7B-{quantization}.gguf",
6824
+ "model_file_name_split_template": "DeepSeek-R1-Distill-Llama-70B-{quantization}/DeepSeek-R1-Distill-Llama-70B-{quantization}-{part}.gguf",
6825
+ "model_hub": "modelscope"
6826
+ },
6827
+ {
6828
+ "model_format": "mlx",
6829
+ "model_size_in_billions": 70,
6830
+ "quantizations": [
6831
+ "3bit",
6832
+ "4bit",
6833
+ "6bit",
6834
+ "8bit"
6835
+ ],
6836
+ "model_id": "okwinds/DeepSeek-R1-Distill-Llama-70B-MLX-{quantization}",
6837
+ "model_hub": "modelscope"
6838
+ }
6839
+ ],
6840
+ "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}",
6841
+ "stop_token_ids": [
6842
+ 151643
6843
+ ],
6844
+ "stop": [
6845
+ "<|end▁of▁sentence|>"
6846
+ ]
6847
+ },
6624
6848
  {
6625
6849
  "version": 1,
6626
6850
  "context_length": 8192,
@@ -6911,7 +7135,7 @@
6911
7135
  "<|endoftext|>"
6912
7136
  ]
6913
7137
  },
6914
- {
7138
+ {
6915
7139
  "version": 1,
6916
7140
  "context_length": 32768,
6917
7141
  "model_name": "marco-o1",
@@ -7009,5 +7233,85 @@
7009
7233
  "<|user|>",
7010
7234
  "<|observation|>"
7011
7235
  ]
7236
+ },
7237
+ {
7238
+ "version": 1,
7239
+ "context_length": 32768,
7240
+ "model_name": "internlm3-instruct",
7241
+ "model_lang": [
7242
+ "en",
7243
+ "zh"
7244
+ ],
7245
+ "model_ability": [
7246
+ "chat",
7247
+ "tools"
7248
+ ],
7249
+ "model_description": "InternLM3 has open-sourced an 8-billion parameter instruction model, InternLM3-8B-Instruct, designed for general-purpose usage and advanced reasoning.",
7250
+ "model_specs": [
7251
+ {
7252
+ "model_format": "pytorch",
7253
+ "model_size_in_billions": 8,
7254
+ "quantizations": [
7255
+ "4-bit",
7256
+ "8-bit",
7257
+ "none"
7258
+ ],
7259
+ "model_id": "Shanghai_AI_Laboratory/internlm3-8b-instruct",
7260
+ "model_hub": "modelscope"
7261
+ },
7262
+ {
7263
+ "model_format": "gptq",
7264
+ "model_size_in_billions": 8,
7265
+ "quantizations": [
7266
+ "Int4"
7267
+ ],
7268
+ "model_id": "Shanghai_AI_Laboratory/internlm3-8b-instruct-gptq-int4",
7269
+ "model_hub": "modelscope"
7270
+ },
7271
+ {
7272
+ "model_format": "awq",
7273
+ "model_size_in_billions": 8,
7274
+ "quantizations": [
7275
+ "Int4"
7276
+ ],
7277
+ "model_id": "Shanghai_AI_Laboratory/internlm3-8b-instruct-awq",
7278
+ "model_hub": "modelscope"
7279
+ },
7280
+ {
7281
+ "model_format": "ggufv2",
7282
+ "model_size_in_billions": 8,
7283
+ "quantizations": [
7284
+ "q2_k",
7285
+ "q3_k_m",
7286
+ "q4_0",
7287
+ "q4_k_m",
7288
+ "q5_0",
7289
+ "q5_k_m",
7290
+ "q6_k",
7291
+ "q8_0"
7292
+ ],
7293
+ "model_id": "Shanghai_AI_Laboratory/internlm3-8b-instruct-gguf",
7294
+ "model_file_name_template": "internlm3-8b-instruct-{quantization}.gguf",
7295
+ "model_hub": "modelscope"
7296
+ },
7297
+ {
7298
+ "model_format":"mlx",
7299
+ "model_size_in_billions":8,
7300
+ "quantizations":[
7301
+ "4bit"
7302
+ ],
7303
+ "model_hub": "modelscope",
7304
+ "model_id":"mlx-community/internlm3-8b-instruct-{quantization}"
7305
+ }
7306
+ ],
7307
+ "chat_template": "{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
7308
+ "stop_token_ids": [
7309
+ 2,
7310
+ 128131
7311
+ ],
7312
+ "stop": [
7313
+ "</s>",
7314
+ "<|im_end|>"
7315
+ ]
7012
7316
  }
7013
7317
  ]
@@ -31,7 +31,12 @@ from ....types import (
31
31
  )
32
32
  from ..core import LLM
33
33
  from ..llm_family import LLMFamilyV1, LLMSpecV1
34
- from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin, generate_completion_chunk
34
+ from ..utils import (
35
+ DEEPSEEK_TOOL_CALL_FAMILY,
36
+ QWEN_TOOL_CALL_FAMILY,
37
+ ChatModelMixin,
38
+ generate_completion_chunk,
39
+ )
35
40
 
36
41
  logger = logging.getLogger(__name__)
37
42
 
@@ -424,8 +429,11 @@ class MLXChatModel(MLXModel, ChatModelMixin):
424
429
  model_family = self.model_family.model_family or self.model_family.model_name
425
430
  tools = generate_config.pop("tools", []) if generate_config else None
426
431
  full_context_kwargs = {}
427
- if tools and model_family in QWEN_TOOL_CALL_FAMILY:
428
- full_context_kwargs["tools"] = tools
432
+ if tools:
433
+ if model_family in QWEN_TOOL_CALL_FAMILY:
434
+ full_context_kwargs["tools"] = tools
435
+ elif model_family in DEEPSEEK_TOOL_CALL_FAMILY:
436
+ self._tools_to_messages_for_deepseek(messages, tools)
429
437
  assert self.model_family.chat_template is not None
430
438
  full_prompt = self.get_full_context(
431
439
  messages, self.model_family.chat_template, **full_context_kwargs
@@ -39,7 +39,12 @@ from ....types import (
39
39
  from ...utils import select_device
40
40
  from ..core import LLM
41
41
  from ..llm_family import LLMFamilyV1, LLMSpecV1
42
- from ..utils import LLAMA3_TOOL_CALL_FAMILY, QWEN_TOOL_CALL_FAMILY, ChatModelMixin
42
+ from ..utils import (
43
+ DEEPSEEK_TOOL_CALL_FAMILY,
44
+ LLAMA3_TOOL_CALL_FAMILY,
45
+ QWEN_TOOL_CALL_FAMILY,
46
+ ChatModelMixin,
47
+ )
43
48
  from .utils import get_context_length, get_max_src_len, pad_prefill_tokens
44
49
 
45
50
  logger = logging.getLogger(__name__)
@@ -62,6 +67,7 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
62
67
  "MiniCPM-V-2.6",
63
68
  "glm-4v",
64
69
  "qwen2-vl-instruct",
70
+ "qwen2.5-vl-instruct",
65
71
  "qwen2-audio",
66
72
  "qwen2-audio-instruct",
67
73
  "deepseek-v2",
@@ -681,6 +687,8 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
681
687
  or model_family in LLAMA3_TOOL_CALL_FAMILY
682
688
  ):
683
689
  full_context_kwargs["tools"] = tools
690
+ elif tools and model_family in DEEPSEEK_TOOL_CALL_FAMILY:
691
+ self._tools_to_messages_for_deepseek(messages, tools)
684
692
  assert self.model_family.chat_template is not None
685
693
  full_prompt = self.get_full_context(
686
694
  messages,
@@ -55,9 +55,9 @@ class Qwen2AudioChatModel(PytorchChatModel):
55
55
 
56
56
  device = self._pytorch_model_config.get("device", "auto")
57
57
  device = select_device(device)
58
- self._device = device
59
58
  # for multiple GPU, set back to auto to make multiple devices work
60
59
  device = "auto" if device == "cuda" else device
60
+ self._device = device
61
61
 
62
62
  self._processor = AutoProcessor.from_pretrained(
63
63
  self.model_path,
@@ -105,6 +105,8 @@ class Qwen2AudioChatModel(PytorchChatModel):
105
105
  inputs = self._processor(
106
106
  text=text, audios=audios, return_tensors="pt", padding=True
107
107
  )
108
+ # Make sure that the inputs and the model are on the same device.
109
+ inputs.data = {k: v.to(self._device) for k, v in inputs.data.items()}
108
110
  inputs.input_ids = inputs.input_ids.to(self._device)
109
111
  generate_config = generate_config if generate_config else {}
110
112
  stream = generate_config.get("stream", False) if generate_config else False
@@ -45,9 +45,13 @@ class Qwen2VLChatModel(PytorchChatModel):
45
45
  def match(
46
46
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
+ if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
49
+ return False
48
50
  llm_family = model_family.model_family or model_family.model_name
49
51
  if "qwen2-vl-instruct".lower() in llm_family.lower():
50
52
  return True
53
+ if "qwen2.5-vl-instruct".lower() in llm_family.lower():
54
+ return True
51
55
  if "qvq-72b-preview".lower() in llm_family.lower():
52
56
  return True
53
57
  return False
@@ -55,6 +59,11 @@ class Qwen2VLChatModel(PytorchChatModel):
55
59
  def load(self):
56
60
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
57
61
 
62
+ try:
63
+ from transformers import Qwen2_5_VLForConditionalGeneration
64
+ except ImportError:
65
+ Qwen2_5_VLForConditionalGeneration = None
66
+
58
67
  device = self._pytorch_model_config.get("device", "auto")
59
68
  device = select_device(device)
60
69
  self._device = device
@@ -66,8 +75,16 @@ class Qwen2VLChatModel(PytorchChatModel):
66
75
  )
67
76
  self._tokenizer = self._processor.tokenizer
68
77
  flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
78
+ llm_family = self.model_family.model_family or self.model_family.model_name
79
+ model_cls = (
80
+ Qwen2_5_VLForConditionalGeneration
81
+ if "qwen2.5" in llm_family
82
+ else Qwen2VLForConditionalGeneration
83
+ )
84
+ if model_cls is None:
85
+ raise ImportError("`transformers` version is too old, please upgrade it")
69
86
  if flash_attn_installed:
70
- self._model = Qwen2VLForConditionalGeneration.from_pretrained(
87
+ self._model = model_cls.from_pretrained(
71
88
  self.model_path,
72
89
  torch_dtype="bfloat16",
73
90
  device_map=device,
@@ -76,14 +93,14 @@ class Qwen2VLChatModel(PytorchChatModel):
76
93
  ).eval()
77
94
  elif is_npu_available():
78
95
  # Ascend do not support bf16
79
- self._model = Qwen2VLForConditionalGeneration.from_pretrained(
96
+ self._model = model_cls.from_pretrained(
80
97
  self.model_path,
81
98
  device_map="auto",
82
99
  trust_remote_code=True,
83
100
  torch_dtype="float16",
84
101
  ).eval()
85
102
  else:
86
- self._model = Qwen2VLForConditionalGeneration.from_pretrained(
103
+ self._model = model_cls.from_pretrained(
87
104
  self.model_path, device_map=device, trust_remote_code=True
88
105
  ).eval()
89
106
 
@@ -193,16 +193,14 @@ def _get_pad_param(seq_len_idx: int, pad_len: int) -> Tuple:
193
193
 
194
194
  def _merge_kv_cache(
195
195
  xinf_model_obj: "PytorchModel",
196
- past_kv: Tuple[Tuple[torch.Tensor]],
197
- new_kv: Tuple[Tuple[torch.Tensor]],
198
- ):
196
+ past_cache: DynamicCache,
197
+ new_cache: DynamicCache,
198
+ ) -> DynamicCache:
199
199
  from torch.nn.functional import pad
200
200
 
201
201
  _, seq_len_idx = xinf_model_obj.get_batch_size_and_seq_len_indexes_from_kv()
202
- past_cache = DynamicCache.from_legacy_cache(past_kv)
203
- new_cache = DynamicCache.from_legacy_cache(new_kv)
204
- past_seq_len = past_kv[0][0].shape[seq_len_idx]
205
- new_seq_len = new_kv[0][0].shape[seq_len_idx]
202
+ past_seq_len = past_cache[0][0].shape[seq_len_idx]
203
+ new_seq_len = new_cache[0][0].shape[seq_len_idx]
206
204
  if past_seq_len != new_seq_len:
207
205
  padding_target = new_cache if past_seq_len > new_seq_len else past_cache
208
206
  padding_len = abs(past_seq_len - new_seq_len)
@@ -219,8 +217,12 @@ def _merge_kv_cache(
219
217
  for idx in range(len(past_cache)):
220
218
  k1, k2 = new_cache.key_cache[idx], past_cache.key_cache[idx]
221
219
  v1, v2 = new_cache.value_cache[idx], past_cache.value_cache[idx]
222
- ret_kv.update(torch.cat((k1, k2), 0), torch.cat((v1, v2), 0), idx)
223
- return ret_kv.to_legacy_cache()
220
+ ret_kv.update(
221
+ torch.cat((k1, k2), 0).contiguous(),
222
+ torch.cat((v1, v2), 0).contiguous(),
223
+ idx,
224
+ )
225
+ return ret_kv
224
226
 
225
227
 
226
228
  def get_batch_size_and_seq_len_from_kv_cache(kv, xinf_model_obj: "PytorchModel"):
@@ -228,6 +230,15 @@ def get_batch_size_and_seq_len_from_kv_cache(kv, xinf_model_obj: "PytorchModel")
228
230
  return kv[0][0].shape[bs_idx], kv[0][0].shape[seq_len_idx] + 1
229
231
 
230
232
 
233
+ def convert_to_cache_cls(cache) -> DynamicCache:
234
+ """
235
+ Compatible with some old models
236
+ """
237
+ if isinstance(cache, tuple):
238
+ return DynamicCache.from_legacy_cache(cache)
239
+ return cache
240
+
241
+
231
242
  @torch.inference_mode()
232
243
  def _batch_inference_one_step_internal(
233
244
  xinf_model_obj: "PytorchModel",
@@ -269,7 +280,7 @@ def _batch_inference_one_step_internal(
269
280
  out = model(**prefill_kws, use_cache=True)
270
281
 
271
282
  logits = out.logits
272
- past_key_values = out.past_key_values
283
+ past_key_values = convert_to_cache_cls(out.past_key_values)
273
284
 
274
285
  for i, r in enumerate(prefill_reqs):
275
286
  (
@@ -317,7 +328,7 @@ def _batch_inference_one_step_internal(
317
328
  )
318
329
  out = model(**inf_kws, use_cache=True, past_key_values=past_key_values)
319
330
  logits = out.logits
320
- past_key_values = out.past_key_values
331
+ past_key_values = convert_to_cache_cls(out.past_key_values)
321
332
 
322
333
  for i, r in enumerate(valid_req_list):
323
334
  (