sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +48 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +34 -0
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/nixl/conn.py +6 -6
  10. sglang/srt/disaggregation/prefill.py +2 -2
  11. sglang/srt/disaggregation/utils.py +1 -1
  12. sglang/srt/distributed/parallel_state.py +44 -17
  13. sglang/srt/entrypoints/EngineBase.py +8 -0
  14. sglang/srt/entrypoints/engine.py +40 -6
  15. sglang/srt/entrypoints/http_server.py +111 -24
  16. sglang/srt/entrypoints/openai/protocol.py +4 -2
  17. sglang/srt/eplb/__init__.py +0 -0
  18. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  19. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  20. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  21. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  22. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  24. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  25. sglang/srt/hf_transformers_utils.py +2 -1
  26. sglang/srt/layers/activation.py +2 -2
  27. sglang/srt/layers/amx_utils.py +86 -0
  28. sglang/srt/layers/attention/ascend_backend.py +219 -0
  29. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  30. sglang/srt/layers/attention/tbo_backend.py +37 -9
  31. sglang/srt/layers/communicator.py +18 -2
  32. sglang/srt/layers/dp_attention.py +9 -3
  33. sglang/srt/layers/elementwise.py +76 -12
  34. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  35. sglang/srt/layers/layernorm.py +26 -0
  36. sglang/srt/layers/linear.py +84 -14
  37. sglang/srt/layers/logits_processor.py +4 -4
  38. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +36 -13
  40. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  41. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
  42. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
  43. sglang/srt/layers/moe/router.py +60 -22
  44. sglang/srt/layers/moe/topk.py +10 -28
  45. sglang/srt/layers/parameter.py +67 -7
  46. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  47. sglang/srt/layers/quantization/fp8.py +44 -0
  48. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  49. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  50. sglang/srt/layers/quantization/gptq.py +5 -1
  51. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  52. sglang/srt/layers/quantization/quant_utils.py +166 -0
  53. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  54. sglang/srt/layers/rotary_embedding.py +2 -2
  55. sglang/srt/layers/vocab_parallel_embedding.py +11 -7
  56. sglang/srt/lora/lora.py +4 -5
  57. sglang/srt/lora/lora_manager.py +73 -20
  58. sglang/srt/managers/configure_logging.py +1 -1
  59. sglang/srt/managers/io_struct.py +50 -13
  60. sglang/srt/managers/mm_utils.py +73 -59
  61. sglang/srt/managers/multimodal_processor.py +2 -6
  62. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  63. sglang/srt/managers/schedule_batch.py +77 -84
  64. sglang/srt/managers/scheduler.py +113 -59
  65. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  66. sglang/srt/managers/session_controller.py +12 -3
  67. sglang/srt/managers/tokenizer_manager.py +314 -103
  68. sglang/srt/managers/tp_worker.py +13 -1
  69. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  70. sglang/srt/mem_cache/allocator.py +290 -0
  71. sglang/srt/mem_cache/chunk_cache.py +34 -2
  72. sglang/srt/mem_cache/memory_pool.py +289 -3
  73. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  74. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  75. sglang/srt/model_executor/forward_batch_info.py +17 -4
  76. sglang/srt/model_executor/model_runner.py +297 -56
  77. sglang/srt/model_loader/loader.py +41 -0
  78. sglang/srt/model_loader/weight_utils.py +72 -4
  79. sglang/srt/models/deepseek_nextn.py +1 -3
  80. sglang/srt/models/deepseek_v2.py +181 -45
  81. sglang/srt/models/deepseek_vl2.py +3 -5
  82. sglang/srt/models/gemma3_causal.py +1 -2
  83. sglang/srt/models/gemma3n_causal.py +4 -3
  84. sglang/srt/models/gemma3n_mm.py +4 -20
  85. sglang/srt/models/hunyuan.py +1 -1
  86. sglang/srt/models/kimi_vl.py +1 -2
  87. sglang/srt/models/llama.py +10 -4
  88. sglang/srt/models/llama4.py +32 -45
  89. sglang/srt/models/llama_eagle3.py +61 -11
  90. sglang/srt/models/llava.py +5 -5
  91. sglang/srt/models/minicpmo.py +2 -2
  92. sglang/srt/models/mistral.py +1 -1
  93. sglang/srt/models/mllama4.py +43 -11
  94. sglang/srt/models/phi4mm.py +1 -3
  95. sglang/srt/models/pixtral.py +3 -7
  96. sglang/srt/models/qwen2.py +31 -3
  97. sglang/srt/models/qwen2_5_vl.py +1 -3
  98. sglang/srt/models/qwen2_audio.py +200 -0
  99. sglang/srt/models/qwen2_moe.py +32 -6
  100. sglang/srt/models/qwen2_vl.py +1 -4
  101. sglang/srt/models/qwen3.py +94 -25
  102. sglang/srt/models/qwen3_moe.py +68 -21
  103. sglang/srt/models/vila.py +3 -8
  104. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  105. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  106. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  107. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  108. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  109. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  110. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  111. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  112. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  117. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  120. sglang/srt/operations_strategy.py +6 -2
  121. sglang/srt/reasoning_parser.py +26 -0
  122. sglang/srt/sampling/sampling_batch_info.py +39 -1
  123. sglang/srt/server_args.py +69 -22
  124. sglang/srt/speculative/build_eagle_tree.py +57 -18
  125. sglang/srt/speculative/eagle_worker.py +6 -4
  126. sglang/srt/two_batch_overlap.py +200 -27
  127. sglang/srt/utils.py +306 -146
  128. sglang/srt/warmup.py +12 -3
  129. sglang/test/runners.py +10 -1
  130. sglang/test/test_utils.py +15 -3
  131. sglang/version.py +1 -1
  132. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  133. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
  134. sglang/math_utils.py +0 -8
  135. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  136. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  137. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  138. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  139. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  140. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  141. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Tuple
