sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,11 @@
1
1
  import logging
2
2
  from typing import Callable, List, Optional, Tuple
3
3
 
4
+ import einops
4
5
  import torch
6
+ from sgl_kernel import silu_and_mul
5
7
  from torch.nn import Module
6
8
 
7
- from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
8
- from sglang.srt.managers.expert_location import get_global_expert_location_metadata
9
- from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
10
- from sglang.srt.managers.schedule_batch import global_server_args_dict
11
-
12
- try:
13
- from deep_gemm import (
14
- get_col_major_tma_aligned_tensor,
15
- m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
16
- m_grouped_gemm_fp8_fp8_bf16_nt_masked,
17
- )
18
- from sgl_kernel import silu_and_mul
19
-
20
- from sglang.srt.layers.quantization.fp8_kernel import (
21
- sglang_per_token_group_quant_fp8,
22
- )
23
-
24
- use_deep_gemm = True
25
- except ImportError:
26
- use_deep_gemm = False
27
-
28
9
  from sglang.srt.custom_op import CustomOp
29
10
  from sglang.srt.distributed import (
30
11
  get_tensor_model_parallel_rank,
@@ -35,6 +16,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
35
16
  ep_scatter,
36
17
  gelu_and_mul_triton_kernel,
37
18
  grouped_gemm_triton,
19
+ moe_ep_deepgemm_preprocess,
38
20
  post_reorder_triton_kernel,
39
21
  pre_reorder_triton_kernel,
40
22
  run_moe_ep_preproess,
@@ -45,19 +27,33 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
45
27
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
46
28
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
47
29
  from sglang.srt.layers.moe.topk import select_experts
30
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
48
31
  from sglang.srt.layers.quantization.base_config import (
49
32
  QuantizationConfig,
50
33
  QuantizeMethodBase,
51
34
  )
52
35
  from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
53
36
  from sglang.srt.layers.quantization.fp8_kernel import (
37
+ is_fp8_fnuz,
54
38
  scaled_fp8_quant,
39
+ sglang_per_token_group_quant_fp8,
55
40
  sglang_per_token_quant_fp8,
56
41
  )
42
+ from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
43
+ from sglang.srt.managers.expert_location import get_global_expert_location_metadata
44
+ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
45
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
57
46
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
58
- from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
47
+ from sglang.srt.utils import (
48
+ DeepEPMode,
49
+ dispose_tensor,
50
+ get_bool_env_var,
51
+ is_hip,
52
+ set_weight_attrs,
53
+ )
59
54
 
60
55
  _is_hip = is_hip()
56
+ _is_fp8_fnuz = is_fp8_fnuz()
61
57
 
62
58
  if _is_hip:
63
59
  from vllm._custom_ops import scaled_fp8_quant
@@ -183,6 +179,7 @@ class EPMoE(torch.nn.Module):
183
179
  assert (
184
180
  num_fused_shared_experts == 0
185
181
  ), "num_fused_shared_experts is not supported in EP"
182
+ self.num_fused_shared_experts = num_fused_shared_experts
186
183
  self.num_experts_per_partition = self.num_experts // self.tp_size
187
184
  self.start_expert_id = self.tp_rank * self.num_experts_per_partition
188
185
  self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
@@ -232,13 +229,182 @@ class EPMoE(torch.nn.Module):
232
229
 
233
230
  self.grouped_gemm_runner = None
234
231
 
232
+ self.w13_weight_fp8 = (
233
+ self.w13_weight,
234
+ (
235
+ self.w13_weight_scale_inv
236
+ if self.use_block_quant
237
+ else self.w13_weight_scale
238
+ ),
239
+ )
240
+ self.w2_weight_fp8 = (
241
+ self.w2_weight,
242
+ self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
243
+ )
244
+
235
245
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
246
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
247
+ return self.forward_deepgemm(hidden_states, router_logits)
248
+ else:
249
+ return self.forward_normal(hidden_states, router_logits)
250
+
251
+ def forward_deepgemm(
252
+ self, hidden_states: torch.Tensor, router_logits: torch.Tensor
253
+ ):
254
+ assert self.quant_method is not None
255
+ assert self.activation == "silu"
236
256
  hidden_states_shape = hidden_states.shape
237
257
  hidden_states_dtype = hidden_states.dtype
238
258
  hidden_states_device = hidden_states.device
259
+ topk_weights, topk_ids = select_experts(
260
+ hidden_states=hidden_states,
261
+ router_logits=router_logits,
262
+ top_k=self.top_k,
263
+ use_grouped_topk=self.use_grouped_topk,
264
+ renormalize=self.renormalize,
265
+ topk_group=self.topk_group,
266
+ num_expert_group=self.num_expert_group,
267
+ num_fused_shared_experts=self.num_fused_shared_experts,
268
+ correction_bias=self.correction_bias,
269
+ custom_routing_function=self.custom_routing_function,
270
+ routed_scaling_factor=self.routed_scaling_factor,
271
+ )
239
272
 
240
- assert self.quant_method is not None
273
+ if not self.use_block_quant:
274
+ # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
275
+ scale_block_size = 128
276
+ w13_weight_scale_n = 2 * (
277
+ (self.intermediate_size + scale_block_size - 1) // scale_block_size
278
+ )
279
+ w13_weight_scale_k = (
280
+ hidden_states_shape[-1] + scale_block_size - 1
281
+ ) // scale_block_size
282
+ w13_weight_scale = (
283
+ self.w13_weight_scale.unsqueeze(1)
284
+ .repeat_interleave(w13_weight_scale_n, dim=1)
285
+ .unsqueeze(2)
286
+ .repeat_interleave(w13_weight_scale_k, dim=2)
287
+ )
288
+ self.w13_weight_fp8 = (
289
+ self.w13_weight,
290
+ w13_weight_scale,
291
+ )
292
+ w2_weight_scale_n = (
293
+ hidden_states_shape[-1] + scale_block_size - 1
294
+ ) // scale_block_size
295
+ w2_weight_scale_k = (
296
+ self.intermediate_size + scale_block_size - 1
297
+ ) // scale_block_size
298
+ w2_weight_scale = (
299
+ self.w2_weight_scale.unsqueeze(1)
300
+ .repeat_interleave(w2_weight_scale_n, dim=1)
301
+ .unsqueeze(2)
302
+ .repeat_interleave(w2_weight_scale_k, dim=2)
303
+ )
304
+ self.w2_weight_fp8 = (
305
+ self.w2_weight,
306
+ w2_weight_scale,
307
+ )
308
+
309
+ # PreReorder
310
+ m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
311
+ moe_ep_deepgemm_preprocess(
312
+ topk_ids,
313
+ self.num_experts,
314
+ hidden_states,
315
+ self.top_k,
316
+ self.start_expert_id,
317
+ self.end_expert_id,
318
+ self.block_shape,
319
+ )
320
+ )
321
+
322
+ dispose_tensor(hidden_states)
323
+
324
+ # GroupGemm-0
325
+ gateup_input_fp8 = (
326
+ gateup_input,
327
+ deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
328
+ )
329
+ num_groups, m, k = gateup_input_fp8[0].size()
330
+ n = self.w13_weight.size(1)
331
+ gateup_output = torch.empty(
332
+ (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
333
+ )
334
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
335
+ gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
336
+ )
337
+ del gateup_input
338
+ del gateup_input_fp8
241
339
 
