sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,8 @@ import torch
11
11
  from compressed_tensors import CompressionFormat
12
12
  from compressed_tensors.quantization import QuantizationStrategy
13
13
 
14
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
15
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
14
16
  from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
15
17
  from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
16
18
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
@@ -30,8 +32,10 @@ from sglang.srt.utils import (
30
32
 
31
33
  if TYPE_CHECKING:
32
34
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
33
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
34
- from sglang.srt.layers.moe.topk import TopKOutput
35
+ from sglang.srt.layers.moe.token_dispatcher import (
36
+ CombineInput,
37
+ StandardDispatchOutput,
38
+ )
35
39
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
36
40
  CompressedTensorsConfig,
37
41
  )
@@ -293,14 +297,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
293
297
  )
294
298
  torch.cuda.empty_cache()
295
299
 
300
+ def create_moe_runner(
301
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
302
+ ):
303
+ self.moe_runner_config = moe_runner_config
304
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
305
+
296
306
  def apply(
297
307
  self,
298
308
  layer: torch.nn.Module,
299
- x: torch.Tensor,
300
- topk_output: TopKOutput,
301
- moe_runner_config: MoeRunnerConfig,
302
- ) -> torch.Tensor:
303
- from sglang.srt.layers.moe.fused_moe_triton import fused_experts
309
+ dispatch_output: StandardDispatchOutput,
310
+ ) -> CombineInput:
311
+
312
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
313
+
314
+ x = dispatch_output.hidden_states
315
+ topk_output = dispatch_output.topk_output
316
+
317
+ moe_runner_config = self.moe_runner_config
304
318
 
305
319
  if (
306
320
  _use_aiter
@@ -308,7 +322,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
308
322
  and moe_runner_config.apply_router_weight_on_input
309
323
  ):
310
324
  topk_weights, topk_ids, _ = topk_output
311
- return rocm_fused_experts_tkw1(
325
+ output = rocm_fused_experts_tkw1(
312
326
  hidden_states=x,
313
327
  w1=layer.w13_weight,
314
328
  w2=layer.w2_weight,
@@ -324,21 +338,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
324
338
  a1_scale=layer.w13_input_scale,
325
339
  a2_scale=layer.w2_input_scale,
326
340
  )
341
+ return StandardCombineInput(hidden_states=output)
327
342
  else:
328
- return fused_experts(
329
- x,
330
- layer.w13_weight,
331
- layer.w2_weight,
332
- topk_output=topk_output,
333
- moe_runner_config=moe_runner_config,
343
+ quant_info = TritonMoeQuantInfo(
344
+ w13_weight=layer.w13_weight,
345
+ w2_weight=layer.w2_weight,
334
346
  use_fp8_w8a8=True,
335
347
  per_channel_quant=self.weight_quant.strategy
336
348
  == QuantizationStrategy.CHANNEL,
337
- w1_scale=layer.w13_weight_scale,
349
+ w13_scale=layer.w13_weight_scale,
338
350
  w2_scale=layer.w2_weight_scale,
339
- a1_scale=layer.w13_input_scale,
351
+ a13_scale=layer.w13_input_scale,
340
352
  a2_scale=layer.w2_input_scale,
341
353
  )
354
+ return self.runner.run(dispatch_output, quant_info)
342
355
 
343
356
 
344
357
  class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
@@ -380,8 +393,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
380
393
  params_dtype == torch.float16
381
394
  ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
382
395
 
383
- intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
384
-
385
396
  # Will transpose the loaded weight along the
386
397
  # intermediate and hidden dim sizes. Will
387
398
  # shard for TP along the transposed dims
@@ -415,13 +426,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
415
426
  # In the case where we have actorder/g_idx,
416
427
  # we do not partition the w2 scales
417
428
  load_full_w2 = self.actorder and self.group_size != -1
418
- w2_scales_size = (
419
- intermediate_size_full if load_full_w2 else intermediate_size_per_partition
420
- )
421
429
 
422
- self.is_k_full = (not self.actorder) or (
423
- intermediate_size_per_partition == intermediate_size_full
424
- )
430
+ if load_full_w2:
431
+ w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
432
+ else:
433
+ w2_scales_size = intermediate_size_per_partition
434
+
435
+ self.is_k_full = (not self.actorder) or layer.moe_tp_size == 1
425
436
 
426
437
  if self.strategy == "channel":
427
438
  num_groups_w2 = num_groups_w13 = 1
@@ -640,21 +651,29 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
640
651
  )
641
652
  replace_tensor("w2_weight_scale", marlin_w2_scales)
642
653
 
654
+ def create_moe_runner(
655
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
656
+ ):
657
+ self.moe_runner_config = moe_runner_config
658
+
643
659
  def apply(
644
660
  self,
645
661
  layer: torch.nn.Module,
646
- x: torch.Tensor,
647
- topk_output: TopKOutput,
648
- moe_runner_config: MoeRunnerConfig,
649
- ) -> torch.Tensor:
662
+ dispatch_output: StandardDispatchOutput,
663
+ ) -> CombineInput:
664
+
665
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
650
666
 
