xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +15 -34
- xinference/client/restful/restful_client.py +2 -2
- xinference/core/chat_interface.py +45 -10
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +8 -5
- xinference/core/scheduler.py +1 -2
- xinference/core/worker.py +49 -42
- xinference/deploy/cmdline.py +2 -2
- xinference/deploy/test/test_cmdline.py +7 -7
- xinference/model/audio/chattts.py +24 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +23 -1
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +49 -1
- xinference/model/llm/__init__.py +26 -27
- xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
- xinference/model/llm/llm_family.json +606 -1266
- xinference/model/llm/llm_family.py +16 -139
- xinference/model/llm/llm_family_modelscope.json +276 -313
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/memory.py +9 -9
- xinference/model/llm/sglang/core.py +2 -2
- xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
- xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/{pytorch → transformers}/core.py +3 -10
- xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +540 -0
- xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
- xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
- xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
- xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
- xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
- xinference/model/llm/utils.py +85 -70
- xinference/model/llm/vllm/core.py +110 -11
- xinference/model/utils.py +1 -95
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/internvl/__init__.py +0 -0
- xinference/thirdparty/internvl/conversation.py +393 -0
- xinference/thirdparty/omnilmm/model/utils.py +16 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
- xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
- xinference/locale/utils.py +0 -39
- xinference/locale/zh_CN.json +0 -26
- xinference/model/llm/ggml/tools/__init__.py +0 -15
- xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
- xinference/model/llm/ggml/tools/gguf.py +0 -884
- xinference/model/llm/pytorch/__init__.py +0 -13
- xinference/model/llm/pytorch/baichuan.py +0 -81
- xinference/model/llm/pytorch/falcon.py +0 -138
- xinference/model/llm/pytorch/intern_vl.py +0 -352
- xinference/model/llm/pytorch/vicuna.py +0 -69
- xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
- xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
- /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
- /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
- /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
|
@@ -11,16 +11,12 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import base64
|
|
15
|
-
import json
|
|
16
14
|
import logging
|
|
17
15
|
import time
|
|
18
16
|
import uuid
|
|
19
17
|
from concurrent.futures import ThreadPoolExecutor
|
|
20
|
-
from io import BytesIO
|
|
21
18
|
from typing import Dict, Iterator, List, Optional, Union
|
|
22
19
|
|
|
23
|
-
import requests
|
|
24
20
|
import torch
|
|
25
21
|
from PIL import Image
|
|
26
22
|
|
|
@@ -35,6 +31,7 @@ from ....types import (
|
|
|
35
31
|
)
|
|
36
32
|
from ...utils import select_device
|
|
37
33
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
34
|
+
from ..utils import _decode_image
|
|
38
35
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
39
36
|
|
|
40
37
|
logger = logging.getLogger(__name__)
|
|
@@ -106,47 +103,60 @@ class MiniCPMV26Model(PytorchChatModel):
|
|
|
106
103
|
self._save_tensorizer()
|
|
107
104
|
|
|
108
105
|
def _message_content_to_chat(self, content):
|
|
109
|
-
|
|
106
|
+
MAX_NUM_FRAMES = 64
|
|
107
|
+
|
|
108
|
+
def encode_video(video_path):
|
|
109
|
+
from decord import VideoReader, cpu
|
|
110
|
+
|
|
111
|
+
def uniform_sample(l, n):
|
|
112
|
+
gap = len(l) / n
|
|
113
|
+
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
|
114
|
+
return [l[i] for i in idxs]
|
|
115
|
+
|
|
116
|
+
vr = VideoReader(video_path, ctx=cpu(0))
|
|
117
|
+
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
|
118
|
+
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
|
119
|
+
if len(frame_idx) > MAX_NUM_FRAMES:
|
|
120
|
+
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
|
121
|
+
frames = vr.get_batch(frame_idx).asnumpy()
|
|
122
|
+
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
|
123
|
+
print("num frames:", len(frames))
|
|
124
|
+
return frames
|
|
125
|
+
|
|
126
|
+
def _load_video(_url):
|
|
127
|
+
frames = None
|
|
110
128
|
if _url.startswith("data:"):
|
|
111
|
-
|
|
112
|
-
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
113
|
-
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
114
|
-
_type, data = _url.split(";")
|
|
115
|
-
_, ext = _type.split("/")
|
|
116
|
-
data = data[len("base64,") :]
|
|
117
|
-
data = base64.b64decode(data.encode("utf-8"))
|
|
118
|
-
return Image.open(BytesIO(data)).convert("RGB")
|
|
129
|
+
raise RuntimeError("Only video url format is supported")
|
|
119
130
|
else:
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
except requests.exceptions.MissingSchema:
|
|
123
|
-
return Image.open(_url).convert("RGB")
|
|
124
|
-
else:
|
|
125
|
-
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
131
|
+
frames = encode_video(_url)
|
|
132
|
+
return frames
|
|
126
133
|
|
|
127
134
|
if not isinstance(content, str):
|
|
128
135
|
texts = []
|
|
129
136
|
image_urls = []
|
|
137
|
+
video_urls = []
|
|
130
138
|
for c in content:
|
|
131
139
|
c_type = c.get("type")
|
|
132
140
|
if c_type == "text":
|
|
133
141
|
texts.append(c["text"])
|
|
134
142
|
elif c_type == "image_url":
|
|
135
143
|
image_urls.append(c["image_url"]["url"])
|
|
144
|
+
elif c_type == "video_url":
|
|
145
|
+
video_urls.append(c["video_url"]["url"])
|
|
136
146
|
image_futures = []
|
|
137
147
|
with ThreadPoolExecutor() as executor:
|
|
138
148
|
for image_url in image_urls:
|
|
139
|
-
fut = executor.submit(
|
|
149
|
+
fut = executor.submit(_decode_image, image_url)
|
|
140
150
|
image_futures.append(fut)
|
|
141
151
|
images = [fut.result() for fut in image_futures]
|
|
152
|
+
frames = []
|
|
153
|
+
if len(video_urls) > 1:
|
|
154
|
+
raise RuntimeError("Only one video per message is supported")
|
|
155
|
+
for v in video_urls:
|
|
156
|
+
frames = _load_video(v)
|
|
142
157
|
text = " ".join(texts)
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
elif len(images) == 1:
|
|
146
|
-
return text, images
|
|
147
|
-
else:
|
|
148
|
-
raise RuntimeError("Only one image per message is supported")
|
|
149
|
-
return content, []
|
|
158
|
+
return text, images, frames
|
|
159
|
+
return content, [], []
|
|
150
160
|
|
|
151
161
|
def chat(
|
|
152
162
|
self,
|
|
@@ -156,36 +166,51 @@ class MiniCPMV26Model(PytorchChatModel):
|
|
|
156
166
|
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
157
167
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
158
168
|
stream = generate_config.get("stream", False) if generate_config else False
|
|
159
|
-
|
|
169
|
+
videoExisted = False
|
|
170
|
+
|
|
171
|
+
content, images_chat, video_frames = self._message_content_to_chat(prompt)
|
|
172
|
+
if len(video_frames) > 0:
|
|
173
|
+
videoExisted = True
|
|
174
|
+
images_chat = video_frames
|
|
160
175
|
|
|
161
176
|
msgs = []
|
|
162
177
|
query_to_response: List[Dict] = []
|
|
163
|
-
images_history = []
|
|
164
178
|
for h in chat_history or []:
|
|
179
|
+
images_history = []
|
|
165
180
|
role = h["role"]
|
|
166
|
-
content_h, images_tmp = self._message_content_to_chat(
|
|
181
|
+
content_h, images_tmp, video_frames_h = self._message_content_to_chat(
|
|
182
|
+
h["content"]
|
|
183
|
+
)
|
|
167
184
|
if images_tmp != []:
|
|
168
185
|
images_history = images_tmp
|
|
186
|
+
if len(video_frames_h) > 0:
|
|
187
|
+
videoExisted = True
|
|
188
|
+
images_history = video_frames_h
|
|
169
189
|
if len(query_to_response) == 0 and role == "user":
|
|
170
|
-
query_to_response.append(
|
|
190
|
+
query_to_response.append(
|
|
191
|
+
{"role": "user", "content": images_history + [content_h]}
|
|
192
|
+
)
|
|
171
193
|
if len(query_to_response) == 1 and role == "assistant":
|
|
172
|
-
query_to_response.append(
|
|
194
|
+
query_to_response.append(
|
|
195
|
+
{"role": "assistant", "content": images_history + [content_h]}
|
|
196
|
+
)
|
|
173
197
|
if len(query_to_response) == 2:
|
|
174
198
|
msgs.extend(query_to_response)
|
|
175
199
|
query_to_response = []
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
200
|
+
msgs.append({"role": "user", "content": images_chat + [content]})
|
|
201
|
+
|
|
202
|
+
# Set decode params for video
|
|
203
|
+
params = {}
|
|
204
|
+
if videoExisted:
|
|
205
|
+
params = {"use_image_id": False, "max_slice_nums": 1}
|
|
182
206
|
|
|
183
207
|
chat = self._model.chat(
|
|
184
|
-
image=
|
|
185
|
-
msgs=
|
|
208
|
+
image=None,
|
|
209
|
+
msgs=msgs,
|
|
186
210
|
tokenizer=self._tokenizer,
|
|
187
211
|
sampling=True,
|
|
188
|
-
**generate_config
|
|
212
|
+
**generate_config,
|
|
213
|
+
**params,
|
|
189
214
|
)
|
|
190
215
|
if stream:
|
|
191
216
|
it = self.chat_stream(chat)
|
|
@@ -11,18 +11,14 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import base64
|
|
15
14
|
import logging
|
|
16
15
|
import time
|
|
17
16
|
import uuid
|
|
18
17
|
from concurrent.futures import ThreadPoolExecutor
|
|
19
|
-
from io import BytesIO
|
|
20
18
|
from threading import Thread
|
|
21
19
|
from typing import Dict, Iterator, List, Optional, Union
|
|
22
20
|
|
|
23
|
-
import requests
|
|
24
21
|
import torch
|
|
25
|
-
from PIL import Image
|
|
26
22
|
|
|
27
23
|
from ....model.utils import select_device
|
|
28
24
|
from ....types import (
|
|
@@ -35,6 +31,7 @@ from ....types import (
|
|
|
35
31
|
CompletionUsage,
|
|
36
32
|
)
|
|
37
33
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
34
|
+
from ..utils import _decode_image
|
|
38
35
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
39
36
|
|
|
40
37
|
logger = logging.getLogger(__name__)
|
|
@@ -78,25 +75,6 @@ class YiVLChatModel(PytorchChatModel):
|
|
|
78
75
|
|
|
79
76
|
@staticmethod
|
|
80
77
|
def _message_content_to_yi(content) -> Union[str, tuple]:
|
|
81
|
-
def _load_image(_url):
|
|
82
|
-
if _url.startswith("data:"):
|
|
83
|
-
logging.info("Parse url by base64 decoder.")
|
|
84
|
-
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
85
|
-
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
86
|
-
_type, data = _url.split(";")
|
|
87
|
-
_, ext = _type.split("/")
|
|
88
|
-
data = data[len("base64,") :]
|
|
89
|
-
data = base64.b64decode(data.encode("utf-8"))
|
|
90
|
-
|
|
91
|
-
return Image.open(BytesIO(data))
|
|
92
|
-
else:
|
|
93
|
-
try:
|
|
94
|
-
response = requests.get(_url)
|
|
95
|
-
except requests.exceptions.MissingSchema:
|
|
96
|
-
return Image.open(_url)
|
|
97
|
-
else:
|
|
98
|
-
return Image.open(BytesIO(response.content))
|
|
99
|
-
|
|
100
78
|
if not isinstance(content, str):
|
|
101
79
|
from ....thirdparty.llava.model.constants import DEFAULT_IMAGE_TOKEN
|
|
102
80
|
|
|
@@ -111,7 +89,7 @@ class YiVLChatModel(PytorchChatModel):
|
|
|
111
89
|
image_futures = []
|
|
112
90
|
with ThreadPoolExecutor() as executor:
|
|
113
91
|
for image_url in image_urls:
|
|
114
|
-
fut = executor.submit(
|
|
92
|
+
fut = executor.submit(_decode_image, image_url)
|
|
115
93
|
image_futures.append(fut)
|
|
116
94
|
images = [fut.result() for fut in image_futures]
|
|
117
95
|
text = " ".join(texts)
|
xinference/model/llm/utils.py
CHANGED
|
@@ -11,14 +11,19 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import base64
|
|
14
15
|
import functools
|
|
15
16
|
import json
|
|
16
17
|
import logging
|
|
17
18
|
import os
|
|
18
19
|
import time
|
|
19
20
|
import uuid
|
|
21
|
+
from io import BytesIO
|
|
20
22
|
from typing import AsyncGenerator, Dict, Iterator, List, Optional, Tuple, cast
|
|
21
23
|
|
|
24
|
+
import requests
|
|
25
|
+
from PIL import Image
|
|
26
|
+
|
|
22
27
|
from ...types import (
|
|
23
28
|
SPECIAL_TOOL_PROMPT,
|
|
24
29
|
ChatCompletion,
|
|
@@ -28,7 +33,7 @@ from ...types import (
|
|
|
28
33
|
CompletionChunk,
|
|
29
34
|
)
|
|
30
35
|
from .llm_family import (
|
|
31
|
-
|
|
36
|
+
LlamaCppLLMSpecV1,
|
|
32
37
|
LLMFamilyV1,
|
|
33
38
|
LLMSpecV1,
|
|
34
39
|
PromptStyleV1,
|
|
@@ -60,7 +65,7 @@ class ChatModelMixin:
|
|
|
60
65
|
chat_history: List[ChatCompletionMessage],
|
|
61
66
|
prompt_style: PromptStyleV1,
|
|
62
67
|
tools: Optional[List[Dict]] = None,
|
|
63
|
-
)
|
|
68
|
+
):
|
|
64
69
|
"""
|
|
65
70
|
Inspired by FastChat. Format chat history into a prompt according to the prompty style of
|
|
66
71
|
different models.
|
|
@@ -92,17 +97,6 @@ class ChatModelMixin:
|
|
|
92
97
|
else:
|
|
93
98
|
ret += role + ":"
|
|
94
99
|
return ret
|
|
95
|
-
elif prompt_style.style_name == "ADD_COLON_TWO":
|
|
96
|
-
seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
|
|
97
|
-
ret = prompt_style.system_prompt + seps[0]
|
|
98
|
-
for i, message in enumerate(chat_history):
|
|
99
|
-
role = get_role(message["role"])
|
|
100
|
-
content = message["content"]
|
|
101
|
-
if content:
|
|
102
|
-
ret += role + ": " + content + seps[i % 2]
|
|
103
|
-
else:
|
|
104
|
-
ret += role + ":"
|
|
105
|
-
return ret
|
|
106
100
|
elif prompt_style.style_name == "NO_COLON_TWO":
|
|
107
101
|
seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
|
|
108
102
|
ret = prompt_style.system_prompt
|
|
@@ -144,21 +138,6 @@ class ChatModelMixin:
|
|
|
144
138
|
else:
|
|
145
139
|
ret += f"<|start_header_id|>{role}<|end_header_id|>{prompt_style.intra_message_sep}"
|
|
146
140
|
return ret
|
|
147
|
-
elif prompt_style.style_name == "FALCON":
|
|
148
|
-
ret = prompt_style.system_prompt
|
|
149
|
-
for message in chat_history:
|
|
150
|
-
role = get_role(message["role"])
|
|
151
|
-
content = message["content"]
|
|
152
|
-
if content:
|
|
153
|
-
ret += (
|
|
154
|
-
role
|
|
155
|
-
+ ": "
|
|
156
|
-
+ content.replace("\r\n", "\n").replace("\n\n", "\n")
|
|
157
|
-
)
|
|
158
|
-
ret += "\n\n"
|
|
159
|
-
else:
|
|
160
|
-
ret += role + ":"
|
|
161
|
-
return ret
|
|
162
141
|
elif prompt_style.style_name == "MIXTRAL_V01":
|
|
163
142
|
ret = ""
|
|
164
143
|
for i, message in enumerate(chat_history):
|
|
@@ -168,22 +147,6 @@ class ChatModelMixin:
|
|
|
168
147
|
else: # assistant
|
|
169
148
|
ret += f"{content} </s>"
|
|
170
149
|
return ret
|
|
171
|
-
elif prompt_style.style_name == "CHATGLM":
|
|
172
|
-
round_add_n = 1 if prompt_style.intra_message_sep == "\n\n" else 0
|
|
173
|
-
if prompt_style.system_prompt:
|
|
174
|
-
ret = prompt_style.system_prompt + prompt_style.intra_message_sep
|
|
175
|
-
else:
|
|
176
|
-
ret = ""
|
|
177
|
-
for i, message in enumerate(chat_history):
|
|
178
|
-
role = get_role(message["role"])
|
|
179
|
-
content = message["content"]
|
|
180
|
-
if i % 2 == 0:
|
|
181
|
-
ret += f"[Round {i // 2 + round_add_n}]{prompt_style.intra_message_sep}"
|
|
182
|
-
if content:
|
|
183
|
-
ret += role + ":" + content + prompt_style.intra_message_sep
|
|
184
|
-
else:
|
|
185
|
-
ret += role + ":"
|
|
186
|
-
return ret
|
|
187
150
|
elif prompt_style.style_name == "CHATGLM3":
|
|
188
151
|
prompts = (
|
|
189
152
|
[f"<|system|>\n {prompt_style.system_prompt}"]
|
|
@@ -323,25 +286,6 @@ Begin!"""
|
|
|
323
286
|
else:
|
|
324
287
|
ret += role + "\n"
|
|
325
288
|
return ret
|
|
326
|
-
elif prompt_style.style_name == "INTERNLM":
|
|
327
|
-
seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
|
|
328
|
-
ret = ""
|
|
329
|
-
for i, message in enumerate(chat_history[:-2]):
|
|
330
|
-
if i % 2 == 0:
|
|
331
|
-
ret += "<s>"
|
|
332
|
-
role = get_role(message["role"])
|
|
333
|
-
content = message["content"]
|
|
334
|
-
ret += role + ":" + str(content) + seps[i % 2]
|
|
335
|
-
if len(ret) == 0:
|
|
336
|
-
ret += "<s>"
|
|
337
|
-
ret += (
|
|
338
|
-
chat_history[-2]["role"]
|
|
339
|
-
+ ":"
|
|
340
|
-
+ str(chat_history[-2]["content"])
|
|
341
|
-
+ seps[0]
|
|
342
|
-
)
|
|
343
|
-
ret += chat_history[-1]["role"] + ":"
|
|
344
|
-
return ret
|
|
345
289
|
elif prompt_style.style_name == "INTERNLM2":
|
|
346
290
|
ret = (
|
|
347
291
|
"<s>"
|
|
@@ -370,9 +314,6 @@ Begin!"""
|
|
|
370
314
|
else:
|
|
371
315
|
ret += role + ": Let's think step by step."
|
|
372
316
|
return ret
|
|
373
|
-
elif prompt_style.style_name == "INSTRUCTION":
|
|
374
|
-
message = chat_history[-2]
|
|
375
|
-
return prompt_style.system_prompt.format(message["content"])
|
|
376
317
|
elif prompt_style.style_name == "DEEPSEEK_CHAT":
|
|
377
318
|
seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
|
|
378
319
|
ret = prompt_style.system_prompt
|
|
@@ -504,6 +445,61 @@ Begin!"""
|
|
|
504
445
|
else:
|
|
505
446
|
ret += role
|
|
506
447
|
return ret
|
|
448
|
+
elif prompt_style.style_name == "INTERNVL":
|
|
449
|
+
ret = (
|
|
450
|
+
"<s>"
|
|
451
|
+
if prompt_style.system_prompt == ""
|
|
452
|
+
else "<s><|im_start|>system\n"
|
|
453
|
+
+ prompt_style.system_prompt
|
|
454
|
+
+ prompt_style.intra_message_sep
|
|
455
|
+
+ "\n"
|
|
456
|
+
)
|
|
457
|
+
images = [] # type: ignore
|
|
458
|
+
for message in chat_history:
|
|
459
|
+
role = get_role(message["role"])
|
|
460
|
+
content = message["content"]
|
|
461
|
+
if isinstance(content, str):
|
|
462
|
+
if content:
|
|
463
|
+
ret += (
|
|
464
|
+
role
|
|
465
|
+
+ "\n"
|
|
466
|
+
+ content
|
|
467
|
+
+ prompt_style.intra_message_sep
|
|
468
|
+
+ "\n"
|
|
469
|
+
)
|
|
470
|
+
else:
|
|
471
|
+
ret += role + "\n"
|
|
472
|
+
elif isinstance(content, list):
|
|
473
|
+
text = ""
|
|
474
|
+
image_urls = []
|
|
475
|
+
for c in content:
|
|
476
|
+
c_type = c.get("type")
|
|
477
|
+
if c_type == "text":
|
|
478
|
+
text = c["text"]
|
|
479
|
+
elif c_type == "image_url":
|
|
480
|
+
image_urls.append(c["image_url"]["url"])
|
|
481
|
+
image_futures = []
|
|
482
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
483
|
+
|
|
484
|
+
with ThreadPoolExecutor() as executor:
|
|
485
|
+
for image_url in image_urls:
|
|
486
|
+
fut = executor.submit(_decode_image, image_url)
|
|
487
|
+
image_futures.append(fut)
|
|
488
|
+
images = [fut.result() for fut in image_futures]
|
|
489
|
+
if len(image_futures) == 0:
|
|
490
|
+
ret += (
|
|
491
|
+
role + "\n" + text + prompt_style.intra_message_sep + "\n"
|
|
492
|
+
)
|
|
493
|
+
else:
|
|
494
|
+
ret += (
|
|
495
|
+
role
|
|
496
|
+
+ "\n"
|
|
497
|
+
+ f"<image>\n{text}"
|
|
498
|
+
+ prompt_style.intra_message_sep
|
|
499
|
+
+ "\n"
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
return (ret, images)
|
|
507
503
|
else:
|
|
508
504
|
raise ValueError(f"Invalid prompt style: {prompt_style.style_name}")
|
|
509
505
|
|
|
@@ -706,7 +702,7 @@ Begin!"""
|
|
|
706
702
|
family = model_family.model_family or model_family.model_name
|
|
707
703
|
if family in ["gorilla-openfunctions-v1", "gorilla-openfunctions-v2"]:
|
|
708
704
|
content, func, args = cls._eval_gorilla_openfunctions_arguments(c, tools)
|
|
709
|
-
elif family in
|
|
705
|
+
elif family in GLM4_TOOL_CALL_FAMILY:
|
|
710
706
|
content, func, args = cls._eval_glm_chat_arguments(c, tools)
|
|
711
707
|
elif family in QWEN_TOOL_CALL_FAMILY:
|
|
712
708
|
content, func, args = cls._eval_qwen_chat_arguments(c, tools)
|
|
@@ -870,10 +866,10 @@ def get_file_location(
|
|
|
870
866
|
is_cached = cache_status
|
|
871
867
|
assert isinstance(is_cached, bool)
|
|
872
868
|
|
|
873
|
-
if spec.model_format in ["pytorch", "gptq", "awq", "mlx"]:
|
|
869
|
+
if spec.model_format in ["pytorch", "gptq", "awq", "fp8", "mlx"]:
|
|
874
870
|
return cache_dir, is_cached
|
|
875
|
-
elif spec.model_format in ["
|
|
876
|
-
assert isinstance(spec,
|
|
871
|
+
elif spec.model_format in ["ggufv2"]:
|
|
872
|
+
assert isinstance(spec, LlamaCppLLMSpecV1)
|
|
877
873
|
filename = spec.model_file_name_template.format(quantization=quantization)
|
|
878
874
|
model_path = os.path.join(cache_dir, filename)
|
|
879
875
|
return model_path, is_cached
|
|
@@ -885,3 +881,22 @@ def get_model_version(
|
|
|
885
881
|
llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str
|
|
886
882
|
) -> str:
|
|
887
883
|
return f"{llm_family.model_name}--{llm_spec.model_size_in_billions}B--{llm_spec.model_format}--{quantization}"
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
def _decode_image(_url):
|
|
887
|
+
if _url.startswith("data:"):
|
|
888
|
+
logging.info("Parse url by base64 decoder.")
|
|
889
|
+
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
890
|
+
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
891
|
+
_type, data = _url.split(";")
|
|
892
|
+
_, ext = _type.split("/")
|
|
893
|
+
data = data[len("base64,") :]
|
|
894
|
+
data = base64.b64decode(data.encode("utf-8"))
|
|
895
|
+
return Image.open(BytesIO(data)).convert("RGB")
|
|
896
|
+
else:
|
|
897
|
+
try:
|
|
898
|
+
response = requests.get(_url)
|
|
899
|
+
except requests.exceptions.MissingSchema:
|
|
900
|
+
return Image.open(_url).convert("RGB")
|
|
901
|
+
else:
|
|
902
|
+
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
@@ -21,6 +21,7 @@ import time
|
|
|
21
21
|
import uuid
|
|
22
22
|
from typing import (
|
|
23
23
|
TYPE_CHECKING,
|
|
24
|
+
Any,
|
|
24
25
|
AsyncGenerator,
|
|
25
26
|
Dict,
|
|
26
27
|
Iterable,
|
|
@@ -88,11 +89,12 @@ try:
|
|
|
88
89
|
except ImportError:
|
|
89
90
|
VLLM_INSTALLED = False
|
|
90
91
|
|
|
92
|
+
VLLM_SUPPORTED_VISION_MODEL_LIST: List[str] = [
|
|
93
|
+
"internvl2",
|
|
94
|
+
]
|
|
91
95
|
VLLM_SUPPORTED_MODELS = [
|
|
92
96
|
"llama-2",
|
|
93
97
|
"llama-3",
|
|
94
|
-
"baichuan",
|
|
95
|
-
"internlm-16k",
|
|
96
98
|
"mistral-v0.1",
|
|
97
99
|
"codestral-v0.1",
|
|
98
100
|
"Yi",
|
|
@@ -105,13 +107,7 @@ VLLM_SUPPORTED_MODELS = [
|
|
|
105
107
|
VLLM_SUPPORTED_CHAT_MODELS = [
|
|
106
108
|
"llama-2-chat",
|
|
107
109
|
"llama-3-instruct",
|
|
108
|
-
"vicuna-v1.3",
|
|
109
|
-
"vicuna-v1.5",
|
|
110
|
-
"baichuan-chat",
|
|
111
110
|
"baichuan-2-chat",
|
|
112
|
-
"internlm-chat-7b",
|
|
113
|
-
"internlm-chat-8k",
|
|
114
|
-
"internlm-chat-20b",
|
|
115
111
|
"internlm2-chat",
|
|
116
112
|
"internlm2.5-chat",
|
|
117
113
|
"internlm2.5-chat-1m",
|
|
@@ -338,7 +334,7 @@ class VLLMModel(LLM):
|
|
|
338
334
|
return False
|
|
339
335
|
if not cls._is_linux():
|
|
340
336
|
return False
|
|
341
|
-
if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
337
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
|
|
342
338
|
return False
|
|
343
339
|
if llm_spec.model_format == "pytorch":
|
|
344
340
|
if quantization != "none" and not (quantization is None):
|
|
@@ -421,7 +417,7 @@ class VLLMModel(LLM):
|
|
|
421
417
|
|
|
422
418
|
async def async_generate(
|
|
423
419
|
self,
|
|
424
|
-
prompt: str,
|
|
420
|
+
prompt: Union[str, Dict[str, Any]],
|
|
425
421
|
generate_config: Optional[Dict] = None,
|
|
426
422
|
tools: object = False,
|
|
427
423
|
) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]:
|
|
@@ -558,7 +554,7 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
558
554
|
def match(
|
|
559
555
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
560
556
|
) -> bool:
|
|
561
|
-
if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
557
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
|
|
562
558
|
return False
|
|
563
559
|
if llm_spec.model_format == "pytorch":
|
|
564
560
|
if quantization != "none" and not (quantization is None):
|
|
@@ -644,3 +640,106 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
644
640
|
self.model_family, self.model_uid, c, tools
|
|
645
641
|
)
|
|
646
642
|
return self._to_chat_completion(c)
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
class VLLMVisionModel(VLLMModel, ChatModelMixin):
|
|
646
|
+
def load(self):
|
|
647
|
+
try:
|
|
648
|
+
import vllm
|
|
649
|
+
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
650
|
+
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
651
|
+
except ImportError:
|
|
652
|
+
error_message = "Failed to import module 'vllm'"
|
|
653
|
+
installation_guide = [
|
|
654
|
+
"Please make sure 'vllm' is installed. ",
|
|
655
|
+
"You can install it by `pip install vllm`\n",
|
|
656
|
+
]
|
|
657
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
658
|
+
|
|
659
|
+
if vllm.__version__ >= "0.3.1":
|
|
660
|
+
# from vllm v0.3.1, it uses cupy as NCCL backend
|
|
661
|
+
# in which cupy will fork a process
|
|
662
|
+
# only for xoscar >= 0.3.0, new process is allowed in subpool
|
|
663
|
+
# besides, xinference set start method as forkserver for unix
|
|
664
|
+
# we need to set it to fork to make cupy NCCL work
|
|
665
|
+
multiprocessing.set_start_method("fork", force=True)
|
|
666
|
+
|
|
667
|
+
self._model_config = self._sanitize_model_config(self._model_config)
|
|
668
|
+
|
|
669
|
+
logger.info(
|
|
670
|
+
f"Loading {self.model_uid} with following model config: {self._model_config}"
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
engine_args = AsyncEngineArgs(
|
|
674
|
+
model=self.model_path,
|
|
675
|
+
**self._model_config,
|
|
676
|
+
)
|
|
677
|
+
self._engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
678
|
+
|
|
679
|
+
@classmethod
|
|
680
|
+
def match(
|
|
681
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
682
|
+
) -> bool:
|
|
683
|
+
if llm_spec.model_format != "pytorch":
|
|
684
|
+
return False
|
|
685
|
+
if llm_spec.model_format == "pytorch":
|
|
686
|
+
if quantization != "none" and not (quantization is None):
|
|
687
|
+
return False
|
|
688
|
+
if isinstance(llm_family, CustomLLMFamilyV1):
|
|
689
|
+
if llm_family.model_family not in VLLM_SUPPORTED_VISION_MODEL_LIST:
|
|
690
|
+
return False
|
|
691
|
+
else:
|
|
692
|
+
if llm_family.model_name not in VLLM_SUPPORTED_VISION_MODEL_LIST:
|
|
693
|
+
return False
|
|
694
|
+
if "vision" not in llm_family.model_ability:
|
|
695
|
+
return False
|
|
696
|
+
return VLLM_INSTALLED
|
|
697
|
+
|
|
698
|
+
def _sanitize_chat_config(
|
|
699
|
+
self,
|
|
700
|
+
generate_config: Optional[Dict] = None,
|
|
701
|
+
) -> Dict:
|
|
702
|
+
if not generate_config:
|
|
703
|
+
generate_config = {}
|
|
704
|
+
if self.model_family.prompt_style:
|
|
705
|
+
if self.model_family.prompt_style.stop_token_ids:
|
|
706
|
+
generate_config.setdefault(
|
|
707
|
+
"stop_token_ids",
|
|
708
|
+
self.model_family.prompt_style.stop_token_ids.copy(),
|
|
709
|
+
)
|
|
710
|
+
return generate_config
|
|
711
|
+
|
|
712
|
+
async def async_chat(
|
|
713
|
+
self,
|
|
714
|
+
prompt: str,
|
|
715
|
+
system_prompt: Optional[str] = None,
|
|
716
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
717
|
+
generate_config: Optional[Dict] = None,
|
|
718
|
+
) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
|
|
719
|
+
# only support single image, waiting vllm support multi images
|
|
720
|
+
assert self.model_family.prompt_style is not None
|
|
721
|
+
prompt_style = self.model_family.prompt_style.copy()
|
|
722
|
+
chat_history = chat_history or []
|
|
723
|
+
prompt, images = self.get_prompt(prompt, chat_history, prompt_style)
|
|
724
|
+
|
|
725
|
+
if len(images) == 0:
|
|
726
|
+
inputs = {
|
|
727
|
+
"prompt": prompt,
|
|
728
|
+
}
|
|
729
|
+
else:
|
|
730
|
+
inputs = {
|
|
731
|
+
"prompt": prompt,
|
|
732
|
+
"multi_modal_data": {"image": images[-1]}, # type: ignore
|
|
733
|
+
}
|
|
734
|
+
generate_config = self._sanitize_chat_config(generate_config)
|
|
735
|
+
|
|
736
|
+
stream = generate_config.get("stream", None)
|
|
737
|
+
|
|
738
|
+
if stream:
|
|
739
|
+
agen = await self.async_generate(inputs, generate_config)
|
|
740
|
+
assert isinstance(agen, AsyncGenerator)
|
|
741
|
+
return self._async_to_chat_completion_chunks(agen)
|
|
742
|
+
else:
|
|
743
|
+
c = await self.async_generate(inputs, generate_config)
|
|
744
|
+
assert not isinstance(c, AsyncGenerator)
|
|
745
|
+
return self._to_chat_completion(c)
|