340
+ # Act
341
+ down_input = torch.empty(
342
+ (
343
+ gateup_output.shape[0],
344
+ gateup_output.shape[1],
345
+ gateup_output.shape[2] // 2,
346
+ ),
347
+ device=hidden_states_device,
348
+ dtype=self.fp8_dtype,
349
+ )
350
+ scale_block_size = 128
351
+ down_input_scale = torch.empty(
352
+ (
353
+ gateup_output.shape[0],
354
+ gateup_output.shape[1],
355
+ gateup_output.shape[2] // 2 // scale_block_size,
356
+ ),
357
+ device=hidden_states_device,
358
+ dtype=torch.float32,
359
+ )
360
+ silu_and_mul_masked_post_quant_fwd(
361
+ gateup_output,
362
+ down_input,
363
+ down_input_scale,
364
+ scale_block_size,
365
+ masked_m,
366
+ )
367
+ del gateup_output
368
+
369
+ # GroupGemm-1
370
+ n = self.w2_weight.size(1)
371
+ down_input_fp8 = (
372
+ down_input,
373
+ deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
374
+ )
375
+ down_output = torch.empty(
376
+ (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
377
+ )
378
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
379
+ down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
380
+ )
381
+ del down_input
382
+ del down_input_fp8
383
+
384
+ # PostReorder
385
+ output = torch.empty(
386
+ hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
387
+ )
388
+ post_reorder_triton_kernel[(hidden_states_shape[0],)](
389
+ down_output,
390
+ output,
391
+ src2dst,
392
+ topk_ids,
393
+ topk_weights,
394
+ self.start_expert_id,
395
+ self.end_expert_id,
396
+ self.top_k,
397
+ hidden_states_shape[1],
398
+ m_max * self.start_expert_id,
399
+ BLOCK_SIZE=512,
400
+ )
401
+ return output
402
+
403
+ def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
404
+ assert self.quant_method is not None
405
+ hidden_states_shape = hidden_states.shape
406
+ hidden_states_dtype = hidden_states.dtype
407
+ hidden_states_device = hidden_states.device
242
408
  if self.grouped_gemm_runner is None:
243
409
  self.grouped_gemm_runner = GroupedGemmRunner(
244
410
  hidden_states.device,
@@ -254,6 +420,7 @@ class EPMoE(torch.nn.Module):
254
420
  renormalize=self.renormalize,
255
421
  topk_group=self.topk_group,
256
422
  num_expert_group=self.num_expert_group,
423
+ num_fused_shared_experts=self.num_fused_shared_experts,
257
424
  correction_bias=self.correction_bias,
258
425
  custom_routing_function=self.custom_routing_function,
259
426
  routed_scaling_factor=self.routed_scaling_factor,
@@ -445,6 +612,7 @@ class EPMoE(torch.nn.Module):
445
612
  self.end_expert_id,
446
613
  self.top_k,
447
614
  hidden_states_shape[1],
615
+ 0,
448
616
  BLOCK_SIZE=512,
449
617
  )
450
618
  return output
@@ -680,7 +848,6 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
680
848
  params_dtype: torch.dtype,
681
849
  **extra_weight_attrs,
682
850
  ):
683
-
684
851
  if self.quant_config.is_checkpoint_fp8_serialized:
685
852
  params_dtype = torch.float8_e4m3fn
686
853
 
@@ -852,6 +1019,33 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
852
1019
  torch.max(layer.w13_weight_scale, dim=1).values,
853
1020
  requires_grad=False,
854
1021
  )
1022
+ if self.block_quant:
1023
+ # If ROCm, normalize the weights and scales to e4m3fnuz
1024
+ if _is_fp8_fnuz:
1025
+ # activation_scheme: dynamic
1026
+ w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1027
+ weight=layer.w13_weight,
1028
+ weight_scale=layer.w13_weight_scale_inv,
1029
+ input_scale=None,
1030
+ )
1031
+ w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1032
+ weight=layer.w2_weight,
1033
+ weight_scale=layer.w2_weight_scale_inv,
1034
+ input_scale=None,
1035
+ )
1036
+ # Reset the parameter
1037
+ layer.w13_weight = torch.nn.Parameter(
1038
+ w13_weight, requires_grad=False
1039
+ )
1040
+ layer.w13_weight_scale_inv = torch.nn.Parameter(
1041
+ w13_weight_scale, requires_grad=False
1042
+ )
1043
+ layer.w13_input_scale = None
1044
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
1045
+ layer.w2_weight_scale_inv = torch.nn.Parameter(
1046
+ w2_weight_scale, requires_grad=False
1047
+ )
1048
+ layer.w2_input_scale = None
855
1049
  return