651
667
  assert (
652
- moe_runner_config.activation == "silu"
668
+ self.moe_runner_config.activation == "silu"
653
669
  ), "Only SiLU activation is supported."
654
670
 
671
+ x = dispatch_output.hidden_states
672
+ topk_output = dispatch_output.topk_output
673
+
655
674
  topk_weights, topk_ids, router_logits = topk_output
656
675
 
657
- return torch.ops.vllm.fused_marlin_moe(
676
+ output = torch.ops.vllm.fused_marlin_moe(
658
677
  x,
659
678
  layer.w13_weight_packed,
660
679
  layer.w2_weight_packed,
@@ -670,3 +689,4 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
670
689
  num_bits=self.num_bits,
671
690
  is_k_full=self.is_k_full,
672
691
  )
692
+ return StandardCombineInput(hidden_states=output)
@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
21
21
  normalize_e4m3fn_to_e4m3fnuz,
22
22
  )
23
23
  from sglang.srt.layers.quantization.utils import requantize_with_max_scale
24
+ from sglang.srt.utils import get_bool_env_var, is_hip
24
25
 
25
26
  __all__ = ["CompressedTensorsW8A8Fp8"]
26
27
 
28
+ _is_hip = is_hip()
29
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
30
+ if _use_aiter:
31
+ from aiter.ops.shuffle import shuffle_weight
32
+
27
33
 
28
34
  class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
29
35
 
@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
76
82
  else:
77
83
  weight_scale = layer.weight_scale.data
78
84
 
79
- layer.weight = Parameter(weight.t(), requires_grad=False)
85
+ if _use_aiter:
86
+ layer.weight = Parameter(
87
+ shuffle_weight(weight, (16, 16)), requires_grad=False
88
+ )
89
+ else:
90
+ layer.weight = Parameter(weight.t(), requires_grad=False)
91
+
80
92
  # required by torch.compile to be torch.nn.Parameter
81
93
  layer.weight_scale = Parameter(weight_scale, requires_grad=False)
82
94
 
