mineru 2.2.2__py3-none-any.whl → 2.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mineru/backend/pipeline/pipeline_middle_json_mkcontent.py +3 -3
- mineru/backend/vlm/model_output_to_middle_json.py +123 -0
- mineru/backend/vlm/vlm_analyze.py +97 -16
- mineru/backend/vlm/vlm_magic_model.py +201 -135
- mineru/backend/vlm/vlm_middle_json_mkcontent.py +52 -11
- mineru/cli/client.py +6 -5
- mineru/cli/common.py +17 -16
- mineru/cli/fast_api.py +9 -7
- mineru/cli/gradio_app.py +15 -16
- mineru/cli/vlm_vllm_server.py +4 -0
- mineru/model/table/rec/unet_table/main.py +8 -0
- mineru/model/vlm_vllm_model/__init__.py +0 -0
- mineru/model/vlm_vllm_model/server.py +51 -0
- mineru/resources/header.html +10 -2
- mineru/utils/draw_bbox.py +32 -10
- mineru/utils/enum_class.py +16 -2
- mineru/utils/guess_suffix_or_lang.py +20 -0
- mineru/utils/span_block_fix.py +4 -2
- mineru/version.py +1 -1
- {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/METADATA +70 -25
- {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/RECORD +25 -38
- {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/entry_points.txt +1 -1
- mineru/backend/vlm/base_predictor.py +0 -186
- mineru/backend/vlm/hf_predictor.py +0 -217
- mineru/backend/vlm/predictor.py +0 -111
- mineru/backend/vlm/sglang_client_predictor.py +0 -443
- mineru/backend/vlm/sglang_engine_predictor.py +0 -246
- mineru/backend/vlm/token_to_middle_json.py +0 -122
- mineru/backend/vlm/utils.py +0 -40
- mineru/cli/vlm_sglang_server.py +0 -4
- mineru/model/vlm_hf_model/__init__.py +0 -9
- mineru/model/vlm_hf_model/configuration_mineru2.py +0 -38
- mineru/model/vlm_hf_model/image_processing_mineru2.py +0 -269
- mineru/model/vlm_hf_model/modeling_mineru2.py +0 -449
- mineru/model/vlm_sglang_model/__init__.py +0 -14
- mineru/model/vlm_sglang_model/engine.py +0 -264
- mineru/model/vlm_sglang_model/image_processor.py +0 -213
- mineru/model/vlm_sglang_model/logit_processor.py +0 -90
- mineru/model/vlm_sglang_model/model.py +0 -453
- mineru/model/vlm_sglang_model/server.py +0 -75
- {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/WHEEL +0 -0
- {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/licenses/LICENSE.md +0 -0
- {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,449 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
import re
|
|
3
|
-
from typing import List, Optional, Tuple, Union
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
import torch.nn as nn
|
|
7
|
-
from transformers import (
|
|
8
|
-
Qwen2ForCausalLM,
|
|
9
|
-
Qwen2Model,
|
|
10
|
-
SiglipVisionConfig,
|
|
11
|
-
SiglipVisionModel,
|
|
12
|
-
)
|
|
13
|
-
from transformers.generation.utils import GenerateOutput
|
|
14
|
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
15
|
-
|
|
16
|
-
from .configuration_mineru2 import Mineru2QwenConfig
|
|
17
|
-
from .image_processing_mineru2 import Mineru2ImageProcessor, get_anyres_image_grid_shape
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class SiglipVisionTower(nn.Module):
|
|
21
|
-
def __init__(self, vision_tower):
|
|
22
|
-
super().__init__()
|
|
23
|
-
|
|
24
|
-
self.config = SiglipVisionConfig.from_pretrained(vision_tower)
|
|
25
|
-
assert isinstance(self.config, SiglipVisionConfig)
|
|
26
|
-
self.config.num_hidden_layers -= 1 # drop the last hidden layer
|
|
27
|
-
self.config.vision_use_head = False
|
|
28
|
-
|
|
29
|
-
self.vision_tower = SiglipVisionModel(self.config)
|
|
30
|
-
self.vision_tower.requires_grad_(False)
|
|
31
|
-
|
|
32
|
-
self.image_processor = Mineru2ImageProcessor()
|
|
33
|
-
|
|
34
|
-
def forward(self, images):
|
|
35
|
-
if type(images) is list:
|
|
36
|
-
image_features = []
|
|
37
|
-
for image in images:
|
|
38
|
-
image_forward_out = self.vision_tower(
|
|
39
|
-
image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
|
|
40
|
-
)
|
|
41
|
-
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
|
42
|
-
image_features.append(image_feature)
|
|
43
|
-
else:
|
|
44
|
-
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
|
45
|
-
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
|
46
|
-
|
|
47
|
-
return image_features
|
|
48
|
-
|
|
49
|
-
@property
|
|
50
|
-
def dummy_feature(self):
|
|
51
|
-
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
|
52
|
-
|
|
53
|
-
@property
|
|
54
|
-
def dtype(self):
|
|
55
|
-
for p in self.vision_tower.parameters():
|
|
56
|
-
return p.dtype
|
|
57
|
-
|
|
58
|
-
@property
|
|
59
|
-
def device(self):
|
|
60
|
-
for p in self.vision_tower.parameters():
|
|
61
|
-
return p.device
|
|
62
|
-
|
|
63
|
-
@property
|
|
64
|
-
def hidden_size(self):
|
|
65
|
-
return self.config.hidden_size
|
|
66
|
-
|
|
67
|
-
@property
|
|
68
|
-
def num_patches(self):
|
|
69
|
-
return (self.config.image_size // self.config.patch_size) ** 2
|
|
70
|
-
|
|
71
|
-
@property
|
|
72
|
-
def num_patches_per_side(self):
|
|
73
|
-
return self.config.image_size // self.config.patch_size
|
|
74
|
-
|
|
75
|
-
@property
|
|
76
|
-
def image_size(self):
|
|
77
|
-
return self.config.image_size
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def build_vision_tower(config: Mineru2QwenConfig):
|
|
81
|
-
vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
|
|
82
|
-
model_path = getattr(config, "_name_or_path", "")
|
|
83
|
-
if "siglip" in vision_tower.lower():
|
|
84
|
-
if model_path:
|
|
85
|
-
return SiglipVisionTower(f"{model_path}/{vision_tower}")
|
|
86
|
-
else:
|
|
87
|
-
return SiglipVisionTower(vision_tower)
|
|
88
|
-
raise ValueError(f"Unknown vision tower: {vision_tower}")
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def build_vision_projector(config: Mineru2QwenConfig):
|
|
92
|
-
projector_type = getattr(config, "mm_projector_type", "linear")
|
|
93
|
-
|
|
94
|
-
if projector_type == "linear":
|
|
95
|
-
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
|
96
|
-
|
|
97
|
-
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
|
98
|
-
if mlp_gelu_match:
|
|
99
|
-
mlp_depth = int(mlp_gelu_match.group(1))
|
|
100
|
-
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
|
101
|
-
for _ in range(1, mlp_depth):
|
|
102
|
-
modules.append(nn.GELU()) # type: ignore
|
|
103
|
-
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
|
104
|
-
return nn.Sequential(*modules)
|
|
105
|
-
|
|
106
|
-
if projector_type == "identity":
|
|
107
|
-
return nn.Identity()
|
|
108
|
-
|
|
109
|
-
raise ValueError(f"Unknown projector type: {projector_type}")
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
class Mineru2QwenModel(Qwen2Model):
|
|
113
|
-
config_class = Mineru2QwenConfig
|
|
114
|
-
|
|
115
|
-
def __init__(self, config: Mineru2QwenConfig):
|
|
116
|
-
super(Mineru2QwenModel, self).__init__(config)
|
|
117
|
-
|
|
118
|
-
self.vision_tower = build_vision_tower(config)
|
|
119
|
-
self.mm_projector = build_vision_projector(config)
|
|
120
|
-
|
|
121
|
-
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
|
122
|
-
self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
class Mineru2QwenForCausalLM(Qwen2ForCausalLM):
|
|
126
|
-
config_class = Mineru2QwenConfig
|
|
127
|
-
|
|
128
|
-
def __init__(self, config: Mineru2QwenConfig):
|
|
129
|
-
super(Qwen2ForCausalLM, self).__init__(config)
|
|
130
|
-
config.rope_scaling = None
|
|
131
|
-
self.model = Mineru2QwenModel(config)
|
|
132
|
-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
133
|
-
|
|
134
|
-
self.ignore_index = config.ignore_index
|
|
135
|
-
self.image_token_index = config.image_token_index
|
|
136
|
-
|
|
137
|
-
# Initialize weights and apply final processing
|
|
138
|
-
self.post_init()
|
|
139
|
-
|
|
140
|
-
def get_model(self):
|
|
141
|
-
return self.model
|
|
142
|
-
|
|
143
|
-
def encode_images(self, images: torch.Tensor):
|
|
144
|
-
image_features = self.get_model().vision_tower(images)
|
|
145
|
-
image_features = self.get_model().mm_projector(image_features)
|
|
146
|
-
return image_features
|
|
147
|
-
|
|
148
|
-
def prepare_inputs_labels_for_multimodal(
|
|
149
|
-
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
|
|
150
|
-
):
|
|
151
|
-
vision_tower = self.get_model().vision_tower
|
|
152
|
-
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
|
153
|
-
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
|
154
|
-
|
|
155
|
-
if type(images) is list or images.ndim == 5:
|
|
156
|
-
if type(images) is list:
|
|
157
|
-
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
|
|
158
|
-
concat_images = torch.cat([image for image in images], dim=0)
|
|
159
|
-
image_features = self.encode_images(concat_images)
|
|
160
|
-
split_sizes = [image.shape[0] for image in images]
|
|
161
|
-
image_features = torch.split(image_features, split_sizes, dim=0)
|
|
162
|
-
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
|
163
|
-
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
|
164
|
-
if mm_patch_merge_type == "flat":
|
|
165
|
-
image_features = [x.flatten(0, 1) for x in image_features]
|
|
166
|
-
elif mm_patch_merge_type.startswith("spatial"):
|
|
167
|
-
new_image_features = []
|
|
168
|
-
for image_idx, image_feature in enumerate(image_features):
|
|
169
|
-
if image_feature.shape[0] > 1:
|
|
170
|
-
base_image_feature = image_feature[0]
|
|
171
|
-
image_feature = image_feature[1:]
|
|
172
|
-
height = width = self.get_model().vision_tower.num_patches_per_side
|
|
173
|
-
assert height * width == base_image_feature.shape[0]
|
|
174
|
-
|
|
175
|
-
if "anyres_max" in image_aspect_ratio:
|
|
176
|
-
matched_anyres_max_num_patches = re.match(r"square_anyres_max_(\d+)", image_aspect_ratio)
|
|
177
|
-
if matched_anyres_max_num_patches:
|
|
178
|
-
max_num_patches = int(matched_anyres_max_num_patches.group(1))
|
|
179
|
-
|
|
180
|
-
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
|
|
181
|
-
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
|
182
|
-
image_sizes[image_idx],
|
|
183
|
-
self.config.image_grid_pinpoints,
|
|
184
|
-
self.get_model().vision_tower.config.image_size,
|
|
185
|
-
)
|
|
186
|
-
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
|
187
|
-
else:
|
|
188
|
-
raise NotImplementedError
|
|
189
|
-
if (
|
|
190
|
-
"unpad" in mm_patch_merge_type
|
|
191
|
-
and "anyres_max" in image_aspect_ratio
|
|
192
|
-
and matched_anyres_max_num_patches
|
|
193
|
-
):
|
|
194
|
-
unit = image_feature.shape[2]
|
|
195
|
-
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
|
196
|
-
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
|
197
|
-
c, h, w = image_feature.shape
|
|
198
|
-
times = math.sqrt(h * w / (max_num_patches * unit**2))
|
|
199
|
-
if times > 1.1:
|
|
200
|
-
image_feature = image_feature[None]
|
|
201
|
-
image_feature = nn.functional.interpolate(
|
|
202
|
-
image_feature, [int(h // times), int(w // times)], mode="bilinear"
|
|
203
|
-
)[0]
|
|
204
|
-
image_feature = torch.cat(
|
|
205
|
-
(
|
|
206
|
-
image_feature,
|
|
207
|
-
self.model.image_newline[:, None, None]
|
|
208
|
-
.expand(*image_feature.shape[:-1], 1)
|
|
209
|
-
.to(image_feature.device),
|
|
210
|
-
),
|
|
211
|
-
dim=-1,
|
|
212
|
-
)
|
|
213
|
-
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
|
214
|
-
elif "unpad" in mm_patch_merge_type:
|
|
215
|
-
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
|
216
|
-
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
|
217
|
-
image_feature = torch.cat(
|
|
218
|
-
(
|
|
219
|
-
image_feature,
|
|
220
|
-
self.model.image_newline[:, None, None]
|
|
221
|
-
.expand(*image_feature.shape[:-1], 1)
|
|
222
|
-
.to(image_feature.device),
|
|
223
|
-
),
|
|
224
|
-
dim=-1,
|
|
225
|
-
)
|
|
226
|
-
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
|
227
|
-
else:
|
|
228
|
-
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
|
|
229
|
-
image_feature = image_feature.flatten(0, 3)
|
|
230
|
-
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
|
231
|
-
else:
|
|
232
|
-
image_feature = image_feature[0]
|
|
233
|
-
if "unpad" in mm_patch_merge_type:
|
|
234
|
-
image_feature = torch.cat(
|
|
235
|
-
(image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0
|
|
236
|
-
)
|
|
237
|
-
new_image_features.append(image_feature)
|
|
238
|
-
image_features = new_image_features
|
|
239
|
-
else:
|
|
240
|
-
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
|
|
241
|
-
else:
|
|
242
|
-
image_features = self.encode_images(images)
|
|
243
|
-
|
|
244
|
-
_labels = labels
|
|
245
|
-
_position_ids = position_ids
|
|
246
|
-
_attention_mask = attention_mask
|
|
247
|
-
if attention_mask is None:
|
|
248
|
-
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
|
249
|
-
else:
|
|
250
|
-
attention_mask = attention_mask.bool()
|
|
251
|
-
if position_ids is None:
|
|
252
|
-
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
|
253
|
-
if labels is None:
|
|
254
|
-
labels = torch.full_like(input_ids, self.ignore_index)
|
|
255
|
-
|
|
256
|
-
# remove the padding using attention_mask -- FIXME
|
|
257
|
-
_input_ids = input_ids
|
|
258
|
-
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
|
259
|
-
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
|
260
|
-
|
|
261
|
-
new_input_embeds = []
|
|
262
|
-
new_labels = []
|
|
263
|
-
cur_image_idx = 0
|
|
264
|
-
for batch_idx, cur_input_ids in enumerate(input_ids):
|
|
265
|
-
num_images = (cur_input_ids == self.image_token_index).sum()
|
|
266
|
-
if num_images == 0:
|
|
267
|
-
cur_image_features = image_features[cur_image_idx]
|
|
268
|
-
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
|
269
|
-
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
|
270
|
-
new_input_embeds.append(cur_input_embeds)
|
|
271
|
-
new_labels.append(labels[batch_idx])
|
|
272
|
-
cur_image_idx += 1
|
|
273
|
-
continue
|
|
274
|
-
|
|
275
|
-
image_token_indices = (
|
|
276
|
-
[-1] + torch.where(cur_input_ids == self.image_token_index)[0].tolist() + [cur_input_ids.shape[0]]
|
|
277
|
-
)
|
|
278
|
-
cur_input_ids_noim = []
|
|
279
|
-
cur_labels = labels[batch_idx]
|
|
280
|
-
cur_labels_noim = []
|
|
281
|
-
for i in range(len(image_token_indices) - 1):
|
|
282
|
-
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
|
283
|
-
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
|
284
|
-
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
|
285
|
-
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
|
286
|
-
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
|
287
|
-
cur_new_input_embeds = []
|
|
288
|
-
cur_new_labels = []
|
|
289
|
-
|
|
290
|
-
for i in range(num_images + 1):
|
|
291
|
-
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
|
292
|
-
cur_new_labels.append(cur_labels_noim[i])
|
|
293
|
-
if i < num_images:
|
|
294
|
-
cur_image_features = image_features[cur_image_idx]
|
|
295
|
-
cur_image_idx += 1
|
|
296
|
-
cur_new_input_embeds.append(cur_image_features)
|
|
297
|
-
cur_new_labels.append(
|
|
298
|
-
torch.full(
|
|
299
|
-
(cur_image_features.shape[0],), self.ignore_index, device=cur_labels.device, dtype=cur_labels.dtype
|
|
300
|
-
)
|
|
301
|
-
)
|
|
302
|
-
|
|
303
|
-
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
|
304
|
-
|
|
305
|
-
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
|
306
|
-
cur_new_labels = torch.cat(cur_new_labels)
|
|
307
|
-
|
|
308
|
-
new_input_embeds.append(cur_new_input_embeds)
|
|
309
|
-
new_labels.append(cur_new_labels)
|
|
310
|
-
|
|
311
|
-
# Truncate sequences to max length as image embeddings can make the sequence longer
|
|
312
|
-
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
|
|
313
|
-
if tokenizer_model_max_length is not None:
|
|
314
|
-
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
|
315
|
-
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
|
316
|
-
|
|
317
|
-
# Combine them
|
|
318
|
-
max_len = max(x.shape[0] for x in new_input_embeds)
|
|
319
|
-
batch_size = len(new_input_embeds)
|
|
320
|
-
|
|
321
|
-
new_input_embeds_padded = []
|
|
322
|
-
new_labels_padded = torch.full(
|
|
323
|
-
(batch_size, max_len), self.ignore_index, dtype=new_labels[0].dtype, device=new_labels[0].device
|
|
324
|
-
)
|
|
325
|
-
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
|
326
|
-
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
|
327
|
-
|
|
328
|
-
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
|
329
|
-
cur_len = cur_new_embed.shape[0]
|
|
330
|
-
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
|
|
331
|
-
new_input_embeds_padded.append(
|
|
332
|
-
torch.cat(
|
|
333
|
-
(
|
|
334
|
-
torch.zeros(
|
|
335
|
-
(max_len - cur_len, cur_new_embed.shape[1]),
|
|
336
|
-
dtype=cur_new_embed.dtype,
|
|
337
|
-
device=cur_new_embed.device,
|
|
338
|
-
),
|
|
339
|
-
cur_new_embed,
|
|
340
|
-
),
|
|
341
|
-
dim=0,
|
|
342
|
-
)
|
|
343
|
-
)
|
|
344
|
-
if cur_len > 0:
|
|
345
|
-
new_labels_padded[i, -cur_len:] = cur_new_labels
|
|
346
|
-
attention_mask[i, -cur_len:] = True
|
|
347
|
-
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
|
348
|
-
else:
|
|
349
|
-
new_input_embeds_padded.append(
|
|
350
|
-
torch.cat(
|
|
351
|
-
(
|
|
352
|
-
cur_new_embed,
|
|
353
|
-
torch.zeros(
|
|
354
|
-
(max_len - cur_len, cur_new_embed.shape[1]),
|
|
355
|
-
dtype=cur_new_embed.dtype,
|
|
356
|
-
device=cur_new_embed.device,
|
|
357
|
-
),
|
|
358
|
-
),
|
|
359
|
-
dim=0,
|
|
360
|
-
)
|
|
361
|
-
)
|
|
362
|
-
if cur_len > 0:
|
|
363
|
-
new_labels_padded[i, :cur_len] = cur_new_labels
|
|
364
|
-
attention_mask[i, :cur_len] = True
|
|
365
|
-
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
|
366
|
-
|
|
367
|
-
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
|
368
|
-
|
|
369
|
-
if _labels is None:
|
|
370
|
-
new_labels = None
|
|
371
|
-
else:
|
|
372
|
-
new_labels = new_labels_padded
|
|
373
|
-
|
|
374
|
-
if _attention_mask is None:
|
|
375
|
-
attention_mask = None
|
|
376
|
-
else:
|
|
377
|
-
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
|
378
|
-
|
|
379
|
-
if _position_ids is None:
|
|
380
|
-
position_ids = None
|
|
381
|
-
|
|
382
|
-
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
|
383
|
-
|
|
384
|
-
def forward(
|
|
385
|
-
self,
|
|
386
|
-
input_ids: torch.LongTensor = None,
|
|
387
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
388
|
-
position_ids: Optional[torch.LongTensor] = None,
|
|
389
|
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
390
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
391
|
-
labels: Optional[torch.LongTensor] = None,
|
|
392
|
-
use_cache: Optional[bool] = None,
|
|
393
|
-
output_attentions: Optional[bool] = None,
|
|
394
|
-
output_hidden_states: Optional[bool] = None,
|
|
395
|
-
images: Optional[torch.FloatTensor] = None,
|
|
396
|
-
image_sizes: Optional[List[List[int]]] = None,
|
|
397
|
-
return_dict: Optional[bool] = None,
|
|
398
|
-
cache_position: Optional[torch.LongTensor] = None,
|
|
399
|
-
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
400
|
-
|
|
401
|
-
if inputs_embeds is None:
|
|
402
|
-
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
|
|
403
|
-
self.prepare_inputs_labels_for_multimodal(
|
|
404
|
-
input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes
|
|
405
|
-
)
|
|
406
|
-
)
|
|
407
|
-
return super().forward(
|
|
408
|
-
input_ids=input_ids,
|
|
409
|
-
attention_mask=attention_mask,
|
|
410
|
-
position_ids=position_ids,
|
|
411
|
-
past_key_values=past_key_values,
|
|
412
|
-
inputs_embeds=inputs_embeds,
|
|
413
|
-
labels=labels,
|
|
414
|
-
use_cache=use_cache,
|
|
415
|
-
output_attentions=output_attentions,
|
|
416
|
-
output_hidden_states=output_hidden_states,
|
|
417
|
-
return_dict=return_dict,
|
|
418
|
-
)
|
|
419
|
-
|
|
420
|
-
@torch.no_grad()
|
|
421
|
-
def generate(
|
|
422
|
-
self,
|
|
423
|
-
inputs: Optional[torch.Tensor] = None,
|
|
424
|
-
images: Optional[torch.Tensor] = None,
|
|
425
|
-
image_sizes: Optional[List[List[int]]] = None,
|
|
426
|
-
**kwargs,
|
|
427
|
-
) -> Union[GenerateOutput, torch.LongTensor]:
|
|
428
|
-
position_ids = kwargs.pop("position_ids", None)
|
|
429
|
-
attention_mask = kwargs.pop("attention_mask", None)
|
|
430
|
-
if "inputs_embeds" in kwargs:
|
|
431
|
-
raise NotImplementedError("`inputs_embeds` is not supported")
|
|
432
|
-
|
|
433
|
-
inputs, position_ids, attention_mask, _, inputs_embeds, _ = self.prepare_inputs_labels_for_multimodal(
|
|
434
|
-
inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes
|
|
435
|
-
)
|
|
436
|
-
|
|
437
|
-
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
|
438
|
-
|
|
439
|
-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
|
440
|
-
images = kwargs.pop("images", None)
|
|
441
|
-
image_sizes = kwargs.pop("image_sizes", None)
|
|
442
|
-
inputs = super().prepare_inputs_for_generation(
|
|
443
|
-
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
|
444
|
-
)
|
|
445
|
-
if images is not None:
|
|
446
|
-
inputs["images"] = images
|
|
447
|
-
if image_sizes is not None:
|
|
448
|
-
inputs["image_sizes"] = image_sizes
|
|
449
|
-
return inputs
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
from sglang.srt.configs.model_config import multimodal_model_archs
|
|
2
|
-
from sglang.srt.models.registry import ModelRegistry
|
|
3
|
-
|
|
4
|
-
from sglang.srt.managers.multimodal_processor import (
|
|
5
|
-
PROCESSOR_MAPPING as PROCESSOR_MAPPING,
|
|
6
|
-
)
|
|
7
|
-
|
|
8
|
-
from .. import vlm_hf_model as _
|
|
9
|
-
from .image_processor import Mineru2ImageProcessor
|
|
10
|
-
from .model import Mineru2QwenForCausalLM
|
|
11
|
-
|
|
12
|
-
ModelRegistry.models[Mineru2QwenForCausalLM.__name__] = Mineru2QwenForCausalLM
|
|
13
|
-
PROCESSOR_MAPPING[Mineru2QwenForCausalLM] = Mineru2ImageProcessor
|
|
14
|
-
multimodal_model_archs.append(Mineru2QwenForCausalLM.__name__)
|