xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +15 -34
- xinference/client/restful/restful_client.py +2 -2
- xinference/core/chat_interface.py +45 -10
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +8 -5
- xinference/core/scheduler.py +1 -2
- xinference/core/worker.py +49 -42
- xinference/deploy/cmdline.py +2 -2
- xinference/deploy/test/test_cmdline.py +7 -7
- xinference/model/audio/chattts.py +24 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +23 -1
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +49 -1
- xinference/model/llm/__init__.py +26 -27
- xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
- xinference/model/llm/llm_family.json +606 -1266
- xinference/model/llm/llm_family.py +16 -139
- xinference/model/llm/llm_family_modelscope.json +276 -313
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/memory.py +9 -9
- xinference/model/llm/sglang/core.py +2 -2
- xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
- xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/{pytorch → transformers}/core.py +3 -10
- xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +540 -0
- xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
- xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
- xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
- xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
- xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
- xinference/model/llm/utils.py +85 -70
- xinference/model/llm/vllm/core.py +110 -11
- xinference/model/utils.py +1 -95
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/internvl/__init__.py +0 -0
- xinference/thirdparty/internvl/conversation.py +393 -0
- xinference/thirdparty/omnilmm/model/utils.py +16 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
- xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
- xinference/locale/utils.py +0 -39
- xinference/locale/zh_CN.json +0 -26
- xinference/model/llm/ggml/tools/__init__.py +0 -15
- xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
- xinference/model/llm/ggml/tools/gguf.py +0 -884
- xinference/model/llm/pytorch/__init__.py +0 -13
- xinference/model/llm/pytorch/baichuan.py +0 -81
- xinference/model/llm/pytorch/falcon.py +0 -138
- xinference/model/llm/pytorch/intern_vl.py +0 -352
- xinference/model/llm/pytorch/vicuna.py +0 -69
- xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
- xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
- /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
- /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
- /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
|
@@ -11,19 +11,15 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import base64
|
|
15
14
|
import logging
|
|
16
15
|
import time
|
|
17
16
|
import typing
|
|
18
17
|
import uuid
|
|
19
18
|
from concurrent.futures import ThreadPoolExecutor
|
|
20
|
-
from io import BytesIO
|
|
21
19
|
from threading import Thread
|
|
22
20
|
from typing import Dict, Iterator, List, Optional, Union
|
|
23
21
|
|
|
24
|
-
import requests
|
|
25
22
|
import torch
|
|
26
|
-
from PIL import Image
|
|
27
23
|
|
|
28
24
|
from ....core.scheduler import InferenceRequest
|
|
29
25
|
from ....types import (
|
|
@@ -37,6 +33,7 @@ from ....types import (
|
|
|
37
33
|
)
|
|
38
34
|
from ...utils import select_device
|
|
39
35
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
36
|
+
from ..utils import _decode_image
|
|
40
37
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
41
38
|
from .utils import get_max_src_len
|
|
42
39
|
|
|
@@ -106,24 +103,6 @@ class Glm4VModel(PytorchChatModel):
|
|
|
106
103
|
self._save_tensorizer()
|
|
107
104
|
|
|
108
105
|
def _message_content_to_chat(self, content):
|
|
109
|
-
def _load_image(_url):
|
|
110
|
-
if _url.startswith("data:"):
|
|
111
|
-
logging.info("Parse url by base64 decoder.")
|
|
112
|
-
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
113
|
-
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
114
|
-
_type, data = _url.split(";")
|
|
115
|
-
_, ext = _type.split("/")
|
|
116
|
-
data = data[len("base64,") :]
|
|
117
|
-
data = base64.b64decode(data.encode("utf-8"))
|
|
118
|
-
return Image.open(BytesIO(data)).convert("RGB")
|
|
119
|
-
else:
|
|
120
|
-
try:
|
|
121
|
-
response = requests.get(_url)
|
|
122
|
-
except requests.exceptions.MissingSchema:
|
|
123
|
-
return Image.open(_url).convert("RGB")
|
|
124
|
-
else:
|
|
125
|
-
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
126
|
-
|
|
127
106
|
if not isinstance(content, str):
|
|
128
107
|
texts = []
|
|
129
108
|
image_urls = []
|
|
@@ -136,7 +115,7 @@ class Glm4VModel(PytorchChatModel):
|
|
|
136
115
|
image_futures = []
|
|
137
116
|
with ThreadPoolExecutor() as executor:
|
|
138
117
|
for image_url in image_urls:
|
|
139
|
-
fut = executor.submit(
|
|
118
|
+
fut = executor.submit(_decode_image, image_url)
|
|
140
119
|
image_futures.append(fut)
|
|
141
120
|
images = [fut.result() for fut in image_futures]
|
|
142
121
|
text = " ".join(texts)
|
|
@@ -0,0 +1,540 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
import time
|
|
16
|
+
import uuid
|
|
17
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
18
|
+
from typing import Dict, Iterator, List, Optional, Union
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from ....types import (
|
|
23
|
+
ChatCompletion,
|
|
24
|
+
ChatCompletionChunk,
|
|
25
|
+
ChatCompletionMessage,
|
|
26
|
+
Completion,
|
|
27
|
+
CompletionChoice,
|
|
28
|
+
CompletionChunk,
|
|
29
|
+
CompletionUsage,
|
|
30
|
+
)
|
|
31
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
32
|
+
from ..utils import _decode_image
|
|
33
|
+
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
|
38
|
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _message_content_to_intern(content, image_cnt):
|
|
42
|
+
if not isinstance(content, str):
|
|
43
|
+
texts = []
|
|
44
|
+
image_urls = []
|
|
45
|
+
video_urls = []
|
|
46
|
+
for c in content:
|
|
47
|
+
c_type = c.get("type")
|
|
48
|
+
if c_type == "text":
|
|
49
|
+
texts.append(c["text"])
|
|
50
|
+
elif c_type == "image_url":
|
|
51
|
+
image_urls.append(c["image_url"]["url"])
|
|
52
|
+
elif c_type == "video_url":
|
|
53
|
+
video_urls.append(c["video_url"]["url"])
|
|
54
|
+
if len(video_urls) > 1:
|
|
55
|
+
raise RuntimeError("Only one video per message is supported")
|
|
56
|
+
image_futures = []
|
|
57
|
+
with ThreadPoolExecutor() as executor:
|
|
58
|
+
for image_url in image_urls:
|
|
59
|
+
fut = executor.submit(_decode_image, image_url)
|
|
60
|
+
image_futures.append(fut)
|
|
61
|
+
images = [fut.result() for fut in image_futures]
|
|
62
|
+
videos = []
|
|
63
|
+
for vid_url in video_urls:
|
|
64
|
+
videos.append(_load_video(vid_url, num_segments=8, max_num=1))
|
|
65
|
+
prefix = ""
|
|
66
|
+
for i, _ in enumerate(images):
|
|
67
|
+
prefix += f"Image-{image_cnt + i + 1}: <image>\n\n"
|
|
68
|
+
|
|
69
|
+
if len(videos) > 0:
|
|
70
|
+
prefix = "".join(
|
|
71
|
+
[f"Frame{i+1}: <image>\n" for i in range(len(videos[0][1]))]
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
text = prefix + " ".join(texts)
|
|
75
|
+
return text, images, videos
|
|
76
|
+
return content, [], []
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _get_prompt_and_chat_history(
|
|
80
|
+
prompt: Union[str, List[Dict]],
|
|
81
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
82
|
+
):
|
|
83
|
+
# Convert openai history to intern vl history
|
|
84
|
+
images = []
|
|
85
|
+
videos = []
|
|
86
|
+
history = []
|
|
87
|
+
image_cnt = 0
|
|
88
|
+
for h1, h2 in zip(*[iter(chat_history or [])] * 2):
|
|
89
|
+
content1, img, vid = _message_content_to_intern(h1["content"], image_cnt)
|
|
90
|
+
content2, _, _ = _message_content_to_intern(h2["content"], image_cnt)
|
|
91
|
+
history.append([content1, content2])
|
|
92
|
+
images.extend(img)
|
|
93
|
+
image_cnt += len(img)
|
|
94
|
+
videos.extend(vid)
|
|
95
|
+
|
|
96
|
+
question, img, vid = _message_content_to_intern(prompt, image_cnt)
|
|
97
|
+
images.extend(img)
|
|
98
|
+
videos.extend(vid)
|
|
99
|
+
return question, history, images, videos
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _build_transform(input_size=448):
|
|
103
|
+
import torchvision.transforms as T
|
|
104
|
+
from torchvision.transforms.functional import InterpolationMode
|
|
105
|
+
|
|
106
|
+
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
|
107
|
+
transform = T.Compose(
|
|
108
|
+
[
|
|
109
|
+
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
|
110
|
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
|
111
|
+
T.ToTensor(),
|
|
112
|
+
T.Normalize(mean=MEAN, std=STD),
|
|
113
|
+
]
|
|
114
|
+
)
|
|
115
|
+
return transform
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
|
119
|
+
best_ratio_diff = float("inf")
|
|
120
|
+
best_ratio = (1, 1)
|
|
121
|
+
area = width * height
|
|
122
|
+
for ratio in target_ratios:
|
|
123
|
+
target_aspect_ratio = ratio[0] / ratio[1]
|
|
124
|
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
|
125
|
+
if ratio_diff < best_ratio_diff:
|
|
126
|
+
best_ratio_diff = ratio_diff
|
|
127
|
+
best_ratio = ratio
|
|
128
|
+
elif ratio_diff == best_ratio_diff:
|
|
129
|
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
|
130
|
+
best_ratio = ratio
|
|
131
|
+
return best_ratio
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _dynamic_preprocess(
|
|
135
|
+
image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
|
|
136
|
+
):
|
|
137
|
+
orig_width, orig_height = image.size
|
|
138
|
+
aspect_ratio = orig_width / orig_height
|
|
139
|
+
|
|
140
|
+
# calculate the existing image aspect ratio
|
|
141
|
+
target_ratios = set(
|
|
142
|
+
(i, j)
|
|
143
|
+
for n in range(min_num, max_num + 1)
|
|
144
|
+
for i in range(1, n + 1)
|
|
145
|
+
for j in range(1, n + 1)
|
|
146
|
+
if i * j <= max_num and i * j >= min_num
|
|
147
|
+
)
|
|
148
|
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
|
149
|
+
|
|
150
|
+
# find the closest aspect ratio to the target
|
|
151
|
+
target_aspect_ratio = _find_closest_aspect_ratio(
|
|
152
|
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# calculate the target width and height
|
|
156
|
+
target_width = image_size * target_aspect_ratio[0]
|
|
157
|
+
target_height = image_size * target_aspect_ratio[1]
|
|
158
|
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
|
159
|
+
|
|
160
|
+
# resize the image
|
|
161
|
+
resized_img = image.resize((target_width, target_height))
|
|
162
|
+
processed_images = []
|
|
163
|
+
for i in range(blocks):
|
|
164
|
+
box = (
|
|
165
|
+
(i % (target_width // image_size)) * image_size,
|
|
166
|
+
(i // (target_width // image_size)) * image_size,
|
|
167
|
+
((i % (target_width // image_size)) + 1) * image_size,
|
|
168
|
+
((i // (target_width // image_size)) + 1) * image_size,
|
|
169
|
+
)
|
|
170
|
+
# split the image
|
|
171
|
+
split_img = resized_img.crop(box)
|
|
172
|
+
processed_images.append(split_img)
|
|
173
|
+
assert len(processed_images) == blocks
|
|
174
|
+
if use_thumbnail and len(processed_images) != 1:
|
|
175
|
+
thumbnail_img = image.resize((image_size, image_size))
|
|
176
|
+
processed_images.append(thumbnail_img)
|
|
177
|
+
return processed_images
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _load_image(image_file, input_size=448, max_num=12):
|
|
181
|
+
image = image_file.convert("RGB")
|
|
182
|
+
transform = _build_transform(input_size=input_size)
|
|
183
|
+
images = _dynamic_preprocess(
|
|
184
|
+
image, image_size=input_size, use_thumbnail=True, max_num=max_num
|
|
185
|
+
)
|
|
186
|
+
pixel_values = [transform(image) for image in images]
|
|
187
|
+
pixel_values = torch.stack(pixel_values)
|
|
188
|
+
return pixel_values
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
# video multi-round conversation
|
|
192
|
+
def _get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
|
|
193
|
+
import numpy as np
|
|
194
|
+
|
|
195
|
+
if bound:
|
|
196
|
+
start, end = bound[0], bound[1]
|
|
197
|
+
else:
|
|
198
|
+
start, end = -100000, 100000
|
|
199
|
+
start_idx = max(first_idx, round(start * fps))
|
|
200
|
+
end_idx = min(round(end * fps), max_frame)
|
|
201
|
+
seg_size = float(end_idx - start_idx) / num_segments
|
|
202
|
+
frame_indices = np.array(
|
|
203
|
+
[
|
|
204
|
+
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
|
205
|
+
for idx in range(num_segments)
|
|
206
|
+
]
|
|
207
|
+
)
|
|
208
|
+
return frame_indices
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
|
|
212
|
+
from decord import VideoReader, cpu
|
|
213
|
+
from PIL import Image
|
|
214
|
+
|
|
215
|
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
|
216
|
+
max_frame = len(vr) - 1
|
|
217
|
+
fps = float(vr.get_avg_fps())
|
|
218
|
+
|
|
219
|
+
pixel_values_list, num_patches_list = [], []
|
|
220
|
+
transform = _build_transform(input_size=input_size)
|
|
221
|
+
frame_indices = _get_index(
|
|
222
|
+
bound, fps, max_frame, first_idx=0, num_segments=num_segments
|
|
223
|
+
)
|
|
224
|
+
for frame_index in frame_indices:
|
|
225
|
+
img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
|
|
226
|
+
img = _dynamic_preprocess(
|
|
227
|
+
img, image_size=input_size, use_thumbnail=True, max_num=max_num
|
|
228
|
+
)
|
|
229
|
+
pixel_values = [transform(tile) for tile in img]
|
|
230
|
+
pixel_values = torch.stack(pixel_values)
|
|
231
|
+
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
|
232
|
+
num_patches_list.append(pixel_values.shape[0])
|
|
233
|
+
pixel_values_list.append(pixel_values)
|
|
234
|
+
pixel_values = torch.cat(pixel_values_list)
|
|
235
|
+
return pixel_values, num_patches_list
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class InternVLChatModel(PytorchChatModel):
|
|
239
|
+
def __init__(self, *args, **kwargs):
|
|
240
|
+
super().__init__(*args, **kwargs)
|
|
241
|
+
self._tokenizer = None
|
|
242
|
+
self._model = None
|
|
243
|
+
|
|
244
|
+
@classmethod
|
|
245
|
+
def match(
|
|
246
|
+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
247
|
+
) -> bool:
|
|
248
|
+
family = model_family.model_family or model_family.model_name
|
|
249
|
+
if "internvl" not in family.lower():
|
|
250
|
+
return False
|
|
251
|
+
if "pytorch" not in model_spec.model_format:
|
|
252
|
+
return False
|
|
253
|
+
return True
|
|
254
|
+
|
|
255
|
+
def _get_model_class(self):
|
|
256
|
+
from transformers import AutoModel
|
|
257
|
+
|
|
258
|
+
return AutoModel
|
|
259
|
+
|
|
260
|
+
# Copy from InternVL page
|
|
261
|
+
# reference: https://huggingface.co/OpenGVLab/InternVL2-8B
|
|
262
|
+
def _split_model(self):
|
|
263
|
+
import math
|
|
264
|
+
|
|
265
|
+
device_map = {}
|
|
266
|
+
world_size = torch.cuda.device_count()
|
|
267
|
+
# single gpu
|
|
268
|
+
if world_size == 1:
|
|
269
|
+
return None
|
|
270
|
+
model_size = f"{self.model_spec.model_size_in_billions}B"
|
|
271
|
+
num_layers = {
|
|
272
|
+
"1B": 24,
|
|
273
|
+
"2B": 24,
|
|
274
|
+
"4B": 32,
|
|
275
|
+
"8B": 32,
|
|
276
|
+
"26B": 48,
|
|
277
|
+
"40B": 60,
|
|
278
|
+
"76B": 80,
|
|
279
|
+
}[model_size]
|
|
280
|
+
# Since the first GPU will be used for ViT, treat it as half a GPU.
|
|
281
|
+
num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
|
|
282
|
+
num_layers_per_gpu = [num_layers_per_gpu] * world_size
|
|
283
|
+
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
|
|
284
|
+
layer_cnt = 0
|
|
285
|
+
for i, num_layer in enumerate(num_layers_per_gpu):
|
|
286
|
+
for j in range(num_layer):
|
|
287
|
+
device_map[f"language_model.model.layers.{layer_cnt}"] = i
|
|
288
|
+
layer_cnt += 1
|
|
289
|
+
device_map["vision_model"] = 0
|
|
290
|
+
device_map["mlp1"] = 0
|
|
291
|
+
device_map["language_model.model.tok_embeddings"] = 0
|
|
292
|
+
device_map["language_model.model.embed_tokens"] = 0
|
|
293
|
+
device_map["language_model.output"] = 0
|
|
294
|
+
device_map["language_model.model.norm"] = 0
|
|
295
|
+
device_map["language_model.lm_head"] = 0
|
|
296
|
+
device_map[f"language_model.model.layers.{num_layers - 1}"] = 0
|
|
297
|
+
return device_map
|
|
298
|
+
|
|
299
|
+
def load(self, **kwargs):
|
|
300
|
+
from transformers import AutoModel, AutoTokenizer
|
|
301
|
+
|
|
302
|
+
if self._check_tensorizer_integrity():
|
|
303
|
+
self._model, self._tokenizer = self._load_tensorizer()
|
|
304
|
+
return
|
|
305
|
+
|
|
306
|
+
device = self._split_model()
|
|
307
|
+
|
|
308
|
+
kwargs = {
|
|
309
|
+
"torch_dtype": torch.bfloat16,
|
|
310
|
+
"low_cpu_mem_usage": True,
|
|
311
|
+
"trust_remote_code": True,
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
if device is not None:
|
|
315
|
+
kwargs["device_map"] = device
|
|
316
|
+
|
|
317
|
+
if "8-bit" in self.quantization.lower():
|
|
318
|
+
kwargs["load_in_8bit"] = True
|
|
319
|
+
elif "4-bit" in self.quantization.lower():
|
|
320
|
+
kwargs["load_in_4bit"] = True
|
|
321
|
+
|
|
322
|
+
self._model = AutoModel.from_pretrained(self.model_path, **kwargs).eval()
|
|
323
|
+
|
|
324
|
+
if device is None and "none" in self.quantization.lower():
|
|
325
|
+
self._model.cuda()
|
|
326
|
+
|
|
327
|
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
328
|
+
self.model_path,
|
|
329
|
+
trust_remote_code=True,
|
|
330
|
+
use_fast=False,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
def chat(
|
|
334
|
+
self,
|
|
335
|
+
prompt: Union[str, List[Dict]],
|
|
336
|
+
system_prompt: Optional[str] = None,
|
|
337
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
338
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
339
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
340
|
+
from ....thirdparty.internvl.conversation import get_conv_template
|
|
341
|
+
|
|
342
|
+
IMG_START_TOKEN = "<img>"
|
|
343
|
+
IMG_END_TOKEN = "</img>"
|
|
344
|
+
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
|
|
345
|
+
|
|
346
|
+
generation_config = {
|
|
347
|
+
"max_new_tokens": generate_config.get("max_tokens", 1024)
|
|
348
|
+
if generate_config
|
|
349
|
+
else 1024,
|
|
350
|
+
"do_sample": False,
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
stream = (
|
|
354
|
+
generate_config.get("stream", False)
|
|
355
|
+
if isinstance(generate_config, dict)
|
|
356
|
+
else False
|
|
357
|
+
)
|
|
358
|
+
stream_options = (
|
|
359
|
+
generate_config.get("stream_options", None)
|
|
360
|
+
if isinstance(generate_config, dict)
|
|
361
|
+
else False
|
|
362
|
+
)
|
|
363
|
+
include_usage = (
|
|
364
|
+
stream_options["include_usage"]
|
|
365
|
+
if isinstance(stream_options, dict)
|
|
366
|
+
else False
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
content, history, images, videos = _get_prompt_and_chat_history(
|
|
370
|
+
prompt, chat_history
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
num_patches_list = []
|
|
374
|
+
if len(images) == 1:
|
|
375
|
+
content = content.replace("Image-1: <image>\n\n", "<image>\n")
|
|
376
|
+
history = [
|
|
377
|
+
[item[0].replace("Image-1: <image>\n\n", "<image>\n"), item[1]]
|
|
378
|
+
for item in history
|
|
379
|
+
]
|
|
380
|
+
pixel_values = _load_image(images[-1], max_num=12).to(torch.bfloat16).cuda()
|
|
381
|
+
num_patches_list = (
|
|
382
|
+
[pixel_values.shape[0]] if pixel_values is not None else []
|
|
383
|
+
)
|
|
384
|
+
elif len(images) > 1:
|
|
385
|
+
pixel_values = [
|
|
386
|
+
_load_image(img, max_num=12).to(torch.bfloat16).cuda() for img in images
|
|
387
|
+
]
|
|
388
|
+
num_patches_list = [values.size(0) for values in pixel_values]
|
|
389
|
+
pixel_values = torch.cat(pixel_values, dim=0)
|
|
390
|
+
else:
|
|
391
|
+
pixel_values = None
|
|
392
|
+
|
|
393
|
+
if len(videos) > 0:
|
|
394
|
+
pixel_values = videos[0][0]
|
|
395
|
+
num_patches_list = videos[0][1]
|
|
396
|
+
|
|
397
|
+
assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
|
|
398
|
+
|
|
399
|
+
img_context_token_id = self._tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
|
400
|
+
self._model.img_context_token_id = img_context_token_id
|
|
401
|
+
|
|
402
|
+
template = get_conv_template(self._model.template)
|
|
403
|
+
template.system_message = self._model.system_message
|
|
404
|
+
eos_token_id = self._tokenizer.convert_tokens_to_ids(template.sep)
|
|
405
|
+
|
|
406
|
+
history = [] if history is None else history
|
|
407
|
+
for old_question, old_answer in history:
|
|
408
|
+
template.append_message(template.roles[0], old_question)
|
|
409
|
+
template.append_message(template.roles[1], old_answer)
|
|
410
|
+
template.append_message(template.roles[0], content)
|
|
411
|
+
template.append_message(template.roles[1], None)
|
|
412
|
+
query = template.get_prompt()
|
|
413
|
+
|
|
414
|
+
for num_patches in num_patches_list:
|
|
415
|
+
image_tokens = (
|
|
416
|
+
IMG_START_TOKEN
|
|
417
|
+
+ IMG_CONTEXT_TOKEN * self._model.num_image_token * num_patches
|
|
418
|
+
+ IMG_END_TOKEN
|
|
419
|
+
)
|
|
420
|
+
query = query.replace("<image>", image_tokens, 1)
|
|
421
|
+
|
|
422
|
+
model_inputs = self._tokenizer(query, return_tensors="pt")
|
|
423
|
+
input_ids = model_inputs["input_ids"].cuda()
|
|
424
|
+
attention_mask = model_inputs["attention_mask"].cuda()
|
|
425
|
+
generation_config["eos_token_id"] = eos_token_id
|
|
426
|
+
generate_kwargs = {
|
|
427
|
+
"pixel_values": pixel_values,
|
|
428
|
+
"input_ids": input_ids,
|
|
429
|
+
"attention_mask": attention_mask,
|
|
430
|
+
}
|
|
431
|
+
generate_kwargs.update(generation_config)
|
|
432
|
+
|
|
433
|
+
if stream:
|
|
434
|
+
chunk = self._generate_stream(generate_kwargs, input_ids, include_usage)
|
|
435
|
+
return self._to_chat_completion_chunks(chunk)
|
|
436
|
+
else:
|
|
437
|
+
chunk = self._generate(generate_kwargs, input_ids, template)
|
|
438
|
+
return self._to_chat_completion(chunk)
|
|
439
|
+
|
|
440
|
+
def _generate(self, generate_kwargs, input_ids, template):
|
|
441
|
+
prompt_tokens = len(input_ids[0])
|
|
442
|
+
generation_output = self._model.generate(**generate_kwargs)
|
|
443
|
+
completion_tokens = len(generation_output[0])
|
|
444
|
+
response = self._tokenizer.batch_decode(
|
|
445
|
+
generation_output, skip_special_tokens=True
|
|
446
|
+
)[0]
|
|
447
|
+
response = response.split(template.sep)[0].strip()
|
|
448
|
+
chunk = Completion(
|
|
449
|
+
id=str(uuid.uuid1()),
|
|
450
|
+
object="text_completion",
|
|
451
|
+
created=int(time.time()),
|
|
452
|
+
model=self.model_uid,
|
|
453
|
+
choices=[
|
|
454
|
+
CompletionChoice(
|
|
455
|
+
index=0, text=response, finish_reason="stop", logprobs=None
|
|
456
|
+
)
|
|
457
|
+
],
|
|
458
|
+
usage=CompletionUsage(
|
|
459
|
+
prompt_tokens=prompt_tokens,
|
|
460
|
+
completion_tokens=completion_tokens,
|
|
461
|
+
total_tokens=prompt_tokens + completion_tokens,
|
|
462
|
+
),
|
|
463
|
+
)
|
|
464
|
+
return chunk
|
|
465
|
+
|
|
466
|
+
def _generate_stream(self, generate_kwargs, input_ids, include_usage):
|
|
467
|
+
from threading import Thread
|
|
468
|
+
|
|
469
|
+
from transformers import TextIteratorStreamer
|
|
470
|
+
|
|
471
|
+
# Initialize the streamer
|
|
472
|
+
streamer = TextIteratorStreamer(
|
|
473
|
+
self._tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
|
|
474
|
+
)
|
|
475
|
+
# Define the generation configuration
|
|
476
|
+
generate_kwargs["streamer"] = streamer
|
|
477
|
+
# Start the model chat in a separate thread
|
|
478
|
+
thread = Thread(
|
|
479
|
+
target=self._model.generate,
|
|
480
|
+
kwargs=generate_kwargs,
|
|
481
|
+
)
|
|
482
|
+
thread.start()
|
|
483
|
+
|
|
484
|
+
completion_id = str(uuid.uuid1())
|
|
485
|
+
prompt_tokens = len(input_ids[0])
|
|
486
|
+
completion_tokens = 0
|
|
487
|
+
# Loop through the streamer to get the new text as it is generated
|
|
488
|
+
for i, new_text in enumerate(streamer):
|
|
489
|
+
if new_text == self._model.conv_template.sep:
|
|
490
|
+
break
|
|
491
|
+
completion_choice = CompletionChoice(
|
|
492
|
+
text=new_text, index=0, logprobs=None, finish_reason=None
|
|
493
|
+
)
|
|
494
|
+
chunk = CompletionChunk(
|
|
495
|
+
id=completion_id,
|
|
496
|
+
object="text_completion",
|
|
497
|
+
created=int(time.time()),
|
|
498
|
+
model=self.model_uid,
|
|
499
|
+
choices=[completion_choice],
|
|
500
|
+
)
|
|
501
|
+
completion_tokens = max(completion_tokens, len(streamer.token_cache))
|
|
502
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
503
|
+
completion_usage = CompletionUsage(
|
|
504
|
+
prompt_tokens=prompt_tokens,
|
|
505
|
+
completion_tokens=completion_tokens,
|
|
506
|
+
total_tokens=total_tokens,
|
|
507
|
+
)
|
|
508
|
+
chunk["usage"] = completion_usage
|
|
509
|
+
yield chunk
|
|
510
|
+
completion_choice = CompletionChoice(
|
|
511
|
+
text="", index=0, logprobs=None, finish_reason="stop"
|
|
512
|
+
)
|
|
513
|
+
chunk = CompletionChunk(
|
|
514
|
+
id=completion_id,
|
|
515
|
+
object="text_completion",
|
|
516
|
+
created=int(time.time()),
|
|
517
|
+
model=self.model_uid,
|
|
518
|
+
choices=[completion_choice],
|
|
519
|
+
)
|
|
520
|
+
completion_usage = CompletionUsage(
|
|
521
|
+
prompt_tokens=prompt_tokens,
|
|
522
|
+
completion_tokens=completion_tokens,
|
|
523
|
+
total_tokens=total_tokens,
|
|
524
|
+
)
|
|
525
|
+
chunk["usage"] = completion_usage
|
|
526
|
+
yield chunk
|
|
527
|
+
if include_usage:
|
|
528
|
+
chunk = CompletionChunk(
|
|
529
|
+
id=completion_id,
|
|
530
|
+
object="text_completion",
|
|
531
|
+
created=int(time.time()),
|
|
532
|
+
model=self.model_uid,
|
|
533
|
+
choices=[],
|
|
534
|
+
)
|
|
535
|
+
chunk["usage"] = CompletionUsage(
|
|
536
|
+
prompt_tokens=prompt_tokens,
|
|
537
|
+
completion_tokens=completion_tokens,
|
|
538
|
+
total_tokens=total_tokens,
|
|
539
|
+
)
|
|
540
|
+
yield chunk
|
|
@@ -85,14 +85,10 @@ class Internlm2PytorchChatModel(PytorchChatModel):
|
|
|
85
85
|
def match(
|
|
86
86
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
87
87
|
) -> bool:
|
|
88
|
-
if llm_spec.model_format != "pytorch":
|
|
89
|
-
return False
|
|
90
88
|
model_family = llm_family.model_family or llm_family.model_name
|
|
91
|
-
if model_family
|
|
92
|
-
return
|
|
93
|
-
|
|
94
|
-
return False
|
|
95
|
-
return True
|
|
89
|
+
if model_family in ["internlm2-chat", "internlm2.5-chat"]:
|
|
90
|
+
return True
|
|
91
|
+
return False
|
|
96
92
|
|
|
97
93
|
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
98
94
|
"""
|
|
@@ -153,7 +149,7 @@ class Internlm2PytorchChatModel(PytorchChatModel):
|
|
|
153
149
|
inputs = inputs.to(self._model.device)
|
|
154
150
|
prompt_tokens = len(inputs["input_ids"][0])
|
|
155
151
|
for chunk_text, _ in self._model.stream_chat(
|
|
156
|
-
self._tokenizer, prompt,
|
|
152
|
+
self._tokenizer, prompt, input_history, **kwargs
|
|
157
153
|
):
|
|
158
154
|
completion_tokens = completion_tokens + 1
|
|
159
155
|
total_tokens = prompt_tokens + completion_tokens
|
|
@@ -11,18 +11,14 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import base64
|
|
15
14
|
import json
|
|
16
15
|
import logging
|
|
17
16
|
import time
|
|
18
17
|
import uuid
|
|
19
18
|
from concurrent.futures import ThreadPoolExecutor
|
|
20
|
-
from io import BytesIO
|
|
21
19
|
from typing import Dict, Iterator, List, Optional, Union
|
|
22
20
|
|
|
23
|
-
import requests
|
|
24
21
|
import torch
|
|
25
|
-
from PIL import Image
|
|
26
22
|
|
|
27
23
|
from ....types import (
|
|
28
24
|
ChatCompletion,
|
|
@@ -35,6 +31,7 @@ from ....types import (
|
|
|
35
31
|
)
|
|
36
32
|
from ...utils import select_device
|
|
37
33
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
34
|
+
from ..utils import _decode_image
|
|
38
35
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
39
36
|
|
|
40
37
|
logger = logging.getLogger(__name__)
|
|
@@ -102,24 +99,6 @@ class MiniCPMV25Model(PytorchChatModel):
|
|
|
102
99
|
self._save_tensorizer()
|
|
103
100
|
|
|
104
101
|
def _message_content_to_chat(self, content):
|
|
105
|
-
def _load_image(_url):
|
|
106
|
-
if _url.startswith("data:"):
|
|
107
|
-
logging.info("Parse url by base64 decoder.")
|
|
108
|
-
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
109
|
-
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
110
|
-
_type, data = _url.split(";")
|
|
111
|
-
_, ext = _type.split("/")
|
|
112
|
-
data = data[len("base64,") :]
|
|
113
|
-
data = base64.b64decode(data.encode("utf-8"))
|
|
114
|
-
return Image.open(BytesIO(data)).convert("RGB")
|
|
115
|
-
else:
|
|
116
|
-
try:
|
|
117
|
-
response = requests.get(_url)
|
|
118
|
-
except requests.exceptions.MissingSchema:
|
|
119
|
-
return Image.open(_url).convert("RGB")
|
|
120
|
-
else:
|
|
121
|
-
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
122
|
-
|
|
123
102
|
if not isinstance(content, str):
|
|
124
103
|
texts = []
|
|
125
104
|
image_urls = []
|
|
@@ -132,7 +111,7 @@ class MiniCPMV25Model(PytorchChatModel):
|
|
|
132
111
|
image_futures = []
|
|
133
112
|
with ThreadPoolExecutor() as executor:
|
|
134
113
|
for image_url in image_urls:
|
|
135
|
-
fut = executor.submit(
|
|
114
|
+
fut = executor.submit(_decode_image, image_url)
|
|
136
115
|
image_futures.append(fut)
|
|
137
116
|
images = [fut.result() for fut in image_futures]
|
|
138
117
|
text = " ".join(texts)
|