sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -31,13 +31,13 @@ from sglang.srt.layers.communicator import enable_moe_dense_fully_dp
31
31
  from sglang.srt.layers.layernorm import RMSNorm
32
32
  from sglang.srt.layers.logits_processor import LogitsProcessor
33
33
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
34
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
34
35
  from sglang.srt.layers.moe.topk import TopK
35
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
37
  from sglang.srt.layers.vocab_parallel_embedding import (
37
38
  ParallelLMHead,
38
39
  VocabParallelEmbedding,
39
40
  )
40
- from sglang.srt.managers.schedule_batch import global_server_args_dict
41
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
42
  from sglang.srt.model_loader.weight_utils import default_weight_loader
43
43
  from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP
@@ -361,7 +361,7 @@ class Ernie4_5_ForCausalLM(nn.Module):
361
361
 
362
362
  class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM):
363
363
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
364
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
364
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
365
365
  ckpt_gate_proj_name="gate_proj",
366
366
  ckpt_down_proj_name="down_proj",
367
367
  ckpt_up_proj_name="up_proj",
@@ -24,6 +24,7 @@ from transformers import PretrainedConfig
24
24
 
25
25
  from sglang.srt.distributed import (
26
26
  get_moe_expert_parallel_world_size,
27
+ get_pp_group,
27
28
  get_tensor_model_parallel_rank,
28
29
  get_tensor_model_parallel_world_size,
29
30
  parallel_state,
@@ -39,7 +40,7 @@ from sglang.srt.layers.communicator import (
39
40
  from sglang.srt.layers.dp_attention import (
40
41
  get_attention_tp_rank,
41
42
  get_attention_tp_size,
42
- get_local_attention_dp_size,
43
+ is_dp_attention_enabled,
43
44
  )
44
45
  from sglang.srt.layers.layernorm import RMSNorm
45
46
  from sglang.srt.layers.linear import (
@@ -50,9 +51,10 @@ from sglang.srt.layers.linear import (
50
51
  RowParallelLinear,
51
52
  )
52
53
  from sglang.srt.layers.logits_processor import LogitsProcessor
54
+ from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
53
55
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
56
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
54
57
  from sglang.srt.layers.moe.topk import TopK
55
- from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
56
58
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
57
59
  from sglang.srt.layers.quantization.fp8_kernel import (
58
60
  is_fp8_fnuz,
@@ -75,10 +77,7 @@ from sglang.srt.models.deepseek_v2 import (
75
77
  DeepseekV2Model,
76
78
  DeepseekV2MoE,
77
79
  )
78
- from sglang.srt.two_batch_overlap import (
79
- MaybeTboDeepEPDispatcher,
80
- model_forward_maybe_tbo,
81
- )
80
+ from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
82
81
  from sglang.srt.utils import (
83
82
  BumpAllocator,
84
83
  LazyValue,
@@ -413,19 +412,15 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
413
412
  config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
414
413
  )
415
414
 
416
- self.topk = (
417
- TopK(
418
- top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
419
- renormalize=config.norm_topk_prob,
420
- use_grouped_topk=True,
421
- num_expert_group=config.n_group,
422
- num_fused_shared_experts=self.num_fused_shared_experts,
423
- topk_group=config.topk_group,
424
- correction_bias=self.gate.e_score_correction_bias,
425
- routed_scaling_factor=self.routed_scaling_factor,
426
- )
427
- if not should_use_flashinfer_trtllm_moe()
428
- else None
415
+ self.topk = TopK(
416
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
417
+ renormalize=config.norm_topk_prob,
418
+ use_grouped_topk=True,
419
+ num_expert_group=config.n_group,
420
+ num_fused_shared_experts=self.num_fused_shared_experts,
421
+ topk_group=config.topk_group,
422
+ correction_bias=self.gate.e_score_correction_bias,
423
+ routed_scaling_factor=self.routed_scaling_factor,
429
424
  )
430
425
 
431
426
  self.experts = get_moe_impl_class()(
@@ -440,31 +435,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
440
435
  quant_config=quant_config,
441
436
  routed_scaling_factor=self.routed_scaling_factor,
442
437
  prefix=add_prefix("experts", prefix),
443
- **(
444
- dict(deepep_mode=global_server_args_dict["deepep_mode"])
445
- if global_server_args_dict["moe_a2a_backend"].is_deepep()
446
- else {}
447
- ),
448
- # Additional args for FusedMoE
449
- **(
450
- dict(
451
- enable_flashinfer_cutlass_moe=True,
452
- )
453
- if global_server_args_dict["enable_flashinfer_cutlass_moe"]
454
- else {}
455
- ),
456
- **(
457
- dict(
458
- renormalize=config.norm_topk_prob,
459
- use_grouped_topk=True,
460
- num_expert_group=config.n_group,
461
- num_fused_shared_experts=self.num_fused_shared_experts,
462
- topk_group=config.topk_group,
463
- correction_bias=self.gate.e_score_correction_bias,
464
- )
465
- if should_use_flashinfer_trtllm_moe()
466
- else {}
467
- ),
468
438
  )
469
439
 
470
440
  self.shared_experts_is_int8 = False
@@ -495,7 +465,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
495
465
 
496
466
  self.top_k = config.num_experts_per_tok
497
467
 
498
- if global_server_args_dict["moe_a2a_backend"].is_deepep():
468
+ if get_moe_a2a_backend().is_deepep():
499
469
  # TODO: we will support tp < ep in the future
500
470
  self.ep_size = get_moe_expert_parallel_world_size()
501
471
  self.num_experts = (
@@ -519,12 +489,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
519
489
  num_local_experts=config.n_routed_experts // self.tp_size,
520
490
  hidden_size=config.hidden_size,
521
491
  params_dtype=config.torch_dtype,
522
- deepep_mode=global_server_args_dict["deepep_mode"],
492
+ deepep_mode=get_deepep_mode(),
523
493
  async_finish=True,
524
494
  return_recv_hook=True,
525
495
  )
526
496
 
527
- self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
497
+ self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
528
498
 
529
499
  def forward_normal_dual_stream(
530
500
  self,
@@ -540,12 +510,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
540
510
  with torch.cuda.stream(self.alt_stream):
541
511
  # router_logits: (num_tokens, n_experts)
542
512
  router_logits = self.gate(hidden_states)
543
- kwargs = {"hidden_states": hidden_states}
544
- if self.topk is not None:
545
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
546
- else:
547
- kwargs["router_logits"] = router_logits
548
- final_hidden_states = self.experts(**kwargs)
513
+ topk_output = self.topk(hidden_states, router_logits)
514
+ final_hidden_states = self.experts(hidden_states, topk_output)
549
515
  if not _is_cuda:
550
516
  final_hidden_states *= self.routed_scaling_factor
551
517
  current_stream.wait_stream(self.alt_stream)
@@ -586,12 +552,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
586
552
  shared_output = self._forward_shared_experts(hidden_states)
587
553
  # router_logits: (num_tokens, n_experts)
588
554
  router_logits = self.gate(hidden_states)
589
- kwargs = {"hidden_states": hidden_states}
590
- if self.topk is not None:
591
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
592
- else:
593
- kwargs["router_logits"] = router_logits
594
- final_hidden_states = self.experts(**kwargs)
555
+ topk_output = self.topk(hidden_states, router_logits)
556
+ final_hidden_states = self.experts(hidden_states, topk_output)
595
557
  if not _is_cuda and not _use_aiter:
596
558
  # fused in biased_grouped_topk so we can skip here
597
559
  final_hidden_states *= self.routed_scaling_factor
@@ -634,7 +596,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
634
596
  )
635
597
  rms_norm_eps = config.rms_norm_eps
636
598
  attention_bias = config.attention_bias
637
- self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
638
599
  self.layer_id = layer_id
639
600
  self.self_attn = Glm4MoeAttention(
640
601
  hidden_size=self.hidden_size,
@@ -744,7 +705,7 @@ class Glm4MoeModel(DeepseekV2Model):
744
705
  self.embed_tokens = VocabParallelEmbedding(
745
706
  config.vocab_size,
746
707
  config.hidden_size,
747
- enable_tp=not global_server_args_dict["enable_dp_attention"],
708
+ enable_tp=not is_dp_attention_enabled(),
748
709
  )
749
710
  self.alt_stream = torch.cuda.Stream() if _is_cuda else None
750
711
  self.layers = nn.ModuleList(
@@ -759,10 +720,11 @@ class Glm4MoeModel(DeepseekV2Model):
759
720
  for layer_id in range(config.num_hidden_layers)
760
721
  ]
761
722
  )
723
+ self.pp_group = get_pp_group()
724
+ self.start_layer = 0
725
+ self.end_layer = config.num_hidden_layers
762
726
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
763
727
 
764
- self.dp_size = get_local_attention_dp_size()
765
-
766
728
 
767
729
  class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
768
730
 
@@ -777,6 +739,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
777
739
  self.config = config
778
740
  self.tp_size = get_tensor_model_parallel_world_size()
779
741
  self.quant_config = quant_config
742
+ self.pp_group = get_pp_group()
780
743
  self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
781
744
  self.model = Glm4MoeModel(
782
745
  config, quant_config, prefix=add_prefix("model", prefix)
@@ -789,7 +752,6 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
789
752
  use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
790
753
  )
791
754
  self.logits_processor = LogitsProcessor(config)
792
- self.dp_size = get_local_attention_dp_size()
793
755
 
794
756
  self._routed_experts_weights_of_layer = LazyValue(
795
757
  lambda: {
@@ -953,7 +915,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
953
915
 
954
916
  # Params for weights, fp8 weight scales, fp8 activation scales
955
917
  # (param_name, weight_name, expert_id, shard_id)
956
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
918
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
957
919
  ckpt_gate_proj_name="gate_proj",
958
920
  ckpt_down_proj_name="down_proj",
959
921
  ckpt_up_proj_name="up_proj",
@@ -22,6 +22,7 @@ from transformers import PretrainedConfig
22
22
 
23
23
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
24
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
25
+ from sglang.srt.layers.dp_attention import is_dp_attention_enabled
25
26
  from sglang.srt.layers.layernorm import RMSNorm
26
27
  from sglang.srt.layers.logits_processor import LogitsProcessor
27
28
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -56,7 +57,7 @@ class Glm4MoeModelNextN(nn.Module):
56
57
  self.embed_tokens = VocabParallelEmbedding(
57
58
  config.vocab_size,
58
59
  config.hidden_size,
59
- enable_tp=not global_server_args_dict["enable_dp_attention"],
60
+ enable_tp=not is_dp_attention_enabled(),
60
61
  prefix=add_prefix("embed_tokens", prefix),
61
62
  )
62
63
 
@@ -9,6 +9,7 @@ from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisi
9
9
 
10
10
  from sglang.srt.hf_transformers_utils import get_processor
11
11
  from sglang.srt.layers.activation import SiluAndMul
12
+ from sglang.srt.layers.attention import vision_utils
12
13
  from sglang.srt.layers.layernorm import RMSNorm
13
14
  from sglang.srt.layers.linear import (
14
15
  ColumnParallelLinear,
@@ -91,6 +92,7 @@ class Glm4vVisionBlock(Qwen2_5_VisionBlock):
91
92
  norm_layer=norm_layer,
92
93
  quant_config=quant_config,
93
94
  prefix=prefix,
95
+ num_dummy_heads=config.num_dummy_heads,
94
96
  )
95
97
  self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
96
98
  self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -469,7 +471,7 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
469
471
  nn.Module.__init__(self)
470
472
 
471
473
  self.config = config
472
-
474
+ vision_utils.update_vit_attn_dummy_heads_config(self.config)
473
475
  self.model = Glm4Model(
474
476
  config,
475
477
  quant_config,
@@ -537,6 +539,51 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
537
539
  video_embeds = torch.split(video_embeds, split_sizes)
538
540
  return torch.cat(video_embeds)
539
541
 
542
+ def _update_hf_config(self):
543
+ """update hf config to ensure vision attention num_attention_heads is divisible by tp_size"""
544
+ tp_size = get_attention_tp_size()
545
+ num_heads = self.config.vision_config.num_heads
546
+ head_dim = self.config.vision_config.hidden_size // num_heads
547
+ num_dummy_heads = 0
548
+
549
+ if num_heads % tp_size != 0:
550
+ num_dummy_heads = (
551
+ (num_heads + tp_size - 1) // tp_size
552
+ ) * tp_size - num_heads
553
+
554
+ setattr(self.config.vision_config, "head_dim", head_dim)
555
+ setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
556
+
557
+ def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
558
+ """pad attn qkv weights for dummy heads"""
559
+ num_dummy_heads = self.config.vision_config.num_dummy_heads
560
+ if num_dummy_heads == 0:
561
+ return loaded_weight
562
+ head_dim = self.config.vision_config.head_dim
563
+
564
+ if "attn.qkv_proj" in name:
565
+ wq, wk, wv = loaded_weight.chunk(3, dim=0)
566
+ if name.endswith(".weight"):
567
+ dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
568
+ elif name.endswith(".bias"):
569
+ dummy_shape = [num_dummy_heads, head_dim]
570
+ else:
571
+ raise RuntimeError(f"Unsupported weight with name={name}")
572
+ pad_func = lambda x: torch.cat(
573
+ [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
574
+ ).flatten(0, 1)
575
+ wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
576
+ loaded_weight = torch.cat([wq, wk, wv], dim=0)
577
+ elif "attn.proj.weight" in name:
578
+ padded_weight = loaded_weight.new_zeros(
579
+ loaded_weight.shape[0], head_dim * num_dummy_heads
580
+ )
581
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
582
+ elif "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
583
+ padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
584
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
585
+ return loaded_weight
586
+
540
587
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
541
588
  stacked_params_mapping = [
542
589
  # (param_name, shard_name, shard_id)
@@ -583,6 +630,10 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
583
630
  raise
584
631
 
585
632
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
633
+ if "visual" in name:
634
+ loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
635
+ self.config, name, loaded_weight
636
+ )
586
637
  weight_loader(param, loaded_weight)
587
638
 
588
639
 
@@ -8,19 +8,12 @@ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
8
8
 
9
9
  from sglang.srt.distributed import (
10
10
  get_moe_expert_parallel_world_size,
11
- get_tensor_model_parallel_rank,
12
11
  get_tensor_model_parallel_world_size,
13
- parallel_state,
14
- tensor_model_parallel_all_reduce,
15
12
  )
16
13
  from sglang.srt.hf_transformers_utils import get_processor
17
- from sglang.srt.layers.dp_attention import (
18
- get_attention_tp_rank,
19
- get_attention_tp_size,
20
- get_local_attention_dp_size,
21
- )
14
+ from sglang.srt.layers.attention import vision_utils
22
15
  from sglang.srt.layers.logits_processor import LogitsProcessor
23
- from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
16
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
24
17
  from sglang.srt.layers.pooler import Pooler, PoolingType
25
18
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
19
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
@@ -48,8 +41,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
48
41
 
49
42
  config.moe_layer_freq = 1
50
43
  self.config = config
44
+ vision_utils.update_vit_attn_dummy_heads_config(self.config)
51
45
  self.tp_size = get_tensor_model_parallel_world_size()
52
- self.dp_size = get_local_attention_dp_size()
53
46
  self.quant_config = quant_config
54
47
  self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
55
48
  self.num_fused_shared_experts = (
@@ -232,7 +225,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
232
225
 
233
226
  # Params for weights, fp8 weight scales, fp8 activation scales
234
227
  # (param_name, weight_name, expert_id, shard_id)
235
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
228
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
236
229
  ckpt_gate_proj_name="gate_proj",
237
230
  ckpt_down_proj_name="down_proj",
238
231
  ckpt_up_proj_name="up_proj",
@@ -394,6 +387,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
394
387
  weight_loader = getattr(
395
388
  param, "weight_loader", default_weight_loader
396
389
  )
390
+ if "visual" in name:
391
+ loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
392
+ self.config, name, loaded_weight
393
+ )
397
394
  weight_loader(param, loaded_weight)
398
395
 
399
396
 
@@ -16,6 +16,7 @@
16
16
  """Inference-only GptOss model compatible with HuggingFace weights."""
17
17
 
18
18
  import logging
19
+ import math
19
20
  from collections.abc import Iterable
20
21
  from functools import partial
21
22
  from typing import Any, Dict, List, Optional, Tuple, Union
@@ -40,7 +41,7 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
40
41
  from sglang.srt.layers.dp_attention import (
41
42
  get_attention_tp_rank,
42
43
  get_attention_tp_size,
43
- get_local_attention_dp_size,
44
+ is_dp_attention_enabled,
44
45
  )
45
46
  from sglang.srt.layers.layernorm import RMSNorm
46
47
  from sglang.srt.layers.linear import (
@@ -49,9 +50,10 @@ from sglang.srt.layers.linear import (
49
50
  RowParallelLinear,
50
51
  )
51
52
  from sglang.srt.layers.logits_processor import LogitsProcessor
53
+ from sglang.srt.layers.moe import get_moe_a2a_backend
52
54
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
55
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
53
56
  from sglang.srt.layers.moe.topk import TopK
54
- from sglang.srt.layers.moe.utils import DeepEPMode
55
57
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
56
58
  from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
57
59
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -109,16 +111,13 @@ class GptOssSparseMoeBlock(nn.Module):
109
111
  self.tp_size = get_tensor_model_parallel_world_size()
110
112
  self.layer_id = layer_id
111
113
  self.activation = config.hidden_act
112
- self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
113
- self.swiglu_limit = config.swiglu_limit
114
+ self.gemm1_alpha = getattr(config, "hidden_act_alpha", 1.702)
115
+ self.gemm1_clamp_limit = config.swiglu_limit
114
116
 
115
- if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
116
- self.topk = None
117
- else:
118
- self.topk = TopK(
119
- top_k=config.num_experts_per_tok,
120
- renormalize=True,
121
- )
117
+ self.topk = TopK(
118
+ top_k=config.num_experts_per_tok,
119
+ renormalize=True,
120
+ )
122
121
 
123
122
  self.top_k = config.num_experts_per_tok
124
123
  experts_type = get_moe_impl_class()
@@ -128,11 +127,9 @@ class GptOssSparseMoeBlock(nn.Module):
128
127
  quant_config.get_name() if quant_config is not None else None
129
128
  )
130
129
  extra_kwargs = {
131
- "enable_flashinfer_cutlass_moe": global_server_args_dict[
132
- "enable_flashinfer_cutlass_moe"
133
- ],
134
130
  # for moe gate_up_proj and down_proj and their bias loading
135
- "use_weight_loader_fused": quant_config_name != "mxfp4",
131
+ "use_weight_loader_fused": quant_config_name
132
+ != "mxfp4"
136
133
  }
137
134
  self.experts = experts_type(
138
135
  num_experts=config.num_local_experts
@@ -143,15 +140,10 @@ class GptOssSparseMoeBlock(nn.Module):
143
140
  intermediate_size=config.intermediate_size,
144
141
  quant_config=quant_config,
145
142
  activation=self.activation,
146
- activation_alpha=self.activation_alpha,
147
- swiglu_limit=self.swiglu_limit,
143
+ gemm1_alpha=self.gemm1_alpha,
144
+ gemm1_clamp_limit=self.gemm1_clamp_limit,
148
145
  with_bias=True,
149
146
  prefix=add_prefix("experts", prefix),
150
- **(
151
- dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
152
- if global_server_args_dict["moe_a2a_backend"].is_deepep()
153
- else {}
154
- ),
155
147
  **extra_kwargs,
156
148
  )
157
149
 
@@ -170,7 +162,7 @@ class GptOssSparseMoeBlock(nn.Module):
170
162
  forward_batch: Optional[ForwardBatch] = None,
171
163
  should_allreduce_fusion: bool = False,
172
164
  ) -> torch.Tensor:
173
- if not global_server_args_dict["moe_a2a_backend"].is_deepep():
165
+ if not get_moe_a2a_backend().is_deepep():
174
166
  return self.forward_normal(hidden_states, should_allreduce_fusion)
175
167
  else:
176
168
  raise Exception("forward_deepep branch not implemented yet")
@@ -188,17 +180,10 @@ class GptOssSparseMoeBlock(nn.Module):
188
180
  should_allreduce_fusion: bool = False,
189
181
  ) -> torch.Tensor:
190
182
  num_tokens, hidden_dim = hidden_states.shape
191
- hidden_states = hidden_states.view(-1, hidden_dim)
192
183
 
193
- # router_logits: (num_tokens, n_experts)
194
184
  router_logits, _ = self.router(hidden_states)
195
-
196
- kwargs = {"hidden_states": hidden_states}
197
- if self.topk is not None:
198
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
199
- else:
200
- kwargs["topk_output"] = (self.top_k, router_logits)
201
- final_hidden_states = self.experts(**kwargs)
185
+ topk_output = self.topk(hidden_states, router_logits)
186
+ final_hidden_states = self.experts(hidden_states, topk_output)
202
187
 
203
188
  if self.tp_size > 1 and not should_allreduce_fusion:
204
189
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
@@ -293,8 +278,12 @@ class GptOssAttention(nn.Module):
293
278
  prefix=add_prefix("qkv_proj", prefix),
294
279
  )
295
280
 
281
+ # Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
282
+ # others can use bfloat16
283
+ attn_backend = global_server_args_dict.get("attention_backend")
284
+ sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
296
285
  self.sinks = nn.Parameter(
297
- torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False
286
+ torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
298
287
  )
299
288
 
300
289
  self.o_proj = RowParallelLinear(
@@ -431,7 +420,6 @@ class GptOssDecoderLayer(nn.Module):
431
420
 
432
421
  self.attn_tp_size = get_attention_tp_size()
433
422
  self.attn_tp_rank = get_attention_tp_rank()
434
- self.local_dp_size = get_local_attention_dp_size()
435
423
 
436
424
  # GptOss all layers are sparse and have no nextn now
437
425
  self.is_layer_sparse = True
@@ -466,44 +454,11 @@ class GptOssDecoderLayer(nn.Module):
466
454
  layer_scatter_modes=self.layer_scatter_modes,
467
455
  input_layernorm=self.input_layernorm,
468
456
  post_attention_layernorm=self.post_attention_layernorm,
457
+ is_last_layer=(
458
+ self.is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
459
+ ),
469
460
  )
470
461
 
471
- self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
472
-
473
- def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
474
- """Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
475
-
476
- batch_size = (
477
- forward_batch.input_ids.shape[0]
478
- if hasattr(forward_batch, "input_ids")
479
- else 0
480
- )
481
-
482
- if batch_size > 128:
483
- return False
484
-
485
- return self._fuse_allreduce_lookup_table.get(batch_size, False)
486
-
487
- def _build_fuse_allreduce_lookup_table(self):
488
- static_conditions_met = (
489
- self.layer_id != self.config.num_hidden_layers - 1
490
- and get_tensor_model_parallel_world_size() > 1
491
- and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
492
- and _is_sm100_supported
493
- and _is_flashinfer_available
494
- )
495
-
496
- if not static_conditions_met:
497
- return {}
498
-
499
- lookup_table = {}
500
- for batch_size in range(129): # 0 to 128
501
- is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
502
- should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
503
- lookup_table[batch_size] = should_fuse
504
-
505
- return lookup_table
506
-
507
462
  def forward(
508
463
  self,
509
464
  positions: torch.Tensor,
@@ -527,8 +482,9 @@ class GptOssDecoderLayer(nn.Module):
527
482
  )
528
483
 
529
484
  should_allreduce_fusion = (
530
- self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
531
- and not self.is_nextn
485
+ self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
486
+ forward_batch
487
+ )
532
488
  )
533
489
 
534
490
  hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
@@ -561,7 +517,7 @@ class GptOssModel(nn.Module):
561
517
  self.embed_tokens = VocabParallelEmbedding(
562
518
  config.vocab_size,
563
519
  config.hidden_size,
564
- enable_tp=not global_server_args_dict["enable_dp_attention"],
520
+ enable_tp=not is_dp_attention_enabled(),
565
521
  prefix=add_prefix("embed_tokens", prefix),
566
522
  )
567
523
  else:
@@ -833,18 +789,27 @@ class GptOssForCausalLM(nn.Module):
833
789
  moe_ep_size = get_moe_expert_parallel_world_size()
834
790
 
835
791
  intermediate_size = self.config.intermediate_size
792
+ assert (
793
+ intermediate_size % mxfp4_block == 0
794
+ ), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
836
795
  intermediate_size_block = intermediate_size // mxfp4_block
837
- per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
796
+
797
+ per_rank_intermediate_size_block = math.ceil(
798
+ intermediate_size_block / moe_tp_size
799
+ )
800
+
838
801
  per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
839
802
 
840
803
  # Calculate common slicing bounds for current rank
841
804
  assert self.config.num_local_experts % moe_ep_size == 0
842
805
  moe_num_global_experts = self.config.num_local_experts
843
806
  moe_num_local_experts = self.config.num_local_experts // moe_ep_size
807
+
844
808
  moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
845
809
  moe_tp_rank_end = min(
846
810
  (moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
847
811
  )
812
+
848
813
  moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
849
814
  moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts
850
815
 
@@ -1055,7 +1020,7 @@ class GptOssForCausalLM(nn.Module):
1055
1020
  ("qkv_proj", "k_proj", "k"),
1056
1021
  ("qkv_proj", "v_proj", "v"),
1057
1022
  ]
1058
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
1023
+ expert_params_mapping = FusedMoE.make_expert_params_mapping_fused(
1059
1024
  ckpt_gate_up_proj_name="gate_up_proj",
1060
1025
  ckpt_down_proj_name="down_proj",
1061
1026
  ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
@@ -1136,7 +1101,7 @@ class GptOssForCausalLM(nn.Module):
1136
1101
  if name in params_dict.keys():
1137
1102
  param = params_dict[name]
1138
1103
  if "sinks" in name:
1139
- start = tp_rank * param.numel()
1104
+ start = get_attention_tp_rank() * param.numel()
1140
1105
  param.data.copy_(
1141
1106
  loaded_weight[start : start + param.numel()]
1142
1107
  )
@@ -76,7 +76,6 @@ class GraniteMoeMoE(nn.Module):
76
76
  params_dtype=params_dtype,
77
77
  reduce_results=True,
78
78
  quant_config=quant_config,
79
- tp_size=tp_size,
80
79
  prefix=f"{prefix}.experts",
81
80
  )
82
81