xinference 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/oauth2/auth_service.py +132 -0
- xinference/api/restful_api.py +282 -78
- xinference/client/handlers.py +3 -0
- xinference/client/restful/restful_client.py +108 -75
- xinference/constants.py +14 -4
- xinference/core/cache_tracker.py +102 -0
- xinference/core/chat_interface.py +10 -4
- xinference/core/event.py +56 -0
- xinference/core/model.py +44 -0
- xinference/core/resource.py +19 -12
- xinference/core/status_guard.py +4 -0
- xinference/core/supervisor.py +278 -87
- xinference/core/utils.py +68 -3
- xinference/core/worker.py +98 -8
- xinference/deploy/cmdline.py +6 -3
- xinference/deploy/local.py +2 -2
- xinference/deploy/supervisor.py +2 -2
- xinference/model/audio/__init__.py +27 -0
- xinference/model/audio/core.py +161 -0
- xinference/model/audio/model_spec.json +79 -0
- xinference/model/audio/utils.py +18 -0
- xinference/model/audio/whisper.py +132 -0
- xinference/model/core.py +18 -13
- xinference/model/embedding/__init__.py +27 -2
- xinference/model/embedding/core.py +43 -3
- xinference/model/embedding/model_spec.json +24 -0
- xinference/model/embedding/model_spec_modelscope.json +24 -0
- xinference/model/embedding/utils.py +18 -0
- xinference/model/image/__init__.py +12 -1
- xinference/model/image/core.py +63 -9
- xinference/model/image/utils.py +26 -0
- xinference/model/llm/__init__.py +20 -1
- xinference/model/llm/core.py +43 -2
- xinference/model/llm/ggml/chatglm.py +15 -6
- xinference/model/llm/llm_family.json +197 -6
- xinference/model/llm/llm_family.py +9 -7
- xinference/model/llm/llm_family_modelscope.json +189 -4
- xinference/model/llm/pytorch/chatglm.py +3 -3
- xinference/model/llm/pytorch/core.py +4 -2
- xinference/model/{multimodal → llm/pytorch}/qwen_vl.py +10 -8
- xinference/model/llm/pytorch/utils.py +21 -9
- xinference/model/llm/pytorch/yi_vl.py +246 -0
- xinference/model/llm/utils.py +57 -4
- xinference/model/llm/vllm/core.py +5 -4
- xinference/model/rerank/__init__.py +25 -2
- xinference/model/rerank/core.py +51 -9
- xinference/model/rerank/model_spec.json +6 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -0
- xinference/{api/oauth2/common.py → model/rerank/utils.py} +6 -2
- xinference/model/utils.py +5 -3
- xinference/thirdparty/__init__.py +0 -0
- xinference/thirdparty/llava/__init__.py +1 -0
- xinference/thirdparty/llava/conversation.py +205 -0
- xinference/thirdparty/llava/mm_utils.py +122 -0
- xinference/thirdparty/llava/model/__init__.py +1 -0
- xinference/thirdparty/llava/model/clip_encoder/__init__.py +0 -0
- xinference/thirdparty/llava/model/clip_encoder/builder.py +11 -0
- xinference/thirdparty/llava/model/clip_encoder/clip_encoder.py +86 -0
- xinference/thirdparty/llava/model/constants.py +6 -0
- xinference/thirdparty/llava/model/llava_arch.py +385 -0
- xinference/thirdparty/llava/model/llava_llama.py +163 -0
- xinference/thirdparty/llava/model/multimodal_projector/__init__.py +0 -0
- xinference/thirdparty/llava/model/multimodal_projector/builder.py +64 -0
- xinference/types.py +1 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.15822aeb.js +3 -0
- xinference/web/ui/build/static/js/main.15822aeb.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/139e5e4adf436923107d2b02994c7ff6dba2aac1989e9b6638984f0dfe782c4a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/52aa27272b4b9968f62666262b47661cb1992336a2aff3b13994cc36877b3ec3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/64accc515dc6cd584a2873796cd7da6f93de57f7e465eb5423cca9a2f3fe3eff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/65ca3ba225b8c8dac907210545b51f2fcdb2591f0feeb7195f1c037f2bc956a0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b80db1012318b97c329c4e3e72454f7512fb107e57c444b437dbe4ba1a3faa5a.json +1 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/METADATA +33 -23
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/RECORD +81 -64
- xinference/api/oauth2/core.py +0 -93
- xinference/model/multimodal/__init__.py +0 -52
- xinference/model/multimodal/core.py +0 -467
- xinference/model/multimodal/model_spec.json +0 -43
- xinference/model/multimodal/model_spec_modelscope.json +0 -45
- xinference/web/ui/build/static/js/main.b83095c2.js +0 -3
- xinference/web/ui/build/static/js/main.b83095c2.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/101923c539819f26ad11fbcbd6f6e56436b285efbb090dcc7dd648c6e924c4a8.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4942da6bc03bf7373af068e22f916341aabc5b5df855d73c1d348c696724ce37.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/52a6136cb2dbbf9c51d461724d9b283ebe74a73fb19d5df7ba8e13c42bd7174d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/71493aadd34d568fbe605cacaba220aa69bd09273251ee4ba27930f8d01fccd8.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8b071db2a5a9ef68dc14d5f606540bd23d9785e365a11997c510656764d2dccf.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a4d72d3b806ba061919115f0c513738726872e3c79cf258f007519d3f91d1a16.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f037ffef5992af0892d6d991053c1dace364cd39a3f11f1a41f92776e8a59459.json +0 -1
- /xinference/web/ui/build/static/js/{main.b83095c2.js.LICENSE.txt → main.15822aeb.js.LICENSE.txt} +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/LICENSE +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/WHEEL +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -12,5 +12,12 @@
|
|
|
12
12
|
"model_id": "Xorbits/bge-reranker-large",
|
|
13
13
|
"model_revision": "v0.0.1",
|
|
14
14
|
"model_hub": "modelscope"
|
|
15
|
+
},
|
|
16
|
+
{
|
|
17
|
+
"model_name": "bce-reranker-base_v1",
|
|
18
|
+
"language": ["en", "zh"],
|
|
19
|
+
"model_id": "maidalun/bce-reranker-base_v1",
|
|
20
|
+
"model_revision": "v0.0.1",
|
|
21
|
+
"model_hub": "modelscope"
|
|
15
22
|
}
|
|
16
23
|
]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022-
|
|
1
|
+
# Copyright 2022-2024 XProbe Inc.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,4 +11,8 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
14
|
+
from .core import RerankModelSpec
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_model_version(rerank_model: RerankModelSpec) -> str:
|
|
18
|
+
return rerank_model.model_name
|
xinference/model/utils.py
CHANGED
|
@@ -141,10 +141,12 @@ def valid_model_revision(
|
|
|
141
141
|
return real_revision == expected_model_revision
|
|
142
142
|
|
|
143
143
|
|
|
144
|
+
def get_cache_dir(model_spec: Any) -> str:
|
|
145
|
+
return os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name))
|
|
146
|
+
|
|
147
|
+
|
|
144
148
|
def is_model_cached(model_spec: Any, name_to_revisions_mapping: Dict):
|
|
145
|
-
cache_dir =
|
|
146
|
-
os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
|
|
147
|
-
)
|
|
149
|
+
cache_dir = get_cache_dir(model_spec)
|
|
148
150
|
meta_path = os.path.join(cache_dir, "__valid_download")
|
|
149
151
|
revisions = name_to_revisions_mapping[model_spec.model_name]
|
|
150
152
|
if model_spec.model_revision not in revisions: # Usually for UT
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .model import LlavaLlamaForCausalLM
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from enum import Enum, auto
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SeparatorStyle(Enum):
|
|
7
|
+
"""Different separator style."""
|
|
8
|
+
|
|
9
|
+
SINGLE = auto()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclasses.dataclass
|
|
13
|
+
class Conversation:
|
|
14
|
+
"""A class that keeps all conversation history."""
|
|
15
|
+
|
|
16
|
+
system: str
|
|
17
|
+
roles: List[str]
|
|
18
|
+
messages: List[List[str]]
|
|
19
|
+
offset: int
|
|
20
|
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
|
21
|
+
sep: str = "###"
|
|
22
|
+
sep2: str = None
|
|
23
|
+
version: str = "Unknown"
|
|
24
|
+
|
|
25
|
+
skip_next: bool = False
|
|
26
|
+
|
|
27
|
+
def get_prompt(self):
|
|
28
|
+
messages = self.messages
|
|
29
|
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
|
30
|
+
messages = self.messages.copy()
|
|
31
|
+
init_role, init_msg = messages[0].copy()
|
|
32
|
+
init_msg = init_msg[0].replace("<image_placeholder>", "").strip()
|
|
33
|
+
if "mmtag" in self.version:
|
|
34
|
+
messages[0] = (init_role, init_msg)
|
|
35
|
+
messages.insert(
|
|
36
|
+
0, (self.roles[0], "<Image><image_placeholder></Image>")
|
|
37
|
+
)
|
|
38
|
+
messages.insert(1, (self.roles[1], "Received."))
|
|
39
|
+
else:
|
|
40
|
+
messages[0] = (init_role, "<image_placeholder>\n" + init_msg)
|
|
41
|
+
|
|
42
|
+
if self.sep_style == SeparatorStyle.SINGLE:
|
|
43
|
+
ret = self.system + "\n\n" + self.sep + " "
|
|
44
|
+
for role, message in messages:
|
|
45
|
+
if message:
|
|
46
|
+
if type(message) is tuple:
|
|
47
|
+
message, _, _ = message
|
|
48
|
+
ret += role + ": " + message + "\n" + self.sep + " "
|
|
49
|
+
else:
|
|
50
|
+
ret += role + ":"
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
|
53
|
+
|
|
54
|
+
return ret
|
|
55
|
+
|
|
56
|
+
def append_message(self, role, message):
|
|
57
|
+
self.messages.append([role, message])
|
|
58
|
+
|
|
59
|
+
def get_images(self, return_pil=False):
|
|
60
|
+
images = []
|
|
61
|
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
|
62
|
+
if i % 2 == 0:
|
|
63
|
+
if type(msg) is tuple:
|
|
64
|
+
import base64
|
|
65
|
+
from io import BytesIO
|
|
66
|
+
|
|
67
|
+
from PIL import Image
|
|
68
|
+
|
|
69
|
+
msg, image, image_process_mode = msg
|
|
70
|
+
if image_process_mode == "Pad":
|
|
71
|
+
|
|
72
|
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
|
73
|
+
width, height = pil_img.size
|
|
74
|
+
if width == height:
|
|
75
|
+
return pil_img
|
|
76
|
+
elif width > height:
|
|
77
|
+
result = Image.new(
|
|
78
|
+
pil_img.mode, (width, width), background_color
|
|
79
|
+
)
|
|
80
|
+
result.paste(pil_img, (0, (width - height) // 2))
|
|
81
|
+
return result
|
|
82
|
+
else:
|
|
83
|
+
result = Image.new(
|
|
84
|
+
pil_img.mode, (height, height), background_color
|
|
85
|
+
)
|
|
86
|
+
result.paste(pil_img, ((height - width) // 2, 0))
|
|
87
|
+
return result
|
|
88
|
+
|
|
89
|
+
image = expand2square(image)
|
|
90
|
+
elif image_process_mode == "Crop":
|
|
91
|
+
pass
|
|
92
|
+
elif image_process_mode == "Resize":
|
|
93
|
+
image = image.resize((336, 336))
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"Invalid image_process_mode: {image_process_mode}"
|
|
97
|
+
)
|
|
98
|
+
max_hw, min_hw = max(image.size), min(image.size)
|
|
99
|
+
aspect_ratio = max_hw / min_hw
|
|
100
|
+
max_len, min_len = 800, 400
|
|
101
|
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
|
102
|
+
longest_edge = int(shortest_edge * aspect_ratio)
|
|
103
|
+
W, H = image.size
|
|
104
|
+
if H > W:
|
|
105
|
+
H, W = longest_edge, shortest_edge
|
|
106
|
+
else:
|
|
107
|
+
H, W = shortest_edge, longest_edge
|
|
108
|
+
image = image.resize((W, H))
|
|
109
|
+
if return_pil:
|
|
110
|
+
images.append(image)
|
|
111
|
+
else:
|
|
112
|
+
buffered = BytesIO()
|
|
113
|
+
image.save(buffered, format="PNG")
|
|
114
|
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
|
115
|
+
images.append(img_b64_str)
|
|
116
|
+
return images
|
|
117
|
+
|
|
118
|
+
def to_gradio_chatbot(self):
|
|
119
|
+
ret = []
|
|
120
|
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
|
121
|
+
if i % 2 == 0:
|
|
122
|
+
if type(msg) is tuple:
|
|
123
|
+
import base64
|
|
124
|
+
from io import BytesIO
|
|
125
|
+
|
|
126
|
+
msg, image, image_process_mode = msg
|
|
127
|
+
max_hw, min_hw = max(image.size), min(image.size)
|
|
128
|
+
aspect_ratio = max_hw / min_hw
|
|
129
|
+
max_len, min_len = 800, 400
|
|
130
|
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
|
131
|
+
longest_edge = int(shortest_edge * aspect_ratio)
|
|
132
|
+
W, H = image.size
|
|
133
|
+
if H > W:
|
|
134
|
+
H, W = longest_edge, shortest_edge
|
|
135
|
+
else:
|
|
136
|
+
H, W = shortest_edge, longest_edge
|
|
137
|
+
image = image.resize((W, H))
|
|
138
|
+
buffered = BytesIO()
|
|
139
|
+
image.save(buffered, format="JPEG")
|
|
140
|
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
|
141
|
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
|
142
|
+
msg = img_str + msg.replace("<image_placeholder>", "").strip()
|
|
143
|
+
ret.append([msg, None])
|
|
144
|
+
else:
|
|
145
|
+
ret.append([msg, None])
|
|
146
|
+
else:
|
|
147
|
+
ret[-1][-1] = msg
|
|
148
|
+
return ret
|
|
149
|
+
|
|
150
|
+
def copy(self):
|
|
151
|
+
return Conversation(
|
|
152
|
+
system=self.system,
|
|
153
|
+
roles=self.roles,
|
|
154
|
+
messages=[[x, y] for x, y in self.messages],
|
|
155
|
+
offset=self.offset,
|
|
156
|
+
sep_style=self.sep_style,
|
|
157
|
+
sep=self.sep,
|
|
158
|
+
sep2=self.sep2,
|
|
159
|
+
version=self.version,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def dict(self):
|
|
163
|
+
if len(self.get_images()) > 0:
|
|
164
|
+
return {
|
|
165
|
+
"system": self.system,
|
|
166
|
+
"roles": self.roles,
|
|
167
|
+
"messages": [
|
|
168
|
+
[x, y[0] if type(y) is tuple else y] for x, y in self.messages
|
|
169
|
+
],
|
|
170
|
+
"offset": self.offset,
|
|
171
|
+
"sep": self.sep,
|
|
172
|
+
"sep2": self.sep2,
|
|
173
|
+
}
|
|
174
|
+
return {
|
|
175
|
+
"system": self.system,
|
|
176
|
+
"roles": self.roles,
|
|
177
|
+
"messages": self.messages,
|
|
178
|
+
"offset": self.offset,
|
|
179
|
+
"sep": self.sep,
|
|
180
|
+
"sep2": self.sep2,
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
mm_default_conv = Conversation(
|
|
185
|
+
system="This is a chat between an inquisitive human and an AI assistant. "
|
|
186
|
+
"Assume the role of the AI assistant. "
|
|
187
|
+
"Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers. "
|
|
188
|
+
"这是一个好奇的人类和一个人工智能助手之间的对话。"
|
|
189
|
+
"假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。",
|
|
190
|
+
roles=("Human", "Assistant"),
|
|
191
|
+
messages=(),
|
|
192
|
+
offset=0,
|
|
193
|
+
sep_style=SeparatorStyle.SINGLE,
|
|
194
|
+
sep="###",
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
default_conversation = mm_default_conv
|
|
199
|
+
conv_templates = {
|
|
200
|
+
"mm_default": mm_default_conv,
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
if __name__ == "__main__":
|
|
205
|
+
print(default_conversation.get_prompt())
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
from io import BytesIO
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from .model import LlavaLlamaForCausalLM
|
|
6
|
+
from .model.constants import IMAGE_TOKEN_INDEX
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from transformers import AutoTokenizer, StoppingCriteria
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_image_from_base64(image):
|
|
12
|
+
return Image.open(BytesIO(base64.b64decode(image)))
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def process_images(images, image_processor, model_cfg):
|
|
16
|
+
return image_processor(images, return_tensors="pt")["pixel_values"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def expand2square(pil_img, background_color):
|
|
20
|
+
width, height = pil_img.size
|
|
21
|
+
if width == height:
|
|
22
|
+
return pil_img
|
|
23
|
+
elif width > height:
|
|
24
|
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
|
25
|
+
result.paste(pil_img, (0, (width - height) // 2))
|
|
26
|
+
return result
|
|
27
|
+
else:
|
|
28
|
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
|
29
|
+
result.paste(pil_img, ((height - width) // 2, 0))
|
|
30
|
+
return result
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def tokenizer_image_token(
|
|
34
|
+
prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
|
|
35
|
+
):
|
|
36
|
+
prompt_chunks = [
|
|
37
|
+
tokenizer(chunk).input_ids for chunk in prompt.split("<image_placeholder>")
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
def insert_separator(X, sep):
|
|
41
|
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
|
42
|
+
|
|
43
|
+
input_ids = []
|
|
44
|
+
offset = 0
|
|
45
|
+
if (
|
|
46
|
+
len(prompt_chunks) > 0
|
|
47
|
+
and len(prompt_chunks[0]) > 0
|
|
48
|
+
and prompt_chunks[0][0] == tokenizer.bos_token_id
|
|
49
|
+
):
|
|
50
|
+
offset = 1
|
|
51
|
+
input_ids.append(prompt_chunks[0][0])
|
|
52
|
+
|
|
53
|
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
|
54
|
+
input_ids.extend(x[offset:])
|
|
55
|
+
|
|
56
|
+
if return_tensors is not None:
|
|
57
|
+
if return_tensors == "pt":
|
|
58
|
+
return torch.tensor(input_ids, dtype=torch.long)
|
|
59
|
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
|
60
|
+
return input_ids
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_model_name_from_path(model_path):
|
|
64
|
+
model_path = model_path.strip("/")
|
|
65
|
+
model_paths = model_path.split("/")
|
|
66
|
+
if model_paths[-1].startswith("checkpoint-"):
|
|
67
|
+
return model_paths[-2] + "_" + model_paths[-1]
|
|
68
|
+
else:
|
|
69
|
+
return model_paths[-1]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def load_pretrained_model(
|
|
73
|
+
model_path, load_8bit=False, load_4bit=False, device_map="auto", multimodal="IMAGE"
|
|
74
|
+
):
|
|
75
|
+
kwargs = {"device_map": device_map}
|
|
76
|
+
kwargs["torch_dtype"] = torch.bfloat16
|
|
77
|
+
|
|
78
|
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
|
79
|
+
model = LlavaLlamaForCausalLM.from_pretrained(
|
|
80
|
+
model_path, low_cpu_mem_usage=True, **kwargs
|
|
81
|
+
)
|
|
82
|
+
image_processor = None
|
|
83
|
+
model.resize_token_embeddings(len(tokenizer))
|
|
84
|
+
vision_tower = model.get_vision_tower()
|
|
85
|
+
|
|
86
|
+
if not vision_tower.is_loaded:
|
|
87
|
+
vision_tower.load_model()
|
|
88
|
+
vision_tower.to(device="cuda", dtype=torch.bfloat16)
|
|
89
|
+
image_processor = vision_tower.image_processor
|
|
90
|
+
|
|
91
|
+
if hasattr(model.config, "max_sequence_length"):
|
|
92
|
+
context_len = model.config.max_sequence_length
|
|
93
|
+
else:
|
|
94
|
+
context_len = 2048
|
|
95
|
+
|
|
96
|
+
return tokenizer, model, image_processor, context_len
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
|
100
|
+
def __init__(self, keywords, tokenizer, input_ids):
|
|
101
|
+
self.keywords = keywords
|
|
102
|
+
self.tokenizer = tokenizer
|
|
103
|
+
self.start_len = None
|
|
104
|
+
self.input_ids = input_ids
|
|
105
|
+
|
|
106
|
+
def __call__(
|
|
107
|
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
|
108
|
+
) -> bool:
|
|
109
|
+
if self.start_len is None:
|
|
110
|
+
self.start_len = self.input_ids.shape[1]
|
|
111
|
+
return False
|
|
112
|
+
else:
|
|
113
|
+
outputs = self.tokenizer.batch_decode(
|
|
114
|
+
output_ids[:, self.start_len :], skip_special_tokens=True
|
|
115
|
+
)
|
|
116
|
+
flag = True
|
|
117
|
+
for output in outputs:
|
|
118
|
+
for keyword in self.keywords:
|
|
119
|
+
if keyword not in output:
|
|
120
|
+
flag = False
|
|
121
|
+
return False
|
|
122
|
+
return flag
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .llava_llama import LlavaConfig, LlavaLlamaForCausalLM
|
|
File without changes
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .clip_encoder import CLIPVisionTower
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
|
5
|
+
vision_tower = getattr(
|
|
6
|
+
vision_tower_cfg,
|
|
7
|
+
"mm_vision_tower",
|
|
8
|
+
getattr(vision_tower_cfg, "vision_tower", None),
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CLIPVisionTower(nn.Module):
|
|
7
|
+
def __init__(self, vision_tower, args, delay_load=False):
|
|
8
|
+
super().__init__()
|
|
9
|
+
|
|
10
|
+
self.is_loaded = False
|
|
11
|
+
|
|
12
|
+
self.vision_tower_name = vision_tower
|
|
13
|
+
self.select_layer = args.mm_vision_select_layer
|
|
14
|
+
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
|
15
|
+
|
|
16
|
+
if not delay_load:
|
|
17
|
+
self.load_model()
|
|
18
|
+
else:
|
|
19
|
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
|
20
|
+
|
|
21
|
+
def load_model(self):
|
|
22
|
+
self.image_processor = CLIPImageProcessor.from_pretrained(
|
|
23
|
+
self.vision_tower_name
|
|
24
|
+
)
|
|
25
|
+
self.vision_tower = CLIPVisionModel.from_pretrained(
|
|
26
|
+
self.vision_tower_name, ignore_mismatched_sizes=True
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
self.is_loaded = True
|
|
30
|
+
|
|
31
|
+
def feature_select(self, image_forward_outs):
|
|
32
|
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
|
33
|
+
if self.select_feature == "patch":
|
|
34
|
+
image_features = image_features[:, 1:]
|
|
35
|
+
elif self.select_feature == "cls_patch":
|
|
36
|
+
image_features = image_features
|
|
37
|
+
else:
|
|
38
|
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
|
39
|
+
return image_features
|
|
40
|
+
|
|
41
|
+
# @torch.no_grad()
|
|
42
|
+
def forward(self, images):
|
|
43
|
+
if type(images) is list:
|
|
44
|
+
image_features = []
|
|
45
|
+
for image in images:
|
|
46
|
+
image_forward_out = self.vision_tower(
|
|
47
|
+
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
|
48
|
+
output_hidden_states=True,
|
|
49
|
+
)
|
|
50
|
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
|
51
|
+
image_features.append(image_feature)
|
|
52
|
+
else:
|
|
53
|
+
image_forward_outs = self.vision_tower(
|
|
54
|
+
images.to(device=self.device, dtype=self.dtype),
|
|
55
|
+
output_hidden_states=True,
|
|
56
|
+
)
|
|
57
|
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
|
58
|
+
|
|
59
|
+
return image_features
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def dummy_feature(self):
|
|
63
|
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def dtype(self):
|
|
67
|
+
return self.vision_tower.dtype
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def device(self):
|
|
71
|
+
return self.vision_tower.device
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def config(self):
|
|
75
|
+
if self.is_loaded:
|
|
76
|
+
return self.vision_tower.config
|
|
77
|
+
else:
|
|
78
|
+
return self.cfg_only
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def hidden_size(self):
|
|
82
|
+
return self.config.hidden_size
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def num_patches(self):
|
|
86
|
+
return (self.config.image_size // self.config.patch_size) ** 2
|