3
3
 
4
4
  import einops
5
5
  import torch
6
- from sgl_kernel import silu_and_mul
7
6
  from torch.nn import Module
8
7
 
9
8
  from sglang.srt.custom_op import CustomOp
@@ -11,6 +10,8 @@ from sglang.srt.distributed import (
11
10
  get_tensor_model_parallel_rank,
12
11
  get_tensor_model_parallel_world_size,
13
12
  )
13
+ from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
14
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
14
15
  from sglang.srt.layers.moe.ep_moe.kernels import (
15
16
  ep_gather,
16
17
  ep_scatter,
@@ -40,22 +41,26 @@ from sglang.srt.layers.quantization.fp8_kernel import (
40
41
  sglang_per_token_quant_fp8,
41
42
  )
42
43
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
43
- from sglang.srt.managers.expert_location import get_global_expert_location_metadata
44
- from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
45
44
  from sglang.srt.managers.schedule_batch import global_server_args_dict
46
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
45
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
46
  from sglang.srt.utils import (
48
47
  DeepEPMode,
48
+ ceil_div,
49
49
  dispose_tensor,
50
50
  get_bool_env_var,
51
51
  is_hip,
52
+ is_npu,
52
53
  set_weight_attrs,
53
54
  )
54
55
 
55
56
  _is_hip = is_hip()
57
+ _is_npu = is_npu()
56
58
  _is_fp8_fnuz = is_fp8_fnuz()
57
59
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
58
60
 
61
+ if not _is_npu:
62
+ from sgl_kernel import silu_and_mul
63
+
59
64
  if _is_hip:
60
65
  from vllm._custom_ops import scaled_fp8_quant
61
66
 
@@ -1173,12 +1178,14 @@ class DeepEPMoE(EPMoE):
1173
1178
  masked_m: torch.Tensor,
1174
1179
  expected_m: int,
1175
1180
  num_recv_tokens_per_expert: List[int],
1176
- forward_mode: ForwardMode,
1181
+ forward_batch: ForwardBatch,
1177
1182
  ):
1178
1183
  if _use_aiter:
