sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -41,7 +41,6 @@ from torch import nn
41
41
  from torch.nn.init import trunc_normal_
42
42
  from transformers import PretrainedConfig
43
43
 
44
- from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
45
44
  from sglang.srt.layers.activation import get_act_fn
46
45
  from sglang.srt.layers.attention.vision import VisionAttention
47
46
  from sglang.srt.layers.linear import (
@@ -51,6 +50,9 @@ from sglang.srt.layers.linear import (
51
50
  )
52
51
  from sglang.srt.layers.logits_processor import LogitsProcessor
53
52
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
53
+ from sglang.srt.managers.multi_modality_padding import (
54
+ MultiModalityDataPaddingPatternTokenPairs,
55
+ )
54
56
  from sglang.srt.managers.schedule_batch import ImageInputs
55
57
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
58
  from sglang.srt.model_loader.utils import set_default_torch_dtype
@@ -186,19 +188,16 @@ class Idefics2EncoderLayer(nn.Module):
186
188
  ) -> None:
187
189
  super().__init__()
188
190
  self.embed_dim = config.hidden_size
189
-
190
191
  self.num_heads = config.num_attention_heads
191
- tp_size = get_tensor_model_parallel_world_size()
192
- num_heads_per_partition = divide(self.num_heads, tp_size)
193
192
  self.self_attn = VisionAttention(
194
193
  embed_dim=config.hidden_size,
195
- num_heads=num_heads_per_partition,
194
+ num_heads=self.num_heads,
196
195
  projection_size=config.intermediate_size,
197
196
  use_qkv_parallel=True,
198
197
  quant_config=quant_config,
199
198
  dropout=config.attention_dropout,
200
199
  use_context_forward=False,
201
- use_full_precision_softmax=True,
200
+ softmax_in_single_precision=True,
202
201
  flatten_batch=False,
203
202
  prefix=add_prefix("self_attn", prefix),
204
203
  )
@@ -708,21 +707,21 @@ class MiniCPMVBaseModel(nn.Module):
708
707
  self,
709
708
  input_ids: torch.Tensor,
710
709
  pad_values: List[int],
711
- im_start_id: torch.Tensor,
712
- im_end_id: torch.Tensor,
713
- slice_start_id: Optional[torch.Tensor] = None,
714
- slice_end_id: Optional[torch.Tensor] = None,
710
+ im_start_id: int,
711
+ im_end_id: int,
712
+ slice_start_id: Optional[int] = None,
713
+ slice_end_id: Optional[int] = None,
715
714
  ) -> torch.Tensor:
716
715
  """
717
716
  Returns a tensor indicating the bounds (start and end token ids) of the images
718
717
  """
719
718
  # All the images in the batch should share the same special image
720
719
  # bound token ids.
721
- start_cond = input_ids == im_start_id[0]
722
- end_cond = input_ids == im_end_id[0]
720
+ start_cond = input_ids == im_start_id
721
+ end_cond = input_ids == im_end_id
723
722
  if slice_start_id is not None:
724
- start_cond |= input_ids == slice_start_id[0]
725
- end_cond |= input_ids == slice_end_id[0]
723
+ start_cond |= input_ids == slice_start_id
724
+ end_cond |= input_ids == slice_end_id
726
725
 
727
726
  (image_start_tokens,) = torch.where(start_cond)
728
727
  image_start_tokens += 1
@@ -733,6 +732,8 @@ class MiniCPMVBaseModel(nn.Module):
733
732
  if (
734
733
  len(image_start_tokens) + 1 == len(image_end_tokens)
735
734
  and input_ids[0] in pad_values
735
+ and len(image_start_tokens) != 0
736
+ and len(image_end_tokens) != 0
736
737
  and image_end_tokens[0] < image_start_tokens[0]
737
738
  ):
738
739
  image_start_tokens = torch.cat(
@@ -897,9 +898,12 @@ class MiniCPMVBaseModel(nn.Module):
897
898
  forward_batch: ForwardBatch,
898
899
  **kwargs: Any,
899
900
  ) -> torch.Tensor:
900
- if forward_batch.image_inputs is not None and forward_batch.image_inputs != [
901
- None
902
- ]:
901
+ if (
902
+ forward_batch.image_inputs is not None
903
+ and len(forward_batch.image_inputs) > 0
904
+ and forward_batch.image_inputs[0] is not None
905
+ ):
906
+ # TODO: bath
903
907
  kwargs.update(
904
908
  {
905
909
  "pixel_values": (
@@ -1135,81 +1139,16 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
1135
1139
  return self.resampler(vision_embedding, tgt_sizes)
1136
1140
 
1137
1141
  def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
1138
- if not isinstance(image_inputs.im_start_id, list) or not isinstance(
1139
- image_inputs.im_end_id, list
1140
- ):
1141
- return input_ids
1142
-
1143
- new_input_ids = []
1144
- last_idx = 0
1145
- image_idx = -1
1146
- image_inputs.image_offsets = []
1147
-
1148
1142
  # Get all special token IDs
1149
- im_start_id = (
1150
- image_inputs.im_start_id[0].item()
1151
- if isinstance(image_inputs.im_start_id[0], torch.Tensor)
1152
- else image_inputs.im_start_id[0]
1153
- )
1154
- im_end_id = (
1155
- image_inputs.im_end_id[0].item()
1156
- if isinstance(image_inputs.im_end_id[0], torch.Tensor)
1157
- else image_inputs.im_end_id[0]
1158
- )
1159
- slice_start_id = (
1160
- image_inputs.slice_start_id[0].item()
1161
- if isinstance(image_inputs.slice_start_id[0], torch.Tensor)
1162
- else image_inputs.slice_start_id[0]
1163
- )
1164
- slice_end_id = (
1165
- image_inputs.slice_end_id[0].item()
1166
- if isinstance(image_inputs.slice_end_id[0], torch.Tensor)
1167
- else image_inputs.slice_end_id[0]
1168
- )
1169
-
1170
- # Find all start and end positions for both types
1171
- start_indices = [
1172
- i
1173
- for i, x in enumerate(input_ids)
1174
- if x == im_start_id or x == slice_start_id
1175
- ]
1176
- end_indices = [
1177
- i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id
1178
- ]
1179
-
1180
- if len(start_indices) != len(end_indices):
1181
- return input_ids
1182
- # Process each region (both image and slice)
1183
- for start_idx, end_idx in zip(start_indices, end_indices):
1184
- # Add non-image tokens before this region
1185
- new_input_ids.extend(
1186
- input_ids[last_idx : start_idx + 1]
1187
- ) # include start token
1188
-
1189
- is_image_start = input_ids[start_idx] == im_start_id
1190
-
1191
- if is_image_start:
1192
- image_inputs.image_offsets += [start_idx]
1193
- image_idx += 1
1194
-
1195
- num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
1196
-
1197
- # Generate pad_ids
1198
- pad_values = [image_inputs.pad_values[image_idx]]
1199
-
1200
- pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
1201
- pad_ids = pad_ids[:num_tokens]
1202
-
1203
- # Add pad_ids
1204
- new_input_ids.extend(pad_ids)
1143
+ im_start_id: int = image_inputs.im_start_id
1144
+ im_end_id: int = image_inputs.im_end_id
1145
+ slice_start_id: int = image_inputs.slice_start_id
1146
+ slice_end_id: int = image_inputs.slice_end_id
1205
1147
 
1206
- # Update last_idx to after end token
1207
- last_idx = end_idx
1148
+ media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
1149
+ pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
1208
1150
 
1209
- # Add remaining tokens after last region
1210
- new_input_ids.extend(input_ids[last_idx:])
1211
- assert len(input_ids) == len(new_input_ids)
1212
- return new_input_ids
1151
+ return pattern.pad_input_tokens(input_ids, image_inputs)
1213
1152
 
1214
1153
 
1215
1154
  _SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
@@ -202,7 +202,7 @@ class MllamaVisionEncoderLayer(nn.Module):
202
202
  quant_config=None,
203
203
  dropout=0.0,
204
204
  use_context_forward=False,
205
- use_full_precision_softmax=False,
205
+ softmax_in_single_precision=False,
206
206
  flatten_batch=False,
207
207
  prefix=add_prefix("self_attn", prefix),
208
208
  )
@@ -15,7 +15,6 @@
15
15
  # Adapted from llama2.py
16
16
  # Modify details for the adaptation of Qwen2 model.
17
17
  """Inference-only Qwen2 model compatible with HuggingFace weights."""
18
- from readline import add_history
19
18
  from typing import Any, Dict, Iterable, Optional, Tuple
20
19
 
21
20
  import torch
@@ -47,6 +47,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.pooler import Pooler, PoolingType
48
48
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
49
49
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
50
+ from sglang.srt.managers.multi_modality_padding import (
51
+ MultiModalityDataPaddingPatternTokenPairs,
52
+ )
50
53
  from sglang.srt.managers.schedule_batch import ImageInputs
51
54
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
52
55
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -121,12 +124,12 @@ class Qwen2_5_VisionBlock(nn.Module):
121
124
  self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
122
125
  if attn_implementation == "sdpa":
123
126
  use_context_forward = False
124
- use_full_precision_softmax = False
127
+ softmax_in_single_precision = False
125
128
  elif attn_implementation == "flash_attention_2":
126
- use_full_precision_softmax = False
129
+ softmax_in_single_precision = False
127
130
  use_context_forward = True
128
131
  elif attn_implementation == "eager":
129
- use_full_precision_softmax = True
132
+ softmax_in_single_precision = True
130
133
  use_context_forward = False
131
134
 
132
135
  self.attn = VisionAttention(
@@ -135,7 +138,7 @@ class Qwen2_5_VisionBlock(nn.Module):
135
138
  projection_size=dim,
136
139
  use_qkv_parallel=False,
137
140
  use_context_forward=use_context_forward,
138
- use_full_precision_softmax=use_full_precision_softmax,
141
+ softmax_in_single_precision=softmax_in_single_precision,
139
142
  flatten_batch=True,
140
143
  quant_config=quant_config,
141
144
  prefix=add_prefix("attn", prefix),
@@ -149,12 +152,17 @@ class Qwen2_5_VisionBlock(nn.Module):
149
152
  )
150
153
 
151
154
  def forward(
152
- self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
155
+ self,
156
+ x: torch.Tensor,
157
+ cu_seqlens: torch.Tensor,
158
+ position_embeddings: torch.Tensor,
153
159
  ) -> torch.Tensor:
154
160
  hidden_states = self.norm1(x)
155
161
  hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
156
162
  attn = self.attn(
157
- hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
163
+ hidden_states,
164
+ cu_seqlens=cu_seqlens,
165
+ position_embeddings=position_embeddings,
158
166
  )
159
167
  attn = rearrange(attn, "b s ... -> s b ...")
160
168
  x = x + attn
@@ -443,6 +451,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
443
451
  )
444
452
  rotary_pos_emb = rotary_pos_emb[window_index, :, :]
445
453
  rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
454
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
455
+ position_embeddings = (emb.cos(), emb.sin())
446
456
 
447
457
  # compute cu_seqlens
448
458
  cu_seqlens = torch.repeat_interleave(
@@ -457,7 +467,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
457
467
  cu_seqlens_now = cu_seqlens
458
468
  else:
459
469
  cu_seqlens_now = cu_window_seqlens
460
- x = blk(x, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb)
470
+ x = blk(
471
+ x, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings
472
+ )
461
473
 
462
474
  # adapter
463
475
  x = self.merger(x)
@@ -522,50 +534,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
522
534
  return num_image_tokens
523
535
 
524
536
  def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
525
- new_input_ids = []
526
- last_idx = 0
527
- image_idx = -1
528
- image_inputs.image_offsets = []
529
-
530
537
  # Get all special token IDs
531
- im_start_id = image_inputs.im_start_id
532
- im_end_id = image_inputs.im_end_id
533
-
534
- # Find all start and end positions for both types
535
- start_indices = [i for i, x in enumerate(input_ids) if x == im_start_id]
536
- end_indices = [i for i, x in enumerate(input_ids) if x == im_end_id]
537
-
538
- if len(start_indices) != len(end_indices):
539
- return input_ids
540
- # Process each region (both image and slice)
541
- for start_idx, end_idx in zip(start_indices, end_indices):
542
- # Add non-image tokens before this region
543
- new_input_ids.extend(input_ids[last_idx : start_idx + 1])
544
-
545
- is_image_start = input_ids[start_idx] == im_start_id
546
-
547
- if is_image_start:
548
- image_inputs.image_offsets += [start_idx]
549
- image_idx += 1
550
-
551
- num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
552
-
553
- # Generate pad_ids
554
- pad_values = [image_inputs.pad_values[image_idx]]
555
-
556
- pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
557
- pad_ids = pad_ids[:num_tokens]
558
-
559
- # Add pad_ids
560
- new_input_ids.extend(pad_ids)
538
+ im_start_id: int = image_inputs.im_start_id
539
+ im_end_id: int = image_inputs.im_end_id
561
540
 
562
- # Update last_idx to after end token
563
- last_idx = end_idx
541
+ media_token_pairs = [(im_start_id, im_end_id)]
542
+ pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
564
543
 
565
- # Add remaining tokens after last region
566
- new_input_ids.extend(input_ids[last_idx:])
567
- assert len(input_ids) == len(new_input_ids)
568
- return new_input_ids
544
+ return pattern.pad_input_tokens(input_ids, image_inputs)
569
545
 
570
546
  def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
571
547
  pixel_values = image_input["pixel_values"].type(self.visual.dtype)
@@ -629,7 +605,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
629
605
  extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
630
606
  prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
631
607
  for i, image in enumerate(forward_batch.image_inputs):
632
- if image is None:
608
+ if image is None or image.pixel_values is None:
633
609
  continue
634
610
  start_idx = extend_start_loc_cpu[i]
635
611
  prefix_len = prefix_lens_cpu[i]
@@ -678,7 +654,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
678
654
  )
679
655
  image_embeds_offset += num_image_tokens
680
656
 
681
- input_ids = None
682
657
  hidden_states = self.model(
683
658
  input_ids=input_ids,
684
659
  positions=positions,
@@ -42,6 +42,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.pooler import Pooler, PoolingType
43
43
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
44
44
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
45
+ from sglang.srt.managers.multi_modality_padding import (
46
+ MultiModalityDataPaddingPatternTokenPairs,
47
+ )
45
48
  from sglang.srt.managers.schedule_batch import ImageInputs
46
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
50
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -137,12 +140,12 @@ class Qwen2VisionBlock(nn.Module):
137
140
  mlp_hidden_dim = int(dim * mlp_ratio)
138
141
  if attn_implementation == "sdpa":
139
142
  use_context_forward = False
140
- use_full_precision_softmax = False
143
+ softmax_in_single_precision = False
141
144
  elif attn_implementation == "flash_attention_2":
142
- use_full_precision_softmax = False
145
+ softmax_in_single_precision = False
143
146
  use_context_forward = True
144
147
  elif attn_implementation == "eager":
145
- use_full_precision_softmax = True
148
+ softmax_in_single_precision = True
146
149
  use_context_forward = False
147
150
 
148
151
  self.attn = VisionAttention(
@@ -151,7 +154,7 @@ class Qwen2VisionBlock(nn.Module):
151
154
  projection_size=dim,
152
155
  use_qkv_parallel=False,
153
156
  use_context_forward=use_context_forward,
154
- use_full_precision_softmax=use_full_precision_softmax,
157
+ softmax_in_single_precision=softmax_in_single_precision,
155
158
  flatten_batch=True,
156
159
  quant_config=quant_config,
157
160
  prefix=add_prefix("attn", prefix),
@@ -165,12 +168,17 @@ class Qwen2VisionBlock(nn.Module):
165
168
  )
166
169
 
167
170
  def forward(
168
- self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
171
+ self,
172
+ x: torch.Tensor,
173
+ cu_seqlens: torch.Tensor,
174
+ position_embeddings: torch.Tensor,
169
175
  ) -> torch.Tensor:
170
176
  hidden_states = self.norm1(x)
171
177
  hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
172
178
  attn = self.attn(
173
- hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
179
+ hidden_states,
180
+ cu_seqlens=cu_seqlens,
181
+ position_embeddings=position_embeddings,
174
182
  )
175
183
  attn = rearrange(attn, "b s ... -> s b ...")
176
184
  x = x + attn
@@ -392,7 +400,8 @@ class Qwen2VisionTransformer(nn.Module):
392
400
 
393
401
  # compute position embedding
394
402
  rotary_pos_emb = self.rot_pos_emb(grid_thw)
395
-
403
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
404
+ position_embeddings = (emb.cos(), emb.sin())
396
405
  # compute cu_seqlens
397
406
  cu_seqlens = torch.repeat_interleave(
398
407
  grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
@@ -402,7 +411,7 @@ class Qwen2VisionTransformer(nn.Module):
402
411
  # transformers
403
412
  x = x.unsqueeze(1)
404
413
  for blk in self.blocks:
405
- x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
414
+ x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
406
415
 
407
416
  # adapter
408
417
  x = self.merger(x)
@@ -425,40 +434,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
425
434
  )
426
435
  return num_image_tokens
427
436
 
428
- # Use grid_t * grid_w * grid_h to pad tokens for each image
429
- # add replaced padding by unique image hash
430
- def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
431
- image_grid_thws = image_inputs.image_grid_thws
432
- pad_values = image_inputs.pad_values
433
-
434
- image_indices = [
435
- idx
436
- for idx, token in enumerate(input_ids)
437
- if token == self.config.image_token_id
438
- ]
439
- image_inputs.image_offsets = []
440
-
441
- input_ids_with_image = []
442
- for image_cnt, _ in enumerate(image_grid_thws):
443
- num_image_tokens = self.calculate_num_image_tokens(
444
- image_grid_thws[image_cnt]
445
- )
446
- if image_cnt == 0:
447
- non_image_tokens = input_ids[: image_indices[image_cnt]]
448
- else:
449
- non_image_tokens = input_ids[
450
- image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
451
- ]
452
- input_ids_with_image.extend(non_image_tokens)
453
- image_inputs.image_offsets.append(len(input_ids_with_image))
454
- pad_ids = pad_values * (
455
- (num_image_tokens + len(pad_values)) // len(pad_values)
456
- )
457
- input_ids_with_image.extend(pad_ids[:num_image_tokens])
458
- input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
459
-
460
- return input_ids_with_image
461
-
462
437
  def __init__(
463
438
  self,
464
439
  config: Qwen2VLConfig,
@@ -494,6 +469,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
494
469
  self.logits_processor = LogitsProcessor(config)
495
470
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
496
471
 
472
+ # Use grid_t * grid_w * grid_h to pad tokens for each image
473
+ # add replaced padding by unique image hash
474
+ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
475
+ # Get all special token IDs
476
+ im_start_id: int = image_inputs.im_start_id
477
+ im_end_id: int = image_inputs.im_end_id
478
+
479
+ media_token_pairs = [(im_start_id, im_end_id)]
480
+ pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
481
+ return pattern.pad_input_tokens(input_ids, image_inputs)
482
+
497
483
  def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
498
484
  pixel_values = image_input["pixel_values"].type(self.visual.dtype)
499
485
  image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
@@ -556,12 +542,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
556
542
  extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
557
543
  prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
558
544
  for i, image in enumerate(forward_batch.image_inputs):
559
- if image is None:
545
+ if image is None or image.pixel_values is None:
560
546
  continue
561
547
  start_idx = extend_start_loc_cpu[i]
562
548
  prefix_len = prefix_lens_cpu[i]
549
+ pixel_values = image.pixel_values.clone()
563
550
 
564
- pixel_values = torch.tensor(image.pixel_values, device="cuda")
565
551
  image_grid_thws = torch.tensor(
566
552
  np.array(image.image_grid_thws), device="cuda"
567
553
  )
@@ -579,15 +565,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
579
565
  image_grid_thws[idx]
580
566
  )
581
567
 
582
- left_idx = start_idx + (image_offset - prefix_len)
583
- right_idx = (
584
- start_idx + (image_offset - prefix_len) + num_image_tokens
585
- )
586
-
568
+ left_idx = start_idx + (image_offset - prefix_len + 1)
569
+ right_idx = left_idx + num_image_tokens
587
570
  inputs_embeds[left_idx:right_idx] = image_embeds[
588
571
  image_embeds_offset : image_embeds_offset + num_image_tokens
589
572
  ]
590
573
  image_embeds_offset += num_image_tokens
574
+ input_ids = None
591
575
 
592
576
  hidden_states = self.model(
593
577
  input_ids=input_ids,