sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1238 @@
1
+ # Adapted from
2
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
3
+ # Copyright 2023 The vLLM team.
4
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
7
+ # and OPT implementations in this library. It has been modified from its
8
+ # original forms to accommodate minor architectural differences compared
9
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ """Inference-only MiniCPM-V model compatible with HuggingFace weights."""
23
+ from functools import cached_property, partial
24
+ from typing import (
25
+ Any,
26
+ Callable,
27
+ Iterable,
28
+ List,
29
+ Literal,
30
+ Optional,
31
+ Tuple,
32
+ TypedDict,
33
+ Union,
34
+ )
35
+
36
+ import torch
37
+ import torch.types
38
+ from PIL import Image
39
+ from torch import nn
40
+ from torch.nn.init import trunc_normal_
41
+ from transformers import PretrainedConfig
42
+ from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
43
+ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
44
+ from vllm.model_executor.models.module_mapping import MultiModelKeys
45
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
46
+
47
+ from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
48
+ from sglang.srt.layers.activation import get_act_fn
49
+ from sglang.srt.layers.attention.vision import VisionAttention
50
+ from sglang.srt.layers.linear import (
51
+ ColumnParallelLinear,
52
+ ReplicatedLinear,
53
+ RowParallelLinear,
54
+ )
55
+ from sglang.srt.layers.logits_processor import LogitsProcessor
56
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
57
+ from sglang.srt.managers.schedule_batch import ImageInputs
58
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
59
+ from sglang.srt.model_loader.utils import set_default_torch_dtype
60
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
61
+ from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
62
+
63
+ RawImageType = Union[Image.Image, torch.Tensor]
64
+
65
+
66
+ class Idefics2VisionMLP(nn.Module):
67
+
68
+ def __init__(
69
+ self,
70
+ config: PretrainedConfig,
71
+ quant_config: Optional[QuantizationConfig] = None,
72
+ prefix: str = "",
73
+ ) -> None:
74
+ super().__init__()
75
+ self.config = config
76
+ self.activation_fn = get_act_fn(config.hidden_act)
77
+ self.fc1 = ColumnParallelLinear(
78
+ config.hidden_size,
79
+ config.intermediate_size,
80
+ bias=True,
81
+ quant_config=quant_config,
82
+ prefix=f"{prefix}.fc1",
83
+ )
84
+ self.fc2 = RowParallelLinear(
85
+ config.intermediate_size,
86
+ config.hidden_size,
87
+ bias=True,
88
+ quant_config=quant_config,
89
+ prefix=f"{prefix}.fc2",
90
+ )
91
+
92
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
93
+ hidden_states, _ = self.fc1(hidden_states)
94
+ hidden_states = self.activation_fn(hidden_states)
95
+ hidden_states, _ = self.fc2(hidden_states)
96
+ return hidden_states
97
+
98
+
99
+ class Idefics2EncoderLayer(nn.Module):
100
+
101
+ def __init__(
102
+ self,
103
+ config: PretrainedConfig,
104
+ quant_config: Optional[QuantizationConfig] = None,
105
+ prefix: str = "",
106
+ ) -> None:
107
+ super().__init__()
108
+ self.embed_dim = config.hidden_size
109
+
110
+ self.num_heads = config.num_attention_heads
111
+ tp_size = get_tensor_model_parallel_world_size()
112
+ num_heads_per_partition = divide(self.num_heads, tp_size)
113
+ self.self_attn = VisionAttention(
114
+ embed_dim=config.hidden_size,
115
+ num_heads=num_heads_per_partition,
116
+ projection_size=config.intermediate_size,
117
+ use_qkv_parallel=True,
118
+ quant_config=quant_config,
119
+ prefix=f"{prefix}.self_attn",
120
+ )
121
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
122
+ self.mlp = Idefics2VisionMLP(config, quant_config=quant_config)
123
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
124
+
125
+ def forward(
126
+ self,
127
+ hidden_states: torch.Tensor,
128
+ cu_seqlens: torch.Tensor,
129
+ forward_batch: ForwardBatch,
130
+ ) -> torch.Tensor:
131
+ """
132
+ Args:
133
+ hidden_states (`torch.FloatTensor`):
134
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
135
+
136
+ """
137
+ residual = hidden_states
138
+ hidden_states = self.layer_norm1(hidden_states)
139
+ hidden_states = self.self_attn(
140
+ hidden_states,
141
+ cu_seqlens=cu_seqlens,
142
+ # , forward_batch=forward_batch
143
+ )
144
+ hidden_states = residual + hidden_states
145
+ residual = hidden_states
146
+ hidden_states = self.layer_norm2(hidden_states)
147
+ hidden_states = self.mlp(hidden_states)
148
+ hidden_states = residual + hidden_states
149
+ return hidden_states
150
+
151
+
152
+ class Idefics2Encoder(nn.Module):
153
+ """
154
+ Transformer encoder consisting of `config.num_hidden_layers` self attention
155
+ layers. Each layer is a
156
+ [`Idefics2EncoderLayer`].
157
+
158
+ Args:
159
+ config: Idefics2Config
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ config: PretrainedConfig,
165
+ quant_config: Optional[QuantizationConfig] = None,
166
+ ) -> None:
167
+ super().__init__()
168
+
169
+ self.config = config
170
+ self.layers = nn.ModuleList(
171
+ [
172
+ Idefics2EncoderLayer(
173
+ config,
174
+ quant_config=quant_config,
175
+ )
176
+ for _ in range(config.num_hidden_layers)
177
+ ]
178
+ )
179
+
180
+ def forward(
181
+ self,
182
+ inputs_embeds: torch.Tensor,
183
+ cu_seqlens: torch.Tensor,
184
+ forward_batch: ForwardBatch,
185
+ ) -> torch.Tensor:
186
+ r"""
187
+ Args:
188
+ inputs_embeds (torch.Tensor):
189
+ Optionally, instead of passing `input_ids` you can choose to
190
+ directly pass an embedded representation.
191
+ This is useful if you want more control over how to convert
192
+ `input_ids` indices into associated vectorsthan the model's
193
+ internal embedding lookup matrix.
194
+ """
195
+ hidden_states = inputs_embeds
196
+ for encoder_layer in self.layers:
197
+ layer_outputs = encoder_layer(
198
+ hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch
199
+ )
200
+ hidden_states = layer_outputs
201
+ return hidden_states
202
+
203
+
204
+ class Idefics2VisionEmbeddings(nn.Module):
205
+ """
206
+ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings
207
+ ` to enable images of variable
208
+ resolution.
209
+
210
+ The modifications are adapted from [Patch n' Pack: NaViT, a Vision
211
+ Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
212
+ which allows treating images in their native aspect ratio and without the
213
+ need to resize them to the same fixed size. In particular, we start from the
214
+ original pre-trained SigLIP model(which uses images of fixed-size square
215
+ images) and adapt it by training on images of variable resolutions.
216
+ """
217
+
218
+ def __init__(self, config: PretrainedConfig):
219
+ super().__init__()
220
+ self.embed_dim = config.hidden_size
221
+ self.image_size = config.image_size
222
+ self.patch_size = config.patch_size
223
+ self.patch_embedding = nn.Conv2d(
224
+ in_channels=config.num_channels,
225
+ out_channels=self.embed_dim,
226
+ kernel_size=self.patch_size,
227
+ stride=self.patch_size,
228
+ padding="valid",
229
+ )
230
+ self.num_patches_per_side = self.image_size // self.patch_size
231
+ self.num_patches = self.num_patches_per_side**2
232
+ self.num_positions = self.num_patches
233
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
234
+
235
+ def forward(
236
+ self,
237
+ pixel_values: torch.FloatTensor,
238
+ patch_attention_mask: torch.BoolTensor,
239
+ tgt_sizes: Optional[torch.IntTensor] = None,
240
+ ) -> torch.Tensor:
241
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
242
+ target_dtype = self.patch_embedding.weight.dtype
243
+ pixel_values = pixel_values.to(
244
+ device=self.patch_embedding.weight.device, dtype=target_dtype
245
+ )
246
+ patch_embeds = self.patch_embedding(pixel_values)
247
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
248
+ max_nb_patches_h, max_nb_patches_w = (
249
+ max_im_h // self.patch_size,
250
+ max_im_w // self.patch_size,
251
+ )
252
+ boundaries = torch.arange(
253
+ 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
254
+ )
255
+ position_ids = torch.full(
256
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
257
+ )
258
+
259
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
260
+
261
+ if tgt_sizes is not None:
262
+ nb_patches_h = tgt_sizes[batch_idx][0]
263
+ nb_patches_w = tgt_sizes[batch_idx][1]
264
+ else:
265
+ nb_patches_h = p_attn_mask[:, 0].sum()
266
+ nb_patches_w = p_attn_mask[0].sum()
267
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
268
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
269
+ bucket_coords_h = torch.bucketize(
270
+ fractional_coords_h, boundaries, right=True
271
+ )
272
+ bucket_coords_w = torch.bucketize(
273
+ fractional_coords_w, boundaries, right=True
274
+ )
275
+ pos_ids = (
276
+ bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
277
+ ).flatten()
278
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
279
+ position_ids = position_ids.to(self.position_embedding.weight.device)
280
+ embeddings = embeddings + self.position_embedding(position_ids)
281
+ return embeddings
282
+
283
+
284
+ class Idefics2VisionTransformer(nn.Module):
285
+
286
+ def __init__(
287
+ self,
288
+ config: PretrainedConfig,
289
+ quant_config: Optional[QuantizationConfig] = None,
290
+ prefix: str = "",
291
+ ) -> None:
292
+ super().__init__()
293
+
294
+ embed_dim = config.hidden_size
295
+ self.config = config
296
+ self.embeddings = Idefics2VisionEmbeddings(config)
297
+ self.encoder = Idefics2Encoder(config=config, quant_config=quant_config)
298
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
299
+
300
+ def get_input_embeddings(self):
301
+ return self.embeddings
302
+
303
+ def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
304
+ patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,)
305
+
306
+ # 做 prefix sum 来得到 cu_seqlens,注意在最前面插一个 0 作为 offset
307
+ cu_seqlens = torch.cat(
308
+ [
309
+ torch.tensor([0], device=patch_len.device, dtype=torch.int32),
310
+ torch.cumsum(patch_len, dim=0, dtype=torch.int32),
311
+ ],
312
+ dim=0,
313
+ ).to(tgt_sizes.device)
314
+ return cu_seqlens
315
+
316
+ def forward(
317
+ self,
318
+ pixel_values,
319
+ forward_batch: ForwardBatch,
320
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
321
+ tgt_sizes: Optional[torch.IntTensor] = None,
322
+ ) -> torch.Tensor:
323
+ hidden_states = self.embeddings(
324
+ pixel_values=pixel_values,
325
+ patch_attention_mask=patch_attention_mask,
326
+ # forward_batch=forward_batch,
327
+ tgt_sizes=tgt_sizes,
328
+ )
329
+ cu_seqlens = self.compute_cu_seqlens(tgt_sizes)
330
+ encoder_outputs = self.encoder(
331
+ hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch
332
+ )
333
+ last_hidden_state = self.post_layernorm(encoder_outputs)
334
+ return last_hidden_state
335
+
336
+
337
+ class MiniCPMVImagePixelInputs(TypedDict):
338
+ type: Literal["pixel_values"]
339
+ data: List[torch.Tensor]
340
+ """
341
+ Shape: `(batch_size * num_images, num_channels, height, width)`
342
+
343
+ Note that the image size may vary, so we pass it as a list
344
+ instead of a batched tensor.
345
+ """
346
+
347
+ image_bounds: torch.Tensor
348
+ """
349
+ Shape: `(batch_size * num_images, 2)`
350
+
351
+ This should be in `(start, stop)` format.
352
+ """
353
+
354
+ tgt_sizes: torch.Tensor
355
+ """
356
+ Shape: `(batch_size * num_images, 2)`
357
+
358
+ This should be in `(height, width)` format.
359
+ """
360
+
361
+
362
+ class MiniCPMVImageEmbeddingInputs(TypedDict):
363
+ type: Literal["image_embeds"]
364
+ data: torch.Tensor
365
+ """
366
+ Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
367
+
368
+ `hidden_size` must match the hidden size of language model backbone.
369
+ instead of a batched tensor.
370
+ """
371
+
372
+ image_bounds: torch.Tensor
373
+ """
374
+ Shape: `(batch_size * num_images, 2)`
375
+
376
+ This should be in `(start, stop)` format.
377
+ """
378
+
379
+
380
+ MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageEmbeddingInputs]
381
+
382
+ DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
383
+
384
+
385
+ class BaseResampler(nn.Module):
386
+ """
387
+ A 2D perceiver-resampler network with one cross attention layers by
388
+ (grid_size**2) learnable queries and 2d sincos pos_emb.
389
+ Outputs:
390
+ A tensor with the shape of (grid_size**2, embed_dim)
391
+ """
392
+
393
+ def __init__(
394
+ self,
395
+ num_queries: int,
396
+ embed_dim: int,
397
+ num_heads: int,
398
+ kv_dim: Optional[int] = None,
399
+ norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
400
+ do_post_projection: bool = True,
401
+ quant_config: Optional[QuantizationConfig] = None,
402
+ prefix: str = "",
403
+ ) -> None:
404
+ super().__init__()
405
+
406
+ self.num_queries = num_queries
407
+ self.embed_dim = embed_dim
408
+ self.num_heads = num_heads
409
+
410
+ self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
411
+ trunc_normal_(self.query, std=0.02)
412
+ if kv_dim is not None and kv_dim != embed_dim:
413
+ self.kv_proj = ReplicatedLinear(
414
+ kv_dim,
415
+ embed_dim,
416
+ bias=False,
417
+ quant_config=quant_config,
418
+ prefix=f"{prefix}.kv_proj",
419
+ )
420
+ else:
421
+ # Maintain the same return value with ReplicatedLinear.forward
422
+ self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
423
+ nn.Identity()(*args, **kwargs),
424
+ None,
425
+ )
426
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
427
+ self.ln_q = norm_layer(embed_dim)
428
+ self.ln_kv = norm_layer(embed_dim)
429
+ self.do_post_projection = do_post_projection
430
+ self.ln_post = norm_layer(embed_dim) if do_post_projection else None
431
+ self.proj = (
432
+ nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
433
+ if do_post_projection
434
+ else None
435
+ )
436
+
437
+ def _init_weights(self, m: nn.Module) -> None:
438
+ if isinstance(m, nn.Linear):
439
+ trunc_normal_(m.weight, std=0.02)
440
+ if isinstance(m, nn.Linear) and m.bias is not None:
441
+ nn.init.constant_(m.bias, 0)
442
+ elif isinstance(m, nn.LayerNorm):
443
+ nn.init.constant_(m.bias, 0)
444
+ nn.init.constant_(m.weight, 1.0)
445
+
446
+ def _repeat(self, query, N: int):
447
+ return query.unsqueeze(1).repeat(1, N, 1)
448
+
449
+
450
+ class Resampler2_5(BaseResampler):
451
+
452
+ def __init__(
453
+ self,
454
+ num_queries: int,
455
+ embed_dim: int,
456
+ num_heads: int,
457
+ kv_dim: Optional[int] = None,
458
+ norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
459
+ max_size: Tuple[int, int] = (70, 70),
460
+ quant_config: Optional[QuantizationConfig] = None,
461
+ prefix: str = "",
462
+ ) -> None:
463
+ super().__init__(
464
+ num_queries,
465
+ embed_dim,
466
+ num_heads,
467
+ kv_dim,
468
+ norm_layer,
469
+ quant_config=quant_config,
470
+ prefix=prefix,
471
+ )
472
+
473
+ self.max_size = max_size
474
+ self._set_2d_pos_cache(self.max_size)
475
+
476
+ self.apply(self._init_weights)
477
+
478
+ def _set_2d_pos_cache(
479
+ self, max_size: Tuple[int, int], device: torch.types.Device = "cpu"
480
+ ) -> None:
481
+ pos_embed_arr = get_2d_sincos_pos_embed(
482
+ self.embed_dim, max_size, version=(2, 5)
483
+ )
484
+ pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
485
+ self.register_buffer("pos_embed", pos_embed, persistent=False)
486
+
487
+ def _adjust_pos_cache(
488
+ self, tgt_sizes: torch.Tensor, device: torch.types.Device
489
+ ) -> None:
490
+ max_h = tgt_sizes[:, 0].max().item()
491
+ max_w = tgt_sizes[:, 1].max().item()
492
+ assert isinstance(max_h, int) and isinstance(max_w, int)
493
+
494
+ if max_h > self.max_size[0] or max_w > self.max_size[1]:
495
+ self.max_size = (
496
+ max(max_h, self.max_size[0]),
497
+ max(max_w, self.max_size[1]),
498
+ )
499
+ self._set_2d_pos_cache(self.max_size, device)
500
+
501
+ def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor:
502
+ assert x.shape[0] == tgt_sizes.shape[0]
503
+ bs = x.shape[0]
504
+
505
+ device = x.device
506
+ dtype = x.dtype
507
+
508
+ patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
509
+
510
+ self._adjust_pos_cache(tgt_sizes, device=device)
511
+
512
+ max_patch_len = patch_len.max().item()
513
+ assert isinstance(max_patch_len, int)
514
+
515
+ key_padding_mask = torch.zeros(
516
+ (bs, max_patch_len), dtype=torch.bool, device=device
517
+ )
518
+
519
+ pos_embed = []
520
+ for i in range(bs):
521
+ tgt_h, tgt_w = tgt_sizes[i].tolist()
522
+ pos_embed.append(
523
+ self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)
524
+ ) # patches * D
525
+ key_padding_mask[i, patch_len[i] :] = True
526
+ pos_embed = torch.nn.utils.rnn.pad_sequence(
527
+ pos_embed, batch_first=True, padding_value=0.0
528
+ ).permute(
529
+ 1, 0, 2
530
+ ) # BLD => L * B * D
531
+ x, _ = self.kv_proj(x) # B * L * D
532
+ x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
533
+
534
+ q = self.ln_q(self.query) # Q * D
535
+
536
+ out = self.attn(
537
+ self._repeat(q, bs), # Q * B * D
538
+ x + pos_embed, # L * B * D + L * B * D
539
+ x,
540
+ key_padding_mask=key_padding_mask,
541
+ )[0]
542
+ # out: Q * B * D
543
+ x = out.permute(1, 0, 2) # B * Q * D
544
+
545
+ x = self.ln_post(x)
546
+ x = x @ self.proj
547
+ return x
548
+
549
+
550
+ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
551
+ version_float = getattr(config, "version", None)
552
+
553
+ # The old configs do not include version number
554
+ # TODO: Remove this after the HF repos are updated
555
+ if version_float is None:
556
+ if config.hidden_size == 2304 and config.query_num == 64:
557
+ return 2, 0
558
+ return 2, 5
559
+
560
+ version_str = str(version_float)
561
+ return tuple(int(x) for x in version_str.split("."))
562
+
563
+
564
+ class MiniCPMVBaseModel(nn.Module):
565
+ """
566
+ The abstract class of MiniCPMV can only be inherited, but cannot be
567
+ instantiated.
568
+ """
569
+
570
+ def __init__(
571
+ self,
572
+ *,
573
+ config: PretrainedConfig,
574
+ quant_config: Optional[QuantizationConfig] = None,
575
+ ):
576
+ # multimodal_config = config.model_config.multimodal_config
577
+ super().__init__()
578
+ # All MiniCPM-V models disable `tie_word_embeddings` but
579
+ # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
580
+ # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
581
+ # and config class
582
+ self.config = config
583
+ # self.multimodal_config = multimodal_config
584
+
585
+ self.version = get_version_by_config(self.config)
586
+ self.llm = self.init_llm(config=config, quant_config=quant_config)
587
+ self.vpm = self.init_vision_module(config, quant_config)
588
+ self.vision_dim = (
589
+ self.vpm.embed_dim
590
+ if self.version == (2, 0)
591
+ else self.vpm.embeddings.embed_dim
592
+ )
593
+ self.embed_dim = self.config.hidden_size
594
+
595
+ self.resampler = self.init_resampler(
596
+ self.embed_dim, self.vision_dim, quant_config=quant_config
597
+ )
598
+
599
+ self.logits_processor = LogitsProcessor(config)
600
+
601
+ @cached_property
602
+ def sampler(self):
603
+ if hasattr(self.llm, "sampler"):
604
+ return self.llm.sampler
605
+
606
+ return get_sampler()
607
+
608
+ def _get_image_bounds(
609
+ self,
610
+ input_ids: torch.Tensor,
611
+ pad_values: List[int],
612
+ im_start_id: torch.Tensor,
613
+ im_end_id: torch.Tensor,
614
+ slice_start_id: Optional[torch.Tensor] = None,
615
+ slice_end_id: Optional[torch.Tensor] = None,
616
+ ) -> torch.Tensor:
617
+ """
618
+ Returns a tensor indicating the bounds (start and end token ids) of the images
619
+ """
620
+ # All the images in the batch should share the same special image
621
+ # bound token ids.
622
+ start_cond = input_ids == im_start_id[0]
623
+ end_cond = input_ids == im_end_id[0]
624
+ if slice_start_id is not None:
625
+ start_cond |= input_ids == slice_start_id[0]
626
+ end_cond |= input_ids == slice_end_id[0]
627
+
628
+ (image_start_tokens,) = torch.where(start_cond)
629
+ image_start_tokens += 1
630
+ (image_end_tokens,) = torch.where(end_cond)
631
+
632
+ # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
633
+ if len(image_start_tokens) != len(image_end_tokens):
634
+ if (
635
+ len(image_start_tokens) + 1 == len(image_end_tokens)
636
+ and input_ids[0] in pad_values
637
+ and image_end_tokens[0] < image_start_tokens[0]
638
+ ):
639
+ image_start_tokens = torch.cat(
640
+ [
641
+ torch.tensor([0], device=image_start_tokens.device),
642
+ image_start_tokens,
643
+ ]
644
+ )
645
+ valid_image_nums = min(len(image_start_tokens), len(image_end_tokens))
646
+
647
+ if valid_image_nums == 0:
648
+ return torch.zeros((0, 2), device=input_ids.device)
649
+
650
+ # Filter out pairs where start_token >= end_token
651
+ valid_pairs = []
652
+ for i in range(valid_image_nums):
653
+ start_token = image_start_tokens[i]
654
+ end_token = image_end_tokens[i]
655
+ if start_token < end_token:
656
+ valid_pairs.append((start_token, end_token))
657
+
658
+ if not valid_pairs:
659
+ return torch.zeros((0, 2), device=input_ids.device)
660
+
661
+ # Convert valid pairs to tensor
662
+ valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
663
+ return valid_pairs_tensor
664
+
665
+ def get_embedding(
666
+ self,
667
+ input_ids: torch.Tensor,
668
+ image_inputs: Optional[MiniCPMVImageInputs],
669
+ forward_batch: ForwardBatch,
670
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
671
+ vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
672
+
673
+ if image_inputs is None: # No image
674
+ vision_hidden_states = torch.tensor([], device=input_ids.device)
675
+ else:
676
+ if image_inputs["type"] == "image_embeds":
677
+ vision_hidden_states = (
678
+ image_inputs["data"]
679
+ .type(vlm_embedding.dtype)
680
+ .to(vlm_embedding.device)
681
+ )
682
+ else:
683
+ vision_hidden_states = self.get_vision_hidden_states(
684
+ forward_batch, image_inputs
685
+ )
686
+
687
+ # See NOTE in _parse_and_validate_inputs
688
+ image_bounds = image_inputs["image_bounds"]
689
+ if len(image_bounds) > 0:
690
+ image_indices = torch.stack(
691
+ [
692
+ torch.arange(start, end, dtype=torch.long)
693
+ for start, end in image_bounds.tolist()
694
+ ]
695
+ ).to(vlm_embedding.device)
696
+ vlm_embedding.scatter_(
697
+ 0,
698
+ image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
699
+ vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
700
+ )
701
+
702
+ return vlm_embedding, vision_hidden_states
703
+
704
+ def _parse_and_validate_inputs(
705
+ self,
706
+ input_ids: torch.Tensor,
707
+ **kwargs: object,
708
+ ) -> Optional[MiniCPMVImageInputs]:
709
+ pixel_values = kwargs.pop("pixel_values", [])
710
+ tgt_sizes = kwargs.pop("tgt_sizes", [])
711
+ im_start_id = kwargs.pop("im_start_id", None)
712
+ im_end_id = kwargs.pop("im_end_id", None)
713
+ slice_start_id = kwargs.pop("slice_start_id", None)
714
+ slice_end_id = kwargs.pop("slice_end_id", None)
715
+ image_embeds = kwargs.pop("image_embeds", None)
716
+ pad_values = kwargs.pop("pad_values", None)
717
+
718
+ if image_embeds is not None:
719
+ image_bounds = self._get_image_bounds(
720
+ input_ids=input_ids,
721
+ pad_values=pad_values,
722
+ im_start_id=im_start_id,
723
+ im_end_id=im_end_id,
724
+ slice_start_id=slice_start_id,
725
+ slice_end_id=slice_end_id,
726
+ )
727
+ if not isinstance(image_embeds, (torch.Tensor, list)):
728
+ raise ValueError(
729
+ f"Incorrect type of image embeds. "
730
+ f"Got type: {type(image_embeds)}"
731
+ )
732
+
733
+ if isinstance(image_embeds, list):
734
+ image_embeds = torch.concat(image_embeds)
735
+
736
+ return MiniCPMVImageEmbeddingInputs(
737
+ image_bounds=image_bounds,
738
+ data=image_embeds,
739
+ type="image_embeds",
740
+ )
741
+
742
+ if not isinstance(pixel_values, (torch.Tensor, list)):
743
+ raise ValueError(
744
+ "Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
745
+ )
746
+
747
+ if not isinstance(tgt_sizes, (torch.Tensor, list)):
748
+ raise ValueError(
749
+ "Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
750
+ )
751
+
752
+ if len(pixel_values) != len(tgt_sizes):
753
+ raise ValueError(
754
+ "Inconsistent batch lengths, found: "
755
+ f"{len(pixel_values)} vs. {len(tgt_sizes)}"
756
+ )
757
+
758
+ pixel_values_flat: List[torch.Tensor] = []
759
+ tgt_sizes_flat: List[torch.Tensor] = []
760
+ for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
761
+ if len(pixel_b) != len(tgt_b):
762
+ raise ValueError(
763
+ "Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
764
+ )
765
+
766
+ for pixel_n, tgt_n in zip(pixel_b, tgt_b):
767
+ pixel_values_flat += pixel_n
768
+ tgt_sizes_flat += tgt_n
769
+
770
+ # NOTE: Input IDs does not contain image tokens during memory profiling,
771
+ # so we allow it to be empty
772
+ if len(pixel_values_flat) != len(tgt_sizes_flat):
773
+ raise ValueError(
774
+ "Inconsistent flattened lengths, found: "
775
+ f"{len(pixel_values_flat)} vs. "
776
+ f"{len(tgt_sizes_flat)}"
777
+ )
778
+
779
+ if len(pixel_values_flat) == 0:
780
+ return None
781
+
782
+ image_bounds = self._get_image_bounds(
783
+ input_ids=input_ids,
784
+ pad_values=pad_values,
785
+ im_start_id=im_start_id,
786
+ im_end_id=im_end_id,
787
+ slice_start_id=slice_start_id,
788
+ slice_end_id=slice_end_id,
789
+ )
790
+ return MiniCPMVImagePixelInputs(
791
+ image_bounds=image_bounds.to(device=input_ids.device),
792
+ data=pixel_values_flat,
793
+ tgt_sizes=torch.stack(tgt_sizes_flat),
794
+ type="pixel_values",
795
+ )
796
+
797
+ def forward(
798
+ self,
799
+ input_ids: torch.Tensor,
800
+ positions: torch.Tensor,
801
+ forward_batch: ForwardBatch,
802
+ **kwargs: Any,
803
+ ) -> torch.Tensor:
804
+ if forward_batch.image_inputs is not None and forward_batch.image_inputs != [
805
+ None
806
+ ]:
807
+ kwargs.update(
808
+ {
809
+ "pixel_values": (
810
+ None
811
+ if forward_batch.image_inputs is None
812
+ else [
813
+ i.pixel_values
814
+ for i in forward_batch.image_inputs
815
+ if i is not None
816
+ ]
817
+ ),
818
+ "tgt_sizes": (
819
+ None
820
+ if forward_batch.image_inputs is None
821
+ else [
822
+ i.tgt_sizes
823
+ for i in forward_batch.image_inputs
824
+ if i is not None
825
+ ]
826
+ ),
827
+ "im_start_id": forward_batch.image_inputs[0].im_start_id,
828
+ "im_end_id": forward_batch.image_inputs[0].im_end_id,
829
+ "slice_start_id": forward_batch.image_inputs[0].slice_start_id,
830
+ "slice_end_id": forward_batch.image_inputs[0].slice_end_id,
831
+ "pad_values": forward_batch.image_inputs[0].pad_values,
832
+ }
833
+ )
834
+
835
+ image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
836
+
837
+ # Clamp input ids. This is because the input_ids for the image tokens are
838
+ # filled with the hash values of the image for the prefix matching in the radix attention.
839
+ # There values are useless because their embeddings will be replaced by vision embeddings anyway.
840
+ input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
841
+
842
+ vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs, forward_batch)
843
+
844
+ # always pass the input via `inputs_embeds`
845
+ # to make sure the computation graph is consistent
846
+ # for `torch.compile` integration
847
+ input_ids = None
848
+
849
+ hidden_states = self.llm.model(
850
+ input_ids=input_ids,
851
+ positions=positions,
852
+ forward_batch=forward_batch,
853
+ input_embeds=vlm_embeddings,
854
+ )
855
+
856
+ return self.logits_processor(
857
+ input_ids, hidden_states, self.llm.lm_head, forward_batch
858
+ )
859
+
860
+ def compute_logits(
861
+ self,
862
+ hidden_states: torch.Tensor,
863
+ sampling_metadata: SamplingMetadata,
864
+ ) -> Optional[torch.Tensor]:
865
+ return self.llm.compute_logits(hidden_states, sampling_metadata)
866
+
867
+ def sample(
868
+ self,
869
+ logits: torch.Tensor,
870
+ sampling_metadata: SamplingMetadata,
871
+ ) -> Optional[SamplerOutput]:
872
+ next_tokens = self.sampler(logits, sampling_metadata)
873
+ return next_tokens
874
+
875
+ def get_mm_mapping(self) -> MultiModelKeys:
876
+ """
877
+ Get the module prefix in multimodal models
878
+ """
879
+ return MultiModelKeys.from_string_field(
880
+ language_model="llm", connector="resampler", tower_model="vpm"
881
+ )
882
+
883
+ def init_llm(
884
+ self,
885
+ config: Qwen2Config,
886
+ quant_config: Optional[QuantizationConfig] = None,
887
+ ) -> nn.Module:
888
+ raise NotImplementedError
889
+
890
+ def init_vision_module(
891
+ self,
892
+ config: PretrainedConfig,
893
+ quant_config: Optional[QuantizationConfig],
894
+ ) -> nn.Module:
895
+ raise NotImplementedError
896
+
897
+ def init_resampler(
898
+ self,
899
+ embed_dim: int,
900
+ vision_dim: int,
901
+ quant_config: Optional[QuantizationConfig] = None,
902
+ ) -> nn.Module:
903
+ raise NotImplementedError
904
+
905
+ def get_vision_embedding(
906
+ self,
907
+ pixel_values: List[torch.Tensor],
908
+ patch_attn_mask: Optional[torch.Tensor] = None,
909
+ tgt_sizes: Optional[torch.Tensor] = None,
910
+ ) -> torch.Tensor:
911
+ raise NotImplementedError
912
+
913
+ def get_vision_hidden_states(
914
+ self, forward_batch: ForwardBatch, data: MiniCPMVImageInputs
915
+ ) -> torch.Tensor:
916
+ raise NotImplementedError
917
+
918
+
919
+ class MiniCPMV2_6(MiniCPMVBaseModel):
920
+ packed_modules_mapping = {
921
+ "qkv_proj": [
922
+ "q_proj",
923
+ "k_proj",
924
+ "v_proj",
925
+ ],
926
+ "gate_up_proj": [
927
+ "gate_proj",
928
+ "up_proj",
929
+ ],
930
+ }
931
+ # LoRA specific attributes
932
+ supported_lora_modules = [
933
+ # vision encoder
934
+ "fc1",
935
+ "fc2",
936
+ "out_proj",
937
+ # language model
938
+ "qkv_proj", # same name with vision encoder
939
+ "o_proj",
940
+ "gate_up_proj",
941
+ "down_proj",
942
+ # resampler
943
+ "kv_proj",
944
+ ]
945
+
946
+ # BitandBytes specific attributes
947
+ bitsandbytes_stacked_params_mapping = {
948
+ # shard_name, weight_name, index
949
+ "q_proj": ("qkv_proj", 0),
950
+ "k_proj": ("qkv_proj", 1),
951
+ "v_proj": ("qkv_proj", 2),
952
+ "gate_proj": ("gate_up_proj", 0),
953
+ "up_proj": ("gate_up_proj", 1),
954
+ }
955
+
956
+ embedding_modules = {}
957
+ embedding_padding_modules = []
958
+
959
+ def __init__(
960
+ self,
961
+ config: PretrainedConfig,
962
+ quant_config: Optional[QuantizationConfig] = None,
963
+ ):
964
+ super().__init__(config=config, quant_config=quant_config)
965
+ assert self.version == (2, 6)
966
+
967
+ def init_llm(
968
+ self,
969
+ config: Qwen2Config,
970
+ quant_config: Optional[QuantizationConfig] = None,
971
+ ) -> nn.Module:
972
+ return Qwen2ForCausalLM(config=config, quant_config=quant_config)
973
+
974
+ def init_vision_module(
975
+ self,
976
+ config: PretrainedConfig,
977
+ quant_config: Optional[QuantizationConfig],
978
+ ) -> nn.Module:
979
+ model = Idefics2VisionTransformer(
980
+ config=config.vision_config, quant_config=quant_config
981
+ )
982
+ if self.config.drop_vision_last_layer:
983
+ model.encoder.layers = model.encoder.layers[:-1]
984
+
985
+ setattr(model, "embed_dim", model.embeddings.embed_dim)
986
+ setattr(model, "patch_size", model.embeddings.patch_size)
987
+ return model
988
+
989
+ def init_resampler(
990
+ self,
991
+ embed_dim: int,
992
+ vision_dim: int,
993
+ quant_config: Optional[QuantizationConfig] = None,
994
+ ) -> nn.Module:
995
+ with set_default_torch_dtype(torch.float16):
996
+ # The resampler in 2.6 remains consistent with the one in 2.5.
997
+ resampler = Resampler2_5(
998
+ num_queries=self.config.query_num,
999
+ embed_dim=embed_dim,
1000
+ num_heads=embed_dim // 128,
1001
+ kv_dim=vision_dim,
1002
+ quant_config=quant_config,
1003
+ )
1004
+
1005
+ return resampler.to(device="cuda", dtype=torch.get_default_dtype())
1006
+
1007
+ def get_vision_embedding(
1008
+ self,
1009
+ pixel_values: List[torch.Tensor],
1010
+ patch_attn_mask: Optional[torch.Tensor] = None,
1011
+ tgt_sizes: Optional[torch.Tensor] = None,
1012
+ ) -> torch.Tensor:
1013
+ vision_embedding = self.vpm(
1014
+ pixel_values,
1015
+ patch_attention_mask=patch_attn_mask,
1016
+ tgt_sizes=tgt_sizes,
1017
+ )
1018
+ return vision_embedding
1019
+
1020
+ def get_vision_hidden_states(
1021
+ self,
1022
+ forward_batch: ForwardBatch,
1023
+ data: MiniCPMVImageInputs,
1024
+ ) -> torch.Tensor:
1025
+ pixel_values = data["data"]
1026
+ tgt_sizes = data["tgt_sizes"]
1027
+
1028
+ device = self.vpm.embeddings.position_embedding.weight.device
1029
+ dtype = self.vpm.embeddings.position_embedding.weight.dtype
1030
+ all_pixel_values_lst = [
1031
+ i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
1032
+ ]
1033
+
1034
+ max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
1035
+ assert isinstance(max_patches, int)
1036
+
1037
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(
1038
+ all_pixel_values_lst, batch_first=True, padding_value=0.0
1039
+ )
1040
+ B, L, _ = all_pixel_values.shape
1041
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
1042
+ patch_attn_mask = torch.zeros(
1043
+ (B, 1, max_patches), dtype=torch.bool, device=device
1044
+ )
1045
+ for i in range(B):
1046
+ patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
1047
+ vision_embedding = self.vpm(
1048
+ all_pixel_values.type(dtype),
1049
+ forward_batch=forward_batch,
1050
+ patch_attention_mask=patch_attn_mask,
1051
+ tgt_sizes=tgt_sizes,
1052
+ )
1053
+
1054
+ return self.resampler(vision_embedding, tgt_sizes)
1055
+
1056
+ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
1057
+ if not isinstance(image_inputs.im_start_id, list) or not isinstance(
1058
+ image_inputs.im_end_id, list
1059
+ ):
1060
+ return input_ids
1061
+
1062
+ new_input_ids = []
1063
+ last_idx = 0
1064
+ image_idx = -1
1065
+ image_inputs.image_offsets = []
1066
+
1067
+ # Get all special token IDs
1068
+ im_start_id = (
1069
+ image_inputs.im_start_id[0].item()
1070
+ if isinstance(image_inputs.im_start_id[0], torch.Tensor)
1071
+ else image_inputs.im_start_id[0]
1072
+ )
1073
+ im_end_id = (
1074
+ image_inputs.im_end_id[0].item()
1075
+ if isinstance(image_inputs.im_end_id[0], torch.Tensor)
1076
+ else image_inputs.im_end_id[0]
1077
+ )
1078
+ slice_start_id = (
1079
+ image_inputs.slice_start_id[0].item()
1080
+ if isinstance(image_inputs.slice_start_id[0], torch.Tensor)
1081
+ else image_inputs.slice_start_id[0]
1082
+ )
1083
+ slice_end_id = (
1084
+ image_inputs.slice_end_id[0].item()
1085
+ if isinstance(image_inputs.slice_end_id[0], torch.Tensor)
1086
+ else image_inputs.slice_end_id[0]
1087
+ )
1088
+
1089
+ # Find all start and end positions for both types
1090
+ start_indices = [
1091
+ i
1092
+ for i, x in enumerate(input_ids)
1093
+ if x == im_start_id or x == slice_start_id
1094
+ ]
1095
+ end_indices = [
1096
+ i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id
1097
+ ]
1098
+
1099
+ if len(start_indices) != len(end_indices):
1100
+ return input_ids
1101
+ # Process each region (both image and slice)
1102
+ for start_idx, end_idx in zip(start_indices, end_indices):
1103
+ # Add non-image tokens before this region
1104
+ new_input_ids.extend(
1105
+ input_ids[last_idx : start_idx + 1]
1106
+ ) # include start token
1107
+
1108
+ is_image_start = input_ids[start_idx] == im_start_id
1109
+
1110
+ if is_image_start:
1111
+ image_inputs.image_offsets += [start_idx]
1112
+ image_idx += 1
1113
+
1114
+ num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
1115
+
1116
+ # Generate pad_ids
1117
+ pad_values = [image_inputs.pad_values[image_idx]]
1118
+
1119
+ pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
1120
+ pad_ids = pad_ids[:num_tokens]
1121
+
1122
+ # Add pad_ids
1123
+ new_input_ids.extend(pad_ids)
1124
+
1125
+ # Update last_idx to after end token
1126
+ last_idx = end_idx
1127
+
1128
+ # Add remaining tokens after last region
1129
+ new_input_ids.extend(input_ids[last_idx:])
1130
+ assert len(input_ids) == len(new_input_ids)
1131
+ return new_input_ids
1132
+
1133
+
1134
+ _SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
1135
+
1136
+
1137
+ class MiniCPMV:
1138
+ """
1139
+ Different versions of MiniCPMV use different visual encoders and LLMs,
1140
+ which is not conducive to the current integration logic of LoRA and
1141
+ bitsandbytes in vLLM. Therefore, it is necessary to separate them.
1142
+ """
1143
+
1144
+ # Ensure that the LoRA support check passes when the class is not
1145
+ # initialized, but set all these attributes to empty.
1146
+ packed_modules_mapping = {}
1147
+ supported_lora_modules = []
1148
+ embedding_modules = {}
1149
+ embedding_padding_modules = []
1150
+
1151
+ minicpmv: nn.Module
1152
+
1153
+ def __init__(
1154
+ self,
1155
+ config: PretrainedConfig,
1156
+ quant_config: Optional[QuantizationConfig] = None,
1157
+ ) -> None:
1158
+ super().__init__()
1159
+
1160
+ if not hasattr(config, "version"):
1161
+ version = (2, 6)
1162
+ else:
1163
+ version = str(config.version).split(".")
1164
+ version = tuple([int(x) for x in version])
1165
+ # Dispatch class based on version
1166
+ instance_class = _SUPPORT_VERSION.get(version)
1167
+ if instance_class is None:
1168
+ raise ValueError("Currently, MiniCPMV only supports versions 2.6")
1169
+
1170
+ try:
1171
+ minicpmv = instance_class(config=config, quant_config=quant_config)
1172
+ self.minicpmv = minicpmv
1173
+ except Exception as e:
1174
+ print(f"Failed to instantiate MiniCPMV: {e}")
1175
+ raise e
1176
+ self.config = config
1177
+
1178
+ def __getattr__(self, name):
1179
+ if name == "minicpmv":
1180
+ return None
1181
+ return getattr(self.minicpmv, name)
1182
+
1183
+ def __call__(self, *args, **kwargs):
1184
+ return self.minicpmv(*args, **kwargs)
1185
+
1186
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1187
+ stacked_params_mapping = [
1188
+ # (param_name, shard_name, shard_id)
1189
+ ("qkv_proj", "q_proj", "q"),
1190
+ ("qkv_proj", "k_proj", "k"),
1191
+ ("qkv_proj", "v_proj", "v"),
1192
+ ("gate_up_proj", "gate_proj", 0),
1193
+ ("gate_up_proj", "up_proj", 1),
1194
+ ]
1195
+
1196
+ params_dict = dict(self.minicpmv.named_parameters())
1197
+ for name, loaded_weight in weights:
1198
+ if "rotary_emb.inv_freq~" in name or "projector" in name:
1199
+ continue
1200
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
1201
+ # Models trained using ColossalAI may include these tensors in
1202
+ # the checkpoint. Skip them.
1203
+ continue
1204
+ if name.startswith("model.vision_tower") and name not in params_dict:
1205
+ continue
1206
+
1207
+ # adapt to VisionAttention
1208
+ name = name.replace(r"self_attn.out_proj", r"self_attn.proj")
1209
+
1210
+ if "sampler" in name:
1211
+ param = params_dict[name]
1212
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
1213
+ weight_loader(param, loaded_weight)
1214
+ continue
1215
+
1216
+ for param_name, weight_name, shard_id in stacked_params_mapping:
1217
+ # replace the name and load with customized loader
1218
+ if weight_name not in name:
1219
+ continue
1220
+ name = name.replace(weight_name, param_name)
1221
+ # # Skip loading extra bias for GPTQ models.
1222
+ if name.endswith(".bias") and name not in params_dict:
1223
+ continue
1224
+ param = params_dict[name]
1225
+ weight_loader = param.weight_loader
1226
+ weight_loader(param, loaded_weight, shard_id)
1227
+ break
1228
+ else:
1229
+ # Skip loading extra bias for GPTQ models.
1230
+ if name.endswith(".bias") and name not in params_dict:
1231
+ continue
1232
+
1233
+ param = params_dict[name]
1234
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
1235
+ weight_loader(param, loaded_weight)
1236
+
1237
+
1238
+ EntryClass = MiniCPMV