sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -31,13 +31,13 @@ from sglang.srt.layers.communicator import enable_moe_dense_fully_dp
31
31
  from sglang.srt.layers.layernorm import RMSNorm
32
32
  from sglang.srt.layers.logits_processor import LogitsProcessor
33
33
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
34
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
34
35
  from sglang.srt.layers.moe.topk import TopK
35
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
37
  from sglang.srt.layers.vocab_parallel_embedding import (
37
38
  ParallelLMHead,
38
39
  VocabParallelEmbedding,
39
40
  )
40
- from sglang.srt.managers.schedule_batch import global_server_args_dict
41
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
42
  from sglang.srt.model_loader.weight_utils import default_weight_loader
43
43
  from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP
@@ -361,7 +361,7 @@ class Ernie4_5_ForCausalLM(nn.Module):
361
361
 
362
362
  class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM):
363
363
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
364
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
364
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
365
365
  ckpt_gate_proj_name="gate_proj",
366
366
  ckpt_down_proj_name="down_proj",
367
367
  ckpt_up_proj_name="up_proj",
@@ -24,6 +24,7 @@ from transformers import PretrainedConfig
24
24
 
25
25
  from sglang.srt.distributed import (
26
26
  get_moe_expert_parallel_world_size,
27
+ get_pp_group,
27
28
  get_tensor_model_parallel_rank,
28
29
  get_tensor_model_parallel_world_size,
29
30
  parallel_state,
@@ -39,7 +40,6 @@ from sglang.srt.layers.communicator import (
39
40
  from sglang.srt.layers.dp_attention import (
40
41
  get_attention_tp_rank,
41
42
  get_attention_tp_size,
42
- get_local_attention_dp_size,
43
43
  is_dp_attention_enabled,
44
44
  )
45
45
  from sglang.srt.layers.layernorm import RMSNorm
@@ -51,9 +51,10 @@ from sglang.srt.layers.linear import (
51
51
  RowParallelLinear,
52
52
  )
53
53
  from sglang.srt.layers.logits_processor import LogitsProcessor
54
+ from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
54
55
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
56
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
55
57
  from sglang.srt.layers.moe.topk import TopK
56
- from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
57
58
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
58
59
  from sglang.srt.layers.quantization.fp8_kernel import (
59
60
  is_fp8_fnuz,
@@ -76,10 +77,7 @@ from sglang.srt.models.deepseek_v2 import (
76
77
  DeepseekV2Model,
77
78
  DeepseekV2MoE,
78
79
  )
79
- from sglang.srt.two_batch_overlap import (
80
- MaybeTboDeepEPDispatcher,
81
- model_forward_maybe_tbo,
82
- )
80
+ from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
83
81
  from sglang.srt.utils import (
84
82
  BumpAllocator,
85
83
  LazyValue,
@@ -414,19 +412,15 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
414
412
  config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
415
413
  )
416
414
 
417
- self.topk = (
418
- TopK(
419
- top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
420
- renormalize=config.norm_topk_prob,
421
- use_grouped_topk=True,
422
- num_expert_group=config.n_group,
423
- num_fused_shared_experts=self.num_fused_shared_experts,
424
- topk_group=config.topk_group,
425
- correction_bias=self.gate.e_score_correction_bias,
426
- routed_scaling_factor=self.routed_scaling_factor,
427
- )
428
- if not should_use_flashinfer_trtllm_moe()
429
- else None
415
+ self.topk = TopK(
416
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
417
+ renormalize=config.norm_topk_prob,
418
+ use_grouped_topk=True,
419
+ num_expert_group=config.n_group,
420
+ num_fused_shared_experts=self.num_fused_shared_experts,
421
+ topk_group=config.topk_group,
422
+ correction_bias=self.gate.e_score_correction_bias,
423
+ routed_scaling_factor=self.routed_scaling_factor,
430
424
  )
431
425
 
432
426
  self.experts = get_moe_impl_class()(
@@ -441,31 +435,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
441
435
  quant_config=quant_config,
442
436
  routed_scaling_factor=self.routed_scaling_factor,
443
437
  prefix=add_prefix("experts", prefix),
444
- **(
445
- dict(deepep_mode=global_server_args_dict["deepep_mode"])
446
- if global_server_args_dict["moe_a2a_backend"].is_deepep()
447
- else {}
448
- ),
449
- # Additional args for FusedMoE
450
- **(
451
- dict(
452
- enable_flashinfer_cutlass_moe=True,
453
- )
454
- if global_server_args_dict["enable_flashinfer_cutlass_moe"]
455
- else {}
456
- ),
457
- **(
458
- dict(
459
- renormalize=config.norm_topk_prob,
460
- use_grouped_topk=True,
461
- num_expert_group=config.n_group,
462
- num_fused_shared_experts=self.num_fused_shared_experts,
463
- topk_group=config.topk_group,
464
- correction_bias=self.gate.e_score_correction_bias,
465
- )
466
- if should_use_flashinfer_trtllm_moe()
467
- else {}
468
- ),
469
438
  )
470
439
 
471
440
  self.shared_experts_is_int8 = False
@@ -496,7 +465,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
496
465
 
497
466
  self.top_k = config.num_experts_per_tok
498
467
 
499
- if global_server_args_dict["moe_a2a_backend"].is_deepep():
468
+ if get_moe_a2a_backend().is_deepep():
500
469
  # TODO: we will support tp < ep in the future
501
470
  self.ep_size = get_moe_expert_parallel_world_size()
502
471
  self.num_experts = (
@@ -520,12 +489,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
520
489
  num_local_experts=config.n_routed_experts // self.tp_size,
521
490
  hidden_size=config.hidden_size,
522
491
  params_dtype=config.torch_dtype,
523
- deepep_mode=global_server_args_dict["deepep_mode"],
492
+ deepep_mode=get_deepep_mode(),
524
493
  async_finish=True,
525
494
  return_recv_hook=True,
526
495
  )
527
496
 
528
- self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
497
+ self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
529
498
 
530
499
  def forward_normal_dual_stream(
531
500
  self,
@@ -541,12 +510,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
541
510
  with torch.cuda.stream(self.alt_stream):
542
511
  # router_logits: (num_tokens, n_experts)
543
512
  router_logits = self.gate(hidden_states)
544
- kwargs = {"hidden_states": hidden_states}
545
- if self.topk is not None:
546
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
547
- else:
548
- kwargs["router_logits"] = router_logits
549
- final_hidden_states = self.experts(**kwargs)
513
+ topk_output = self.topk(hidden_states, router_logits)
514
+ final_hidden_states = self.experts(hidden_states, topk_output)
550
515
  if not _is_cuda:
551
516
  final_hidden_states *= self.routed_scaling_factor
552
517
  current_stream.wait_stream(self.alt_stream)
@@ -587,12 +552,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
587
552
  shared_output = self._forward_shared_experts(hidden_states)
588
553
  # router_logits: (num_tokens, n_experts)
589
554
  router_logits = self.gate(hidden_states)
590
- kwargs = {"hidden_states": hidden_states}
591
- if self.topk is not None:
592
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
593
- else:
594
- kwargs["router_logits"] = router_logits
595
- final_hidden_states = self.experts(**kwargs)
555
+ topk_output = self.topk(hidden_states, router_logits)
556
+ final_hidden_states = self.experts(hidden_states, topk_output)
596
557
  if not _is_cuda and not _use_aiter:
597
558
  # fused in biased_grouped_topk so we can skip here
598
559
  final_hidden_states *= self.routed_scaling_factor
@@ -759,10 +720,11 @@ class Glm4MoeModel(DeepseekV2Model):
759
720
  for layer_id in range(config.num_hidden_layers)
760
721
  ]
761
722
  )
723
+ self.pp_group = get_pp_group()
724
+ self.start_layer = 0
725
+ self.end_layer = config.num_hidden_layers
762
726
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
763
727
 
764
- self.dp_size = get_local_attention_dp_size()
765
-
766
728
 
767
729
  class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
768
730
 
@@ -777,6 +739,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
777
739
  self.config = config
778
740
  self.tp_size = get_tensor_model_parallel_world_size()
779
741
  self.quant_config = quant_config
742
+ self.pp_group = get_pp_group()
780
743
  self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
781
744
  self.model = Glm4MoeModel(
782
745
  config, quant_config, prefix=add_prefix("model", prefix)
@@ -789,7 +752,6 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
789
752
  use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
790
753
  )
791
754
  self.logits_processor = LogitsProcessor(config)
792
- self.dp_size = get_local_attention_dp_size()
793
755
 
794
756
  self._routed_experts_weights_of_layer = LazyValue(
795
757
  lambda: {
@@ -953,7 +915,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
953
915
 
954
916
  # Params for weights, fp8 weight scales, fp8 activation scales
955
917
  # (param_name, weight_name, expert_id, shard_id)
956
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
918
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
957
919
  ckpt_gate_proj_name="gate_proj",
958
920
  ckpt_down_proj_name="down_proj",
959
921
  ckpt_up_proj_name="up_proj",
@@ -9,6 +9,7 @@ from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisi
9
9
 
10
10
  from sglang.srt.hf_transformers_utils import get_processor
11
11
  from sglang.srt.layers.activation import SiluAndMul
12
+ from sglang.srt.layers.attention import vision_utils
12
13
  from sglang.srt.layers.layernorm import RMSNorm
13
14
  from sglang.srt.layers.linear import (
14
15
  ColumnParallelLinear,
@@ -91,6 +92,7 @@ class Glm4vVisionBlock(Qwen2_5_VisionBlock):
91
92
  norm_layer=norm_layer,
92
93
  quant_config=quant_config,
93
94
  prefix=prefix,
95
+ num_dummy_heads=config.num_dummy_heads,
94
96
  )
95
97
  self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
96
98
  self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -469,7 +471,7 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
469
471
  nn.Module.__init__(self)
470
472
 
471
473
  self.config = config
472
-
474
+ vision_utils.update_vit_attn_dummy_heads_config(self.config)
473
475
  self.model = Glm4Model(
474
476
  config,
475
477
  quant_config,
@@ -537,6 +539,51 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
537
539
  video_embeds = torch.split(video_embeds, split_sizes)
538
540
  return torch.cat(video_embeds)
539
541
 
542
+ def _update_hf_config(self):
543
+ """update hf config to ensure vision attention num_attention_heads is divisible by tp_size"""
544
+ tp_size = get_attention_tp_size()
545
+ num_heads = self.config.vision_config.num_heads
546
+ head_dim = self.config.vision_config.hidden_size // num_heads
547
+ num_dummy_heads = 0
548
+
549
+ if num_heads % tp_size != 0:
550
+ num_dummy_heads = (
551
+ (num_heads + tp_size - 1) // tp_size
552
+ ) * tp_size - num_heads
553
+
554
+ setattr(self.config.vision_config, "head_dim", head_dim)
555
+ setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
556
+
557
+ def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
558
+ """pad attn qkv weights for dummy heads"""
559
+ num_dummy_heads = self.config.vision_config.num_dummy_heads
560
+ if num_dummy_heads == 0:
561
+ return loaded_weight
562
+ head_dim = self.config.vision_config.head_dim
563
+
564
+ if "attn.qkv_proj" in name:
565
+ wq, wk, wv = loaded_weight.chunk(3, dim=0)
566
+ if name.endswith(".weight"):
567
+ dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
568
+ elif name.endswith(".bias"):
569
+ dummy_shape = [num_dummy_heads, head_dim]
570
+ else:
571
+ raise RuntimeError(f"Unsupported weight with name={name}")
572
+ pad_func = lambda x: torch.cat(
573
+ [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
574
+ ).flatten(0, 1)
575
+ wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
576
+ loaded_weight = torch.cat([wq, wk, wv], dim=0)
577
+ elif "attn.proj.weight" in name:
578
+ padded_weight = loaded_weight.new_zeros(
579
+ loaded_weight.shape[0], head_dim * num_dummy_heads
580
+ )
581
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
582
+ elif "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
583
+ padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
584
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
585
+ return loaded_weight
586
+
540
587
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
541
588
  stacked_params_mapping = [
542
589
  # (param_name, shard_name, shard_id)
@@ -583,6 +630,10 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
583
630
  raise
584
631
 
585
632
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
633
+ if "visual" in name:
634
+ loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
635
+ self.config, name, loaded_weight
636
+ )
586
637
  weight_loader(param, loaded_weight)
587
638
 
588
639
 
@@ -8,19 +8,12 @@ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
8
8
 
9
9
  from sglang.srt.distributed import (
10
10
  get_moe_expert_parallel_world_size,
11
- get_tensor_model_parallel_rank,
12
11
  get_tensor_model_parallel_world_size,
13
- parallel_state,
14
- tensor_model_parallel_all_reduce,
15
12
  )
16
13
  from sglang.srt.hf_transformers_utils import get_processor
17
- from sglang.srt.layers.dp_attention import (
18
- get_attention_tp_rank,
19
- get_attention_tp_size,
20
- get_local_attention_dp_size,
21
- )
14
+ from sglang.srt.layers.attention import vision_utils
22
15
  from sglang.srt.layers.logits_processor import LogitsProcessor
23
- from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
16
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
24
17
  from sglang.srt.layers.pooler import Pooler, PoolingType
25
18
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
19
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
@@ -48,8 +41,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
48
41
 
49
42
  config.moe_layer_freq = 1
50
43
  self.config = config
44
+ vision_utils.update_vit_attn_dummy_heads_config(self.config)
51
45
  self.tp_size = get_tensor_model_parallel_world_size()
52
- self.dp_size = get_local_attention_dp_size()
53
46
  self.quant_config = quant_config
54
47
  self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
55
48
  self.num_fused_shared_experts = (
@@ -232,7 +225,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
232
225
 
233
226
  # Params for weights, fp8 weight scales, fp8 activation scales
234
227
  # (param_name, weight_name, expert_id, shard_id)
235
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
228
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
236
229
  ckpt_gate_proj_name="gate_proj",
237
230
  ckpt_down_proj_name="down_proj",
238
231
  ckpt_up_proj_name="up_proj",
@@ -394,6 +387,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
394
387
  weight_loader = getattr(
395
388
  param, "weight_loader", default_weight_loader
396
389
  )
390
+ if "visual" in name:
391
+ loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
392
+ self.config, name, loaded_weight
393
+ )
397
394
  weight_loader(param, loaded_weight)
398
395
 
399
396
 
@@ -16,6 +16,7 @@
16
16
  """Inference-only GptOss model compatible with HuggingFace weights."""
17
17
 
18
18
  import logging
19
+ import math
19
20
  from collections.abc import Iterable
20
21
  from functools import partial
21
22
  from typing import Any, Dict, List, Optional, Tuple, Union
@@ -40,7 +41,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
40
41
  from sglang.srt.layers.dp_attention import (
41
42
  get_attention_tp_rank,
42
43
  get_attention_tp_size,
43
- get_local_attention_dp_size,
44
44
  is_dp_attention_enabled,
45
45
  )
46
46
  from sglang.srt.layers.layernorm import RMSNorm
@@ -50,9 +50,10 @@ from sglang.srt.layers.linear import (
50
50
  RowParallelLinear,
51
51
  )
52
52
  from sglang.srt.layers.logits_processor import LogitsProcessor
53
+ from sglang.srt.layers.moe import get_moe_a2a_backend
53
54
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
55
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
54
56
  from sglang.srt.layers.moe.topk import TopK
55
- from sglang.srt.layers.moe.utils import DeepEPMode
56
57
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
57
58
  from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
58
59
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -110,16 +111,13 @@ class GptOssSparseMoeBlock(nn.Module):
110
111
  self.tp_size = get_tensor_model_parallel_world_size()
111
112
  self.layer_id = layer_id
112
113
  self.activation = config.hidden_act
113
- self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
114
- self.swiglu_limit = config.swiglu_limit
114
+ self.gemm1_alpha = getattr(config, "hidden_act_alpha", 1.702)
115
+ self.gemm1_clamp_limit = config.swiglu_limit
115
116
 
116
- if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
117
- self.topk = None
118
- else:
119
- self.topk = TopK(
120
- top_k=config.num_experts_per_tok,
121
- renormalize=True,
122
- )
117
+ self.topk = TopK(
118
+ top_k=config.num_experts_per_tok,
119
+ renormalize=True,
120
+ )
123
121
 
124
122
  self.top_k = config.num_experts_per_tok
125
123
  experts_type = get_moe_impl_class()
@@ -129,11 +127,9 @@ class GptOssSparseMoeBlock(nn.Module):
129
127
  quant_config.get_name() if quant_config is not None else None
130
128
  )
131
129
  extra_kwargs = {
132
- "enable_flashinfer_cutlass_moe": global_server_args_dict[
133
- "enable_flashinfer_cutlass_moe"
134
- ],
135
130
  # for moe gate_up_proj and down_proj and their bias loading
136
- "use_weight_loader_fused": quant_config_name != "mxfp4",
131
+ "use_weight_loader_fused": quant_config_name
132
+ != "mxfp4"
137
133
  }
138
134
  self.experts = experts_type(
139
135
  num_experts=config.num_local_experts
@@ -144,15 +140,10 @@ class GptOssSparseMoeBlock(nn.Module):
144
140
  intermediate_size=config.intermediate_size,
145
141
  quant_config=quant_config,
146
142
  activation=self.activation,
147
- activation_alpha=self.activation_alpha,
148
- swiglu_limit=self.swiglu_limit,
143
+ gemm1_alpha=self.gemm1_alpha,
144
+ gemm1_clamp_limit=self.gemm1_clamp_limit,
149
145
  with_bias=True,
150
146
  prefix=add_prefix("experts", prefix),
151
- **(
152
- dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
153
- if global_server_args_dict["moe_a2a_backend"].is_deepep()
154
- else {}
155
- ),
156
147
  **extra_kwargs,
157
148
  )
158
149
 
@@ -171,7 +162,7 @@ class GptOssSparseMoeBlock(nn.Module):
171
162
  forward_batch: Optional[ForwardBatch] = None,
172
163
  should_allreduce_fusion: bool = False,
173
164
  ) -> torch.Tensor:
174
- if not global_server_args_dict["moe_a2a_backend"].is_deepep():
165
+ if not get_moe_a2a_backend().is_deepep():
175
166
  return self.forward_normal(hidden_states, should_allreduce_fusion)
176
167
  else:
177
168
  raise Exception("forward_deepep branch not implemented yet")
@@ -189,17 +180,10 @@ class GptOssSparseMoeBlock(nn.Module):
189
180
  should_allreduce_fusion: bool = False,
190
181
  ) -> torch.Tensor:
191
182
  num_tokens, hidden_dim = hidden_states.shape
192
- hidden_states = hidden_states.view(-1, hidden_dim)
193
183
 
194
- # router_logits: (num_tokens, n_experts)
195
184
  router_logits, _ = self.router(hidden_states)
196
-
197
- kwargs = {"hidden_states": hidden_states}
198
- if self.topk is not None:
199
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
200
- else:
201
- kwargs["topk_output"] = (self.top_k, router_logits)
202
- final_hidden_states = self.experts(**kwargs)
185
+ topk_output = self.topk(hidden_states, router_logits)
186
+ final_hidden_states = self.experts(hidden_states, topk_output)
203
187
 
204
188
  if self.tp_size > 1 and not should_allreduce_fusion:
205
189
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
@@ -436,7 +420,6 @@ class GptOssDecoderLayer(nn.Module):
436
420
 
437
421
  self.attn_tp_size = get_attention_tp_size()
438
422
  self.attn_tp_rank = get_attention_tp_rank()
439
- self.local_dp_size = get_local_attention_dp_size()
440
423
 
441
424
  # GptOss all layers are sparse and have no nextn now
442
425
  self.is_layer_sparse = True
@@ -471,44 +454,11 @@ class GptOssDecoderLayer(nn.Module):
471
454
  layer_scatter_modes=self.layer_scatter_modes,
472
455
  input_layernorm=self.input_layernorm,
473
456
  post_attention_layernorm=self.post_attention_layernorm,
457
+ is_last_layer=(
458
+ self.is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
459
+ ),
474
460
  )
475
461
 
476
- self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
477
-
478
- def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
479
- """Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
480
-
481
- batch_size = (
482
- forward_batch.input_ids.shape[0]
483
- if hasattr(forward_batch, "input_ids")
484
- else 0
485
- )
486
-
487
- if batch_size > 128:
488
- return False
489
-
490
- return self._fuse_allreduce_lookup_table.get(batch_size, False)
491
-
492
- def _build_fuse_allreduce_lookup_table(self):
493
- static_conditions_met = (
494
- self.layer_id != self.config.num_hidden_layers - 1
495
- and get_tensor_model_parallel_world_size() > 1
496
- and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
497
- and _is_sm100_supported
498
- and _is_flashinfer_available
499
- )
500
-
501
- if not static_conditions_met:
502
- return {}
503
-
504
- lookup_table = {}
505
- for batch_size in range(129): # 0 to 128
506
- is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
507
- should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
508
- lookup_table[batch_size] = should_fuse
509
-
510
- return lookup_table
511
-
512
462
  def forward(
513
463
  self,
514
464
  positions: torch.Tensor,
@@ -532,8 +482,9 @@ class GptOssDecoderLayer(nn.Module):
532
482
  )
533
483
 
534
484
  should_allreduce_fusion = (
535
- self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
536
- and not self.is_nextn
485
+ self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
486
+ forward_batch
487
+ )
537
488
  )
538
489
 
539
490
  hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
@@ -838,18 +789,27 @@ class GptOssForCausalLM(nn.Module):
838
789
  moe_ep_size = get_moe_expert_parallel_world_size()
839
790
 
840
791
  intermediate_size = self.config.intermediate_size
792
+ assert (
793
+ intermediate_size % mxfp4_block == 0
794
+ ), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
841
795
  intermediate_size_block = intermediate_size // mxfp4_block
842
- per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
796
+
797
+ per_rank_intermediate_size_block = math.ceil(
798
+ intermediate_size_block / moe_tp_size
799
+ )
800
+
843
801
  per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
844
802
 
845
803
  # Calculate common slicing bounds for current rank
846
804
  assert self.config.num_local_experts % moe_ep_size == 0
847
805
  moe_num_global_experts = self.config.num_local_experts
848
806
  moe_num_local_experts = self.config.num_local_experts // moe_ep_size
807
+
849
808
  moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
850
809
  moe_tp_rank_end = min(
851
810
  (moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
852
811
  )
812
+
853
813
  moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
854
814
  moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts
855
815
 
@@ -1060,7 +1020,7 @@ class GptOssForCausalLM(nn.Module):
1060
1020
  ("qkv_proj", "k_proj", "k"),
1061
1021
  ("qkv_proj", "v_proj", "v"),
1062
1022
  ]
1063
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
1023
+ expert_params_mapping = FusedMoE.make_expert_params_mapping_fused(
1064
1024
  ckpt_gate_up_proj_name="gate_up_proj",
1065
1025
  ckpt_down_proj_name="down_proj",
1066
1026
  ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
@@ -1141,7 +1101,7 @@ class GptOssForCausalLM(nn.Module):
1141
1101
  if name in params_dict.keys():
1142
1102
  param = params_dict[name]
1143
1103
  if "sinks" in name:
1144
- start = tp_rank * param.numel()
1104
+ start = get_attention_tp_rank() * param.numel()
1145
1105
  param.data.copy_(
1146
1106
  loaded_weight[start : start + param.numel()]
1147
1107
  )
@@ -76,7 +76,6 @@ class GraniteMoeMoE(nn.Module):
76
76
  params_dtype=params_dtype,
77
77
  reduce_results=True,
78
78
  quant_config=quant_config,
79
- tp_size=tp_size,
80
79
  prefix=f"{prefix}.experts",
81
80
  )
82
81