sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -363,31 +363,6 @@ class GraniteForCausalLM(nn.Module):
363
363
  else:
364
364
  return self.pooler(hidden_states, forward_batch)
365
365
 
366
- def get_hidden_dim(self, module_name):
367
- # return input_dim, output_dim
368
- if module_name in ["q_proj", "o_proj", "qkv_proj"]:
369
- return self.config.hidden_size, self.config.hidden_size
370
- elif module_name in ["kv_proj"]:
371
- return self.config.hidden_size, self.config.hidden_size // (
372
- self.config.num_attention_heads // self.config.num_key_value_heads
373
- )
374
- elif module_name == "gate_up_proj":
375
- return self.config.hidden_size, self.config.intermediate_size
376
- elif module_name == "down_proj":
377
- return self.config.intermediate_size, self.config.hidden_size
378
- else:
379
- raise NotImplementedError()
380
-
381
- def get_module_name(self, name):
382
- params_mapping = {
383
- "q_proj": "qkv_proj",
384
- "k_proj": "qkv_proj",
385
- "v_proj": "qkv_proj",
386
- "gate_proj": "gate_up_proj",
387
- "up_proj": "gate_up_proj",
388
- }
389
- return params_mapping.get(name, name)
390
-
391
366
  def get_module_name_from_weight_name(self, name):
392
367
  for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
393
368
  if weight_name in name:
@@ -91,10 +91,18 @@ class LlamaMLP(nn.Module):
91
91
  )
92
92
  self.act_fn = SiluAndMul()
93
93
 
94
- def forward(self, x, forward_batch=None):
94
+ def forward(
95
+ self,
96
+ x,
97
+ forward_batch=None,
98
+ use_reduce_scatter: bool = False,
99
+ ):
95
100
  gate_up, _ = self.gate_up_proj(x)
96
101
  x = self.act_fn(gate_up)
97
- x, _ = self.down_proj(x)
102
+ x, _ = self.down_proj(
103
+ x,
104
+ skip_all_reduce=use_reduce_scatter,
105
+ )
98
106
  return x
99
107
 
100
108
 
@@ -532,31 +540,6 @@ class LlamaForCausalLM(nn.Module):
532
540
  def get_input_embeddings(self) -> nn.Embedding:
533
541
  return self.model.embed_tokens
534
542
 
535
- def get_hidden_dim(self, module_name):
536
- # return input_dim, output_dim
537
- if module_name in ["q_proj", "o_proj", "qkv_proj"]:
538
- return self.config.hidden_size, self.config.hidden_size
539
- elif module_name in ["kv_proj"]:
540
- return self.config.hidden_size, self.config.hidden_size // (
541
- self.config.num_attention_heads // self.config.num_key_value_heads
542
- )
543
- elif module_name == "gate_up_proj":
544
- return self.config.hidden_size, self.config.intermediate_size
545
- elif module_name == "down_proj":
546
- return self.config.intermediate_size, self.config.hidden_size
547
- else:
548
- raise NotImplementedError()
549
-
550
- def get_module_name(self, name):
551
- params_mapping = {
552
- "q_proj": "qkv_proj",
553
- "k_proj": "qkv_proj",
554
- "v_proj": "qkv_proj",
555
- "gate_proj": "gate_up_proj",
556
- "up_proj": "gate_up_proj",
557
- }
558
- return params_mapping.get(name, name)
559
-
560
543
  def get_module_name_from_weight_name(self, name):
561
544
  for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
562
545
  if weight_name in name:
@@ -32,6 +32,7 @@ from sglang.srt.layers.dp_attention import (
32
32
  get_attention_tp_rank,
33
33
  get_attention_tp_size,
34
34
  get_local_attention_dp_size,
35
+ is_dp_attention_enabled,
35
36
  )
36
37
  from sglang.srt.layers.layernorm import RMSNorm
