xinference 1.5.0.post2__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (137) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +107 -11
  3. xinference/client/restful/restful_client.py +51 -11
  4. xinference/constants.py +5 -1
  5. xinference/core/media_interface.py +758 -0
  6. xinference/core/model.py +49 -9
  7. xinference/core/supervisor.py +1 -1
  8. xinference/core/utils.py +1 -1
  9. xinference/core/worker.py +33 -39
  10. xinference/deploy/cmdline.py +17 -0
  11. xinference/deploy/utils.py +0 -3
  12. xinference/model/audio/__init__.py +16 -27
  13. xinference/model/audio/core.py +2 -1
  14. xinference/model/audio/cosyvoice.py +4 -2
  15. xinference/model/audio/model_spec.json +63 -46
  16. xinference/model/audio/model_spec_modelscope.json +31 -14
  17. xinference/model/embedding/__init__.py +16 -24
  18. xinference/model/image/__init__.py +15 -25
  19. xinference/model/llm/__init__.py +40 -115
  20. xinference/model/llm/core.py +29 -6
  21. xinference/model/llm/llama_cpp/core.py +30 -347
  22. xinference/model/llm/llm_family.json +1674 -2203
  23. xinference/model/llm/llm_family.py +71 -7
  24. xinference/model/llm/llm_family_csghub.json +0 -32
  25. xinference/model/llm/llm_family_modelscope.json +1838 -2016
  26. xinference/model/llm/llm_family_openmind_hub.json +19 -325
  27. xinference/model/llm/lmdeploy/core.py +7 -2
  28. xinference/model/llm/mlx/core.py +23 -7
  29. xinference/model/llm/reasoning_parser.py +281 -5
  30. xinference/model/llm/sglang/core.py +39 -11
  31. xinference/model/llm/transformers/chatglm.py +9 -2
  32. xinference/model/llm/transformers/cogagent.py +10 -12
  33. xinference/model/llm/transformers/cogvlm2.py +6 -3
  34. xinference/model/llm/transformers/cogvlm2_video.py +3 -6
  35. xinference/model/llm/transformers/core.py +58 -60
  36. xinference/model/llm/transformers/deepseek_v2.py +4 -2
  37. xinference/model/llm/transformers/deepseek_vl.py +10 -4
  38. xinference/model/llm/transformers/deepseek_vl2.py +9 -4
  39. xinference/model/llm/transformers/gemma3.py +4 -5
  40. xinference/model/llm/transformers/glm4v.py +3 -21
  41. xinference/model/llm/transformers/glm_edge_v.py +3 -20
  42. xinference/model/llm/transformers/intern_vl.py +3 -6
  43. xinference/model/llm/transformers/internlm2.py +1 -1
  44. xinference/model/llm/transformers/minicpmv25.py +4 -2
  45. xinference/model/llm/transformers/minicpmv26.py +5 -3
  46. xinference/model/llm/transformers/omnilmm.py +1 -1
  47. xinference/model/llm/transformers/opt.py +1 -1
  48. xinference/model/llm/transformers/ovis2.py +302 -0
  49. xinference/model/llm/transformers/qwen-omni.py +8 -1
  50. xinference/model/llm/transformers/qwen2_audio.py +3 -1
  51. xinference/model/llm/transformers/qwen2_vl.py +5 -1
  52. xinference/model/llm/transformers/qwen_vl.py +5 -2
  53. xinference/model/llm/utils.py +96 -45
  54. xinference/model/llm/vllm/core.py +108 -24
  55. xinference/model/llm/vllm/distributed_executor.py +8 -7
  56. xinference/model/llm/vllm/xavier/allocator.py +1 -1
  57. xinference/model/llm/vllm/xavier/block_manager.py +1 -1
  58. xinference/model/llm/vllm/xavier/block_tracker.py +3 -3
  59. xinference/model/llm/vllm/xavier/executor.py +1 -1
  60. xinference/model/llm/vllm/xavier/test/test_xavier.py +2 -11
  61. xinference/model/rerank/__init__.py +13 -24
  62. xinference/model/video/__init__.py +15 -25
  63. xinference/model/video/core.py +3 -3
  64. xinference/model/video/diffusers.py +157 -13
  65. xinference/model/video/model_spec.json +100 -0
  66. xinference/model/video/model_spec_modelscope.json +104 -0
  67. xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
  68. xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
  69. xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
  70. xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
  71. xinference/thirdparty/cosyvoice/bin/train.py +7 -2
  72. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
  73. xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
  74. xinference/thirdparty/cosyvoice/cli/model.py +140 -155
  75. xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
  76. xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
  77. xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
  78. xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
  79. xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
  80. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
  81. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
  82. xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
  83. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
  84. xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
  85. xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
  86. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
  87. xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
  88. xinference/thirdparty/cosyvoice/utils/common.py +1 -1
  89. xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
  90. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
  91. xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
  92. xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
  93. xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
  94. xinference/types.py +2 -71
  95. xinference/web/ui/build/asset-manifest.json +6 -6
  96. xinference/web/ui/build/index.html +1 -1
  97. xinference/web/ui/build/static/css/{main.0f6523be.css → main.337afe76.css} +2 -2
  98. xinference/web/ui/build/static/css/main.337afe76.css.map +1 -0
  99. xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
  100. xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
  101. xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
  102. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
  103. xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
  104. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
  105. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
  106. xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +1 -0
  107. xinference/web/ui/node_modules/.cache/babel-loader/6798e126f3bc5f95a4c16a9c2ad52ffe77970c62406d83e20604dfda7ffd2247.json +1 -0
  108. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
  109. xinference/web/ui/node_modules/.cache/babel-loader/b617f7d21a95045fc57b26a9373551740f1978a826134cbf705c3a1bf8714a93.json +1 -0
  110. xinference/web/ui/node_modules/.cache/babel-loader/c1506cb142151366074975f30fa1ff9cd6e5e978b62a4b074dfc16fe08d70d75.json +1 -0
  111. xinference/web/ui/node_modules/.cache/babel-loader/c5c7c2cd1b863ce41adff2c4737bba06eef3a1acf28288cb83d992060f6b8923.json +1 -0
  112. xinference/web/ui/src/locales/en.json +7 -4
  113. xinference/web/ui/src/locales/zh.json +7 -4
  114. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/METADATA +56 -36
  115. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/RECORD +120 -121
  116. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/WHEEL +1 -1
  117. xinference/core/image_interface.py +0 -377
  118. xinference/model/llm/transformers/compression.py +0 -258
  119. xinference/model/llm/transformers/yi_vl.py +0 -239
  120. xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
  121. xinference/web/ui/build/static/css/main.0f6523be.css.map +0 -1
  122. xinference/web/ui/build/static/js/main.4b67a723.js +0 -3
  123. xinference/web/ui/build/static/js/main.4b67a723.js.map +0 -1
  124. xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
  125. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +0 -1
  126. xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/8f9af2979e45d4648f0cfae108363e58ee421c29a9d4e7329b6f06d9adfd4133.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/9c8b1a86e7c65b2b2599a205e30920652d6c2105f926508ef5bcf29a3ef4ce76.json +0 -1
  129. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +0 -1
  130. xinference/web/ui/node_modules/.cache/babel-loader/e4ba658c6b3b0490910acdae0c535a892257efb61539a24adf8038fc653bd22f.json +0 -1
  131. xinference/web/ui/node_modules/.cache/babel-loader/efe7cd132c27a8f9fd5352a394c491fd5fb0da0348cf9fcbd923164a32365eab.json +0 -1
  132. xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
  133. xinference/web/ui/node_modules/.cache/babel-loader/f199e8173f6409a5802ed44acb95f218388131136504b2e9132129e150c92f9a.json +0 -1
  134. /xinference/web/ui/build/static/js/{main.4b67a723.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
  135. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/entry_points.txt +0 -0
  136. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/licenses/LICENSE +0 -0
  137. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,302 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import uuid
16
+ from typing import Dict, Iterator, List, Optional, Union
17
+
18
+ import torch
19
+ from PIL import Image
20
+
21
+ from ....types import (
22
+ ChatCompletion,
23
+ ChatCompletionChunk,
24
+ ChatCompletionMessage,
25
+ CompletionChunk,
26
+ )
27
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
28
+ from ..utils import generate_chat_completion, generate_completion_chunk
29
+ from .core import PytorchChatModel, PytorchGenerateConfig
30
+ from .utils import cache_clean
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class Ovis2ChatModel(PytorchChatModel):
36
+ def __init__(self, *args, **kwargs):
37
+ super().__init__(*args, **kwargs)
38
+ self._tokenizer = None
39
+ self._model = None
40
+ self._device = None
41
+ self._processor = None
42
+
43
+ @classmethod
44
+ def match_json(
45
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
46
+ ) -> bool:
47
+ if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
48
+ return False
49
+ llm_family = model_family.model_family or model_family.model_name
50
+ if "ovis2".lower() in llm_family.lower():
51
+ return True
52
+ return False
53
+
54
+ def load(self):
55
+ from transformers import AutoModelForCausalLM
56
+
57
+ # load model
58
+ self._model = AutoModelForCausalLM.from_pretrained(
59
+ self.model_path,
60
+ torch_dtype=torch.bfloat16,
61
+ multimodal_max_length=32768,
62
+ trust_remote_code=True,
63
+ ).cuda()
64
+ self._text_tokenizer = self._model.get_text_tokenizer()
65
+ self._visual_tokenizer = self._model.get_visual_tokenizer()
66
+
67
+ @cache_clean
68
+ def chat(
69
+ self,
70
+ messages: List[ChatCompletionMessage], # type: ignore
71
+ generate_config: Optional[PytorchGenerateConfig] = None,
72
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
73
+ messages = self._transform_messages(messages)
74
+
75
+ generate_config = generate_config if generate_config else {}
76
+
77
+ stream = generate_config.get("stream", False) if generate_config else False
78
+
79
+ if stream:
80
+ # raise NotImplementedError("Stream is not supported for Ovis2 model.")
81
+ it = self._generate_stream(messages, generate_config)
82
+ return self._to_chat_completion_chunks(it)
83
+ else:
84
+ c = self._generate(messages, generate_config)
85
+ return c
86
+
87
+ def _generate(
88
+ self, messages: List, config: PytorchGenerateConfig = {}
89
+ ) -> ChatCompletion:
90
+ input_ids, attention_mask, pixel_values, gen_kwargs = self._generate_chat_data(
91
+ messages, config
92
+ )
93
+
94
+ # generate output
95
+ with torch.inference_mode():
96
+ gen_kwargs.update(
97
+ dict(
98
+ pixel_values=pixel_values,
99
+ attention_mask=attention_mask,
100
+ )
101
+ )
102
+
103
+ output_ids = self._model.generate(
104
+ input_ids,
105
+ **gen_kwargs,
106
+ )[0]
107
+ output = self._text_tokenizer.decode(output_ids, skip_special_tokens=True)
108
+ return generate_chat_completion(self.model_uid, output)
109
+
110
+ def _generate_stream(
111
+ self, messages: List, config: PytorchGenerateConfig = {}
112
+ ) -> Iterator[CompletionChunk]:
113
+ from threading import Thread
114
+
115
+ from transformers import TextIteratorStreamer
116
+
117
+ input_ids, attention_mask, pixel_values, gen_kwargs = self._generate_chat_data(
118
+ messages, config
119
+ )
120
+
121
+ _, inputs_embeds, _, attention_mask = self._model.merge_multimodal(
122
+ text_input_ids=input_ids,
123
+ text_attention_masks=attention_mask,
124
+ text_labels=None,
125
+ pixel_values=pixel_values,
126
+ left_padding=True,
127
+ )
128
+
129
+ streamer = TextIteratorStreamer(
130
+ self._text_tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
131
+ )
132
+
133
+ gen_kwargs.update(
134
+ dict(
135
+ inputs_embeds=inputs_embeds,
136
+ attention_mask=attention_mask,
137
+ streamer=streamer,
138
+ )
139
+ )
140
+
141
+ inputs_embeds = inputs_embeds.detach()
142
+ torch.cuda.empty_cache()
143
+
144
+ thread = Thread(target=self._model.llm.generate, kwargs=gen_kwargs)
145
+ thread.start()
146
+
147
+ completion_id = str(uuid.uuid1())
148
+
149
+ for new_text in streamer:
150
+ yield generate_completion_chunk(
151
+ chunk_text=new_text,
152
+ finish_reason=None,
153
+ chunk_id=completion_id,
154
+ model_uid=self.model_uid,
155
+ prompt_tokens=-1,
156
+ completion_tokens=-1,
157
+ total_tokens=-1,
158
+ has_choice=True,
159
+ has_content=True,
160
+ )
161
+
162
+ yield generate_completion_chunk(
163
+ chunk_text=None,
164
+ finish_reason="stop",
165
+ chunk_id=completion_id,
166
+ model_uid=self.model_uid,
167
+ prompt_tokens=-1,
168
+ completion_tokens=-1,
169
+ total_tokens=-1,
170
+ has_choice=True,
171
+ has_content=False,
172
+ )
173
+
174
+ def parse_messages_ovis(self, messages: List[Dict]) -> List[Dict]:
175
+ ovis_msgs = []
176
+ for mess in messages:
177
+ contents = mess["content"]
178
+ role = mess["role"]
179
+ if role == "user":
180
+ role = "human"
181
+ elif role == "assistant":
182
+ role = "gpt"
183
+ elif role == "system":
184
+ role = "system"
185
+
186
+ for content in contents:
187
+ if content["type"] == "text":
188
+ ovis_msgs.append({"from": role, "value": content["text"]})
189
+
190
+ return ovis_msgs
191
+
192
+ def _generate_chat_data(
193
+ self, messages: List[Dict], config: PytorchGenerateConfig = {}
194
+ ):
195
+ from qwen_vl_utils import process_vision_info
196
+
197
+ messages_ovis = self.parse_messages_ovis(messages)
198
+ max_partition = None
199
+ prompt = messages_ovis[-1]["value"]
200
+
201
+ # Preparation for inference
202
+ image_inputs, video_inputs = process_vision_info(messages)
203
+
204
+ image_inputs = image_inputs if image_inputs else []
205
+
206
+ if image_inputs and len(image_inputs) > 0:
207
+ if len(image_inputs) == 1:
208
+ max_partition = 9
209
+ prompt = f"<image>\n{prompt}"
210
+ else:
211
+ max_partition = len(image_inputs) + 1
212
+ prompt = (
213
+ "\n".join(
214
+ [f"Image {i+1}: <image>" for i in range(len(image_inputs))]
215
+ )
216
+ + "\n"
217
+ + prompt
218
+ )
219
+ elif video_inputs and len(video_inputs) > 0:
220
+ if isinstance(video_inputs[0], torch.Tensor):
221
+ # Convert from list[Tensor] to list[Image]
222
+ pil_images = self._convert_video_tensors_to_pil(video_inputs)
223
+
224
+ video_inputs = pil_images # Update video_inputs to PIL image list
225
+
226
+ max_partition = 1
227
+ image_inputs = video_inputs
228
+ prompt = "\n".join(["<image>"] * len(video_inputs)) + "\n" + prompt
229
+ else:
230
+ max_partition = 0
231
+ prompt = prompt
232
+
233
+ messages_ovis[-1]["value"] = prompt
234
+
235
+ # format conversation
236
+ prompt, input_ids, pixel_values = self._model.preprocess_inputs(
237
+ messages_ovis, image_inputs, max_partition=max_partition
238
+ )
239
+
240
+ attention_mask = torch.ne(input_ids, self._text_tokenizer.pad_token_id)
241
+ input_ids = input_ids.unsqueeze(0).to(device=self._model.device)
242
+ attention_mask = attention_mask.unsqueeze(0).to(device=self._model.device)
243
+ if pixel_values is not None:
244
+ pixel_values = pixel_values.to(
245
+ dtype=self._visual_tokenizer.dtype, device=self._visual_tokenizer.device
246
+ )
247
+ pixel_values = [pixel_values]
248
+
249
+ gen_kwargs = dict(
250
+ max_new_tokens=config.get("max_tokens", 1024),
251
+ do_sample=False,
252
+ top_p=None,
253
+ top_k=None,
254
+ temperature=config.get("temperature", None),
255
+ repetition_penalty=None,
256
+ eos_token_id=self._model.generation_config.eos_token_id,
257
+ pad_token_id=self._text_tokenizer.pad_token_id,
258
+ use_cache=True,
259
+ )
260
+
261
+ return input_ids, attention_mask, pixel_values, gen_kwargs
262
+
263
+ def _convert_video_tensors_to_pil(self, video_inputs: List) -> List[Image.Image]:
264
+ """Convert video tensors to a list of PIL images"""
265
+ from torchvision import transforms
266
+
267
+ to_pil = transforms.ToPILImage()
268
+ pil_images = []
269
+
270
+ for video_tensor_4d in video_inputs:
271
+ if isinstance(video_tensor_4d, torch.Tensor):
272
+ # Verify it's a 4D tensor
273
+ if video_tensor_4d.ndim == 4:
274
+ # Iterate through the first dimension (frames) of 4D tensor
275
+ for i in range(video_tensor_4d.size(0)):
276
+ frame_tensor_3d = video_tensor_4d[
277
+ i
278
+ ] # Get 3D frame tensor [C, H, W]
279
+ # Ensure tensor is on CPU before conversion
280
+ if frame_tensor_3d.is_cuda:
281
+ frame_tensor_3d = frame_tensor_3d.cpu()
282
+ try:
283
+ pil_image = to_pil(frame_tensor_3d)
284
+ pil_images.append(pil_image)
285
+ except Exception as e:
286
+ logger.error(
287
+ f"Error converting frame {i} to PIL Image: {e}"
288
+ )
289
+ # Can choose to skip this frame or handle error differently
290
+ else:
291
+ logger.warning(
292
+ f"Expected 4D tensor in video_inputs, but got {video_tensor_4d.ndim}D. Skipping this tensor."
293
+ )
294
+ elif isinstance(video_tensor_4d, Image.Image):
295
+ # If fetch_video returns Image list, add directly
296
+ pil_images.append(video_tensor_4d)
297
+ else:
298
+ logger.warning(
299
+ f"Unexpected type in video_inputs: {type(video_tensor_4d)}. Skipping."
300
+ )
301
+
302
+ return pil_images
@@ -56,7 +56,7 @@ class Qwen2_5OmniChatModel(PytorchChatModel):
56
56
  self._processor = None
57
57
 
58
58
  @classmethod
59
- def match(
59
+ def match_json(
60
60
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
61
61
  ) -> bool:
62
62
  if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
@@ -67,6 +67,12 @@ class Qwen2_5OmniChatModel(PytorchChatModel):
67
67
  return False
68
68
 
69
69
  def load(self):
70
+ logger.debug(
71
+ "Try to load model, current python: %s, sys path: %s",
72
+ sys.executable,
73
+ sys.path,
74
+ )
75
+
70
76
  from transformers import (
71
77
  Qwen2_5OmniForConditionalGeneration,
72
78
  Qwen2_5OmniProcessor,
@@ -83,6 +89,7 @@ class Qwen2_5OmniChatModel(PytorchChatModel):
83
89
  if not flash_attn_installed
84
90
  else {"attn_implementation": "flash_attention_2"}
85
91
  )
92
+ kwargs = self.apply_bnb_quantization(kwargs)
86
93
  logger.debug("Loading model with extra kwargs: %s", kwargs)
87
94
 
88
95
  self._processor = Qwen2_5OmniProcessor.from_pretrained(
@@ -42,7 +42,7 @@ class Qwen2AudioChatModel(PytorchChatModel):
42
42
  self._device = None
43
43
 
44
44
  @classmethod
45
- def match(
45
+ def match_json(
46
46
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
48
  llm_family = model_family.model_family or model_family.model_name
@@ -58,6 +58,7 @@ class Qwen2AudioChatModel(PytorchChatModel):
58
58
  # for multiple GPU, set back to auto to make multiple devices work
59
59
  device = "auto" if device == "cuda" else device
60
60
  self._device = device
61
+ kwargs = self.apply_bnb_quantization()
61
62
 
62
63
  self._processor = AutoProcessor.from_pretrained(
63
64
  self.model_path,
@@ -70,6 +71,7 @@ class Qwen2AudioChatModel(PytorchChatModel):
70
71
  device_map=device,
71
72
  # trust_remote_code=True,
72
73
  revision=self.model_spec.model_revision,
74
+ **kwargs,
73
75
  )
74
76
 
75
77
  def _transform_messages(
@@ -54,7 +54,7 @@ class Qwen2VLChatModel(PytorchChatModel):
54
54
  return pytorch_model_config
55
55
 
56
56
  @classmethod
57
- def match(
57
+ def match_json(
58
58
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
59
59
  ) -> bool:
60
60
  if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
@@ -81,6 +81,8 @@ class Qwen2VLChatModel(PytorchChatModel):
81
81
  self._device = device
82
82
  # for multiple GPU, set back to auto to make multiple devices work
83
83
  device = "auto" if device == "cuda" else device
84
+ kwargs = self.apply_bnb_quantization()
85
+
84
86
  min_pixels = self._pytorch_model_config.get("min_pixels")
85
87
  max_pixels = self._pytorch_model_config.get("max_pixels")
86
88
  self._processor = AutoProcessor.from_pretrained(
@@ -106,6 +108,7 @@ class Qwen2VLChatModel(PytorchChatModel):
106
108
  device_map=device,
107
109
  attn_implementation="flash_attention_2",
108
110
  trust_remote_code=True,
111
+ **kwargs,
109
112
  ).eval()
110
113
  elif is_npu_available():
111
114
  # Ascend do not support bf16
@@ -114,6 +117,7 @@ class Qwen2VLChatModel(PytorchChatModel):
114
117
  device_map="auto",
115
118
  trust_remote_code=True,
116
119
  torch_dtype="float16",
120
+ **kwargs,
117
121
  ).eval()
118
122
  else:
119
123
  self._model = model_cls.from_pretrained(
@@ -41,7 +41,7 @@ class QwenVLChatModel(PytorchChatModel):
41
41
  self._device = None
42
42
 
43
43
  @classmethod
44
- def match(
44
+ def match_json(
45
45
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
46
46
  ) -> bool:
47
47
  llm_family = model_family.model_family or model_family.model_name
@@ -66,6 +66,8 @@ class QwenVLChatModel(PytorchChatModel):
66
66
  # for multiple GPU, set back to auto to make multiple devices work
67
67
  device = "auto" if device == "cuda" else device
68
68
 
69
+ kwargs = self.apply_bnb_quantization()
70
+
69
71
  self._tokenizer = AutoTokenizer.from_pretrained(
70
72
  self.model_path,
71
73
  trust_remote_code=True,
@@ -76,6 +78,7 @@ class QwenVLChatModel(PytorchChatModel):
76
78
  device_map=device,
77
79
  trust_remote_code=True,
78
80
  code_revision=self.model_spec.model_revision,
81
+ **kwargs,
79
82
  ).eval()
80
83
 
81
84
  # Specify hyperparameters for generation
@@ -310,7 +313,7 @@ class QwenVLChatModel(PytorchChatModel):
310
313
 
311
314
  return raw_text, context_tokens
312
315
 
313
- def _get_full_prompt(self, messages: List[Dict], tools):
316
+ def _get_full_prompt(self, messages: List[Dict], tools, generate_config: dict): # type: ignore
314
317
  prompt, qwen_history = self._get_prompt_and_chat_history(messages)
315
318
  _, context_tokens = self.make_context(self._tokenizer, prompt, qwen_history)
316
319
  return context_tokens