856
1050
 
857
1051
  def apply(
@@ -920,7 +1114,9 @@ class DeepEPMoE(EPMoE):
920
1114
  )
921
1115
  self.deepep_mode = deepep_mode
922
1116
  if self.deepep_mode.enable_low_latency():
923
- assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
1117
+ assert (
1118
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1119
+ ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
924
1120
  self.w13_weight_fp8 = (
925
1121
  self.w13_weight,
926
1122
  (
@@ -948,7 +1144,7 @@ class DeepEPMoE(EPMoE):
948
1144
  ):
949
1145
  resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
950
1146
  if resolved_deepep_mode == DeepEPMode.normal:
951
- if _ENABLE_JIT_DEEPGEMM:
1147
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
952
1148
  return self.forward_deepgemm_contiguous(
953
1149
  hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
954
1150
  )
@@ -1145,7 +1341,7 @@ class DeepEPMoE(EPMoE):
1145
1341
  dtype=torch.bfloat16,
1146
1342
  )
1147
1343
  input_tensor[1] = tma_align_input_scale(input_tensor[1])
1148
- m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
1344
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
1149
1345
  input_tensor, self.w13_weight_fp8, gateup_output, m_indices
1150
1346
  )
1151
1347
  del input_tensor
@@ -1169,7 +1365,7 @@ class DeepEPMoE(EPMoE):
1169
1365
  )
1170
1366
  del down_input
1171
1367
  down_input_scale = tma_align_input_scale(down_input_scale)
1172
- m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
1368
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
1173
1369
  (down_input_fp8, down_input_scale),
1174
1370
  self.w2_weight_fp8,
1175
1371
  down_output,
@@ -1202,8 +1398,13 @@ class DeepEPMoE(EPMoE):
1202
1398
  gateup_output = torch.empty(
1203
1399
  (num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
1204
1400
  )
1205
- m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1206
- hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
1401
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1402
+ hidden_states_fp8,
1403
+ self.w13_weight_fp8,
1404
+ gateup_output,
1405
+ masked_m,
1406
+ expected_m,
1407
+ recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
1207
1408
  )
1208
1409
  dispose_tensor(hidden_states_fp8[0])
1209
1410
 
@@ -1233,6 +1434,7 @@ class DeepEPMoE(EPMoE):
1233
1434
  down_input_scale,
1234
1435
  scale_block_size,
1235
1436
  masked_m,
1437
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
1236
1438
  )
1237
1439
  del gateup_output
1238
1440
 
@@ -1240,13 +1442,24 @@ class DeepEPMoE(EPMoE):
1240
1442
  n = self.w2_weight.size(1)
1241
1443
  down_input_fp8 = (
1242
1444
  down_input,
1243
- get_col_major_tma_aligned_tensor(down_input_scale),
1445
+ (
1446
+ down_input_scale
1447
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
1448
+ else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
1449
+ down_input_scale
1450
+ )
1451
+ ),
1244
1452
  )
1245
1453
  down_output = torch.empty(
1246
1454
  (num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
1247
1455
  )
1248
- m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1249
- down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
1456
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1457
+ down_input_fp8,
1458
+ self.w2_weight_fp8,
1459
+ down_output,
1460
+ masked_m,
1461
+ expected_m,
1462
+ recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
1250
1463
  )
1251
1464
 
1252
1465
  return down_output
@@ -1255,6 +1468,9 @@ class DeepEPMoE(EPMoE):
1255
1468
  def get_moe_impl_class():
1256
1469
  if global_server_args_dict["enable_deepep_moe"]:
1257
1470
  return DeepEPMoE
1471
+ if global_server_args_dict["enable_flashinfer_moe"]:
1472
+ # Must come before EPMoE because FusedMoE also supports enable_ep_moe
1473
+ return FusedMoE
1258
1474
  if global_server_args_dict["enable_ep_moe"]:
1259
1475
  return EPMoE
1260
1476
  return FusedMoE
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from dataclasses import dataclass
3
3
 