1179
1184
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
1180
1185
  return self.forward_aiter(hidden_states, topk_idx, topk_weights)
1181
- resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
1186
+ resolved_deepep_mode = self.deepep_mode.resolve(
1187
+ forward_batch.is_extend_in_batch
1188
+ )
1182
1189
  if resolved_deepep_mode == DeepEPMode.normal:
1183
1190
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
1184
1191
  return self.forward_deepgemm_contiguous(
@@ -1370,10 +1377,19 @@ class DeepEPMoE(EPMoE):
1370
1377
  device=hidden_states_fp8.device,
1371
1378
  dtype=hidden_states_fp8.dtype,
1372
1379
  ),
1373
- torch.empty(
1374
- (all_tokens, K // 128),
1375
- device=hidden_states_fp8.device,
1376
- dtype=torch.float32,
1380
+ (
1381
+ # TODO check whether need `zeros`
1382
+ torch.zeros(
1383
+ (ceil_div(K // 128, 4), all_tokens),
1384
+ device=hidden_states_fp8.device,
1385
+ dtype=torch.int,
1386
+ ).transpose(0, 1)
1387
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
1388
+ else torch.empty(
1389
+ (all_tokens, K // 128),
1390
+ device=hidden_states_fp8.device,
1391
+ dtype=torch.float32,
1392
+ )
1377
1393
  ),
1378
1394
  ]
1379
1395
  m_indices = torch.empty(
@@ -1399,6 +1415,7 @@ class DeepEPMoE(EPMoE):
1399
1415
  input_tensor[1],
1400
1416
  m_indices,
1401
1417
  output_index,
1418
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
1402
1419
  )
1403
1420
  dispose_tensor(hidden_states_fp8)
1404
1421
 
@@ -1407,7 +1424,8 @@ class DeepEPMoE(EPMoE):
1407
1424
  device=hidden_states_fp8_device,
1408
1425
  dtype=torch.bfloat16,
1409
1426
  )
1410
- input_tensor[1] = tma_align_input_scale(input_tensor[1])
1427
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
1428
+ input_tensor[1] = tma_align_input_scale(input_tensor[1])
1411
1429
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
1412
1430
  input_tensor, self.w13_weight_fp8, gateup_output, m_indices
1413
1431
  )
@@ -1428,10 +1446,15 @@ class DeepEPMoE(EPMoE):
1428
1446
  dtype=torch.bfloat16,
1429
1447
  )
1430
1448
  down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
1431
- down_input, scale_block_size
1449
+ down_input,
1450
+ scale_block_size,
1451
+ column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
1452
+ scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
1453
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
1432
1454
  )
1433
1455
  del down_input
1434
- down_input_scale = tma_align_input_scale(down_input_scale)
1456
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
1457
+ down_input_scale = tma_align_input_scale(down_input_scale)
1435
1458
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
1436
1459
  (down_input_fp8, down_input_scale),
1437
1460
  self.w2_weight_fp8,
@@ -1,10 +1,8 @@
1
1
  import logging
2
2
  from dataclasses import dataclass
3
3
 
4
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
4
5
  from sglang.srt.layers.quantization import deep_gemm_wrapper
5
- from sglang.srt.managers.expert_distribution import (
6
- get_global_expert_distribution_recorder,
7
- )
8
6
  from sglang.srt.managers.schedule_batch import global_server_args_dict
9
7
  from sglang.srt.utils import (
10
8
  DeepEPMode,
@@ -36,7 +34,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
36
34
  deepep_post_reorder_triton_kernel,
37
35
  deepep_run_moe_deep_preprocess,
38
36
  )
39
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
37
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
38
 
41
39
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
42
40
 
@@ -246,7 +244,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
246
244
  topk_idx = topk_idx.to(torch.int64)
247
245
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
248
246
  # TODO hard code 128 block quant,use fp8 communication
249
- hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
247
+ hidden_states = sglang_per_token_group_quant_fp8(
248
+ hidden_states,
249
+ 128,
250
+ column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
251
+ scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
252
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
253
+ )
250
254
  previous_event = Buffer.capture() if self.async_finish else None
