sglang 0.4.2__py3-none-any.whl → 0.4.2.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -149,6 +149,7 @@ class Scheduler:
149
149
  if not self.spec_algorithm.is_none()
150
150
  else 1
151
151
  )
152
+ self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
152
153
 
153
154
  # Distributed rank info
154
155
  self.dp_size = server_args.dp_size
@@ -831,10 +832,16 @@ class Scheduler:
831
832
  available_size = (
832
833
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
833
834
  )
834
- if available_size != self.max_total_num_tokens:
835
+ protected_size = self.tree_cache.protected_size()
836
+ memory_leak = available_size != (
837
+ self.max_total_num_tokens
838
+ if not self.enable_hierarchical_cache
839
+ else self.max_total_num_tokens - protected_size
840
+ )
841
+ if memory_leak:
835
842
  msg = (
836
843
  "KV cache pool leak detected!"
837
- f"{available_size=}, {self.max_total_num_tokens=}\n"
844
+ f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
838
845
  )
839
846
  warnings.warn(msg)
840
847
  if crash_on_warnings():
@@ -949,7 +956,14 @@ class Scheduler:
949
956
  res = adder.add_one_req(req)
950
957
  if res != AddReqResult.CONTINUE:
951
958
  if res == AddReqResult.NO_TOKEN:
952
- self.batch_is_full = True
959
+ if self.enable_hierarchical_cache:
960
+ # Set batch_is_full after making sure there are requests that can be served
961
+ self.batch_is_full = len(adder.can_run_list) > 0 or (
962
+ self.running_batch is not None
963
+ and not self.running_batch.is_empty()
964
+ )
965
+ else:
966
+ self.batch_is_full = True
953
967
  break
954
968
  if self.server_args.prefill_only_one_req:
955
969
  break
@@ -41,6 +41,10 @@ class BasePrefixCache(ABC):
41
41
  def evictable_size(self):
42
42
  pass
43
43
 
44
+ @abstractmethod
45
+ def protected_size(self):
46
+ raise NotImplementedError()
47
+
44
48
  def total_size(self):
45
49
  raise NotImplementedError()
46
50
 
@@ -85,3 +85,6 @@ class ChunkCache(BasePrefixCache):
85
85
 
86
86
  def evictable_size(self):
87
87
  return 0
88
+
89
+ def protected_size(self):
90
+ return 0
@@ -34,7 +34,10 @@ if TYPE_CHECKING:
34
34
 
35
35
 
36
36
  class TreeNode:
37
- def __init__(self):
37
+
38
+ counter = 0
39
+
40
+ def __init__(self, id: Optional[int] = None):
38
41
  self.children = defaultdict(TreeNode)
39
42
  self.parent = None
40
43
  self.key = None
@@ -42,6 +45,23 @@ class TreeNode:
42
45
  self.lock_ref = 0
43
46
  self.last_access_time = time.time()
44
47
 
48
+ self.hit_count = 0
49
+ # indicating the node is loading KV cache from host
50
+ self.loading = False
51
+ # store the host indices of KV cache
52
+ self.host_value = None
53
+
54
+ self.id = TreeNode.counter if id is None else id
55
+ TreeNode.counter += 1
56
+
57
+ @property
58
+ def evicted(self):
59
+ return self.value is None
60
+
61
+ @property
62
+ def backuped(self):
63
+ return self.host_value is not None
64
+
45
65
  def __lt__(self, other: "TreeNode"):
46
66
  return self.last_access_time < other.last_access_time
47
67
 
@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache):
75
95
  self.root_node.value = []
76
96
  self.root_node.lock_ref = 1
77
97
  self.evictable_size_ = 0
98
+ self.protected_size_ = 0
78
99
 
79
100
  def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
