xinference 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/oauth2/auth_service.py +132 -0
- xinference/api/restful_api.py +282 -78
- xinference/client/handlers.py +3 -0
- xinference/client/restful/restful_client.py +108 -75
- xinference/constants.py +14 -4
- xinference/core/cache_tracker.py +102 -0
- xinference/core/chat_interface.py +10 -4
- xinference/core/event.py +56 -0
- xinference/core/model.py +44 -0
- xinference/core/resource.py +19 -12
- xinference/core/status_guard.py +4 -0
- xinference/core/supervisor.py +278 -87
- xinference/core/utils.py +68 -3
- xinference/core/worker.py +98 -8
- xinference/deploy/cmdline.py +6 -3
- xinference/deploy/local.py +2 -2
- xinference/deploy/supervisor.py +2 -2
- xinference/model/audio/__init__.py +27 -0
- xinference/model/audio/core.py +161 -0
- xinference/model/audio/model_spec.json +79 -0
- xinference/model/audio/utils.py +18 -0
- xinference/model/audio/whisper.py +132 -0
- xinference/model/core.py +18 -13
- xinference/model/embedding/__init__.py +27 -2
- xinference/model/embedding/core.py +43 -3
- xinference/model/embedding/model_spec.json +24 -0
- xinference/model/embedding/model_spec_modelscope.json +24 -0
- xinference/model/embedding/utils.py +18 -0
- xinference/model/image/__init__.py +12 -1
- xinference/model/image/core.py +63 -9
- xinference/model/image/utils.py +26 -0
- xinference/model/llm/__init__.py +20 -1
- xinference/model/llm/core.py +43 -2
- xinference/model/llm/ggml/chatglm.py +15 -6
- xinference/model/llm/llm_family.json +197 -6
- xinference/model/llm/llm_family.py +9 -7
- xinference/model/llm/llm_family_modelscope.json +189 -4
- xinference/model/llm/pytorch/chatglm.py +3 -3
- xinference/model/llm/pytorch/core.py +4 -2
- xinference/model/{multimodal → llm/pytorch}/qwen_vl.py +10 -8
- xinference/model/llm/pytorch/utils.py +21 -9
- xinference/model/llm/pytorch/yi_vl.py +246 -0
- xinference/model/llm/utils.py +57 -4
- xinference/model/llm/vllm/core.py +5 -4
- xinference/model/rerank/__init__.py +25 -2
- xinference/model/rerank/core.py +51 -9
- xinference/model/rerank/model_spec.json +6 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -0
- xinference/{api/oauth2/common.py → model/rerank/utils.py} +6 -2
- xinference/model/utils.py +5 -3
- xinference/thirdparty/__init__.py +0 -0
- xinference/thirdparty/llava/__init__.py +1 -0
- xinference/thirdparty/llava/conversation.py +205 -0
- xinference/thirdparty/llava/mm_utils.py +122 -0
- xinference/thirdparty/llava/model/__init__.py +1 -0
- xinference/thirdparty/llava/model/clip_encoder/__init__.py +0 -0
- xinference/thirdparty/llava/model/clip_encoder/builder.py +11 -0
- xinference/thirdparty/llava/model/clip_encoder/clip_encoder.py +86 -0
- xinference/thirdparty/llava/model/constants.py +6 -0
- xinference/thirdparty/llava/model/llava_arch.py +385 -0
- xinference/thirdparty/llava/model/llava_llama.py +163 -0
- xinference/thirdparty/llava/model/multimodal_projector/__init__.py +0 -0
- xinference/thirdparty/llava/model/multimodal_projector/builder.py +64 -0
- xinference/types.py +1 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.15822aeb.js +3 -0
- xinference/web/ui/build/static/js/main.15822aeb.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/139e5e4adf436923107d2b02994c7ff6dba2aac1989e9b6638984f0dfe782c4a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/52aa27272b4b9968f62666262b47661cb1992336a2aff3b13994cc36877b3ec3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/64accc515dc6cd584a2873796cd7da6f93de57f7e465eb5423cca9a2f3fe3eff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/65ca3ba225b8c8dac907210545b51f2fcdb2591f0feeb7195f1c037f2bc956a0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b80db1012318b97c329c4e3e72454f7512fb107e57c444b437dbe4ba1a3faa5a.json +1 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/METADATA +33 -23
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/RECORD +81 -64
- xinference/api/oauth2/core.py +0 -93
- xinference/model/multimodal/__init__.py +0 -52
- xinference/model/multimodal/core.py +0 -467
- xinference/model/multimodal/model_spec.json +0 -43
- xinference/model/multimodal/model_spec_modelscope.json +0 -45
- xinference/web/ui/build/static/js/main.b83095c2.js +0 -3
- xinference/web/ui/build/static/js/main.b83095c2.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/101923c539819f26ad11fbcbd6f6e56436b285efbb090dcc7dd648c6e924c4a8.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4942da6bc03bf7373af068e22f916341aabc5b5df855d73c1d348c696724ce37.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/52a6136cb2dbbf9c51d461724d9b283ebe74a73fb19d5df7ba8e13c42bd7174d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/71493aadd34d568fbe605cacaba220aa69bd09273251ee4ba27930f8d01fccd8.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8b071db2a5a9ef68dc14d5f606540bd23d9785e365a11997c510656764d2dccf.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a4d72d3b806ba061919115f0c513738726872e3c79cf258f007519d3f91d1a16.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f037ffef5992af0892d6d991053c1dace364cd39a3f11f1a41f92776e8a59459.json +0 -1
- /xinference/web/ui/build/static/js/{main.b83095c2.js.LICENSE.txt → main.15822aeb.js.LICENSE.txt} +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/LICENSE +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/WHEEL +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -19,19 +19,21 @@ import time
|
|
|
19
19
|
import uuid
|
|
20
20
|
from typing import Dict, Iterator, List, Optional, Union
|
|
21
21
|
|
|
22
|
-
from
|
|
22
|
+
from ....model.utils import select_device
|
|
23
|
+
from ....types import (
|
|
23
24
|
ChatCompletion,
|
|
24
25
|
ChatCompletionChoice,
|
|
25
26
|
ChatCompletionChunk,
|
|
27
|
+
ChatCompletionMessage,
|
|
26
28
|
CompletionUsage,
|
|
27
29
|
)
|
|
28
|
-
from ..
|
|
29
|
-
from .core import
|
|
30
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
31
|
+
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
30
32
|
|
|
31
33
|
logger = logging.getLogger(__name__)
|
|
32
34
|
|
|
33
35
|
|
|
34
|
-
class
|
|
36
|
+
class QwenVLChatModel(PytorchChatModel):
|
|
35
37
|
def __init__(self, *args, **kwargs):
|
|
36
38
|
super().__init__(*args, **kwargs)
|
|
37
39
|
self._tokenizer = None
|
|
@@ -39,7 +41,7 @@ class QwenVLChat(LVLM):
|
|
|
39
41
|
|
|
40
42
|
@classmethod
|
|
41
43
|
def match(
|
|
42
|
-
cls, model_family: "
|
|
44
|
+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
43
45
|
) -> bool:
|
|
44
46
|
if "qwen" in model_family.model_name:
|
|
45
47
|
return True
|
|
@@ -49,7 +51,7 @@ class QwenVLChat(LVLM):
|
|
|
49
51
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
50
52
|
from transformers.generation import GenerationConfig
|
|
51
53
|
|
|
52
|
-
device = self.
|
|
54
|
+
device = self._pytorch_model_config.get("device", "auto")
|
|
53
55
|
device = select_device(device)
|
|
54
56
|
|
|
55
57
|
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
@@ -106,8 +108,8 @@ class QwenVLChat(LVLM):
|
|
|
106
108
|
self,
|
|
107
109
|
prompt: Union[str, List[Dict]],
|
|
108
110
|
system_prompt: Optional[str] = None,
|
|
109
|
-
chat_history: Optional[List[
|
|
110
|
-
generate_config: Optional[
|
|
111
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
112
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
111
113
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
112
114
|
if generate_config and generate_config.get("stream"):
|
|
113
115
|
raise Exception(
|
|
@@ -29,7 +29,12 @@ from transformers.generation.logits_process import (
|
|
|
29
29
|
TopPLogitsWarper,
|
|
30
30
|
)
|
|
31
31
|
|
|
32
|
-
from ....types import
|
|
32
|
+
from ....types import (
|
|
33
|
+
CompletionChoice,
|
|
34
|
+
CompletionChunk,
|
|
35
|
+
CompletionUsage,
|
|
36
|
+
max_tokens_field,
|
|
37
|
+
)
|
|
33
38
|
|
|
34
39
|
logger = logging.getLogger(__name__)
|
|
35
40
|
|
|
@@ -54,16 +59,21 @@ def get_context_length(config):
|
|
|
54
59
|
hasattr(config, "max_sequence_length")
|
|
55
60
|
and config.max_sequence_length is not None
|
|
56
61
|
):
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
62
|
+
max_sequence_length = config.max_sequence_length
|
|
63
|
+
else:
|
|
64
|
+
max_sequence_length = 2048
|
|
65
|
+
if hasattr(config, "seq_length") and config.seq_length is not None:
|
|
66
|
+
seq_length = config.seq_length
|
|
67
|
+
else:
|
|
68
|
+
seq_length = 2048
|
|
69
|
+
if (
|
|
61
70
|
hasattr(config, "max_position_embeddings")
|
|
62
71
|
and config.max_position_embeddings is not None
|
|
63
72
|
):
|
|
64
|
-
|
|
73
|
+
max_position_embeddings = config.max_position_embeddings
|
|
65
74
|
else:
|
|
66
|
-
|
|
75
|
+
max_position_embeddings = 2048
|
|
76
|
+
return max(max_sequence_length, seq_length, max_position_embeddings)
|
|
67
77
|
|
|
68
78
|
|
|
69
79
|
def prepare_logits_processor(
|
|
@@ -102,7 +112,7 @@ def generate_stream(
|
|
|
102
112
|
repetition_penalty = float(generate_config.get("repetition_penalty", 1.0))
|
|
103
113
|
top_p = float(generate_config.get("top_p", 1.0))
|
|
104
114
|
top_k = int(generate_config.get("top_k", -1)) # -1 means disable
|
|
105
|
-
max_new_tokens = int(generate_config.get("max_tokens",
|
|
115
|
+
max_new_tokens = int(generate_config.get("max_tokens", max_tokens_field.default))
|
|
106
116
|
echo = bool(generate_config.get("echo", False))
|
|
107
117
|
stop_str = generate_config.get("stop", None)
|
|
108
118
|
stop_token_ids = generate_config.get("stop_token_ids", None) or []
|
|
@@ -123,6 +133,8 @@ def generate_stream(
|
|
|
123
133
|
max_src_len = context_len
|
|
124
134
|
else:
|
|
125
135
|
max_src_len = context_len - max_new_tokens - 8
|
|
136
|
+
if max_src_len < 0:
|
|
137
|
+
raise ValueError("Max tokens exceeds model's max length")
|
|
126
138
|
|
|
127
139
|
input_ids = input_ids[-max_src_len:]
|
|
128
140
|
input_echo_len = len(input_ids)
|
|
@@ -346,7 +358,7 @@ def generate_stream_falcon(
|
|
|
346
358
|
repetition_penalty = float(generate_config.get("repetition_penalty", 1.0))
|
|
347
359
|
top_p = float(generate_config.get("top_p", 1.0))
|
|
348
360
|
top_k = int(generate_config.get("top_k", 50)) # -1 means disable
|
|
349
|
-
max_new_tokens = int(generate_config.get("max_tokens",
|
|
361
|
+
max_new_tokens = int(generate_config.get("max_tokens", max_tokens_field.default))
|
|
350
362
|
echo = bool(generate_config.get("echo", False))
|
|
351
363
|
stop_str = generate_config.get("stop", None)
|
|
352
364
|
stop_token_ids = generate_config.get("stop_token_ids", None) or []
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import base64
|
|
15
|
+
import logging
|
|
16
|
+
import time
|
|
17
|
+
import uuid
|
|
18
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
19
|
+
from io import BytesIO
|
|
20
|
+
from threading import Thread
|
|
21
|
+
from typing import Dict, Iterator, List, Optional, Union
|
|
22
|
+
|
|
23
|
+
import requests
|
|
24
|
+
import torch
|
|
25
|
+
from PIL import Image
|
|
26
|
+
|
|
27
|
+
from ....model.utils import select_device
|
|
28
|
+
from ....types import (
|
|
29
|
+
ChatCompletion,
|
|
30
|
+
ChatCompletionChoice,
|
|
31
|
+
ChatCompletionChunk,
|
|
32
|
+
ChatCompletionMessage,
|
|
33
|
+
CompletionUsage,
|
|
34
|
+
)
|
|
35
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
36
|
+
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class YiVLChatModel(PytorchChatModel):
|
|
42
|
+
def __init__(self, *args, **kwargs):
|
|
43
|
+
super().__init__(*args, **kwargs)
|
|
44
|
+
self._tokenizer = None
|
|
45
|
+
self._model = None
|
|
46
|
+
self._image_processor = None
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def match(
|
|
50
|
+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
51
|
+
) -> bool:
|
|
52
|
+
if "yi" in model_family.model_name:
|
|
53
|
+
return True
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
def load(self):
|
|
57
|
+
from ....thirdparty.llava.mm_utils import load_pretrained_model
|
|
58
|
+
from ....thirdparty.llava.model.constants import key_info
|
|
59
|
+
|
|
60
|
+
device = self._pytorch_model_config.get("device", "auto")
|
|
61
|
+
device = select_device(device)
|
|
62
|
+
|
|
63
|
+
key_info["model_path"] = self.model_path
|
|
64
|
+
(
|
|
65
|
+
self._tokenizer,
|
|
66
|
+
self._model,
|
|
67
|
+
self._image_processor,
|
|
68
|
+
_,
|
|
69
|
+
) = load_pretrained_model(self.model_path, device_map=device)
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def _message_content_to_yi(content) -> Union[str, tuple]:
|
|
73
|
+
def _load_image(_url):
|
|
74
|
+
if _url.startswith("data:"):
|
|
75
|
+
logging.info("Parse url by base64 decoder.")
|
|
76
|
+
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
77
|
+
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
78
|
+
_type, data = _url.split(";")
|
|
79
|
+
_, ext = _type.split("/")
|
|
80
|
+
data = data[len("base64,") :]
|
|
81
|
+
data = base64.b64decode(data.encode("utf-8"))
|
|
82
|
+
|
|
83
|
+
return Image.open(BytesIO(data))
|
|
84
|
+
else:
|
|
85
|
+
try:
|
|
86
|
+
response = requests.get(_url)
|
|
87
|
+
except requests.exceptions.MissingSchema:
|
|
88
|
+
return Image.open(_url)
|
|
89
|
+
else:
|
|
90
|
+
return Image.open(BytesIO(response.content))
|
|
91
|
+
|
|
92
|
+
if not isinstance(content, str):
|
|
93
|
+
from ....thirdparty.llava.model.constants import DEFAULT_IMAGE_TOKEN
|
|
94
|
+
|
|
95
|
+
texts = []
|
|
96
|
+
image_urls = []
|
|
97
|
+
for c in content:
|
|
98
|
+
c_type = c.get("type")
|
|
99
|
+
if c_type == "text":
|
|
100
|
+
texts.append(c["text"])
|
|
101
|
+
elif c_type == "image_url":
|
|
102
|
+
image_urls.append(c["image_url"]["url"])
|
|
103
|
+
image_futures = []
|
|
104
|
+
with ThreadPoolExecutor() as executor:
|
|
105
|
+
for image_url in image_urls:
|
|
106
|
+
fut = executor.submit(_load_image, image_url)
|
|
107
|
+
image_futures.append(fut)
|
|
108
|
+
images = [fut.result() for fut in image_futures]
|
|
109
|
+
text = " ".join(texts)
|
|
110
|
+
if DEFAULT_IMAGE_TOKEN not in text:
|
|
111
|
+
text = DEFAULT_IMAGE_TOKEN + "\n" + text
|
|
112
|
+
if len(images) == 0:
|
|
113
|
+
return text
|
|
114
|
+
elif len(images) == 1:
|
|
115
|
+
return text, images[0], "Pad"
|
|
116
|
+
else:
|
|
117
|
+
raise RuntimeError("Only one image per message is supported by Yi VL.")
|
|
118
|
+
return content
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def _parse_text(text):
|
|
122
|
+
lines = text.split("\n")
|
|
123
|
+
lines = [line for line in lines if line != ""]
|
|
124
|
+
count = 0
|
|
125
|
+
for i, line in enumerate(lines):
|
|
126
|
+
if "```" in line:
|
|
127
|
+
count += 1
|
|
128
|
+
items = line.split("`")
|
|
129
|
+
if count % 2 == 1:
|
|
130
|
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
|
131
|
+
else:
|
|
132
|
+
lines[i] = f"<br></code></pre>"
|
|
133
|
+
else:
|
|
134
|
+
if i > 0:
|
|
135
|
+
if count % 2 == 1:
|
|
136
|
+
line = line.replace("`", r"\`")
|
|
137
|
+
line = line.replace("<", "<")
|
|
138
|
+
line = line.replace(">", ">")
|
|
139
|
+
line = line.replace(" ", " ")
|
|
140
|
+
line = line.replace("*", "*")
|
|
141
|
+
line = line.replace("_", "_")
|
|
142
|
+
line = line.replace("-", "-")
|
|
143
|
+
line = line.replace(".", ".")
|
|
144
|
+
line = line.replace("!", "!")
|
|
145
|
+
line = line.replace("(", "(")
|
|
146
|
+
line = line.replace(")", ")")
|
|
147
|
+
line = line.replace("$", "$")
|
|
148
|
+
lines[i] = "<br>" + line
|
|
149
|
+
text = "".join(lines)
|
|
150
|
+
return text
|
|
151
|
+
|
|
152
|
+
def chat(
|
|
153
|
+
self,
|
|
154
|
+
prompt: Union[str, List[Dict]],
|
|
155
|
+
system_prompt: Optional[str] = None,
|
|
156
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
157
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
158
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
159
|
+
from transformers import TextIteratorStreamer
|
|
160
|
+
|
|
161
|
+
# TODO(codingl2k1): implement stream mode.
|
|
162
|
+
if generate_config and generate_config.get("stream"):
|
|
163
|
+
raise Exception(
|
|
164
|
+
f"Chat with model {self.model_family.model_name} does not support stream."
|
|
165
|
+
)
|
|
166
|
+
if not generate_config:
|
|
167
|
+
generate_config = {}
|
|
168
|
+
from ....thirdparty.llava.conversation import conv_templates
|
|
169
|
+
from ....thirdparty.llava.mm_utils import (
|
|
170
|
+
KeywordsStoppingCriteria,
|
|
171
|
+
tokenizer_image_token,
|
|
172
|
+
)
|
|
173
|
+
from ....thirdparty.llava.model.constants import IMAGE_TOKEN_INDEX
|
|
174
|
+
|
|
175
|
+
# Convert chat history to llava state
|
|
176
|
+
state = conv_templates["mm_default"].copy()
|
|
177
|
+
for message in chat_history or []:
|
|
178
|
+
content = self._message_content_to_yi(message["content"])
|
|
179
|
+
state.append_message(message["role"], content)
|
|
180
|
+
state.append_message(state.roles[0], self._message_content_to_yi(prompt))
|
|
181
|
+
state.append_message(state.roles[1], None)
|
|
182
|
+
|
|
183
|
+
prompt = state.get_prompt()
|
|
184
|
+
|
|
185
|
+
input_ids = (
|
|
186
|
+
tokenizer_image_token(
|
|
187
|
+
prompt, self._tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
|
188
|
+
)
|
|
189
|
+
.unsqueeze(0)
|
|
190
|
+
.cuda()
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
images = state.get_images(return_pil=True)
|
|
194
|
+
image = images[0]
|
|
195
|
+
|
|
196
|
+
image_tensor = self._image_processor.preprocess(image, return_tensors="pt")[
|
|
197
|
+
"pixel_values"
|
|
198
|
+
][0]
|
|
199
|
+
|
|
200
|
+
stop_str = state.sep
|
|
201
|
+
keywords = [stop_str]
|
|
202
|
+
stopping_criteria = KeywordsStoppingCriteria(
|
|
203
|
+
keywords, self._tokenizer, input_ids
|
|
204
|
+
)
|
|
205
|
+
streamer = TextIteratorStreamer(
|
|
206
|
+
self._tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
|
|
207
|
+
)
|
|
208
|
+
top_p = generate_config.get("top_p", 0.7)
|
|
209
|
+
temperature = generate_config.get("temperature", 0.2)
|
|
210
|
+
max_new_tokens = generate_config.get("max_tokens", 512)
|
|
211
|
+
generate_kwargs = {
|
|
212
|
+
"input_ids": input_ids,
|
|
213
|
+
"images": image_tensor.unsqueeze(0).to(dtype=torch.bfloat16).cuda(),
|
|
214
|
+
"streamer": streamer,
|
|
215
|
+
"do_sample": True,
|
|
216
|
+
"top_p": float(top_p),
|
|
217
|
+
"temperature": float(temperature),
|
|
218
|
+
"stopping_criteria": [stopping_criteria],
|
|
219
|
+
"use_cache": True,
|
|
220
|
+
"max_new_tokens": min(int(max_new_tokens), 1536),
|
|
221
|
+
}
|
|
222
|
+
t = Thread(target=self._model.generate, kwargs=generate_kwargs)
|
|
223
|
+
t.start()
|
|
224
|
+
|
|
225
|
+
generated_text = ""
|
|
226
|
+
for new_text in streamer:
|
|
227
|
+
generated_text += new_text
|
|
228
|
+
if generated_text.endswith(stop_str):
|
|
229
|
+
generated_text = generated_text[: -len(stop_str)]
|
|
230
|
+
r = self._parse_text(generated_text)
|
|
231
|
+
return ChatCompletion(
|
|
232
|
+
id="chat" + str(uuid.uuid1()),
|
|
233
|
+
object="chat.completion",
|
|
234
|
+
created=int(time.time()),
|
|
235
|
+
model=self.model_uid,
|
|
236
|
+
choices=[
|
|
237
|
+
ChatCompletionChoice(
|
|
238
|
+
index=0,
|
|
239
|
+
message={"role": "assistant", "content": r},
|
|
240
|
+
finish_reason="stop",
|
|
241
|
+
)
|
|
242
|
+
],
|
|
243
|
+
usage=CompletionUsage(
|
|
244
|
+
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
245
|
+
),
|
|
246
|
+
)
|
xinference/model/llm/utils.py
CHANGED
|
@@ -14,11 +14,10 @@
|
|
|
14
14
|
import functools
|
|
15
15
|
import json
|
|
16
16
|
import logging
|
|
17
|
+
import os
|
|
17
18
|
import time
|
|
18
19
|
import uuid
|
|
19
|
-
from typing import AsyncGenerator, Dict, Iterator, List, Optional, cast
|
|
20
|
-
|
|
21
|
-
from xinference.model.llm.llm_family import PromptStyleV1
|
|
20
|
+
from typing import AsyncGenerator, Dict, Iterator, List, Optional, Tuple, cast
|
|
22
21
|
|
|
23
22
|
from ...types import (
|
|
24
23
|
SPECIAL_TOOL_PROMPT,
|
|
@@ -28,6 +27,14 @@ from ...types import (
|
|
|
28
27
|
Completion,
|
|
29
28
|
CompletionChunk,
|
|
30
29
|
)
|
|
30
|
+
from .llm_family import (
|
|
31
|
+
GgmlLLMSpecV1,
|
|
32
|
+
LLMFamilyV1,
|
|
33
|
+
LLMSpecV1,
|
|
34
|
+
PromptStyleV1,
|
|
35
|
+
_get_cache_dir,
|
|
36
|
+
get_cache_status,
|
|
37
|
+
)
|
|
31
38
|
|
|
32
39
|
logger = logging.getLogger(__name__)
|
|
33
40
|
|
|
@@ -303,7 +310,7 @@ Begin!"""
|
|
|
303
310
|
ret = (
|
|
304
311
|
"<s>"
|
|
305
312
|
if prompt_style.system_prompt == ""
|
|
306
|
-
else "<s
|
|
313
|
+
else "<s><|im_start|>system\n"
|
|
307
314
|
+ prompt_style.system_prompt
|
|
308
315
|
+ prompt_style.intra_message_sep
|
|
309
316
|
+ "\n"
|
|
@@ -373,6 +380,20 @@ Begin!"""
|
|
|
373
380
|
return f"USER: <<question>> {prompt} <<function>> {tools_string}\nASSISTANT: "
|
|
374
381
|
else:
|
|
375
382
|
return f"USER: <<question>> {prompt}\nASSISTANT: "
|
|
383
|
+
elif prompt_style.style_name == "orion":
|
|
384
|
+
ret = "<s>"
|
|
385
|
+
for i, message in enumerate(chat_history):
|
|
386
|
+
content = message["content"]
|
|
387
|
+
role = message["role"]
|
|
388
|
+
if i % 2 == 0: # Human
|
|
389
|
+
assert content is not None
|
|
390
|
+
ret += role + ": " + content + "\n\n"
|
|
391
|
+
else: # Assistant
|
|
392
|
+
if content:
|
|
393
|
+
ret += role + ": </s>" + content + "</s>"
|
|
394
|
+
else:
|
|
395
|
+
ret += role + ": </s>"
|
|
396
|
+
return ret
|
|
376
397
|
else:
|
|
377
398
|
raise ValueError(f"Invalid prompt style: {prompt_style.style_name}")
|
|
378
399
|
|
|
@@ -573,3 +594,35 @@ Begin!"""
|
|
|
573
594
|
"total_tokens": -1,
|
|
574
595
|
},
|
|
575
596
|
}
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
def get_file_location(
|
|
600
|
+
llm_family: LLMFamilyV1, spec: LLMSpecV1, quantization: str
|
|
601
|
+
) -> Tuple[str, bool]:
|
|
602
|
+
cache_dir = _get_cache_dir(llm_family, spec, create_if_not_exist=False)
|
|
603
|
+
cache_status = get_cache_status(llm_family, spec)
|
|
604
|
+
if isinstance(cache_status, list):
|
|
605
|
+
is_cached = None
|
|
606
|
+
for q, cs in zip(spec.quantizations, cache_status):
|
|
607
|
+
if q == quantization:
|
|
608
|
+
is_cached = cs
|
|
609
|
+
break
|
|
610
|
+
else:
|
|
611
|
+
is_cached = cache_status
|
|
612
|
+
assert isinstance(is_cached, bool)
|
|
613
|
+
|
|
614
|
+
if spec.model_format in ["pytorch", "gptq", "awq"]:
|
|
615
|
+
return cache_dir, is_cached
|
|
616
|
+
elif spec.model_format in ["ggmlv3", "ggufv2"]:
|
|
617
|
+
assert isinstance(spec, GgmlLLMSpecV1)
|
|
618
|
+
filename = spec.model_file_name_template.format(quantization=quantization)
|
|
619
|
+
model_path = os.path.join(cache_dir, filename)
|
|
620
|
+
return model_path, is_cached
|
|
621
|
+
else:
|
|
622
|
+
raise ValueError(f"Not supported model format {spec.model_format}")
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
def get_model_version(
|
|
626
|
+
llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str
|
|
627
|
+
) -> str:
|
|
628
|
+
return f"{llm_family.model_name}--{llm_spec.model_size_in_billions}B--{llm_spec.model_format}--{quantization}"
|
|
@@ -95,6 +95,7 @@ VLLM_SUPPORTED_CHAT_MODELS = [
|
|
|
95
95
|
"code-llama-instruct",
|
|
96
96
|
"mistral-instruct-v0.1",
|
|
97
97
|
"mistral-instruct-v0.2",
|
|
98
|
+
"mixtral-instruct-v0.1",
|
|
98
99
|
"chatglm3",
|
|
99
100
|
]
|
|
100
101
|
|
|
@@ -190,12 +191,12 @@ class VLLMModel(LLM):
|
|
|
190
191
|
return False
|
|
191
192
|
if not cls._is_linux():
|
|
192
193
|
return False
|
|
193
|
-
if llm_spec.model_format not in ["pytorch", "gptq"]:
|
|
194
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
194
195
|
return False
|
|
195
196
|
if llm_spec.model_format == "pytorch":
|
|
196
197
|
if quantization != "none" and not (quantization is None):
|
|
197
198
|
return False
|
|
198
|
-
if llm_spec.model_format
|
|
199
|
+
if llm_spec.model_format in ["gptq", "awq"]:
|
|
199
200
|
# Currently, only 4-bit weight quantization is supported for GPTQ, but got 8 bits.
|
|
200
201
|
if "4" not in quantization:
|
|
201
202
|
return False
|
|
@@ -336,12 +337,12 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
336
337
|
) -> bool:
|
|
337
338
|
if XINFERENCE_DISABLE_VLLM:
|
|
338
339
|
return False
|
|
339
|
-
if llm_spec.model_format not in ["pytorch", "gptq"]:
|
|
340
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
340
341
|
return False
|
|
341
342
|
if llm_spec.model_format == "pytorch":
|
|
342
343
|
if quantization != "none" and not (quantization is None):
|
|
343
344
|
return False
|
|
344
|
-
if llm_spec.model_format
|
|
345
|
+
if llm_spec.model_format in ["gptq", "awq"]:
|
|
345
346
|
# Currently, only 4-bit weight quantization is supported for GPTQ, but got 8 bits.
|
|
346
347
|
if "4" not in quantization:
|
|
347
348
|
return False
|
|
@@ -17,8 +17,20 @@ import json
|
|
|
17
17
|
import os
|
|
18
18
|
|
|
19
19
|
from ...constants import XINFERENCE_MODEL_DIR
|
|
20
|
-
from .core import
|
|
21
|
-
|
|
20
|
+
from .core import (
|
|
21
|
+
MODEL_NAME_TO_REVISION,
|
|
22
|
+
RERANK_MODEL_DESCRIPTIONS,
|
|
23
|
+
RerankModelSpec,
|
|
24
|
+
generate_rerank_description,
|
|
25
|
+
get_cache_status,
|
|
26
|
+
get_rerank_model_descriptions,
|
|
27
|
+
)
|
|
28
|
+
from .custom import (
|
|
29
|
+
CustomRerankModelSpec,
|
|
30
|
+
get_user_defined_reranks,
|
|
31
|
+
register_rerank,
|
|
32
|
+
unregister_rerank,
|
|
33
|
+
)
|
|
22
34
|
|
|
23
35
|
_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
|
|
24
36
|
_model_spec_modelscope_json = os.path.join(
|
|
@@ -30,6 +42,7 @@ BUILTIN_RERANK_MODELS = dict(
|
|
|
30
42
|
)
|
|
31
43
|
for model_name, model_spec in BUILTIN_RERANK_MODELS.items():
|
|
32
44
|
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
|
|
45
|
+
|
|
33
46
|
MODELSCOPE_RERANK_MODELS = dict(
|
|
34
47
|
(spec["model_name"], RerankModelSpec(**spec))
|
|
35
48
|
for spec in json.load(
|
|
@@ -39,6 +52,12 @@ MODELSCOPE_RERANK_MODELS = dict(
|
|
|
39
52
|
for model_name, model_spec in MODELSCOPE_RERANK_MODELS.items():
|
|
40
53
|
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
|
|
41
54
|
|
|
55
|
+
# register model description after recording model revision
|
|
56
|
+
for model_spec_info in [BUILTIN_RERANK_MODELS, MODELSCOPE_RERANK_MODELS]:
|
|
57
|
+
for model_name, model_spec in model_spec_info.items():
|
|
58
|
+
if model_spec.model_name not in RERANK_MODEL_DESCRIPTIONS:
|
|
59
|
+
RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(model_spec))
|
|
60
|
+
|
|
42
61
|
# if persist=True, load them when init
|
|
43
62
|
user_defined_rerank_dir = os.path.join(XINFERENCE_MODEL_DIR, "rerank")
|
|
44
63
|
if os.path.isdir(user_defined_rerank_dir):
|
|
@@ -49,5 +68,9 @@ if os.path.isdir(user_defined_rerank_dir):
|
|
|
49
68
|
user_defined_rerank_spec = CustomRerankModelSpec.parse_obj(json.load(fd))
|
|
50
69
|
register_rerank(user_defined_rerank_spec, persist=False)
|
|
51
70
|
|
|
71
|
+
# register model description
|
|
72
|
+
for ud_rerank in get_user_defined_reranks():
|
|
73
|
+
RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(ud_rerank))
|
|
74
|
+
|
|
52
75
|
del _model_spec_json
|
|
53
76
|
del _model_spec_modelscope_json
|
xinference/model/rerank/core.py
CHANGED
|
@@ -36,6 +36,15 @@ MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
|
|
|
36
36
|
SUPPORTED_SCHEMES = ["s3"]
|
|
37
37
|
|
|
38
38
|
|
|
39
|
+
RERANK_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_rerank_model_descriptions():
|
|
43
|
+
import copy
|
|
44
|
+
|
|
45
|
+
return copy.deepcopy(RERANK_MODEL_DESCRIPTIONS)
|
|
46
|
+
|
|
47
|
+
|
|
39
48
|
class RerankModelSpec(BaseModel):
|
|
40
49
|
model_name: str
|
|
41
50
|
language: List[str]
|
|
@@ -50,8 +59,9 @@ class RerankModelDescription(ModelDescription):
|
|
|
50
59
|
address: Optional[str],
|
|
51
60
|
devices: Optional[List[str]],
|
|
52
61
|
model_spec: RerankModelSpec,
|
|
62
|
+
model_path: Optional[str] = None,
|
|
53
63
|
):
|
|
54
|
-
super().__init__(address, devices)
|
|
64
|
+
super().__init__(address, devices, model_path=model_path)
|
|
55
65
|
self._model_spec = model_spec
|
|
56
66
|
|
|
57
67
|
def to_dict(self):
|
|
@@ -64,6 +74,31 @@ class RerankModelDescription(ModelDescription):
|
|
|
64
74
|
"model_revision": self._model_spec.model_revision,
|
|
65
75
|
}
|
|
66
76
|
|
|
77
|
+
def to_version_info(self):
|
|
78
|
+
from .utils import get_model_version
|
|
79
|
+
|
|
80
|
+
if self._model_path is None:
|
|
81
|
+
is_cached = get_cache_status(self._model_spec)
|
|
82
|
+
file_location = get_cache_dir(self._model_spec)
|
|
83
|
+
else:
|
|
84
|
+
is_cached = True
|
|
85
|
+
file_location = self._model_path
|
|
86
|
+
|
|
87
|
+
return {
|
|
88
|
+
"model_version": get_model_version(self._model_spec),
|
|
89
|
+
"model_file_location": file_location,
|
|
90
|
+
"cache_status": is_cached,
|
|
91
|
+
"language": self._model_spec.language,
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def generate_rerank_description(model_spec: RerankModelSpec) -> Dict[str, List[Dict]]:
|
|
96
|
+
res = defaultdict(list)
|
|
97
|
+
res[model_spec.model_name].append(
|
|
98
|
+
RerankModelDescription(None, None, model_spec).to_version_info()
|
|
99
|
+
)
|
|
100
|
+
return res
|
|
101
|
+
|
|
67
102
|
|
|
68
103
|
class RerankModel:
|
|
69
104
|
def __init__(
|
|
@@ -71,12 +106,14 @@ class RerankModel:
|
|
|
71
106
|
model_uid: str,
|
|
72
107
|
model_path: str,
|
|
73
108
|
device: Optional[str] = None,
|
|
109
|
+
use_fp16: bool = False,
|
|
74
110
|
model_config: Optional[Dict] = None,
|
|
75
111
|
):
|
|
76
112
|
self._model_uid = model_uid
|
|
77
113
|
self._model_path = model_path
|
|
78
114
|
self._device = device
|
|
79
115
|
self._model_config = model_config or dict()
|
|
116
|
+
self._use_fp16 = use_fp16
|
|
80
117
|
self._model = None
|
|
81
118
|
|
|
82
119
|
def load(self):
|
|
@@ -93,6 +130,8 @@ class RerankModel:
|
|
|
93
130
|
self._model = CrossEncoder(
|
|
94
131
|
self._model_path, device=self._device, **self._model_config
|
|
95
132
|
)
|
|
133
|
+
if self._use_fp16:
|
|
134
|
+
self._model.model.half()
|
|
96
135
|
|
|
97
136
|
def rerank(
|
|
98
137
|
self,
|
|
@@ -131,6 +170,10 @@ class RerankModel:
|
|
|
131
170
|
return Rerank(id=str(uuid.uuid1()), results=docs)
|
|
132
171
|
|
|
133
172
|
|
|
173
|
+
def get_cache_dir(model_spec: RerankModelSpec):
|
|
174
|
+
return os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name))
|
|
175
|
+
|
|
176
|
+
|
|
134
177
|
def get_cache_status(
|
|
135
178
|
model_spec: RerankModelSpec,
|
|
136
179
|
) -> bool:
|
|
@@ -145,9 +188,7 @@ def cache_from_uri(
|
|
|
145
188
|
|
|
146
189
|
from ..utils import copy_from_src_to_dst, parse_uri
|
|
147
190
|
|
|
148
|
-
cache_dir =
|
|
149
|
-
os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
|
|
150
|
-
)
|
|
191
|
+
cache_dir = get_cache_dir(model_spec)
|
|
151
192
|
if os.path.exists(cache_dir):
|
|
152
193
|
logger.info(f"Rerank cache {cache_dir} exists")
|
|
153
194
|
return cache_dir
|
|
@@ -227,9 +268,7 @@ def cache(model_spec: RerankModelSpec):
|
|
|
227
268
|
logger.info(f"Rerank model caching from URI: {model_spec.model_uri}")
|
|
228
269
|
return cache_from_uri(model_spec=model_spec)
|
|
229
270
|
|
|
230
|
-
cache_dir =
|
|
231
|
-
os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
|
|
232
|
-
)
|
|
271
|
+
cache_dir = get_cache_dir(model_spec)
|
|
233
272
|
if not os.path.exists(cache_dir):
|
|
234
273
|
os.makedirs(cache_dir, exist_ok=True)
|
|
235
274
|
meta_path = os.path.join(cache_dir, "__valid_download")
|
|
@@ -312,6 +351,9 @@ def create_rerank_model_instance(
|
|
|
312
351
|
)
|
|
313
352
|
|
|
314
353
|
model_path = cache(model_spec)
|
|
315
|
-
|
|
316
|
-
|
|
354
|
+
use_fp16 = kwargs.pop("use_fp16", False)
|
|
355
|
+
model = RerankModel(model_uid, model_path, use_fp16=use_fp16, model_config=kwargs)
|
|
356
|
+
model_description = RerankModelDescription(
|
|
357
|
+
subpool_addr, devices, model_spec, model_path=model_path
|
|
358
|
+
)
|
|
317
359
|
return model, model_description
|
|
@@ -10,5 +10,11 @@
|
|
|
10
10
|
"language": ["en", "zh"],
|
|
11
11
|
"model_id": "BAAI/bge-reranker-base",
|
|
12
12
|
"model_revision": "465b4b7ddf2be0a020c8ad6e525b9bb1dbb708ae"
|
|
13
|
+
},
|
|
14
|
+
{
|
|
15
|
+
"model_name": "bce-reranker-base_v1",
|
|
16
|
+
"language": ["en", "zh"],
|
|
17
|
+
"model_id": "maidalun1020/bce-reranker-base_v1",
|
|
18
|
+
"model_revision": "eaa31a577a0574e87a08959bd229ca14ce1b5496"
|
|
13
19
|
}
|
|
14
20
|
]
|