xinference 1.5.0.post2__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +107 -11
- xinference/client/restful/restful_client.py +51 -11
- xinference/constants.py +5 -1
- xinference/core/media_interface.py +758 -0
- xinference/core/model.py +49 -9
- xinference/core/supervisor.py +1 -1
- xinference/core/utils.py +1 -1
- xinference/core/worker.py +33 -39
- xinference/deploy/cmdline.py +17 -0
- xinference/deploy/utils.py +0 -3
- xinference/model/audio/__init__.py +16 -27
- xinference/model/audio/core.py +2 -1
- xinference/model/audio/cosyvoice.py +4 -2
- xinference/model/audio/model_spec.json +63 -46
- xinference/model/audio/model_spec_modelscope.json +31 -14
- xinference/model/embedding/__init__.py +16 -24
- xinference/model/image/__init__.py +15 -25
- xinference/model/llm/__init__.py +40 -115
- xinference/model/llm/core.py +29 -6
- xinference/model/llm/llama_cpp/core.py +30 -347
- xinference/model/llm/llm_family.json +1674 -2203
- xinference/model/llm/llm_family.py +71 -7
- xinference/model/llm/llm_family_csghub.json +0 -32
- xinference/model/llm/llm_family_modelscope.json +1838 -2016
- xinference/model/llm/llm_family_openmind_hub.json +19 -325
- xinference/model/llm/lmdeploy/core.py +7 -2
- xinference/model/llm/mlx/core.py +23 -7
- xinference/model/llm/reasoning_parser.py +281 -5
- xinference/model/llm/sglang/core.py +39 -11
- xinference/model/llm/transformers/chatglm.py +9 -2
- xinference/model/llm/transformers/cogagent.py +10 -12
- xinference/model/llm/transformers/cogvlm2.py +6 -3
- xinference/model/llm/transformers/cogvlm2_video.py +3 -6
- xinference/model/llm/transformers/core.py +58 -60
- xinference/model/llm/transformers/deepseek_v2.py +4 -2
- xinference/model/llm/transformers/deepseek_vl.py +10 -4
- xinference/model/llm/transformers/deepseek_vl2.py +9 -4
- xinference/model/llm/transformers/gemma3.py +4 -5
- xinference/model/llm/transformers/glm4v.py +3 -21
- xinference/model/llm/transformers/glm_edge_v.py +3 -20
- xinference/model/llm/transformers/intern_vl.py +3 -6
- xinference/model/llm/transformers/internlm2.py +1 -1
- xinference/model/llm/transformers/minicpmv25.py +4 -2
- xinference/model/llm/transformers/minicpmv26.py +5 -3
- xinference/model/llm/transformers/omnilmm.py +1 -1
- xinference/model/llm/transformers/opt.py +1 -1
- xinference/model/llm/transformers/ovis2.py +302 -0
- xinference/model/llm/transformers/qwen-omni.py +8 -1
- xinference/model/llm/transformers/qwen2_audio.py +3 -1
- xinference/model/llm/transformers/qwen2_vl.py +5 -1
- xinference/model/llm/transformers/qwen_vl.py +5 -2
- xinference/model/llm/utils.py +96 -45
- xinference/model/llm/vllm/core.py +108 -24
- xinference/model/llm/vllm/distributed_executor.py +8 -7
- xinference/model/llm/vllm/xavier/allocator.py +1 -1
- xinference/model/llm/vllm/xavier/block_manager.py +1 -1
- xinference/model/llm/vllm/xavier/block_tracker.py +3 -3
- xinference/model/llm/vllm/xavier/executor.py +1 -1
- xinference/model/llm/vllm/xavier/test/test_xavier.py +2 -11
- xinference/model/rerank/__init__.py +13 -24
- xinference/model/video/__init__.py +15 -25
- xinference/model/video/core.py +3 -3
- xinference/model/video/diffusers.py +157 -13
- xinference/model/video/model_spec.json +100 -0
- xinference/model/video/model_spec_modelscope.json +104 -0
- xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
- xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
- xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
- xinference/thirdparty/cosyvoice/bin/train.py +7 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
- xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
- xinference/thirdparty/cosyvoice/cli/model.py +140 -155
- xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
- xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
- xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
- xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
- xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
- xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
- xinference/thirdparty/cosyvoice/utils/common.py +1 -1
- xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
- xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
- xinference/types.py +2 -71
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.0f6523be.css → main.337afe76.css} +2 -2
- xinference/web/ui/build/static/css/main.337afe76.css.map +1 -0
- xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
- xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6798e126f3bc5f95a4c16a9c2ad52ffe77970c62406d83e20604dfda7ffd2247.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b617f7d21a95045fc57b26a9373551740f1978a826134cbf705c3a1bf8714a93.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c1506cb142151366074975f30fa1ff9cd6e5e978b62a4b074dfc16fe08d70d75.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c5c7c2cd1b863ce41adff2c4737bba06eef3a1acf28288cb83d992060f6b8923.json +1 -0
- xinference/web/ui/src/locales/en.json +7 -4
- xinference/web/ui/src/locales/zh.json +7 -4
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/METADATA +56 -36
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/RECORD +120 -121
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/WHEEL +1 -1
- xinference/core/image_interface.py +0 -377
- xinference/model/llm/transformers/compression.py +0 -258
- xinference/model/llm/transformers/yi_vl.py +0 -239
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
- xinference/web/ui/build/static/css/main.0f6523be.css.map +0 -1
- xinference/web/ui/build/static/js/main.4b67a723.js +0 -3
- xinference/web/ui/build/static/js/main.4b67a723.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8f9af2979e45d4648f0cfae108363e58ee421c29a9d4e7329b6f06d9adfd4133.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9c8b1a86e7c65b2b2599a205e30920652d6c2105f926508ef5bcf29a3ef4ce76.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e4ba658c6b3b0490910acdae0c535a892257efb61539a24adf8038fc653bd22f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/efe7cd132c27a8f9fd5352a394c491fd5fb0da0348cf9fcbd923164a32365eab.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f199e8173f6409a5802ed44acb95f218388131136504b2e9132129e150c92f9a.json +0 -1
- /xinference/web/ui/build/static/js/{main.4b67a723.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
import uuid
|
|
16
|
+
from typing import Dict, Iterator, List, Optional, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from PIL import Image
|
|
20
|
+
|
|
21
|
+
from ....types import (
|
|
22
|
+
ChatCompletion,
|
|
23
|
+
ChatCompletionChunk,
|
|
24
|
+
ChatCompletionMessage,
|
|
25
|
+
CompletionChunk,
|
|
26
|
+
)
|
|
27
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
28
|
+
from ..utils import generate_chat_completion, generate_completion_chunk
|
|
29
|
+
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
30
|
+
from .utils import cache_clean
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Ovis2ChatModel(PytorchChatModel):
|
|
36
|
+
def __init__(self, *args, **kwargs):
|
|
37
|
+
super().__init__(*args, **kwargs)
|
|
38
|
+
self._tokenizer = None
|
|
39
|
+
self._model = None
|
|
40
|
+
self._device = None
|
|
41
|
+
self._processor = None
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def match_json(
|
|
45
|
+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
46
|
+
) -> bool:
|
|
47
|
+
if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
48
|
+
return False
|
|
49
|
+
llm_family = model_family.model_family or model_family.model_name
|
|
50
|
+
if "ovis2".lower() in llm_family.lower():
|
|
51
|
+
return True
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
def load(self):
|
|
55
|
+
from transformers import AutoModelForCausalLM
|
|
56
|
+
|
|
57
|
+
# load model
|
|
58
|
+
self._model = AutoModelForCausalLM.from_pretrained(
|
|
59
|
+
self.model_path,
|
|
60
|
+
torch_dtype=torch.bfloat16,
|
|
61
|
+
multimodal_max_length=32768,
|
|
62
|
+
trust_remote_code=True,
|
|
63
|
+
).cuda()
|
|
64
|
+
self._text_tokenizer = self._model.get_text_tokenizer()
|
|
65
|
+
self._visual_tokenizer = self._model.get_visual_tokenizer()
|
|
66
|
+
|
|
67
|
+
@cache_clean
|
|
68
|
+
def chat(
|
|
69
|
+
self,
|
|
70
|
+
messages: List[ChatCompletionMessage], # type: ignore
|
|
71
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
72
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
73
|
+
messages = self._transform_messages(messages)
|
|
74
|
+
|
|
75
|
+
generate_config = generate_config if generate_config else {}
|
|
76
|
+
|
|
77
|
+
stream = generate_config.get("stream", False) if generate_config else False
|
|
78
|
+
|
|
79
|
+
if stream:
|
|
80
|
+
# raise NotImplementedError("Stream is not supported for Ovis2 model.")
|
|
81
|
+
it = self._generate_stream(messages, generate_config)
|
|
82
|
+
return self._to_chat_completion_chunks(it)
|
|
83
|
+
else:
|
|
84
|
+
c = self._generate(messages, generate_config)
|
|
85
|
+
return c
|
|
86
|
+
|
|
87
|
+
def _generate(
|
|
88
|
+
self, messages: List, config: PytorchGenerateConfig = {}
|
|
89
|
+
) -> ChatCompletion:
|
|
90
|
+
input_ids, attention_mask, pixel_values, gen_kwargs = self._generate_chat_data(
|
|
91
|
+
messages, config
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# generate output
|
|
95
|
+
with torch.inference_mode():
|
|
96
|
+
gen_kwargs.update(
|
|
97
|
+
dict(
|
|
98
|
+
pixel_values=pixel_values,
|
|
99
|
+
attention_mask=attention_mask,
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
output_ids = self._model.generate(
|
|
104
|
+
input_ids,
|
|
105
|
+
**gen_kwargs,
|
|
106
|
+
)[0]
|
|
107
|
+
output = self._text_tokenizer.decode(output_ids, skip_special_tokens=True)
|
|
108
|
+
return generate_chat_completion(self.model_uid, output)
|
|
109
|
+
|
|
110
|
+
def _generate_stream(
|
|
111
|
+
self, messages: List, config: PytorchGenerateConfig = {}
|
|
112
|
+
) -> Iterator[CompletionChunk]:
|
|
113
|
+
from threading import Thread
|
|
114
|
+
|
|
115
|
+
from transformers import TextIteratorStreamer
|
|
116
|
+
|
|
117
|
+
input_ids, attention_mask, pixel_values, gen_kwargs = self._generate_chat_data(
|
|
118
|
+
messages, config
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
_, inputs_embeds, _, attention_mask = self._model.merge_multimodal(
|
|
122
|
+
text_input_ids=input_ids,
|
|
123
|
+
text_attention_masks=attention_mask,
|
|
124
|
+
text_labels=None,
|
|
125
|
+
pixel_values=pixel_values,
|
|
126
|
+
left_padding=True,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
streamer = TextIteratorStreamer(
|
|
130
|
+
self._text_tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
gen_kwargs.update(
|
|
134
|
+
dict(
|
|
135
|
+
inputs_embeds=inputs_embeds,
|
|
136
|
+
attention_mask=attention_mask,
|
|
137
|
+
streamer=streamer,
|
|
138
|
+
)
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
inputs_embeds = inputs_embeds.detach()
|
|
142
|
+
torch.cuda.empty_cache()
|
|
143
|
+
|
|
144
|
+
thread = Thread(target=self._model.llm.generate, kwargs=gen_kwargs)
|
|
145
|
+
thread.start()
|
|
146
|
+
|
|
147
|
+
completion_id = str(uuid.uuid1())
|
|
148
|
+
|
|
149
|
+
for new_text in streamer:
|
|
150
|
+
yield generate_completion_chunk(
|
|
151
|
+
chunk_text=new_text,
|
|
152
|
+
finish_reason=None,
|
|
153
|
+
chunk_id=completion_id,
|
|
154
|
+
model_uid=self.model_uid,
|
|
155
|
+
prompt_tokens=-1,
|
|
156
|
+
completion_tokens=-1,
|
|
157
|
+
total_tokens=-1,
|
|
158
|
+
has_choice=True,
|
|
159
|
+
has_content=True,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
yield generate_completion_chunk(
|
|
163
|
+
chunk_text=None,
|
|
164
|
+
finish_reason="stop",
|
|
165
|
+
chunk_id=completion_id,
|
|
166
|
+
model_uid=self.model_uid,
|
|
167
|
+
prompt_tokens=-1,
|
|
168
|
+
completion_tokens=-1,
|
|
169
|
+
total_tokens=-1,
|
|
170
|
+
has_choice=True,
|
|
171
|
+
has_content=False,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def parse_messages_ovis(self, messages: List[Dict]) -> List[Dict]:
|
|
175
|
+
ovis_msgs = []
|
|
176
|
+
for mess in messages:
|
|
177
|
+
contents = mess["content"]
|
|
178
|
+
role = mess["role"]
|
|
179
|
+
if role == "user":
|
|
180
|
+
role = "human"
|
|
181
|
+
elif role == "assistant":
|
|
182
|
+
role = "gpt"
|
|
183
|
+
elif role == "system":
|
|
184
|
+
role = "system"
|
|
185
|
+
|
|
186
|
+
for content in contents:
|
|
187
|
+
if content["type"] == "text":
|
|
188
|
+
ovis_msgs.append({"from": role, "value": content["text"]})
|
|
189
|
+
|
|
190
|
+
return ovis_msgs
|
|
191
|
+
|
|
192
|
+
def _generate_chat_data(
|
|
193
|
+
self, messages: List[Dict], config: PytorchGenerateConfig = {}
|
|
194
|
+
):
|
|
195
|
+
from qwen_vl_utils import process_vision_info
|
|
196
|
+
|
|
197
|
+
messages_ovis = self.parse_messages_ovis(messages)
|
|
198
|
+
max_partition = None
|
|
199
|
+
prompt = messages_ovis[-1]["value"]
|
|
200
|
+
|
|
201
|
+
# Preparation for inference
|
|
202
|
+
image_inputs, video_inputs = process_vision_info(messages)
|
|
203
|
+
|
|
204
|
+
image_inputs = image_inputs if image_inputs else []
|
|
205
|
+
|
|
206
|
+
if image_inputs and len(image_inputs) > 0:
|
|
207
|
+
if len(image_inputs) == 1:
|
|
208
|
+
max_partition = 9
|
|
209
|
+
prompt = f"<image>\n{prompt}"
|
|
210
|
+
else:
|
|
211
|
+
max_partition = len(image_inputs) + 1
|
|
212
|
+
prompt = (
|
|
213
|
+
"\n".join(
|
|
214
|
+
[f"Image {i+1}: <image>" for i in range(len(image_inputs))]
|
|
215
|
+
)
|
|
216
|
+
+ "\n"
|
|
217
|
+
+ prompt
|
|
218
|
+
)
|
|
219
|
+
elif video_inputs and len(video_inputs) > 0:
|
|
220
|
+
if isinstance(video_inputs[0], torch.Tensor):
|
|
221
|
+
# Convert from list[Tensor] to list[Image]
|
|
222
|
+
pil_images = self._convert_video_tensors_to_pil(video_inputs)
|
|
223
|
+
|
|
224
|
+
video_inputs = pil_images # Update video_inputs to PIL image list
|
|
225
|
+
|
|
226
|
+
max_partition = 1
|
|
227
|
+
image_inputs = video_inputs
|
|
228
|
+
prompt = "\n".join(["<image>"] * len(video_inputs)) + "\n" + prompt
|
|
229
|
+
else:
|
|
230
|
+
max_partition = 0
|
|
231
|
+
prompt = prompt
|
|
232
|
+
|
|
233
|
+
messages_ovis[-1]["value"] = prompt
|
|
234
|
+
|
|
235
|
+
# format conversation
|
|
236
|
+
prompt, input_ids, pixel_values = self._model.preprocess_inputs(
|
|
237
|
+
messages_ovis, image_inputs, max_partition=max_partition
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
attention_mask = torch.ne(input_ids, self._text_tokenizer.pad_token_id)
|
|
241
|
+
input_ids = input_ids.unsqueeze(0).to(device=self._model.device)
|
|
242
|
+
attention_mask = attention_mask.unsqueeze(0).to(device=self._model.device)
|
|
243
|
+
if pixel_values is not None:
|
|
244
|
+
pixel_values = pixel_values.to(
|
|
245
|
+
dtype=self._visual_tokenizer.dtype, device=self._visual_tokenizer.device
|
|
246
|
+
)
|
|
247
|
+
pixel_values = [pixel_values]
|
|
248
|
+
|
|
249
|
+
gen_kwargs = dict(
|
|
250
|
+
max_new_tokens=config.get("max_tokens", 1024),
|
|
251
|
+
do_sample=False,
|
|
252
|
+
top_p=None,
|
|
253
|
+
top_k=None,
|
|
254
|
+
temperature=config.get("temperature", None),
|
|
255
|
+
repetition_penalty=None,
|
|
256
|
+
eos_token_id=self._model.generation_config.eos_token_id,
|
|
257
|
+
pad_token_id=self._text_tokenizer.pad_token_id,
|
|
258
|
+
use_cache=True,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
return input_ids, attention_mask, pixel_values, gen_kwargs
|
|
262
|
+
|
|
263
|
+
def _convert_video_tensors_to_pil(self, video_inputs: List) -> List[Image.Image]:
|
|
264
|
+
"""Convert video tensors to a list of PIL images"""
|
|
265
|
+
from torchvision import transforms
|
|
266
|
+
|
|
267
|
+
to_pil = transforms.ToPILImage()
|
|
268
|
+
pil_images = []
|
|
269
|
+
|
|
270
|
+
for video_tensor_4d in video_inputs:
|
|
271
|
+
if isinstance(video_tensor_4d, torch.Tensor):
|
|
272
|
+
# Verify it's a 4D tensor
|
|
273
|
+
if video_tensor_4d.ndim == 4:
|
|
274
|
+
# Iterate through the first dimension (frames) of 4D tensor
|
|
275
|
+
for i in range(video_tensor_4d.size(0)):
|
|
276
|
+
frame_tensor_3d = video_tensor_4d[
|
|
277
|
+
i
|
|
278
|
+
] # Get 3D frame tensor [C, H, W]
|
|
279
|
+
# Ensure tensor is on CPU before conversion
|
|
280
|
+
if frame_tensor_3d.is_cuda:
|
|
281
|
+
frame_tensor_3d = frame_tensor_3d.cpu()
|
|
282
|
+
try:
|
|
283
|
+
pil_image = to_pil(frame_tensor_3d)
|
|
284
|
+
pil_images.append(pil_image)
|
|
285
|
+
except Exception as e:
|
|
286
|
+
logger.error(
|
|
287
|
+
f"Error converting frame {i} to PIL Image: {e}"
|
|
288
|
+
)
|
|
289
|
+
# Can choose to skip this frame or handle error differently
|
|
290
|
+
else:
|
|
291
|
+
logger.warning(
|
|
292
|
+
f"Expected 4D tensor in video_inputs, but got {video_tensor_4d.ndim}D. Skipping this tensor."
|
|
293
|
+
)
|
|
294
|
+
elif isinstance(video_tensor_4d, Image.Image):
|
|
295
|
+
# If fetch_video returns Image list, add directly
|
|
296
|
+
pil_images.append(video_tensor_4d)
|
|
297
|
+
else:
|
|
298
|
+
logger.warning(
|
|
299
|
+
f"Unexpected type in video_inputs: {type(video_tensor_4d)}. Skipping."
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
return pil_images
|
|
@@ -56,7 +56,7 @@ class Qwen2_5OmniChatModel(PytorchChatModel):
|
|
|
56
56
|
self._processor = None
|
|
57
57
|
|
|
58
58
|
@classmethod
|
|
59
|
-
def
|
|
59
|
+
def match_json(
|
|
60
60
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
61
61
|
) -> bool:
|
|
62
62
|
if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
@@ -67,6 +67,12 @@ class Qwen2_5OmniChatModel(PytorchChatModel):
|
|
|
67
67
|
return False
|
|
68
68
|
|
|
69
69
|
def load(self):
|
|
70
|
+
logger.debug(
|
|
71
|
+
"Try to load model, current python: %s, sys path: %s",
|
|
72
|
+
sys.executable,
|
|
73
|
+
sys.path,
|
|
74
|
+
)
|
|
75
|
+
|
|
70
76
|
from transformers import (
|
|
71
77
|
Qwen2_5OmniForConditionalGeneration,
|
|
72
78
|
Qwen2_5OmniProcessor,
|
|
@@ -83,6 +89,7 @@ class Qwen2_5OmniChatModel(PytorchChatModel):
|
|
|
83
89
|
if not flash_attn_installed
|
|
84
90
|
else {"attn_implementation": "flash_attention_2"}
|
|
85
91
|
)
|
|
92
|
+
kwargs = self.apply_bnb_quantization(kwargs)
|
|
86
93
|
logger.debug("Loading model with extra kwargs: %s", kwargs)
|
|
87
94
|
|
|
88
95
|
self._processor = Qwen2_5OmniProcessor.from_pretrained(
|
|
@@ -42,7 +42,7 @@ class Qwen2AudioChatModel(PytorchChatModel):
|
|
|
42
42
|
self._device = None
|
|
43
43
|
|
|
44
44
|
@classmethod
|
|
45
|
-
def
|
|
45
|
+
def match_json(
|
|
46
46
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
47
47
|
) -> bool:
|
|
48
48
|
llm_family = model_family.model_family or model_family.model_name
|
|
@@ -58,6 +58,7 @@ class Qwen2AudioChatModel(PytorchChatModel):
|
|
|
58
58
|
# for multiple GPU, set back to auto to make multiple devices work
|
|
59
59
|
device = "auto" if device == "cuda" else device
|
|
60
60
|
self._device = device
|
|
61
|
+
kwargs = self.apply_bnb_quantization()
|
|
61
62
|
|
|
62
63
|
self._processor = AutoProcessor.from_pretrained(
|
|
63
64
|
self.model_path,
|
|
@@ -70,6 +71,7 @@ class Qwen2AudioChatModel(PytorchChatModel):
|
|
|
70
71
|
device_map=device,
|
|
71
72
|
# trust_remote_code=True,
|
|
72
73
|
revision=self.model_spec.model_revision,
|
|
74
|
+
**kwargs,
|
|
73
75
|
)
|
|
74
76
|
|
|
75
77
|
def _transform_messages(
|
|
@@ -54,7 +54,7 @@ class Qwen2VLChatModel(PytorchChatModel):
|
|
|
54
54
|
return pytorch_model_config
|
|
55
55
|
|
|
56
56
|
@classmethod
|
|
57
|
-
def
|
|
57
|
+
def match_json(
|
|
58
58
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
59
59
|
) -> bool:
|
|
60
60
|
if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
@@ -81,6 +81,8 @@ class Qwen2VLChatModel(PytorchChatModel):
|
|
|
81
81
|
self._device = device
|
|
82
82
|
# for multiple GPU, set back to auto to make multiple devices work
|
|
83
83
|
device = "auto" if device == "cuda" else device
|
|
84
|
+
kwargs = self.apply_bnb_quantization()
|
|
85
|
+
|
|
84
86
|
min_pixels = self._pytorch_model_config.get("min_pixels")
|
|
85
87
|
max_pixels = self._pytorch_model_config.get("max_pixels")
|
|
86
88
|
self._processor = AutoProcessor.from_pretrained(
|
|
@@ -106,6 +108,7 @@ class Qwen2VLChatModel(PytorchChatModel):
|
|
|
106
108
|
device_map=device,
|
|
107
109
|
attn_implementation="flash_attention_2",
|
|
108
110
|
trust_remote_code=True,
|
|
111
|
+
**kwargs,
|
|
109
112
|
).eval()
|
|
110
113
|
elif is_npu_available():
|
|
111
114
|
# Ascend do not support bf16
|
|
@@ -114,6 +117,7 @@ class Qwen2VLChatModel(PytorchChatModel):
|
|
|
114
117
|
device_map="auto",
|
|
115
118
|
trust_remote_code=True,
|
|
116
119
|
torch_dtype="float16",
|
|
120
|
+
**kwargs,
|
|
117
121
|
).eval()
|
|
118
122
|
else:
|
|
119
123
|
self._model = model_cls.from_pretrained(
|
|
@@ -41,7 +41,7 @@ class QwenVLChatModel(PytorchChatModel):
|
|
|
41
41
|
self._device = None
|
|
42
42
|
|
|
43
43
|
@classmethod
|
|
44
|
-
def
|
|
44
|
+
def match_json(
|
|
45
45
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
46
46
|
) -> bool:
|
|
47
47
|
llm_family = model_family.model_family or model_family.model_name
|
|
@@ -66,6 +66,8 @@ class QwenVLChatModel(PytorchChatModel):
|
|
|
66
66
|
# for multiple GPU, set back to auto to make multiple devices work
|
|
67
67
|
device = "auto" if device == "cuda" else device
|
|
68
68
|
|
|
69
|
+
kwargs = self.apply_bnb_quantization()
|
|
70
|
+
|
|
69
71
|
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
70
72
|
self.model_path,
|
|
71
73
|
trust_remote_code=True,
|
|
@@ -76,6 +78,7 @@ class QwenVLChatModel(PytorchChatModel):
|
|
|
76
78
|
device_map=device,
|
|
77
79
|
trust_remote_code=True,
|
|
78
80
|
code_revision=self.model_spec.model_revision,
|
|
81
|
+
**kwargs,
|
|
79
82
|
).eval()
|
|
80
83
|
|
|
81
84
|
# Specify hyperparameters for generation
|
|
@@ -310,7 +313,7 @@ class QwenVLChatModel(PytorchChatModel):
|
|
|
310
313
|
|
|
311
314
|
return raw_text, context_tokens
|
|
312
315
|
|
|
313
|
-
def _get_full_prompt(self, messages: List[Dict], tools):
|
|
316
|
+
def _get_full_prompt(self, messages: List[Dict], tools, generate_config: dict): # type: ignore
|
|
314
317
|
prompt, qwen_history = self._get_prompt_and_chat_history(messages)
|
|
315
318
|
_, context_tokens = self.make_context(self._tokenizer, prompt, qwen_history)
|
|
316
319
|
return context_tokens
|