80
101
  """Find the matching prefix from the radix tree.
@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache):
203
224
  while node != self.root_node:
204
225
  if node.lock_ref == 0:
205
226
  self.evictable_size_ -= len(node.value)
227
+ self.protected_size_ += len(node.value)
206
228
  delta -= len(node.value)
207
229
  node.lock_ref += 1
208
230
  node = node.parent
@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache):
216
238
  while node != self.root_node:
217
239
  if node.lock_ref == 1:
218
240
  self.evictable_size_ += len(node.value)
241
+ self.protected_size_ -= len(node.value)
219
242
  delta += len(node.value)
220
243
  node.lock_ref -= 1
221
244
  node = node.parent
@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache):
224
247
  def evictable_size(self):
225
248
  return self.evictable_size_
226
249
 
250
+ def protected_size(self):
251
+ # protected size refers to the size of the cache that is locked
252
+ return self.protected_size_
253
+
227
254
  ##### Internal Helper Functions #####
228
255
 
229
256
  def _match_prefix_helper(
@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache):
303
330
  self.evictable_size_ -= len(node.key)
304
331
 
305
332
  def _total_size_helper(self, node: TreeNode):
333
+ if node.evicted:
334
+ return 0
306
335
  x = len(node.value)
307
336
  for child in node.children.values():
308
337
  x += self._total_size_helper(child)
@@ -1,6 +1,6 @@
1
1
  # Adapted from
2
2
  # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
3
- # Copyright 2023 The vLLM team.
3
+ # Copyright 2023 The SGLang team.
4
4
  # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
5
5
  #
6
6
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
@@ -20,7 +20,7 @@
20
20
  # See the License for the specific language governing permissions and
21
21
  # limitations under the License.
22
22
  """Inference-only MiniCPM-V model compatible with HuggingFace weights."""
23
- from functools import cached_property, partial
23
+ from functools import partial
24
24
  from typing import (
25
25
  Any,
26
26
  Callable,
@@ -33,16 +33,13 @@ from typing import (
33
33
  Union,
34
34
  )
35
35
 
36
+ import numpy as np
36
37
  import torch
37
38
  import torch.types
38
39
  from PIL import Image
39
40
  from torch import nn
40
41
  from torch.nn.init import trunc_normal_
41
42
  from transformers import PretrainedConfig
42
- from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
43
- from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
44
- from vllm.model_executor.models.module_mapping import MultiModelKeys
45
- from vllm.model_executor.sampling_metadata import SamplingMetadata
46
43
 
47
44
  from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
48
45
  from sglang.srt.layers.activation import get_act_fn
@@ -63,6 +60,88 @@ from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
63
60
  RawImageType = Union[Image.Image, torch.Tensor]
64
61
 
65
62
 
63
+ # sin/cos positional embedding helpers are adapted from:
64
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
65
+ def get_1d_sincos_pos_embed_from_grid(
66
+ embed_dim: int, pos: np.ndarray, version: Tuple[int, int] = (2, 0)
67
+ ) -> torch.Tensor:
68
+ """
69
+ embed_dim: output dimension for each position
70
+ pos: a list of positions to be encoded: size (M,) / (H, W)
71
+ out: (M, D) / (H, W, D)
72
+ """
73
+ assert embed_dim % 2 == 0
74
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
75
+ omega /= embed_dim / 2.0
76
+ omega = 1.0 / 10000**omega # (D/2,)
77
+
78
+ if version == (2, 0):
79
+ pos = pos.reshape(-1) # (M,)
80
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
81
+ emb_sin = np.sin(out) # (M, D/2)
82
+ emb_cos = np.cos(out) # (M, D/2)
83
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
84
+ else:
85
+ out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
86
+ emb_sin = np.sin(out) # (H, W, D/2)
87
+ emb_cos = np.cos(out) # (H, W, D/2)
88
+ emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
89
+ return emb
90
+
91
+
92
+ def get_2d_sincos_pos_embed_from_grid(
93
+ embed_dim: int, grid: np.ndarray, version: Tuple[int, int] = (2, 0)
94
+ ) -> torch.Tensor:
95
+ assert embed_dim % 2 == 0
96
+
97
+ # use half of dimensions to encode grid_h
98
+ emb_h = get_1d_sincos_pos_embed_from_grid(
99
+ embed_dim // 2, grid[0], version
100
+ ) # (H*W, D/2) or (H, W, D/2)
101
+ emb_w = get_1d_sincos_pos_embed_from_grid(
102
+ embed_dim // 2, grid[1], version
103
+ ) # (H*W, D/2) or (H, W, D/2)
104
+
105
+ if version == (2, 0):
106
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
107
+ else:
108
+ emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
109
+ return emb
110
+
111
+
112
+ def get_2d_sincos_pos_embed(
113
+ embed_dim: int,
114
+ grid_size: Union[int, Tuple[int, int]],
115
+ cls_token: bool = False,
116
+ version: Tuple[int, int] = (2, 0),
117
+ ) -> torch.Tensor:
118
+ """
119
+ grid_size: int of the grid height and width
120
+ return:
121
+ pos_embed: [grid_size*grid_size, embed_dim] or
122
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
123
+ """
124
+ if isinstance(grid_size, int):
125
+ grid_h_size, grid_w_size = grid_size, grid_size
126
+ else:
127
+ grid_h_size, grid_w_size = grid_size[0], grid_size[1]
128
+
129
+ grid_h = np.arange(grid_h_size, dtype=np.float32)
130
+ grid_w = np.arange(grid_w_size, dtype=np.float32)
131
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
132
+ grid = np.stack(grid, axis=0)
133
+ assert isinstance(grid, np.ndarray) and grid.shape == (2, grid_h_size, grid_w_size)
134
+
135
+ if version == (2, 0):
136
+ grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
137
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
138
+ if cls_token:
139
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
140
+ else:
141
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
142
+ return pos_embed
143
+
144
+
66
145
  class Idefics2VisionMLP(nn.Module):
67
146
 
68
147
  def __init__(
@@ -116,6 +195,10 @@ class Idefics2EncoderLayer(nn.Module):
116
195
  projection_size=config.intermediate_size,
117
196
  use_qkv_parallel=True,
118
197
  quant_config=quant_config,
198
+ dropout=config.attention_dropout,
199
+ use_context_forward=False,
200
+ use_full_precision_softmax=True,
201
+ flatten_batch=False,
119
202
  prefix=f"{prefix}.self_attn",
120
203
  )
121
204
  self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -126,7 +209,6 @@ class Idefics2EncoderLayer(nn.Module):
126
209
  self,
127
210
  hidden_states: torch.Tensor,
128
211
  cu_seqlens: torch.Tensor,
129
- forward_batch: ForwardBatch,
130
212
  ) -> torch.Tensor:
131
213
  """
