sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,10 @@
16
16
  # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
17
17
  """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
18
18
 
19
- from typing import Any, Dict, Iterable, Optional, Tuple
19
+ import logging
20
+ from dataclasses import dataclass
21
+ from enum import Enum, auto
22
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
20
23
 
21
24
  import torch
22
25
  import torch.nn.functional as F
@@ -24,10 +27,20 @@ from torch import nn
24
27
  from transformers import PretrainedConfig
25
28
 
26
29
  from sglang.srt.distributed import (
30
+ get_pp_group,
27
31
  get_tensor_model_parallel_world_size,
28
32
  tensor_model_parallel_all_reduce,
29
33
  )
30
34
  from sglang.srt.layers.activation import SiluAndMul
35
+ from sglang.srt.layers.dp_attention import (
36
+ attn_tp_all_gather,
37
+ attn_tp_reduce_scatter,
38
+ dp_gather_partial,
39
+ dp_scatter,
40
+ get_attention_tp_rank,
41
+ get_attention_tp_size,
42
+ get_local_attention_dp_size,
43
+ )
31
44
  from sglang.srt.layers.layernorm import RMSNorm
32
45
  from sglang.srt.layers.linear import (
33
46
  MergedColumnParallelLinear,
@@ -35,23 +48,28 @@ from sglang.srt.layers.linear import (
35
48
  ReplicatedLinear,
36
49
  RowParallelLinear,
37
50
  )
38
- from sglang.srt.layers.logits_processor import LogitsProcessor
51
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
39
52
  from sglang.srt.layers.moe.ep_moe.layer import EPMoE
40
53
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
41
54
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
55
  from sglang.srt.layers.radix_attention import RadixAttention
43
56
  from sglang.srt.layers.rotary_embedding import get_rope
57
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
44
58
  from sglang.srt.layers.vocab_parallel_embedding import (
45
59
  ParallelLMHead,
46
60
  VocabParallelEmbedding,
47
61
  )
48
- from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
62
+ from sglang.srt.managers.expert_distribution import (
63
+ ExpertDistributionRecorder,
64
+ get_global_expert_distribution_recorder,
65
+ )
66
+ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
49
67
  from sglang.srt.managers.schedule_batch import global_server_args_dict
50
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
68
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
51
69
  from sglang.srt.model_loader.weight_utils import default_weight_loader
52
70
  from sglang.srt.utils import add_prefix, make_layers
53
71
 
54
- expert_distribution_recorder = ExpertDistributionRecorder()
72
+ logger = logging.getLogger(__name__)
55
73
 
56
74
 
57
75
  class Qwen2MoeMLP(nn.Module):
@@ -82,8 +100,7 @@ class Qwen2MoeMLP(nn.Module):
82
100
  )
83
101
  if hidden_act != "silu":
84
102
  raise ValueError(
85
- f"Unsupported activation: {hidden_act}. "
86
- "Only silu is supported for now."
103
+ f"Unsupported activation: {hidden_act}. Only silu is supported for now."
87
104
  )
88
105
  self.act_fn = SiluAndMul()
89
106
 
@@ -160,7 +177,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
160
177
  )
161
178
  if shared_output is not None:
162
179
  final_hidden_states = final_hidden_states + shared_output
163
- if self.tp_size > 1:
164
180
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
165
181
 
166
182
  return final_hidden_states.view(num_tokens, hidden_dim)
@@ -182,20 +198,23 @@ class Qwen2MoeAttention(nn.Module):
182
198
  ) -> None:
183
199
  super().__init__()
184
200
  self.hidden_size = hidden_size
185
- tp_size = get_tensor_model_parallel_world_size()
201
+
202
+ attn_tp_rank = get_attention_tp_rank()
203
+ attn_tp_size = get_attention_tp_size()
204
+
186
205
  self.total_num_heads = num_heads
187
- assert self.total_num_heads % tp_size == 0
188
- self.num_heads = self.total_num_heads // tp_size
206
+ assert self.total_num_heads % attn_tp_size == 0
207
+ self.num_heads = self.total_num_heads // attn_tp_size
189
208
  self.total_num_kv_heads = num_kv_heads
190
- if self.total_num_kv_heads >= tp_size:
209
+ if self.total_num_kv_heads >= attn_tp_size:
191
210
  # Number of KV heads is greater than TP size, so we partition
192
211
  # the KV heads across multiple tensor parallel GPUs.
193
- assert self.total_num_kv_heads % tp_size == 0
212
+ assert self.total_num_kv_heads % attn_tp_size == 0
194
213
  else:
195
214
  # Number of KV heads is less than TP size, so we replicate
196
215
  # the KV heads across multiple tensor parallel GPUs.
197
- assert tp_size % self.total_num_kv_heads == 0
198
- self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
216
+ assert attn_tp_size % self.total_num_kv_heads == 0
217
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
199
218
  self.head_dim = hidden_size // self.total_num_heads
200
219
  self.q_size = self.num_heads * self.head_dim
201
220
  self.kv_size = self.num_kv_heads * self.head_dim
@@ -210,6 +229,8 @@ class Qwen2MoeAttention(nn.Module):
210
229
  self.total_num_kv_heads,
211
230
  bias=qkv_bias,
212
231
  quant_config=quant_config,
232
+ tp_rank=attn_tp_rank,
233
+ tp_size=attn_tp_size,
213
234
  prefix=add_prefix("qkv_proj", prefix),
214
235
  )
215
236
 
@@ -218,6 +239,9 @@ class Qwen2MoeAttention(nn.Module):
218
239
  hidden_size,
219
240
  bias=False,
220
241
  quant_config=quant_config,
242
+ tp_rank=attn_tp_rank,
243
+ tp_size=attn_tp_size,
244
+ reduce_results=False,
221
245
  prefix=add_prefix("o_proj", prefix),
222
246
  )
223
247
 
@@ -252,6 +276,19 @@ class Qwen2MoeAttention(nn.Module):
252
276
  return output
253
277
 
254
278
 
279
+ class _FFNInputMode(Enum):
280
+ # The MLP sublayer requires 1/tp_size tokens as input
281
+ SCATTERED = auto()
282
+ # The MLP sublayer requires all tokens as input
283
+ FULL = auto()
284
+
285
+
286
+ @dataclass
287
+ class _DecoderLayerInfo:
288
+ is_sparse: bool
289
+ ffn_input_mode: _FFNInputMode
290
+
291
+
255
292
  class Qwen2MoeDecoderLayer(nn.Module):
256
293
  def __init__(
257
294
  self,
@@ -279,14 +316,21 @@ class Qwen2MoeDecoderLayer(nn.Module):
279
316
  prefix=add_prefix("self_attn", prefix),
280
317
  )
281
318
 
282
- # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
283
- # `mlp_only_layers` in the config.
284
- mlp_only_layers = (
285
- [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
319
+ self.layer_id = layer_id
320
+
321
+ self.attn_tp_size = get_attention_tp_size()
322
+ self.attn_tp_rank = get_attention_tp_rank()
323
+ self.local_dp_size = get_local_attention_dp_size()
324
+
325
+ self.info = self._compute_info(config, layer_id=layer_id)
326
+ previous_layer_info = self._compute_info(config, layer_id=layer_id - 1)
327
+ self.input_is_scattered = (
328
+ layer_id > 0
329
+ and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
286
330
  )
287
- if (layer_id not in mlp_only_layers) and (
288
- config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
289
- ):
331
+ self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
332
+
333
+ if self.info.is_sparse:
290
334
  self.mlp = Qwen2MoeSparseMoeBlock(
291
335
  config=config,
292
336
  quant_config=quant_config,
@@ -305,28 +349,185 @@ class Qwen2MoeDecoderLayer(nn.Module):
305
349
  config.hidden_size, eps=config.rms_norm_eps
306
350
  )
307
351
 
352
+ @staticmethod
353
+ def _enable_moe_dense_fully_dp():
354
+ return global_server_args_dict["moe_dense_tp_size"] == 1
355
+
356
+ @staticmethod
357
+ def _compute_info(config: PretrainedConfig, layer_id: int):
358
+ # WARN: Qwen2MOE has no dense_layer, it is only for compatibility.
359
+ mlp_only_layers = (
360
+ [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
361
+ )
362
+ is_sparse = (layer_id not in mlp_only_layers) and (
363
+ config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
364
+ )
365
+ ffn_input_mode = (
366
+ _FFNInputMode.SCATTERED
367
+ if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
368
+ or (Qwen2MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
369
+ else _FFNInputMode.FULL
370
+ )
371
+ return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
372
+
308
373
  def forward(
309
374
  self,
310
375
  positions: torch.Tensor,
311
376
  hidden_states: torch.Tensor,
312
377
  forward_batch: ForwardBatch,
313
378
  residual: Optional[torch.Tensor],
314
- ) -> torch.Tensor:
315
- # Self Attention
316
- if residual is None:
379
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
380
+ if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
381
+ return self.forward_ffn_with_scattered_input(
382
+ positions, hidden_states, forward_batch, residual
383
+ )
384
+ elif self.info.ffn_input_mode == _FFNInputMode.FULL:
385
+ return self.forward_ffn_with_full_input(
386
+ positions, hidden_states, forward_batch, residual
387
+ )
388
+ else:
389
+ raise NotImplementedError
390
+
391
+ def forward_ffn_with_full_input(
392
+ self,
393
+ positions: torch.Tensor,
394
+ hidden_states: torch.Tensor,
395
+ forward_batch: ForwardBatch,
396
+ residual: Optional[torch.Tensor],
397
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
398
+ if hidden_states.shape[0] == 0:
317
399
  residual = hidden_states
318
- hidden_states = self.input_layernorm(hidden_states)
319
400
  else:
320
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
321
- hidden_states = self.self_attn(
322
- positions=positions,
323
- hidden_states=hidden_states,
324
- forward_batch=forward_batch,
325
- )
401
+ if residual is None:
402
+ residual = hidden_states
403
+ hidden_states = self.input_layernorm(hidden_states)
404
+ else:
405
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
406
+
407
+ # Self Attention
408
+ hidden_states = self.self_attn(
409
+ positions=positions,
410
+ hidden_states=hidden_states,
411
+ forward_batch=forward_batch,
412
+ )
413
+ # Gather
414
+ if get_tensor_model_parallel_world_size() > 1:
415
+ # all gather and all reduce
416
+ if self.local_dp_size != 1:
417
+ if self.attn_tp_rank == 0:
418
+ hidden_states += residual
419
+ hidden_states, local_hidden_states = (
420
+ forward_batch.gathered_buffer,
421
+ hidden_states,
422
+ )
423
+ dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
424
+ dp_scatter(residual, hidden_states, forward_batch)
425
+ # TODO extract this bugfix
426
+ if hidden_states.shape[0] != 0:
427
+ hidden_states = self.post_attention_layernorm(hidden_states)
428
+ else:
429
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
430
+ # TODO extract this bugfix
431
+ if hidden_states.shape[0] != 0:
432
+ hidden_states, residual = self.post_attention_layernorm(
433
+ hidden_states, residual
434
+ )
435
+ elif hidden_states.shape[0] != 0:
436
+ hidden_states, residual = self.post_attention_layernorm(
437
+ hidden_states, residual
438
+ )
326
439
 
327
440
  # Fully Connected
328
- hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
329
441
  hidden_states = self.mlp(hidden_states)
442
+
443
+ # TODO: use reduce-scatter in MLP to avoid this scatter
444
+ # Scatter
445
+ if self.local_dp_size != 1:
446
+ # important: forward batch.gathered_buffer is used both after scatter and after gather.
447
+ # be careful about this!
448
+ hidden_states, global_hidden_states = (
449
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
450
+ hidden_states,
451
+ )
452
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
453
+
454
+ return hidden_states, residual
455
+
456
+ def forward_ffn_with_scattered_input(
457
+ self,
458
+ positions: torch.Tensor,
459
+ hidden_states: torch.Tensor,
460
+ forward_batch: ForwardBatch,
461
+ residual: Optional[torch.Tensor],
462
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
463
+ if hidden_states.shape[0] == 0:
464
+ residual = hidden_states
465
+ else:
466
+ if residual is None:
467
+ residual = hidden_states
468
+ hidden_states = self.input_layernorm(hidden_states)
469
+ else:
470
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
471
+
472
+ if self.attn_tp_size != 1 and self.input_is_scattered:
473
+ hidden_states, local_hidden_states = (
474
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
475
+ hidden_states,
476
+ )
477
+ attn_tp_all_gather(
478
+ list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
479
+ )
480
+
481
+ # Self Attention
482
+ if hidden_states.shape[0] != 0:
483
+ hidden_states = self.self_attn(
484
+ positions=positions,
485
+ hidden_states=hidden_states,
486
+ forward_batch=forward_batch,
487
+ )
488
+
489
+ if self.attn_tp_size != 1:
490
+ if self.input_is_scattered:
491
+ tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
492
+ hidden_states = tensor_list[self.attn_tp_rank]
493
+ attn_tp_reduce_scatter(hidden_states, tensor_list)
494
+ if hidden_states.shape[0] != 0:
495
+ hidden_states, residual = self.post_attention_layernorm(
496
+ hidden_states, residual
497
+ )
498
+ else:
499
+ if self.attn_tp_rank == 0:
500
+ hidden_states += residual
501
+ tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
502
+ hidden_states = tensor_list[self.attn_tp_rank]
503
+ attn_tp_reduce_scatter(hidden_states, tensor_list)
504
+ residual = hidden_states
505
+ if hidden_states.shape[0] != 0:
506
+ hidden_states = self.post_attention_layernorm(hidden_states)
507
+ else:
508
+ if hidden_states.shape[0] != 0:
509
+ hidden_states, residual = self.post_attention_layernorm(
510
+ hidden_states, residual
511
+ )
512
+
513
+ if not (
514
+ self._enable_moe_dense_fully_dp()
515
+ and (not self.info.is_sparse)
516
+ and hidden_states.shape[0] == 0
517
+ ):
518
+ hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
519
+
520
+ if self.is_last_layer and self.attn_tp_size != 1:
521
+ hidden_states += residual
522
+ residual = None
523
+ hidden_states, local_hidden_states = (
524
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
525
+ hidden_states,
526
+ )
527
+ attn_tp_all_gather(
528
+ list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
529
+ )
530
+
330
531
  return hidden_states, residual
331
532
 
332
533
 
@@ -341,15 +542,21 @@ class Qwen2MoeModel(nn.Module):
341
542
  super().__init__()
342
543
  self.padding_idx = config.pad_token_id
343
544
  self.vocab_size = config.vocab_size
545
+ self.pp_group = get_pp_group()
546
+
547
+ if self.pp_group.is_first_rank:
548
+ self.embed_tokens = VocabParallelEmbedding(
549
+ config.vocab_size,
550
+ config.hidden_size,
551
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
552
+ prefix=add_prefix("embed_tokens", prefix),
553
+ )
554
+ else:
555
+ self.embed_tokens = PPMissingLayer()
344
556
 
345
- self.embed_tokens = VocabParallelEmbedding(
346
- config.vocab_size,
347
- config.hidden_size,
348
- prefix=add_prefix("embed_tokens", prefix),
349
- )
350
557
  # Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
351
558
  decoder_layer_type = decoder_layer_type or Qwen2MoeDecoderLayer
352
- self.layers = make_layers(
559
+ self.layers, self.start_layer, self.end_layer = make_layers(
353
560
  config.num_hidden_layers,
354
561
  lambda idx, prefix: decoder_layer_type(
355
562
  layer_id=idx,
@@ -357,9 +564,14 @@ class Qwen2MoeModel(nn.Module):
357
564
  quant_config=quant_config,
358
565
  prefix=prefix,
359
566
  ),
567
+ pp_rank=self.pp_group.rank_in_group,
568
+ pp_size=self.pp_group.world_size,
360
569
  prefix=add_prefix("layers", prefix),
361
570
  )
362
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
571
+ if self.pp_group.is_last_rank:
572
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
573
+ else:
574
+ self.norm = PPMissingLayer(return_tuple=True)
363
575
 
364
576
  def forward(
365
577
  self,
@@ -367,24 +579,42 @@ class Qwen2MoeModel(nn.Module):
367
579
  positions: torch.Tensor,
368
580
  forward_batch: ForwardBatch,
369
581
  input_embeds: torch.Tensor = None,
370
- ) -> torch.Tensor:
371
- if input_embeds is None:
372
- hidden_states = self.embed_tokens(input_ids)
582
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
583
+ ) -> Union[torch.Tensor, PPProxyTensors]:
584
+ if self.pp_group.is_first_rank:
585
+ if input_embeds is None:
586
+ hidden_states = self.embed_tokens(input_ids)
587
+ else:
588
+ hidden_states = input_embeds
589
+ residual = None
373
590
  else:
374
- hidden_states = input_embeds
375
- residual = None
376
- for i in range(len(self.layers)):
377
- expert_distribution_recorder.set_current_layer(i)
378
- layer = self.layers[i]
379
- hidden_states, residual = layer(
380
- positions, hidden_states, forward_batch, residual
591
+ assert pp_proxy_tensors is not None
592
+ hidden_states = pp_proxy_tensors["hidden_states"]
593
+ residual = pp_proxy_tensors["residual"]
594
+
595
+ for i in range(self.start_layer, self.end_layer):
596
+ with get_global_expert_distribution_recorder().with_current_layer(i):
597
+ layer = self.layers[i]
598
+ hidden_states, residual = layer(
599
+ positions, hidden_states, forward_batch, residual
600
+ )
601
+ if not self.pp_group.is_last_rank:
602
+ return PPProxyTensors(
603
+ {
604
+ "hidden_states": hidden_states,
605
+ "residual": residual,
606
+ }
381
607
  )
382
- hidden_states, _ = self.norm(hidden_states, residual)
608
+ else:
609
+ if hidden_states.shape[0] != 0:
610
+ if residual is None:
611
+ hidden_states = self.norm(hidden_states)
612
+ else:
613
+ hidden_states, _ = self.norm(hidden_states, residual)
383
614
  return hidden_states
384
615
 
385
616
 
386
617
  class Qwen2MoeForCausalLM(nn.Module):
387
-
388
618
  fall_back_to_pt_during_load = False
389
619
 
390
620
  def __init__(
@@ -394,6 +624,7 @@ class Qwen2MoeForCausalLM(nn.Module):
394
624
  prefix: str = "",
395
625
  ) -> None:
396
626
  super().__init__()
627
+ self.pp_group = get_pp_group()
397
628
  self.config = config
398
629
  self.quant_config = quant_config
399
630
  self.model = Qwen2MoeModel(
@@ -414,11 +645,29 @@ class Qwen2MoeForCausalLM(nn.Module):
414
645
  positions: torch.Tensor,
415
646
  forward_batch: ForwardBatch,
416
647
  input_embeds: torch.Tensor = None,
648
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
417
649
  ) -> torch.Tensor:
418
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
419
- return self.logits_processor(
420
- input_ids, hidden_states, self.lm_head, forward_batch
650
+ hidden_states = self.model(
651
+ input_ids,
652
+ positions,
653
+ forward_batch,
654
+ input_embeds,
655
+ pp_proxy_tensors=pp_proxy_tensors,
421
656
  )
657
+ if self.pp_group.is_last_rank:
658
+ return self.logits_processor(
659
+ input_ids, hidden_states, self.lm_head, forward_batch
660
+ )
661
+ else:
662
+ return hidden_states
663
+
664
+ @property
665
+ def start_layer(self):
666
+ return self.model.start_layer
667
+
668
+ @property
669
+ def end_layer(self):
670
+ return self.model.end_layer
422
671
 
423
672
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
424
673
  stacked_params_mapping = [
@@ -441,6 +690,16 @@ class Qwen2MoeForCausalLM(nn.Module):
441
690
 
442
691
  params_dict = dict(self.named_parameters())
443
692
  for name, loaded_weight in weights:
693
+ layer_id = get_layer_id(name)
694
+ if (
695
+ layer_id is not None
696
+ and hasattr(self.model, "start_layer")
697
+ and (
698
+ layer_id < self.model.start_layer
699
+ or layer_id >= self.model.end_layer
700
+ )
701
+ ):
702
+ continue
444
703
  if "rotary_emb.inv_freq" in name:
445
704
  continue
446
705
  for param_name, weight_name, shard_id in stacked_params_mapping:
@@ -489,11 +748,22 @@ class Qwen2MoeForCausalLM(nn.Module):
489
748
  if name not in params_dict:
490
749
  continue
491
750
 
492
- param = params_dict[name]
493
- weight_loader = getattr(
494
- param, "weight_loader", default_weight_loader
495
- )
496
- weight_loader(param, loaded_weight)
751
+ if name in params_dict.keys():
752
+ param = params_dict[name]
753
+ weight_loader = getattr(
754
+ param, "weight_loader", default_weight_loader
755
+ )
756
+ weight_loader(param, loaded_weight)
757
+ else:
758
+ logger.warning(f"Parameter {name} not found in params_dict")
759
+
760
+ @classmethod
761
+ def get_model_config_for_expert_location(cls, config):
762
+ return ModelConfigForExpertLocation(
763
+ num_layers=config.num_hidden_layers,
764
+ num_logical_experts=config.num_experts,
765
+ num_groups=None,
766
+ )
497
767
 
498
768
 
499
769
  EntryClass = Qwen2MoeForCausalLM
@@ -486,6 +486,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
486
486
  return pattern.pad_input_tokens(input_ids, mm_inputs)
487
487
 
488
488
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
489
+ if any(item.precomputed_features is not None for item in items):
490
+ if not all(item.precomputed_features is not None for item in items):
491
+ raise NotImplementedError(
492
+ "MM inputs where only some items are precomputed."
493
+ )
494
+ return torch.concat([item.precomputed_features for item in items])
489
495
  # in qwen-vl, last dim is the same
490
496
  pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
491
497
  self.visual.dtype
@@ -1,5 +1,6 @@
1
1
  # Adapted from qwen2.py
2
2
 
3
+ import logging
3
4
  from functools import partial
4
5
  from typing import Any, Dict, Iterable, Optional, Tuple
5
6
 
@@ -7,6 +8,7 @@ import torch
7
8
  from torch import nn
8
9
 
9
10
  from sglang.srt.distributed import (
11
+ get_pp_group,
10
12
  get_tensor_model_parallel_rank,
11
13
  get_tensor_model_parallel_world_size,
12
14
  split_tensor_along_last_dim,
@@ -19,8 +21,9 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
19
21
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
20
22
  from sglang.srt.layers.radix_attention import RadixAttention
21
23
  from sglang.srt.layers.rotary_embedding import get_rope
24
+ from sglang.srt.layers.utils import get_layer_id
22
25
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
23
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
26
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
24
27
  from sglang.srt.model_loader.weight_utils import default_weight_loader
25
28
  from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
26
29
  from sglang.srt.models.qwen2 import Qwen2Model
@@ -28,6 +31,8 @@ from sglang.srt.utils import add_prefix
28
31
 
29
32
  Qwen3Config = None
30
33
 
34
+ logger = logging.getLogger(__name__)
35
+
31
36
 
32
37
  class Qwen3Attention(nn.Module):
33
38
  def __init__(
@@ -238,6 +243,7 @@ class Qwen3ForCausalLM(nn.Module):
238
243
  prefix: str = "",
239
244
  ) -> None:
240
245
  super().__init__()
246
+ self.pp_group = get_pp_group()
241
247
  self.config = config
242
248
  self.quant_config = quant_config
243
249
  self.model = Qwen3Model(
@@ -266,14 +272,33 @@ class Qwen3ForCausalLM(nn.Module):
266
272
  forward_batch: ForwardBatch,
267
273
  input_embeds: torch.Tensor = None,
268
274
  get_embedding: bool = False,
275
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
269
276
  ) -> torch.Tensor:
270
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
271
- if not get_embedding:
272
- return self.logits_processor(
273
- input_ids, hidden_states, self.lm_head, forward_batch
274
- )
277
+ hidden_states = self.model(
278
+ input_ids,
279
+ positions,
280
+ forward_batch,
281
+ input_embeds,
282
+ pp_proxy_tensors=pp_proxy_tensors,
283
+ )
284
+
285
+ if self.pp_group.is_last_rank:
286
+ if not get_embedding:
287
+ return self.logits_processor(
288
+ input_ids, hidden_states, self.lm_head, forward_batch
289
+ )
290
+ else:
291
+ return self.pooler(hidden_states, forward_batch)
275
292
  else:
276
- return self.pooler(hidden_states, forward_batch)
293
+ return hidden_states
294
+
295
+ @property
296
+ def start_layer(self):
297
+ return self.model.start_layer
298
+
299
+ @property
300
+ def end_layer(self):
301
+ return self.model.end_layer
277
302
 
278
303
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
279
304
  stacked_params_mapping = [
@@ -287,6 +312,17 @@ class Qwen3ForCausalLM(nn.Module):
287
312
 
288
313
  params_dict = dict(self.named_parameters())
289
314
  for name, loaded_weight in weights:
315
+ layer_id = get_layer_id(name)
316
+ if (
317
+ layer_id is not None
318
+ and hasattr(self.model, "start_layer")
319
+ and (
320
+ layer_id < self.model.start_layer
321
+ or layer_id >= self.model.end_layer
322
+ )
323
+ ):
324
+ continue
325
+
290
326
  if "rotary_emb.inv_freq" in name or "projector" in name:
291
327
  continue
292
328
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -313,9 +349,15 @@ class Qwen3ForCausalLM(nn.Module):
313
349
  # Skip loading extra bias for GPTQ models.
314
350
  if name.endswith(".bias") and name not in params_dict:
315
351
  continue
316
- param = params_dict[name]
317
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
318
- weight_loader(param, loaded_weight)
352
+
353
+ if name in params_dict.keys():
354
+ param = params_dict[name]
355
+ weight_loader = getattr(
356
+ param, "weight_loader", default_weight_loader
357
+ )
358
+ weight_loader(param, loaded_weight)
359
+ else:
360
+ logger.warning(f"Parameter {name} not found in params_dict")
319
361
 
320
362
  def get_embed_and_head(self):
321
363
  return self.model.embed_tokens.weight, self.lm_head.weight