4
- from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
4
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
5
5
  from sglang.srt.managers.expert_distribution import (
6
6
  get_global_expert_distribution_recorder,
7
7
  )
@@ -107,6 +107,8 @@ class DeepEPBuffer:
107
107
  num_rdma_bytes,
108
108
  low_latency_mode=deepep_mode.enable_low_latency(),
109
109
  num_qps_per_rank=num_qps_per_rank,
110
+ # TODO can be false when unneeded
111
+ allow_mnnvl=True,
110
112
  )
111
113
  return cls._buffer
112
114
 
@@ -234,14 +236,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
234
236
  topk_weights: torch.Tensor,
235
237
  ):
236
238
  topk_idx = topk_idx.to(torch.int64)
237
- if _ENABLE_JIT_DEEPGEMM:
239
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
238
240
  # TODO hard code 128 block quant,use fp8 communication
239
241
  hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
240
242
  previous_event = Buffer.capture() if self.async_finish else None
241
243
  return hidden_states, topk_idx, topk_weights, previous_event
242
244
 
243
245
  def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
244
- if _ENABLE_JIT_DEEPGEMM:
246
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
245
247
  (
246
248
  hidden_states,
247
249
  topk_idx,
@@ -343,7 +345,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
343
345
  previous_event=previous_event,
344
346
  async_finish=self.async_finish,
345
347
  allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
346
- expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
348
+ expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
347
349
  config=DeepEPConfig.get_instance().normal_dispatch_config,
348
350
  )
349
351
 
@@ -407,7 +409,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
407
409
  topk_idx: torch.Tensor,
408
410
  topk_weights: torch.Tensor,
409
411
  ):
410
- if _ENABLE_JIT_DEEPGEMM:
412
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
411
413
  output = hidden_states
412
414
  else:
413
415
  if hidden_states.shape[0] > 0:
@@ -540,38 +542,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
540
542
  topk_idx: torch.Tensor,
541
543
  use_fp8: bool = False,
542
544
  ):
543
- """
544
- # For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'.
545
- # Please make sure to change DeepEP code in internode_ll.cu dispatch / combine as below first and then reinstall.
546
- # More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
547
-
548
- diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
549
- index 76ae2e2..8ecd08f 100644
550
- --- a/csrc/kernels/internode_ll.cu
551
- +++ b/csrc/kernels/internode_ll.cu
552
- @@ -310,8 +310,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
553
- int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
554
- void* workspace, cudaStream_t stream, int phases) {
555
- constexpr int kNumMaxTopK = 9;
556
- - constexpr int kNumWarpsPerGroup = 10;
557
- - constexpr int kNumWarpGroups = 3;
558
- + constexpr int kNumWarpsPerGroup = 8;
559
- + constexpr int kNumWarpGroups = 4;
560
- EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
561
-
562
- const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
563
- @@ -501,8 +501,8 @@ void combine(void* combined_x,
564
- int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
565
- int num_topk, int num_experts, int rank, int num_ranks,
566
- void* workspace, cudaStream_t stream, int phases) {
567
- - constexpr int kNumWarpsPerGroup = 10;
568
- - constexpr int kNumWarpGroups = 3;
569
- + constexpr int kNumWarpsPerGroup = 8;
570
- + constexpr int kNumWarpGroups = 4;
571
- constexpr int kNumMaxTopk = 9;
572
-
573
- const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
574
- """
575
545
  buffer = self._get_buffer()
576
546
  packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
577
547
  buffer.low_latency_dispatch(
@@ -582,6 +552,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
582
552
  use_fp8=use_fp8,
583
553
  async_finish=not self.return_recv_hook,
584
554
  return_recv_hook=self.return_recv_hook,
555
+ round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
556
+ and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
557
+ use_ue8m0=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
558
+ and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
585
559
  )
586
560
  )
587
561
  return packed_recv_hidden, packed_recv_count, event, hook
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 32,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 5
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 4,
72
+ "num_stages": 5
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 5
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 5
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }