sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +9 -7
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +1 -0
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +48 -43
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +7 -2
  20. sglang/srt/disaggregation/fake/conn.py +1 -1
  21. sglang/srt/disaggregation/mooncake/conn.py +227 -120
  22. sglang/srt/disaggregation/nixl/conn.py +1 -0
  23. sglang/srt/disaggregation/prefill.py +7 -4
  24. sglang/srt/disaggregation/utils.py +7 -1
  25. sglang/srt/entrypoints/engine.py +17 -2
  26. sglang/srt/entrypoints/http_server.py +17 -2
  27. sglang/srt/function_call_parser.py +2 -2
  28. sglang/srt/layers/attention/flashattention_backend.py +1 -1
  29. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  30. sglang/srt/layers/attention/utils.py +4 -2
  31. sglang/srt/layers/dp_attention.py +71 -21
  32. sglang/srt/layers/layernorm.py +1 -1
  33. sglang/srt/layers/logits_processor.py +46 -11
  34. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  35. sglang/srt/layers/moe/ep_moe/layer.py +1 -1
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  37. sglang/srt/layers/moe/topk.py +1 -1
  38. sglang/srt/layers/quantization/__init__.py +1 -1
  39. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  40. sglang/srt/layers/quantization/deep_gemm.py +72 -71
  41. sglang/srt/layers/quantization/fp8.py +2 -2
  42. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  43. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  44. sglang/srt/layers/sampler.py +0 -4
  45. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  46. sglang/srt/lora/lora_manager.py +1 -1
  47. sglang/srt/lora/mem_pool.py +4 -4
  48. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  49. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  50. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  51. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  52. sglang/srt/lora/utils.py +1 -1
  53. sglang/srt/managers/data_parallel_controller.py +3 -3
  54. sglang/srt/managers/detokenizer_manager.py +21 -8
  55. sglang/srt/managers/io_struct.py +3 -1
  56. sglang/srt/managers/mm_utils.py +1 -1
  57. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  58. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  59. sglang/srt/managers/schedule_batch.py +76 -24
  60. sglang/srt/managers/schedule_policy.py +0 -3
  61. sglang/srt/managers/scheduler.py +113 -88
  62. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  63. sglang/srt/managers/tokenizer_manager.py +133 -34
  64. sglang/srt/managers/tp_worker.py +12 -9
  65. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  66. sglang/srt/mem_cache/memory_pool.py +2 -0
  67. sglang/srt/metrics/collector.py +312 -37
  68. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  69. sglang/srt/model_executor/forward_batch_info.py +1 -1
  70. sglang/srt/model_executor/model_runner.py +19 -14
  71. sglang/srt/models/deepseek_janus_pro.py +2 -2
  72. sglang/srt/models/deepseek_v2.py +23 -20
  73. sglang/srt/models/llama.py +2 -0
  74. sglang/srt/models/llama4.py +5 -6
  75. sglang/srt/models/llava.py +248 -5
  76. sglang/srt/models/mixtral.py +98 -34
  77. sglang/srt/models/pixtral.py +467 -0
  78. sglang/srt/models/roberta.py +1 -1
  79. sglang/srt/models/torch_native_llama.py +1 -1
  80. sglang/srt/openai_api/adapter.py +30 -4
  81. sglang/srt/openai_api/protocol.py +0 -8
  82. sglang/srt/reasoning_parser.py +3 -3
  83. sglang/srt/sampling/custom_logit_processor.py +18 -3
  84. sglang/srt/sampling/sampling_batch_info.py +4 -56
  85. sglang/srt/sampling/sampling_params.py +2 -2
  86. sglang/srt/server_args.py +34 -4
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +7 -7
  89. sglang/srt/speculative/eagle_worker.py +22 -19
  90. sglang/srt/utils.py +6 -5
  91. sglang/test/few_shot_gsm8k.py +2 -2
  92. sglang/test/few_shot_gsm8k_engine.py +2 -2
  93. sglang/test/run_eval.py +2 -2
  94. sglang/test/runners.py +8 -1
  95. sglang/test/send_one.py +13 -3
  96. sglang/test/simple_eval_common.py +1 -1
  97. sglang/test/simple_eval_humaneval.py +1 -1
  98. sglang/test/test_programs.py +5 -5
  99. sglang/test/test_utils.py +89 -14
  100. sglang/utils.py +1 -1
  101. sglang/version.py +1 -1
  102. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
  103. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
  104. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  105. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
  106. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -16,13 +16,15 @@
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
17
17
  """Inference-only Mixtral model."""
18
18
 
19
- from typing import Iterable, Optional, Tuple
19
+ import logging
20
+ from typing import Iterable, Optional, Tuple, Union
20
21
 
21
22
  import torch
22
23
  from torch import nn
23
24
  from transformers import MixtralConfig
24
25
 
25
26
  from sglang.srt.distributed import (
27
+ get_pp_group,
26
28
  get_tensor_model_parallel_world_size,
27
29
  tensor_model_parallel_all_reduce,
28
30
  )
@@ -38,14 +40,17 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
38
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
41
  from sglang.srt.layers.radix_attention import RadixAttention
40
42
  from sglang.srt.layers.rotary_embedding import get_rope
43
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
41
44
  from sglang.srt.layers.vocab_parallel_embedding import (
42
45
  ParallelLMHead,
43
46
  VocabParallelEmbedding,
44
47
  )
45
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
46
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
47
50
  from sglang.srt.model_loader.weight_utils import default_weight_loader
48
- from sglang.srt.utils import add_prefix
51
+ from sglang.srt.utils import add_prefix, make_layers
52
+
53
+ logger = logging.getLogger(__name__)
49
54
 
50
55
 
51
56
  class MixtralMoE(nn.Module):
@@ -257,24 +262,32 @@ class MixtralModel(nn.Module):
257
262
  super().__init__()
258
263
  self.padding_idx = config.pad_token_id
259
264
  self.vocab_size = config.vocab_size
265
+ self.pp_group = get_pp_group()
260
266
 
261
- self.embed_tokens = VocabParallelEmbedding(
262
- config.vocab_size,
263
- config.hidden_size,
264
- prefix=add_prefix("embed_tokens", prefix),
265
- )
266
- self.layers = nn.ModuleList(
267
- [
268
- MixtralDecoderLayer(
269
- config,
270
- i,
271
- quant_config=quant_config,
272
- prefix=add_prefix(f"layers.{i}", prefix),
273
- )
274
- for i in range(config.num_hidden_layers)
275
- ]
267
+ if self.pp_group.is_first_rank:
268
+ self.embed_tokens = VocabParallelEmbedding(
269
+ config.vocab_size,
270
+ config.hidden_size,
271
+ prefix=add_prefix("embed_tokens", prefix),
272
+ )
273
+ else:
274
+ self.embed_tokens = PPMissingLayer()
275
+
276
+ self.layers, self.start_layer, self.end_layer = make_layers(
277
+ config.num_hidden_layers,
278
+ lambda idx, prefix: MixtralDecoderLayer(
279
+ config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
280
+ ),
281
+ pp_rank=self.pp_group.rank_in_group,
282
+ pp_size=self.pp_group.world_size,
283
+ prefix="layers",
284
+ return_tuple=True,
276
285
  )
277
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
286
+
287
+ if self.pp_group.is_last_rank:
288
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
289
+ else:
290
+ self.norm = PPMissingLayer(return_tuple=True)
278
291
 
279
292
  def forward(
280
293
  self,
@@ -282,18 +295,35 @@ class MixtralModel(nn.Module):
282
295
  positions: torch.Tensor,
283
296
  forward_batch: ForwardBatch,
284
297
  input_embeds: torch.Tensor = None,
285
- ) -> torch.Tensor:
286
- if input_embeds is None:
287
- hidden_states = self.embed_tokens(input_ids)
298
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
299
+ ) -> Union[torch.Tensor, PPProxyTensors]:
300
+ if self.pp_group.is_first_rank:
301
+ if input_embeds is None:
302
+ hidden_states = self.embed_tokens(input_ids)
303
+ else:
304
+ hidden_states = input_embeds
305
+ residual = None
288
306
  else:
289
- hidden_states = input_embeds
290
- residual = None
291
- for i in range(len(self.layers)):
307
+ assert pp_proxy_tensors is not None
308
+ hidden_states = pp_proxy_tensors["hidden_states"]
309
+ residual = pp_proxy_tensors["residual"]
310
+
311
+ for i in range(self.start_layer, self.end_layer):
292
312
  layer = self.layers[i]
