xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +209 -40
- xinference/client/restful/restful_client.py +7 -26
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/image_interface.py +28 -0
- xinference/core/model.py +110 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +17 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/audio/fish_speech.py +9 -9
- xinference/model/audio/model_spec.json +9 -9
- xinference/model/audio/whisper.py +4 -1
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/core.py +2 -1
- xinference/model/image/model_spec.json +16 -4
- xinference/model/image/model_spec_modelscope.json +16 -4
- xinference/model/image/sdapi.py +136 -0
- xinference/model/image/stable_diffusion/core.py +164 -19
- xinference/model/llm/__init__.py +29 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +1011 -1296
- xinference/model/llm/llm_family.py +34 -53
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +981 -1122
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +36 -18
- xinference/model/llm/transformers/chatglm.py +168 -306
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +55 -50
- xinference/model/llm/transformers/deepseek_v2.py +340 -0
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_audio.py +168 -0
- xinference/model/llm/transformers/qwen2_vl.py +234 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +280 -554
- xinference/model/llm/vllm/core.py +161 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
- xinference/thirdparty/fish_speech/tools/api.py +79 -134
- xinference/thirdparty/fish_speech/tools/commons.py +35 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
- xinference/thirdparty/fish_speech/tools/file.py +17 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
- xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
- xinference/thirdparty/fish_speech/tools/webui.py +12 -146
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +14 -53
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
- xinference/web/ui/build/static/js/main.754740c0.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
- xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
import uuid
|
|
16
|
+
from typing import Dict, Iterator, List, Optional, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from ....types import (
|
|
21
|
+
ChatCompletion,
|
|
22
|
+
ChatCompletionChunk,
|
|
23
|
+
Completion,
|
|
24
|
+
CompletionChunk,
|
|
25
|
+
PytorchGenerateConfig,
|
|
26
|
+
)
|
|
27
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
28
|
+
from ..utils import (
|
|
29
|
+
generate_chat_completion,
|
|
30
|
+
generate_completion,
|
|
31
|
+
generate_completion_chunk,
|
|
32
|
+
)
|
|
33
|
+
from .core import PytorchChatModel, PytorchModel
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class DeepSeekV2PytorchModel(PytorchModel):
|
|
39
|
+
def _load_model(self, **kwargs):
|
|
40
|
+
try:
|
|
41
|
+
from transformers import (
|
|
42
|
+
AutoModelForCausalLM,
|
|
43
|
+
AutoTokenizer,
|
|
44
|
+
GenerationConfig,
|
|
45
|
+
)
|
|
46
|
+
except ImportError:
|
|
47
|
+
error_message = "Failed to import module 'transformers'"
|
|
48
|
+
installation_guide = [
|
|
49
|
+
"Please make sure 'transformers' is installed. ",
|
|
50
|
+
"You can install it by `pip install transformers`\n",
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
54
|
+
|
|
55
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
56
|
+
self.model_path,
|
|
57
|
+
trust_remote_code=kwargs["trust_remote_code"],
|
|
58
|
+
)
|
|
59
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
60
|
+
self.model_path,
|
|
61
|
+
attn_implementation="eager",
|
|
62
|
+
torch_dtype=torch.bfloat16,
|
|
63
|
+
trust_remote_code=True,
|
|
64
|
+
device_map="auto",
|
|
65
|
+
)
|
|
66
|
+
model.generation_config = GenerationConfig.from_pretrained(self.model_path)
|
|
67
|
+
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
|
68
|
+
return model, tokenizer
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def match(
|
|
72
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
73
|
+
) -> bool:
|
|
74
|
+
if llm_spec.model_format != "pytorch":
|
|
75
|
+
return False
|
|
76
|
+
model_family = llm_family.model_family or llm_family.model_name
|
|
77
|
+
if "deepseek-v2" not in model_family:
|
|
78
|
+
return False
|
|
79
|
+
if "generate" not in llm_family.model_ability:
|
|
80
|
+
return False
|
|
81
|
+
return True
|
|
82
|
+
|
|
83
|
+
def generate(
|
|
84
|
+
self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None
|
|
85
|
+
) -> Union[Completion, Iterator[CompletionChunk]]:
|
|
86
|
+
input_tensor = self._tokenizer(prompt, return_tensors="pt")
|
|
87
|
+
generate_config = self._sanitize_generate_config(generate_config)
|
|
88
|
+
default_generate_config = self._model.generation_config
|
|
89
|
+
generate_kwargs = {
|
|
90
|
+
"input_ids": input_tensor["input_ids"].cuda(),
|
|
91
|
+
"attention_mask": input_tensor["attention_mask"].cuda(),
|
|
92
|
+
"temperature": float(
|
|
93
|
+
generate_config.get("temperature", default_generate_config.temperature)
|
|
94
|
+
),
|
|
95
|
+
"repetition_penalty": float(generate_config.get("repetition_penalty", 1.0)),
|
|
96
|
+
"top_p": float(generate_config.get("top_p", default_generate_config.top_p)),
|
|
97
|
+
"top_k": int(generate_config.get("top_k", -1)),
|
|
98
|
+
"max_new_tokens": generate_config.get("max_tokens", 512),
|
|
99
|
+
"bos_token_id": default_generate_config.bos_token_id,
|
|
100
|
+
"do_sample": default_generate_config.do_sample,
|
|
101
|
+
"eos_token_id": default_generate_config.eos_token_id,
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
stream = generate_config.get("stream", False)
|
|
105
|
+
if stream:
|
|
106
|
+
return self._generate_stream(generate_kwargs, input_tensor)
|
|
107
|
+
else:
|
|
108
|
+
return self._generate(generate_kwargs, input_tensor)
|
|
109
|
+
|
|
110
|
+
def _generate(self, generate_kwargs, input_ids) -> Completion:
|
|
111
|
+
prompt_tokens = len(input_ids[0])
|
|
112
|
+
logger.info(f"generate_kwargs:{generate_kwargs}")
|
|
113
|
+
generation_output = self._model.generate(**generate_kwargs)
|
|
114
|
+
completion_tokens = len(generation_output[0])
|
|
115
|
+
response = self._tokenizer.decode(
|
|
116
|
+
generation_output[0], skip_special_tokens=True
|
|
117
|
+
)
|
|
118
|
+
return generate_completion(
|
|
119
|
+
self.model_uid,
|
|
120
|
+
response,
|
|
121
|
+
prompt_tokens=prompt_tokens,
|
|
122
|
+
completion_tokens=completion_tokens,
|
|
123
|
+
total_tokens=prompt_tokens + completion_tokens,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def _generate_stream(self, generate_kwargs, input_ids):
|
|
127
|
+
from threading import Thread
|
|
128
|
+
|
|
129
|
+
from transformers import TextIteratorStreamer
|
|
130
|
+
|
|
131
|
+
# Initialize the streamer
|
|
132
|
+
streamer = TextIteratorStreamer(
|
|
133
|
+
self._tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
|
|
134
|
+
)
|
|
135
|
+
# Define the generation configuration
|
|
136
|
+
generate_kwargs["streamer"] = streamer
|
|
137
|
+
# Start the model chat in a separate thread
|
|
138
|
+
thread = Thread(
|
|
139
|
+
target=self._model.generate,
|
|
140
|
+
kwargs=generate_kwargs,
|
|
141
|
+
)
|
|
142
|
+
thread.start()
|
|
143
|
+
|
|
144
|
+
completion_id = str(uuid.uuid1())
|
|
145
|
+
prompt_tokens = len(input_ids[0])
|
|
146
|
+
total_tokens, completion_tokens = 0, 0
|
|
147
|
+
# Loop through the streamer to get the new text as it is generated
|
|
148
|
+
for i, new_text in enumerate(streamer):
|
|
149
|
+
completion_tokens = i
|
|
150
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
151
|
+
yield generate_completion_chunk(
|
|
152
|
+
chunk_text=new_text,
|
|
153
|
+
finish_reason=None,
|
|
154
|
+
chunk_id=completion_id,
|
|
155
|
+
model_uid=self.model_uid,
|
|
156
|
+
prompt_tokens=prompt_tokens,
|
|
157
|
+
completion_tokens=completion_tokens,
|
|
158
|
+
total_tokens=total_tokens,
|
|
159
|
+
)
|
|
160
|
+
yield generate_completion_chunk(
|
|
161
|
+
chunk_text=None,
|
|
162
|
+
finish_reason="stop",
|
|
163
|
+
chunk_id=completion_id,
|
|
164
|
+
model_uid=self.model_uid,
|
|
165
|
+
prompt_tokens=prompt_tokens,
|
|
166
|
+
completion_tokens=completion_tokens,
|
|
167
|
+
total_tokens=total_tokens,
|
|
168
|
+
has_choice=True,
|
|
169
|
+
has_content=False,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class DeepSeekV2PytorchChatModel(PytorchChatModel):
|
|
174
|
+
def _load_model(self, **kwargs):
|
|
175
|
+
try:
|
|
176
|
+
from transformers import (
|
|
177
|
+
AutoModelForCausalLM,
|
|
178
|
+
AutoTokenizer,
|
|
179
|
+
GenerationConfig,
|
|
180
|
+
)
|
|
181
|
+
except ImportError:
|
|
182
|
+
error_message = "Failed to import module 'transformers'"
|
|
183
|
+
installation_guide = [
|
|
184
|
+
"Please make sure 'transformers' is installed. ",
|
|
185
|
+
"You can install it by `pip install transformers`\n",
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
189
|
+
|
|
190
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
191
|
+
self.model_path,
|
|
192
|
+
trust_remote_code=kwargs["trust_remote_code"],
|
|
193
|
+
)
|
|
194
|
+
logger.info(f"kwargs:{kwargs}")
|
|
195
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
196
|
+
self.model_path,
|
|
197
|
+
attn_implementation="eager",
|
|
198
|
+
torch_dtype=torch.bfloat16,
|
|
199
|
+
trust_remote_code=True,
|
|
200
|
+
device_map="auto",
|
|
201
|
+
)
|
|
202
|
+
model.generation_config = GenerationConfig.from_pretrained(self.model_path)
|
|
203
|
+
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
|
204
|
+
return model, tokenizer
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
def match(
|
|
208
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
209
|
+
) -> bool:
|
|
210
|
+
if llm_spec.model_format != "pytorch":
|
|
211
|
+
return False
|
|
212
|
+
model_family = llm_family.model_family or llm_family.model_name
|
|
213
|
+
if "deepseek-v2" not in model_family:
|
|
214
|
+
return False
|
|
215
|
+
if "chat" not in llm_family.model_ability:
|
|
216
|
+
return False
|
|
217
|
+
return True
|
|
218
|
+
|
|
219
|
+
def chat(
|
|
220
|
+
self,
|
|
221
|
+
messages: List[Dict],
|
|
222
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
223
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
224
|
+
assert self.model_family.chat_template is not None
|
|
225
|
+
full_prompt = self.get_full_context(
|
|
226
|
+
messages,
|
|
227
|
+
self.model_family.chat_template,
|
|
228
|
+
tokenizer=self._tokenizer,
|
|
229
|
+
)
|
|
230
|
+
input_tensor = self._tokenizer.encode(
|
|
231
|
+
full_prompt,
|
|
232
|
+
padding=False,
|
|
233
|
+
truncation=False,
|
|
234
|
+
max_length=None,
|
|
235
|
+
add_special_tokens=False,
|
|
236
|
+
return_tensors="pt",
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
generate_config = self._sanitize_generate_config(generate_config)
|
|
240
|
+
default_generate_config = self._model.generation_config
|
|
241
|
+
generate_kwargs = {
|
|
242
|
+
"input_ids": input_tensor.cuda(),
|
|
243
|
+
"temperature": float(
|
|
244
|
+
generate_config.get("temperature", default_generate_config.temperature)
|
|
245
|
+
),
|
|
246
|
+
"repetition_penalty": float(generate_config.get("repetition_penalty", 1.0)),
|
|
247
|
+
"top_p": float(generate_config.get("top_p", default_generate_config.top_p)),
|
|
248
|
+
"top_k": int(generate_config.get("top_k", -1)),
|
|
249
|
+
"max_new_tokens": generate_config.get("max_tokens", 512),
|
|
250
|
+
"bos_token_id": default_generate_config.bos_token_id,
|
|
251
|
+
"do_sample": default_generate_config.do_sample,
|
|
252
|
+
"eos_token_id": default_generate_config.eos_token_id,
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
stream = generate_config.get("stream", False)
|
|
256
|
+
stream_options = generate_config.get("stream_options", None)
|
|
257
|
+
include_usage = (
|
|
258
|
+
stream_options["include_usage"]
|
|
259
|
+
if isinstance(stream_options, dict)
|
|
260
|
+
else False
|
|
261
|
+
)
|
|
262
|
+
if stream:
|
|
263
|
+
chunk = self._generate_stream(generate_kwargs, input_tensor, include_usage)
|
|
264
|
+
return self._to_chat_completion_chunks(chunk)
|
|
265
|
+
else:
|
|
266
|
+
return self._generate(generate_kwargs, input_tensor)
|
|
267
|
+
|
|
268
|
+
def _generate(self, generate_kwargs, input_ids) -> ChatCompletion:
|
|
269
|
+
prompt_tokens = len(input_ids[0])
|
|
270
|
+
generation_output = self._model.generate(**generate_kwargs)
|
|
271
|
+
completion_tokens = len(generation_output[0])
|
|
272
|
+
response = self._tokenizer.decode(
|
|
273
|
+
generation_output[0][input_ids.shape[1] :], skip_special_tokens=True
|
|
274
|
+
)
|
|
275
|
+
return generate_chat_completion(
|
|
276
|
+
self.model_uid,
|
|
277
|
+
response,
|
|
278
|
+
prompt_tokens=prompt_tokens,
|
|
279
|
+
completion_tokens=completion_tokens,
|
|
280
|
+
total_tokens=prompt_tokens + completion_tokens,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def _generate_stream(self, generate_kwargs, input_ids, include_usage):
|
|
284
|
+
from threading import Thread
|
|
285
|
+
|
|
286
|
+
from transformers import TextIteratorStreamer
|
|
287
|
+
|
|
288
|
+
# Initialize the streamer
|
|
289
|
+
streamer = TextIteratorStreamer(
|
|
290
|
+
self._tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
|
|
291
|
+
)
|
|
292
|
+
# Define the generation configuration
|
|
293
|
+
generate_kwargs["streamer"] = streamer
|
|
294
|
+
# Start the model chat in a separate thread
|
|
295
|
+
thread = Thread(
|
|
296
|
+
target=self._model.generate,
|
|
297
|
+
kwargs=generate_kwargs,
|
|
298
|
+
)
|
|
299
|
+
thread.start()
|
|
300
|
+
|
|
301
|
+
completion_id = str(uuid.uuid1())
|
|
302
|
+
prompt_tokens = len(input_ids[0])
|
|
303
|
+
total_tokens, completion_tokens = 0, 0
|
|
304
|
+
# Loop through the streamer to get the new text as it is generated
|
|
305
|
+
for i, new_text in enumerate(streamer):
|
|
306
|
+
completion_tokens = max(completion_tokens, len(streamer.token_cache))
|
|
307
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
308
|
+
yield generate_completion_chunk(
|
|
309
|
+
chunk_text=new_text,
|
|
310
|
+
finish_reason=None,
|
|
311
|
+
chunk_id=completion_id,
|
|
312
|
+
model_uid=self.model_uid,
|
|
313
|
+
prompt_tokens=prompt_tokens,
|
|
314
|
+
completion_tokens=completion_tokens,
|
|
315
|
+
total_tokens=total_tokens,
|
|
316
|
+
)
|
|
317
|
+
yield generate_completion_chunk(
|
|
318
|
+
chunk_text=None,
|
|
319
|
+
finish_reason="stop",
|
|
320
|
+
chunk_id=completion_id,
|
|
321
|
+
model_uid=self.model_uid,
|
|
322
|
+
prompt_tokens=prompt_tokens,
|
|
323
|
+
completion_tokens=completion_tokens,
|
|
324
|
+
total_tokens=total_tokens,
|
|
325
|
+
has_choice=True,
|
|
326
|
+
has_content=False,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
if include_usage:
|
|
330
|
+
yield generate_completion_chunk(
|
|
331
|
+
chunk_text=None,
|
|
332
|
+
finish_reason=None,
|
|
333
|
+
chunk_id=completion_id,
|
|
334
|
+
model_uid=self.model_uid,
|
|
335
|
+
prompt_tokens=prompt_tokens,
|
|
336
|
+
completion_tokens=completion_tokens,
|
|
337
|
+
total_tokens=total_tokens,
|
|
338
|
+
has_choice=False,
|
|
339
|
+
has_content=False,
|
|
340
|
+
)
|
|
@@ -15,7 +15,6 @@ import base64
|
|
|
15
15
|
import logging
|
|
16
16
|
import os.path
|
|
17
17
|
import tempfile
|
|
18
|
-
import time
|
|
19
18
|
import uuid
|
|
20
19
|
from concurrent.futures import ThreadPoolExecutor
|
|
21
20
|
from io import BytesIO
|
|
@@ -25,16 +24,9 @@ import requests
|
|
|
25
24
|
import torch
|
|
26
25
|
|
|
27
26
|
from ....model.utils import select_device
|
|
28
|
-
from ....types import
|
|
29
|
-
ChatCompletion,
|
|
30
|
-
ChatCompletionChunk,
|
|
31
|
-
ChatCompletionMessage,
|
|
32
|
-
Completion,
|
|
33
|
-
CompletionChoice,
|
|
34
|
-
CompletionChunk,
|
|
35
|
-
CompletionUsage,
|
|
36
|
-
)
|
|
27
|
+
from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
|
|
37
28
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
29
|
+
from ..utils import generate_chat_completion, generate_completion_chunk
|
|
38
30
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
39
31
|
|
|
40
32
|
logger = logging.getLogger(__name__)
|
|
@@ -147,9 +139,7 @@ class DeepSeekVLChatModel(PytorchChatModel):
|
|
|
147
139
|
|
|
148
140
|
def chat(
|
|
149
141
|
self,
|
|
150
|
-
|
|
151
|
-
system_prompt: Optional[str] = None,
|
|
152
|
-
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
142
|
+
messages: List[Dict],
|
|
153
143
|
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
154
144
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
155
145
|
if not generate_config:
|
|
@@ -162,44 +152,40 @@ class DeepSeekVLChatModel(PytorchChatModel):
|
|
|
162
152
|
if isinstance(stream_options, dict)
|
|
163
153
|
else False
|
|
164
154
|
)
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
{"role": "Assistant", "content": ""},
|
|
172
|
-
]
|
|
173
|
-
if images:
|
|
174
|
-
prompt_messages[0]["images"] = images
|
|
175
|
-
|
|
176
|
-
# Convert openai history to qwen vl history
|
|
177
|
-
deepseek_history = []
|
|
178
|
-
for h in chat_history or []:
|
|
179
|
-
role = h["role"]
|
|
155
|
+
|
|
156
|
+
prompt = ""
|
|
157
|
+
deepseek_messages = []
|
|
158
|
+
for i, message in enumerate(messages):
|
|
159
|
+
role = message["role"]
|
|
160
|
+
content = message["content"]
|
|
180
161
|
if role == "user":
|
|
181
|
-
content,
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
162
|
+
if isinstance(content, str):
|
|
163
|
+
deepseek_messages.append({"role": "User", "content": content})
|
|
164
|
+
else:
|
|
165
|
+
content, images = self._message_content_to_deepseek(content)
|
|
166
|
+
msg: Dict[str, Any] = {
|
|
167
|
+
"role": "User",
|
|
168
|
+
"content": content,
|
|
169
|
+
}
|
|
170
|
+
if images:
|
|
171
|
+
msg["images"] = images
|
|
172
|
+
deepseek_messages.append(msg)
|
|
173
|
+
if i == len(messages) - 1:
|
|
174
|
+
prompt = content
|
|
189
175
|
elif role == "assistant":
|
|
190
|
-
|
|
176
|
+
deepseek_messages.append({"role": "Assistant", "content": content})
|
|
191
177
|
else:
|
|
192
|
-
logger.error(
|
|
193
|
-
|
|
194
|
-
|
|
178
|
+
logger.error(
|
|
179
|
+
f"Unexpected message in messages: role: {role}, message: {message}"
|
|
180
|
+
)
|
|
195
181
|
|
|
196
182
|
from ....thirdparty.deepseek_vl.serve.inference import generate
|
|
197
183
|
from ....thirdparty.deepseek_vl.utils.io import load_pil_images
|
|
198
184
|
|
|
199
185
|
# load images and prepare for inputs
|
|
200
|
-
pil_images = load_pil_images(
|
|
186
|
+
pil_images = load_pil_images(deepseek_messages)
|
|
201
187
|
prepare_inputs = self._vl_chat_processor(
|
|
202
|
-
conversations=
|
|
188
|
+
conversations=deepseek_messages, images=pil_images, force_batchify=True
|
|
203
189
|
).to(self._model.device, self._model.dtype)
|
|
204
190
|
|
|
205
191
|
temperature = generate_config.get("temperature", 0.2)
|
|
@@ -226,31 +212,16 @@ class DeepSeekVLChatModel(PytorchChatModel):
|
|
|
226
212
|
it = self._generate_stream(streamer, stop_str, include_usage, prompt)
|
|
227
213
|
return self._to_chat_completion_chunks(it)
|
|
228
214
|
else:
|
|
229
|
-
|
|
230
|
-
return self._to_chat_completion(c)
|
|
215
|
+
return self._generate(streamer, stop_str)
|
|
231
216
|
|
|
232
|
-
def _generate(self, streamer, stop_str) ->
|
|
217
|
+
def _generate(self, streamer, stop_str) -> ChatCompletion:
|
|
233
218
|
generated_text = ""
|
|
234
219
|
for new_text in streamer:
|
|
235
220
|
if new_text.endswith(stop_str):
|
|
236
221
|
new_text = new_text[: -len(stop_str)]
|
|
237
222
|
generated_text += new_text
|
|
238
223
|
|
|
239
|
-
|
|
240
|
-
id=str(uuid.uuid1()),
|
|
241
|
-
object="text_completion",
|
|
242
|
-
created=int(time.time()),
|
|
243
|
-
model=self.model_uid,
|
|
244
|
-
choices=[
|
|
245
|
-
CompletionChoice(
|
|
246
|
-
index=0, text=generated_text, finish_reason="stop", logprobs=None
|
|
247
|
-
)
|
|
248
|
-
],
|
|
249
|
-
usage=CompletionUsage(
|
|
250
|
-
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
251
|
-
),
|
|
252
|
-
)
|
|
253
|
-
return c
|
|
224
|
+
return generate_chat_completion(self.model_uid, generated_text)
|
|
254
225
|
|
|
255
226
|
def _generate_stream(
|
|
256
227
|
self, streamer, stop_str, include_usage, prompt
|
|
@@ -262,54 +233,40 @@ class DeepSeekVLChatModel(PytorchChatModel):
|
|
|
262
233
|
for i, new_text in enumerate(streamer):
|
|
263
234
|
if new_text.endswith(stop_str):
|
|
264
235
|
new_text = new_text[: -len(stop_str)]
|
|
265
|
-
completion_choice = CompletionChoice(
|
|
266
|
-
text=new_text, index=0, logprobs=None, finish_reason=None
|
|
267
|
-
)
|
|
268
|
-
chunk = CompletionChunk(
|
|
269
|
-
id=completion_id,
|
|
270
|
-
object="text_completion",
|
|
271
|
-
created=int(time.time()),
|
|
272
|
-
model=self.model_uid,
|
|
273
|
-
choices=[completion_choice],
|
|
274
|
-
)
|
|
275
236
|
completion_tokens = i
|
|
276
237
|
total_tokens = prompt_tokens + completion_tokens
|
|
277
|
-
|
|
238
|
+
yield generate_completion_chunk(
|
|
239
|
+
chunk_text=new_text,
|
|
240
|
+
finish_reason=None,
|
|
241
|
+
chunk_id=completion_id,
|
|
242
|
+
model_uid=self.model_uid,
|
|
278
243
|
prompt_tokens=prompt_tokens,
|
|
279
244
|
completion_tokens=completion_tokens,
|
|
280
245
|
total_tokens=total_tokens,
|
|
246
|
+
has_choice=True,
|
|
247
|
+
has_content=True,
|
|
281
248
|
)
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
)
|
|
288
|
-
chunk = CompletionChunk(
|
|
289
|
-
id=completion_id,
|
|
290
|
-
object="text_completion",
|
|
291
|
-
created=int(time.time()),
|
|
292
|
-
model=self.model_uid,
|
|
293
|
-
choices=[completion_choice],
|
|
294
|
-
)
|
|
295
|
-
completion_usage = CompletionUsage(
|
|
249
|
+
yield generate_completion_chunk(
|
|
250
|
+
chunk_text=None,
|
|
251
|
+
finish_reason="stop",
|
|
252
|
+
chunk_id=completion_id,
|
|
253
|
+
model_uid=self.model_uid,
|
|
296
254
|
prompt_tokens=prompt_tokens,
|
|
297
255
|
completion_tokens=completion_tokens,
|
|
298
256
|
total_tokens=total_tokens,
|
|
257
|
+
has_choice=True,
|
|
258
|
+
has_content=False,
|
|
299
259
|
)
|
|
300
|
-
|
|
301
|
-
yield chunk
|
|
260
|
+
|
|
302
261
|
if include_usage:
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
choices=[],
|
|
309
|
-
)
|
|
310
|
-
chunk["usage"] = CompletionUsage(
|
|
262
|
+
yield generate_completion_chunk(
|
|
263
|
+
chunk_text=None,
|
|
264
|
+
finish_reason=None,
|
|
265
|
+
chunk_id=completion_id,
|
|
266
|
+
model_uid=self.model_uid,
|
|
311
267
|
prompt_tokens=prompt_tokens,
|
|
312
268
|
completion_tokens=completion_tokens,
|
|
313
269
|
total_tokens=total_tokens,
|
|
270
|
+
has_choice=False,
|
|
271
|
+
has_content=False,
|
|
314
272
|
)
|
|
315
|
-
yield chunk
|