37
38
  from sglang.srt.layers.linear import (
@@ -45,7 +46,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
45
46
  from sglang.srt.layers.radix_attention import RadixAttention
46
47
  from sglang.srt.layers.rotary_embedding import get_rope
47
48
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
48
- from sglang.srt.managers.schedule_batch import global_server_args_dict
49
49
  from sglang.srt.model_executor.forward_batch_info import (
50
50
  ForwardBatch,
51
51
  ForwardMode,
@@ -131,14 +131,19 @@ class Llama4MoE(nn.Module):
131
131
  reduce_results=False, # We need to do scatter before reduce
132
132
  )
133
133
 
134
- def forward(self, hidden_states, forward_batch: ForwardBatch):
134
+ def forward(
135
+ self,
136
+ hidden_states,
137
+ forward_batch: ForwardBatch,
138
+ use_reduce_scatter: bool = False,
139
+ ):
135
140
  shared_out, routed_out = self._forward_core(
136
141
  hidden_states, forward_batch.forward_mode
137
142
  )
138
143
 
139
144
  out_aD = routed_out + shared_out
140
145
 
141
- if self.tp_size > 1:
146
+ if self.tp_size > 1 and not use_reduce_scatter:
142
147
  out_aD = tensor_model_parallel_all_reduce(out_aD)
143
148
 
144
149
  return out_aD
@@ -204,7 +209,7 @@ class Llama4Attention(nn.Module):
204
209
  super().__init__()
205
210
  self.layer_id = layer_id
206
211
  self.hidden_size = hidden_size
207
- self.use_rope = int((layer_id + 1) % 4 != 0)
212
+ self.use_rope = (layer_id + 1) % 4 != 0
208
213
  self.use_qk_norm = config.use_qk_norm and self.use_rope
209
214
 
210
215
  attn_tp_rank = get_attention_tp_rank()
@@ -412,6 +417,7 @@ class Llama4DecoderLayer(nn.Module):
412
417
  layer_scatter_modes=self.layer_scatter_modes,
413
418
  input_layernorm=self.input_layernorm,
414
419
  post_attention_layernorm=self.post_attention_layernorm,
420
+ allow_reduce_scatter=True,
415
421
  )
416
422
 
417
423
  def _is_moe_layer(self, layer_id: int) -> bool:
@@ -441,8 +447,15 @@ class Llama4DecoderLayer(nn.Module):
441
447
  hidden_states, residual, forward_batch
442
448
  )
443
449
 
450
+ # For DP with padding, reduce scatter can be used instead of all-reduce.
451
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
452
+ forward_batch
453
+ )
454
+
444
455
  # Fully Connected
445
- hidden_states = self.feed_forward(hidden_states, forward_batch)
456
+ hidden_states = self.feed_forward(
457
+ hidden_states, forward_batch, use_reduce_scatter
458
+ )
446
459
  hidden_states, residual = self.layer_communicator.postprocess_layer(
447
460
  hidden_states, residual, forward_batch
448
461
  )
@@ -466,7 +479,7 @@ class Llama4Model(nn.Module):
466
479
  config.hidden_size,
467
480
  quant_config=quant_config,
468
481
  prefix=add_prefix("embed_tokens", prefix),
469
- enable_tp=not global_server_args_dict["enable_dp_attention"],
482
+ enable_tp=not is_dp_attention_enabled(),
470
483
  )
471
484
  self.layers = make_layers(
472
485
  config.num_hidden_layers,
@@ -27,6 +27,7 @@ from sglang.srt.distributed import (
27
27
  get_tensor_model_parallel_world_size,
28
28
  )
29
29
  from sglang.srt.layers.activation import SiluAndMul
30
+ from sglang.srt.layers.dp_attention import is_dp_attention_enabled
30
31
  from sglang.srt.layers.layernorm import RMSNorm
31
32
  from sglang.srt.layers.linear import (
32
33
  MergedColumnParallelLinear,
@@ -43,7 +44,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
43
44
  ParallelLMHead,
44
45
  VocabParallelEmbedding,
45
46
  )
46
- from sglang.srt.managers.schedule_batch import global_server_args_dict
47
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
48
48
  from sglang.srt.model_loader.weight_utils import (
49
49
  default_weight_loader,
@@ -273,7 +273,7 @@ class Qwen2Model(nn.Module):
273
273
  config.vocab_size,
274
274
  config.hidden_size,
275
275
  quant_config=quant_config,
276
- enable_tp=not global_server_args_dict["enable_dp_attention"],
276
+ enable_tp=not is_dp_attention_enabled(),
277
277
  prefix=add_prefix("embed_tokens", prefix),
278
278
  )
279
279
  else:
@@ -114,7 +114,7 @@ class Qwen2_5_VisionBlock(nn.Module):
114
114
  num_heads: int,
115
115
  hidden_act="silu",
116
116
  norm_layer: Type[nn.Module] = None,
117
- attn_implementation: Optional[str] = "sdpa",
117
+ attn_implementation: Optional[str] = None,
118
118
  quant_config: Optional[QuantizationConfig] = None,
119
119
  prefix: str = "",
120
120
  ) -> None:
@@ -123,7 +123,12 @@ class Qwen2_5_VisionBlock(nn.Module):
123
123
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
124
124
  self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
125
125
  self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
126
- if attn_implementation == "sdpa":
126
+
127
+ if attn_implementation is None:
128
+ softmax_in_single_precision = False
129
+ qkv_backend = None
130
+ flatten_batch = True
131
+ elif attn_implementation == "sdpa":
127
132
  softmax_in_single_precision = False
128
133
  qkv_backend = "sdpa"
129
134
  flatten_batch = True
@@ -268,7 +273,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
268
273
  num_heads=num_heads,
269
274
  hidden_act=vision_config.hidden_act,
270
275
  norm_layer=norm_layer,
271
- attn_implementation="sdpa",
272
276
  quant_config=quant_config,
273
277
  prefix=add_prefix(f"blocks.{i}", prefix),
274
278
  )
@@ -52,7 +52,11 @@ from sglang.srt.managers.mm_utils import (
52
52
  MultiModalityDataPaddingPatternMultimodalTokens,
53
53
  general_mm_embed_routine,
54
54
  )
55
- from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
55
+ from sglang.srt.managers.schedule_batch import (
56
+ Modality,
57
+ MultimodalDataItem,
58
+ MultimodalInputs,
59
+ )
56
60
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
57
61
  from sglang.srt.model_loader.weight_utils import default_weight_loader
58
62
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -106,15 +110,10 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
106
110
  self.language_model = Qwen2ForCausalLM(
107
111
  config.text_config, quant_config, prefix=add_prefix("model", prefix)
108
112
  )
113
+ self.pattern = MultiModalityDataPaddingPatternMultimodalTokens()
109
114
 
110
115
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
111
- # Get all special token IDs for audio
112
- audio_token_id: int = getattr(
113
- mm_inputs, "audio_token_id", mm_inputs.im_token_id
114
- )
115
-
116
- pattern = MultiModalityDataPaddingPatternMultimodalTokens([audio_token_id])
117
- return pattern.pad_input_tokens(input_ids, mm_inputs)
116
+ return self.pattern.pad_input_tokens(input_ids, mm_inputs)
118
117
 
119
118
  def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
120
119
  # Extract audio features from input items
@@ -143,7 +142,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
143
142
  input_ids=input_ids,
144
143
  forward_batch=forward_batch,
145
144
  language_model=self.language_model,
146
- audio_data_embedding_func=self.get_audio_feature,
145
+ data_embedding_funcs={
146
+ Modality.AUDIO: self.get_audio_feature,
147
+ },
147
148
  positions=positions,
148
149
  )
149
150
 
@@ -46,6 +46,7 @@ from sglang.srt.layers.dp_attention import (
46
46
  get_attention_tp_rank,
47
47
  get_attention_tp_size,
48
48
  get_local_attention_dp_size,
49
+ is_dp_attention_enabled,
49
50
  )
50
51
  from sglang.srt.layers.layernorm import RMSNorm
51
52
  from sglang.srt.layers.linear import (
@@ -107,10 +108,14 @@ class Qwen2MoeMLP(nn.Module):
107
108
  )
108
109
  self.act_fn = SiluAndMul()
109
110
 
110
- def forward(self, x):
111
+ def forward(
112
+ self,
113
+ x,
114
+ use_reduce_scatter: bool = False,
115
+ ):
111
116
  gate_up, _ = self.gate_up_proj(x)
112
117
  x = self.act_fn(gate_up)
113
- x, _ = self.down_proj(x)
118
+ x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter)
114
119
  return x
115
120
 
116
121
 
@@ -175,7 +180,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
175
180
  self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
176
181
 
177
182
  def forward(
178
- self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
183
+ self,
184
+ hidden_states: torch.Tensor,
185
+ forward_batch: Optional[ForwardBatch] = None,
186
+ use_reduce_scatter: bool = False,
179
187
  ) -> torch.Tensor:
180
188
  num_tokens, hidden_dim = hidden_states.shape
181
189
  hidden_states = hidden_states.view(-1, hidden_dim)
@@ -193,6 +201,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
193
201
  final_hidden_states = self.experts(hidden_states, topk_output)
194
202
  if shared_output is not None:
195
203
  final_hidden_states = final_hidden_states + shared_output
204
+ if self.tp_size > 1 and not use_reduce_scatter:
196
205
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
197
206
 
198
207
  return final_hidden_states.view(num_tokens, hidden_dim)
@@ -367,6 +376,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
367
376
  layer_scatter_modes=self.layer_scatter_modes,
368
377
  input_layernorm=self.input_layernorm,
369
378
  post_attention_layernorm=self.post_attention_layernorm,
379
+ allow_reduce_scatter=True,
370
380
  )
371
381
 
372
382
  def forward(
@@ -392,7 +402,12 @@ class Qwen2MoeDecoderLayer(nn.Module):
392
402
  hidden_states, residual, forward_batch
393
403
  )
394
404
 
395
- hidden_states = self.mlp(hidden_states, forward_batch)
405
+ # For DP with padding, reduce scatter can be used instead of all-reduce.
406
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
407
+ forward_batch
408
+ )
409
+
410
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
396
411
 
397
412
  hidden_states, residual = self.layer_communicator.postprocess_layer(
398
413
  hidden_states, residual, forward_batch
@@ -420,7 +435,7 @@ class Qwen2MoeModel(nn.Module):
420
435
  self.embed_tokens = VocabParallelEmbedding(
421
436
  config.vocab_size,
422
437
  config.hidden_size,
423
- enable_tp=not global_server_args_dict["enable_dp_attention"],
438
+ enable_tp=not is_dp_attention_enabled(),
424
439
  prefix=add_prefix("embed_tokens", prefix),
425
440
  )
426
441
  else:
@@ -330,30 +330,6 @@ class Qwen3ForCausalLM(nn.Module):
330
330
  def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
331
331
  return self.model.get_input_embeddings(input_ids)
332
332
 
333
- def get_hidden_dim(self, module_name: str) -> Tuple[int]:
334
- # return input_dim, output_dim
335
- if module_name in ["q_proj", "qkv_proj"]:
336
- return (
337
- self.config.hidden_size,
338
- self.config.head_dim * self.config.num_attention_heads,
339
- )
340
- elif module_name in ["o_proj"]:
341
- return (
342
- self.config.head_dim * self.config.num_attention_heads,
343
- self.config.hidden_size,
344
- )
345
- elif module_name in ["kv_proj"]:
346
- return (
347
- self.config.hidden_size,
348
- self.config.head_dim * self.config.num_key_value_heads,
349
- )
350
- elif module_name == "gate_up_proj":
351
- return self.config.hidden_size, self.config.intermediate_size
352
- elif module_name == "down_proj":
353
- return self.config.intermediate_size, self.config.hidden_size
354
- else:
355
- raise NotImplementedError()
356
-
357
333
  @torch.no_grad()
358
334
  def forward(
359
335
  self,
@@ -0,0 +1,78 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ from typing import Iterable, Optional, Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+ from transformers import Qwen2Config # Qwen3 uses Qwen2Config
20
+
21
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.models.qwen3 import Qwen3ForCausalLM, Qwen3Model
25
+ from sglang.srt.utils import add_prefix
26
+
27
+
28
+ class Qwen3ForSequenceClassification(nn.Module):
29
+ def __init__(
30
+ self,
31
+ config: Qwen2Config,
32
+ quant_config: Optional[QuantizationConfig] = None,
33
+ prefix: str = "",
34
+ ) -> None:
35
+ super().__init__()
36
+ self.config = config
37
+ self.quant_config = quant_config
38
+ self.model = Qwen3Model(
39
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
40
+ )
41
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
42
+ # Use normalize=True for qwen3 embedding based on official implementation
43
+ # Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55
44
+ # Official code: output = F.normalize(output, p=2, dim=1)
45
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
46
+
47
+ self.eos_token_id = config.eos_token_id
48
+
49
+ @torch.no_grad()
50
+ def forward(
51
+ self,
52
+ input_ids: torch.Tensor,
53
+ positions: torch.Tensor,
54
+ forward_batch: ForwardBatch,
55
+ input_embeds: Optional[torch.Tensor] = None,
56
+ get_embedding: bool = True,
57
+ ) -> EmbeddingPoolerOutput:
58
+ assert (
59
+ get_embedding
60
+ ), "Qwen3ForSequenceClassification is only used for embedding"
61
+
62
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
63
+ logits = self.score(hidden_states)
64
+ pooled_logits = self.pooler(logits, forward_batch).embeddings
65
+
66
+ return EmbeddingPoolerOutput(pooled_logits)
67
+
68
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
69
+ # Filter out lm_head weights of Qwen3ForCausalLM
70
+ filtered_weights = [
71
+ (name, w) for name, w in weights if not name.startswith("lm_head")
72
+ ]
73
+ return Qwen3ForCausalLM.load_weights(self, filtered_weights)
74
+
75
+
76
+ EntryClass = [
77
+ Qwen3ForSequenceClassification,
78
+ ]
@@ -144,11 +144,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
144
144
  self.top_k = config.num_experts_per_tok
145
145
 
146
146
  def forward(
147
- self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
147
+ self,
148
+ hidden_states: torch.Tensor,
149
+ forward_batch: Optional[ForwardBatch] = None,
150
+ use_reduce_scatter: bool = False,
148
151
  ) -> torch.Tensor:
149
152
 
150
153
  if not global_server_args_dict["moe_a2a_backend"].is_deepep():
151
- return self.forward_normal(hidden_states)
154
+ return self.forward_normal(hidden_states, use_reduce_scatter)
152
155
  else:
153
156
  return self.forward_deepep(hidden_states, forward_batch)
154
157
 
@@ -159,7 +162,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
159
162
  if name not in ["correction_bias"]
160
163
  ]
161
164
 
162
- def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
165
+ def forward_normal(
166
+ self,
167
+ hidden_states: torch.Tensor,
168
+ use_reduce_scatter: bool = False,
169
+ ) -> torch.Tensor:
163
170
  num_tokens, hidden_dim = hidden_states.shape
164
171
  hidden_states = hidden_states.view(-1, hidden_dim)
165
172
 
@@ -167,7 +174,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
167
174
  router_logits, _ = self.gate(hidden_states)
168
175
  topk_output = self.topk(hidden_states, router_logits)
169
176
  final_hidden_states = self.experts(hidden_states, topk_output)
170
- if self.tp_size > 1:
177
+ if self.tp_size > 1 and not use_reduce_scatter:
171
178
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
172
179
 
173
180
  return final_hidden_states.view(num_tokens, hidden_dim)
@@ -521,6 +528,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
521
528
  layer_scatter_modes=self.layer_scatter_modes,
522
529
  input_layernorm=self.input_layernorm,
523
530
  post_attention_layernorm=self.post_attention_layernorm,
531
+ allow_reduce_scatter=True,
524
532
  )
525
533
 
526
534
  def forward(
@@ -546,7 +554,12 @@ class Qwen3MoeDecoderLayer(nn.Module):
546
554
  hidden_states, residual, forward_batch
547
555
  )
548
556
 
549
- hidden_states = self.mlp(hidden_states, forward_batch)
557
+ # For DP with padding, reduce scatter can be used instead of all-reduce.
558
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
559
+ forward_batch
560
+ )
561
+
562
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
550
563
 
551
564
  hidden_states, residual = self.layer_communicator.postprocess_layer(
552
565
  hidden_states, residual, forward_batch
@@ -83,7 +83,7 @@ def import_model_classes():
83
83
  try:
84
84
  module = importlib.import_module(name)
85
85
  except Exception as e:
86
- logger.warning(f"Ignore import error when loading {name}. " f"{e}")
86
+ logger.warning(f"Ignore import error when loading {name}: {e}")
87
87
  continue
88
88
  if hasattr(module, "EntryClass"):
89
89
  entry = module.EntryClass
@@ -25,7 +25,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
25
25
  from sglang.srt.layers.activation import SiluAndMul
26
26
  from sglang.srt.layers.attention.vision import VisionAttention
27
27
  from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
28
- from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
28
+ from sglang.srt.layers.dp_attention import (
29
+ get_attention_tp_rank,
30
+ get_attention_tp_size,
31
+ is_dp_attention_enabled,
32
+ )
29
33
  from sglang.srt.layers.layernorm import RMSNorm
30
34
  from sglang.srt.layers.linear import (
31
35
  ColumnParallelLinear,
@@ -437,7 +441,7 @@ class Step3TextModel(nn.Module):
437
441
  self.embed_tokens = VocabParallelEmbedding(
438
442
  config.vocab_size,
439
443
  config.hidden_size,
440
- enable_tp=not global_server_args_dict["enable_dp_attention"],
444
+ enable_tp=not is_dp_attention_enabled(),
441
445
  prefix=add_prefix("embed_tokens", prefix),
442
446
  )
443
447
 
@@ -416,30 +416,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
416
416
  input_ids, hidden_states, self.lm_head, forward_batch
417
417
  )
418
418
 
419
- def get_hidden_dim(self, module_name):
420
- if module_name in ["q_proj", "o_proj", "qkv_proj"]:
421
- return self.config.hidden_size, self.config.hidden_size
422
- elif module_name in ["kv_proj"]:
423
- return self.config.hidden_size, self.config.hidden_size // (
424
- self.config.num_attention_heads // self.config.num_key_value_heads
425
- )
426
- elif module_name == "gate_up_proj":
427
- return self.config.hidden_size, self.config.intermediate_size
428
- elif module_name == "down_proj":
429
- return self.config.intermediate_size, self.config.hidden_size
430
- else:
431
- raise NotImplementedError()
432
-
433
- def get_module_name(self, name):
434
- params_mapping = {
435
- "q_proj": "qkv_proj",
436
- "k_proj": "qkv_proj",
437
- "v_proj": "qkv_proj",
438
- "gate_proj": "gate_up_proj",
439
- "up_proj": "gate_up_proj",
440
- }
441
- return params_mapping.get(name, name)
442
-
443
419
  def get_module_name_from_weight_name(self, name):
444
420
  stacked_params_mapping = [
445
421
  # (param_name, shard_name, shard_id, num_shard)
@@ -22,13 +22,19 @@ class BaseMultiModalProcessorOutput:
22
22
  input_text: str
23
23
 
24
24
  # frames loaded from image, in given order
25
- images: Optional[list[Union[Image.Image, dict]]] = None
25
+ images: Optional[list[Union[Image.Image, dict]]] = dataclasses.field(
26
+ default_factory=list
27
+ )
26
28
 
27
29
  # videos
28
- videos: Optional[list[Union[torch.Tensor, dict]]] = None
30
+ videos: Optional[list[Union[torch.Tensor, dict]]] = dataclasses.field(
31
+ default_factory=list
32
+ )
29
33
 
30
34
  # audios
31
- audios: Optional[list[Union[np.ndarray, dict]]] = None
35
+ audios: Optional[list[Union[np.ndarray, dict]]] = dataclasses.field(
36
+ default_factory=list
37
+ )
32
38
 
33
39
  def organize_results(self) -> List[Tuple[Modality, Any]]:
34
40
  """
@@ -202,7 +208,7 @@ class BaseMultimodalProcessor(ABC):
202
208
 
203
209
  def process_mm_data(
204
210
  self, input_text, images=None, videos=None, audios=None, **kwargs
205
- ):
211
+ ) -> dict:
206
212
  """
207
213
  process multimodal data with transformers AutoProcessor
208
214
  """
@@ -211,10 +217,14 @@ class BaseMultimodalProcessor(ABC):
211
217
  if videos:
212
218
  kwargs["videos"] = videos
213
219
  if audios:
214
- kwargs["audios"] = audios
215
- if self.__class__.__name__ == "Gemma3nSGLangProcessor":
220
+ if self.arch in {
221
+ "Gemma3nForConditionalGeneration",
222
+ "Qwen2AudioForConditionalGeneration",
223
+ }:
216
224
  # Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
217
225
  kwargs["audio"] = audios
226
+ else:
227
+ kwargs["audios"] = audios
218
228
 
219
229
  processor = self._processor
220
230
  if (
@@ -601,12 +611,6 @@ class BaseMultimodalProcessor(ABC):
601
611
  all_collected_items: list[MultimodalDataItem] = []
602
612
  input_ids = None
603
613
 
604
- # Handle dict items (already processed)
605
- for dict_item in dict_items:
606
- all_collected_items.extend(
607
- self.collect_mm_items_from_processor_output(dict_item)
608
- )
609
-
610
614
  # Handle raw items (need processing)
611
615
  if raw_images or raw_audios or raw_videos:
612
616
  collected_items, input_ids, ret = self._process_and_collect_mm_items(
@@ -616,10 +620,16 @@ class BaseMultimodalProcessor(ABC):
616
620
  videos=raw_videos,
617
621
  **kwargs,
618
622
  )
619
- all_collected_items.extend(collected_items)
623
+ all_collected_items = collected_items
620
624
  else:
621
625
  ret = None
622
626
 
627
+ # Handle dict items (already processed)
628
+ for dict_item in dict_items:
629
+ all_collected_items.extend(
630
+ self.collect_mm_items_from_processor_output(dict_item)
631
+ )
632
+
623
633
  # Fallback tokenization if no raw items were processed
624
634
  if input_ids is None:
625
635
  input_ids = self._processor.tokenizer(