251
255
  return hidden_states, topk_idx, topk_weights, previous_event
252
256
 
@@ -682,21 +686,21 @@ class DeepEPDispatcher:
682
686
  hidden_states: torch.Tensor,
683
687
  topk_idx: torch.Tensor,
684
688
  topk_weights: torch.Tensor,
685
- forward_mode: ForwardMode = None,
689
+ forward_batch: ForwardBatch,
686
690
  ):
687
691
  self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
688
- inner_state = self._get_impl(forward_mode).dispatch_a(
692
+ inner_state = self._get_impl(forward_batch).dispatch_a(
689
693
  hidden_states=hidden_states,
690
694
  topk_idx=topk_idx,
691
695
  topk_weights=topk_weights,
692
696
  )
693
- self._dispatch_intermediate_state = forward_mode, inner_state
697
+ self._dispatch_intermediate_state = forward_batch, inner_state
694
698
 
695
699
  def dispatch_b(self):
696
700
  self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
697
- forward_mode, inner_state = self._dispatch_intermediate_state
701
+ forward_batch, inner_state = self._dispatch_intermediate_state
698
702
  del self._dispatch_intermediate_state
699
- return self._get_impl(forward_mode).dispatch_b(*inner_state)
703
+ return self._get_impl(forward_batch).dispatch_b(*inner_state)
700
704
 
701
705
  def combine(self, *args, **kwargs) -> Tuple:
702
706
  self.combine_a(*args, **kwargs)
@@ -708,24 +712,26 @@ class DeepEPDispatcher:
708
712
  hidden_states: torch.Tensor,
709
713
  topk_idx: torch.Tensor,
710
714
  topk_weights: torch.Tensor,
711
- forward_mode: ForwardMode,
715
+ forward_batch: ForwardBatch,
712
716
  ):
713
717
  self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
