xinference 0.14.2__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +48 -41
- xinference/model/audio/chattts.py +24 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +23 -1
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +49 -1
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +2 -0
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +10 -1
- xinference/model/llm/vllm/core.py +1 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/METADATA +18 -6
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/RECORD +135 -37
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
import time
|
|
16
|
+
import uuid
|
|
17
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
18
|
+
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from ....core.scheduler import InferenceRequest
|
|
23
|
+
from ....model.utils import select_device
|
|
24
|
+
from ....types import (
|
|
25
|
+
ChatCompletion,
|
|
26
|
+
ChatCompletionChunk,
|
|
27
|
+
ChatCompletionMessage,
|
|
28
|
+
Completion,
|
|
29
|
+
CompletionChoice,
|
|
30
|
+
CompletionChunk,
|
|
31
|
+
CompletionUsage,
|
|
32
|
+
)
|
|
33
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
34
|
+
from ..utils import _decode_image
|
|
35
|
+
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
36
|
+
from .utils import get_max_src_len
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
LANGUAGE_TOKEN_TYPE = 0
|
|
42
|
+
VISION_TOKEN_TYPE = 1
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def recur_move_to(item, tgt, criterion_func):
|
|
46
|
+
"""
|
|
47
|
+
This function is copied from https://github.com/THUDM/CogVLM2/blob/main/basic_demo/cli_demo_batch_inference.py
|
|
48
|
+
"""
|
|
49
|
+
if criterion_func(item):
|
|
50
|
+
device_copy = item.to(tgt)
|
|
51
|
+
return device_copy
|
|
52
|
+
elif isinstance(item, list):
|
|
53
|
+
return [recur_move_to(v, tgt, criterion_func) for v in item]
|
|
54
|
+
elif isinstance(item, tuple):
|
|
55
|
+
return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
|
|
56
|
+
elif isinstance(item, dict):
|
|
57
|
+
return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
|
|
58
|
+
else:
|
|
59
|
+
return item
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class CogVLM2VideoModel(PytorchChatModel):
|
|
63
|
+
def __init__(self, *args, **kwargs):
|
|
64
|
+
super().__init__(*args, **kwargs)
|
|
65
|
+
self._torch_type = None
|
|
66
|
+
self._device = None
|
|
67
|
+
self._tokenizer = None
|
|
68
|
+
self._model = None
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def match(
|
|
72
|
+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
73
|
+
) -> bool:
|
|
74
|
+
family = model_family.model_family or model_family.model_name
|
|
75
|
+
if "cogvlm2" in family.lower() and "video" in family.lower():
|
|
76
|
+
return True
|
|
77
|
+
return False
|
|
78
|
+
|
|
79
|
+
def load(self, **kwargs):
|
|
80
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
81
|
+
from transformers.generation import GenerationConfig
|
|
82
|
+
|
|
83
|
+
device = self._pytorch_model_config.get("device", "auto")
|
|
84
|
+
self._device = select_device(device)
|
|
85
|
+
self._torch_type = (
|
|
86
|
+
torch.bfloat16
|
|
87
|
+
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
|
|
88
|
+
else torch.float16
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if self._check_tensorizer_integrity():
|
|
92
|
+
self._model, self._tokenizer = self._load_tensorizer()
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
if "8-bit" in self.quantization.lower():
|
|
96
|
+
kwargs["load_in_8bit"] = True
|
|
97
|
+
elif "4-bit" in self.quantization.lower():
|
|
98
|
+
kwargs["load_in_4bit"] = True
|
|
99
|
+
|
|
100
|
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
101
|
+
self.model_path,
|
|
102
|
+
trust_remote_code=True,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
self._model = AutoModelForCausalLM.from_pretrained(
|
|
106
|
+
self.model_path,
|
|
107
|
+
torch_dtype=self._torch_type,
|
|
108
|
+
trust_remote_code=True,
|
|
109
|
+
low_cpu_mem_usage=True,
|
|
110
|
+
device_map="auto",
|
|
111
|
+
**kwargs
|
|
112
|
+
).eval()
|
|
113
|
+
|
|
114
|
+
# Specify hyperparameters for generation
|
|
115
|
+
self._model.generation_config = GenerationConfig.from_pretrained(
|
|
116
|
+
self.model_path,
|
|
117
|
+
trust_remote_code=True,
|
|
118
|
+
)
|
|
119
|
+
self._save_tensorizer()
|
|
120
|
+
|
|
121
|
+
def _load_video(self, video_path):
|
|
122
|
+
import numpy as np
|
|
123
|
+
from decord import VideoReader, bridge, cpu
|
|
124
|
+
|
|
125
|
+
bridge.set_bridge("torch")
|
|
126
|
+
num_frames = 24
|
|
127
|
+
|
|
128
|
+
decord_vr = VideoReader(video_path, ctx=cpu(0))
|
|
129
|
+
frame_id_list = None
|
|
130
|
+
total_frames = len(decord_vr)
|
|
131
|
+
timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
|
|
132
|
+
timestamps = [i[0] for i in timestamps]
|
|
133
|
+
max_second = round(max(timestamps)) + 1
|
|
134
|
+
frame_id_list = []
|
|
135
|
+
for second in range(max_second):
|
|
136
|
+
closest_num = min(timestamps, key=lambda x: abs(x - second))
|
|
137
|
+
index = timestamps.index(closest_num)
|
|
138
|
+
frame_id_list.append(index)
|
|
139
|
+
if len(frame_id_list) >= num_frames:
|
|
140
|
+
break
|
|
141
|
+
video_data = decord_vr.get_batch(frame_id_list)
|
|
142
|
+
video_data = video_data.permute(3, 0, 1, 2)
|
|
143
|
+
return video_data
|
|
144
|
+
|
|
145
|
+
def _message_content_to_cogvlm2(self, content):
|
|
146
|
+
if not isinstance(content, str):
|
|
147
|
+
texts = []
|
|
148
|
+
image_urls = []
|
|
149
|
+
video_urls = []
|
|
150
|
+
for c in content:
|
|
151
|
+
c_type = c.get("type")
|
|
152
|
+
if c_type == "text":
|
|
153
|
+
texts.append(c["text"])
|
|
154
|
+
elif c_type == "image_url":
|
|
155
|
+
image_urls.append(c["image_url"]["url"])
|
|
156
|
+
elif c_type == "video_url":
|
|
157
|
+
video_urls.append(c["video_url"]["url"])
|
|
158
|
+
if len(video_urls) > 1:
|
|
159
|
+
raise RuntimeError("Only one video per message is supported")
|
|
160
|
+
image_futures = []
|
|
161
|
+
video = None
|
|
162
|
+
with ThreadPoolExecutor() as executor:
|
|
163
|
+
for image_url in image_urls:
|
|
164
|
+
fut = executor.submit(_decode_image, image_url)
|
|
165
|
+
image_futures.append(fut)
|
|
166
|
+
images = [fut.result() for fut in image_futures]
|
|
167
|
+
for v in video_urls:
|
|
168
|
+
video = self._load_video(v)
|
|
169
|
+
text = " ".join(texts)
|
|
170
|
+
return text, images, video
|
|
171
|
+
return content, [], None
|
|
172
|
+
|
|
173
|
+
def _history_content_to_cogvlm2(
|
|
174
|
+
self, system_prompt: str, chat_history: List[ChatCompletionMessage]
|
|
175
|
+
):
|
|
176
|
+
query = system_prompt
|
|
177
|
+
history: List[Tuple] = []
|
|
178
|
+
pixel_values = None
|
|
179
|
+
video_urls: List[str] = []
|
|
180
|
+
for i in range(0, len(chat_history), 2):
|
|
181
|
+
user = chat_history[i]["content"]
|
|
182
|
+
if isinstance(user, List):
|
|
183
|
+
for content in user:
|
|
184
|
+
c_type = content.get("type")
|
|
185
|
+
if c_type == "text":
|
|
186
|
+
user = content["text"]
|
|
187
|
+
elif c_type == "image_url" and not pixel_values:
|
|
188
|
+
pixel_values = _decode_image(content["image_url"]["url"])
|
|
189
|
+
elif c_type == "video_url":
|
|
190
|
+
video_urls.append(content["video_url"]["url"])
|
|
191
|
+
assistant = chat_history[i + 1]["content"]
|
|
192
|
+
history.append((user, assistant))
|
|
193
|
+
query = assistant # type: ignore
|
|
194
|
+
if len(video_urls) > 1:
|
|
195
|
+
raise RuntimeError("Only one video per message is supported")
|
|
196
|
+
video = None
|
|
197
|
+
for v in video_urls:
|
|
198
|
+
video = self._load_video(v)
|
|
199
|
+
return query, history, [pixel_values], video
|
|
200
|
+
|
|
201
|
+
def get_query_and_history(
|
|
202
|
+
self,
|
|
203
|
+
prompt: Union[str, List[Dict]],
|
|
204
|
+
system_prompt: Optional[str] = None,
|
|
205
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
206
|
+
):
|
|
207
|
+
content, image, video = self._message_content_to_cogvlm2(prompt)
|
|
208
|
+
|
|
209
|
+
history = []
|
|
210
|
+
history_image = None
|
|
211
|
+
history_video = None
|
|
212
|
+
if chat_history:
|
|
213
|
+
(
|
|
214
|
+
query,
|
|
215
|
+
history,
|
|
216
|
+
history_image,
|
|
217
|
+
history_video,
|
|
218
|
+
) = self._history_content_to_cogvlm2(
|
|
219
|
+
system_prompt, chat_history # type: ignore
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
if image and history_image:
|
|
223
|
+
history = []
|
|
224
|
+
query = content
|
|
225
|
+
else:
|
|
226
|
+
image = image if image else history_image
|
|
227
|
+
query = content
|
|
228
|
+
|
|
229
|
+
if video is not None and history_video is not None:
|
|
230
|
+
history = []
|
|
231
|
+
query = content
|
|
232
|
+
else:
|
|
233
|
+
video = video if video is not None else history_video
|
|
234
|
+
query = content
|
|
235
|
+
|
|
236
|
+
return query, image, video, history
|
|
237
|
+
|
|
238
|
+
def chat(
|
|
239
|
+
self,
|
|
240
|
+
prompt: Union[str, List[Dict]],
|
|
241
|
+
system_prompt: Optional[str] = None,
|
|
242
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
243
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
244
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
245
|
+
system_prompt = system_prompt if system_prompt else ""
|
|
246
|
+
stream = generate_config.get("stream", False) if generate_config else False
|
|
247
|
+
|
|
248
|
+
sanitized_config = {
|
|
249
|
+
"pad_token_id": 128002,
|
|
250
|
+
"max_new_tokens": generate_config.get("max_tokens", 512)
|
|
251
|
+
if generate_config
|
|
252
|
+
else 512,
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
query, image, video, history = self.get_query_and_history(
|
|
256
|
+
prompt, system_prompt=system_prompt, chat_history=chat_history
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
if video is not None:
|
|
260
|
+
image = [video]
|
|
261
|
+
|
|
262
|
+
input_by_model = self._model.build_conversation_input_ids(
|
|
263
|
+
self._tokenizer,
|
|
264
|
+
query=query,
|
|
265
|
+
history=history,
|
|
266
|
+
images=image,
|
|
267
|
+
template_version="chat",
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
inputs = {
|
|
271
|
+
"input_ids": input_by_model["input_ids"].unsqueeze(0).to(self._device),
|
|
272
|
+
"token_type_ids": input_by_model["token_type_ids"]
|
|
273
|
+
.unsqueeze(0)
|
|
274
|
+
.to(self._device),
|
|
275
|
+
"attention_mask": input_by_model["attention_mask"]
|
|
276
|
+
.unsqueeze(0)
|
|
277
|
+
.to(self._device),
|
|
278
|
+
"images": [
|
|
279
|
+
[input_by_model["images"][0].to(self._device).to(self._torch_type)]
|
|
280
|
+
]
|
|
281
|
+
if image is not None
|
|
282
|
+
else None,
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
if stream:
|
|
286
|
+
it = self._streaming_chat_response(inputs, sanitized_config)
|
|
287
|
+
return self._to_chat_completion_chunks(it)
|
|
288
|
+
else:
|
|
289
|
+
with torch.no_grad():
|
|
290
|
+
outputs = self._model.generate(**inputs, **sanitized_config)
|
|
291
|
+
outputs = outputs[:, inputs["input_ids"].shape[1] :]
|
|
292
|
+
response = self._tokenizer.decode(outputs[0])
|
|
293
|
+
response = response.split("<|end_of_text|>")[0]
|
|
294
|
+
|
|
295
|
+
chunk = Completion(
|
|
296
|
+
id=str(uuid.uuid1()),
|
|
297
|
+
object="text_completion",
|
|
298
|
+
created=int(time.time()),
|
|
299
|
+
model=self.model_uid,
|
|
300
|
+
choices=[
|
|
301
|
+
CompletionChoice(
|
|
302
|
+
index=0, text=response, finish_reason="stop", logprobs=None
|
|
303
|
+
)
|
|
304
|
+
],
|
|
305
|
+
usage=CompletionUsage(
|
|
306
|
+
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
307
|
+
),
|
|
308
|
+
)
|
|
309
|
+
return self._to_chat_completion(chunk)
|
|
310
|
+
|
|
311
|
+
def _streaming_chat_response(
|
|
312
|
+
self, inputs: Dict, config: Dict
|
|
313
|
+
) -> Iterator[CompletionChunk]:
|
|
314
|
+
from threading import Thread
|
|
315
|
+
|
|
316
|
+
from transformers import TextIteratorStreamer
|
|
317
|
+
|
|
318
|
+
streamer = TextIteratorStreamer(
|
|
319
|
+
self._tokenizer, skip_prompt=True, skip_special_tokens=True
|
|
320
|
+
)
|
|
321
|
+
generation_kwargs = {
|
|
322
|
+
"input_ids": inputs["input_ids"],
|
|
323
|
+
"attention_mask": inputs["attention_mask"],
|
|
324
|
+
"token_type_ids": inputs["token_type_ids"],
|
|
325
|
+
"images": inputs["images"],
|
|
326
|
+
"max_new_tokens": config["max_new_tokens"],
|
|
327
|
+
"pad_token_id": config["pad_token_id"],
|
|
328
|
+
"streamer": streamer,
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
|
|
332
|
+
thread.start()
|
|
333
|
+
|
|
334
|
+
completion_id = str(uuid.uuid1())
|
|
335
|
+
for new_text in streamer:
|
|
336
|
+
chunk = CompletionChunk(
|
|
337
|
+
id=completion_id,
|
|
338
|
+
object="text_completion",
|
|
339
|
+
created=int(time.time()),
|
|
340
|
+
model=self.model_uid,
|
|
341
|
+
choices=[
|
|
342
|
+
CompletionChoice(
|
|
343
|
+
index=0, text=new_text, finish_reason=None, logprobs=None
|
|
344
|
+
)
|
|
345
|
+
],
|
|
346
|
+
usage=CompletionUsage(
|
|
347
|
+
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
348
|
+
),
|
|
349
|
+
)
|
|
350
|
+
yield chunk
|
|
351
|
+
|
|
352
|
+
completion_choice = CompletionChoice(
|
|
353
|
+
text="", index=0, logprobs=None, finish_reason="stop"
|
|
354
|
+
)
|
|
355
|
+
chunk = CompletionChunk(
|
|
356
|
+
id=completion_id,
|
|
357
|
+
object="text_completion",
|
|
358
|
+
created=int(time.time()),
|
|
359
|
+
model=self.model_uid,
|
|
360
|
+
choices=[completion_choice],
|
|
361
|
+
usage=CompletionUsage(
|
|
362
|
+
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
363
|
+
),
|
|
364
|
+
)
|
|
365
|
+
yield chunk
|
|
366
|
+
|
|
367
|
+
@staticmethod
|
|
368
|
+
def build_position_ids(x, attention_mask=None):
|
|
369
|
+
"""
|
|
370
|
+
Copied from https://huggingface.co/THUDM/cogvlm2-llama3-chinese-chat-19B-int4/blob/main/modeling_cogvlm.py
|
|
371
|
+
"""
|
|
372
|
+
# Fix: 参考官方开源代码
|
|
373
|
+
if attention_mask is not None:
|
|
374
|
+
tmp = x.clone()
|
|
375
|
+
tmp[~(attention_mask.bool())] = -1
|
|
376
|
+
else:
|
|
377
|
+
tmp = x.clone()
|
|
378
|
+
# image boi eoi token as LANGUAGE_TOKEN_TYPE
|
|
379
|
+
is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
|
|
380
|
+
is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (
|
|
381
|
+
tmp[:, :-1] == LANGUAGE_TOKEN_TYPE
|
|
382
|
+
)
|
|
383
|
+
is_boi_eoi[:, 0] |= tmp[:, 0] == VISION_TOKEN_TYPE
|
|
384
|
+
is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (
|
|
385
|
+
tmp[:, 1:] == LANGUAGE_TOKEN_TYPE
|
|
386
|
+
)
|
|
387
|
+
is_boi_eoi[:, -1] |= tmp[:, -1] == VISION_TOKEN_TYPE
|
|
388
|
+
tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
|
|
389
|
+
# final position ids
|
|
390
|
+
y = torch.zeros_like(x, dtype=torch.long)
|
|
391
|
+
y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
|
|
392
|
+
(tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
|
|
393
|
+
)
|
|
394
|
+
y = y.cumsum(dim=-1)
|
|
395
|
+
return y
|
|
396
|
+
|
|
397
|
+
def get_dtype(self):
|
|
398
|
+
return self._torch_type
|
|
399
|
+
|
|
400
|
+
def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
|
|
401
|
+
query, image, video, history = self.get_query_and_history(
|
|
402
|
+
prompt, system_prompt=system_prompt, chat_history=chat_history
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
if video:
|
|
406
|
+
image = [video]
|
|
407
|
+
|
|
408
|
+
input_by_model: dict = self._model.build_conversation_input_ids( # type: ignore
|
|
409
|
+
self._tokenizer,
|
|
410
|
+
query=query,
|
|
411
|
+
history=history,
|
|
412
|
+
images=image,
|
|
413
|
+
template_version="chat",
|
|
414
|
+
)
|
|
415
|
+
return {
|
|
416
|
+
"input_ids": input_by_model["input_ids"], # seq_len
|
|
417
|
+
"token_type_ids": input_by_model["token_type_ids"], # seq_len
|
|
418
|
+
"attention_mask": input_by_model["attention_mask"], # seq_len
|
|
419
|
+
"images": input_by_model["images"],
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
423
|
+
"""
|
|
424
|
+
See https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B/blob/main/generation_config.json
|
|
425
|
+
"""
|
|
426
|
+
raw_config = req.inference_kwargs.get("raw_params", {})
|
|
427
|
+
temperature = raw_config.get("temperature", None)
|
|
428
|
+
if temperature is None:
|
|
429
|
+
raw_config["temperature"] = 0.6
|
|
430
|
+
top_p = raw_config.get("top_p", None)
|
|
431
|
+
if top_p is None:
|
|
432
|
+
raw_config["top_p"] = 0.9
|
|
433
|
+
return raw_config
|
|
434
|
+
|
|
435
|
+
def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
|
|
436
|
+
context_len = self.get_context_len()
|
|
437
|
+
assert isinstance(prompts[0], dict)
|
|
438
|
+
images = []
|
|
439
|
+
max_length = float("-inf")
|
|
440
|
+
for i, feature in enumerate(prompts):
|
|
441
|
+
req = req_list[i]
|
|
442
|
+
if "images" in feature:
|
|
443
|
+
images.append(feature.pop("images", None))
|
|
444
|
+
max_src_len = get_max_src_len(context_len, req)
|
|
445
|
+
input_ids = feature["input_ids"][-max_src_len:]
|
|
446
|
+
req.prompt_tokens = input_ids.tolist()
|
|
447
|
+
feature["input_ids"] = input_ids
|
|
448
|
+
feature["token_type_ids"] = feature["token_type_ids"][-max_src_len:]
|
|
449
|
+
feature["attention_mask"] = feature["attention_mask"][-max_src_len:]
|
|
450
|
+
req.extra_kwargs["attention_mask_seq_len"] = feature[
|
|
451
|
+
"attention_mask"
|
|
452
|
+
].shape[0]
|
|
453
|
+
max_length = max(len(input_ids), max_length)
|
|
454
|
+
|
|
455
|
+
def pad_to_max_length_internal(feature, max_len, idx):
|
|
456
|
+
padding_length = max_len - len(feature["input_ids"])
|
|
457
|
+
req_list[idx].padding_len = padding_length
|
|
458
|
+
feature["input_ids"] = torch.cat(
|
|
459
|
+
[torch.full((padding_length,), 0), feature["input_ids"]]
|
|
460
|
+
)
|
|
461
|
+
feature["token_type_ids"] = torch.cat(
|
|
462
|
+
[
|
|
463
|
+
torch.zeros(padding_length, dtype=torch.long),
|
|
464
|
+
feature["token_type_ids"],
|
|
465
|
+
]
|
|
466
|
+
)
|
|
467
|
+
feature["attention_mask"] = torch.cat(
|
|
468
|
+
[
|
|
469
|
+
torch.zeros(padding_length, dtype=torch.long),
|
|
470
|
+
feature["attention_mask"],
|
|
471
|
+
]
|
|
472
|
+
)
|
|
473
|
+
return feature
|
|
474
|
+
|
|
475
|
+
features = [
|
|
476
|
+
pad_to_max_length_internal(feature, max_length, i)
|
|
477
|
+
for i, feature in enumerate(prompts)
|
|
478
|
+
]
|
|
479
|
+
batch = {
|
|
480
|
+
key: torch.stack([feature[key] for feature in features])
|
|
481
|
+
for key in features[0].keys()
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
position_ids = self.build_position_ids(batch["token_type_ids"])
|
|
485
|
+
batch["position_ids"] = position_ids
|
|
486
|
+
|
|
487
|
+
for i in range(len(prompts)):
|
|
488
|
+
req = req_list[i]
|
|
489
|
+
req.extra_kwargs["max_position_id"] = position_ids[i : i + 1, -1].item()
|
|
490
|
+
|
|
491
|
+
if images:
|
|
492
|
+
batch["images"] = images
|
|
493
|
+
|
|
494
|
+
batch = recur_move_to(
|
|
495
|
+
batch, self._device, lambda x: isinstance(x, torch.Tensor)
|
|
496
|
+
)
|
|
497
|
+
dtype = self.get_dtype()
|
|
498
|
+
if dtype:
|
|
499
|
+
batch = recur_move_to(
|
|
500
|
+
batch,
|
|
501
|
+
dtype,
|
|
502
|
+
lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x),
|
|
503
|
+
)
|
|
504
|
+
return batch
|
|
505
|
+
|
|
506
|
+
def build_decode_token_type_ids(
|
|
507
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
508
|
+
):
|
|
509
|
+
token_type_ids = torch.full(
|
|
510
|
+
(batch_size, 1), fill_value=1, dtype=torch.long, device=self._device
|
|
511
|
+
)
|
|
512
|
+
return token_type_ids
|
|
513
|
+
|
|
514
|
+
def build_decode_position_ids(
|
|
515
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
516
|
+
):
|
|
517
|
+
tmp = []
|
|
518
|
+
for r in reqs:
|
|
519
|
+
r.extra_kwargs["max_position_id"] += 1
|
|
520
|
+
tmp.append(r.extra_kwargs["max_position_id"])
|
|
521
|
+
position_ids = torch.as_tensor(
|
|
522
|
+
tmp, device=self._device, dtype=torch.long
|
|
523
|
+
).unsqueeze(1)
|
|
524
|
+
return position_ids
|
|
@@ -11,19 +11,15 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import base64
|
|
15
14
|
import logging
|
|
16
15
|
import time
|
|
17
16
|
import typing
|
|
18
17
|
import uuid
|
|
19
18
|
from concurrent.futures import ThreadPoolExecutor
|
|
20
|
-
from io import BytesIO
|
|
21
19
|
from threading import Thread
|
|
22
20
|
from typing import Dict, Iterator, List, Optional, Union
|
|
23
21
|
|
|
24
|
-
import requests
|
|
25
22
|
import torch
|
|
26
|
-
from PIL import Image
|
|
27
23
|
|
|
28
24
|
from ....core.scheduler import InferenceRequest
|
|
29
25
|
from ....types import (
|
|
@@ -37,6 +33,7 @@ from ....types import (
|
|
|
37
33
|
)
|
|
38
34
|
from ...utils import select_device
|
|
39
35
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
36
|
+
from ..utils import _decode_image
|
|
40
37
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
41
38
|
from .utils import get_max_src_len
|
|
42
39
|
|
|
@@ -106,24 +103,6 @@ class Glm4VModel(PytorchChatModel):
|
|
|
106
103
|
self._save_tensorizer()
|
|
107
104
|
|
|
108
105
|
def _message_content_to_chat(self, content):
|
|
109
|
-
def _load_image(_url):
|
|
110
|
-
if _url.startswith("data:"):
|
|
111
|
-
logging.info("Parse url by base64 decoder.")
|
|
112
|
-
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
113
|
-
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
114
|
-
_type, data = _url.split(";")
|
|
115
|
-
_, ext = _type.split("/")
|
|
116
|
-
data = data[len("base64,") :]
|
|
117
|
-
data = base64.b64decode(data.encode("utf-8"))
|
|
118
|
-
return Image.open(BytesIO(data)).convert("RGB")
|
|
119
|
-
else:
|
|
120
|
-
try:
|
|
121
|
-
response = requests.get(_url)
|
|
122
|
-
except requests.exceptions.MissingSchema:
|
|
123
|
-
return Image.open(_url).convert("RGB")
|
|
124
|
-
else:
|
|
125
|
-
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
126
|
-
|
|
127
106
|
if not isinstance(content, str):
|
|
128
107
|
texts = []
|
|
129
108
|
image_urls = []
|
|
@@ -136,7 +115,7 @@ class Glm4VModel(PytorchChatModel):
|
|
|
136
115
|
image_futures = []
|
|
137
116
|
with ThreadPoolExecutor() as executor:
|
|
138
117
|
for image_url in image_urls:
|
|
139
|
-
fut = executor.submit(
|
|
118
|
+
fut = executor.submit(_decode_image, image_url)
|
|
140
119
|
image_futures.append(fut)
|
|
141
120
|
images = [fut.result() for fut in image_futures]
|
|
142
121
|
text = " ".join(texts)
|