xinference 1.5.0.post1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (89) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +10 -3
  3. xinference/constants.py +5 -1
  4. xinference/core/supervisor.py +12 -3
  5. xinference/core/utils.py +1 -1
  6. xinference/core/worker.py +2 -2
  7. xinference/deploy/cmdline.py +17 -0
  8. xinference/model/audio/core.py +1 -1
  9. xinference/model/audio/model_spec.json +43 -43
  10. xinference/model/audio/model_spec_modelscope.json +13 -13
  11. xinference/model/llm/__init__.py +3 -5
  12. xinference/model/llm/core.py +14 -0
  13. xinference/model/llm/llama_cpp/core.py +15 -4
  14. xinference/model/llm/llm_family.json +3251 -4304
  15. xinference/model/llm/llm_family.py +62 -6
  16. xinference/model/llm/llm_family_csghub.json +0 -32
  17. xinference/model/llm/llm_family_modelscope.json +1161 -1789
  18. xinference/model/llm/llm_family_openmind_hub.json +19 -325
  19. xinference/model/llm/lmdeploy/core.py +7 -2
  20. xinference/model/llm/mlx/core.py +19 -6
  21. xinference/model/llm/sglang/core.py +25 -10
  22. xinference/model/llm/transformers/chatglm.py +8 -1
  23. xinference/model/llm/transformers/cogagent.py +10 -12
  24. xinference/model/llm/transformers/cogvlm2.py +6 -3
  25. xinference/model/llm/transformers/cogvlm2_video.py +3 -6
  26. xinference/model/llm/transformers/core.py +50 -58
  27. xinference/model/llm/transformers/deepseek_v2.py +4 -2
  28. xinference/model/llm/transformers/deepseek_vl.py +10 -4
  29. xinference/model/llm/transformers/deepseek_vl2.py +9 -4
  30. xinference/model/llm/transformers/gemma3.py +15 -7
  31. xinference/model/llm/transformers/glm4v.py +2 -20
  32. xinference/model/llm/transformers/glm_edge_v.py +3 -20
  33. xinference/model/llm/transformers/intern_vl.py +3 -6
  34. xinference/model/llm/transformers/internlm2.py +1 -1
  35. xinference/model/llm/transformers/minicpmv25.py +4 -2
  36. xinference/model/llm/transformers/minicpmv26.py +5 -3
  37. xinference/model/llm/transformers/omnilmm.py +1 -1
  38. xinference/model/llm/transformers/opt.py +1 -1
  39. xinference/model/llm/transformers/ovis2.py +302 -0
  40. xinference/model/llm/transformers/qwen-omni.py +2 -1
  41. xinference/model/llm/transformers/qwen2_audio.py +3 -1
  42. xinference/model/llm/transformers/qwen2_vl.py +5 -1
  43. xinference/model/llm/transformers/qwen_vl.py +5 -2
  44. xinference/model/llm/utils.py +28 -0
  45. xinference/model/llm/vllm/core.py +73 -9
  46. xinference/model/llm/vllm/distributed_executor.py +8 -7
  47. xinference/model/llm/vllm/xavier/allocator.py +1 -1
  48. xinference/model/llm/vllm/xavier/block_manager.py +1 -1
  49. xinference/model/llm/vllm/xavier/block_tracker.py +3 -3
  50. xinference/model/llm/vllm/xavier/executor.py +1 -1
  51. xinference/model/llm/vllm/xavier/test/test_xavier.py +1 -1
  52. xinference/model/video/diffusers.py +30 -3
  53. xinference/model/video/model_spec.json +46 -0
  54. xinference/model/video/model_spec_modelscope.json +48 -0
  55. xinference/types.py +2 -0
  56. xinference/web/ui/build/asset-manifest.json +6 -6
  57. xinference/web/ui/build/index.html +1 -1
  58. xinference/web/ui/build/static/css/{main.0f6523be.css → main.337afe76.css} +2 -2
  59. xinference/web/ui/build/static/css/main.337afe76.css.map +1 -0
  60. xinference/web/ui/build/static/js/main.91e77b5c.js +3 -0
  61. xinference/web/ui/build/static/js/main.91e77b5c.js.map +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +1 -0
  63. xinference/web/ui/node_modules/.cache/babel-loader/5e6edb0fb87e3798f142e9abf8dd2dc46bab33a60d31dff525797c0c99887097.json +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/6087820be1bd5c02c42dff797e7df365448ef35ab26dd5d6bd33e967e05cbfd4.json +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/6798e126f3bc5f95a4c16a9c2ad52ffe77970c62406d83e20604dfda7ffd2247.json +1 -0
  66. xinference/web/ui/node_modules/.cache/babel-loader/b617f7d21a95045fc57b26a9373551740f1978a826134cbf705c3a1bf8714a93.json +1 -0
  67. xinference/web/ui/node_modules/.cache/babel-loader/c1506cb142151366074975f30fa1ff9cd6e5e978b62a4b074dfc16fe08d70d75.json +1 -0
  68. xinference/web/ui/node_modules/.cache/babel-loader/c5c7c2cd1b863ce41adff2c4737bba06eef3a1acf28288cb83d992060f6b8923.json +1 -0
  69. xinference/web/ui/src/locales/en.json +1 -0
  70. xinference/web/ui/src/locales/zh.json +1 -0
  71. {xinference-1.5.0.post1.dist-info → xinference-1.5.1.dist-info}/METADATA +1 -1
  72. {xinference-1.5.0.post1.dist-info → xinference-1.5.1.dist-info}/RECORD +77 -78
  73. {xinference-1.5.0.post1.dist-info → xinference-1.5.1.dist-info}/WHEEL +1 -1
  74. xinference/model/llm/transformers/compression.py +0 -258
  75. xinference/model/llm/transformers/yi_vl.py +0 -239
  76. xinference/web/ui/build/static/css/main.0f6523be.css.map +0 -1
  77. xinference/web/ui/build/static/js/main.58bd483c.js +0 -3
  78. xinference/web/ui/build/static/js/main.58bd483c.js.map +0 -1
  79. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +0 -1
  80. xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +0 -1
  81. xinference/web/ui/node_modules/.cache/babel-loader/8f9af2979e45d4648f0cfae108363e58ee421c29a9d4e7329b6f06d9adfd4133.json +0 -1
  82. xinference/web/ui/node_modules/.cache/babel-loader/9c8b1a86e7c65b2b2599a205e30920652d6c2105f926508ef5bcf29a3ef4ce76.json +0 -1
  83. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +0 -1
  84. xinference/web/ui/node_modules/.cache/babel-loader/efe7cd132c27a8f9fd5352a394c491fd5fb0da0348cf9fcbd923164a32365eab.json +0 -1
  85. xinference/web/ui/node_modules/.cache/babel-loader/f199e8173f6409a5802ed44acb95f218388131136504b2e9132129e150c92f9a.json +0 -1
  86. /xinference/web/ui/build/static/js/{main.58bd483c.js.LICENSE.txt → main.91e77b5c.js.LICENSE.txt} +0 -0
  87. {xinference-1.5.0.post1.dist-info → xinference-1.5.1.dist-info}/entry_points.txt +0 -0
  88. {xinference-1.5.0.post1.dist-info → xinference-1.5.1.dist-info}/licenses/LICENSE +0 -0
  89. {xinference-1.5.0.post1.dist-info → xinference-1.5.1.dist-info}/top_level.txt +0 -0