714
- inner_state = self._get_impl(forward_mode).combine_a(
718
+ inner_state = self._get_impl(forward_batch).combine_a(
715
719
  hidden_states=hidden_states,
716
720
  topk_idx=topk_idx,
717
721
  topk_weights=topk_weights,
718
722
  )
719
- self._combine_intermediate_state = forward_mode, inner_state
723
+ self._combine_intermediate_state = forward_batch, inner_state
720
724
 
721
725
  def combine_b(self):
722
726
  self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
723
- forward_mode, inner_state = self._combine_intermediate_state
727
+ forward_batch, inner_state = self._combine_intermediate_state
724
728
  del self._combine_intermediate_state
725
- return self._get_impl(forward_mode).combine_b(*inner_state)
729
+ return self._get_impl(forward_batch).combine_b(*inner_state)
726
730
 
727
- def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
728
- resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
731
+ def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
732
+ resolved_deepep_mode = self.deepep_mode.resolve(
733
+ forward_batch.is_extend_in_batch
734
+ )
729
735
  if resolved_deepep_mode == DeepEPMode.normal:
730
736
  return self._normal_dispatcher
731
737
  elif resolved_deepep_mode == DeepEPMode.low_latency:
@@ -12,7 +12,6 @@ import torch
12
12
  import triton
13
13
  import triton.language as tl
14
14
 
15
- from sglang.math_utils import ceil_div
16
15
  from sglang.srt.layers.moe.topk import select_experts
17
16
  from sglang.srt.layers.quantization.fp8_kernel import (
18
17
  per_token_group_quant_fp8,
@@ -25,6 +24,7 @@ from sglang.srt.layers.quantization.int8_kernel import (
25
24
  sglang_per_token_group_quant_int8,
26
25
  )
27
26
  from sglang.srt.utils import (
27
+ ceil_div,
28
28
  cpu_has_amx_support,
29
29
  direct_register_custom_op,
30
30
  get_bool_env_var,
@@ -32,7 +32,6 @@ from sglang.srt.utils import (
32
32
  is_cpu,
33
33
  is_cuda,
34
34
  is_hip,
35
- log_info_on_rank0,
36
35
  next_power_of_2,
37
36
  )
38
37
 
@@ -12,19 +12,21 @@ from sglang.srt.distributed import (
12
12
  get_tensor_model_parallel_world_size,
13
13
  tensor_model_parallel_all_reduce,
14
14
  )
15
+ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
15
16
  from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
16
17
  from sglang.srt.layers.moe.topk import select_experts
17
18
  from sglang.srt.layers.quantization.base_config import (
18
19
  QuantizationConfig,
19
20
  QuantizeMethodBase,
20
21
  )
22
+ from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
21
23
  from sglang.srt.utils import (
22
- _process_weight_after_loading,
23
24
  cpu_has_amx_support,
24
25
  get_bool_env_var,
25
26
  is_cpu,
26
27
  is_hip,
27
28
  set_weight_attrs,
29
+ use_intel_amx_backend,
28
30
  )
29
31
 
30
32
  if torch.cuda.is_available():
@@ -129,7 +131,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
129
131
 
130
132
  # Pack weight for get better performance on CPU
131
133
  if _is_cpu and _is_cpu_amx_available:
132
- _process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
134
+ _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
133
135
 
134
136
  return
135
137
 
@@ -264,10 +266,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
264
266
  ) -> torch.Tensor:
265
267
  assert activation == "silu", f"activation = {activation} is not supported."
266
268
 
267
- if (
268
- getattr(layer, "use_intel_amx_backend", False)
269
- and not apply_router_weight_on_input
270
- ):
269
+ if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
271
270
  topk_weights, topk_ids = select_experts(
272
271
  hidden_states=x,
273
272
  router_logits=router_logits,
@@ -291,7 +290,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
291
290
  torch.float
292
291
  ), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
293
292
  topk_ids,
294
- True, # inplace
293
+ False, # inplace # See [Note] inplace should be False in fused_experts.
295
294
  False, # use_int8_w8a8
296
295
  False, # use_fp8_w8a16
297
296
  None, # w1_scale
@@ -321,6 +320,44 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
321
320
  routed_scaling_factor,
322
321
  )
323
322
 
323
+ def forward_npu(
324
+ self,
325
+ layer: torch.nn.Module,
326
+ x: torch.Tensor,
327
+ use_grouped_topk: bool,
328
+ top_k: int,
329
+ router_logits: torch.Tensor,
330
+ renormalize: bool,
331
+ topk_group: Optional[int] = None,
332
+ num_expert_group: Optional[int] = None,
333
+ num_fused_shared_experts: int = 0,
334
+ custom_routing_function: Optional[Callable] = None,
335
+ correction_bias: Optional[torch.Tensor] = None,
336
+ activation: str = "silu",
337
+ apply_router_weight_on_input: bool = False,
338
+ inplace: bool = True,
339
+ no_combine: bool = False,
340
+ routed_scaling_factor: Optional[float] = None,
341
+ ) -> torch.Tensor:
342
+ return moe_forward_native(
343
+ layer,
344
+ x,
345
+ use_grouped_topk,
346
+ top_k,
347
+ router_logits,
348
+ renormalize,
349
+ topk_group,
350
+ num_expert_group,
351
+ num_fused_shared_experts,
352
+ custom_routing_function,
353
+ correction_bias,
354
+ activation,
355
+ apply_router_weight_on_input,
356
+ inplace,
357
+ no_combine,
358
+ routed_scaling_factor,
359
+ )
360
+
324
361
  def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
325
362
  raise NotImplementedError("The TPU backend currently does not support MoE.")
326
363
 
@@ -537,11 +574,6 @@ class FusedMoE(torch.nn.Module):
537
574
  # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
538
575
  shard_size = expert_data.shape[shard_dim] // 2
539
576
 
540
- if not self.use_presharded_weights:
541
- loaded_weight = loaded_weight.narrow(
542
- shard_dim, shard_size * tp_rank, shard_size
543
- )
544
-
545
577
  # Narrow parameter and load.
546
578
  # w1, gate_proj: Load into first logical weight of w13.
547
579
  # w3, up_proj: Load into second logical weight of w13.
@@ -552,7 +584,24 @@ class FusedMoE(torch.nn.Module):
552
584
  start = shard_size
553
585
  else:
554
586
  start = 0
555
- expert_data = expert_data.narrow(shard_dim, start, shard_size)
587
+
588
+ if _is_cpu:
589
+ expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
590
+ expert_data,
591
+ loaded_weight,
592
+ start,
593
+ shard_size * tp_rank,
594
+ shard_dim,
595
+ shard_size,
596
+ not self.use_presharded_weights,
597
+ )
598
+ else:
599
+ if not self.use_presharded_weights:
600
+ loaded_weight = loaded_weight.narrow(
601
+ shard_dim, shard_size * tp_rank, shard_size
602
+ )
603
+
604
+ expert_data = expert_data.narrow(shard_dim, start, shard_size)
556
605
  expert_data.copy_(loaded_weight)
