xinference 1.6.0.post1__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/client/restful/restful_client.py +1 -1
- xinference/conftest.py +0 -7
- xinference/core/media_interface.py +9 -8
- xinference/core/model.py +13 -6
- xinference/core/scheduler.py +1 -10
- xinference/core/worker.py +0 -10
- xinference/model/audio/model_spec.json +53 -1
- xinference/model/audio/model_spec_modelscope.json +57 -1
- xinference/model/embedding/core.py +19 -11
- xinference/model/image/model_spec.json +10 -1
- xinference/model/image/model_spec_modelscope.json +20 -0
- xinference/model/llm/__init__.py +6 -54
- xinference/model/llm/core.py +19 -5
- xinference/model/llm/llama_cpp/core.py +59 -3
- xinference/model/llm/llama_cpp/memory.py +455 -0
- xinference/model/llm/llm_family.json +185 -397
- xinference/model/llm/llm_family.py +88 -16
- xinference/model/llm/llm_family_modelscope.json +199 -421
- xinference/model/llm/llm_family_openmind_hub.json +0 -34
- xinference/model/llm/sglang/core.py +4 -0
- xinference/model/llm/transformers/__init__.py +27 -6
- xinference/model/llm/transformers/chatglm.py +4 -2
- xinference/model/llm/transformers/core.py +49 -28
- xinference/model/llm/transformers/deepseek_v2.py +6 -49
- xinference/model/llm/transformers/gemma3.py +119 -164
- xinference/{thirdparty/omnilmm/train → model/llm/transformers/multimodal}/__init__.py +1 -1
- xinference/model/llm/transformers/{cogagent.py → multimodal/cogagent.py} +58 -95
- xinference/model/llm/transformers/multimodal/core.py +205 -0
- xinference/model/llm/transformers/{deepseek_vl2.py → multimodal/deepseek_vl2.py} +59 -120
- xinference/model/llm/transformers/multimodal/gemma3.py +117 -0
- xinference/model/llm/transformers/{glm4v.py → multimodal/glm4v.py} +57 -93
- xinference/model/llm/transformers/multimodal/intern_vl.py +412 -0
- xinference/model/llm/transformers/{minicpmv26.py → multimodal/minicpmv26.py} +55 -102
- xinference/model/llm/transformers/{ovis2.py → multimodal/ovis2.py} +114 -175
- xinference/model/llm/transformers/{qwen-omni.py → multimodal/qwen-omni.py} +82 -167
- xinference/model/llm/transformers/multimodal/qwen2_audio.py +131 -0
- xinference/model/llm/transformers/{qwen2_vl.py → multimodal/qwen2_vl.py} +224 -256
- xinference/model/llm/transformers/opt.py +4 -2
- xinference/model/llm/transformers/utils.py +6 -37
- xinference/model/llm/vllm/core.py +4 -0
- xinference/model/rerank/core.py +7 -1
- xinference/model/rerank/utils.py +17 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.ddf9eaee.js +3 -0
- xinference/web/ui/build/static/js/main.ddf9eaee.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/12e637ed5fa9ca6491b03892b6949c03afd4960fe36ac25744488e7e1982aa19.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/567e49df411efb24425d289bb484758cb57067ca54f8b5c67fe4505f698deb96.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77ac2665a784e99501ae95d32ef5937837a0439a47e965d291b38e99cb619f5b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d4ed4e82bfe69915999ec83f5feaa4301c75ecc6bdf1c78f2d03e4671ecbefc8.json +1 -0
- xinference/web/ui/src/locales/en.json +3 -1
- xinference/web/ui/src/locales/zh.json +3 -1
- {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/METADATA +6 -4
- {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/RECORD +60 -76
- {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/WHEEL +1 -1
- xinference/model/llm/transformers/cogvlm2.py +0 -442
- xinference/model/llm/transformers/cogvlm2_video.py +0 -333
- xinference/model/llm/transformers/deepseek_vl.py +0 -280
- xinference/model/llm/transformers/glm_edge_v.py +0 -213
- xinference/model/llm/transformers/intern_vl.py +0 -526
- xinference/model/llm/transformers/internlm2.py +0 -94
- xinference/model/llm/transformers/minicpmv25.py +0 -193
- xinference/model/llm/transformers/omnilmm.py +0 -132
- xinference/model/llm/transformers/qwen2_audio.py +0 -179
- xinference/model/llm/transformers/qwen_vl.py +0 -360
- xinference/thirdparty/omnilmm/LICENSE +0 -201
- xinference/thirdparty/omnilmm/__init__.py +0 -0
- xinference/thirdparty/omnilmm/chat.py +0 -218
- xinference/thirdparty/omnilmm/constants.py +0 -4
- xinference/thirdparty/omnilmm/conversation.py +0 -332
- xinference/thirdparty/omnilmm/model/__init__.py +0 -1
- xinference/thirdparty/omnilmm/model/omnilmm.py +0 -595
- xinference/thirdparty/omnilmm/model/resampler.py +0 -166
- xinference/thirdparty/omnilmm/model/utils.py +0 -578
- xinference/thirdparty/omnilmm/train/train_utils.py +0 -150
- xinference/thirdparty/omnilmm/utils.py +0 -134
- xinference/web/ui/build/static/js/main.ae579a97.js +0 -3
- xinference/web/ui/build/static/js/main.ae579a97.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +0 -1
- /xinference/web/ui/build/static/js/{main.ae579a97.js.LICENSE.txt → main.ddf9eaee.js.LICENSE.txt} +0 -0
- {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.6.0.post1.dist-info → xinference-1.6.1.dist-info}/top_level.txt +0 -0
|
@@ -1,193 +0,0 @@
|
|
|
1
|
-
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
import json
|
|
15
|
-
import logging
|
|
16
|
-
import uuid
|
|
17
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
18
|
-
from typing import Dict, Iterator, List, Optional, Union
|
|
19
|
-
|
|
20
|
-
import torch
|
|
21
|
-
|
|
22
|
-
from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
|
|
23
|
-
from ...utils import select_device
|
|
24
|
-
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
25
|
-
from ..utils import (
|
|
26
|
-
_decode_image,
|
|
27
|
-
generate_chat_completion,
|
|
28
|
-
generate_completion_chunk,
|
|
29
|
-
parse_messages,
|
|
30
|
-
)
|
|
31
|
-
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
32
|
-
from .utils import cache_clean
|
|
33
|
-
|
|
34
|
-
logger = logging.getLogger(__name__)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class MiniCPMV25Model(PytorchChatModel):
|
|
38
|
-
def __init__(self, *args, **kwargs):
|
|
39
|
-
super().__init__(*args, **kwargs)
|
|
40
|
-
self._device = None
|
|
41
|
-
self._tokenizer = None
|
|
42
|
-
self._model = None
|
|
43
|
-
|
|
44
|
-
@classmethod
|
|
45
|
-
def match_json(
|
|
46
|
-
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
47
|
-
) -> bool:
|
|
48
|
-
family = model_family.model_family or model_family.model_name
|
|
49
|
-
if "MiniCPM-Llama3-V-2_5".lower() in family.lower():
|
|
50
|
-
return True
|
|
51
|
-
return False
|
|
52
|
-
|
|
53
|
-
def _get_model_class(self):
|
|
54
|
-
from transformers import AutoModel
|
|
55
|
-
|
|
56
|
-
return AutoModel
|
|
57
|
-
|
|
58
|
-
def load(self):
|
|
59
|
-
from transformers import AutoModel, AutoTokenizer
|
|
60
|
-
from transformers.generation import GenerationConfig
|
|
61
|
-
|
|
62
|
-
device = self._pytorch_model_config.get("device", "auto")
|
|
63
|
-
self._device = select_device(device)
|
|
64
|
-
self._device = "auto" if self._device == "cuda" else self._device
|
|
65
|
-
|
|
66
|
-
if "int4" in self.model_path and device == "mps":
|
|
67
|
-
logger.error(
|
|
68
|
-
"Error: running int4 model with bitsandbytes on Mac is not supported right now."
|
|
69
|
-
)
|
|
70
|
-
exit()
|
|
71
|
-
|
|
72
|
-
if self._check_tensorizer_integrity():
|
|
73
|
-
self._model, self._tokenizer = self._load_tensorizer()
|
|
74
|
-
return
|
|
75
|
-
|
|
76
|
-
if "int4" in self.model_path:
|
|
77
|
-
model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
|
|
78
|
-
else:
|
|
79
|
-
kwargs = self.apply_bnb_quantization()
|
|
80
|
-
model = AutoModel.from_pretrained(
|
|
81
|
-
self.model_path,
|
|
82
|
-
trust_remote_code=True,
|
|
83
|
-
torch_dtype=torch.float16,
|
|
84
|
-
device_map=self._device,
|
|
85
|
-
**kwargs
|
|
86
|
-
)
|
|
87
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
88
|
-
self.model_path, trust_remote_code=True
|
|
89
|
-
)
|
|
90
|
-
self._model = model.eval()
|
|
91
|
-
self._tokenizer = tokenizer
|
|
92
|
-
|
|
93
|
-
# Specify hyperparameters for generation
|
|
94
|
-
self._model.generation_config = GenerationConfig.from_pretrained(
|
|
95
|
-
self.model_path,
|
|
96
|
-
trust_remote_code=True,
|
|
97
|
-
)
|
|
98
|
-
self._save_tensorizer()
|
|
99
|
-
|
|
100
|
-
def _message_content_to_chat(self, content):
|
|
101
|
-
if not isinstance(content, str):
|
|
102
|
-
texts = []
|
|
103
|
-
image_urls = []
|
|
104
|
-
for c in content:
|
|
105
|
-
c_type = c.get("type")
|
|
106
|
-
if c_type == "text":
|
|
107
|
-
texts.append(c["text"])
|
|
108
|
-
elif c_type == "image_url":
|
|
109
|
-
image_urls.append(c["image_url"]["url"])
|
|
110
|
-
image_futures = []
|
|
111
|
-
with ThreadPoolExecutor() as executor:
|
|
112
|
-
for image_url in image_urls:
|
|
113
|
-
fut = executor.submit(_decode_image, image_url)
|
|
114
|
-
image_futures.append(fut)
|
|
115
|
-
images = [fut.result() for fut in image_futures]
|
|
116
|
-
text = " ".join(texts)
|
|
117
|
-
if len(images) == 0:
|
|
118
|
-
return text, []
|
|
119
|
-
elif len(images) == 1:
|
|
120
|
-
return text, images
|
|
121
|
-
else:
|
|
122
|
-
raise RuntimeError("Only one image per message is supported")
|
|
123
|
-
return content, []
|
|
124
|
-
|
|
125
|
-
@cache_clean
|
|
126
|
-
def chat(
|
|
127
|
-
self,
|
|
128
|
-
messages: List[Dict],
|
|
129
|
-
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
130
|
-
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
131
|
-
stream = generate_config.get("stream", False) if generate_config else False
|
|
132
|
-
prompt, _, chat_history = parse_messages(messages)
|
|
133
|
-
content, images_chat = self._message_content_to_chat(prompt)
|
|
134
|
-
|
|
135
|
-
msgs = []
|
|
136
|
-
query_to_response: List[Dict] = []
|
|
137
|
-
images_history = []
|
|
138
|
-
for h in chat_history or []:
|
|
139
|
-
role = h["role"]
|
|
140
|
-
content_h, images_tmp = self._message_content_to_chat(h["content"])
|
|
141
|
-
if images_tmp != []:
|
|
142
|
-
images_history = images_tmp
|
|
143
|
-
if len(query_to_response) == 0 and role == "user":
|
|
144
|
-
query_to_response.append({"role": "user", "content": content_h})
|
|
145
|
-
if len(query_to_response) == 1 and role == "assistant":
|
|
146
|
-
query_to_response.append({"role": "assistant", "content": content_h})
|
|
147
|
-
if len(query_to_response) == 2:
|
|
148
|
-
msgs.extend(query_to_response)
|
|
149
|
-
query_to_response = []
|
|
150
|
-
image = None
|
|
151
|
-
if len(images_chat) > 0:
|
|
152
|
-
image = images_chat[0]
|
|
153
|
-
elif len(images_history) > 0:
|
|
154
|
-
image = images_history[0]
|
|
155
|
-
msgs.append({"role": "user", "content": content})
|
|
156
|
-
|
|
157
|
-
chat = self._model.chat(
|
|
158
|
-
image=image,
|
|
159
|
-
msgs=json.dumps(msgs, ensure_ascii=True),
|
|
160
|
-
tokenizer=self._tokenizer,
|
|
161
|
-
sampling=True,
|
|
162
|
-
**generate_config
|
|
163
|
-
)
|
|
164
|
-
if stream:
|
|
165
|
-
it = self.chat_stream(chat)
|
|
166
|
-
return self._to_chat_completion_chunks(it)
|
|
167
|
-
else:
|
|
168
|
-
return generate_chat_completion(self.model_uid, chat)
|
|
169
|
-
|
|
170
|
-
def chat_stream(self, chat) -> Iterator[CompletionChunk]:
|
|
171
|
-
completion_id = str(uuid.uuid1())
|
|
172
|
-
for new_text in chat:
|
|
173
|
-
yield generate_completion_chunk(
|
|
174
|
-
chunk_text=new_text,
|
|
175
|
-
finish_reason=None,
|
|
176
|
-
chunk_id=completion_id,
|
|
177
|
-
model_uid=self.model_uid,
|
|
178
|
-
prompt_tokens=-1,
|
|
179
|
-
completion_tokens=-1,
|
|
180
|
-
total_tokens=-1,
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
yield generate_completion_chunk(
|
|
184
|
-
chunk_text=None,
|
|
185
|
-
finish_reason="stop",
|
|
186
|
-
chunk_id=completion_id,
|
|
187
|
-
model_uid=self.model_uid,
|
|
188
|
-
prompt_tokens=-1,
|
|
189
|
-
completion_tokens=-1,
|
|
190
|
-
total_tokens=-1,
|
|
191
|
-
has_choice=True,
|
|
192
|
-
has_content=False,
|
|
193
|
-
)
|
|
@@ -1,132 +0,0 @@
|
|
|
1
|
-
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
import base64
|
|
15
|
-
import json
|
|
16
|
-
import logging
|
|
17
|
-
import operator
|
|
18
|
-
import tempfile
|
|
19
|
-
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
|
20
|
-
|
|
21
|
-
from ....thirdparty.omnilmm.chat import OmniLMMChat, img2base64
|
|
22
|
-
from ....types import ChatCompletion, ChatCompletionChunk
|
|
23
|
-
from ...utils import select_device
|
|
24
|
-
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
25
|
-
from ..utils import generate_chat_completion, parse_messages
|
|
26
|
-
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
27
|
-
from .utils import cache_clean
|
|
28
|
-
|
|
29
|
-
logger = logging.getLogger(__name__)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class OmniLMMModel(PytorchChatModel):
|
|
33
|
-
def __init__(self, *args, **kwargs):
|
|
34
|
-
super().__init__(*args, **kwargs)
|
|
35
|
-
self._model = None
|
|
36
|
-
|
|
37
|
-
@classmethod
|
|
38
|
-
def match_json(
|
|
39
|
-
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
40
|
-
) -> bool:
|
|
41
|
-
llm_family = model_family.model_family or model_family.model_name
|
|
42
|
-
if "OmniLMM" in llm_family:
|
|
43
|
-
return True
|
|
44
|
-
return False
|
|
45
|
-
|
|
46
|
-
def load(self):
|
|
47
|
-
device = self._pytorch_model_config.get("device", "auto")
|
|
48
|
-
device = select_device(device)
|
|
49
|
-
self._model = OmniLMMChat(self.model_path, device_map=device)
|
|
50
|
-
|
|
51
|
-
def _message_content_to_OmniLMM(
|
|
52
|
-
self, content
|
|
53
|
-
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
|
54
|
-
def _ensure_url(_url):
|
|
55
|
-
if _url.startswith("data:"):
|
|
56
|
-
logging.info("Parse url by base64 decoder.")
|
|
57
|
-
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
58
|
-
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
59
|
-
_type, data = _url.split(";")
|
|
60
|
-
_, ext = _type.split("/")
|
|
61
|
-
data = data[len("base64,") :]
|
|
62
|
-
data = base64.b64decode(data.encode("utf-8"))
|
|
63
|
-
|
|
64
|
-
with tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False) as f:
|
|
65
|
-
f.write(data)
|
|
66
|
-
logging.info("Dump base64 data to %s", f.name)
|
|
67
|
-
return f.name
|
|
68
|
-
else:
|
|
69
|
-
if len(_url) > 2048:
|
|
70
|
-
raise Exception(f"Image url is too long, {len(_url)} > 2048.")
|
|
71
|
-
return _url
|
|
72
|
-
|
|
73
|
-
if not isinstance(content, str):
|
|
74
|
-
images = []
|
|
75
|
-
other_content = []
|
|
76
|
-
|
|
77
|
-
for c in content:
|
|
78
|
-
if c.get("type") == "image_url":
|
|
79
|
-
images.append(
|
|
80
|
-
{"image": _ensure_url(c["image_url"]["url"]), "type": "image"}
|
|
81
|
-
)
|
|
82
|
-
else:
|
|
83
|
-
other_content.append(c)
|
|
84
|
-
|
|
85
|
-
images = sorted(images, key=operator.itemgetter("type"))
|
|
86
|
-
other_content = sorted(other_content, key=operator.itemgetter("type"))
|
|
87
|
-
|
|
88
|
-
return images, other_content
|
|
89
|
-
return [], [{"type": "text", "text": content}]
|
|
90
|
-
|
|
91
|
-
@cache_clean
|
|
92
|
-
def chat(
|
|
93
|
-
self,
|
|
94
|
-
messages: List[Dict],
|
|
95
|
-
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
96
|
-
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
97
|
-
if generate_config and generate_config.get("stream"):
|
|
98
|
-
raise Exception(
|
|
99
|
-
f"Chat with model {self.model_family.model_name} does not support stream."
|
|
100
|
-
)
|
|
101
|
-
prompt, _, chat_history = parse_messages(messages)
|
|
102
|
-
image_first, prompt = self._message_content_to_OmniLMM(prompt)
|
|
103
|
-
|
|
104
|
-
msgs = []
|
|
105
|
-
query_to_response: List[Dict] = []
|
|
106
|
-
image_another = []
|
|
107
|
-
for h in chat_history or []:
|
|
108
|
-
role = h["role"]
|
|
109
|
-
image_tmp, content = self._message_content_to_OmniLMM(h["content"])
|
|
110
|
-
if image_tmp != []:
|
|
111
|
-
image_another = image_tmp
|
|
112
|
-
if len(query_to_response) == 0 and role == "user":
|
|
113
|
-
query_to_response.append(
|
|
114
|
-
{"role": "user", "content": content[0]["text"]}
|
|
115
|
-
)
|
|
116
|
-
if len(query_to_response) == 1 and role == "assistant":
|
|
117
|
-
query_to_response.append(
|
|
118
|
-
{"role": "assistant", "content": content[0]["text"]}
|
|
119
|
-
)
|
|
120
|
-
if len(query_to_response) == 2:
|
|
121
|
-
msgs.extend(query_to_response)
|
|
122
|
-
query_to_response = []
|
|
123
|
-
if image_first != []:
|
|
124
|
-
image = image_first
|
|
125
|
-
if image_another != []:
|
|
126
|
-
image = image_another
|
|
127
|
-
im_64 = img2base64(image[0]["image"])
|
|
128
|
-
msgs.append({"role": "user", "content": prompt[0]["text"]})
|
|
129
|
-
input = {"image": im_64, "question": json.dumps(msgs, ensure_ascii=True)}
|
|
130
|
-
answer = self._model.chat(input=input)
|
|
131
|
-
|
|
132
|
-
return generate_chat_completion(self.model_uid, answer)
|
|
@@ -1,179 +0,0 @@
|
|
|
1
|
-
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
import logging
|
|
15
|
-
import uuid
|
|
16
|
-
from io import BytesIO
|
|
17
|
-
from typing import Iterator, List, Optional, Union
|
|
18
|
-
from urllib.request import urlopen
|
|
19
|
-
|
|
20
|
-
import numpy as np
|
|
21
|
-
|
|
22
|
-
from ....model.utils import select_device
|
|
23
|
-
from ....types import (
|
|
24
|
-
ChatCompletion,
|
|
25
|
-
ChatCompletionChunk,
|
|
26
|
-
ChatCompletionMessage,
|
|
27
|
-
CompletionChunk,
|
|
28
|
-
)
|
|
29
|
-
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
30
|
-
from ..utils import generate_chat_completion, generate_completion_chunk
|
|
31
|
-
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
32
|
-
from .utils import cache_clean
|
|
33
|
-
|
|
34
|
-
logger = logging.getLogger(__name__)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class Qwen2AudioChatModel(PytorchChatModel):
|
|
38
|
-
def __init__(self, *args, **kwargs):
|
|
39
|
-
super().__init__(*args, **kwargs)
|
|
40
|
-
self._processor = None
|
|
41
|
-
self._model = None
|
|
42
|
-
self._device = None
|
|
43
|
-
|
|
44
|
-
@classmethod
|
|
45
|
-
def match_json(
|
|
46
|
-
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
47
|
-
) -> bool:
|
|
48
|
-
llm_family = model_family.model_family or model_family.model_name
|
|
49
|
-
if "qwen2-audio".lower() in llm_family.lower():
|
|
50
|
-
return True
|
|
51
|
-
return False
|
|
52
|
-
|
|
53
|
-
def load(self):
|
|
54
|
-
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
|
|
55
|
-
|
|
56
|
-
device = self._pytorch_model_config.get("device", "auto")
|
|
57
|
-
device = select_device(device)
|
|
58
|
-
# for multiple GPU, set back to auto to make multiple devices work
|
|
59
|
-
device = "auto" if device == "cuda" else device
|
|
60
|
-
self._device = device
|
|
61
|
-
kwargs = self.apply_bnb_quantization()
|
|
62
|
-
|
|
63
|
-
self._processor = AutoProcessor.from_pretrained(
|
|
64
|
-
self.model_path,
|
|
65
|
-
device_map=device,
|
|
66
|
-
# trust_remote_code=True,
|
|
67
|
-
code_revision=self.model_spec.model_revision,
|
|
68
|
-
)
|
|
69
|
-
self._model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
|
70
|
-
self.model_path,
|
|
71
|
-
device_map=device,
|
|
72
|
-
# trust_remote_code=True,
|
|
73
|
-
revision=self.model_spec.model_revision,
|
|
74
|
-
**kwargs,
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
def _transform_messages(
|
|
78
|
-
self,
|
|
79
|
-
messages: Union[List[ChatCompletionMessage], List[dict]],
|
|
80
|
-
):
|
|
81
|
-
import librosa
|
|
82
|
-
|
|
83
|
-
text = self._processor.apply_chat_template(
|
|
84
|
-
messages, add_generation_prompt=True, tokenize=False
|
|
85
|
-
)
|
|
86
|
-
audios: List[np.ndarray] = []
|
|
87
|
-
for msg in messages:
|
|
88
|
-
content = msg["content"]
|
|
89
|
-
if isinstance(content, List):
|
|
90
|
-
for item in content: # type: ignore
|
|
91
|
-
if item.get("type") == "audio" and "audio_url" in item:
|
|
92
|
-
audio = librosa.load(
|
|
93
|
-
BytesIO(urlopen(item["audio_url"]).read()),
|
|
94
|
-
sr=self._processor.feature_extractor.sampling_rate,
|
|
95
|
-
)[0]
|
|
96
|
-
audios.append(audio)
|
|
97
|
-
|
|
98
|
-
return text, audios
|
|
99
|
-
|
|
100
|
-
@cache_clean
|
|
101
|
-
def chat(
|
|
102
|
-
self,
|
|
103
|
-
messages: List[ChatCompletionMessage],
|
|
104
|
-
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
105
|
-
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
106
|
-
text, audios = self._transform_messages(messages)
|
|
107
|
-
inputs = self._processor(
|
|
108
|
-
text=text, audios=audios, return_tensors="pt", padding=True
|
|
109
|
-
)
|
|
110
|
-
# Make sure that the inputs and the model are on the same device.
|
|
111
|
-
inputs.data = {k: v.to(self._device) for k, v in inputs.data.items()}
|
|
112
|
-
inputs.input_ids = inputs.input_ids.to(self._device)
|
|
113
|
-
generate_config = generate_config if generate_config else {}
|
|
114
|
-
stream = generate_config.get("stream", False) if generate_config else False
|
|
115
|
-
|
|
116
|
-
if stream:
|
|
117
|
-
it = self._generate_stream(inputs, generate_config)
|
|
118
|
-
return self._to_chat_completion_chunks(it)
|
|
119
|
-
else:
|
|
120
|
-
c = self._generate(inputs, generate_config)
|
|
121
|
-
return c
|
|
122
|
-
|
|
123
|
-
def _generate(self, inputs, config: PytorchGenerateConfig = {}) -> ChatCompletion:
|
|
124
|
-
generate_ids = self._model.generate(
|
|
125
|
-
**inputs,
|
|
126
|
-
max_length=config.get("max_tokens", 512),
|
|
127
|
-
)
|
|
128
|
-
generate_ids = generate_ids[:, inputs.input_ids.size(1) :]
|
|
129
|
-
response = self._processor.batch_decode(
|
|
130
|
-
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
131
|
-
)[0]
|
|
132
|
-
return generate_chat_completion(self.model_uid, response)
|
|
133
|
-
|
|
134
|
-
def _generate_stream(
|
|
135
|
-
self, inputs, config: PytorchGenerateConfig = {}
|
|
136
|
-
) -> Iterator[CompletionChunk]:
|
|
137
|
-
from threading import Thread
|
|
138
|
-
|
|
139
|
-
from transformers import TextIteratorStreamer
|
|
140
|
-
|
|
141
|
-
tokenizer = self._processor.tokenizer
|
|
142
|
-
streamer = TextIteratorStreamer(
|
|
143
|
-
tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
gen_kwargs = {
|
|
147
|
-
"max_new_tokens": config.get("max_tokens", 512),
|
|
148
|
-
"streamer": streamer,
|
|
149
|
-
**inputs,
|
|
150
|
-
}
|
|
151
|
-
|
|
152
|
-
thread = Thread(target=self._model.generate, kwargs=gen_kwargs)
|
|
153
|
-
thread.start()
|
|
154
|
-
|
|
155
|
-
completion_id = str(uuid.uuid1())
|
|
156
|
-
for new_text in streamer:
|
|
157
|
-
yield generate_completion_chunk(
|
|
158
|
-
chunk_text=new_text,
|
|
159
|
-
finish_reason=None,
|
|
160
|
-
chunk_id=completion_id,
|
|
161
|
-
model_uid=self.model_uid,
|
|
162
|
-
prompt_tokens=-1,
|
|
163
|
-
completion_tokens=-1,
|
|
164
|
-
total_tokens=-1,
|
|
165
|
-
has_choice=True,
|
|
166
|
-
has_content=True,
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
yield generate_completion_chunk(
|
|
170
|
-
chunk_text=None,
|
|
171
|
-
finish_reason="stop",
|
|
172
|
-
chunk_id=completion_id,
|
|
173
|
-
model_uid=self.model_uid,
|
|
174
|
-
prompt_tokens=-1,
|
|
175
|
-
completion_tokens=-1,
|
|
176
|
-
total_tokens=-1,
|
|
177
|
-
has_choice=True,
|
|
178
|
-
has_content=False,
|
|
179
|
-
)
|