sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,467 @@
1
+ # Copyright 2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ """
16
+ Using mistral-community/pixtral-12b as reference.
17
+ """
18
+
19
+ import logging
20
+ import math
21
+ from typing import Iterable, List, Optional, Set, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from transformers import PixtralVisionConfig, PretrainedConfig
27
+ from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding
28
+ from transformers.models.pixtral.modeling_pixtral import (
29
+ generate_block_attention_mask as _get_pixtral_attention_mask,
30
+ )
31
+ from transformers.models.pixtral.modeling_pixtral import position_ids_in_meshgrid
32
+
33
+ from sglang.srt.layers.activation import SiluAndMul
34
+ from sglang.srt.layers.attention.vision import VisionAttention
35
+ from sglang.srt.layers.layernorm import RMSNorm
36
+ from sglang.srt.layers.linear import MergedColumnParallelLinear, RowParallelLinear
37
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
+ from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
39
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
40
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
41
+
42
+
43
+ class PixtralHFMLP(nn.Module):
44
+ """MLP for PixtralHFVisionModel using SGLang components."""
45
+
46
+ def __init__(
47
+ self,
48
+ config: PretrainedConfig,
49
+ quant_config: Optional[QuantizationConfig] = None,
50
+ *,
51
+ prefix: str = "",
52
+ ) -> None:
53
+ super().__init__()
54
+
55
+ assert config.intermediate_size is not None
56
+
57
+ # Use MergedColumnParallelLinear for gate_up_proj to handle combined weights
58
+ self.gate_up_proj = MergedColumnParallelLinear(
59
+ input_size=config.hidden_size,
60
+ output_sizes=[config.intermediate_size, config.intermediate_size],
61
+ bias=False,
62
+ quant_config=quant_config,
63
+ prefix=f"{prefix}.gate_up_proj",
64
+ )
65
+
66
+ self.down_proj = RowParallelLinear(
67
+ input_size=config.intermediate_size,
68
+ output_size=config.hidden_size,
69
+ bias=False,
70
+ quant_config=quant_config,
71
+ prefix=f"{prefix}.down_proj",
72
+ )
73
+
74
+ self.act_fn = SiluAndMul()
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ gate_up_output, _ = self.gate_up_proj(x)
78
+
79
+ # Apply SiLU activation and multiply
80
+ gate_up = self.act_fn(gate_up_output)
81
+
82
+ # Project back to hidden size
83
+ out, _ = self.down_proj(gate_up)
84
+ return out
85
+
86
+
87
+ class PixtralHFTransformerBlock(nn.Module):
88
+ """Transformer block for PixtralHFVisionModel using SGLang components."""
89
+
90
+ def __init__(
91
+ self,
92
+ config: PretrainedConfig,
93
+ layer_id: int,
94
+ quant_config: Optional[QuantizationConfig] = None,
95
+ *,
96
+ prefix: str = "",
97
+ ) -> None:
98
+ super().__init__()
99
+
100
+ self.layer_id = layer_id
101
+ self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
102
+
103
+ # Use SGLang's VisionAttention instead of vLLM's PixtralHFAttention
104
+ self.attention = VisionAttention(
105
+ embed_dim=config.hidden_size,
106
+ num_heads=config.num_attention_heads,
107
+ projection_size=config.hidden_size,
108
+ use_qkv_parallel=True,
109
+ quant_config=quant_config,
110
+ dropout=0.0,
111
+ use_context_forward=False,
112
+ softmax_in_single_precision=False,
113
+ flatten_batch=False,
114
+ prefix=f"{prefix}.attention",
115
+ )
116
+
117
+ self.feed_forward = PixtralHFMLP(
118
+ config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
119
+ )
120
+
121
+ self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
122
+
123
+ def forward(
124
+ self,
125
+ hidden_states: torch.Tensor,
126
+ attention_mask: Optional[torch.Tensor],
127
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
128
+ ) -> torch.Tensor:
129
+ # Ensure hidden_states has the batch dimension [batch, seq_len, hidden_dim]
130
+ batch_size, seq_len, hidden_dim = hidden_states.shape
131
+
132
+ # Apply attention norm - normalize along the last dimension
133
+ attn_normalized = self.attention_norm(hidden_states.view(-1, hidden_dim)).view(
134
+ batch_size, seq_len, hidden_dim
135
+ )
136
+
137
+ # Pass through attention layer
138
+ attention_output = self.attention(
139
+ attn_normalized,
140
+ attention_mask=attention_mask,
141
+ cu_seqlens=None,
142
+ position_embeddings=position_embeddings,
143
+ )
144
+
145
+ # Apply first residual connection
146
+ hidden_states = hidden_states + attention_output
147
+
148
+ # Apply feed-forward norm - normalize along the last dimension
149
+ ffn_normalized = self.ffn_norm(hidden_states.view(-1, hidden_dim)).view(
150
+ batch_size, seq_len, hidden_dim
151
+ )
152
+
153
+ # Pass through feed-forward layer
154
+ # First reshape to 2D for the feed-forward network, then reshape back
155
+ ffn_output = self.feed_forward(ffn_normalized)
156
+
157
+ # Apply second residual connection
158
+ output = hidden_states + ffn_output
159
+
160
+ return output
161
+
162
+
163
+ class PixtralHFTransformer(nn.Module):
164
+ """Transformer for PixtralHFVisionModel using SGLang components."""
165
+
166
+ def __init__(
167
+ self,
168
+ config: PixtralVisionConfig,
169
+ quant_config: Optional[QuantizationConfig] = None,
170
+ *,
171
+ num_hidden_layers_override: Optional[int] = None,
172
+ prefix: str = "",
173
+ ) -> None:
174
+ super().__init__()
175
+
176
+ num_hidden_layers = config.num_hidden_layers
177
+ if num_hidden_layers_override is not None:
178
+ num_hidden_layers = num_hidden_layers_override
179
+
180
+ self.layers = nn.ModuleList(
181
+ [
182
+ PixtralHFTransformerBlock(
183
+ config=config,
184
+ layer_id=layer_idx,
185
+ quant_config=quant_config,
186
+ prefix=f"{prefix}.layers.{layer_idx}",
187
+ )
188
+ for layer_idx in range(num_hidden_layers)
189
+ ]
190
+ )
191
+
192
+ def forward(
193
+ self,
194
+ x: torch.Tensor,
195
+ attention_mask: Optional[torch.Tensor],
196
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
197
+ return_all_hidden_states: bool = False,
198
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
199
+ """Forward pass through transformer layers.
200
+
201
+ Args:
202
+ x: Input tensor
203
+ attention_mask: Optional attention mask
204
+ position_embeddings: Optional position embeddings for rotary attention
205
+ return_all_hidden_states: Whether to return all hidden states
206
+
207
+ Returns:
208
+ Either the final hidden state, or a list of all hidden states if
209
+ return_all_hidden_states is True
210
+ """
211
+ # For HF model compatibility, always start with the input
212
+ hidden_states = x
213
+ all_hidden_states = [hidden_states] if return_all_hidden_states else None
214
+
215
+ for i, layer in enumerate(self.layers):
216
+ hidden_states = layer(hidden_states, attention_mask, position_embeddings)
217
+ if return_all_hidden_states:
218
+ all_hidden_states.append(hidden_states)
219
+
220
+ if return_all_hidden_states:
221
+ return all_hidden_states
222
+ return hidden_states
223
+
224
+
225
+ def resolve_visual_encoder_outputs(
226
+ outputs: Union[torch.Tensor, List[torch.Tensor]],
227
+ feature_sample_layers: Optional[List[int]],
228
+ post_norm: Optional[nn.Module],
229
+ num_hidden_layers: int,
230
+ ) -> torch.Tensor:
231
+ """Resolve outputs from visual encoder based on feature_sample_layers."""
232
+ if feature_sample_layers is None:
233
+ # Just use the last layer's output
234
+ if isinstance(outputs, list):
235
+ outputs = outputs[-1]
236
+ if post_norm is not None:
237
+ outputs = post_norm(outputs)
238
+ return outputs
239
+
240
+ # Handle the case where we want to use specific layers
241
+ if not isinstance(outputs, list):
242
+ raise ValueError(
243
+ "Expected outputs to be a list when feature_sample_layers is provided"
244
+ )
245
+
246
+ # Validate layer indices
247
+ for layer_idx in feature_sample_layers:
248
+ if layer_idx < 0 or layer_idx > num_hidden_layers:
249
+ raise ValueError(
250
+ f"Feature sample layer index {layer_idx} is out of range "
251
+ f"[0, {num_hidden_layers}]"
252
+ )
253
+
254
+ # Collect outputs from specified layers
255
+ selected_outputs = [outputs[layer_idx] for layer_idx in feature_sample_layers]
256
+
257
+ # Combine the outputs
258
+ combined_outputs = torch.cat(selected_outputs, dim=-1)
259
+
260
+ if post_norm is not None:
261
+ combined_outputs = post_norm(combined_outputs)
262
+
263
+ return combined_outputs
264
+
265
+
266
+ class PixtralHFVisionModel(nn.Module):
267
+ """Hugging Face Pixtral Vision Model implemented using SGLang components."""
268
+
269
+ DEFAULT_IMAGE_TOKEN_ID = 10
270
+
271
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
272
+ return self.input_padder.pad_input_tokens(input_ids, image_inputs)
273
+
274
+ def __init__(
275
+ self,
276
+ config: PixtralVisionConfig,
277
+ quant_config: Optional[QuantizationConfig] = None,
278
+ *,
279
+ image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
280
+ num_hidden_layers_override: Optional[int] = None,
281
+ prefix: str = "",
282
+ ) -> None:
283
+ super().__init__()
284
+
285
+ self.config = config
286
+
287
+ self.image_size = config.image_size
288
+ self.patch_size = config.patch_size
289
+
290
+ self.patch_conv = nn.Conv2d(
291
+ in_channels=config.num_channels,
292
+ out_channels=config.hidden_size,
293
+ kernel_size=config.patch_size,
294
+ stride=config.patch_size,
295
+ bias=False,
296
+ )
297
+
298
+ self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
299
+
300
+ self.transformer = PixtralHFTransformer(
301
+ config,
302
+ quant_config,
303
+ num_hidden_layers_override=num_hidden_layers_override,
304
+ prefix=f"{prefix}.transformer",
305
+ )
306
+
307
+ # Check that num_hidden_layers is valid
308
+ num_hidden_layers = config.num_hidden_layers
309
+ if len(self.transformer.layers) > config.num_hidden_layers:
310
+ raise ValueError(
311
+ f"The original encoder only has {num_hidden_layers} "
312
+ f"layers, but you requested {len(self.transformer.layers)} "
313
+ "layers."
314
+ )
315
+
316
+ # Initialize patch position embedding
317
+ self.image_token_id = image_token_id
318
+ self.patch_positional_embedding = PixtralRotaryEmbedding(config)
319
+ self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
320
+ [self.image_token_id]
321
+ )
322
+
323
+ @property
324
+ def dtype(self):
325
+ return next(self.parameters()).dtype
326
+
327
+ @property
328
+ def device(self):
329
+ return next(self.parameters()).device
330
+
331
+ def forward(
332
+ self,
333
+ pixel_values: torch.Tensor,
334
+ image_sizes: list[tuple[int, int]],
335
+ output_hidden_states: bool = False,
336
+ feature_sample_layers: Optional[list[int]] = None,
337
+ ) -> Union[torch.Tensor, tuple]:
338
+ """
339
+ Args:
340
+ pixel_values: [batch_size, C, H, W], padded if multiple images
341
+ image_sizes: list of (H, W) for each image in the batch
342
+ output_hidden_states: Whether to return all hidden states.
343
+ feature_sample_layers: Layer indices whose features should be
344
+ concatenated and used as the visual encoder output. If none
345
+ are provided, the last layer is used.
346
+
347
+ Returns:
348
+ A tuple containing:
349
+ - hidden_states: Final model outputs (or selected layers if feature_sample_layers given)
350
+ - hidden_states tuple (optional): All hidden states if output_hidden_states=True
351
+ """
352
+ # batch patch images
353
+ embeds_orig = self.patch_conv(
354
+ pixel_values.to(device=self.device, dtype=self.dtype)
355
+ )
356
+ # crop the embeddings
357
+ embeds_2d = [
358
+ embed[..., : h // self.patch_size, : w // self.patch_size]
359
+ for embed, (h, w) in zip(embeds_orig, image_sizes)
360
+ ]
361
+
362
+ # flatten to sequence
363
+ embeds_1d = torch.cat([p.flatten(1).T for p in embeds_2d], dim=0)
364
+ embeds_featurized = self.ln_pre(embeds_1d).unsqueeze(0)
365
+
366
+ # positional embeddings
367
+ position_ids = position_ids_in_meshgrid(
368
+ embeds_2d,
369
+ max_width=self.image_size // self.patch_size,
370
+ ).to(self.device)
371
+
372
+ # The original PixtralRotaryEmbedding expects 2D input but returns a tuple of tensors (cos, sin)
373
+ # These tensors are used by apply_rotary_pos_emb in the transformer blocks
374
+ position_embedding = self.patch_positional_embedding(
375
+ embeds_featurized, position_ids
376
+ )
377
+ attention_mask = _get_pixtral_attention_mask(
378
+ [p.shape[-2] * p.shape[-1] for p in embeds_2d], embeds_featurized
379
+ )
380
+
381
+ return_all_hidden_states = (
382
+ output_hidden_states or feature_sample_layers is not None
383
+ )
384
+
385
+ transformer_outputs = self.transformer(
386
+ embeds_featurized, # add batch dimension
387
+ attention_mask,
388
+ position_embedding,
389
+ return_all_hidden_states=return_all_hidden_states,
390
+ )
391
+
392
+ # Store all hidden states if requested
393
+ all_hidden_states = None
394
+ if isinstance(transformer_outputs, list):
395
+ all_hidden_states = transformer_outputs
396
+ # Use the last layer by default if feature_sample_layers is not specified
397
+ if feature_sample_layers is None:
398
+ out = transformer_outputs[-1]
399
+ else:
400
+ # Resolve outputs based on feature sample layers
401
+ out = resolve_visual_encoder_outputs(
402
+ transformer_outputs,
403
+ feature_sample_layers,
404
+ None,
405
+ self.config.num_hidden_layers,
406
+ )
407
+ else:
408
+ out = transformer_outputs
409
+
410
+ # Format return to be compatible with HuggingFace vision models
411
+ if output_hidden_states:
412
+ return type(
413
+ "VisualOutput",
414
+ (),
415
+ {
416
+ "last_hidden_state": out,
417
+ "hidden_states": all_hidden_states,
418
+ },
419
+ )
420
+ else:
421
+ return out
422
+
423
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
424
+ """Load weights from a HuggingFace checkpoint with proper parameter mapping."""
425
+ params_dict = dict(self.named_parameters())
426
+
427
+ # for (param, weight, shard_id): load weight into param as param's shard_id part
428
+ stacked_params_mapping = [
429
+ (".attention.qkv_proj", ".attention.q_proj", "q"),
430
+ (".attention.qkv_proj", ".attention.k_proj", "k"),
431
+ (".attention.qkv_proj", ".attention.v_proj", "v"),
432
+ (".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
433
+ (".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
434
+ ]
435
+
436
+ # Process each weight
437
+ for name, loaded_weight in weights:
438
+ for param_name, weight_name, shard_id in stacked_params_mapping:
439
+ if weight_name in name:
440
+ # Replace the weight name part with the combined parameter name
441
+ transformed_name = name.replace(weight_name, param_name)
442
+ if transformed_name in params_dict:
443
+ param = params_dict[transformed_name]
444
+ weight_loader = getattr(
445
+ param, "weight_loader", default_weight_loader
446
+ )
447
+ weight_loader(param, loaded_weight, shard_id)
448
+ break
449
+ else:
450
+ if ".attention.o_proj" in name:
451
+ alt_name = name.replace(".attention.o_proj", ".attention.proj")
452
+ if alt_name in params_dict:
453
+ name = alt_name
454
+ if name in params_dict:
455
+ param = params_dict[name]
456
+ weight_loader = getattr(
457
+ param, "weight_loader", default_weight_loader
458
+ )
459
+ weight_loader(param, loaded_weight)
460
+
461
+
462
+ class PixtralVisionModel(PixtralHFVisionModel):
463
+ pass
464
+
465
+
466
+ # Register the model classes for external access
467
+ EntryClass = [PixtralVisionModel]
@@ -15,12 +15,14 @@
15
15
  # Adapted from llama2.py
16
16
  # Modify details for the adaptation of Qwen2 model.
17
17
  """Inference-only Qwen2 model compatible with HuggingFace weights."""
18
- from typing import Any, Dict, Iterable, Optional, Tuple
18
+ import logging
19
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
19
20
 
20
21
  import torch
21
22
  from torch import nn
22
23
 
23
24
  from sglang.srt.distributed import (
25
+ get_pp_group,
24
26
  get_tensor_model_parallel_rank,
25
27
  get_tensor_model_parallel_world_size,
26
28
  )
@@ -36,11 +38,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
36
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
39
  from sglang.srt.layers.radix_attention import RadixAttention
38
40
  from sglang.srt.layers.rotary_embedding import get_rope
41
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
39
42
  from sglang.srt.layers.vocab_parallel_embedding import (
40
43
  ParallelLMHead,
41
44
  VocabParallelEmbedding,
42
45
  )
43
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
46
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
44
47
  from sglang.srt.model_loader.weight_utils import (
45
48
  default_weight_loader,
46
49
  kv_cache_scales_loader,
@@ -50,6 +53,9 @@ from sglang.srt.utils import add_prefix, make_layers
50
53
  Qwen2Config = None
51
54
 
52
55
 
56
+ logger = logging.getLogger(__name__)
57
+
58
+
53
59
  class Qwen2MLP(nn.Module):
54
60
  def __init__(
55
61
  self,
@@ -245,15 +251,21 @@ class Qwen2Model(nn.Module):
245
251
  self.config = config
246
252
  self.padding_idx = config.pad_token_id
247
253
  self.vocab_size = config.vocab_size
248
- self.embed_tokens = VocabParallelEmbedding(
249
- config.vocab_size,
250
- config.hidden_size,
251
- quant_config=quant_config,
252
- prefix=add_prefix("embed_tokens", prefix),
253
- )
254
+ self.pp_group = get_pp_group()
255
+
256
+ if self.pp_group.is_first_rank:
257
+ self.embed_tokens = VocabParallelEmbedding(
258
+ config.vocab_size,
259
+ config.hidden_size,
260
+ quant_config=quant_config,
261
+ prefix=add_prefix("embed_tokens", prefix),
262
+ )
263
+ else:
264
+ self.embed_tokens = PPMissingLayer()
265
+
254
266
  # Use the provided decoder layer type or default to Qwen2DecoderLayer
255
267
  decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
256
- self.layers = make_layers(
268
+ self.layers, self.start_layer, self.end_layer = make_layers(
257
269
  config.num_hidden_layers,
258
270
  lambda idx, prefix: decoder_layer_type(
259
271
  layer_id=idx,
@@ -261,9 +273,14 @@ class Qwen2Model(nn.Module):
261
273
  quant_config=quant_config,
262
274
  prefix=prefix,
263
275
  ),
276
+ pp_rank=self.pp_group.rank_in_group,
277
+ pp_size=self.pp_group.world_size,
264
278
  prefix=add_prefix("layers", prefix),
265
279
  )
266
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
280
+ if self.pp_group.is_last_rank:
281
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
282
+ else:
283
+ self.norm = PPMissingLayer(return_tuple=True)
267
284
 
268
285
  def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
269
286
  if hasattr(self.config, "scale_emb"):
@@ -280,13 +297,20 @@ class Qwen2Model(nn.Module):
280
297
  positions: torch.Tensor,
281
298
  forward_batch: ForwardBatch,
282
299
  input_embeds: torch.Tensor = None,
283
- ) -> torch.Tensor:
284
- if input_embeds is None:
285
- hidden_states = self.embed_tokens(input_ids)
300
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
301
+ ) -> Union[torch.Tensor, PPProxyTensors]:
302
+ if self.pp_group.is_first_rank:
303
+ if input_embeds is None:
304
+ hidden_states = self.embed_tokens(input_ids)
305
+ else:
306
+ hidden_states = input_embeds
307
+ residual = None
286
308
  else:
287
- hidden_states = input_embeds
288
- residual = None
289
- for i in range(len(self.layers)):
309
+ assert pp_proxy_tensors is not None
310
+ hidden_states = pp_proxy_tensors["hidden_states"]
311
+ residual = pp_proxy_tensors["residual"]
312
+
313
+ for i in range(self.start_layer, self.end_layer):
290
314
  layer = self.layers[i]
291
315
  hidden_states, residual = layer(
292
316
  positions,
@@ -294,7 +318,15 @@ class Qwen2Model(nn.Module):
294
318
  forward_batch,
295
319
  residual,
296
320
  )
297
- hidden_states, _ = self.norm(hidden_states, residual)
321
+ if not self.pp_group.is_last_rank:
322
+ return PPProxyTensors(
323
+ {
324
+ "hidden_states": hidden_states,
325
+ "residual": residual,
326
+ }
327
+ )
328
+ else:
329
+ hidden_states, _ = self.norm(hidden_states, residual)
298
330
  return hidden_states
299
331
 
300
332
  # If this function is called, it should always initialize KV cache scale
@@ -348,6 +380,7 @@ class Qwen2ForCausalLM(nn.Module):
348
380
  prefix: str = "",
349
381
  ) -> None:
350
382
  super().__init__()
383
+ self.pp_group = get_pp_group()
351
384
  self.config = config
352
385
  self.quant_config = quant_config
353
386
  self.model = Qwen2Model(
@@ -379,14 +412,33 @@ class Qwen2ForCausalLM(nn.Module):
379
412
  forward_batch: ForwardBatch,
380
413
  input_embeds: torch.Tensor = None,
381
414
  get_embedding: bool = False,
415
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
382
416
  ) -> torch.Tensor:
383
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
384
- if not get_embedding:
385
- return self.logits_processor(
386
- input_ids, hidden_states, self.lm_head, forward_batch
387
- )
417
+ hidden_states = self.model(
418
+ input_ids,
419
+ positions,
420
+ forward_batch,
421
+ input_embeds,
422
+ pp_proxy_tensors=pp_proxy_tensors,
423
+ )
424
+
425
+ if self.pp_group.is_last_rank:
426
+ if not get_embedding:
427
+ return self.logits_processor(
428
+ input_ids, hidden_states, self.lm_head, forward_batch
429
+ )
430
+ else:
431
+ return self.pooler(hidden_states, forward_batch)
388
432
  else:
389
- return self.pooler(hidden_states, forward_batch)
433
+ return hidden_states
434
+
435
+ @property
436
+ def start_layer(self):
437
+ return self.model.start_layer
438
+
439
+ @property
440
+ def end_layer(self):
441
+ return self.model.end_layer
390
442
 
391
443
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
392
444
  stacked_params_mapping = [
@@ -400,6 +452,17 @@ class Qwen2ForCausalLM(nn.Module):
400
452
 
401
453
  params_dict = dict(self.named_parameters())
402
454
  for name, loaded_weight in weights:
455
+ layer_id = get_layer_id(name)
456
+ if (
457
+ layer_id is not None
458
+ and hasattr(self.model, "start_layer")
459
+ and (
460
+ layer_id < self.model.start_layer
461
+ or layer_id >= self.model.end_layer
462
+ )
463
+ ):
464
+ continue
465
+
403
466
  if "rotary_emb.inv_freq" in name or "projector" in name:
404
467
  continue
405
468
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -426,9 +489,15 @@ class Qwen2ForCausalLM(nn.Module):
426
489
  # Skip loading extra bias for GPTQ models.
427
490
  if name.endswith(".bias") and name not in params_dict:
428
491
  continue
429
- param = params_dict[name]
430
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
431
- weight_loader(param, loaded_weight)
492
+
493
+ if name in params_dict.keys():
494
+ param = params_dict[name]
495
+ weight_loader = getattr(
496
+ param, "weight_loader", default_weight_loader
497
+ )
498
+ weight_loader(param, loaded_weight)
499
+ else:
500
+ logger.warning(f"Parameter {name} not found in params_dict")
432
501
 
433
502
  def get_embed_and_head(self):
434
503
  return self.model.embed_tokens.weight, self.lm_head.weight
@@ -146,6 +146,8 @@ class Qwen2_5_VisionBlock(nn.Module):
146
146
  num_heads=num_heads,
147
147
  projection_size=dim,
148
148
  use_qkv_parallel=True,
149
+ rotary_embed="normal",
150
+ proj_bias=True,
149
151
  qkv_backend=qkv_backend,
150
152
  softmax_in_single_precision=softmax_in_single_precision,
151
153
  flatten_batch=flatten_batch,
@@ -497,6 +499,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
497
499
  return pattern.pad_input_tokens(input_ids, mm_inputs)
498
500
 
499
501
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
502
+ if any(item.precomputed_features is not None for item in items):
503
+ if not all(item.precomputed_features is not None for item in items):
504
+ raise NotImplementedError(
505
+ "MM inputs where only some items are precomputed."
506
+ )
507
+ return torch.concat([item.precomputed_features for item in items])
500
508
  # in qwen-vl, last dim is the same
501
509
  pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
502
510
  self.visual.dtype