557
606
 
558
607
  def _load_w2(
@@ -569,10 +618,21 @@ class FusedMoE(torch.nn.Module):
569
618
  # Narrow parameter and load.
570
619
  shard_size = expert_data.shape[shard_dim]
571
620
 
572
- if not self.use_presharded_weights:
573
- loaded_weight = loaded_weight.narrow(
574
- shard_dim, shard_size * tp_rank, shard_size
621
+ if _is_cpu:
622
+ expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
623
+ expert_data,
624
+ loaded_weight,
625
+ 0, # param_data_start
626
+ shard_size * tp_rank,
627
+ shard_dim,
628
+ shard_size,
629
+ not self.use_presharded_weights,
575
630
  )
631
+ else:
632
+ if not self.use_presharded_weights:
633
+ loaded_weight = loaded_weight.narrow(
634
+ shard_dim, shard_size * tp_rank, shard_size
635
+ )
576
636
 
577
637
  # w2, down_proj: Load into only logical weight of w2.
578
638
  expert_data.copy_(loaded_weight)
@@ -1,4 +1,4 @@
1
- from typing import Tuple
1
+ from typing import Optional, Tuple
2
2
 
3
3
  import torch
4
4
  import triton
@@ -16,6 +16,8 @@ def fused_moe_router_kernel(
16
16
  moe_router_weight_ptr, # input (num_experts, hidden_dim)
17
17
  topk_weights_ptr, # output (bs, topk)
18
18
  topk_ids_ptr, # output (bs, topk)
19
+ correction_bias_ptr,
20
+ is_correction_bias: tl.constexpr,
19
21
  num_experts: tl.constexpr,
20
22
  topk: tl.constexpr,
21
23
  moe_softcapping: tl.constexpr,
@@ -49,6 +51,11 @@ def fused_moe_router_kernel(
49
51
  bottom = exped + 1
50
52
  logits_softcapped = top / bottom * moe_softcapping
51
53
 
54
+ # Add bias after softcapping
55
+ if is_correction_bias:
56
+ bias = tl.load(correction_bias_ptr + tl.arange(0, num_experts))
57
+ logits_softcapped = logits_softcapped + bias
58
+
52
59
  # topk
53
60
  # assert 1 <= topk <= num_experts
54
61
 
@@ -109,6 +116,7 @@ def fused_moe_router_impl(
109
116
  router_weight: torch.Tensor,
110
117
  topk: int,
111
118
  moe_softcapping: float,
119
+ correction_bias: Optional[torch.Tensor] = None,
112
120
  ):
113
121
  assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
114
122
  bs, hidden_dim = x.shape
@@ -117,23 +125,23 @@ def fused_moe_router_impl(
117
125
  # router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
118
126
  topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
119
127
  topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
128
+ is_correction_bias = correction_bias is not None
120
129
 
121
- grid = lambda meta: (bs,)
122
-
123
- min_num_warps = 16 if _is_hip else 32
124
-
130
+ max_warps = 16 if _is_hip else 32
125
131
  config = {
126
132
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
127
133
  "num_warps": max(
128
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
134
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
129
135
  ),
130
136
  }
131
137
 
132
- fused_moe_router_kernel[grid](
138
+ fused_moe_router_kernel[(bs,)](
133
139
  x,
134
140
  router_weight,
135
141
  topk_weights,
136
142
  topk_ids,
143
+ correction_bias,
144
+ is_correction_bias=is_correction_bias,
137
145
  num_experts=num_experts,
138
146
  topk=topk,
139
147
  moe_softcapping=moe_softcapping,
@@ -153,7 +161,7 @@ def fused_moe_router_large_bs_kernel(
153
161
  topk_ids_ptr, # output (bs, topk)
154
162
  bs,
155
163
  num_experts: tl.constexpr,
156
- topk: tl.constexpr, # only support topk == 1
164
+ topk: tl.constexpr, # only support topk <= 2
157
165
  moe_softcapping: tl.constexpr,
158
166
  moe_renormalize: tl.constexpr, # not supported
159
167
  K: tl.constexpr,
@@ -204,25 +212,53 @@ def fused_moe_router_large_bs_kernel(
204
212
  logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
205
213
 
206
214
  # 5. top1
207
- cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
208
- top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
215
+ arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
216
+ cond_top1 = arange_block_size_n < num_experts
217
+ top1 = tl.argmax(tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1)
209
218
  top1_v = tl.max(
210
- tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
219
+ tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1, keep_dims=True
211
220
  )
212
- invsumexp = 1.0 / tl.sum(
213
- tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
221
+ top1_invsumexp = 1.0 / tl.sum(
222
+ tl.where(cond_top1, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
214
223
  )
215
224
 
216
- # 6. store to output
217
- offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
218
- topk_mask = offs_topk < bs
219
- tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
225
+ # 6. store top1 to output
226
+ offs_top1 = pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)
227
+ top1_mask = offs_top1 < bs * topk
228
+ tl.store(topk_ids_ptr + offs_top1, top1, mask=top1_mask)
220
229
  tl.store(
221
- topk_weights_ptr + offs_topk,
222
- invsumexp,
223
- mask=topk_mask,
230
+ topk_weights_ptr + offs_top1,
231
+ top1_invsumexp,
232
+ mask=top1_mask,
224
233
  )
225
234
 
235
+ # 7. handle topk == 2
236
+ if topk == 2:
237
+ cond_top2 = (arange_block_size_n < num_experts) and (
238
+ arange_block_size_n != top1[:, None]
239
+ )
240
+ top2 = tl.argmax(
241
+ tl.where(cond_top2, logits_softcapped, float("-inf")),
242
+ axis=1,
243
+ keep_dims=True,
244
+ )
245
+ top2_v = tl.sum(
246
+ logits_softcapped * (arange_block_size_n == top2), axis=1, keep_dims=True
247
+ )
248
+ top2_invsumexp = tl.exp(top2_v - top1_v) * top1_invsumexp[:, None]
249
+
250
+ # store top2
251
+ offs_top2 = (
252
+ pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)[:, None] + 1
253
+ )
254
+ top2_mask = offs_top2 < bs * topk
255
+ tl.store(topk_ids_ptr + offs_top2, top2, mask=top2_mask)
256
+ tl.store(
257
+ topk_weights_ptr + offs_top2,
258
+ top2_invsumexp,
259
+ mask=top2_mask,
260
+ )
261
+
226
262
 
227
263
  def fused_moe_router_large_bs_impl(
228
264
  x: torch.Tensor,
@@ -239,7 +275,7 @@ def fused_moe_router_large_bs_impl(
239
275
 
240
276
  assert num_experts <= BLOCK_SIZE_N
241
277
  assert hidden_dim % BLOCK_SIZE_K == 0
242
- assert topk == 1
278
+ assert topk <= 2
243
279
 
244
280
  topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
245
281
  topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
@@ -273,6 +309,7 @@ def fused_moe_router_shim(
273
309
  gating_output,
274
310
  topk,
275
311
  renormalize,
312
+ correction_bias: Optional[torch.Tensor] = None,
276
313
  ):
277
314
  assert not renormalize
278
315
  assert (
@@ -286,7 +323,7 @@ def fused_moe_router_shim(
286
323
  BLOCK_SIZE_K = 256
287
324
  if (
288
325
  bs >= 512
289
- and topk == 1
326
+ and topk <= 2
290
327
  and num_experts <= BLOCK_SIZE_N
291
328
  and hidden_dim % BLOCK_SIZE_K == 0
292
329
  ):
@@ -305,6 +342,7 @@ def fused_moe_router_shim(
305
342
  router_weight=gating_output,
306
343
  topk=topk,
307
344
  moe_softcapping=moe_softcapping,
345
+ correction_bias=correction_bias,
308
346
  )
309
347
 
310
348
 
@@ -18,12 +18,12 @@ from typing import Callable, Optional
18
18
  import torch
19
19
  import torch.nn.functional as F
20
20
 
21
- from sglang.srt.managers import expert_location_dispatch
22
- from sglang.srt.managers.expert_distribution import (
21
+ from sglang.srt.eplb import expert_location_dispatch
22
+ from sglang.srt.eplb.expert_distribution import (
23
23
  ExpertDistributionRecorder,
24
24
  get_global_expert_distribution_recorder,
25
25
  )
26
- from sglang.srt.managers.expert_location_dispatch import (
26
+ from sglang.srt.eplb.expert_location_dispatch import (
27
27
  ExpertLocationDispatchInfo,
28
28
  topk_ids_logical_to_physical,
29
29
  )
@@ -35,6 +35,7 @@ from sglang.srt.utils import (
35
35
  is_cpu,
36
36
  is_cuda,
37
37
  is_hip,
38
+ is_npu,
38
39
  )
39
40
 
40
41
  _is_cuda = is_cuda()
@@ -42,6 +43,7 @@ _is_hip = is_hip()
42
43
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
43
44
  _is_cpu_amx_available = cpu_has_amx_support()
44
45
  _is_cpu = is_cpu()
46
+ _is_npu = is_npu()
45
47
 
46
48
  if _is_cuda:
47
49
  from sgl_kernel import moe_fused_gate
@@ -106,37 +108,14 @@ def fused_topk(
106
108
  M, topk, dtype=torch.float32, device=hidden_states.device
107
109
  )
108
110
  topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
109
- token_expert_indicies = torch.empty(
110
- M, topk, dtype=torch.int32, device=hidden_states.device
111
- )
112
111
 
113
112
  topk_softmax(
114
113
  topk_weights,
115
114
  topk_ids,
116
- token_expert_indicies,
117
- gating_output.float(),
118
- )
119
- del token_expert_indicies
120
-
121
- return _fused_topk_postprocess(
122
- topk_weights=topk_weights,
123
- topk_ids=topk_ids,
124
- renormalize=renormalize,
125
- expert_location_dispatch_info=expert_location_dispatch_info,
126
- num_token_non_padded=num_token_non_padded,
115
+ gating_output,
116
+ renormalize,
127
117
  )
128
118
 
129
-
130
- @torch.compile(dynamic=True, backend=get_compiler_backend())
131
- def _fused_topk_postprocess(
132
- topk_weights,
133
- topk_ids,
134
- renormalize,
135
- expert_location_dispatch_info,
136
- num_token_non_padded,
137
- ):
138
- if renormalize:
139
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
140
119
  topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
141
120
  _mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
142
121
  return topk_weights, topk_ids
@@ -159,6 +138,9 @@ def grouped_topk_gpu(
159
138
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
160
139
 
161
140
  scores = torch.softmax(gating_output, dim=-1)
141
+ # NPU compiler limitation
142
+ if _is_npu and scores.dtype == torch.bfloat16:
143
+ scores = scores.to(torch.float16)
162
144
  num_token = scores.shape[0]
163
145
  num_experts = scores.shape[1]
164
146
  group_scores = (