@@ -93,7 +93,7 @@ def _maybe_compile_deep_gemm_one_type_all(
93
93
  if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
94
94
  logger.warning(
95
95
  "Entering DeepGEMM JIT Pre-Compile session. "
96
- "It may takes a long time (typically 10-20 mins) "
96
+ "It may take a long time (typically 10-20 mins) "
97
97
  "if you have not run `sglang.compile_deep_gemm`. "
98
98
  "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
99
99
  " for pre-compilation to reduce the overhead if you have not run it before. "
@@ -132,9 +132,17 @@ def _compile_deep_gemm_one_type_all(
132
132
  kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
133
133
  )
134
134
 
135
+ old_compile_mode = deep_gemm.get_compile_mode()
136
+ deep_gemm.set_compile_mode(1)
135
137
  # TODO can use multi thread
136
138
  for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
137
139
  executor.execute(m=m)
140
+ deep_gemm.set_compile_mode(old_compile_mode)
141
+
142
+ # clean up input buffers
143
+ torch.cuda.current_stream().synchronize()
144
+ del executor
145
+ torch.cuda.empty_cache()
138
146
 
139
147
 
140
148
  class _BaseWarmupExecutor:
@@ -11,9 +11,6 @@ def _compute_enable_deep_gemm():
11
11
  sm_version = get_device_sm()
12
12
  if sm_version < 90:
13
13
  return False
14
- # TODO fix deepgemm cu129 fp8 issue
15
- if torch.version.cuda == "12.9":
16
- return False
17
14
 
18
15
  try:
19
16
  import deep_gemm
@@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
11
11
  ENABLE_JIT_DEEPGEMM,
12
12
  )
13
13
  from sglang.srt.server_args import ServerArgs
14
+ from sglang.srt.utils import get_bool_env_var
14
15
 
15
16
  logger = logging.getLogger(__name__)
16
17
 
@@ -18,6 +19,8 @@ if ENABLE_JIT_DEEPGEMM:
18
19
  import deep_gemm
19
20
  from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
20
21
 
22
+ _SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
23
+
21
24
 
22
25
  # TODO maybe rename these functions
23
26
  def grouped_gemm_nt_f8f8bf16_masked(
@@ -31,6 +34,9 @@ def grouped_gemm_nt_f8f8bf16_masked(
31
34
  _, n, _ = rhs[0].shape
32
35
  kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
33
36
 
37
+ _sanity_check_input(lhs)
38
+ _sanity_check_input(rhs)
39
+
34
40
  with compile_utils.deep_gemm_execution_hook(
35
41
  expected_m, n, k, num_groups, kernel_type
36
42
  ):
@@ -53,6 +59,9 @@ def grouped_gemm_nt_f8f8bf16_contig(
53
59
  num_groups, n, _ = rhs[0].shape
54
60
  kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
55
61
 
62
+ _sanity_check_input(lhs)
63
+ _sanity_check_input(rhs)
64
+
56
65
  with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
57
66
  deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
58
67
 
@@ -67,6 +76,9 @@ def gemm_nt_f8f8bf16(
67
76
  num_groups = 1
68
77
  kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
69
78
 
79
+ _sanity_check_input(lhs)
80
+ _sanity_check_input(rhs)
81
+
70
82
  with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
71
83
  deep_gemm.fp8_gemm_nt(
72
84
  lhs,
@@ -90,3 +102,18 @@ def configure_deep_gemm_num_sms(num_sms):
90
102
  yield
91
103
  finally:
92
104
  deep_gemm.set_num_sms(original_num_sms)
105
+
106
+
107
+ def _sanity_check_input(x_fp8: Tuple[torch.Tensor, torch.Tensor]):
108
+ if not _SANITY_CHECK:
109
+ return
110
+
111
+ x, x_scale = x_fp8
112
+
113
+ if x_scale.dtype == torch.int:
114
+ return
115
+
116
+ from sglang.srt.layers.quantization.fp8_utils import ceil_to_ue8m0
117
+
118
+ x_scale_ceil = ceil_to_ue8m0(x_scale)
119
+ assert torch.all(x_scale == x_scale_ceil), f"{x_scale=} {x_scale_ceil=}"
@@ -30,6 +30,9 @@ except ImportError:
30
30
 
31
31
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
32
32
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
33
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
34
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
35
+ from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
33
36
  from sglang.srt.layers.parameter import (
34
37
  BlockQuantScaleParameter,
35
38
  ModelWeightParameter,
@@ -81,7 +84,11 @@ from sglang.srt.utils import (
81
84
  )
82
85
 
83
86
  if TYPE_CHECKING:
84
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
87
+ from sglang.srt.layers.moe.token_dispatcher import (
88
+ CombineInput,
89
+ DispatchOutput,
90
+ StandardDispatchOutput,
91
+ )
85
92
  from sglang.srt.layers.moe.topk import TopKOutput
86
93
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
87
94
 
@@ -345,6 +352,9 @@ class Fp8LinearMethod(LinearMethodBase):
345
352
  _is_cpu_amx_available
346
353
  ), "Fp8LinearMethod on CPU requires that CPU has AMX support"
347
354
  _amx_process_weight_after_loading(layer, ["weight"])
355
+ layer.weight_scale_inv = torch.nn.Parameter(
356
+ layer.weight_scale_inv.data, requires_grad=False
357
+ )
348
358
  return
349
359
  else:
350
360
  weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
@@ -527,7 +537,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
527
537
  layer: Module,
528
538
  num_experts: int,
529
539
  hidden_size: int,
530
- intermediate_size: int,
540
+ intermediate_size_per_partition: int,
531
541
  params_dtype: torch.dtype,
532
542
  **extra_weight_attrs,
533
543
  ):
@@ -543,18 +553,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
543
553
  )
544
554
  # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
545
555
  # Required by column parallel or enabling merged weights
546
- if intermediate_size % block_n != 0:
556
+ if intermediate_size_per_partition % block_n != 0:
547
557
  raise ValueError(
548
558
  f"The output_size of gate's and up's weight = "
549
- f"{intermediate_size} is not divisible by "
559
+ f"{intermediate_size_per_partition} is not divisible by "
550
560
  f"weight quantization block_n = {block_n}."
551
561
  )
552
562
  if tp_size > 1:
553
563
  # Required by row parallel
554
- if intermediate_size % block_k != 0:
564
+ if intermediate_size_per_partition % block_k != 0:
555
565
  raise ValueError(
556
566
  f"The input_size of down's weight = "
557
- f"{intermediate_size} is not divisible by "
567
+ f"{intermediate_size_per_partition} is not divisible by "
558
568
  f"weight quantization block_k = {block_k}."
559
569
  )
560
570
 
@@ -564,7 +574,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
564
574
  w13_weight = torch.nn.Parameter(
565
575
  torch.empty(
566
576
  num_experts,
567
- 2 * intermediate_size,
577
+ 2 * intermediate_size_per_partition,
568
578
  hidden_size // 8,
569
579
  dtype=params_dtype,
570
580
  ),
@@ -572,20 +582,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
572
582
  )
573
583
  w2_weight = torch.nn.Parameter(
574
584
  torch.empty(
575
- num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype
585
+ num_experts,
586
+ hidden_size,
587
+ intermediate_size_per_partition // 8,
588
+ dtype=params_dtype,
576
589
  ),
577
590
  requires_grad=False,
578
591
  )
579
592
  else:
580
593
  w13_weight = torch.nn.Parameter(
581
594
  torch.empty(
582
- num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
595
+ num_experts,
596
+ 2 * intermediate_size_per_partition,
597
+ hidden_size,
598
+ dtype=params_dtype,
583
599
  ),
584
600
  requires_grad=False,
585
601
  )
586
602
  w2_weight = torch.nn.Parameter(
587
603
  torch.empty(
588
- num_experts, hidden_size, intermediate_size, dtype=params_dtype
604
+ num_experts,
605
+ hidden_size,
606
+ intermediate_size_per_partition,
607
+ dtype=params_dtype,
589
608
  ),
590
609
  requires_grad=False,
591
610
  )
@@ -601,7 +620,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
601
620
  w13_weight_scale = torch.nn.Parameter(
602
621
  torch.ones(
603
622
  num_experts,
604
- 2 * ((intermediate_size + block_n - 1) // block_n),
623
+ 2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
605
624
  (hidden_size + block_k - 1) // block_k,
606
625
  dtype=torch.float32,
607
626
  ),
@@ -611,7 +630,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
611
630
  torch.ones(
612
631
  num_experts,
613
632
  (hidden_size + block_n - 1) // block_n,
614
- (intermediate_size + block_k - 1) // block_k,
633
+ (intermediate_size_per_partition + block_k - 1) // block_k,
615
634
  dtype=torch.float32,
616
635
  ),
617
636
  requires_grad=False,
@@ -619,11 +638,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
619
638
  layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
620
639
  layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
621
640
  assert self.quant_config.activation_scheme == "dynamic"
622
- if (
623
- get_bool_env_var("SGLANG_CUTLASS_MOE")
624
- and self.cutlass_fp8_supported
625
- and (is_sm100_supported() or is_sm90_supported())
626
- ):
641
+ if self.use_cutlass_fused_experts_fp8:
627
642
  self.ab_strides1 = torch.full(
628
643
  (num_experts,),
629
644
  hidden_size,
@@ -632,13 +647,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
632
647
  )
633
648
  self.c_strides1 = torch.full(
634
649
  (num_experts,),
635
- 2 * intermediate_size,
650
+ 2 * intermediate_size_per_partition,
636
651
  device=w13_weight.device,
637
652
  dtype=torch.int64,
638
653
  )
639
654
  self.ab_strides2 = torch.full(
640
655
  (num_experts,),
641
- intermediate_size,
656
+ intermediate_size_per_partition,
642
657
  device=w2_weight.device,
643
658
  dtype=torch.int64,
644
659
  )
@@ -691,7 +706,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
691
706
  if _is_hip: # _use_aiter: TODO: add check back after triton kernel
692
707
  # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
693
708
  w13_weight_scale1 = torch.nn.Parameter(
694
- torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
709
+ torch.ones(
710
+ num_experts,
711
+ 2 * intermediate_size_per_partition,
712
+ dtype=torch.float32,
713
+ ),
695
714
  requires_grad=False,
696
715
  )
697
716
  w2_weight_scale1 = torch.nn.Parameter(
@@ -984,14 +1003,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
984
1003
  )
985
1004
  torch.cuda.empty_cache()
986
1005
 
1006
+ def create_moe_runner(
1007
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1008
+ ):
1009
+ self.moe_runner_config = moe_runner_config
1010
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
1011
+
987
1012
  def apply(
988
1013
  self,
989
1014
  layer: torch.nn.Module,
990
- x: torch.Tensor,
991
- topk_output: TopKOutput,
992
- moe_runner_config: MoeRunnerConfig,
993
- ) -> torch.Tensor:
994
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
1015
+ dispatch_output: DispatchOutput,
1016
+ ) -> CombineInput:
1017
+
1018
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1019
+
1020
+ x = dispatch_output.hidden_states
1021
+ topk_output = dispatch_output.topk_output
1022
+ moe_runner_config = self.moe_runner_config
995
1023
 
996
1024
  if use_intel_amx_backend(layer):
997
1025
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
@@ -1001,7 +1029,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1001
1029
  moe_runner_config.apply_router_weight_on_input, topk_weights, x
1002
1030
  )
1003
1031
 
1004
- return torch.ops.sgl_kernel.fused_experts_cpu(
1032
+ output = torch.ops.sgl_kernel.fused_experts_cpu(
1005
1033
  x,
1006
1034
  layer.w13_weight,
1007
1035
  layer.w2_weight,
@@ -1017,6 +1045,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1017
1045
  None, # a2_scale
1018
1046
  True, # is_vnni
1019
1047
  )
1048
+ return StandardCombineInput(hidden_states=output)
1020
1049
 
1021
1050
  if _is_hip:
1022
1051
  ret = self.maybe_apply_hip_fused_experts(
@@ -1027,7 +1056,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1027
1056
  moe_runner_config.no_combine,
1028
1057
  )
1029
1058
  if ret is not None:
1030
- return ret
1059
+ return StandardCombineInput(hidden_states=ret)
1031
1060
 
1032
1061
  if self.use_cutlass_fused_experts_fp8:
1033
1062
  from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
@@ -1056,17 +1085,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1056
1085
  self.problem_sizes2,
1057
1086
  use_fp8_blockscale=True,
1058
1087
  )
1059
- # Scale by routed_scaling_factor is fused into select_experts.
1060
- return output
1061
- # Expert fusion with FP8 quantization
1062
- return fused_experts(
1063
- x,
1064
- layer.w13_weight,
1065
- layer.w2_weight,
1066
- topk_output=topk_output,
1067
- moe_runner_config=moe_runner_config,
1088
+ return StandardCombineInput(hidden_states=output)
1089
+
1090
+ quant_info = TritonMoeQuantInfo(
1091
+ w13_weight=layer.w13_weight,
1092
+ w2_weight=layer.w2_weight,
1068
1093
  use_fp8_w8a8=True,
1069
- w1_scale=(
1094
+ w13_scale=(
1070
1095
  layer.w13_weight_scale_inv
1071
1096
  if self.block_quant
1072
1097
  else layer.w13_weight_scale
@@ -1074,20 +1099,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1074
1099
  w2_scale=(
1075
1100
  layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
1076
1101
  ),
1077
- a1_scale=layer.w13_input_scale,
1102
+ a13_scale=layer.w13_input_scale,
1078
1103
  a2_scale=layer.w2_input_scale,
1079
1104
  block_shape=self.quant_config.weight_block_size,
1080
1105
  )
1106
+ return self.runner.run(dispatch_output, quant_info)
1081
1107
 
1082
1108
  def apply_with_router_logits(
1083
1109
  self,
1084
1110
  layer: torch.nn.Module,
1085
- x: torch.Tensor,
1086
- topk_output: TopKOutput,
1087
- moe_runner_config: MoeRunnerConfig,
1111
+ dispatch_output: StandardDispatchOutput,
1088
1112
  ) -> torch.Tensor:
1089
- activation = moe_runner_config.activation
1090
- routed_scaling_factor = moe_runner_config.routed_scaling_factor
1113
+ x = dispatch_output.hidden_states
1114
+ topk_output = dispatch_output.topk_output
1115
+
1116
+ activation = self.moe_runner_config.activation
1117
+ routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
1091
1118
 
1092
1119
  from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
1093
1120
 
@@ -1108,10 +1135,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1108
1135
  and topk_config.topk_group is not None
1109
1136
  ), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
1110
1137
 
1111
- if topk_config.correction_bias is None:
1112
- correction_bias = topk_config.correction_bias.to(x.dtype)
1113
- else:
1114
- correction_bias = None
1138
+ correction_bias = (
1139
+ None
1140
+ if topk_config.correction_bias is None
1141
+ else topk_config.correction_bias.to(x.dtype)
1142
+ )
1143
+
1115
1144
  return trtllm_fp8_block_scale_moe(
1116
1145
  routing_logits=router_logits.to(torch.float32),
1117
1146
  routing_bias=correction_bias,