@@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
36
36
 
37
37
  class Gemma3TextChatModel(PytorchChatModel):
38
38
  @classmethod
39
- def match(
39
+ def match_json(
40
40
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
41
41
  ) -> bool:
42
42
  if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
@@ -56,7 +56,7 @@ class Gemma3ChatModel(PytorchChatModel):
56
56
  self._processor = None
57
57
 
58
58
  @classmethod
59
- def match(
59
+ def match_json(
60
60
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
61
61
  ) -> bool:
62
62
  if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
@@ -85,6 +85,7 @@ class Gemma3ChatModel(PytorchChatModel):
85
85
  device = "auto" if device == "cuda" else device
86
86
  min_pixels = self._pytorch_model_config.get("min_pixels")
87
87
  max_pixels = self._pytorch_model_config.get("max_pixels")
88
+ kwargs = self.apply_bnb_quantization()
88
89
  self._processor = AutoProcessor.from_pretrained(
89
90
  self.model_path,
90
91
  min_pixels=min_pixels,
@@ -92,9 +93,7 @@ class Gemma3ChatModel(PytorchChatModel):
92
93
  )
93
94
  self._tokenizer = self._processor.tokenizer
94
95
  self._model = Gemma3ForConditionalGeneration.from_pretrained(
95
- self.model_path,
96
- device_map="auto",
97
- torch_dtype="bfloat16",
96
+ self.model_path, device_map="auto", torch_dtype="bfloat16", **kwargs
98
97
  )
99
98
 
100
99
  @cache_clean
@@ -128,7 +127,12 @@ class Gemma3ChatModel(PytorchChatModel):
128
127
  ).to(self._device)
129
128
  input_len = inputs["input_ids"].shape[-1]
130
129
 
131
- generation = self._model.generate(**inputs, do_sample=False)
130
+ generation = self._model.generate(
131
+ **inputs,
132
+ do_sample=False,
133
+ max_new_tokens=config.get("max_tokens", 512),
134
+ temperature=config.get("temperature", 1),
135
+ )
132
136
  generation = generation[0][input_len:]
133
137
 
134
138
  decoded = self._processor.decode(generation, skip_special_tokens=True)
@@ -159,7 +163,11 @@ class Gemma3ChatModel(PytorchChatModel):
159
163
 
160
164
  def model_generate():
161
165
  try:
162
- return self._model.generate(**gen_kwargs)
166
+ return self._model.generate(
167
+ **gen_kwargs,
168
+ max_new_tokens=config.get("max_tokens", 512),
169
+ temperature=config.get("temperature", 1),
170
+ )
163
171
  except Exception:
164
172
  nonlocal error
165
173
  error = sys.exc_info()
@@ -39,7 +39,7 @@ class Glm4VModel(PytorchChatModel):
39
39
  self._model = None
40
40
 
41
41
  @classmethod
42
- def match(
42
+ def match_json(
43
43
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
44
44
  ) -> bool:
45
45
  family = model_family.model_family or model_family.model_name
@@ -54,25 +54,7 @@ class Glm4VModel(PytorchChatModel):
54
54
  self._device = select_device(device)
55
55
 
56
56
  kwargs = {"device_map": self._device}
57
- quantization = self.quantization
58
-
59
- # referenced from PytorchModel.load
60
- if quantization != "none":
61
- if self._device == "cuda" and self._is_linux():
62
- kwargs["device_map"] = "auto"
63
- if quantization == "4-bit":
64
- kwargs["load_in_4bit"] = True
65
- elif quantization == "8-bit":
66
- kwargs["load_in_8bit"] = True
67
- else:
68
- raise ValueError(
69
- f"Quantization {quantization} is not supported in temporary"
70
- )
71
- else:
72
- if quantization != "8-bit":
73
- raise ValueError(
74
- f"Only 8-bit quantization is supported if it is not linux system or cuda device"
75
- )
57
+ kwargs = self.apply_bnb_quantization(kwargs)
76
58
 
77
59
  if self._check_tensorizer_integrity():
78
60
  self._model, self._tokenizer = self._load_tensorizer()
@@ -42,7 +42,7 @@ class GlmEdgeVModel(PytorchChatModel):
42
42
  self._processor = None
43
43
 
44
44
  @classmethod
45
- def match(
45
+ def match_json(
46
46
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
48
  family = model_family.model_family or model_family.model_name
@@ -57,25 +57,7 @@ class GlmEdgeVModel(PytorchChatModel):
57
57
  self._device = select_device(device)
58
58
 
59
59
  kwargs = {"device_map": self._device}
60
- quantization = self.quantization
61
-
62
- # referenced from PytorchModel.load
63
- if quantization != "none":
64
- if self._device == "cuda" and self._is_linux():
65
- kwargs["device_map"] = "auto"
66
- if quantization == "4-bit":
67
- kwargs["load_in_4bit"] = True
68
- elif quantization == "8-bit":
69
- kwargs["load_in_8bit"] = True
70
- else:
71
- raise ValueError(
72
- f"Quantization {quantization} is not supported in temporary"
73
- )
74
- else:
75
- if quantization != "8-bit":
76
- raise ValueError(
77
- f"Only 8-bit quantization is supported if it is not linux system or cuda device"
78
- )
60
+ kwargs = self.apply_bnb_quantization(kwargs)
79
61
 
80
62
  processor = AutoImageProcessor.from_pretrained(
81
63
  self.model_path, trust_remote_code=True
@@ -87,6 +69,7 @@ class GlmEdgeVModel(PytorchChatModel):
87
69
  trust_remote_code=True,
88
70
  torch_dtype=torch.bfloat16,
89
71
  device_map="auto",
72
+ **kwargs
90
73
  )
91
74
 
92
75
  self._model = model
@@ -243,7 +243,7 @@ class InternVLChatModel(PytorchChatModel):
243
243
  self._model = None
244
244
 
245
245
  @classmethod
246
- def match(
246
+ def match_json(
247
247
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
248
248
  ) -> bool:
249
249
  family = model_family.model_family or model_family.model_name
@@ -311,7 +311,7 @@ class InternVLChatModel(PytorchChatModel):
311
311
  device_map[f"language_model.model.layers.{num_layers - 1}"] = 0
312
312
  return device_map
313
313
 
314
- def load(self, **kwargs):
314
+ def load(self):
315
315
  from transformers import AutoModel, AutoTokenizer
316
316
 
317
317
  if self._check_tensorizer_integrity():
@@ -329,10 +329,7 @@ class InternVLChatModel(PytorchChatModel):
329
329
  if device is not None:
330
330
  kwargs["device_map"] = device
331
331
 
332
- if "8-bit" in self.quantization.lower():
333
- kwargs["load_in_8bit"] = True
334
- elif "4-bit" in self.quantization.lower():
335
- kwargs["load_in_4bit"] = True
332
+ kwargs = self.apply_bnb_quantization(kwargs)
336
333
 
337
334
  self._model = AutoModel.from_pretrained(self.model_path, **kwargs).eval()
338
335
 
@@ -71,7 +71,7 @@ class Internlm2PytorchChatModel(PytorchChatModel):
71
71
  return model, tokenizer
72
72
 
73
73
  @classmethod
74
- def match(
74
+ def match_json(
75
75
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
76
76
  ) -> bool:
77
77
  model_family = llm_family.model_family or llm_family.model_name
@@ -42,7 +42,7 @@ class MiniCPMV25Model(PytorchChatModel):
42
42
  self._model = None
43
43
 
44
44
  @classmethod
45
- def match(
45
+ def match_json(
46
46
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
48
  family = model_family.model_family or model_family.model_name
@@ -55,7 +55,7 @@ class MiniCPMV25Model(PytorchChatModel):
55
55
 
56
56
  return AutoModel
57
57
 
58
- def load(self, **kwargs):
58
+ def load(self):
59
59
  from transformers import AutoModel, AutoTokenizer
60
60
  from transformers.generation import GenerationConfig
61
61
 
@@ -76,11 +76,13 @@ class MiniCPMV25Model(PytorchChatModel):
76
76
  if "int4" in self.model_path:
77
77
  model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
78
78
  else:
79
+ kwargs = self.apply_bnb_quantization()
79
80
  model = AutoModel.from_pretrained(
80
81
  self.model_path,
81
82
  trust_remote_code=True,
82
83
  torch_dtype=torch.float16,
83
84
  device_map=self._device,
85
+ **kwargs
84
86
  )
85
87
  tokenizer = AutoTokenizer.from_pretrained(
86
88
  self.model_path, trust_remote_code=True
@@ -49,7 +49,7 @@ class MiniCPMV26Model(PytorchChatModel):
49
49
  self._processor = None
50
50
 
51
51
  @classmethod
52
- def match(
52
+ def match_json(
53
53
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
54
54
  ) -> bool:
55
55
  family = model_family.model_family or model_family.model_name
@@ -71,7 +71,7 @@ class MiniCPMV26Model(PytorchChatModel):
71
71
 
72
72
  return AutoModel
73
73
 
74
- def load(self, **kwargs):
74
+ def load(self):
75
75
  from transformers import AutoModel, AutoProcessor, AutoTokenizer
76
76
  from transformers.generation import GenerationConfig
77
77
 
@@ -96,11 +96,13 @@ class MiniCPMV26Model(PytorchChatModel):
96
96
  if "int4" in self.model_path:
97
97
  model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
98
98
  else:
99
+ kwargs = self.apply_bnb_quantization()
99
100
  model = AutoModel.from_pretrained(
100
101
  self.model_path,
101
102
  trust_remote_code=True,
102
103
  torch_dtype=torch.float16,
103
104
  device_map=self._device,
105
+ **kwargs,
104
106
  )
105
107
  tokenizer = AutoTokenizer.from_pretrained(
106
108
  self.model_path, trust_remote_code=True
@@ -322,7 +324,7 @@ class MiniCPMV26Model(PytorchChatModel):
322
324
  "input_image": images,
323
325
  }
324
326
 
325
- def _get_full_prompt(self, messages: List[Dict], tools):
327
+ def _get_full_prompt(self, messages: List[Dict], tools): # type: ignore
326
328
  msgs, video_existed = self._convert_to_specific_style(messages)
327
329
  if video_existed:
328
330
  raise RuntimeError(
@@ -35,7 +35,7 @@ class OmniLMMModel(PytorchChatModel):
35
35
  self._model = None
36
36
 
37
37
  @classmethod
38
- def match(
38
+ def match_json(
39
39
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
40
40
  ) -> bool:
41
41
  llm_family = model_family.model_family or model_family.model_name
@@ -42,7 +42,7 @@ class OptPytorchModel(PytorchModel):
42
42
  )
43
43
 
44
44
  @classmethod
45
- def match(
45
+ def match_json(
46
46
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
48
  if llm_spec.model_format != "pytorch":
@@ -0,0 +1,302 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import uuid
16
+ from typing import Dict, Iterator, List, Optional, Union
17
+
18
+ import torch
19
+ from PIL import Image
20
+
21
+ from ....types import (
22
+ ChatCompletion,
23
+ ChatCompletionChunk,
24
+ ChatCompletionMessage,
25
+ CompletionChunk,
26
+ )
27
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
28
+ from ..utils import generate_chat_completion, generate_completion_chunk
29
+ from .core import PytorchChatModel, PytorchGenerateConfig
30
+ from .utils import cache_clean
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class Ovis2ChatModel(PytorchChatModel):
36
+ def __init__(self, *args, **kwargs):
37
+ super().__init__(*args, **kwargs)
38
+ self._tokenizer = None
39
+ self._model = None
40
+ self._device = None
41
+ self._processor = None
42
+
43
+ @classmethod
44
+ def match_json(
45
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
46
+ ) -> bool:
47
+ if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
48
+ return False
49
+ llm_family = model_family.model_family or model_family.model_name
50
+ if "ovis2".lower() in llm_family.lower():
51
+ return True
52
+ return False
53
+
54
+ def load(self):
55
+ from transformers import AutoModelForCausalLM
56
+
57
+ # load model
58
+ self._model = AutoModelForCausalLM.from_pretrained(
59
+ self.model_path,
60
+ torch_dtype=torch.bfloat16,
61
+ multimodal_max_length=32768,
62
+ trust_remote_code=True,
63
+ ).cuda()
64
+ self._text_tokenizer = self._model.get_text_tokenizer()
65
+ self._visual_tokenizer = self._model.get_visual_tokenizer()
66
+
67
+ @cache_clean
68
+ def chat(
69
+ self,
70
+ messages: List[ChatCompletionMessage], # type: ignore
71
+ generate_config: Optional[PytorchGenerateConfig] = None,
72
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
73
+ messages = self._transform_messages(messages)
74
+
75
+ generate_config = generate_config if generate_config else {}
76
+
77
+ stream = generate_config.get("stream", False) if generate_config else False
78
+
79
+ if stream:
80
+ # raise NotImplementedError("Stream is not supported for Ovis2 model.")
81
+ it = self._generate_stream(messages, generate_config)
82
+ return self._to_chat_completion_chunks(it)
83
+ else:
84
+ c = self._generate(messages, generate_config)
85
+ return c
86
+
87
+ def _generate(
88
+ self, messages: List, config: PytorchGenerateConfig = {}
89
+ ) -> ChatCompletion:
90
+ input_ids, attention_mask, pixel_values, gen_kwargs = self._generate_chat_data(
91
+ messages, config
92
+ )
93
+
94
+ # generate output
95
+ with torch.inference_mode():
96
+ gen_kwargs.update(
97
+ dict(
98
+ pixel_values=pixel_values,
99
+ attention_mask=attention_mask,
100
+ )
101
+ )
102
+
103
+ output_ids = self._model.generate(
104
+ input_ids,
105
+ **gen_kwargs,
106
+ )[0]
107
+ output = self._text_tokenizer.decode(output_ids, skip_special_tokens=True)
108
+ return generate_chat_completion(self.model_uid, output)
109
+
110
+ def _generate_stream(
111
+ self, messages: List, config: PytorchGenerateConfig = {}
112
+ ) -> Iterator[CompletionChunk]:
113
+ from threading import Thread
114
+
115
+ from transformers import TextIteratorStreamer
116
+
117
+ input_ids, attention_mask, pixel_values, gen_kwargs = self._generate_chat_data(
118
+ messages, config
119
+ )
120
+
121
+ _, inputs_embeds, _, attention_mask = self._model.merge_multimodal(
122
+ text_input_ids=input_ids,
123
+ text_attention_masks=attention_mask,
124
+ text_labels=None,
125
+ pixel_values=pixel_values,
126
+ left_padding=True,
127
+ )
128
+
129
+ streamer = TextIteratorStreamer(
130
+ self._text_tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
131
+ )
132
+
133
+ gen_kwargs.update(
134
+ dict(
135
+ inputs_embeds=inputs_embeds,
136
+ attention_mask=attention_mask,
137
+ streamer=streamer,
138
+ )
139
+ )
140
+
141
+ inputs_embeds = inputs_embeds.detach()
142
+ torch.cuda.empty_cache()
143
+
144
+ thread = Thread(target=self._model.llm.generate, kwargs=gen_kwargs)
145
+ thread.start()
146
+
147
+ completion_id = str(uuid.uuid1())
148
+
149
+ for new_text in streamer:
150
+ yield generate_completion_chunk(
151
+ chunk_text=new_text,
152
+ finish_reason=None,
153
+ chunk_id=completion_id,
154
+ model_uid=self.model_uid,
155
+ prompt_tokens=-1,
156
+ completion_tokens=-1,
157
+ total_tokens=-1,
158
+ has_choice=True,
159
+ has_content=True,
160
+ )
161
+
162
+ yield generate_completion_chunk(
163
+ chunk_text=None,
164
+ finish_reason="stop",
165
+ chunk_id=completion_id,
166
+ model_uid=self.model_uid,
167
+ prompt_tokens=-1,
168
+ completion_tokens=-1,
169
+ total_tokens=-1,
170
+ has_choice=True,
171
+ has_content=False,
172
+ )
173
+
174
+ def parse_messages_ovis(self, messages: List[Dict]) -> List[Dict]:
175
+ ovis_msgs = []
176
+ for mess in messages:
177
+ contents = mess["content"]
178
+ role = mess["role"]
179
+ if role == "user":
180
+ role = "human"
181
+ elif role == "assistant":
182
+ role = "gpt"
183
+ elif role == "system":
184
+ role = "system"
185
+
186
+ for content in contents:
187
+ if content["type"] == "text":
188
+ ovis_msgs.append({"from": role, "value": content["text"]})
189
+
190
+ return ovis_msgs
191
+
192
+ def _generate_chat_data(
193
+ self, messages: List[Dict], config: PytorchGenerateConfig = {}
194
+ ):
195
+ from qwen_vl_utils import process_vision_info
196
+
197
+ messages_ovis = self.parse_messages_ovis(messages)
198
+ max_partition = None
199
+ prompt = messages_ovis[-1]["value"]
200
+
201
+ # Preparation for inference
202
+ image_inputs, video_inputs = process_vision_info(messages)
203
+
204
+ image_inputs = image_inputs if image_inputs else []
205
+
206
+ if image_inputs and len(image_inputs) > 0:
207
+ if len(image_inputs) == 1:
208
+ max_partition = 9
209
+ prompt = f"<image>\n{prompt}"
210
+ else:
211
+ max_partition = len(image_inputs) + 1
212
+ prompt = (
213
+ "\n".join(
214
+ [f"Image {i+1}: <image>" for i in range(len(image_inputs))]
215
+ )
216
+ + "\n"
217
+ + prompt
218
+ )
219
+ elif video_inputs and len(video_inputs) > 0:
220
+ if isinstance(video_inputs[0], torch.Tensor):
221
+ # Convert from list[Tensor] to list[Image]
222
+ pil_images = self._convert_video_tensors_to_pil(video_inputs)
223
+
224
+ video_inputs = pil_images # Update video_inputs to PIL image list
225
+
226
+ max_partition = 1
227
+ image_inputs = video_inputs
228
+ prompt = "\n".join(["<image>"] * len(video_inputs)) + "\n" + prompt
229
+ else:
230
+ max_partition = 0
231
+ prompt = prompt
232
+
233
+ messages_ovis[-1]["value"] = prompt
234
+
235
+ # format conversation
236
+ prompt, input_ids, pixel_values = self._model.preprocess_inputs(
237
+ messages_ovis, image_inputs, max_partition=max_partition
238
+ )
239
+
240
+ attention_mask = torch.ne(input_ids, self._text_tokenizer.pad_token_id)
241
+ input_ids = input_ids.unsqueeze(0).to(device=self._model.device)
242
+ attention_mask = attention_mask.unsqueeze(0).to(device=self._model.device)
243
+ if pixel_values is not None:
244
+ pixel_values = pixel_values.to(
245
+ dtype=self._visual_tokenizer.dtype, device=self._visual_tokenizer.device
246
+ )
247
+ pixel_values = [pixel_values]
248
+
249
+ gen_kwargs = dict(
250
+ max_new_tokens=config.get("max_tokens", 1024),
251
+ do_sample=False,
252
+ top_p=None,
253
+ top_k=None,
254
+ temperature=config.get("temperature", None),
255
+ repetition_penalty=None,
256
+ eos_token_id=self._model.generation_config.eos_token_id,
257
+ pad_token_id=self._text_tokenizer.pad_token_id,
258
+ use_cache=True,
259
+ )
260
+
261
+ return input_ids, attention_mask, pixel_values, gen_kwargs
262
+
263
+ def _convert_video_tensors_to_pil(self, video_inputs: List) -> List[Image.Image]:
264
+ """Convert video tensors to a list of PIL images"""
265
+ from torchvision import transforms
266
+
267
+ to_pil = transforms.ToPILImage()
268
+ pil_images = []
269
+
270
+ for video_tensor_4d in video_inputs:
271
+ if isinstance(video_tensor_4d, torch.Tensor):
272
+ # Verify it's a 4D tensor
273
+ if video_tensor_4d.ndim == 4:
274
+ # Iterate through the first dimension (frames) of 4D tensor
275
+ for i in range(video_tensor_4d.size(0)):
276
+ frame_tensor_3d = video_tensor_4d[
277
+ i
278
+ ] # Get 3D frame tensor [C, H, W]
279
+ # Ensure tensor is on CPU before conversion
280
+ if frame_tensor_3d.is_cuda:
281
+ frame_tensor_3d = frame_tensor_3d.cpu()
282
+ try:
283
+ pil_image = to_pil(frame_tensor_3d)
284
+ pil_images.append(pil_image)
285
+ except Exception as e:
286
+ logger.error(
287
+ f"Error converting frame {i} to PIL Image: {e}"
288
+ )
289
+ # Can choose to skip this frame or handle error differently
290
+ else:
291
+ logger.warning(
292
+ f"Expected 4D tensor in video_inputs, but got {video_tensor_4d.ndim}D. Skipping this tensor."
293
+ )
294
+ elif isinstance(video_tensor_4d, Image.Image):
295
+ # If fetch_video returns Image list, add directly
296
+ pil_images.append(video_tensor_4d)
297
+ else:
298
+ logger.warning(
299
+ f"Unexpected type in video_inputs: {type(video_tensor_4d)}. Skipping."
300
+ )
301
+
302
+ return pil_images
@@ -56,7 +56,7 @@ class Qwen2_5OmniChatModel(PytorchChatModel):
56
56
  self._processor = None
57
57
 
58
58
  @classmethod
59
- def match(
59
+ def match_json(
60
60
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
61
61
  ) -> bool:
62
62
  if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
@@ -83,6 +83,7 @@ class Qwen2_5OmniChatModel(PytorchChatModel):
83
83
  if not flash_attn_installed
84
84
  else {"attn_implementation": "flash_attention_2"}
85
85
  )
86
+ kwargs = self.apply_bnb_quantization(kwargs)
86
87
  logger.debug("Loading model with extra kwargs: %s", kwargs)
87
88
 
88
89
  self._processor = Qwen2_5OmniProcessor.from_pretrained(
@@ -42,7 +42,7 @@ class Qwen2AudioChatModel(PytorchChatModel):
42
42
  self._device = None
43
43
 
44
44
  @classmethod
45
- def match(
45
+ def match_json(
46
46
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
48
  llm_family = model_family.model_family or model_family.model_name
@@ -58,6 +58,7 @@ class Qwen2AudioChatModel(PytorchChatModel):
58
58
  # for multiple GPU, set back to auto to make multiple devices work
59
59
  device = "auto" if device == "cuda" else device
60
60
  self._device = device
61
+ kwargs = self.apply_bnb_quantization()
61
62
 
62
63
  self._processor = AutoProcessor.from_pretrained(
63
64
  self.model_path,
@@ -70,6 +71,7 @@ class Qwen2AudioChatModel(PytorchChatModel):
70
71
  device_map=device,
71
72
  # trust_remote_code=True,
72
73
  revision=self.model_spec.model_revision,
74
+ **kwargs,
73
75
  )
74
76
 
75
77
  def _transform_messages(
@@ -54,7 +54,7 @@ class Qwen2VLChatModel(PytorchChatModel):
54
54
  return pytorch_model_config
55
55
 
56
56
  @classmethod
57
- def match(
57
+ def match_json(
58
58
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
59
59
  ) -> bool:
60
60
  if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
@@ -81,6 +81,8 @@ class Qwen2VLChatModel(PytorchChatModel):
81
81
  self._device = device
82
82
  # for multiple GPU, set back to auto to make multiple devices work
83
83
  device = "auto" if device == "cuda" else device
84
+ kwargs = self.apply_bnb_quantization()
85
+
84
86
  min_pixels = self._pytorch_model_config.get("min_pixels")
85
87
  max_pixels = self._pytorch_model_config.get("max_pixels")
86
88
  self._processor = AutoProcessor.from_pretrained(
@@ -106,6 +108,7 @@ class Qwen2VLChatModel(PytorchChatModel):
106
108
  device_map=device,
107
109
  attn_implementation="flash_attention_2",
108
110
  trust_remote_code=True,
111
+ **kwargs,
109
112
  ).eval()
110
113
  elif is_npu_available():
111
114
  # Ascend do not support bf16
@@ -114,6 +117,7 @@ class Qwen2VLChatModel(PytorchChatModel):
114
117
  device_map="auto",
115
118
  trust_remote_code=True,
116
119
  torch_dtype="float16",
120
+ **kwargs,
117
121
  ).eval()
118
122
  else:
119
123
  self._model = model_cls.from_pretrained(