132
214
  Args:
@@ -136,11 +218,8 @@ class Idefics2EncoderLayer(nn.Module):
136
218
  """
137
219
  residual = hidden_states
138
220
  hidden_states = self.layer_norm1(hidden_states)
139
- hidden_states = self.self_attn(
140
- hidden_states,
141
- cu_seqlens=cu_seqlens,
142
- # , forward_batch=forward_batch
143
- )
221
+ hidden_states = self.self_attn(hidden_states, cu_seqlens=cu_seqlens)
222
+
144
223
  hidden_states = residual + hidden_states
145
224
  residual = hidden_states
146
225
  hidden_states = self.layer_norm2(hidden_states)
@@ -181,7 +260,6 @@ class Idefics2Encoder(nn.Module):
181
260
  self,
182
261
  inputs_embeds: torch.Tensor,
183
262
  cu_seqlens: torch.Tensor,
184
- forward_batch: ForwardBatch,
185
263
  ) -> torch.Tensor:
186
264
  r"""
187
265
  Args:
@@ -195,7 +273,8 @@ class Idefics2Encoder(nn.Module):
195
273
  hidden_states = inputs_embeds
196
274
  for encoder_layer in self.layers:
197
275
  layer_outputs = encoder_layer(
198
- hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch
276
+ hidden_states,
277
+ cu_seqlens=cu_seqlens,
199
278
  )
200
279
  hidden_states = layer_outputs
201
280
  return hidden_states
@@ -232,19 +311,14 @@ class Idefics2VisionEmbeddings(nn.Module):
232
311
  self.num_positions = self.num_patches
233
312
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
234
313
 
235
- def forward(
314
+ def get_position_ids(
236
315
  self,
237
316
  pixel_values: torch.FloatTensor,
238
317
  patch_attention_mask: torch.BoolTensor,
239
318
  tgt_sizes: Optional[torch.IntTensor] = None,
240
- ) -> torch.Tensor:
319
+ ):
241
320
  batch_size, _, max_im_h, max_im_w = pixel_values.shape
242
- target_dtype = self.patch_embedding.weight.dtype
243
- pixel_values = pixel_values.to(
244
- device=self.patch_embedding.weight.device, dtype=target_dtype
245
- )
246
- patch_embeds = self.patch_embedding(pixel_values)
247
- embeddings = patch_embeds.flatten(2).transpose(1, 2)
321
+
248
322
  max_nb_patches_h, max_nb_patches_w = (
249
323
  max_im_h // self.patch_size,
250
324
  max_im_w // self.patch_size,
@@ -277,6 +351,24 @@ class Idefics2VisionEmbeddings(nn.Module):
277
351
  ).flatten()
278
352
  position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
279
353
  position_ids = position_ids.to(self.position_embedding.weight.device)
