mineru 2.2.2__py3-none-any.whl → 2.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. mineru/backend/pipeline/pipeline_middle_json_mkcontent.py +3 -3
  2. mineru/backend/vlm/model_output_to_middle_json.py +123 -0
  3. mineru/backend/vlm/vlm_analyze.py +97 -16
  4. mineru/backend/vlm/vlm_magic_model.py +201 -135
  5. mineru/backend/vlm/vlm_middle_json_mkcontent.py +52 -11
  6. mineru/cli/client.py +6 -5
  7. mineru/cli/common.py +17 -16
  8. mineru/cli/fast_api.py +9 -7
  9. mineru/cli/gradio_app.py +15 -16
  10. mineru/cli/vlm_vllm_server.py +4 -0
  11. mineru/model/table/rec/unet_table/main.py +8 -0
  12. mineru/model/vlm_vllm_model/__init__.py +0 -0
  13. mineru/model/vlm_vllm_model/server.py +51 -0
  14. mineru/resources/header.html +10 -2
  15. mineru/utils/draw_bbox.py +32 -10
  16. mineru/utils/enum_class.py +16 -2
  17. mineru/utils/guess_suffix_or_lang.py +20 -0
  18. mineru/utils/span_block_fix.py +4 -2
  19. mineru/version.py +1 -1
  20. {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/METADATA +70 -25
  21. {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/RECORD +25 -38
  22. {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/entry_points.txt +1 -1
  23. mineru/backend/vlm/base_predictor.py +0 -186
  24. mineru/backend/vlm/hf_predictor.py +0 -217
  25. mineru/backend/vlm/predictor.py +0 -111
  26. mineru/backend/vlm/sglang_client_predictor.py +0 -443
  27. mineru/backend/vlm/sglang_engine_predictor.py +0 -246
  28. mineru/backend/vlm/token_to_middle_json.py +0 -122
  29. mineru/backend/vlm/utils.py +0 -40
  30. mineru/cli/vlm_sglang_server.py +0 -4
  31. mineru/model/vlm_hf_model/__init__.py +0 -9
  32. mineru/model/vlm_hf_model/configuration_mineru2.py +0 -38
  33. mineru/model/vlm_hf_model/image_processing_mineru2.py +0 -269
  34. mineru/model/vlm_hf_model/modeling_mineru2.py +0 -449
  35. mineru/model/vlm_sglang_model/__init__.py +0 -14
  36. mineru/model/vlm_sglang_model/engine.py +0 -264
  37. mineru/model/vlm_sglang_model/image_processor.py +0 -213
  38. mineru/model/vlm_sglang_model/logit_processor.py +0 -90
  39. mineru/model/vlm_sglang_model/model.py +0 -453
  40. mineru/model/vlm_sglang_model/server.py +0 -75
  41. {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/WHEEL +0 -0
  42. {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/licenses/LICENSE.md +0 -0
  43. {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/top_level.txt +0 -0
@@ -1,453 +0,0 @@
1
- import math
2
- import re
3
- from typing import Iterable, List, Optional, Tuple
4
-
5
- import numpy as np
6
- import torch
7
- from sglang.srt.layers.quantization.base_config import QuantizationConfig
8
-
9
- from sglang.version import __version__ as sglang_version
10
- from packaging import version
11
- if version.parse(sglang_version) >= version.parse("0.4.9"):
12
- # sglang >= 0.4.9
13
- from sglang.srt.multimodal.mm_utils import (
14
- get_anyres_image_grid_shape,
15
- )
16
- else:
17
- # 0.4.7 <= sglang < 0.4.9
18
- from sglang.srt.mm_utils import (
19
- get_anyres_image_grid_shape,
20
- )
21
-
22
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
23
- from sglang.srt.model_loader.weight_utils import default_weight_loader
24
- from sglang.srt.models.qwen2 import Qwen2ForCausalLM
25
- from sglang.srt.utils import add_prefix
26
- from torch import nn
27
- from transformers import (
28
- CLIPVisionConfig,
29
- CLIPVisionModel,
30
- SiglipVisionConfig,
31
- SiglipVisionModel,
32
- )
33
-
34
- from ..vlm_hf_model.configuration_mineru2 import Mineru2QwenConfig
35
- from ..vlm_hf_model.modeling_mineru2 import build_vision_projector
36
- from ...utils.models_download_utils import auto_download_and_get_model_root_path
37
-
38
-
39
- def flatten_nested_list(nested_list):
40
- if isinstance(nested_list, list):
41
- return [item for sublist in nested_list for item in flatten_nested_list(sublist)]
42
- else:
43
- return [nested_list]
44
-
45
-
46
- def downgrade_modality(modality):
47
- modality_str = str(modality)
48
- if "MULTI_IMAGES" in modality_str:
49
- return "multi-images"
50
- if "IMAGE" in modality_str:
51
- return "image"
52
- if "VIDEO" in modality_str:
53
- return "video"
54
- if "AUDIO" in modality_str:
55
- return "audio"
56
- raise ValueError(f"Unexpected modality: {modality_str}")
57
-
58
-
59
- class Mineru2QwenForCausalLM(nn.Module):
60
- def __init__(
61
- self,
62
- config: Mineru2QwenConfig,
63
- quant_config: Optional[QuantizationConfig] = None,
64
- prefix: str = "",
65
- ) -> None:
66
- super().__init__()
67
- self.config = config
68
-
69
- if getattr(self.config, "projector_hidden_act", None) is None:
70
- self.config.projector_hidden_act = "gelu"
71
- if getattr(self.config, "image_token_index", None) is None:
72
- self.config.image_token_index = 151646
73
-
74
- # load vision tower
75
- mm_vision_tower = self.config.mm_vision_tower
76
- model_root_path = auto_download_and_get_model_root_path(mm_vision_tower, "vlm")
77
- mm_vision_tower = f"{model_root_path}/{mm_vision_tower}"
78
-
79
- if "clip" in mm_vision_tower:
80
- vision_config = CLIPVisionConfig.from_pretrained(mm_vision_tower)
81
- self.vision_tower = CLIPVisionModel(vision_config) # type: ignore
82
- elif "siglip" in mm_vision_tower:
83
- vision_config = SiglipVisionConfig.from_pretrained(mm_vision_tower)
84
- self.vision_tower = SiglipVisionModel(vision_config) # type: ignore
85
- # Siglip needs all feature tokens
86
- self.config.mm_vision_select_feature = "full"
87
- else:
88
- raise ValueError(f"Unexpected mm_vision_tower: {mm_vision_tower}")
89
-
90
- ### EDIT: change projector
91
- # the name `projector` contains `proj` which is often used in attention layers, which can cause bugs in quantization.
92
- self.multi_modal_mlp = build_vision_projector(config)
93
-
94
- self.language_model = Qwen2ForCausalLM(
95
- config,
96
- quant_config=quant_config,
97
- prefix=add_prefix("language_model", prefix),
98
- )
99
-
100
- if "unpad" in getattr(config, "mm_patch_merge_type", ""):
101
- self.language_model.model.image_newline = nn.Parameter(torch.empty(config.hidden_size))
102
-
103
- language_model_device = next(self.language_model.parameters()).device
104
- self.vision_tower = self.vision_tower.to(language_model_device)
105
- self.vision_tower.eval()
106
-
107
- self.vision_feature_layer = self.config.mm_vision_select_layer
108
- self.vision_feature_select_strategy = self.config.mm_vision_select_feature
109
- self.image_size = self.vision_tower.config.image_size
110
- self.patch_size = self.vision_tower.config.patch_size
111
-
112
- self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
113
- self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
114
- self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
115
-
116
- self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
117
- if self.vision_feature_select_strategy in ("patch", "full"):
118
- pass
119
- elif self.vision_feature_select_strategy == "cls_patch":
120
- self.image_feature_len += 1
121
- else:
122
- raise ValueError(f"Unexpected select feature: {self.select_feature}")
123
-
124
- def pad_input_ids(self, input_ids: List[int], image_inputs):
125
-
126
- image_sizes = flatten_nested_list([item.image_sizes for item in image_inputs.mm_items])
127
- pad_values = [item.pad_value for item in image_inputs.mm_items]
128
-
129
- # hardcode for spatial_unpad + anyres
130
- # if image_inputs.modalities is not None and (
131
- # "multi-images" in image_inputs.modalities or "video" in image_inputs.modalities
132
- # ):
133
- # image_aspect_ratio = "pad"
134
- # else:
135
- # image_aspect_ratio = "anyres"
136
-
137
- offset_list = []
138
- image_inputs.image_pad_len = []
139
- for image_idx, image_s in enumerate(image_sizes):
140
- if len(image_sizes) > 16:
141
- # 2x2 pooling with stride 2
142
- new_image_feature_len = math.ceil(self.image_size / self.patch_size / 2) ** 2
143
- else:
144
- new_image_feature_len = self.image_feature_len # multiimage
145
-
146
- height = width = self.num_patches_per_side
147
- if "anyres" in self.config.image_aspect_ratio:
148
- num_patch_width, num_patch_height = get_anyres_image_grid_shape(
149
- image_s,
150
- self.image_grid_pinpoints,
151
- self.vision_tower.config.image_size,
152
- )
153
- h = num_patch_height * height
154
- w = num_patch_width * width
155
-
156
- ### EDIT: remove `unpad_image_shape`
157
- # new_h, new_w = unpad_image_shape(h, w, image_s)
158
- new_h, new_w = h, w
159
-
160
- if "anyres_max" in self.config.image_aspect_ratio:
161
- matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", self.config.image_aspect_ratio)
162
- if matched_anyres_max_num_patches:
163
- max_num_patches = int(matched_anyres_max_num_patches.group(1))
164
- times = math.sqrt(new_h * new_w / (max_num_patches * self.image_feature_len))
165
- if times > 1.1:
166
- new_h = int(new_h // times)
167
- new_w = int(new_w // times)
168
- new_image_feature_len += new_h * (new_w + 1)
169
-
170
- try:
171
- offset = input_ids.index(self.config.image_token_index)
172
- except ValueError:
173
- offset = 0
174
- # old_len + pad_len - 1, because we need to remove image_token_id
175
- input_ids = input_ids[:offset] + [pad_values[image_idx]] * new_image_feature_len + input_ids[offset + 1 :]
176
- offset_list.append(offset)
177
- image_inputs.image_pad_len.append(new_image_feature_len)
178
-
179
- image_inputs.image_offsets = offset_list
180
- return input_ids
181
-
182
- def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
183
- pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype)
184
- image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
185
- # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
186
-
187
- selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
188
- if self.vision_feature_select_strategy in ["default", "patch"]:
189
- selected_image_feature = selected_image_feature[:, 1:]
190
- elif self.vision_feature_select_strategy == "full":
191
- selected_image_feature = selected_image_feature
192
- else:
193
- raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}")
194
-
195
- image_features = self.multi_modal_mlp(selected_image_feature)
196
- return image_features
197
-
198
- @torch.no_grad()
199
- def forward(
200
- self,
201
- input_ids: torch.LongTensor,
202
- positions: torch.Tensor,
203
- forward_batch: ForwardBatch,
204
- ) -> torch.Tensor:
205
-
206
- image_inputs = forward_batch.mm_inputs
207
-
208
- if image_inputs is None:
209
- image_inputs = []
210
-
211
- if forward_batch.forward_mode.is_extend():
212
- # Clamp input ids. This is because the input_ids for the image tokens are
213
- # filled with the hash values of the image for the prefix matching in the radix attention.
214
- # There values are useless because their embeddings will be replaced by vision embeddings anyway.
215
- input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
216
-
217
- # Embed text inputs
218
- input_embeds = self.language_model.model.embed_tokens(input_ids)
219
-
220
- # Got List[List[str]] extend it to List[str]
221
- # The length of the List should be equal to batch size
222
- modalities_list = []
223
- max_image_offset = []
224
- for im in image_inputs:
225
- if im:
226
- modalities_list.extend([downgrade_modality(item.modality) for item in im.mm_items])
227
- if im and im.image_offsets:
228
- max_image_offset.append(np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)))
229
- else:
230
- max_image_offset.append(-1)
231
-
232
- start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
233
- need_vision = start_positions <= np.array(max_image_offset)
234
-
235
- if need_vision.any():
236
- bs = forward_batch.batch_size
237
-
238
- if version.parse(sglang_version) >= version.parse("0.4.9.post3"):
239
- # sglang >= 0.4.9.post3
240
- pixel_values = flatten_nested_list(
241
- [[item.feature for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i]]
242
- ) # image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
243
- image_sizes = [
244
- flatten_nested_list([item.model_specific_data["image_sizes"] for item in image_inputs[i].mm_items])
245
- for i in range(bs)
246
- if need_vision[i]
247
- ] # image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
248
- else:
249
- # 0.4.7 <= sglang <= 0.4.9.post2
250
- pixel_values = flatten_nested_list(
251
- [[item.pixel_values for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i]]
252
- ) # image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
253
- image_sizes = [
254
- flatten_nested_list([item.image_sizes for item in image_inputs[i].mm_items])
255
- for i in range(bs)
256
- if need_vision[i]
257
- ] # image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
258
-
259
- ########## Encode Image ########
260
-
261
- if pixel_values[0].ndim == 4:
262
- # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
263
- np.concatenate(pixel_values, axis=0)
264
- # ndim=4
265
- concat_images = torch.tensor(
266
- np.concatenate(pixel_values, axis=0),
267
- device=self.vision_tower.device,
268
- )
269
- image_features = self.encode_images(concat_images)
270
- split_sizes = [image.shape[0] for image in pixel_values]
271
- image_features = torch.split(image_features, split_sizes, dim=0)
272
- # hd image_features: BS, num_patch, 576, 4096
273
- else:
274
- # normal pixel: BS, C=3, H=336, W=336
275
- pixel_values = torch.tensor(np.array(pixel_values), device=self.vision_tower.device)
276
- image_features = self.encode_images(pixel_values)
277
- # image_features: BS, 576, 4096
278
-
279
- if self.mm_patch_merge_type.startswith("spatial"):
280
- new_image_features = []
281
- height = width = self.num_patches_per_side
282
- for image_idx, image_feature in enumerate(image_features):
283
- if modalities_list[image_idx] == "image":
284
- image_aspect_ratio = self.config.image_aspect_ratio # single image
285
- elif modalities_list[image_idx] == "multi-images" or modalities_list[image_idx] == "video":
286
- image_aspect_ratio = "pad" # multi image
287
- # image_aspect_ratio = (
288
- # "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
289
- # )
290
- if (
291
- image_feature.shape[0] > 1
292
- and "anyres" in image_aspect_ratio
293
- and modalities_list[image_idx] == "image"
294
- ):
295
- base_image_feature = image_feature[0]
296
- image_feature = image_feature[1:]
297
- assert height * width == base_image_feature.shape[0]
298
-
299
- if "anyres_max" in image_aspect_ratio:
300
- matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", image_aspect_ratio)
301
- if matched_anyres_max_num_patches:
302
- max_num_patches = int(matched_anyres_max_num_patches.group(1))
303
-
304
- if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
305
- vision_tower_image_size = self.image_size
306
- try:
307
- num_patch_width, num_patch_height = get_anyres_image_grid_shape(
308
- image_sizes[image_idx][0],
309
- self.config.image_grid_pinpoints,
310
- vision_tower_image_size,
311
- )
312
- except Exception as e:
313
- print(f"Error: {e}")
314
- num_patch_width, num_patch_height = 2, 2
315
- image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
316
- else:
317
- image_feature = image_feature.view(2, 2, height, width, -1)
318
-
319
- if "unpad" in self.mm_patch_merge_type:
320
- unit = image_feature.shape[2]
321
- image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
322
- image_feature = image_feature.flatten(1, 2).flatten(2, 3)
323
-
324
- ### EDIT: remove `unpad_image`
325
- # image_feature = unpad_image(image_feature, image_sizes[image_idx][0])
326
-
327
- if "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
328
- c, h, w = image_feature.shape
329
- times = math.sqrt(h * w / (max_num_patches * unit**2))
330
- if times > 1.1:
331
- image_feature = image_feature[None]
332
- image_feature = nn.functional.interpolate(
333
- image_feature,
334
- [int(h // times), int(w // times)],
335
- mode="bilinear",
336
- )[0]
337
- image_feature = torch.cat(
338
- (
339
- image_feature,
340
- self.language_model.model.image_newline[:, None, None].expand(
341
- *image_feature.shape[:-1], 1
342
- ),
343
- ),
344
- dim=-1,
345
- )
346
- image_feature = image_feature.flatten(1, 2).transpose(0, 1)
347
- else:
348
- image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
349
- image_feature = image_feature.flatten(0, 3)
350
- image_feature = torch.cat((base_image_feature, image_feature), dim=0)
351
- image_feature = image_feature.unsqueeze(0)
352
- else:
353
- if modalities_list[image_idx] == "video": # video
354
- # 2x2 pooling
355
- num_of_frames = image_feature.shape[0]
356
- image_feature = image_feature.view(num_of_frames, height, width, -1)
357
- image_feature = image_feature.permute(0, 3, 1, 2).contiguous() # N, C, H, W
358
- height, weight = image_feature.shape[2:]
359
- scaled_shape = [
360
- math.ceil(height / 2),
361
- math.ceil(weight / 2),
362
- ]
363
- image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode="bilinear")
364
- image_feature = image_feature.flatten(2).transpose(1, 2).contiguous() # N, C, H*W
365
- if "unpad" in self.mm_patch_merge_type:
366
- image_feature = torch.cat(
367
- (
368
- image_feature,
369
- # Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
370
- self.language_model.model.image_newline[None, None].expand(
371
- image_feature.shape[0],
372
- 1,
373
- image_feature.shape[-1],
374
- ),
375
- ),
376
- dim=1,
377
- )
378
-
379
- new_image_features.append(image_feature)
380
- image_features = new_image_features
381
-
382
- # Fill in the placeholder for the image
383
- extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
384
- extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
385
- prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
386
- pt = 0
387
- for i in range(bs):
388
- if not need_vision[i]:
389
- continue
390
-
391
- start_idx = extend_start_loc_cpu[i]
392
- seq_len = extend_seq_lens[i]
393
- prefix_len = prefix_lens_cpu[i]
394
-
395
- # Multiple images
396
- for image_idx, image_offset in enumerate(image_inputs[i].image_offsets):
397
- if image_offset + image_inputs[i].image_pad_len[image_idx] <= prefix_len:
398
- continue
399
- if image_offset >= prefix_len + seq_len:
400
- break
401
-
402
- tmp_image_feature = image_features[pt][image_idx]
403
- pad_len = tmp_image_feature.shape[0]
404
-
405
- input_offset = image_offset - prefix_len
406
- left_idx = start_idx + input_offset
407
- right_idx = left_idx + pad_len
408
- assert right_idx > start_idx
409
- if input_offset < 0:
410
- left_idx = start_idx
411
- tmp_image_feature = tmp_image_feature[-input_offset:]
412
- if right_idx > start_idx + seq_len:
413
- tmp_image_feature = tmp_image_feature[: start_idx + seq_len - right_idx]
414
- right_idx = start_idx + seq_len
415
- try:
416
- input_embeds[left_idx:right_idx] = tmp_image_feature
417
- except RuntimeError as e:
418
- print(f"RuntimeError in image encoding: {e}")
419
- print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
420
- print(f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}")
421
- pt += 1
422
-
423
- return self.language_model(input_ids, positions, forward_batch, input_embeds=input_embeds)
424
- elif forward_batch.forward_mode.is_decode():
425
- return self.language_model(input_ids, positions, forward_batch)
426
- else:
427
- raise ValueError(f"Unexpected forward mode: {forward_batch.forward_mode}")
428
-
429
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
430
- projector_weights = {
431
- "model.mm_projector": "multi_modal_mlp",
432
- "model.vision_tower.vision_tower": "vision_tower",
433
- # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
434
- "model.image_newline": "language_model.model.image_newline",
435
- }
436
- params_dict = dict(self.named_parameters())
437
- for name, loaded_weight in weights:
438
- if "projector" in name or "vision_tower" in name or "image_newline" in name:
439
- for weight_name, param_name in projector_weights.items():
440
- if weight_name in name:
441
- name = name.replace(weight_name, param_name)
442
- param = params_dict[name]
443
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
444
- weight_loader(param, loaded_weight)
445
- else:
446
- self.language_model.load_weights([(name, loaded_weight)])
447
-
448
- @property
449
- def num_patches_per_side(self):
450
- return self.image_size // self.patch_size
451
-
452
-
453
- EntryClass = [Mineru2QwenForCausalLM]
@@ -1,75 +0,0 @@
1
- import os
2
- import sys
3
-
4
- from fastapi import Request
5
- from sglang.srt.entrypoints.http_server import app, generate_request, launch_server
6
- from sglang.srt.managers.io_struct import GenerateReqInput
7
- from sglang.srt.server_args import prepare_server_args
8
- from sglang.srt.utils import kill_process_tree
9
- from sglang.srt.conversation import Conversation
10
-
11
- from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
12
- from .logit_processor import Mineru2LogitProcessor
13
-
14
- # mineru2.0的chat_template与chatml在换行上有微小区别
15
- def custom_get_prompt(self) -> str:
16
- system_prompt = self.system_template.format(system_message=self.system_message)
17
- if self.system_message == "":
18
- ret = ""
19
- else:
20
- ret = system_prompt + self.sep
21
-
22
- for role, message in self.messages:
23
- if message:
24
- ret += role + "\n" + message + self.sep
25
- else:
26
- ret += role + "\n"
27
- return ret
28
-
29
- _custom_logit_processor_str = Mineru2LogitProcessor().to_str()
30
-
31
- # remote the existing /generate route
32
- for route in app.routes[:]:
33
- if hasattr(route, "path") and getattr(route, "path") == "/generate":
34
- app.routes.remove(route)
35
-
36
-
37
- # add the custom /generate route
38
- @app.api_route("/generate", methods=["POST", "PUT"])
39
- async def custom_generate_request(obj: GenerateReqInput, request: Request):
40
- if obj.custom_logit_processor is None:
41
- obj.custom_logit_processor = _custom_logit_processor_str
42
- return await generate_request(obj, request)
43
-
44
-
45
- def main():
46
- # 检查命令行参数中是否包含--model-path
47
- args = sys.argv[1:]
48
- has_model_path_arg = False
49
-
50
- for i, arg in enumerate(args):
51
- if arg == "--model-path" or arg.startswith("--model-path="):
52
- has_model_path_arg = True
53
- break
54
-
55
- # 如果没有--model-path参数,在参数列表中添加它
56
- if not has_model_path_arg:
57
- default_path = auto_download_and_get_model_root_path("/", "vlm")
58
- args.extend(["--model-path", default_path])
59
-
60
- server_args = prepare_server_args(args)
61
-
62
- if server_args.chat_template is None:
63
- server_args.chat_template = "chatml"
64
- Conversation.get_prompt = custom_get_prompt
65
-
66
- server_args.enable_custom_logit_processor = True
67
-
68
- try:
69
- launch_server(server_args)
70
- finally:
71
- kill_process_tree(os.getpid(), include_parent=False)
72
-
73
-
74
- if __name__ == "__main__":
75
- main()
File without changes