xinference 0.14.2__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +48 -41
- xinference/model/audio/chattts.py +24 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +23 -1
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +49 -1
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +2 -0
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +10 -1
- xinference/model/llm/vllm/core.py +1 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/METADATA +18 -6
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/RECORD +135 -37
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,557 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
import time
|
|
16
|
+
import uuid
|
|
17
|
+
from typing import AsyncGenerator, Dict, Iterator, List, Optional, TypedDict, Union
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
from ....types import (
|
|
22
|
+
ChatCompletion,
|
|
23
|
+
ChatCompletionChunk,
|
|
24
|
+
ChatCompletionChunkChoice,
|
|
25
|
+
ChatCompletionMessage,
|
|
26
|
+
Completion,
|
|
27
|
+
CompletionChoice,
|
|
28
|
+
CompletionUsage,
|
|
29
|
+
LoRA,
|
|
30
|
+
)
|
|
31
|
+
from ..core import LLM
|
|
32
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
33
|
+
from ..utils import ChatModelMixin
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
import lmdeploy # noqa: F401
|
|
39
|
+
|
|
40
|
+
LMDEPLOY_INSTALLED = True
|
|
41
|
+
except ImportError:
|
|
42
|
+
LMDEPLOY_INSTALLED = False
|
|
43
|
+
|
|
44
|
+
LMDEPLOY_SUPPORTED_CHAT_MODELS = ["internvl2"]
|
|
45
|
+
LMDEPLOY_MODEL_CHAT_TEMPLATE_NAME = {
|
|
46
|
+
"internvl2": "internvl-internlm2",
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LMDeployModelConfig(TypedDict, total=False):
|
|
51
|
+
model_format: Optional[str]
|
|
52
|
+
tp: Optional[int]
|
|
53
|
+
session_len: Optional[int]
|
|
54
|
+
max_batch_size: Optional[int]
|
|
55
|
+
cache_max_entry_count: Optional[float]
|
|
56
|
+
cache_block_seq_len: Optional[int]
|
|
57
|
+
enable_prefix_caching: Optional[bool]
|
|
58
|
+
quant_policy: Optional[int]
|
|
59
|
+
rope_scaling_factor: Optional[float]
|
|
60
|
+
use_logn_attn: Optional[bool]
|
|
61
|
+
download_dir: Optional[str]
|
|
62
|
+
revision: Optional[str]
|
|
63
|
+
max_prefill_token_num: Optional[int]
|
|
64
|
+
num_tokens_per_iter: Optional[int]
|
|
65
|
+
max_prefill_iters: Optional[int]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class LMDeployGenerateConfig(TypedDict, total=False):
|
|
69
|
+
n: Optional[int]
|
|
70
|
+
max_new_tokens: Optional[int]
|
|
71
|
+
top_p: Optional[float]
|
|
72
|
+
top_k: Optional[int]
|
|
73
|
+
temperature: Optional[float]
|
|
74
|
+
repetition_penalty: Optional[float]
|
|
75
|
+
ignore_eos: Optional[bool]
|
|
76
|
+
random_seed: Optional[int]
|
|
77
|
+
stop_words: Optional[List[str]]
|
|
78
|
+
bad_words: Optional[List[str]]
|
|
79
|
+
min_new_tokens: Optional[int]
|
|
80
|
+
skip_special_tokens: Optional[bool]
|
|
81
|
+
logprobs: Optional[int]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class LMDeployModel(LLM):
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
model_uid: str,
|
|
88
|
+
model_family: "LLMFamilyV1",
|
|
89
|
+
model_spec: "LLMSpecV1",
|
|
90
|
+
quantization: str,
|
|
91
|
+
model_path: str,
|
|
92
|
+
model_config: Optional[LMDeployModelConfig] = None,
|
|
93
|
+
peft_model: Optional[List[LoRA]] = None,
|
|
94
|
+
):
|
|
95
|
+
super().__init__(model_uid, model_family, model_spec, quantization, model_path)
|
|
96
|
+
self._model_config: LMDeployModelConfig = self._sanitize_model_config(
|
|
97
|
+
model_config
|
|
98
|
+
)
|
|
99
|
+
if peft_model is not None:
|
|
100
|
+
raise ValueError("LMDEPLOY engine has not supported lora yet.")
|
|
101
|
+
|
|
102
|
+
def _sanitize_model_config(
|
|
103
|
+
self, model_config: Optional[LMDeployModelConfig]
|
|
104
|
+
) -> LMDeployModelConfig:
|
|
105
|
+
if model_config is None:
|
|
106
|
+
model_config = LMDeployModelConfig()
|
|
107
|
+
model_config.setdefault("session_len", 8192)
|
|
108
|
+
if self.model_spec.model_format == "awq":
|
|
109
|
+
model_config.setdefault("model_format", "awq")
|
|
110
|
+
return model_config
|
|
111
|
+
|
|
112
|
+
def load(self):
|
|
113
|
+
try:
|
|
114
|
+
import lmdeploy # noqa: F401, F811
|
|
115
|
+
except ImportError:
|
|
116
|
+
error_message = "Failed to import module 'lmdeploy'"
|
|
117
|
+
installation_guide = [
|
|
118
|
+
"Please make sure 'lmdeploy' is installed. ",
|
|
119
|
+
"You can install it by `pip install lmdeploy`\n",
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
123
|
+
raise ValueError("LMDEPLOY engine has not supported generate yet.")
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def match(
|
|
127
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
128
|
+
) -> bool:
|
|
129
|
+
return False
|
|
130
|
+
|
|
131
|
+
def generate(
|
|
132
|
+
self,
|
|
133
|
+
prompt: str,
|
|
134
|
+
generate_config: Optional[Dict] = None,
|
|
135
|
+
) -> Union[Completion, Iterator[ChatCompletionChunk]]:
|
|
136
|
+
raise NotImplementedError("LMDeploy generate ablility does not support now.")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class LMDeployChatModel(LMDeployModel, ChatModelMixin):
|
|
140
|
+
def load(self):
|
|
141
|
+
try:
|
|
142
|
+
from lmdeploy import (
|
|
143
|
+
ChatTemplateConfig,
|
|
144
|
+
TurbomindEngineConfig,
|
|
145
|
+
VisionConfig,
|
|
146
|
+
pipeline,
|
|
147
|
+
)
|
|
148
|
+
except ImportError:
|
|
149
|
+
error_message = "Failed to import module 'lmdeploy'"
|
|
150
|
+
installation_guide = [
|
|
151
|
+
"Please make sure 'lmdeploy' is installed. ",
|
|
152
|
+
"You can install it by `pip install lmdeploy`\n",
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
156
|
+
|
|
157
|
+
chat_temp_name = ""
|
|
158
|
+
family = self.model_family.model_family or self.model_family.model_name
|
|
159
|
+
for key in LMDEPLOY_MODEL_CHAT_TEMPLATE_NAME.keys():
|
|
160
|
+
if family in key:
|
|
161
|
+
chat_temp_name = LMDEPLOY_MODEL_CHAT_TEMPLATE_NAME[key]
|
|
162
|
+
break
|
|
163
|
+
if chat_temp_name == "":
|
|
164
|
+
raise ValueError(f"Can not find correct chat template.")
|
|
165
|
+
|
|
166
|
+
chat_template_config = ChatTemplateConfig(chat_temp_name)
|
|
167
|
+
chat_template_config.meta_instruction = (
|
|
168
|
+
self.model_family.prompt_style.system_prompt
|
|
169
|
+
)
|
|
170
|
+
count = torch.cuda.device_count()
|
|
171
|
+
if count > 1:
|
|
172
|
+
self._model_config.setdefault("tp", torch.cuda.device_count())
|
|
173
|
+
|
|
174
|
+
self._model = pipeline(
|
|
175
|
+
self.model_path,
|
|
176
|
+
chat_template_config=chat_template_config,
|
|
177
|
+
backend_config=TurbomindEngineConfig(**self._model_config),
|
|
178
|
+
vision_config=VisionConfig(thread_safe=True),
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def match(
|
|
183
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
184
|
+
) -> bool:
|
|
185
|
+
if llm_spec.model_format == "awq":
|
|
186
|
+
# Currently, only 4-bit weight quantization is supported for AWQ, but got 8 bits.
|
|
187
|
+
if "4" not in quantization:
|
|
188
|
+
return False
|
|
189
|
+
if llm_family.model_name not in LMDEPLOY_SUPPORTED_CHAT_MODELS:
|
|
190
|
+
return False
|
|
191
|
+
return LMDEPLOY_INSTALLED
|
|
192
|
+
|
|
193
|
+
async def async_chat(
|
|
194
|
+
self,
|
|
195
|
+
prompt: Union[str, List[Dict]],
|
|
196
|
+
system_prompt: Optional[str] = None,
|
|
197
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
198
|
+
generate_config: Optional[Dict] = None,
|
|
199
|
+
) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
|
|
200
|
+
stream = (
|
|
201
|
+
generate_config.get("stream", False)
|
|
202
|
+
if isinstance(generate_config, dict)
|
|
203
|
+
else False
|
|
204
|
+
)
|
|
205
|
+
stream_options = (
|
|
206
|
+
generate_config.get("stream_options", None)
|
|
207
|
+
if isinstance(generate_config, dict)
|
|
208
|
+
else False
|
|
209
|
+
)
|
|
210
|
+
include_usage = (
|
|
211
|
+
stream_options["include_usage"]
|
|
212
|
+
if isinstance(stream_options, dict)
|
|
213
|
+
else False
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
chat_history = chat_history or []
|
|
217
|
+
|
|
218
|
+
if stream:
|
|
219
|
+
chunk = self._chat_stream(prompt, chat_history, include_usage)
|
|
220
|
+
return self._async_to_chat_completion_chunks(chunk)
|
|
221
|
+
else:
|
|
222
|
+
chunk = await self._chat(prompt, chat_history)
|
|
223
|
+
return self._to_chat_completion(chunk)
|
|
224
|
+
|
|
225
|
+
async def _chat_stream(self, prompt, chat_history, include_usage):
|
|
226
|
+
from lmdeploy.messages import Response
|
|
227
|
+
|
|
228
|
+
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
|
229
|
+
completion_id = str(uuid.uuid1())
|
|
230
|
+
async for output in self._generate(
|
|
231
|
+
prompt,
|
|
232
|
+
chat_history,
|
|
233
|
+
session_id=-1,
|
|
234
|
+
stream_response=True,
|
|
235
|
+
):
|
|
236
|
+
new_text = output.text if isinstance(output, Response) else output.response
|
|
237
|
+
|
|
238
|
+
completion_choice = ChatCompletionChunkChoice(
|
|
239
|
+
text=new_text,
|
|
240
|
+
index=0,
|
|
241
|
+
logprobs=None,
|
|
242
|
+
finish_reason=output.finish_reason,
|
|
243
|
+
)
|
|
244
|
+
chunk = ChatCompletionChunk(
|
|
245
|
+
id=completion_id,
|
|
246
|
+
object="chat.completion",
|
|
247
|
+
created=int(time.time()),
|
|
248
|
+
model=self.model_uid,
|
|
249
|
+
choices=[completion_choice],
|
|
250
|
+
)
|
|
251
|
+
prompt_tokens = output.input_token_len
|
|
252
|
+
completion_tokens = output.generate_token_len
|
|
253
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
254
|
+
completion_usage = CompletionUsage(
|
|
255
|
+
prompt_tokens=prompt_tokens,
|
|
256
|
+
completion_tokens=completion_tokens,
|
|
257
|
+
total_tokens=total_tokens,
|
|
258
|
+
)
|
|
259
|
+
chunk["usage"] = completion_usage
|
|
260
|
+
print(chunk)
|
|
261
|
+
yield chunk
|
|
262
|
+
if include_usage:
|
|
263
|
+
chunk = ChatCompletionChunk(
|
|
264
|
+
id=completion_id,
|
|
265
|
+
object="chat.completion",
|
|
266
|
+
created=int(time.time()),
|
|
267
|
+
model=self.model_uid,
|
|
268
|
+
choices=[],
|
|
269
|
+
)
|
|
270
|
+
chunk["usage"] = CompletionUsage(
|
|
271
|
+
prompt_tokens=prompt_tokens,
|
|
272
|
+
completion_tokens=completion_tokens,
|
|
273
|
+
total_tokens=total_tokens,
|
|
274
|
+
)
|
|
275
|
+
yield chunk
|
|
276
|
+
|
|
277
|
+
async def _chat(self, prompt, chat_history):
|
|
278
|
+
from lmdeploy.messages import Response
|
|
279
|
+
|
|
280
|
+
response, finish_reason = "", ""
|
|
281
|
+
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
|
282
|
+
async for output in self._generate(
|
|
283
|
+
prompt,
|
|
284
|
+
chat_history,
|
|
285
|
+
session_id=-1,
|
|
286
|
+
stream_response=False,
|
|
287
|
+
):
|
|
288
|
+
response += output.text if isinstance(output, Response) else output.response
|
|
289
|
+
prompt_tokens = output.input_token_len
|
|
290
|
+
completion_tokens = output.generate_token_len
|
|
291
|
+
total_tokens = output.input_token_len + output.generate_token_len
|
|
292
|
+
finish_reason = output.finish_reason
|
|
293
|
+
|
|
294
|
+
chunk = ChatCompletion(
|
|
295
|
+
id=str(uuid.uuid1()),
|
|
296
|
+
object="chat.completion",
|
|
297
|
+
created=int(time.time()),
|
|
298
|
+
model=self.model_uid,
|
|
299
|
+
choices=[
|
|
300
|
+
CompletionChoice(
|
|
301
|
+
index=0, text=response, finish_reason=finish_reason, logprobs=None
|
|
302
|
+
)
|
|
303
|
+
],
|
|
304
|
+
usage=CompletionUsage(
|
|
305
|
+
prompt_tokens=prompt_tokens,
|
|
306
|
+
completion_tokens=completion_tokens,
|
|
307
|
+
total_tokens=total_tokens,
|
|
308
|
+
),
|
|
309
|
+
)
|
|
310
|
+
return chunk
|
|
311
|
+
|
|
312
|
+
# copy from lmdeploy
|
|
313
|
+
# Reference: lmdeploy.serve.async_engine.py
|
|
314
|
+
async def _generate(
|
|
315
|
+
self,
|
|
316
|
+
prompt,
|
|
317
|
+
chat_history,
|
|
318
|
+
session_id: int,
|
|
319
|
+
generate_config: Optional[Dict] = None,
|
|
320
|
+
tools: Optional[List[object]] = None,
|
|
321
|
+
stream_response: bool = True,
|
|
322
|
+
sequence_start: bool = True,
|
|
323
|
+
sequence_end: bool = True, # no interactive mode by default
|
|
324
|
+
step: int = 0,
|
|
325
|
+
do_preprocess: bool = False,
|
|
326
|
+
adapter_name: Optional[str] = None,
|
|
327
|
+
**kwargs,
|
|
328
|
+
):
|
|
329
|
+
import random
|
|
330
|
+
|
|
331
|
+
from lmdeploy.messages import EngineGenerationConfig, GenerationConfig
|
|
332
|
+
from lmdeploy.serve.async_engine import GenOut
|
|
333
|
+
from lmdeploy.tokenizer import DetokenizeState
|
|
334
|
+
|
|
335
|
+
session_id = -1
|
|
336
|
+
|
|
337
|
+
if str(session_id) not in self._model.id2step:
|
|
338
|
+
self._model.id2step[str(session_id)] = 0
|
|
339
|
+
if generate_config is None:
|
|
340
|
+
generate_config = GenerationConfig()
|
|
341
|
+
if type(generate_config) is GenerationConfig:
|
|
342
|
+
generate_config = EngineGenerationConfig.From(
|
|
343
|
+
generate_config, self._model.tokenizer
|
|
344
|
+
)
|
|
345
|
+
if generate_config.stop_words is None: # type: ignore
|
|
346
|
+
generate_config.stop_words = self._model.stop_words # type: ignore
|
|
347
|
+
if generate_config.random_seed is None and sequence_start: # type: ignore
|
|
348
|
+
generate_config.random_seed = random.getrandbits(64) # type: ignore
|
|
349
|
+
if generate_config.n > 1: # type: ignore
|
|
350
|
+
logger.warning(
|
|
351
|
+
f"n({generate_config.n}) > 1 hasn't been supported yet. " # type: ignore
|
|
352
|
+
f"Fallback to 1"
|
|
353
|
+
)
|
|
354
|
+
generate_config.n = 1 # type: ignore
|
|
355
|
+
|
|
356
|
+
prompt_input = await self._get_prompt_input(prompt, chat_history)
|
|
357
|
+
prompt = prompt_input["prompt"]
|
|
358
|
+
input_ids = prompt_input["input_ids"]
|
|
359
|
+
finish_reason = None
|
|
360
|
+
logger.info(
|
|
361
|
+
f"prompt={prompt!r}, "
|
|
362
|
+
f"gen_config={generate_config}, "
|
|
363
|
+
f"prompt_token_id={input_ids}, "
|
|
364
|
+
f"adapter_name={adapter_name}."
|
|
365
|
+
)
|
|
366
|
+
logger.info(
|
|
367
|
+
f"session_id={session_id}, " # type: ignore
|
|
368
|
+
f"history_tokens={self._model.id2step[str(session_id)]}, "
|
|
369
|
+
f"input_tokens={len(input_ids)}, "
|
|
370
|
+
f"max_new_tokens={generate_config.max_new_tokens}, "
|
|
371
|
+
f"seq_start={sequence_start}, seq_end={sequence_end}, "
|
|
372
|
+
f"step={step}, prep={do_preprocess}"
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
if generate_config.max_new_tokens is None: # type: ignore
|
|
376
|
+
# for interactive endpoint, will try maximum possible token num
|
|
377
|
+
generate_config.max_new_tokens = max( # type: ignore
|
|
378
|
+
128,
|
|
379
|
+
self._model.session_len
|
|
380
|
+
- self._model.id2step[str(session_id)]
|
|
381
|
+
- len(input_ids),
|
|
382
|
+
)
|
|
383
|
+
elif (
|
|
384
|
+
self._model.id2step[str(session_id)]
|
|
385
|
+
+ len(input_ids)
|
|
386
|
+
+ generate_config.max_new_tokens # type: ignore
|
|
387
|
+
> self._model.session_len
|
|
388
|
+
):
|
|
389
|
+
generate_config.max_new_tokens = max( # type: ignore
|
|
390
|
+
self._model.session_len
|
|
391
|
+
- self._model.id2step[str(session_id)]
|
|
392
|
+
- len(input_ids),
|
|
393
|
+
128,
|
|
394
|
+
)
|
|
395
|
+
logger.error(f"Truncate max_new_tokens to {generate_config.max_new_tokens}") # type: ignore
|
|
396
|
+
|
|
397
|
+
if (
|
|
398
|
+
self._model.id2step[str(session_id)]
|
|
399
|
+
+ len(input_ids)
|
|
400
|
+
+ generate_config.max_new_tokens # type: ignore
|
|
401
|
+
> self._model.session_len
|
|
402
|
+
):
|
|
403
|
+
logger.error(f"run out of tokens. session_id={session_id}.")
|
|
404
|
+
yield GenOut(
|
|
405
|
+
"", self._model.id2step[str(session_id)], len(input_ids), 0, "length"
|
|
406
|
+
)
|
|
407
|
+
if sequence_end is True and sequence_start is False:
|
|
408
|
+
await self._model.end_session(session_id)
|
|
409
|
+
else:
|
|
410
|
+
generator = await self._model.get_generator(False, session_id)
|
|
411
|
+
async with self._model.safe_run(session_id):
|
|
412
|
+
state = DetokenizeState(len(input_ids))
|
|
413
|
+
start_ids_offset = state.ids_offset
|
|
414
|
+
response = ""
|
|
415
|
+
async for outputs in generator.async_stream_infer(
|
|
416
|
+
session_id=session_id,
|
|
417
|
+
**prompt_input,
|
|
418
|
+
gen_config=generate_config,
|
|
419
|
+
adapter_name=adapter_name,
|
|
420
|
+
stream_output=stream_response,
|
|
421
|
+
sequence_start=sequence_start,
|
|
422
|
+
sequence_end=sequence_end,
|
|
423
|
+
step=self._model.id2step[str(session_id)],
|
|
424
|
+
):
|
|
425
|
+
# decode res
|
|
426
|
+
res, tokens = (
|
|
427
|
+
input_ids + outputs.token_ids,
|
|
428
|
+
outputs.num_token,
|
|
429
|
+
) # noqa
|
|
430
|
+
if len(res) <= state.ids_offset:
|
|
431
|
+
continue
|
|
432
|
+
|
|
433
|
+
ids_offset = state.ids_offset
|
|
434
|
+
response, state = self._model.tokenizer.detokenize_incrementally(
|
|
435
|
+
res,
|
|
436
|
+
state,
|
|
437
|
+
skip_special_tokens=generate_config.skip_special_tokens, # type: ignore
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
res = res[ids_offset:]
|
|
441
|
+
logprobs = None
|
|
442
|
+
if outputs.logprobs:
|
|
443
|
+
log_offset = ids_offset - start_ids_offset
|
|
444
|
+
logprobs = outputs.logprobs[log_offset:]
|
|
445
|
+
|
|
446
|
+
# response, history token len,
|
|
447
|
+
# input token len, gen token len
|
|
448
|
+
yield GenOut(
|
|
449
|
+
response,
|
|
450
|
+
self._model.id2step[str(session_id)],
|
|
451
|
+
len(input_ids),
|
|
452
|
+
tokens,
|
|
453
|
+
finish_reason,
|
|
454
|
+
res,
|
|
455
|
+
logprobs,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
finish_reason = (
|
|
459
|
+
"length" if tokens >= generate_config.max_new_tokens else "stop" # type: ignore
|
|
460
|
+
)
|
|
461
|
+
# utf-8 char at the end means it's a potential unfinished
|
|
462
|
+
# byte sequence
|
|
463
|
+
if not response.endswith("�"):
|
|
464
|
+
response = "" # avaid returning the last response twice
|
|
465
|
+
yield GenOut(
|
|
466
|
+
response,
|
|
467
|
+
self._model.id2step[str(session_id)],
|
|
468
|
+
len(input_ids),
|
|
469
|
+
tokens,
|
|
470
|
+
finish_reason,
|
|
471
|
+
)
|
|
472
|
+
# update step
|
|
473
|
+
self._model.id2step[str(session_id)] += len(input_ids) + tokens
|
|
474
|
+
if sequence_end:
|
|
475
|
+
self._model.id2step[str(session_id)] = 0
|
|
476
|
+
# manually end pytorch session
|
|
477
|
+
# TODO modify pytorch or turbomind api
|
|
478
|
+
if self._model.backend == "pytorch" and sequence_end:
|
|
479
|
+
await self._model.end_session(session_id)
|
|
480
|
+
|
|
481
|
+
# copy from lmdeploy
|
|
482
|
+
# Reference: lmdeploy.serve.vl_async_engine.py
|
|
483
|
+
async def _get_prompt_input(
|
|
484
|
+
self,
|
|
485
|
+
prompt: Union[str, List[Dict]],
|
|
486
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
487
|
+
sequence_start: bool = True,
|
|
488
|
+
tools: Optional[List[object]] = None,
|
|
489
|
+
**kwargs,
|
|
490
|
+
):
|
|
491
|
+
"""get input_ids, embeddings and offsets."""
|
|
492
|
+
IMAGE_TOKEN = "<IMAGE_TOKEN>"
|
|
493
|
+
IMAGE_DUMMY_TOKEN_INDEX = 0
|
|
494
|
+
import numpy as np
|
|
495
|
+
|
|
496
|
+
assert self.model_family.prompt_style is not None
|
|
497
|
+
prompt_style = self.model_family.prompt_style.copy()
|
|
498
|
+
chat_history = chat_history or []
|
|
499
|
+
|
|
500
|
+
decorated, _ = self.get_prompt(prompt, chat_history, prompt_style) # type: ignore
|
|
501
|
+
chat_history.append(ChatCompletionMessage(role="user", content=prompt)) # type: ignore
|
|
502
|
+
prompt = chat_history # type: ignore
|
|
503
|
+
|
|
504
|
+
decorated = decorated.replace("<image>", "<img><IMAGE_TOKEN></img>")
|
|
505
|
+
|
|
506
|
+
segs = decorated.split(IMAGE_TOKEN)
|
|
507
|
+
|
|
508
|
+
results = {}
|
|
509
|
+
input_ids = [] # type: ignore
|
|
510
|
+
if len(segs) > 1:
|
|
511
|
+
images = await self._model.vl_prompt_template.async_collect_pil_images(
|
|
512
|
+
prompt
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
features = await self._model.vl_encoder.async_infer(images)
|
|
516
|
+
|
|
517
|
+
from lmdeploy.vl.templates import MiniCPMVTempateWrapper
|
|
518
|
+
|
|
519
|
+
if isinstance(self._model.vl_prompt_template, MiniCPMVTempateWrapper):
|
|
520
|
+
(
|
|
521
|
+
decorated,
|
|
522
|
+
features,
|
|
523
|
+
) = self._model.vl_prompt_template.update_image_token( # noqa: E501
|
|
524
|
+
decorated, features
|
|
525
|
+
)
|
|
526
|
+
segs = decorated.split(IMAGE_TOKEN)
|
|
527
|
+
|
|
528
|
+
features = [x.cpu().numpy() for x in features]
|
|
529
|
+
input_ids = []
|
|
530
|
+
begins = []
|
|
531
|
+
ends = []
|
|
532
|
+
if len(segs) != len(features) + 1:
|
|
533
|
+
logger.error(
|
|
534
|
+
f"the number of {IMAGE_TOKEN} is not equal "
|
|
535
|
+
f"to input images, {len(segs) - 1} vs {len(features)}"
|
|
536
|
+
)
|
|
537
|
+
features = features[: len(segs) - 1]
|
|
538
|
+
for i, seg in enumerate(segs):
|
|
539
|
+
if i > 0 and i <= len(features):
|
|
540
|
+
image_dim = features[i - 1].shape[0]
|
|
541
|
+
begins.append(len(input_ids))
|
|
542
|
+
ends.append(begins[-1] + image_dim)
|
|
543
|
+
input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim)
|
|
544
|
+
seg_ids = self._model.tokenizer.encode(
|
|
545
|
+
seg, add_bos=((i == 0) and sequence_start)
|
|
546
|
+
)
|
|
547
|
+
input_ids.extend(seg_ids)
|
|
548
|
+
ranges = np.stack([begins, ends], axis=1).tolist()
|
|
549
|
+
results["input_embeddings"] = features
|
|
550
|
+
results["input_embedding_ranges"] = ranges
|
|
551
|
+
else:
|
|
552
|
+
input_ids = self._model.tokenizer.encode(decorated, add_bos=sequence_start)
|
|
553
|
+
|
|
554
|
+
results["input_ids"] = input_ids
|
|
555
|
+
results["prompt"] = decorated
|
|
556
|
+
|
|
557
|
+
return results
|
|
@@ -11,17 +11,13 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import base64
|
|
15
14
|
import logging
|
|
16
15
|
import time
|
|
17
16
|
import uuid
|
|
18
17
|
from concurrent.futures import ThreadPoolExecutor
|
|
19
|
-
from io import BytesIO
|
|
20
18
|
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
|
21
19
|
|
|
22
|
-
import requests
|
|
23
20
|
import torch
|
|
24
|
-
from PIL import Image
|
|
25
21
|
|
|
26
22
|
from ....core.scheduler import InferenceRequest
|
|
27
23
|
from ....model.utils import select_device
|
|
@@ -35,6 +31,7 @@ from ....types import (
|
|
|
35
31
|
CompletionUsage,
|
|
36
32
|
)
|
|
37
33
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
34
|
+
from ..utils import _decode_image
|
|
38
35
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
39
36
|
from .utils import get_max_src_len
|
|
40
37
|
|
|
@@ -75,7 +72,7 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
75
72
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
76
73
|
) -> bool:
|
|
77
74
|
family = model_family.model_family or model_family.model_name
|
|
78
|
-
if "
|
|
75
|
+
if "cogvlm2" in family.lower() and "video" not in family.lower():
|
|
79
76
|
return True
|
|
80
77
|
return False
|
|
81
78
|
|
|
@@ -116,24 +113,6 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
116
113
|
self._save_tensorizer()
|
|
117
114
|
|
|
118
115
|
def _message_content_to_cogvlm2(self, content):
|
|
119
|
-
def _load_image(_url):
|
|
120
|
-
if _url.startswith("data:"):
|
|
121
|
-
logging.info("Parse url by base64 decoder.")
|
|
122
|
-
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
123
|
-
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
124
|
-
_type, data = _url.split(";")
|
|
125
|
-
_, ext = _type.split("/")
|
|
126
|
-
data = data[len("base64,") :]
|
|
127
|
-
data = base64.b64decode(data.encode("utf-8"))
|
|
128
|
-
return Image.open(BytesIO(data)).convert("RGB")
|
|
129
|
-
else:
|
|
130
|
-
try:
|
|
131
|
-
response = requests.get(_url)
|
|
132
|
-
except requests.exceptions.MissingSchema:
|
|
133
|
-
return Image.open(_url).convert("RGB")
|
|
134
|
-
else:
|
|
135
|
-
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
136
|
-
|
|
137
116
|
if not isinstance(content, str):
|
|
138
117
|
texts = []
|
|
139
118
|
image_urls = []
|
|
@@ -146,7 +125,7 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
146
125
|
image_futures = []
|
|
147
126
|
with ThreadPoolExecutor() as executor:
|
|
148
127
|
for image_url in image_urls:
|
|
149
|
-
fut = executor.submit(
|
|
128
|
+
fut = executor.submit(_decode_image, image_url)
|
|
150
129
|
image_futures.append(fut)
|
|
151
130
|
images = [fut.result() for fut in image_futures]
|
|
152
131
|
text = " ".join(texts)
|
|
@@ -163,24 +142,6 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
163
142
|
def _history_content_to_cogvlm2(
|
|
164
143
|
self, system_prompt: str, chat_history: List[ChatCompletionMessage]
|
|
165
144
|
):
|
|
166
|
-
def _image_to_piexl_values(image):
|
|
167
|
-
if image.startswith("data:"):
|
|
168
|
-
logging.info("Parse url by base64 decoder.")
|
|
169
|
-
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
170
|
-
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
171
|
-
_type, data = image.split(";")
|
|
172
|
-
_, ext = _type.split("/")
|
|
173
|
-
data = data[len("base64,") :]
|
|
174
|
-
data = base64.b64decode(data.encode("utf-8"))
|
|
175
|
-
return Image.open(BytesIO(data)).convert("RGB")
|
|
176
|
-
else:
|
|
177
|
-
try:
|
|
178
|
-
response = requests.get(image)
|
|
179
|
-
except requests.exceptions.MissingSchema:
|
|
180
|
-
return Image.open(image).convert("RGB")
|
|
181
|
-
else:
|
|
182
|
-
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
183
|
-
|
|
184
145
|
query = system_prompt
|
|
185
146
|
history: List[Tuple] = []
|
|
186
147
|
pixel_values = None
|
|
@@ -192,9 +153,7 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
192
153
|
if c_type == "text":
|
|
193
154
|
user = content["text"]
|
|
194
155
|
elif c_type == "image_url" and not pixel_values:
|
|
195
|
-
pixel_values =
|
|
196
|
-
content["image_url"]["url"]
|
|
197
|
-
)
|
|
156
|
+
pixel_values = _decode_image(content["image_url"]["url"])
|
|
198
157
|
assistant = chat_history[i + 1]["content"]
|
|
199
158
|
history.append((user, assistant))
|
|
200
159
|
query = assistant # type: ignore
|