354
+ return position_ids
355
+
356
+ def forward(
357
+ self,
358
+ pixel_values: torch.FloatTensor,
359
+ patch_attention_mask: torch.BoolTensor,
360
+ tgt_sizes: Optional[torch.IntTensor] = None,
361
+ ) -> torch.Tensor:
362
+ target_dtype = self.patch_embedding.weight.dtype
363
+ pixel_values = pixel_values.to(
364
+ device=self.patch_embedding.weight.device, dtype=target_dtype
365
+ )
366
+ patch_embeds = self.patch_embedding(pixel_values)
367
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
368
+ position_ids = self.get_position_ids(
369
+ pixel_values, patch_attention_mask, tgt_sizes
370
+ )
371
+
280
372
  embeddings = embeddings + self.position_embedding(position_ids)
281
373
  return embeddings
282
374
 
@@ -287,7 +379,6 @@ class Idefics2VisionTransformer(nn.Module):
287
379
  self,
288
380
  config: PretrainedConfig,
289
381
  quant_config: Optional[QuantizationConfig] = None,
290
- prefix: str = "",
291
382
  ) -> None:
292
383
  super().__init__()
293
384
 
@@ -302,8 +393,6 @@ class Idefics2VisionTransformer(nn.Module):
302
393
 
303
394
  def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
304
395
  patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,)
305
-
306
- # 做 prefix sum 来得到 cu_seqlens,注意在最前面插一个 0 作为 offset
307
396
  cu_seqlens = torch.cat(
308
397
  [
309
398
  torch.tensor([0], device=patch_len.device, dtype=torch.int32),
@@ -316,19 +405,18 @@ class Idefics2VisionTransformer(nn.Module):
316
405
  def forward(
317
406
  self,
318
407
  pixel_values,
319
- forward_batch: ForwardBatch,
320
408
  patch_attention_mask: Optional[torch.BoolTensor] = None,
321
409
  tgt_sizes: Optional[torch.IntTensor] = None,
322
410
  ) -> torch.Tensor:
323
411
  hidden_states = self.embeddings(
324
412
  pixel_values=pixel_values,
325
413
  patch_attention_mask=patch_attention_mask,
326
- # forward_batch=forward_batch,
327
414
  tgt_sizes=tgt_sizes,
328
415
  )
329
416
  cu_seqlens = self.compute_cu_seqlens(tgt_sizes)
330
417
  encoder_outputs = self.encoder(
331
- hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch
418
+ hidden_states,
419
+ cu_seqlens=cu_seqlens,
332
420
  )
333
421
  last_hidden_state = self.post_layernorm(encoder_outputs)
334
422
  return last_hidden_state
@@ -573,14 +661,12 @@ class MiniCPMVBaseModel(nn.Module):
573
661
  config: PretrainedConfig,
574
662
  quant_config: Optional[QuantizationConfig] = None,
575
663
  ):
576
- # multimodal_config = config.model_config.multimodal_config
577
664
  super().__init__()
578
665
  # All MiniCPM-V models disable `tie_word_embeddings` but
579
666
  # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
580
- # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
667
+ # check `tie_word_embeddings` until SGLang integrate MiniCPM-V model
581
668
  # and config class
582
669
  self.config = config
583
- # self.multimodal_config = multimodal_config
584
670
 
585
671
  self.version = get_version_by_config(self.config)
586
672
  self.llm = self.init_llm(config=config, quant_config=quant_config)
@@ -598,13 +684,6 @@ class MiniCPMVBaseModel(nn.Module):
598
684
 
599
685
  self.logits_processor = LogitsProcessor(config)
600
686
 
601
- @cached_property
602
- def sampler(self):
603
- if hasattr(self.llm, "sampler"):
604
- return self.llm.sampler
605
-
606
- return get_sampler()
607
-
608
687
  def _get_image_bounds(
609
688
  self,
610
689
  input_ids: torch.Tensor,
@@ -666,7 +745,6 @@ class MiniCPMVBaseModel(nn.Module):
666
745
  self,
667
746
  input_ids: torch.Tensor,
668
747
  image_inputs: Optional[MiniCPMVImageInputs],
669
- forward_batch: ForwardBatch,
670
748
  ) -> Tuple[torch.Tensor, torch.Tensor]:
671
749
  vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
672
750
 
@@ -680,10 +758,7 @@ class MiniCPMVBaseModel(nn.Module):
680
758
  .to(vlm_embedding.device)
681
759
  )
682
760
  else:
683
- vision_hidden_states = self.get_vision_hidden_states(
684
- forward_batch, image_inputs
685
- )
686
-
761
+ vision_hidden_states = self.get_vision_hidden_states(image_inputs)
687
762
  # See NOTE in _parse_and_validate_inputs