293
313
  hidden_states, residual = layer(
294
314
  positions, hidden_states, forward_batch, residual
295
315
  )
296
- hidden_states, _ = self.norm(hidden_states, residual)
316
+
317
+ if not self.pp_group.is_last_rank:
318
+ return PPProxyTensors(
319
+ {
320
+ "hidden_states": hidden_states,
321
+ "residual": residual,
322
+ }
323
+ )
324
+ else:
325
+ hidden_states, _ = self.norm(hidden_states, residual)
326
+
297
327
  return hidden_states
298
328
 
299
329
 
@@ -306,6 +336,7 @@ class MixtralForCausalLM(nn.Module):
306
336
  prefix: str = "",
307
337
  ) -> None:
308
338
  super().__init__()
339
+ self.pp_group = get_pp_group()
309
340
  self.config = config
310
341
  self.quant_config = quant_config
311
342
  self.model = MixtralModel(
@@ -322,12 +353,31 @@ class MixtralForCausalLM(nn.Module):
322
353
  positions: torch.Tensor,
323
354
  forward_batch: ForwardBatch,
324
355
  input_embeds: torch.Tensor = None,
356
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
325
357
  ) -> torch.Tensor:
326
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
327
- return self.logits_processor(
328
- input_ids, hidden_states, self.lm_head, forward_batch
358
+ hidden_states = self.model(
359
+ input_ids,
360
+ positions,
361
+ forward_batch,
362
+ input_embeds,
363
+ pp_proxy_tensors=pp_proxy_tensors,
329
364
  )
330
365
 
366
+ if self.pp_group.is_last_rank:
367
+ return self.logits_processor(
368
+ input_ids, hidden_states, self.lm_head, forward_batch
369
+ )
370
+ else:
371
+ return hidden_states
372
+
373
+ @property
374
+ def start_layer(self):
375
+ return self.model.start_layer
376
+
377
+ @property
378
+ def end_layer(self):
379
+ return self.model.end_layer
380
+
331
381
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
332
382
  stacked_params_mapping = [
333
383
  # (param_name, shard_name, shard_id)
@@ -348,6 +398,17 @@ class MixtralForCausalLM(nn.Module):
348
398
 
349
399
  params_dict = dict(self.named_parameters())
350
400
  for name, loaded_weight in weights:
401
+ layer_id = get_layer_id(name)
402
+ if (
403
+ layer_id is not None
404
+ and hasattr(self.model, "start_layer")
405
+ and (
406
+ layer_id < self.model.start_layer
407
+ or layer_id >= self.model.end_layer
408
+ )
409
+ ):
410
+ continue
411
+
351
412
  if "rotary_emb.inv_freq" in name:
352
413
  continue
353
414
 
@@ -398,11 +459,14 @@ class MixtralForCausalLM(nn.Module):
398
459
  if name is None:
399
460
  continue
400
461
 
401
- param = params_dict[name]
402
- weight_loader = getattr(
403
- param, "weight_loader", default_weight_loader
404
- )
405
- weight_loader(param, loaded_weight)
462
+ if name in params_dict.keys():
463
+ param = params_dict[name]
464
+ weight_loader = getattr(
465
+ param, "weight_loader", default_weight_loader
466
+ )
467
+ weight_loader(param, loaded_weight)
468
+ else:
469
+ logger.warning(f"Parameter {name} not found in params_dict")
406
470
 
407
471
 
408
472
  EntryClass = MixtralForCausalLM
@@ -0,0 +1,467 @@
1
+ # Copyright 2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ """
16
+ Using mistral-community/pixtral-12b as reference.
17
+ """
18
+
19
+ import logging
20
+ import math
21
+ from typing import Iterable, List, Optional, Set, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from transformers import PixtralVisionConfig, PretrainedConfig
27
+ from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding
28
+ from transformers.models.pixtral.modeling_pixtral import (
29
+ generate_block_attention_mask as _get_pixtral_attention_mask,
30
+ )
31
+ from transformers.models.pixtral.modeling_pixtral import position_ids_in_meshgrid
32
+
33
+ from sglang.srt.layers.activation import SiluAndMul
34
+ from sglang.srt.layers.attention.vision import VisionAttention
35
+ from sglang.srt.layers.layernorm import RMSNorm
36
+ from sglang.srt.layers.linear import MergedColumnParallelLinear, RowParallelLinear
37
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
+ from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
39
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
40
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
41
+
42
+
43
+ class PixtralHFMLP(nn.Module):
44
+ """MLP for PixtralHFVisionModel using SGLang components."""
45
+
46
+ def __init__(
47
+ self,
48
+ config: PretrainedConfig,
49
+ quant_config: Optional[QuantizationConfig] = None,
50
+ *,
51
+ prefix: str = "",
52
+ ) -> None:
53
+ super().__init__()
54
+
55
+ assert config.intermediate_size is not None
56
+
57
+ # Use MergedColumnParallelLinear for gate_up_proj to handle combined weights
58
+ self.gate_up_proj = MergedColumnParallelLinear(
59
+ input_size=config.hidden_size,
60
+ output_sizes=[config.intermediate_size, config.intermediate_size],
61
+ bias=False,
62
+ quant_config=quant_config,
63
+ prefix=f"{prefix}.gate_up_proj",
64
+ )
65
+
66
+ self.down_proj = RowParallelLinear(
67
+ input_size=config.intermediate_size,
68
+ output_size=config.hidden_size,
69
+ bias=False,
70
+ quant_config=quant_config,
71
+ prefix=f"{prefix}.down_proj",
72
+ )
73
+
74
+ self.act_fn = SiluAndMul()
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ gate_up_output, _ = self.gate_up_proj(x)
78
+
79
+ # Apply SiLU activation and multiply
80
+ gate_up = self.act_fn(gate_up_output)
81
+
82
+ # Project back to hidden size
83
+ out, _ = self.down_proj(gate_up)
84
+ return out
85
+
86
+
87
+ class PixtralHFTransformerBlock(nn.Module):
88
+ """Transformer block for PixtralHFVisionModel using SGLang components."""
89
+
90
+ def __init__(
91
+ self,
92
+ config: PretrainedConfig,
93
+ layer_id: int,
94
+ quant_config: Optional[QuantizationConfig] = None,
95
+ *,
96
+ prefix: str = "",
97
+ ) -> None:
98
+ super().__init__()
99
+
100
+ self.layer_id = layer_id
101
+ self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
102
+
103
+ # Use SGLang's VisionAttention instead of vLLM's PixtralHFAttention
104
+ self.attention = VisionAttention(
105
+ embed_dim=config.hidden_size,
106
+ num_heads=config.num_attention_heads,
107
+ projection_size=config.hidden_size,
108
+ use_qkv_parallel=True,
109
+ quant_config=quant_config,
110
+ dropout=0.0,
111
+ use_context_forward=False,
112
+ softmax_in_single_precision=False,
113
+ flatten_batch=False,
114
+ prefix=f"{prefix}.attention",
115
+ )
116
+
117
+ self.feed_forward = PixtralHFMLP(
118
+ config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
119
+ )
120
+
121
+ self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
122
+
123
+ def forward(
124
+ self,
125
+ hidden_states: torch.Tensor,
126
+ attention_mask: Optional[torch.Tensor],
127
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
128
+ ) -> torch.Tensor:
129
+ # Ensure hidden_states has the batch dimension [batch, seq_len, hidden_dim]
130
+ batch_size, seq_len, hidden_dim = hidden_states.shape
131
+
132
+ # Apply attention norm - normalize along the last dimension
133
+ attn_normalized = self.attention_norm(hidden_states.view(-1, hidden_dim)).view(
134
+ batch_size, seq_len, hidden_dim
135
+ )
136
+
137
+ # Pass through attention layer
138
+ attention_output = self.attention(
139
+ attn_normalized,
140
+ attention_mask=attention_mask,
141
+ cu_seqlens=None,
142
+ position_embeddings=position_embeddings,
143
+ )
144
+
145
+ # Apply first residual connection
146
+ hidden_states = hidden_states + attention_output
147
+
148
+ # Apply feed-forward norm - normalize along the last dimension
149
+ ffn_normalized = self.ffn_norm(hidden_states.view(-1, hidden_dim)).view(
150
+ batch_size, seq_len, hidden_dim
151
+ )
152
+
153
+ # Pass through feed-forward layer
154
+ # First reshape to 2D for the feed-forward network, then reshape back
155
+ ffn_output = self.feed_forward(ffn_normalized)
156
+
157
+ # Apply second residual connection
158
+ output = hidden_states + ffn_output
159
+
160
+ return output
161
+
162
+
163
+ class PixtralHFTransformer(nn.Module):
164
+ """Transformer for PixtralHFVisionModel using SGLang components."""
165
+
166
+ def __init__(
167
+ self,
168
+ config: PixtralVisionConfig,
169
+ quant_config: Optional[QuantizationConfig] = None,
170
+ *,
171
+ num_hidden_layers_override: Optional[int] = None,
172
+ prefix: str = "",
173
+ ) -> None:
174
+ super().__init__()
175
+
176
+ num_hidden_layers = config.num_hidden_layers
177
+ if num_hidden_layers_override is not None:
178
+ num_hidden_layers = num_hidden_layers_override
179
+
180
+ self.layers = nn.ModuleList(
181
+ [
182
+ PixtralHFTransformerBlock(
183
+ config=config,
184
+ layer_id=layer_idx,
185
+ quant_config=quant_config,
186
+ prefix=f"{prefix}.layers.{layer_idx}",
187
+ )
188
+ for layer_idx in range(num_hidden_layers)
189
+ ]
190
+ )
191
+
192
+ def forward(
193
+ self,
194
+ x: torch.Tensor,
195
+ attention_mask: Optional[torch.Tensor],
196
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
197
+ return_all_hidden_states: bool = False,
198
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
199
+ """Forward pass through transformer layers.
200
+
201
+ Args:
202
+ x: Input tensor
203
+ attention_mask: Optional attention mask
204
+ position_embeddings: Optional position embeddings for rotary attention
205
+ return_all_hidden_states: Whether to return all hidden states
206
+
207
+ Returns:
208
+ Either the final hidden state, or a list of all hidden states if
209
+ return_all_hidden_states is True
210
+ """
211
+ # For HF model compatibility, always start with the input
212
+ hidden_states = x
213
+ all_hidden_states = [hidden_states] if return_all_hidden_states else None
214
+
215
+ for i, layer in enumerate(self.layers):
216
+ hidden_states = layer(hidden_states, attention_mask, position_embeddings)
217
+ if return_all_hidden_states:
218
+ all_hidden_states.append(hidden_states)
219
+
220
+ if return_all_hidden_states:
221
+ return all_hidden_states
222
+ return hidden_states
223
+
224
+
225
+ def resolve_visual_encoder_outputs(
226
+ outputs: Union[torch.Tensor, List[torch.Tensor]],
227
+ feature_sample_layers: Optional[List[int]],
228
+ post_norm: Optional[nn.Module],
229
+ num_hidden_layers: int,
230
+ ) -> torch.Tensor:
231
+ """Resolve outputs from visual encoder based on feature_sample_layers."""
232
+ if feature_sample_layers is None:
233
+ # Just use the last layer's output
234
+ if isinstance(outputs, list):
235
+ outputs = outputs[-1]
236
+ if post_norm is not None:
237
+ outputs = post_norm(outputs)
238
+ return outputs
239
+
240
+ # Handle the case where we want to use specific layers
241
+ if not isinstance(outputs, list):
242
+ raise ValueError(
243
+ "Expected outputs to be a list when feature_sample_layers is provided"
244
+ )
245
+
246
+ # Validate layer indices
247
+ for layer_idx in feature_sample_layers:
248
+ if layer_idx < 0 or layer_idx > num_hidden_layers:
249
+ raise ValueError(
250
+ f"Feature sample layer index {layer_idx} is out of range "
251
+ f"[0, {num_hidden_layers}]"
252
+ )
253
+
254
+ # Collect outputs from specified layers
255
+ selected_outputs = [outputs[layer_idx] for layer_idx in feature_sample_layers]
256
+
257
+ # Combine the outputs
258
+ combined_outputs = torch.cat(selected_outputs, dim=-1)
259
+
260
+ if post_norm is not None:
261
+ combined_outputs = post_norm(combined_outputs)
262
+
263
+ return combined_outputs
264
+
265
+
266
+ class PixtralHFVisionModel(nn.Module):
267
+ """Hugging Face Pixtral Vision Model implemented using SGLang components."""
268
+
269
+ DEFAULT_IMAGE_TOKEN_ID = 10
270
+
271
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
272
+ return self.input_padder.pad_input_tokens(input_ids, image_inputs)
273
+
274
+ def __init__(
275
+ self,
276
+ config: PixtralVisionConfig,
277
+ quant_config: Optional[QuantizationConfig] = None,
278
+ *,
279
+ image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
280
+ num_hidden_layers_override: Optional[int] = None,
281
+ prefix: str = "",
282
+ ) -> None:
283
+ super().__init__()
284
+
285
+ self.config = config
286
+
287
+ self.image_size = config.image_size
288
+ self.patch_size = config.patch_size
289
+
290
+ self.patch_conv = nn.Conv2d(
291
+ in_channels=config.num_channels,
292
+ out_channels=config.hidden_size,
293
+ kernel_size=config.patch_size,
294
+ stride=config.patch_size,
295
+ bias=False,
296
+ )
297
+
298
+ self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
299
+
300
+ self.transformer = PixtralHFTransformer(
301
+ config,
302
+ quant_config,
303
+ num_hidden_layers_override=num_hidden_layers_override,
304
+ prefix=f"{prefix}.transformer",
305
+ )
306
+
307
+ # Check that num_hidden_layers is valid
308
+ num_hidden_layers = config.num_hidden_layers
309
+ if len(self.transformer.layers) > config.num_hidden_layers:
310
+ raise ValueError(
311
+ f"The original encoder only has {num_hidden_layers} "
312
+ f"layers, but you requested {len(self.transformer.layers)} "
313
+ "layers."
314
+ )
315
+
316
+ # Initialize patch position embedding
317
+ self.image_token_id = image_token_id
318
+ self.patch_positional_embedding = PixtralRotaryEmbedding(config)
319
+ self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
320
+ [self.image_token_id]
321
+ )
322
+
323
+ @property
324
+ def dtype(self):
325
+ return next(self.parameters()).dtype
326
+
327
+ @property
328
+ def device(self):
329
+ return next(self.parameters()).device
330
+
331
+ def forward(
332
+ self,
333
+ pixel_values: torch.Tensor,
334
+ image_sizes: list[tuple[int, int]],
335
+ output_hidden_states: bool = False,
336
+ feature_sample_layers: Optional[list[int]] = None,
337
+ ) -> Union[torch.Tensor, tuple]:
338
+ """
339
+ Args:
340
+ pixel_values: [batch_size, C, H, W], padded if multiple images
341
+ image_sizes: list of (H, W) for each image in the batch
342
+ output_hidden_states: Whether to return all hidden states.
343
+ feature_sample_layers: Layer indices whose features should be
344
+ concatenated and used as the visual encoder output. If none
345
+ are provided, the last layer is used.
346
+
347
+ Returns:
348
+ A tuple containing:
349
+ - hidden_states: Final model outputs (or selected layers if feature_sample_layers given)
350
+ - hidden_states tuple (optional): All hidden states if output_hidden_states=True
351
+ """
352
+ # batch patch images
353
+ embeds_orig = self.patch_conv(
354
+ pixel_values.to(device=self.device, dtype=self.dtype)
355
+ )
356
+ # crop the embeddings
357
+ embeds_2d = [
358
+ embed[..., : h // self.patch_size, : w // self.patch_size]
359
+ for embed, (h, w) in zip(embeds_orig, image_sizes)
360
+ ]
361
+
362
+ # flatten to sequence
363
+ embeds_1d = torch.cat([p.flatten(1).T for p in embeds_2d], dim=0)
364
+ embeds_featurized = self.ln_pre(embeds_1d).unsqueeze(0)
365
+
366
+ # positional embeddings
367
+ position_ids = position_ids_in_meshgrid(
368
+ embeds_2d,
369
+ max_width=self.image_size // self.patch_size,
370
+ ).to(self.device)
371
+
372
+ # The original PixtralRotaryEmbedding expects 2D input but returns a tuple of tensors (cos, sin)
373
+ # These tensors are used by apply_rotary_pos_emb in the transformer blocks
374
+ position_embedding = self.patch_positional_embedding(
375
+ embeds_featurized, position_ids
376
+ )
377
+ attention_mask = _get_pixtral_attention_mask(
378
+ [p.shape[-2] * p.shape[-1] for p in embeds_2d], embeds_featurized
379
+ )
380
+
381
+ return_all_hidden_states = (
382
+ output_hidden_states or feature_sample_layers is not None
383
+ )
384
+
385
+ transformer_outputs = self.transformer(
386
+ embeds_featurized, # add batch dimension
387
+ attention_mask,
388
+ position_embedding,
389
+ return_all_hidden_states=return_all_hidden_states,
390
+ )
391
+
392
+ # Store all hidden states if requested
393
+ all_hidden_states = None
394
+ if isinstance(transformer_outputs, list):
395
+ all_hidden_states = transformer_outputs
396
+ # Use the last layer by default if feature_sample_layers is not specified
397
+ if feature_sample_layers is None:
398
+ out = transformer_outputs[-1]
399
+ else:
400
+ # Resolve outputs based on feature sample layers
401
+ out = resolve_visual_encoder_outputs(
402
+ transformer_outputs,
403
+ feature_sample_layers,
404
+ None,
405
+ self.config.num_hidden_layers,
406
+ )
407
+ else:
408
+ out = transformer_outputs
409
+
410
+ # Format return to be compatible with HuggingFace vision models
411
+ if output_hidden_states:
412
+ return type(
413
+ "VisualOutput",
414
+ (),
415
+ {
416
+ "last_hidden_state": out,
417
+ "hidden_states": all_hidden_states,
418
+ },
419
+ )
420
+ else:
421
+ return out
422
+
423
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
424
+ """Load weights from a HuggingFace checkpoint with proper parameter mapping."""
425
+ params_dict = dict(self.named_parameters())
426
+
427
+ # for (param, weight, shard_id): load weight into param as param's shard_id part
428
+ stacked_params_mapping = [
429
+ (".attention.qkv_proj", ".attention.q_proj", "q"),
430
+ (".attention.qkv_proj", ".attention.k_proj", "k"),
431
+ (".attention.qkv_proj", ".attention.v_proj", "v"),
432
+ (".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
433
+ (".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
434
+ ]
435
+
436
+ # Process each weight
437
+ for name, loaded_weight in weights:
438
+ for param_name, weight_name, shard_id in stacked_params_mapping:
439
+ if weight_name in name:
440
+ # Replace the weight name part with the combined parameter name
441
+ transformed_name = name.replace(weight_name, param_name)
442
+ if transformed_name in params_dict:
443
+ param = params_dict[transformed_name]
444
+ weight_loader = getattr(
445
+ param, "weight_loader", default_weight_loader
446
+ )
447
+ weight_loader(param, loaded_weight, shard_id)
448
+ break
449
+ else:
450
+ if ".attention.o_proj" in name:
451
+ alt_name = name.replace(".attention.o_proj", ".attention.proj")
452
+ if alt_name in params_dict:
453
+ name = alt_name
454
+ if name in params_dict:
455
+ param = params_dict[name]
456
+ weight_loader = getattr(
457
+ param, "weight_loader", default_weight_loader
458
+ )
459
+ weight_loader(param, loaded_weight)
460
+
461
+
462
+ class PixtralVisionModel(PixtralHFVisionModel):
463
+ pass
464
+
465
+
466
+ # Register the model classes for external access
467
+ EntryClass = [PixtralVisionModel]
@@ -57,7 +57,7 @@ class RobertaEmbedding(nn.Module):
57
57
  input_shape = input_ids.size()
58
58
  inputs_embeds = self.word_embeddings(input_ids)
59
59
 
60
- # adpated from vllm: https://github.com/vllm-project/vllm/commit/4a18fd14ba4a349291c798a16bf62fa8a9af0b6b/vllm/model_executor/models/roberta.py
60
+ # Adapted from vllm: https://github.com/vllm-project/vllm/commit/4a18fd14ba4a349291c798a16bf62fa8a9af0b6b/vllm/model_executor/models/roberta.py
61
61
 
62
62
  pos_list = []
63
63
  token_list = []
@@ -37,7 +37,7 @@ $ python3 -m sglang.bench_one_batch --correct \
37
37
  --tensor-parallel-size 2 \
38
38
  --disable-cuda-graph
39
39
  ```
40
- We will eanble CUDA Graph support soon.
40
+ We will enable CUDA Graph support soon.
41
41
  """
42
42
 
43
43
  import types