xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +15 -34
- xinference/client/restful/restful_client.py +2 -2
- xinference/core/chat_interface.py +45 -10
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +8 -5
- xinference/core/scheduler.py +1 -2
- xinference/core/worker.py +49 -42
- xinference/deploy/cmdline.py +2 -2
- xinference/deploy/test/test_cmdline.py +7 -7
- xinference/model/audio/chattts.py +24 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +23 -1
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +49 -1
- xinference/model/llm/__init__.py +26 -27
- xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
- xinference/model/llm/llm_family.json +606 -1266
- xinference/model/llm/llm_family.py +16 -139
- xinference/model/llm/llm_family_modelscope.json +276 -313
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/memory.py +9 -9
- xinference/model/llm/sglang/core.py +2 -2
- xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
- xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/{pytorch → transformers}/core.py +3 -10
- xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +540 -0
- xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
- xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
- xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
- xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
- xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
- xinference/model/llm/utils.py +85 -70
- xinference/model/llm/vllm/core.py +110 -11
- xinference/model/utils.py +1 -95
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/internvl/__init__.py +0 -0
- xinference/thirdparty/internvl/conversation.py +393 -0
- xinference/thirdparty/omnilmm/model/utils.py +16 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
- xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
- xinference/locale/utils.py +0 -39
- xinference/locale/zh_CN.json +0 -26
- xinference/model/llm/ggml/tools/__init__.py +0 -15
- xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
- xinference/model/llm/ggml/tools/gguf.py +0 -884
- xinference/model/llm/pytorch/__init__.py +0 -13
- xinference/model/llm/pytorch/baichuan.py +0 -81
- xinference/model/llm/pytorch/falcon.py +0 -138
- xinference/model/llm/pytorch/intern_vl.py +0 -352
- xinference/model/llm/pytorch/vicuna.py +0 -69
- xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
- xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
- /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
- /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
- /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
@@ -1,81 +0,0 @@
|
|
|
1
|
-
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from typing import List, Optional
|
|
16
|
-
|
|
17
|
-
from ....types import LoRA
|
|
18
|
-
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
19
|
-
from .core import PytorchChatModel, PytorchModelConfig
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class BaichuanPytorchChatModel(PytorchChatModel):
|
|
23
|
-
def __init__(
|
|
24
|
-
self,
|
|
25
|
-
model_uid: str,
|
|
26
|
-
model_family: "LLMFamilyV1",
|
|
27
|
-
model_spec: "LLMSpecV1",
|
|
28
|
-
quantization: str,
|
|
29
|
-
model_path: str,
|
|
30
|
-
pytorch_model_config: Optional[PytorchModelConfig] = None,
|
|
31
|
-
peft_model: Optional[List[LoRA]] = None,
|
|
32
|
-
):
|
|
33
|
-
super().__init__(
|
|
34
|
-
model_uid,
|
|
35
|
-
model_family,
|
|
36
|
-
model_spec,
|
|
37
|
-
quantization,
|
|
38
|
-
model_path,
|
|
39
|
-
pytorch_model_config=pytorch_model_config,
|
|
40
|
-
peft_model=peft_model,
|
|
41
|
-
)
|
|
42
|
-
self._use_fast_tokenizer = False
|
|
43
|
-
|
|
44
|
-
def _load_model(self, **kwargs):
|
|
45
|
-
try:
|
|
46
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
47
|
-
from transformers.generation.utils import GenerationConfig
|
|
48
|
-
except ImportError:
|
|
49
|
-
error_message = "Failed to import module 'transformers'"
|
|
50
|
-
installation_guide = [
|
|
51
|
-
"Please make sure 'transformers' is installed. ",
|
|
52
|
-
"You can install it by `pip install transformers`\n",
|
|
53
|
-
]
|
|
54
|
-
|
|
55
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
56
|
-
|
|
57
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
58
|
-
self.model_path,
|
|
59
|
-
use_fast=self._use_fast_tokenizer,
|
|
60
|
-
trust_remote_code=kwargs["trust_remote_code"],
|
|
61
|
-
revision=kwargs["revision"],
|
|
62
|
-
)
|
|
63
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
64
|
-
self.model_path,
|
|
65
|
-
**kwargs,
|
|
66
|
-
)
|
|
67
|
-
model.generation_config = GenerationConfig.from_pretrained(self.model_path)
|
|
68
|
-
return model, tokenizer
|
|
69
|
-
|
|
70
|
-
@classmethod
|
|
71
|
-
def match(
|
|
72
|
-
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
73
|
-
) -> bool:
|
|
74
|
-
if llm_spec.model_format != "pytorch":
|
|
75
|
-
return False
|
|
76
|
-
model_family = llm_family.model_family or llm_family.model_name
|
|
77
|
-
if model_family not in ["baichuan-chat", "baichuan-2-chat"]:
|
|
78
|
-
return False
|
|
79
|
-
if "chat" not in llm_family.model_ability:
|
|
80
|
-
return False
|
|
81
|
-
return True
|
|
@@ -1,138 +0,0 @@
|
|
|
1
|
-
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from typing import List, Optional
|
|
16
|
-
|
|
17
|
-
from ....types import LoRA
|
|
18
|
-
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
19
|
-
from .core import PytorchChatModel, PytorchModel, PytorchModelConfig
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class FalconPytorchModel(PytorchModel):
|
|
23
|
-
def __init__(
|
|
24
|
-
self,
|
|
25
|
-
model_uid: str,
|
|
26
|
-
model_family: "LLMFamilyV1",
|
|
27
|
-
model_spec: "LLMSpecV1",
|
|
28
|
-
quantization: str,
|
|
29
|
-
model_path: str,
|
|
30
|
-
pytorch_model_config: Optional[PytorchModelConfig] = None,
|
|
31
|
-
peft_model: Optional[List[LoRA]] = None,
|
|
32
|
-
):
|
|
33
|
-
super().__init__(
|
|
34
|
-
model_uid,
|
|
35
|
-
model_family,
|
|
36
|
-
model_spec,
|
|
37
|
-
quantization,
|
|
38
|
-
model_path,
|
|
39
|
-
pytorch_model_config=pytorch_model_config,
|
|
40
|
-
peft_model=peft_model,
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
def _load_model(self, **kwargs):
|
|
44
|
-
try:
|
|
45
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
46
|
-
except ImportError:
|
|
47
|
-
error_message = "Failed to import module 'transformers'"
|
|
48
|
-
installation_guide = [
|
|
49
|
-
"Please make sure 'transformers' is installed. ",
|
|
50
|
-
"You can install it by `pip install transformers`\n",
|
|
51
|
-
]
|
|
52
|
-
|
|
53
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
54
|
-
|
|
55
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
56
|
-
self.model_path,
|
|
57
|
-
trust_remote_code=kwargs["trust_remote_code"],
|
|
58
|
-
revision=kwargs["revision"],
|
|
59
|
-
)
|
|
60
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
61
|
-
self.model_path,
|
|
62
|
-
low_cpu_mem_usage=True,
|
|
63
|
-
**kwargs,
|
|
64
|
-
)
|
|
65
|
-
tokenizer.pad_token_id = 9
|
|
66
|
-
return model, tokenizer
|
|
67
|
-
|
|
68
|
-
@classmethod
|
|
69
|
-
def match(
|
|
70
|
-
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
71
|
-
) -> bool:
|
|
72
|
-
if llm_spec.model_format != "pytorch":
|
|
73
|
-
return False
|
|
74
|
-
model_family = llm_family.model_family or llm_family.model_name
|
|
75
|
-
if "falcon" not in model_family:
|
|
76
|
-
return False
|
|
77
|
-
if "generate" not in llm_family.model_ability:
|
|
78
|
-
return False
|
|
79
|
-
return True
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
class FalconPytorchChatModel(PytorchChatModel):
|
|
83
|
-
def __init__(
|
|
84
|
-
self,
|
|
85
|
-
model_uid: str,
|
|
86
|
-
model_family: "LLMFamilyV1",
|
|
87
|
-
model_spec: "LLMSpecV1",
|
|
88
|
-
quantization: str,
|
|
89
|
-
model_path: str,
|
|
90
|
-
pytorch_model_config: Optional[PytorchModelConfig] = None,
|
|
91
|
-
peft_model: Optional[List[LoRA]] = None,
|
|
92
|
-
):
|
|
93
|
-
super().__init__(
|
|
94
|
-
model_uid,
|
|
95
|
-
model_family,
|
|
96
|
-
model_spec,
|
|
97
|
-
quantization,
|
|
98
|
-
model_path,
|
|
99
|
-
pytorch_model_config=pytorch_model_config,
|
|
100
|
-
peft_model=peft_model,
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
def _load_model(self, **kwargs):
|
|
104
|
-
try:
|
|
105
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
106
|
-
except ImportError:
|
|
107
|
-
error_message = "Failed to import module 'transformers'"
|
|
108
|
-
installation_guide = [
|
|
109
|
-
"Please make sure 'transformers' is installed. ",
|
|
110
|
-
"You can install it by `pip install transformers`\n",
|
|
111
|
-
]
|
|
112
|
-
|
|
113
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
114
|
-
|
|
115
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
116
|
-
self.model_path,
|
|
117
|
-
trust_remote_code=kwargs["trust_remote_code"],
|
|
118
|
-
revision=kwargs["revision"],
|
|
119
|
-
)
|
|
120
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
121
|
-
self.model_path,
|
|
122
|
-
low_cpu_mem_usage=True,
|
|
123
|
-
**kwargs,
|
|
124
|
-
)
|
|
125
|
-
tokenizer.pad_token_id = 9
|
|
126
|
-
return model, tokenizer
|
|
127
|
-
|
|
128
|
-
@classmethod
|
|
129
|
-
def match(
|
|
130
|
-
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
131
|
-
) -> bool:
|
|
132
|
-
if llm_spec.model_format != "pytorch":
|
|
133
|
-
return False
|
|
134
|
-
if "falcon" not in llm_family.model_name:
|
|
135
|
-
return False
|
|
136
|
-
if "chat" not in llm_family.model_ability:
|
|
137
|
-
return False
|
|
138
|
-
return True
|
|
@@ -1,352 +0,0 @@
|
|
|
1
|
-
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
import base64
|
|
15
|
-
import logging
|
|
16
|
-
import time
|
|
17
|
-
import uuid
|
|
18
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
19
|
-
from io import BytesIO
|
|
20
|
-
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
|
21
|
-
|
|
22
|
-
import requests
|
|
23
|
-
import torch
|
|
24
|
-
from PIL import Image
|
|
25
|
-
|
|
26
|
-
from ....model.utils import select_device
|
|
27
|
-
from ....types import (
|
|
28
|
-
ChatCompletion,
|
|
29
|
-
ChatCompletionChunk,
|
|
30
|
-
ChatCompletionMessage,
|
|
31
|
-
Completion,
|
|
32
|
-
CompletionChoice,
|
|
33
|
-
CompletionUsage,
|
|
34
|
-
)
|
|
35
|
-
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
36
|
-
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
37
|
-
|
|
38
|
-
logger = logging.getLogger(__name__)
|
|
39
|
-
|
|
40
|
-
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
|
41
|
-
IMAGENET_STD = (0.229, 0.224, 0.225)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
class InternVLChatModel(PytorchChatModel):
|
|
45
|
-
def __init__(self, *args, **kwargs):
|
|
46
|
-
super().__init__(*args, **kwargs)
|
|
47
|
-
self._tokenizer = None
|
|
48
|
-
self._model = None
|
|
49
|
-
|
|
50
|
-
@classmethod
|
|
51
|
-
def match(
|
|
52
|
-
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
53
|
-
) -> bool:
|
|
54
|
-
family = model_family.model_family or model_family.model_name
|
|
55
|
-
if "internvl" in family.lower():
|
|
56
|
-
return True
|
|
57
|
-
return False
|
|
58
|
-
|
|
59
|
-
def _get_model_class(self):
|
|
60
|
-
from transformers import AutoModel
|
|
61
|
-
|
|
62
|
-
return AutoModel
|
|
63
|
-
|
|
64
|
-
def load(self, **kwargs):
|
|
65
|
-
from transformers import AutoModel, AutoTokenizer
|
|
66
|
-
from transformers.generation import GenerationConfig
|
|
67
|
-
|
|
68
|
-
if self._check_tensorizer_integrity():
|
|
69
|
-
self._model, self._tokenizer = self._load_tensorizer()
|
|
70
|
-
return
|
|
71
|
-
|
|
72
|
-
device = self._pytorch_model_config.get("device", "auto")
|
|
73
|
-
device = select_device(device)
|
|
74
|
-
# for multiple GPU, set back to auto to make multiple devices work
|
|
75
|
-
device = "auto" if device == "cuda" else device
|
|
76
|
-
|
|
77
|
-
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
78
|
-
self.model_path,
|
|
79
|
-
trust_remote_code=True,
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
kwargs = {
|
|
83
|
-
"torch_dtype": torch.bfloat16,
|
|
84
|
-
"low_cpu_mem_usage": True,
|
|
85
|
-
"trust_remote_code": True,
|
|
86
|
-
"device_map": device,
|
|
87
|
-
}
|
|
88
|
-
|
|
89
|
-
if "int8" in self.quantization.lower():
|
|
90
|
-
kwargs["load_in_8bit"] = True
|
|
91
|
-
elif 2 == self.model_spec.model_size_in_billions:
|
|
92
|
-
kwargs.pop("device_map")
|
|
93
|
-
|
|
94
|
-
self._model = AutoModel.from_pretrained(self.model_path, **kwargs).eval()
|
|
95
|
-
|
|
96
|
-
if "int8" not in self.quantization.lower():
|
|
97
|
-
self._model.cuda()
|
|
98
|
-
|
|
99
|
-
# Specify hyperparameters for generation
|
|
100
|
-
self._model.generation_config = GenerationConfig.from_pretrained(
|
|
101
|
-
self.model_path,
|
|
102
|
-
trust_remote_code=True,
|
|
103
|
-
)
|
|
104
|
-
self._save_tensorizer()
|
|
105
|
-
|
|
106
|
-
def _message_content_to_intern(self, content):
|
|
107
|
-
def _load_image(_url):
|
|
108
|
-
if _url.startswith("data:"):
|
|
109
|
-
logging.info("Parse url by base64 decoder.")
|
|
110
|
-
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
111
|
-
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
112
|
-
_type, data = _url.split(";")
|
|
113
|
-
_, ext = _type.split("/")
|
|
114
|
-
data = data[len("base64,") :]
|
|
115
|
-
data = base64.b64decode(data.encode("utf-8"))
|
|
116
|
-
return Image.open(BytesIO(data)).convert("RGB")
|
|
117
|
-
else:
|
|
118
|
-
try:
|
|
119
|
-
response = requests.get(_url)
|
|
120
|
-
except requests.exceptions.MissingSchema:
|
|
121
|
-
return Image.open(_url).convert("RGB")
|
|
122
|
-
else:
|
|
123
|
-
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
124
|
-
|
|
125
|
-
if not isinstance(content, str):
|
|
126
|
-
texts = []
|
|
127
|
-
image_urls = []
|
|
128
|
-
for c in content:
|
|
129
|
-
c_type = c.get("type")
|
|
130
|
-
if c_type == "text":
|
|
131
|
-
texts.append(c["text"])
|
|
132
|
-
elif c_type == "image_url":
|
|
133
|
-
image_urls.append(c["image_url"]["url"])
|
|
134
|
-
image_futures = []
|
|
135
|
-
with ThreadPoolExecutor() as executor:
|
|
136
|
-
for image_url in image_urls:
|
|
137
|
-
fut = executor.submit(_load_image, image_url)
|
|
138
|
-
image_futures.append(fut)
|
|
139
|
-
images = [fut.result() for fut in image_futures]
|
|
140
|
-
text = " ".join(texts)
|
|
141
|
-
if len(images) == 0:
|
|
142
|
-
return text, None
|
|
143
|
-
else:
|
|
144
|
-
return text, images
|
|
145
|
-
return content, None
|
|
146
|
-
|
|
147
|
-
def _history_content_to_intern(
|
|
148
|
-
self,
|
|
149
|
-
chat_history: List[ChatCompletionMessage],
|
|
150
|
-
IMG_START_TOKEN="<img>",
|
|
151
|
-
IMG_END_TOKEN="</img>",
|
|
152
|
-
IMG_CONTEXT_TOKEN="<IMG_CONTEXT>",
|
|
153
|
-
):
|
|
154
|
-
def _image_to_piexl_values(images):
|
|
155
|
-
load_images = []
|
|
156
|
-
for image in images:
|
|
157
|
-
if image.startswith("data:"):
|
|
158
|
-
logging.info("Parse url by base64 decoder.")
|
|
159
|
-
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
160
|
-
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
161
|
-
_type, data = image.split(";")
|
|
162
|
-
_, ext = _type.split("/")
|
|
163
|
-
data = data[len("base64,") :]
|
|
164
|
-
data = base64.b64decode(data.encode("utf-8"))
|
|
165
|
-
img = Image.open(BytesIO(data)).convert("RGB")
|
|
166
|
-
pixel_value = (
|
|
167
|
-
self._load_image(img, max_num=6).to(torch.bfloat16).cuda()
|
|
168
|
-
)
|
|
169
|
-
load_images.append(pixel_value)
|
|
170
|
-
else:
|
|
171
|
-
try:
|
|
172
|
-
response = requests.get(image)
|
|
173
|
-
except requests.exceptions.MissingSchema:
|
|
174
|
-
img = Image.open(image).convert("RGB")
|
|
175
|
-
else:
|
|
176
|
-
img = Image.open(BytesIO(response.content)).convert("RGB")
|
|
177
|
-
pixel_value = (
|
|
178
|
-
self._load_image(img, max_num=6).to(torch.bfloat16).cuda()
|
|
179
|
-
)
|
|
180
|
-
load_images.append(pixel_value)
|
|
181
|
-
return torch.cat(tuple(load_images), dim=0)
|
|
182
|
-
|
|
183
|
-
history: List[Tuple] = []
|
|
184
|
-
pixel_values = None
|
|
185
|
-
for i in range(0, len(chat_history), 2):
|
|
186
|
-
tmp = []
|
|
187
|
-
images: List[str] = []
|
|
188
|
-
user = chat_history[i]["content"]
|
|
189
|
-
if isinstance(user, List):
|
|
190
|
-
for content in user:
|
|
191
|
-
c_type = content.get("type")
|
|
192
|
-
if c_type == "text":
|
|
193
|
-
tmp.append(content["text"])
|
|
194
|
-
elif c_type == "image_url" and not history:
|
|
195
|
-
images.append(content["image_url"]["url"])
|
|
196
|
-
if not history:
|
|
197
|
-
pixel_values = _image_to_piexl_values(images)
|
|
198
|
-
image_bs = pixel_values.shape[0]
|
|
199
|
-
image_tokens = (
|
|
200
|
-
IMG_START_TOKEN
|
|
201
|
-
+ IMG_CONTEXT_TOKEN * self._model.num_image_token * image_bs
|
|
202
|
-
+ IMG_END_TOKEN
|
|
203
|
-
)
|
|
204
|
-
tmp[0] = image_tokens + "\n" + tmp[0]
|
|
205
|
-
else:
|
|
206
|
-
tmp.append(user)
|
|
207
|
-
tmp.append(chat_history[i + 1]["content"])
|
|
208
|
-
history.append(tuple(tmp))
|
|
209
|
-
return history, pixel_values
|
|
210
|
-
|
|
211
|
-
def _find_closest_aspect_ratio(
|
|
212
|
-
self, aspect_ratio, target_ratios, width, height, image_size
|
|
213
|
-
):
|
|
214
|
-
best_ratio_diff = float("inf")
|
|
215
|
-
best_ratio = (1, 1)
|
|
216
|
-
area = width * height
|
|
217
|
-
for ratio in target_ratios:
|
|
218
|
-
target_aspect_ratio = ratio[0] / ratio[1]
|
|
219
|
-
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
|
220
|
-
if ratio_diff < best_ratio_diff:
|
|
221
|
-
best_ratio_diff = ratio_diff
|
|
222
|
-
best_ratio = ratio
|
|
223
|
-
elif ratio_diff == best_ratio_diff:
|
|
224
|
-
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
|
225
|
-
best_ratio = ratio
|
|
226
|
-
return best_ratio
|
|
227
|
-
|
|
228
|
-
def _dynamic_preprocess(
|
|
229
|
-
self, image, min_num=1, max_num=6, image_size=448, use_thumbnail=False
|
|
230
|
-
):
|
|
231
|
-
orig_width, orig_height = image.size
|
|
232
|
-
aspect_ratio = orig_width / orig_height
|
|
233
|
-
|
|
234
|
-
# calculate the existing image aspect ratio
|
|
235
|
-
target_ratios = set(
|
|
236
|
-
(i, j)
|
|
237
|
-
for n in range(min_num, max_num + 1)
|
|
238
|
-
for i in range(1, n + 1)
|
|
239
|
-
for j in range(1, n + 1)
|
|
240
|
-
if i * j <= max_num and i * j >= min_num
|
|
241
|
-
)
|
|
242
|
-
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
|
243
|
-
|
|
244
|
-
# find the closest aspect ratio to the target
|
|
245
|
-
target_aspect_ratio = self._find_closest_aspect_ratio(
|
|
246
|
-
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
# calculate the target width and height
|
|
250
|
-
target_width = image_size * target_aspect_ratio[0]
|
|
251
|
-
target_height = image_size * target_aspect_ratio[1]
|
|
252
|
-
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
|
253
|
-
|
|
254
|
-
# resize the image
|
|
255
|
-
resized_img = image.resize((target_width, target_height))
|
|
256
|
-
processed_images = []
|
|
257
|
-
for i in range(blocks):
|
|
258
|
-
box = (
|
|
259
|
-
(i % (target_width // image_size)) * image_size,
|
|
260
|
-
(i // (target_width // image_size)) * image_size,
|
|
261
|
-
((i % (target_width // image_size)) + 1) * image_size,
|
|
262
|
-
((i // (target_width // image_size)) + 1) * image_size,
|
|
263
|
-
)
|
|
264
|
-
# split the image
|
|
265
|
-
split_img = resized_img.crop(box)
|
|
266
|
-
processed_images.append(split_img)
|
|
267
|
-
assert len(processed_images) == blocks
|
|
268
|
-
if use_thumbnail and len(processed_images) != 1:
|
|
269
|
-
thumbnail_img = image.resize((image_size, image_size))
|
|
270
|
-
processed_images.append(thumbnail_img)
|
|
271
|
-
return processed_images
|
|
272
|
-
|
|
273
|
-
def _build_transform(self, input_size):
|
|
274
|
-
import torchvision.transforms as T
|
|
275
|
-
from torchvision.transforms.functional import InterpolationMode
|
|
276
|
-
|
|
277
|
-
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
|
278
|
-
transform = T.Compose(
|
|
279
|
-
[
|
|
280
|
-
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
|
281
|
-
T.Resize(
|
|
282
|
-
(input_size, input_size), interpolation=InterpolationMode.BICUBIC
|
|
283
|
-
),
|
|
284
|
-
T.ToTensor(),
|
|
285
|
-
T.Normalize(mean=MEAN, std=STD),
|
|
286
|
-
]
|
|
287
|
-
)
|
|
288
|
-
return transform
|
|
289
|
-
|
|
290
|
-
def _load_image(self, image_file, input_size=448, max_num=6):
|
|
291
|
-
transform = self._build_transform(input_size=input_size)
|
|
292
|
-
images = self._dynamic_preprocess(
|
|
293
|
-
image_file, image_size=input_size, use_thumbnail=True, max_num=max_num
|
|
294
|
-
)
|
|
295
|
-
pixel_values = [transform(image) for image in images]
|
|
296
|
-
pixel_values = torch.stack(pixel_values)
|
|
297
|
-
return pixel_values
|
|
298
|
-
|
|
299
|
-
def chat(
|
|
300
|
-
self,
|
|
301
|
-
prompt: Union[str, List[Dict]],
|
|
302
|
-
system_prompt: Optional[str] = None,
|
|
303
|
-
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
304
|
-
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
305
|
-
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
306
|
-
if generate_config and generate_config.get("stream"):
|
|
307
|
-
raise Exception(
|
|
308
|
-
f"Chat with model {self.model_family.model_name} does not support stream."
|
|
309
|
-
)
|
|
310
|
-
sanitized_config = {
|
|
311
|
-
"num_beams": 1,
|
|
312
|
-
"max_new_tokens": generate_config.get("max_tokens", 512)
|
|
313
|
-
if generate_config
|
|
314
|
-
else 512,
|
|
315
|
-
"do_sample": False,
|
|
316
|
-
}
|
|
317
|
-
|
|
318
|
-
content, image = self._message_content_to_intern(prompt)
|
|
319
|
-
|
|
320
|
-
history = None
|
|
321
|
-
if chat_history:
|
|
322
|
-
history, pixel_values = self._history_content_to_intern(chat_history)
|
|
323
|
-
else:
|
|
324
|
-
load_images = []
|
|
325
|
-
for img in image:
|
|
326
|
-
pixel_value = self._load_image(img, max_num=6).to(torch.bfloat16).cuda()
|
|
327
|
-
load_images.append(pixel_value)
|
|
328
|
-
pixel_values = torch.cat(tuple(load_images), dim=0)
|
|
329
|
-
|
|
330
|
-
response, history = self._model.chat(
|
|
331
|
-
self._tokenizer,
|
|
332
|
-
pixel_values,
|
|
333
|
-
content,
|
|
334
|
-
sanitized_config,
|
|
335
|
-
history=history,
|
|
336
|
-
return_history=True,
|
|
337
|
-
)
|
|
338
|
-
chunk = Completion(
|
|
339
|
-
id=str(uuid.uuid1()),
|
|
340
|
-
object="text_completion",
|
|
341
|
-
created=int(time.time()),
|
|
342
|
-
model=self.model_uid,
|
|
343
|
-
choices=[
|
|
344
|
-
CompletionChoice(
|
|
345
|
-
index=0, text=response, finish_reason="stop", logprobs=None
|
|
346
|
-
)
|
|
347
|
-
],
|
|
348
|
-
usage=CompletionUsage(
|
|
349
|
-
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
350
|
-
),
|
|
351
|
-
)
|
|
352
|
-
return self._to_chat_completion(chunk)
|
|
@@ -1,69 +0,0 @@
|
|
|
1
|
-
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
# Copyright 2022-2023 XProbe Inc.
|
|
16
|
-
#
|
|
17
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
18
|
-
# you may not use this file except in compliance with the License.
|
|
19
|
-
# You may obtain a copy of the License at
|
|
20
|
-
#
|
|
21
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
22
|
-
#
|
|
23
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
24
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
25
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
26
|
-
# See the License for the specific language governing permissions and
|
|
27
|
-
# limitations under the License.
|
|
28
|
-
|
|
29
|
-
from typing import List, Optional
|
|
30
|
-
|
|
31
|
-
from ....types import LoRA
|
|
32
|
-
from .. import LLMFamilyV1, LLMSpecV1
|
|
33
|
-
from .core import PytorchChatModel, PytorchModelConfig
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
class VicunaPytorchChatModel(PytorchChatModel):
|
|
37
|
-
def __init__(
|
|
38
|
-
self,
|
|
39
|
-
model_uid: str,
|
|
40
|
-
model_family: "LLMFamilyV1",
|
|
41
|
-
model_spec: "LLMSpecV1",
|
|
42
|
-
quantization: str,
|
|
43
|
-
model_path: str,
|
|
44
|
-
pytorch_model_config: Optional["PytorchModelConfig"] = None,
|
|
45
|
-
peft_model: Optional[List[LoRA]] = None,
|
|
46
|
-
):
|
|
47
|
-
super().__init__(
|
|
48
|
-
model_uid,
|
|
49
|
-
model_family,
|
|
50
|
-
model_spec,
|
|
51
|
-
quantization,
|
|
52
|
-
model_path,
|
|
53
|
-
pytorch_model_config=pytorch_model_config,
|
|
54
|
-
peft_model=peft_model,
|
|
55
|
-
)
|
|
56
|
-
self._use_fast_tokenizer = False
|
|
57
|
-
|
|
58
|
-
@classmethod
|
|
59
|
-
def match(
|
|
60
|
-
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
61
|
-
) -> bool:
|
|
62
|
-
if llm_spec.model_format != "pytorch":
|
|
63
|
-
return False
|
|
64
|
-
model_family = llm_family.model_family or llm_family.model_name
|
|
65
|
-
if "vicuna" not in model_family:
|
|
66
|
-
return False
|
|
67
|
-
if "chat" not in llm_family.model_ability:
|
|
68
|
-
return False
|
|
69
|
-
return True
|