sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ from torch import nn
6
6
  from transformers import Phi3Config
7
7
  from transformers.configuration_utils import PretrainedConfig
8
8
 
9
- from sglang.srt.distributed import get_tensor_model_parallel_world_size
9
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
10
10
  from sglang.srt.layers.linear import (
11
11
  MergedColumnParallelLinear,
12
12
  QKVParallelLinear,
@@ -17,6 +17,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
17
17
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
18
18
  from sglang.srt.layers.radix_attention import RadixAttention
19
19
  from sglang.srt.layers.rotary_embedding import get_rope
20
+ from sglang.srt.layers.utils import PPMissingLayer
20
21
  from sglang.srt.layers.vocab_parallel_embedding import (
21
22
  DEFAULT_VOCAB_PADDING_SIZE,
22
23
  ParallelLMHead,
@@ -294,13 +295,24 @@ class Phi3SmallModel(nn.Module):
294
295
  super().__init__()
295
296
 
296
297
  self.config = config
298
+
299
+ self.pp_group = get_pp_group()
300
+ if self.pp_group.is_first_rank:
301
+ self.embed_tokens = VocabParallelEmbedding(
302
+ config.vocab_size,
303
+ config.hidden_size,
304
+ prefix=add_prefix("embed_tokens", prefix),
305
+ )
306
+ else:
307
+ self.embed_tokens = PPMissingLayer()
308
+
297
309
  self.embed_tokens = VocabParallelEmbedding(
298
310
  config.vocab_size,
299
311
  config.hidden_size,
300
312
  prefix=add_prefix("embed_tokens", prefix),
301
313
  )
302
314
  self.mup_embedding_multiplier = config.mup_embedding_multiplier
303
- self.start_layer, self.end_layer, self.layers = make_layers(
315
+ self.layers, self.start_layer, self.end_layer = make_layers(
304
316
  config.num_hidden_layers,
305
317
  lambda idx, prefix: Phi3SmallDecoderLayer(
306
318
  config,
@@ -308,6 +320,8 @@ class Phi3SmallModel(nn.Module):
308
320
  quant_config,
309
321
  prefix=prefix,
310
322
  ),
323
+ pp_rank=self.pp_group.rank_in_group,
324
+ pp_size=self.pp_group.world_size,
311
325
  prefix=add_prefix("layers", prefix),
312
326
  )
313
327
 
@@ -0,0 +1,467 @@
1
+ # Copyright 2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ """
16
+ Using mistral-community/pixtral-12b as reference.
17
+ """
18
+
19
+ import logging
20
+ import math
21
+ from typing import Iterable, List, Optional, Set, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from transformers import PixtralVisionConfig, PretrainedConfig
27
+ from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding
28
+ from transformers.models.pixtral.modeling_pixtral import (
29
+ generate_block_attention_mask as _get_pixtral_attention_mask,
30
+ )
31
+ from transformers.models.pixtral.modeling_pixtral import position_ids_in_meshgrid
32
+
33
+ from sglang.srt.layers.activation import SiluAndMul
34
+ from sglang.srt.layers.attention.vision import VisionAttention
35
+ from sglang.srt.layers.layernorm import RMSNorm
36
+ from sglang.srt.layers.linear import MergedColumnParallelLinear, RowParallelLinear
37
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
+ from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
39
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
40
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
41
+
42
+
43
+ class PixtralHFMLP(nn.Module):
44
+ """MLP for PixtralHFVisionModel using SGLang components."""
45
+
46
+ def __init__(
47
+ self,
48
+ config: PretrainedConfig,
49
+ quant_config: Optional[QuantizationConfig] = None,
50
+ *,
51
+ prefix: str = "",
52
+ ) -> None:
53
+ super().__init__()
54
+
55
+ assert config.intermediate_size is not None
56
+
57
+ # Use MergedColumnParallelLinear for gate_up_proj to handle combined weights
58
+ self.gate_up_proj = MergedColumnParallelLinear(
59
+ input_size=config.hidden_size,
60
+ output_sizes=[config.intermediate_size, config.intermediate_size],
61
+ bias=False,
62
+ quant_config=quant_config,
63
+ prefix=f"{prefix}.gate_up_proj",
64
+ )
65
+
66
+ self.down_proj = RowParallelLinear(
67
+ input_size=config.intermediate_size,
68
+ output_size=config.hidden_size,
69
+ bias=False,
70
+ quant_config=quant_config,
71
+ prefix=f"{prefix}.down_proj",
72
+ )
73
+
74
+ self.act_fn = SiluAndMul()
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ gate_up_output, _ = self.gate_up_proj(x)
78
+
79
+ # Apply SiLU activation and multiply
80
+ gate_up = self.act_fn(gate_up_output)
81
+
82
+ # Project back to hidden size
83
+ out, _ = self.down_proj(gate_up)
84
+ return out
85
+
86
+
87
+ class PixtralHFTransformerBlock(nn.Module):
88
+ """Transformer block for PixtralHFVisionModel using SGLang components."""
89
+
90
+ def __init__(
91
+ self,
92
+ config: PretrainedConfig,
93
+ layer_id: int,
94
+ quant_config: Optional[QuantizationConfig] = None,
95
+ *,
96
+ prefix: str = "",
97
+ ) -> None:
98
+ super().__init__()
99
+
100
+ self.layer_id = layer_id
101
+ self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
102
+
103
+ # Use SGLang's VisionAttention instead of vLLM's PixtralHFAttention
104
+ self.attention = VisionAttention(
105
+ embed_dim=config.hidden_size,
106
+ num_heads=config.num_attention_heads,
107
+ projection_size=config.hidden_size,
108
+ use_qkv_parallel=True,
109
+ quant_config=quant_config,
110
+ dropout=0.0,
111
+ use_context_forward=False,
112
+ softmax_in_single_precision=False,
113
+ flatten_batch=False,
114
+ prefix=f"{prefix}.attention",
115
+ )
116
+
117
+ self.feed_forward = PixtralHFMLP(
118
+ config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
119
+ )
120
+
121
+ self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
122
+
123
+ def forward(
124
+ self,
125
+ hidden_states: torch.Tensor,
126
+ attention_mask: Optional[torch.Tensor],
127
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
128
+ ) -> torch.Tensor:
129
+ # Ensure hidden_states has the batch dimension [batch, seq_len, hidden_dim]
130
+ batch_size, seq_len, hidden_dim = hidden_states.shape
131
+
132
+ # Apply attention norm - normalize along the last dimension
133
+ attn_normalized = self.attention_norm(hidden_states.view(-1, hidden_dim)).view(
134
+ batch_size, seq_len, hidden_dim
135
+ )
136
+
137
+ # Pass through attention layer
138
+ attention_output = self.attention(
139
+ attn_normalized,
140
+ attention_mask=attention_mask,
141
+ cu_seqlens=None,
142
+ position_embeddings=position_embeddings,
143
+ )
144
+
145
+ # Apply first residual connection
146
+ hidden_states = hidden_states + attention_output
147
+
148
+ # Apply feed-forward norm - normalize along the last dimension
149
+ ffn_normalized = self.ffn_norm(hidden_states.view(-1, hidden_dim)).view(
150
+ batch_size, seq_len, hidden_dim
151
+ )
152
+
153
+ # Pass through feed-forward layer
154
+ # First reshape to 2D for the feed-forward network, then reshape back
155
+ ffn_output = self.feed_forward(ffn_normalized)
156
+
157
+ # Apply second residual connection
158
+ output = hidden_states + ffn_output
159
+
160
+ return output
161
+
162
+
163
+ class PixtralHFTransformer(nn.Module):
164
+ """Transformer for PixtralHFVisionModel using SGLang components."""
165
+
166
+ def __init__(
167
+ self,
168
+ config: PixtralVisionConfig,
169
+ quant_config: Optional[QuantizationConfig] = None,
170
+ *,
171
+ num_hidden_layers_override: Optional[int] = None,
172
+ prefix: str = "",
173
+ ) -> None:
174
+ super().__init__()
175
+
176
+ num_hidden_layers = config.num_hidden_layers
177
+ if num_hidden_layers_override is not None:
178
+ num_hidden_layers = num_hidden_layers_override
179
+
180
+ self.layers = nn.ModuleList(
181
+ [
182
+ PixtralHFTransformerBlock(
183
+ config=config,
184
+ layer_id=layer_idx,
185
+ quant_config=quant_config,
186
+ prefix=f"{prefix}.layers.{layer_idx}",
187
+ )
188
+ for layer_idx in range(num_hidden_layers)
189
+ ]
190
+ )
191
+
192
+ def forward(
193
+ self,
194
+ x: torch.Tensor,
195
+ attention_mask: Optional[torch.Tensor],
196
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
197
+ return_all_hidden_states: bool = False,
198
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
199
+ """Forward pass through transformer layers.
200
+
201
+ Args:
202
+ x: Input tensor
203
+ attention_mask: Optional attention mask
204
+ position_embeddings: Optional position embeddings for rotary attention
205
+ return_all_hidden_states: Whether to return all hidden states
206
+
207
+ Returns:
208
+ Either the final hidden state, or a list of all hidden states if
209
+ return_all_hidden_states is True
210
+ """
211
+ # For HF model compatibility, always start with the input
212
+ hidden_states = x
213
+ all_hidden_states = [hidden_states] if return_all_hidden_states else None
214
+
215
+ for i, layer in enumerate(self.layers):
216
+ hidden_states = layer(hidden_states, attention_mask, position_embeddings)
217
+ if return_all_hidden_states:
218
+ all_hidden_states.append(hidden_states)
219
+
220
+ if return_all_hidden_states:
221
+ return all_hidden_states
222
+ return hidden_states
223
+
224
+
225
+ def resolve_visual_encoder_outputs(
226
+ outputs: Union[torch.Tensor, List[torch.Tensor]],
227
+ feature_sample_layers: Optional[List[int]],
228
+ post_norm: Optional[nn.Module],
229
+ num_hidden_layers: int,
230
+ ) -> torch.Tensor:
231
+ """Resolve outputs from visual encoder based on feature_sample_layers."""
232
+ if feature_sample_layers is None:
233
+ # Just use the last layer's output
234
+ if isinstance(outputs, list):
235
+ outputs = outputs[-1]
236
+ if post_norm is not None:
237
+ outputs = post_norm(outputs)
238
+ return outputs
239
+
240
+ # Handle the case where we want to use specific layers
241
+ if not isinstance(outputs, list):
242
+ raise ValueError(
243
+ "Expected outputs to be a list when feature_sample_layers is provided"
244
+ )
245
+
246
+ # Validate layer indices
247
+ for layer_idx in feature_sample_layers:
248
+ if layer_idx < 0 or layer_idx > num_hidden_layers:
249
+ raise ValueError(
250
+ f"Feature sample layer index {layer_idx} is out of range "
251
+ f"[0, {num_hidden_layers}]"
252
+ )
253
+
254
+ # Collect outputs from specified layers
255
+ selected_outputs = [outputs[layer_idx] for layer_idx in feature_sample_layers]
256
+
257
+ # Combine the outputs
258
+ combined_outputs = torch.cat(selected_outputs, dim=-1)
259
+
260
+ if post_norm is not None:
261
+ combined_outputs = post_norm(combined_outputs)
262
+
263
+ return combined_outputs
264
+
265
+
266
+ class PixtralHFVisionModel(nn.Module):
267
+ """Hugging Face Pixtral Vision Model implemented using SGLang components."""
268
+
269
+ DEFAULT_IMAGE_TOKEN_ID = 10
270
+
271
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
272
+ return self.input_padder.pad_input_tokens(input_ids, image_inputs)
273
+
274
+ def __init__(
275
+ self,
276
+ config: PixtralVisionConfig,
277
+ quant_config: Optional[QuantizationConfig] = None,
278
+ *,
279
+ image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
280
+ num_hidden_layers_override: Optional[int] = None,
281
+ prefix: str = "",
282
+ ) -> None:
283
+ super().__init__()
284
+
285
+ self.config = config
286
+
287
+ self.image_size = config.image_size
288
+ self.patch_size = config.patch_size
289
+
290
+ self.patch_conv = nn.Conv2d(
291
+ in_channels=config.num_channels,
292
+ out_channels=config.hidden_size,
293
+ kernel_size=config.patch_size,
294
+ stride=config.patch_size,
295
+ bias=False,
296
+ )
297
+
298
+ self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
299
+
300
+ self.transformer = PixtralHFTransformer(
301
+ config,
302
+ quant_config,
303
+ num_hidden_layers_override=num_hidden_layers_override,
304
+ prefix=f"{prefix}.transformer",
305
+ )
306
+
307
+ # Check that num_hidden_layers is valid
308
+ num_hidden_layers = config.num_hidden_layers
309
+ if len(self.transformer.layers) > config.num_hidden_layers:
310
+ raise ValueError(
311
+ f"The original encoder only has {num_hidden_layers} "
312
+ f"layers, but you requested {len(self.transformer.layers)} "
313
+ "layers."
314
+ )
315
+
316
+ # Initialize patch position embedding
317
+ self.image_token_id = image_token_id
318
+ self.patch_positional_embedding = PixtralRotaryEmbedding(config)
319
+ self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
320
+ [self.image_token_id]
321
+ )
322
+
323
+ @property
324
+ def dtype(self):
325
+ return next(self.parameters()).dtype
326
+
327
+ @property
328
+ def device(self):
329
+ return next(self.parameters()).device
330
+
331
+ def forward(
332
+ self,
333
+ pixel_values: torch.Tensor,
334
+ image_sizes: list[tuple[int, int]],
335
+ output_hidden_states: bool = False,
336
+ feature_sample_layers: Optional[list[int]] = None,
337
+ ) -> Union[torch.Tensor, tuple]:
338
+ """
339
+ Args:
340
+ pixel_values: [batch_size, C, H, W], padded if multiple images
341
+ image_sizes: list of (H, W) for each image in the batch
342
+ output_hidden_states: Whether to return all hidden states.
343
+ feature_sample_layers: Layer indices whose features should be
344
+ concatenated and used as the visual encoder output. If none
345
+ are provided, the last layer is used.
346
+
347
+ Returns:
348
+ A tuple containing:
349
+ - hidden_states: Final model outputs (or selected layers if feature_sample_layers given)
350
+ - hidden_states tuple (optional): All hidden states if output_hidden_states=True
351
+ """
352
+ # batch patch images
353
+ embeds_orig = self.patch_conv(
354
+ pixel_values.to(device=self.device, dtype=self.dtype)
355
+ )
356
+ # crop the embeddings
357
+ embeds_2d = [
358
+ embed[..., : h // self.patch_size, : w // self.patch_size]
359
+ for embed, (h, w) in zip(embeds_orig, image_sizes)
360
+ ]
361
+
362
+ # flatten to sequence
363
+ embeds_1d = torch.cat([p.flatten(1).T for p in embeds_2d], dim=0)
364
+ embeds_featurized = self.ln_pre(embeds_1d).unsqueeze(0)
365
+
366
+ # positional embeddings
367
+ position_ids = position_ids_in_meshgrid(
368
+ embeds_2d,
369
+ max_width=self.image_size // self.patch_size,
370
+ ).to(self.device)
371
+
372
+ # The original PixtralRotaryEmbedding expects 2D input but returns a tuple of tensors (cos, sin)
373
+ # These tensors are used by apply_rotary_pos_emb in the transformer blocks
374
+ position_embedding = self.patch_positional_embedding(
375
+ embeds_featurized, position_ids
376
+ )
377
+ attention_mask = _get_pixtral_attention_mask(
378
+ [p.shape[-2] * p.shape[-1] for p in embeds_2d], embeds_featurized
379
+ )
380
+
381
+ return_all_hidden_states = (
382
+ output_hidden_states or feature_sample_layers is not None
383
+ )
384
+
385
+ transformer_outputs = self.transformer(
386
+ embeds_featurized, # add batch dimension
387
+ attention_mask,
388
+ position_embedding,
389
+ return_all_hidden_states=return_all_hidden_states,
390
+ )
391
+
392
+ # Store all hidden states if requested
393
+ all_hidden_states = None
394
+ if isinstance(transformer_outputs, list):
395
+ all_hidden_states = transformer_outputs
396
+ # Use the last layer by default if feature_sample_layers is not specified
397
+ if feature_sample_layers is None:
398
+ out = transformer_outputs[-1]
399
+ else:
400
+ # Resolve outputs based on feature sample layers
401
+ out = resolve_visual_encoder_outputs(
402
+ transformer_outputs,
403
+ feature_sample_layers,
404
+ None,
405
+ self.config.num_hidden_layers,
406
+ )
407
+ else:
408
+ out = transformer_outputs
409
+
410
+ # Format return to be compatible with HuggingFace vision models
411
+ if output_hidden_states:
412
+ return type(
413
+ "VisualOutput",
414
+ (),
415
+ {
416
+ "last_hidden_state": out,
417
+ "hidden_states": all_hidden_states,
418
+ },
419
+ )
420
+ else:
421
+ return out
422
+
423
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
424
+ """Load weights from a HuggingFace checkpoint with proper parameter mapping."""
425
+ params_dict = dict(self.named_parameters())
426
+
427
+ # for (param, weight, shard_id): load weight into param as param's shard_id part
428
+ stacked_params_mapping = [
429
+ (".attention.qkv_proj", ".attention.q_proj", "q"),
430
+ (".attention.qkv_proj", ".attention.k_proj", "k"),
431
+ (".attention.qkv_proj", ".attention.v_proj", "v"),
432
+ (".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
433
+ (".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
434
+ ]
435
+
436
+ # Process each weight
437
+ for name, loaded_weight in weights:
438
+ for param_name, weight_name, shard_id in stacked_params_mapping:
439
+ if weight_name in name:
440
+ # Replace the weight name part with the combined parameter name
441
+ transformed_name = name.replace(weight_name, param_name)
442
+ if transformed_name in params_dict:
443
+ param = params_dict[transformed_name]
444
+ weight_loader = getattr(
445
+ param, "weight_loader", default_weight_loader
446
+ )
447
+ weight_loader(param, loaded_weight, shard_id)
448
+ break
449
+ else:
450
+ if ".attention.o_proj" in name:
451
+ alt_name = name.replace(".attention.o_proj", ".attention.proj")
452
+ if alt_name in params_dict:
453
+ name = alt_name
454
+ if name in params_dict:
455
+ param = params_dict[name]
456
+ weight_loader = getattr(
457
+ param, "weight_loader", default_weight_loader
458
+ )
459
+ weight_loader(param, loaded_weight)
460
+
461
+
462
+ class PixtralVisionModel(PixtralHFVisionModel):
463
+ pass
464
+
465
+
466
+ # Register the model classes for external access
467
+ EntryClass = [PixtralVisionModel]
@@ -125,16 +125,20 @@ class Qwen2_5_VisionBlock(nn.Module):
125
125
  self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
126
126
  self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
127
127
  if attn_implementation == "sdpa":
128
- use_context_forward = False
129
128
  softmax_in_single_precision = False
129
+ qkv_backend = "sdpa"
130
130
  flatten_batch = True
131
131
  elif attn_implementation == "flash_attention_2":
132
132
  softmax_in_single_precision = False
133
- use_context_forward = True
133
+ qkv_backend = "triton_attn"
134
134
  flatten_batch = True
135
135
  elif attn_implementation == "eager":
136
136
  softmax_in_single_precision = True
137
- use_context_forward = False
137
+ qkv_backend = "sdpa"
138
+ flatten_batch = True
139
+ elif attn_implementation == "flash_attention_3":
140
+ softmax_in_single_precision = False
141
+ qkv_backend = "fa3"
138
142
  flatten_batch = True
139
143
 
140
144
  self.attn = VisionAttention(
@@ -142,7 +146,7 @@ class Qwen2_5_VisionBlock(nn.Module):
142
146
  num_heads=num_heads,
143
147
  projection_size=dim,
144
148
  use_qkv_parallel=True,
145
- use_context_forward=use_context_forward,
149
+ qkv_backend=qkv_backend,
146
150
  softmax_in_single_precision=softmax_in_single_precision,
147
151
  flatten_batch=flatten_batch,
148
152
  quant_config=quant_config,
@@ -139,21 +139,21 @@ class Qwen2VisionBlock(nn.Module):
139
139
  self.norm2 = norm_layer(dim)
140
140
  mlp_hidden_dim = int(dim * mlp_ratio)
141
141
  if attn_implementation == "sdpa":
142
- use_context_forward = False
142
+ qkv_backend = "sdpa"
143
143
  softmax_in_single_precision = False
144
144
  elif attn_implementation == "flash_attention_2":
145
+ qkv_backend = "triton_attn"
145
146
  softmax_in_single_precision = False
146
- use_context_forward = True
147
147
  elif attn_implementation == "eager":
148
+ qkv_backend = "sdpa"
148
149
  softmax_in_single_precision = True
149
- use_context_forward = False
150
150
 
151
151
  self.attn = VisionAttention(
152
152
  embed_dim=dim,
153
153
  num_heads=num_heads,
154
154
  projection_size=dim,
155
155
  use_qkv_parallel=True,
156
- use_context_forward=use_context_forward,
156
+ qkv_backend=qkv_backend,
157
157
  softmax_in_single_precision=softmax_in_single_precision,
158
158
  flatten_batch=True,
159
159
  quant_config=quant_config,
@@ -57,7 +57,7 @@ class RobertaEmbedding(nn.Module):
57
57
  input_shape = input_ids.size()
58
58
  inputs_embeds = self.word_embeddings(input_ids)
59
59
 
60
- # adpated from vllm: https://github.com/vllm-project/vllm/commit/4a18fd14ba4a349291c798a16bf62fa8a9af0b6b/vllm/model_executor/models/roberta.py
60
+ # Adapted from vllm: https://github.com/vllm-project/vllm/commit/4a18fd14ba4a349291c798a16bf62fa8a9af0b6b/vllm/model_executor/models/roberta.py
61
61
 
62
62
  pos_list = []
63
63
  token_list = []
@@ -37,7 +37,7 @@ $ python3 -m sglang.bench_one_batch --correct \
37
37
  --tensor-parallel-size 2 \
38
38
  --disable-cuda-graph
39
39
  ```
40
- We will eanble CUDA Graph support soon.
40
+ We will enable CUDA Graph support soon.
41
41
  """
42
42
 
43
43
  import types