xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +15 -34
- xinference/client/restful/restful_client.py +2 -2
- xinference/core/chat_interface.py +45 -10
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +8 -5
- xinference/core/scheduler.py +1 -2
- xinference/core/worker.py +49 -42
- xinference/deploy/cmdline.py +2 -2
- xinference/deploy/test/test_cmdline.py +7 -7
- xinference/model/audio/chattts.py +24 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +23 -1
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +49 -1
- xinference/model/llm/__init__.py +26 -27
- xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
- xinference/model/llm/llm_family.json +606 -1266
- xinference/model/llm/llm_family.py +16 -139
- xinference/model/llm/llm_family_modelscope.json +276 -313
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/memory.py +9 -9
- xinference/model/llm/sglang/core.py +2 -2
- xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
- xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/{pytorch → transformers}/core.py +3 -10
- xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +540 -0
- xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
- xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
- xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
- xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
- xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
- xinference/model/llm/utils.py +85 -70
- xinference/model/llm/vllm/core.py +110 -11
- xinference/model/utils.py +1 -95
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/internvl/__init__.py +0 -0
- xinference/thirdparty/internvl/conversation.py +393 -0
- xinference/thirdparty/omnilmm/model/utils.py +16 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
- xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
- xinference/locale/utils.py +0 -39
- xinference/locale/zh_CN.json +0 -26
- xinference/model/llm/ggml/tools/__init__.py +0 -15
- xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
- xinference/model/llm/ggml/tools/gguf.py +0 -884
- xinference/model/llm/pytorch/__init__.py +0 -13
- xinference/model/llm/pytorch/baichuan.py +0 -81
- xinference/model/llm/pytorch/falcon.py +0 -138
- xinference/model/llm/pytorch/intern_vl.py +0 -352
- xinference/model/llm/pytorch/vicuna.py +0 -69
- xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
- xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
- /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
- /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
- /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,557 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
import time
|
|
16
|
+
import uuid
|
|
17
|
+
from typing import AsyncGenerator, Dict, Iterator, List, Optional, TypedDict, Union
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
from ....types import (
|
|
22
|
+
ChatCompletion,
|
|
23
|
+
ChatCompletionChunk,
|
|
24
|
+
ChatCompletionChunkChoice,
|
|
25
|
+
ChatCompletionMessage,
|
|
26
|
+
Completion,
|
|
27
|
+
CompletionChoice,
|
|
28
|
+
CompletionUsage,
|
|
29
|
+
LoRA,
|
|
30
|
+
)
|
|
31
|
+
from ..core import LLM
|
|
32
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
33
|
+
from ..utils import ChatModelMixin
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
import lmdeploy # noqa: F401
|
|
39
|
+
|
|
40
|
+
LMDEPLOY_INSTALLED = True
|
|
41
|
+
except ImportError:
|
|
42
|
+
LMDEPLOY_INSTALLED = False
|
|
43
|
+
|
|
44
|
+
LMDEPLOY_SUPPORTED_CHAT_MODELS = ["internvl2"]
|
|
45
|
+
LMDEPLOY_MODEL_CHAT_TEMPLATE_NAME = {
|
|
46
|
+
"internvl2": "internvl-internlm2",
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LMDeployModelConfig(TypedDict, total=False):
|
|
51
|
+
model_format: Optional[str]
|
|
52
|
+
tp: Optional[int]
|
|
53
|
+
session_len: Optional[int]
|
|
54
|
+
max_batch_size: Optional[int]
|
|
55
|
+
cache_max_entry_count: Optional[float]
|
|
56
|
+
cache_block_seq_len: Optional[int]
|
|
57
|
+
enable_prefix_caching: Optional[bool]
|
|
58
|
+
quant_policy: Optional[int]
|
|
59
|
+
rope_scaling_factor: Optional[float]
|
|
60
|
+
use_logn_attn: Optional[bool]
|
|
61
|
+
download_dir: Optional[str]
|
|
62
|
+
revision: Optional[str]
|
|
63
|
+
max_prefill_token_num: Optional[int]
|
|
64
|
+
num_tokens_per_iter: Optional[int]
|
|
65
|
+
max_prefill_iters: Optional[int]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class LMDeployGenerateConfig(TypedDict, total=False):
|
|
69
|
+
n: Optional[int]
|
|
70
|
+
max_new_tokens: Optional[int]
|
|
71
|
+
top_p: Optional[float]
|
|
72
|
+
top_k: Optional[int]
|
|
73
|
+
temperature: Optional[float]
|
|
74
|
+
repetition_penalty: Optional[float]
|
|
75
|
+
ignore_eos: Optional[bool]
|
|
76
|
+
random_seed: Optional[int]
|
|
77
|
+
stop_words: Optional[List[str]]
|
|
78
|
+
bad_words: Optional[List[str]]
|
|
79
|
+
min_new_tokens: Optional[int]
|
|
80
|
+
skip_special_tokens: Optional[bool]
|
|
81
|
+
logprobs: Optional[int]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class LMDeployModel(LLM):
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
model_uid: str,
|
|
88
|
+
model_family: "LLMFamilyV1",
|
|
89
|
+
model_spec: "LLMSpecV1",
|
|
90
|
+
quantization: str,
|
|
91
|
+
model_path: str,
|
|
92
|
+
model_config: Optional[LMDeployModelConfig] = None,
|
|
93
|
+
peft_model: Optional[List[LoRA]] = None,
|
|
94
|
+
):
|
|
95
|
+
super().__init__(model_uid, model_family, model_spec, quantization, model_path)
|
|
96
|
+
self._model_config: LMDeployModelConfig = self._sanitize_model_config(
|
|
97
|
+
model_config
|
|
98
|
+
)
|
|
99
|
+
if peft_model is not None:
|
|
100
|
+
raise ValueError("LMDEPLOY engine has not supported lora yet.")
|
|
101
|
+
|
|
102
|
+
def _sanitize_model_config(
|
|
103
|
+
self, model_config: Optional[LMDeployModelConfig]
|
|
104
|
+
) -> LMDeployModelConfig:
|
|
105
|
+
if model_config is None:
|
|
106
|
+
model_config = LMDeployModelConfig()
|
|
107
|
+
model_config.setdefault("session_len", 8192)
|
|
108
|
+
if self.model_spec.model_format == "awq":
|
|
109
|
+
model_config.setdefault("model_format", "awq")
|
|
110
|
+
return model_config
|
|
111
|
+
|
|
112
|
+
def load(self):
|
|
113
|
+
try:
|
|
114
|
+
import lmdeploy # noqa: F401, F811
|
|
115
|
+
except ImportError:
|
|
116
|
+
error_message = "Failed to import module 'lmdeploy'"
|
|
117
|
+
installation_guide = [
|
|
118
|
+
"Please make sure 'lmdeploy' is installed. ",
|
|
119
|
+
"You can install it by `pip install lmdeploy`\n",
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
123
|
+
raise ValueError("LMDEPLOY engine has not supported generate yet.")
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def match(
|
|
127
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
128
|
+
) -> bool:
|
|
129
|
+
return False
|
|
130
|
+
|
|
131
|
+
def generate(
|
|
132
|
+
self,
|
|
133
|
+
prompt: str,
|
|
134
|
+
generate_config: Optional[Dict] = None,
|
|
135
|
+
) -> Union[Completion, Iterator[ChatCompletionChunk]]:
|
|
136
|
+
raise NotImplementedError("LMDeploy generate ablility does not support now.")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class LMDeployChatModel(LMDeployModel, ChatModelMixin):
|
|
140
|
+
def load(self):
|
|
141
|
+
try:
|
|
142
|
+
from lmdeploy import (
|
|
143
|
+
ChatTemplateConfig,
|
|
144
|
+
TurbomindEngineConfig,
|
|
145
|
+
VisionConfig,
|
|
146
|
+
pipeline,
|
|
147
|
+
)
|
|
148
|
+
except ImportError:
|
|
149
|
+
error_message = "Failed to import module 'lmdeploy'"
|
|
150
|
+
installation_guide = [
|
|
151
|
+
"Please make sure 'lmdeploy' is installed. ",
|
|
152
|
+
"You can install it by `pip install lmdeploy`\n",
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
156
|
+
|
|
157
|
+
chat_temp_name = ""
|
|
158
|
+
family = self.model_family.model_family or self.model_family.model_name
|
|
159
|
+
for key in LMDEPLOY_MODEL_CHAT_TEMPLATE_NAME.keys():
|
|
160
|
+
if family in key:
|
|
161
|
+
chat_temp_name = LMDEPLOY_MODEL_CHAT_TEMPLATE_NAME[key]
|
|
162
|
+
break
|
|
163
|
+
if chat_temp_name == "":
|
|
164
|
+
raise ValueError(f"Can not find correct chat template.")
|
|
165
|
+
|
|
166
|
+
chat_template_config = ChatTemplateConfig(chat_temp_name)
|
|
167
|
+
chat_template_config.meta_instruction = (
|
|
168
|
+
self.model_family.prompt_style.system_prompt
|
|
169
|
+
)
|
|
170
|
+
count = torch.cuda.device_count()
|
|
171
|
+
if count > 1:
|
|
172
|
+
self._model_config.setdefault("tp", torch.cuda.device_count())
|
|
173
|
+
|
|
174
|
+
self._model = pipeline(
|
|
175
|
+
self.model_path,
|
|
176
|
+
chat_template_config=chat_template_config,
|
|
177
|
+
backend_config=TurbomindEngineConfig(**self._model_config),
|
|
178
|
+
vision_config=VisionConfig(thread_safe=True),
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def match(
|
|
183
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
184
|
+
) -> bool:
|
|
185
|
+
if llm_spec.model_format == "awq":
|
|
186
|
+
# Currently, only 4-bit weight quantization is supported for AWQ, but got 8 bits.
|
|
187
|
+
if "4" not in quantization:
|
|
188
|
+
return False
|
|
189
|
+
if llm_family.model_name not in LMDEPLOY_SUPPORTED_CHAT_MODELS:
|
|
190
|
+
return False
|
|
191
|
+
return LMDEPLOY_INSTALLED
|
|
192
|
+
|
|
193
|
+
async def async_chat(
|
|
194
|
+
self,
|
|
195
|
+
prompt: Union[str, List[Dict]],
|
|
196
|
+
system_prompt: Optional[str] = None,
|
|
197
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
198
|
+
generate_config: Optional[Dict] = None,
|
|
199
|
+
) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
|
|
200
|
+
stream = (
|
|
201
|
+
generate_config.get("stream", False)
|
|
202
|
+
if isinstance(generate_config, dict)
|
|
203
|
+
else False
|
|
204
|
+
)
|
|
205
|
+
stream_options = (
|
|
206
|
+
generate_config.get("stream_options", None)
|
|
207
|
+
if isinstance(generate_config, dict)
|
|
208
|
+
else False
|
|
209
|
+
)
|
|
210
|
+
include_usage = (
|
|
211
|
+
stream_options["include_usage"]
|
|
212
|
+
if isinstance(stream_options, dict)
|
|
213
|
+
else False
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
chat_history = chat_history or []
|
|
217
|
+
|
|
218
|
+
if stream:
|
|
219
|
+
chunk = self._chat_stream(prompt, chat_history, include_usage)
|
|
220
|
+
return self._async_to_chat_completion_chunks(chunk)
|
|
221
|
+
else:
|
|
222
|
+
chunk = await self._chat(prompt, chat_history)
|
|
223
|
+
return self._to_chat_completion(chunk)
|
|
224
|
+
|
|
225
|
+
async def _chat_stream(self, prompt, chat_history, include_usage):
|
|
226
|
+
from lmdeploy.messages import Response
|
|
227
|
+
|
|
228
|
+
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
|
229
|
+
completion_id = str(uuid.uuid1())
|
|
230
|
+
async for output in self._generate(
|
|
231
|
+
prompt,
|
|
232
|
+
chat_history,
|
|
233
|
+
session_id=-1,
|
|
234
|
+
stream_response=True,
|
|
235
|
+
):
|
|
236
|
+
new_text = output.text if isinstance(output, Response) else output.response
|
|
237
|
+
|
|
238
|
+
completion_choice = ChatCompletionChunkChoice(
|
|
239
|
+
text=new_text,
|
|
240
|
+
index=0,
|
|
241
|
+
logprobs=None,
|
|
242
|
+
finish_reason=output.finish_reason,
|
|
243
|
+
)
|
|
244
|
+
chunk = ChatCompletionChunk(
|
|
245
|
+
id=completion_id,
|
|
246
|
+
object="chat.completion",
|
|
247
|
+
created=int(time.time()),
|
|
248
|
+
model=self.model_uid,
|
|
249
|
+
choices=[completion_choice],
|
|
250
|
+
)
|
|
251
|
+
prompt_tokens = output.input_token_len
|
|
252
|
+
completion_tokens = output.generate_token_len
|
|
253
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
254
|
+
completion_usage = CompletionUsage(
|
|
255
|
+
prompt_tokens=prompt_tokens,
|
|
256
|
+
completion_tokens=completion_tokens,
|
|
257
|
+
total_tokens=total_tokens,
|
|
258
|
+
)
|
|
259
|
+
chunk["usage"] = completion_usage
|
|
260
|
+
print(chunk)
|
|
261
|
+
yield chunk
|
|
262
|
+
if include_usage:
|
|
263
|
+
chunk = ChatCompletionChunk(
|
|
264
|
+
id=completion_id,
|
|
265
|
+
object="chat.completion",
|
|
266
|
+
created=int(time.time()),
|
|
267
|
+
model=self.model_uid,
|
|
268
|
+
choices=[],
|
|
269
|
+
)
|
|
270
|
+
chunk["usage"] = CompletionUsage(
|
|
271
|
+
prompt_tokens=prompt_tokens,
|
|
272
|
+
completion_tokens=completion_tokens,
|
|
273
|
+
total_tokens=total_tokens,
|
|
274
|
+
)
|
|
275
|
+
yield chunk
|
|
276
|
+
|
|
277
|
+
async def _chat(self, prompt, chat_history):
|
|
278
|
+
from lmdeploy.messages import Response
|
|
279
|
+
|
|
280
|
+
response, finish_reason = "", ""
|
|
281
|
+
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
|
282
|
+
async for output in self._generate(
|
|
283
|
+
prompt,
|
|
284
|
+
chat_history,
|
|
285
|
+
session_id=-1,
|
|
286
|
+
stream_response=False,
|
|
287
|
+
):
|
|
288
|
+
response += output.text if isinstance(output, Response) else output.response
|
|
289
|
+
prompt_tokens = output.input_token_len
|
|
290
|
+
completion_tokens = output.generate_token_len
|
|
291
|
+
total_tokens = output.input_token_len + output.generate_token_len
|
|
292
|
+
finish_reason = output.finish_reason
|
|
293
|
+
|
|
294
|
+
chunk = ChatCompletion(
|
|
295
|
+
id=str(uuid.uuid1()),
|
|
296
|
+
object="chat.completion",
|
|
297
|
+
created=int(time.time()),
|
|
298
|
+
model=self.model_uid,
|
|
299
|
+
choices=[
|
|
300
|
+
CompletionChoice(
|
|
301
|
+
index=0, text=response, finish_reason=finish_reason, logprobs=None
|
|
302
|
+
)
|
|
303
|
+
],
|
|
304
|
+
usage=CompletionUsage(
|
|
305
|
+
prompt_tokens=prompt_tokens,
|
|
306
|
+
completion_tokens=completion_tokens,
|
|
307
|
+
total_tokens=total_tokens,
|
|
308
|
+
),
|
|
309
|
+
)
|
|
310
|
+
return chunk
|
|
311
|
+
|
|
312
|
+
# copy from lmdeploy
|
|
313
|
+
# Reference: lmdeploy.serve.async_engine.py
|
|
314
|
+
async def _generate(
|
|
315
|
+
self,
|
|
316
|
+
prompt,
|
|
317
|
+
chat_history,
|
|
318
|
+
session_id: int,
|
|
319
|
+
generate_config: Optional[Dict] = None,
|
|
320
|
+
tools: Optional[List[object]] = None,
|
|
321
|
+
stream_response: bool = True,
|
|
322
|
+
sequence_start: bool = True,
|
|
323
|
+
sequence_end: bool = True, # no interactive mode by default
|
|
324
|
+
step: int = 0,
|
|
325
|
+
do_preprocess: bool = False,
|
|
326
|
+
adapter_name: Optional[str] = None,
|
|
327
|
+
**kwargs,
|
|
328
|
+
):
|
|
329
|
+
import random
|
|
330
|
+
|
|
331
|
+
from lmdeploy.messages import EngineGenerationConfig, GenerationConfig
|
|
332
|
+
from lmdeploy.serve.async_engine import GenOut
|
|
333
|
+
from lmdeploy.tokenizer import DetokenizeState
|
|
334
|
+
|
|
335
|
+
session_id = -1
|
|
336
|
+
|
|
337
|
+
if str(session_id) not in self._model.id2step:
|
|
338
|
+
self._model.id2step[str(session_id)] = 0
|
|
339
|
+
if generate_config is None:
|
|
340
|
+
generate_config = GenerationConfig()
|
|
341
|
+
if type(generate_config) is GenerationConfig:
|
|
342
|
+
generate_config = EngineGenerationConfig.From(
|
|
343
|
+
generate_config, self._model.tokenizer
|
|
344
|
+
)
|
|
345
|
+
if generate_config.stop_words is None: # type: ignore
|
|
346
|
+
generate_config.stop_words = self._model.stop_words # type: ignore
|
|
347
|
+
if generate_config.random_seed is None and sequence_start: # type: ignore
|
|
348
|
+
generate_config.random_seed = random.getrandbits(64) # type: ignore
|
|
349
|
+
if generate_config.n > 1: # type: ignore
|
|
350
|
+
logger.warning(
|
|
351
|
+
f"n({generate_config.n}) > 1 hasn't been supported yet. " # type: ignore
|
|
352
|
+
f"Fallback to 1"
|
|
353
|
+
)
|
|
354
|
+
generate_config.n = 1 # type: ignore
|
|
355
|
+
|
|
356
|
+
prompt_input = await self._get_prompt_input(prompt, chat_history)
|
|
357
|
+
prompt = prompt_input["prompt"]
|
|
358
|
+
input_ids = prompt_input["input_ids"]
|
|
359
|
+
finish_reason = None
|
|
360
|
+
logger.info(
|
|
361
|
+
f"prompt={prompt!r}, "
|
|
362
|
+
f"gen_config={generate_config}, "
|
|
363
|
+
f"prompt_token_id={input_ids}, "
|
|
364
|
+
f"adapter_name={adapter_name}."
|
|
365
|
+
)
|
|
366
|
+
logger.info(
|
|
367
|
+
f"session_id={session_id}, " # type: ignore
|
|
368
|
+
f"history_tokens={self._model.id2step[str(session_id)]}, "
|
|
369
|
+
f"input_tokens={len(input_ids)}, "
|
|
370
|
+
f"max_new_tokens={generate_config.max_new_tokens}, "
|
|
371
|
+
f"seq_start={sequence_start}, seq_end={sequence_end}, "
|
|
372
|
+
f"step={step}, prep={do_preprocess}"
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
if generate_config.max_new_tokens is None: # type: ignore
|
|
376
|
+
# for interactive endpoint, will try maximum possible token num
|
|
377
|
+
generate_config.max_new_tokens = max( # type: ignore
|
|
378
|
+
128,
|
|
379
|
+
self._model.session_len
|
|
380
|
+
- self._model.id2step[str(session_id)]
|
|
381
|
+
- len(input_ids),
|
|
382
|
+
)
|
|
383
|
+
elif (
|
|
384
|
+
self._model.id2step[str(session_id)]
|
|
385
|
+
+ len(input_ids)
|
|
386
|
+
+ generate_config.max_new_tokens # type: ignore
|
|
387
|
+
> self._model.session_len
|
|
388
|
+
):
|
|
389
|
+
generate_config.max_new_tokens = max( # type: ignore
|
|
390
|
+
self._model.session_len
|
|
391
|
+
- self._model.id2step[str(session_id)]
|
|
392
|
+
- len(input_ids),
|
|
393
|
+
128,
|
|
394
|
+
)
|
|
395
|
+
logger.error(f"Truncate max_new_tokens to {generate_config.max_new_tokens}") # type: ignore
|
|
396
|
+
|
|
397
|
+
if (
|
|
398
|
+
self._model.id2step[str(session_id)]
|
|
399
|
+
+ len(input_ids)
|
|
400
|
+
+ generate_config.max_new_tokens # type: ignore
|
|
401
|
+
> self._model.session_len
|
|
402
|
+
):
|
|
403
|
+
logger.error(f"run out of tokens. session_id={session_id}.")
|
|
404
|
+
yield GenOut(
|
|
405
|
+
"", self._model.id2step[str(session_id)], len(input_ids), 0, "length"
|
|
406
|
+
)
|
|
407
|
+
if sequence_end is True and sequence_start is False:
|
|
408
|
+
await self._model.end_session(session_id)
|
|
409
|
+
else:
|
|
410
|
+
generator = await self._model.get_generator(False, session_id)
|
|
411
|
+
async with self._model.safe_run(session_id):
|
|
412
|
+
state = DetokenizeState(len(input_ids))
|
|
413
|
+
start_ids_offset = state.ids_offset
|
|
414
|
+
response = ""
|
|
415
|
+
async for outputs in generator.async_stream_infer(
|
|
416
|
+
session_id=session_id,
|
|
417
|
+
**prompt_input,
|
|
418
|
+
gen_config=generate_config,
|
|
419
|
+
adapter_name=adapter_name,
|
|
420
|
+
stream_output=stream_response,
|
|
421
|
+
sequence_start=sequence_start,
|
|
422
|
+
sequence_end=sequence_end,
|
|
423
|
+
step=self._model.id2step[str(session_id)],
|
|
424
|
+
):
|
|
425
|
+
# decode res
|
|
426
|
+
res, tokens = (
|
|
427
|
+
input_ids + outputs.token_ids,
|
|
428
|
+
outputs.num_token,
|
|
429
|
+
) # noqa
|
|
430
|
+
if len(res) <= state.ids_offset:
|
|
431
|
+
continue
|
|
432
|
+
|
|
433
|
+
ids_offset = state.ids_offset
|
|
434
|
+
response, state = self._model.tokenizer.detokenize_incrementally(
|
|
435
|
+
res,
|
|
436
|
+
state,
|
|
437
|
+
skip_special_tokens=generate_config.skip_special_tokens, # type: ignore
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
res = res[ids_offset:]
|
|
441
|
+
logprobs = None
|
|
442
|
+
if outputs.logprobs:
|
|
443
|
+
log_offset = ids_offset - start_ids_offset
|
|
444
|
+
logprobs = outputs.logprobs[log_offset:]
|
|
445
|
+
|
|
446
|
+
# response, history token len,
|
|
447
|
+
# input token len, gen token len
|
|
448
|
+
yield GenOut(
|
|
449
|
+
response,
|
|
450
|
+
self._model.id2step[str(session_id)],
|
|
451
|
+
len(input_ids),
|
|
452
|
+
tokens,
|
|
453
|
+
finish_reason,
|
|
454
|
+
res,
|
|
455
|
+
logprobs,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
finish_reason = (
|
|
459
|
+
"length" if tokens >= generate_config.max_new_tokens else "stop" # type: ignore
|
|
460
|
+
)
|
|
461
|
+
# utf-8 char at the end means it's a potential unfinished
|
|
462
|
+
# byte sequence
|
|
463
|
+
if not response.endswith("�"):
|
|
464
|
+
response = "" # avaid returning the last response twice
|
|
465
|
+
yield GenOut(
|
|
466
|
+
response,
|
|
467
|
+
self._model.id2step[str(session_id)],
|
|
468
|
+
len(input_ids),
|
|
469
|
+
tokens,
|
|
470
|
+
finish_reason,
|
|
471
|
+
)
|
|
472
|
+
# update step
|
|
473
|
+
self._model.id2step[str(session_id)] += len(input_ids) + tokens
|
|
474
|
+
if sequence_end:
|
|
475
|
+
self._model.id2step[str(session_id)] = 0
|
|
476
|
+
# manually end pytorch session
|
|
477
|
+
# TODO modify pytorch or turbomind api
|
|
478
|
+
if self._model.backend == "pytorch" and sequence_end:
|
|
479
|
+
await self._model.end_session(session_id)
|
|
480
|
+
|
|
481
|
+
# copy from lmdeploy
|
|
482
|
+
# Reference: lmdeploy.serve.vl_async_engine.py
|
|
483
|
+
async def _get_prompt_input(
|
|
484
|
+
self,
|
|
485
|
+
prompt: Union[str, List[Dict]],
|
|
486
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
487
|
+
sequence_start: bool = True,
|
|
488
|
+
tools: Optional[List[object]] = None,
|
|
489
|
+
**kwargs,
|
|
490
|
+
):
|
|
491
|
+
"""get input_ids, embeddings and offsets."""
|
|
492
|
+
IMAGE_TOKEN = "<IMAGE_TOKEN>"
|
|
493
|
+
IMAGE_DUMMY_TOKEN_INDEX = 0
|
|
494
|
+
import numpy as np
|
|
495
|
+
|
|
496
|
+
assert self.model_family.prompt_style is not None
|
|
497
|
+
prompt_style = self.model_family.prompt_style.copy()
|
|
498
|
+
chat_history = chat_history or []
|
|
499
|
+
|
|
500
|
+
decorated, _ = self.get_prompt(prompt, chat_history, prompt_style) # type: ignore
|
|
501
|
+
chat_history.append(ChatCompletionMessage(role="user", content=prompt)) # type: ignore
|
|
502
|
+
prompt = chat_history # type: ignore
|
|
503
|
+
|
|
504
|
+
decorated = decorated.replace("<image>", "<img><IMAGE_TOKEN></img>")
|
|
505
|
+
|
|
506
|
+
segs = decorated.split(IMAGE_TOKEN)
|
|
507
|
+
|
|
508
|
+
results = {}
|
|
509
|
+
input_ids = [] # type: ignore
|
|
510
|
+
if len(segs) > 1:
|
|
511
|
+
images = await self._model.vl_prompt_template.async_collect_pil_images(
|
|
512
|
+
prompt
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
features = await self._model.vl_encoder.async_infer(images)
|
|
516
|
+
|
|
517
|
+
from lmdeploy.vl.templates import MiniCPMVTempateWrapper
|
|
518
|
+
|
|
519
|
+
if isinstance(self._model.vl_prompt_template, MiniCPMVTempateWrapper):
|
|
520
|
+
(
|
|
521
|
+
decorated,
|
|
522
|
+
features,
|
|
523
|
+
) = self._model.vl_prompt_template.update_image_token( # noqa: E501
|
|
524
|
+
decorated, features
|
|
525
|
+
)
|
|
526
|
+
segs = decorated.split(IMAGE_TOKEN)
|
|
527
|
+
|
|
528
|
+
features = [x.cpu().numpy() for x in features]
|
|
529
|
+
input_ids = []
|
|
530
|
+
begins = []
|
|
531
|
+
ends = []
|
|
532
|
+
if len(segs) != len(features) + 1:
|
|
533
|
+
logger.error(
|
|
534
|
+
f"the number of {IMAGE_TOKEN} is not equal "
|
|
535
|
+
f"to input images, {len(segs) - 1} vs {len(features)}"
|
|
536
|
+
)
|
|
537
|
+
features = features[: len(segs) - 1]
|
|
538
|
+
for i, seg in enumerate(segs):
|
|
539
|
+
if i > 0 and i <= len(features):
|
|
540
|
+
image_dim = features[i - 1].shape[0]
|
|
541
|
+
begins.append(len(input_ids))
|
|
542
|
+
ends.append(begins[-1] + image_dim)
|
|
543
|
+
input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim)
|
|
544
|
+
seg_ids = self._model.tokenizer.encode(
|
|
545
|
+
seg, add_bos=((i == 0) and sequence_start)
|
|
546
|
+
)
|
|
547
|
+
input_ids.extend(seg_ids)
|
|
548
|
+
ranges = np.stack([begins, ends], axis=1).tolist()
|
|
549
|
+
results["input_embeddings"] = features
|
|
550
|
+
results["input_embedding_ranges"] = ranges
|
|
551
|
+
else:
|
|
552
|
+
input_ids = self._model.tokenizer.encode(decorated, add_bos=sequence_start)
|
|
553
|
+
|
|
554
|
+
results["input_ids"] = input_ids
|
|
555
|
+
results["prompt"] = decorated
|
|
556
|
+
|
|
557
|
+
return results
|
xinference/model/llm/memory.py
CHANGED
|
@@ -61,7 +61,7 @@ class ModelMemInfo:
|
|
|
61
61
|
|
|
62
62
|
QUANT_NORMALIZE = {"int4": "4-bit", "int8": "8-bit", "4-bit": "4-bit", "8-bit": "8-bit"}
|
|
63
63
|
|
|
64
|
-
|
|
64
|
+
GGUF_MULTI_FACTOR_DICT = {
|
|
65
65
|
"q4_0": 18,
|
|
66
66
|
"q4_1": 20,
|
|
67
67
|
"q5_0": 22,
|
|
@@ -70,14 +70,14 @@ GGML_MULTI_FACTOR_DICT = {
|
|
|
70
70
|
"q8_1": 40,
|
|
71
71
|
}
|
|
72
72
|
|
|
73
|
-
|
|
73
|
+
GGUF_MULTI_FACTOR_DICT_64 = {
|
|
74
74
|
"q6_K": 54.0,
|
|
75
75
|
"q3": 26.0,
|
|
76
76
|
"q4": 38.0,
|
|
77
77
|
"q5": 46.0,
|
|
78
78
|
}
|
|
79
79
|
|
|
80
|
-
|
|
80
|
+
GGUF_MULTI_FACTOR_DICT_COMBINE = {
|
|
81
81
|
"q3_K_L": [38.0, 26.0],
|
|
82
82
|
"q3_K_M": [46.0, 26.0],
|
|
83
83
|
"q4_K_S": [46.0, 38.0],
|
|
@@ -136,9 +136,9 @@ def estimate_llm_gpu_memory_details(
|
|
|
136
136
|
else:
|
|
137
137
|
kv_dtype_size = 4
|
|
138
138
|
overhead = 650.0
|
|
139
|
-
if model_format == "
|
|
139
|
+
if model_format == "ggufv2":
|
|
140
140
|
assert quantization is not None and quantization != "none"
|
|
141
|
-
model_size_in_mb =
|
|
141
|
+
model_size_in_mb = _compute_model_size_gguf(info, quantization)
|
|
142
142
|
inference_mem = float(
|
|
143
143
|
context_length * kv_dtype_size * info.hidden_dim * info.num_layers
|
|
144
144
|
)
|
|
@@ -291,7 +291,7 @@ def _compute_inference_only_activation_memory(
|
|
|
291
291
|
return ret
|
|
292
292
|
|
|
293
293
|
|
|
294
|
-
def
|
|
294
|
+
def _compute_model_size_gguf(info: ModelLayersInfo, quantization: str) -> float:
|
|
295
295
|
assert quantization is not None
|
|
296
296
|
vocab_size = info.vocab_size
|
|
297
297
|
num_layers = info.num_layers
|
|
@@ -310,13 +310,13 @@ def _compute_model_size_ggml(info: ModelLayersInfo, quantization: str) -> float:
|
|
|
310
310
|
)
|
|
311
311
|
|
|
312
312
|
total = 0.0
|
|
313
|
-
v1 =
|
|
313
|
+
v1 = GGUF_MULTI_FACTOR_DICT.get(quantization)
|
|
314
314
|
if v1 is not None:
|
|
315
315
|
total = (v1 * total_params) / (32 * 1024 * 1024)
|
|
316
|
-
v2 =
|
|
316
|
+
v2 = GGUF_MULTI_FACTOR_DICT_64.get(quantization)
|
|
317
317
|
if v2 is not None:
|
|
318
318
|
total = (v2 * total_params) / (64 * 1024 * 1024)
|
|
319
|
-
v3 =
|
|
319
|
+
v3 = GGUF_MULTI_FACTOR_DICT_COMBINE.get(quantization)
|
|
320
320
|
if v3 is not None:
|
|
321
321
|
factors = v3
|
|
322
322
|
if quantization == "q2_K":
|
|
@@ -189,7 +189,7 @@ class SGLANGModel(LLM):
|
|
|
189
189
|
return False
|
|
190
190
|
if not cls._is_linux():
|
|
191
191
|
return False
|
|
192
|
-
if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
192
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
|
|
193
193
|
return False
|
|
194
194
|
if llm_spec.model_format == "pytorch":
|
|
195
195
|
if quantization != "none" and not (quantization is None):
|
|
@@ -378,7 +378,7 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
|
|
|
378
378
|
def match(
|
|
379
379
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
380
380
|
) -> bool:
|
|
381
|
-
if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
381
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
|
|
382
382
|
return False
|
|
383
383
|
if llm_spec.model_format == "pytorch":
|
|
384
384
|
if quantization != "none" and not (quantization is None):
|
|
@@ -344,7 +344,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
344
344
|
return kwargs, tools
|
|
345
345
|
|
|
346
346
|
@torch.inference_mode()
|
|
347
|
-
def
|
|
347
|
+
def _stream_chat(
|
|
348
348
|
self,
|
|
349
349
|
tokenizer,
|
|
350
350
|
query: str,
|
|
@@ -399,7 +399,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
399
399
|
yield new_response, new_history
|
|
400
400
|
|
|
401
401
|
@torch.inference_mode()
|
|
402
|
-
def
|
|
402
|
+
def _non_stream_chat(
|
|
403
403
|
self,
|
|
404
404
|
tokenizer,
|
|
405
405
|
query: str,
|
|
@@ -475,10 +475,6 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
475
475
|
if stream and (
|
|
476
476
|
not tools or self.model_family.model_name in GLM4_TOOL_CALL_FAMILY
|
|
477
477
|
):
|
|
478
|
-
if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
|
|
479
|
-
stream_chat = self.stream_chat
|
|
480
|
-
else:
|
|
481
|
-
stream_chat = self._model.stream_chat
|
|
482
478
|
|
|
483
479
|
def _stream_generator():
|
|
484
480
|
last_chunk_text_length = 0
|
|
@@ -487,7 +483,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
487
483
|
inputs = self._tokenizer([prompt], return_tensors="pt")
|
|
488
484
|
inputs = inputs.to(self._model.device)
|
|
489
485
|
prompt_tokens = len(inputs["input_ids"][0])
|
|
490
|
-
for chunk_text, _ in
|
|
486
|
+
for chunk_text, _ in self._stream_chat(
|
|
491
487
|
self._tokenizer, prompt, chat_history, **kwargs
|
|
492
488
|
):
|
|
493
489
|
if tools and isinstance(chunk_text, dict):
|
|
@@ -548,12 +544,9 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
548
544
|
|
|
549
545
|
return self._to_chat_completion_chunks(_stream_generator())
|
|
550
546
|
else:
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
chat = self._model.chat
|
|
555
|
-
|
|
556
|
-
response = chat(self._tokenizer, prompt, chat_history, **kwargs)
|
|
547
|
+
response = self._non_stream_chat(
|
|
548
|
+
self._tokenizer, prompt, chat_history, **kwargs
|
|
549
|
+
)
|
|
557
550
|
if tools:
|
|
558
551
|
return self._tool_calls_completion(
|
|
559
552
|
self.model_family, self.model_uid, response, tools
|