688
763
  image_bounds = image_inputs["image_bounds"]
689
764
  if len(image_bounds) > 0:
@@ -693,6 +768,7 @@ class MiniCPMVBaseModel(nn.Module):
693
768
  for start, end in image_bounds.tolist()
694
769
  ]
695
770
  ).to(vlm_embedding.device)
771
+
696
772
  vlm_embedding.scatter_(
697
773
  0,
698
774
  image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
@@ -839,7 +915,7 @@ class MiniCPMVBaseModel(nn.Module):
839
915
  # There values are useless because their embeddings will be replaced by vision embeddings anyway.
840
916
  input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
841
917
 
842
- vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs, forward_batch)
918
+ vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
843
919
 
844
920
  # always pass the input via `inputs_embeds`
845
921
  # to make sure the computation graph is consistent
@@ -857,29 +933,6 @@ class MiniCPMVBaseModel(nn.Module):
857
933
  input_ids, hidden_states, self.llm.lm_head, forward_batch
858
934
  )
859
935
 
860
- def compute_logits(
861
- self,
862
- hidden_states: torch.Tensor,
863
- sampling_metadata: SamplingMetadata,
864
- ) -> Optional[torch.Tensor]:
865
- return self.llm.compute_logits(hidden_states, sampling_metadata)
866
-
867
- def sample(
868
- self,
869
- logits: torch.Tensor,
870
- sampling_metadata: SamplingMetadata,
871
- ) -> Optional[SamplerOutput]:
872
- next_tokens = self.sampler(logits, sampling_metadata)
873
- return next_tokens
874
-
875
- def get_mm_mapping(self) -> MultiModelKeys:
876
- """
877
- Get the module prefix in multimodal models
878
- """
879
- return MultiModelKeys.from_string_field(
880
- language_model="llm", connector="resampler", tower_model="vpm"
881
- )
882
-
883
936
  def init_llm(
884
937
  self,
885
938
  config: Qwen2Config,
@@ -910,9 +963,7 @@ class MiniCPMVBaseModel(nn.Module):
910
963
  ) -> torch.Tensor:
911
964
  raise NotImplementedError
912
965
 
913
- def get_vision_hidden_states(
914
- self, forward_batch: ForwardBatch, data: MiniCPMVImageInputs
915
- ) -> torch.Tensor:
966
+ def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor:
916
967
  raise NotImplementedError
917
968
 
918
969
 
@@ -1019,7 +1070,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
1019
1070
 
1020
1071
  def get_vision_hidden_states(
1021
1072
  self,
1022
- forward_batch: ForwardBatch,
1023
1073
  data: MiniCPMVImageInputs,
1024
1074
  ) -> torch.Tensor:
1025
1075
  pixel_values = data["data"]
@@ -1042,15 +1092,18 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
1042
1092
  patch_attn_mask = torch.zeros(
1043
1093
  (B, 1, max_patches), dtype=torch.bool, device=device
1044
1094
  )
1045
- for i in range(B):
1046
- patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
1095
+
1096
+ tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
1097
+ mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
1098
+ patch_attn_mask[:, 0, :] = torch.arange(
1099
+ patch_attn_mask.size(2), device=patch_attn_mask.device
1100
+ ).unsqueeze(0) < mask_shapes.unsqueeze(1)
1101
+
1047
1102
  vision_embedding = self.vpm(
1048
1103
  all_pixel_values.type(dtype),
1049
- forward_batch=forward_batch,
1050
1104
  patch_attention_mask=patch_attn_mask,
1051
1105
  tgt_sizes=tgt_sizes,
1052
1106
  )
1053
-
1054
1107
  return self.resampler(vision_embedding, tgt_sizes)
1055
1108
 
1056
1109
  def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
@@ -1138,7 +1191,7 @@ class MiniCPMV:
1138
1191
  """
1139
1192
  Different versions of MiniCPMV use different visual encoders and LLMs,
1140
1193
  which is not conducive to the current integration logic of LoRA and
1141
- bitsandbytes in vLLM. Therefore, it is necessary to separate them.
1194
+ bitsandbytes in SGLang. Therefore, it is necessary to separate them.
1142
1195
  """
1143
1196
 
1144
1197
  # Ensure that the LoRA support check passes when the class is not
@@ -17,6 +17,7 @@ from transformers.models.mllama.modeling_mllama import (
17
17
  import sglang.srt.distributed.parallel_state as ps
18
18
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
19
19
  from sglang.srt.layers.activation import get_act_fn
20
+ from sglang.srt.layers.attention.vision import VisionAttention
20
21
  from sglang.srt.layers.layernorm import RMSNorm
21
22
  from sglang.srt.layers.linear import (
22
23
  ColumnParallelLinear,
@@ -145,61 +146,6 @@ class MllamaPrecomputedPositionEmbedding(nn.Module):
145
146
  return hidden_state
146
147
 
147
148
 
148
- class MllamaVisionSdpaAttention(nn.Module):
149
- def __init__(self, config: config_mllama.MllamaVisionConfig):
150
- super().__init__()
151
-
152
- model_parallel_size = get_tensor_model_parallel_world_size()
153
- self.embed_dim = config.hidden_size
154
- self.num_heads = config.attention_heads
155
- self.head_dim = config.hidden_size // config.attention_heads
156
- self.num_local_heads = self.num_heads // model_parallel_size
157
- self.q_size = self.num_local_heads * self.head_dim
158
- self.kv_size = self.num_local_heads * self.head_dim
159
-
160
- self.qkv_proj = QKVParallelLinear(
161
- self.embed_dim,
162
- self.head_dim,
163
- self.num_heads,
164
- bias=False,
165
- )
166
- self.o_proj = RowParallelLinear(
167
- self.num_heads * self.head_dim,
168
- self.embed_dim,
169
- bias=False,
170
- input_is_parallel=True,
171
- )
172
-
173
- def forward(
174
- self,
175
- hidden_state: torch.Tensor,
176
- attention_mask: Optional[torch.Tensor] = None,
177
- ) -> torch.Tensor:
178
- qkv, _ = self.qkv_proj(hidden_state)
179
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
180
- q = q.view(
181
- q.shape[0], q.shape[1], self.num_local_heads, self.head_dim
182
- ).transpose(1, 2)
183
- k = k.view(
184
- k.shape[0], k.shape[1], self.num_local_heads, self.head_dim
185
- ).transpose(1, 2)
186
- v = v.view(
187
- v.shape[0], v.shape[1], self.num_local_heads, self.head_dim
188
- ).transpose(1, 2)
189
-
190
- # TODO: remove padding in image encoder
191
- attn_output = F.scaled_dot_product_attention(
192
- q, k, v, attn_mask=attention_mask, dropout_p=0.0
193
- )
194
-
195
- attn_output = attn_output.transpose(1, 2).contiguous()
196
- attn_output = attn_output.reshape(
197
- attn_output.shape[0], attn_output.shape[1], -1
198
- )
199
- output, _ = self.o_proj(attn_output)
200
- return output
201
-
202
-
203
149
  class MllamaVisionMLP(nn.Module):
204
150
  def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
205
151
  super().__init__()
@@ -237,7 +183,17 @@ class MllamaVisionEncoderLayer(nn.Module):
237
183
  self.is_gated = is_gated
238
184
  self.intermediate_size = config.intermediate_size
239
185
 
240
- self.self_attn = MllamaVisionSdpaAttention(config)
186
+ self.self_attn = VisionAttention(
187
+ self.hidden_size,
188
+ self.num_attention_heads,
189
+ self.hidden_size,
190
+ use_qkv_parallel=True,
191
+ quant_config=None,
192
+ dropout=0.0,
193
+ use_context_forward=False,
194
+ use_full_precision_softmax=False,
195
+ flatten_batch=False,
196
+ )
241
197
  self.mlp = MllamaVisionMLP(config)
242
198
 
243
199
  self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
@@ -992,6 +948,10 @@ class MllamaForConditionalGeneration(nn.Module):
992
948
  weight_loader(param, loaded_weight, shard_id)
993
949
  break
994
950
  else:
951
+ if "vision_model" in name:
952
+ # adapt to VisionAttention
953
+ name = name.replace("self_attn.o_proj", "self_attn.proj")
954
+
995
955
  param = params_dict.pop(name)
996
956
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
997
957
  weight_loader(param, loaded_weight)
@@ -249,7 +249,10 @@ class Qwen2Model(nn.Module):
249
249
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
250
250
 
251
251
  def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
252
- return self.embed_tokens(input_ids)
252
+ if hasattr(self.config, "scale_emb"):
253
+ return self.embed_tokens(input_ids) * self.config.scale_emb
254
+ else:
255
+ return self.embed_tokens(input_ids)
253
256
 
254
257
  def forward(
255
258
  self,