sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/model_config.py +1 -0
  4. sglang/srt/constrained/base_grammar_backend.py +5 -1
  5. sglang/srt/custom_op.py +5 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  7. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  8. sglang/srt/entrypoints/engine.py +0 -5
  9. sglang/srt/layers/attention/flashattention_backend.py +394 -76
  10. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  11. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  12. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  13. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  14. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  15. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  20. sglang/srt/layers/moe/topk.py +49 -3
  21. sglang/srt/layers/quantization/__init__.py +4 -1
  22. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  23. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  24. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  25. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  26. sglang/srt/layers/quantization/utils.py +1 -1
  27. sglang/srt/layers/rotary_embedding.py +0 -12
  28. sglang/srt/managers/cache_controller.py +34 -11
  29. sglang/srt/managers/mm_utils.py +202 -156
  30. sglang/srt/managers/multimodal_processor.py +0 -2
  31. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  32. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  33. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  34. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  35. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  36. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  37. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  38. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  40. sglang/srt/managers/schedule_batch.py +185 -128
  41. sglang/srt/managers/scheduler.py +4 -4
  42. sglang/srt/managers/tokenizer_manager.py +1 -1
  43. sglang/srt/managers/utils.py +1 -6
  44. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  45. sglang/srt/mem_cache/memory_pool.py +72 -6
  46. sglang/srt/mem_cache/paged_allocator.py +39 -0
  47. sglang/srt/metrics/collector.py +23 -53
  48. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  49. sglang/srt/model_executor/forward_batch_info.py +10 -10
  50. sglang/srt/model_executor/model_runner.py +59 -57
  51. sglang/srt/model_loader/loader.py +8 -0
  52. sglang/srt/models/clip.py +12 -7
  53. sglang/srt/models/deepseek_janus_pro.py +10 -15
  54. sglang/srt/models/deepseek_v2.py +212 -121
  55. sglang/srt/models/deepseek_vl2.py +105 -104
  56. sglang/srt/models/gemma3_mm.py +14 -80
  57. sglang/srt/models/llama.py +4 -1
  58. sglang/srt/models/llava.py +31 -19
  59. sglang/srt/models/llavavid.py +16 -7
  60. sglang/srt/models/minicpmo.py +63 -147
  61. sglang/srt/models/minicpmv.py +17 -27
  62. sglang/srt/models/mllama.py +29 -14
  63. sglang/srt/models/qwen2.py +9 -6
  64. sglang/srt/models/qwen2_5_vl.py +21 -31
  65. sglang/srt/models/qwen2_vl.py +20 -21
  66. sglang/srt/openai_api/adapter.py +18 -6
  67. sglang/srt/platforms/interface.py +371 -0
  68. sglang/srt/server_args.py +99 -14
  69. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  70. sglang/srt/speculative/eagle_utils.py +140 -28
  71. sglang/srt/speculative/eagle_worker.py +93 -24
  72. sglang/srt/utils.py +104 -51
  73. sglang/test/test_custom_ops.py +55 -0
  74. sglang/test/test_utils.py +13 -26
  75. sglang/utils.py +2 -2
  76. sglang/version.py +1 -1
  77. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
  78. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
  79. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  80. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  81. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -244,6 +244,148 @@ def silu_and_mul_triton_kernel(
244
244
  tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
245
245
 
246
246
 
247
+ # copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
248
+ @triton.jit
249
+ def _silu_and_mul_post_quant_kernel(
250
+ input_ptr,
251
+ stride_input_0,
252
+ stride_input_1,
253
+ stride_input_2,
254
+ output_ptr,
255
+ stride_output_0,
256
+ stride_output_1,
257
+ stride_output_2,
258
+ output_scale_ptr,
259
+ stride_output_scale_0,
260
+ stride_output_scale_1,
261
+ stride_output_scale_2,
262
+ masked_m_ptr,
263
+ size_n,
264
+ fp8_max,
265
+ fp8_min,
266
+ BLOCK_N: tl.constexpr,
267
+ NUM_STAGE: tl.constexpr,
268
+ ):
269
+ expert_id = tl.program_id(2)
270
+ token_id = tl.program_id(1)
271
+ hidden_dim_block_index = tl.program_id(0)
272
+
273
+ block_num_per_expert = tl.num_programs(1)
274
+
275
+ token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
276
+
277
+ stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
278
+ stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
279
+ stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
280
+ stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
281
+
282
+ offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
283
+ input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
284
+ output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
285
+ output_scale_offs = (
286
+ output_scale_ptr
287
+ + expert_id * stride_output_scale_0
288
+ + hidden_dim_block_index * stride_output_scale_2
289
+ )
290
+
291
+ for token_index in tl.range(
292
+ token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE
293
+ ):
294
+ gate = tl.load(
295
+ input_ptr_offs + token_index * stride_input_1,
296
+ mask=offs_in_d < size_n,
297
+ other=0.0,
298
+ ).to(tl.float32)
299
+ up = tl.load(
300
+ input_ptr_offs + token_index * stride_input_1 + size_n,
301
+ mask=offs_in_d < size_n,
302
+ other=0.0,
303
+ )
304
+ gate = gate / (1 + tl.exp(-gate))
305
+ gate = gate.to(input_ptr.dtype.element_ty)
306
+ gate_up = up * gate
307
+ _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
308
+ output_s = _absmax / fp8_max
309
+ output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
310
+ output_ptr.dtype.element_ty
311
+ )
312
+ tl.store(
313
+ output_ptr_offs + token_index * stride_output_1,
314
+ output_q,
315
+ mask=offs_in_d < size_n,
316
+ )
317
+ tl.store(
318
+ output_scale_offs + token_index * stride_output_scale_1,
319
+ output_s,
320
+ )
321
+
322
+
323
+ def silu_and_mul_masked_post_quant_fwd(
324
+ input: torch.Tensor,
325
+ output: torch.Tensor,
326
+ output_scale: torch.Tensor,
327
+ quant_group_size: int,
328
+ masked_m: torch.Tensor,
329
+ ):
330
+ """
331
+ input shape [expert_num, token_num_padded, hidden_dim]
332
+ output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
333
+ output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
334
+ quant_group_size int,
335
+ masked_m shape [expert_num],
336
+ """
337
+
338
+ assert input.is_contiguous()
339
+ assert output.dtype == torch.float8_e4m3fn
340
+ assert output.is_contiguous()
341
+ assert len(input.shape) == 3
342
+ assert input.shape[0] == masked_m.shape[0]
343
+ assert input.shape[-1] % 2 == 0
344
+
345
+ size_n = input.shape[-1] // 2
346
+ assert size_n % quant_group_size == 0
347
+
348
+ expert_num = len(masked_m)
349
+
350
+ if expert_num < 4:
351
+ BLOCK_NUM_PER_EXPERT = 64
352
+ else:
353
+ BLOCK_NUM_PER_EXPERT = 32
354
+
355
+ BLOCK_N = quant_group_size
356
+ num_warps = 1
357
+ NUM_STAGES = 6
358
+ hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N)
359
+ assert BLOCK_N % quant_group_size == 0
360
+
361
+ grid = (
362
+ hidden_dim_split_block_num,
363
+ BLOCK_NUM_PER_EXPERT,
364
+ expert_num,
365
+ )
366
+
367
+ finfo = torch.finfo(torch.float8_e4m3fn)
368
+ fp8_max = finfo.max
369
+ fp8_min = -fp8_max
370
+
371
+ _silu_and_mul_post_quant_kernel[grid](
372
+ input,
373
+ *input.stride(),
374
+ output,
375
+ *output.stride(),
376
+ output_scale,
377
+ *output_scale.stride(),
378
+ masked_m,
379
+ size_n,
380
+ fp8_max,
381
+ fp8_min,
382
+ BLOCK_N=BLOCK_N,
383
+ NUM_STAGE=NUM_STAGES,
384
+ num_warps=num_warps,
385
+ )
386
+ return
387
+
388
+
247
389
  @triton.jit
248
390
  def tanh(x):
249
391
  return 2 * tl.sigmoid(2 * x) - 1
@@ -3,12 +3,16 @@ from typing import Callable, List, Optional, Tuple
3
3
 
4
4
  import torch
5
5
 
6
- # TODO: use deep_gemm masked kernel after low latency dispatch
7
- # import deep_gemm
8
- # from deep_gemm import (
9
- # get_col_major_tma_aligned_tensor,
10
- # m_grouped_gemm_fp8_fp8_bf16_nt_masked,
11
- # )
6
+ try:
7
+ from deep_gemm import (
8
+ get_col_major_tma_aligned_tensor,
9
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked,
10
+ )
11
+
12
+ use_deep_gemm = True
13
+ except ImportError:
14
+ use_deep_gemm = False
15
+
12
16
  from torch.nn import Module
13
17
 
14
18
  from sglang.srt.custom_op import CustomOp
@@ -22,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
22
26
  post_reorder_triton_kernel,
23
27
  pre_reorder_triton_kernel,
24
28
  run_moe_ep_preproess,
29
+ silu_and_mul_masked_post_quant_fwd,
25
30
  silu_and_mul_triton_kernel,
26
31
  )
27
32
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
@@ -33,7 +38,7 @@ from sglang.srt.layers.quantization.base_config import (
33
38
  )
34
39
  from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
35
40
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
36
- from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
41
+ from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
37
42
 
38
43
  _is_cuda = is_cuda()
39
44
 
@@ -42,7 +47,6 @@ if _is_cuda:
42
47
  else:
43
48
  from vllm import _custom_ops as vllm_ops
44
49
 
45
-
46
50
  logger = logging.getLogger(__name__)
47
51
 
48
52
  _is_hip = is_hip()
@@ -809,6 +813,7 @@ class DeepEPMoE(EPMoE):
809
813
  correction_bias: Optional[torch.Tensor] = None,
810
814
  custom_routing_function: Optional[Callable] = None,
811
815
  activation: str = "silu",
816
+ deepep_mode: DeepEPMode = DeepEPMode.auto,
812
817
  ):
813
818
  super().__init__(
814
819
  num_experts,
@@ -827,21 +832,38 @@ class DeepEPMoE(EPMoE):
827
832
  custom_routing_function,
828
833
  activation,
829
834
  )
835
+ self.deepep_mode = deepep_mode
836
+ if self.deepep_mode.enable_low_latency():
837
+ assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
838
+ self.w13_weight_fp8 = (
839
+ self.w13_weight,
840
+ (
841
+ self.w13_weight_scale_inv
842
+ if self.use_block_quant
843
+ else self.w13_weight_scale
844
+ ),
845
+ )
846
+ self.w2_weight_fp8 = (
847
+ self.w2_weight,
848
+ self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
849
+ )
830
850
 
831
851
  def forward(
832
852
  self,
833
853
  hidden_states: torch.Tensor,
834
854
  reorder_topk_ids: torch.Tensor,
835
855
  seg_indptr: torch.Tensor,
856
+ masked_m: torch.Tensor,
857
+ expected_m: int,
836
858
  forward_mode: ForwardMode,
837
859
  ):
838
- # Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
839
- if True: # not forward_mode.is_decode():
860
+ resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
861
+ if resolved_deepep_mode == DeepEPMode.normal:
840
862
  return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
863
+ elif resolved_deepep_mode == DeepEPMode.low_latency:
864
+ return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
841
865
  else:
842
- return self.forward_deepgemm_masked(
843
- hidden_states, reorder_topk_ids, seg_indptr
844
- )
866
+ raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
845
867
 
846
868
  def forward_normal(
847
869
  self,
@@ -958,89 +980,66 @@ class DeepEPMoE(EPMoE):
958
980
 
959
981
  def forward_deepgemm_masked(
960
982
  self,
961
- hidden_states: torch.Tensor,
962
- reorder_topk_ids: torch.Tensor,
963
- seg_indptr: torch.Tensor,
983
+ hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
984
+ masked_m: torch.Tensor,
985
+ expected_m: int,
964
986
  ):
965
987
  assert self.quant_method is not None
966
988
  assert self.activation == "silu"
967
-
968
- if self.activation_scheme == "dynamic" and not self.use_block_quant:
969
- max_value = (
970
- torch.max(hidden_states)
971
- .repeat(self.num_experts_per_partition)
972
- .to(torch.float32)
973
- )
974
- self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
989
+ assert (
990
+ hidden_states_fp8[0].size(0) % 4 == 0
991
+ ), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
975
992
 
976
993
  # GroupGemm-0
994
+ num_groups, m, k = hidden_states_fp8[0].size()
995
+ n = self.w13_weight.size(1)
996
+ expected_m = min(expected_m, m)
977
997
  gateup_output = torch.empty(
978
- hidden_states.shape[0],
979
- self.w13_weight.shape[1],
980
- device=hidden_states.device,
981
- dtype=hidden_states.dtype,
998
+ (num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
999
+ )
1000
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1001
+ hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
982
1002
  )
983
- if hidden_states.shape[0] > 0:
984
- # Transpose earlier so that the testing will not trigger transposing kernels
985
- hidden_states = (
986
- hidden_states[0],
987
- get_col_major_tma_aligned_tensor(hidden_states[1]),
988
- )
989
- """
990
- gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
991
- hidden_states, self.w13_weight, out, masked_m, expected_m
992
- )
993
- """
994
1003
 
995
1004
  # Act
996
1005
  down_input = torch.empty(
997
- gateup_output.shape[0],
998
- gateup_output.shape[1] // 2,
999
- device=gateup_output.device,
1000
- dtype=(
1001
- self.fp8_dtype
1002
- if (self.use_fp8_w8a8 and not self.use_block_quant)
1003
- else hidden_states.dtype
1006
+ (
1007
+ gateup_output.shape[0],
1008
+ gateup_output.shape[1],
1009
+ gateup_output.shape[2] // 2,
1004
1010
  ),
1011
+ device=gateup_output.device,
1012
+ dtype=self.fp8_dtype,
1005
1013
  )
1006
- if self.w2_input_scale is None and not self.use_block_quant:
1007
- self.w2_input_scale = torch.ones(
1008
- self.num_experts_per_partition,
1009
- dtype=torch.float32,
1010
- device=hidden_states.device,
1011
- )
1012
-
1013
- if self.activation == "silu":
1014
- silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
1015
- gateup_output,
1016
- down_input,
1014
+ scale_block_size = 128
1015
+ down_input_scale = torch.empty(
1016
+ (
1017
+ gateup_output.shape[0],
1017
1018
  gateup_output.shape[1],
1018
- reorder_topk_ids,
1019
- self.w2_input_scale,
1020
- 0,
1021
- self.num_experts_per_partition - 1,
1022
- BLOCK_SIZE=512,
1023
- )
1024
- else:
1025
- raise ValueError(f"Unsupported activation: {self.activation=}")
1019
+ gateup_output.shape[2] // 2 // scale_block_size,
1020
+ ),
1021
+ device=gateup_output.device,
1022
+ dtype=torch.float32,
1023
+ )
1024
+ silu_and_mul_masked_post_quant_fwd(
1025
+ gateup_output,
1026
+ down_input,
1027
+ down_input_scale,
1028
+ scale_block_size,
1029
+ masked_m,
1030
+ )
1026
1031
 
1027
1032
  # GroupGemm-1
1033
+ n = self.w2_weight.size(1)
1034
+ down_input_fp8 = (
1035
+ down_input,
1036
+ get_col_major_tma_aligned_tensor(down_input_scale),
1037
+ )
1028
1038
  down_output = torch.empty(
1029
- down_input.shape[0],
1030
- self.w2_weight.shape[1],
1031
- device=hidden_states.device,
1032
- dtype=hidden_states.dtype,
1039
+ (num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
1040
+ )
1041
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1042
+ down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
1033
1043
  )
1034
- if down_input.shape[0] > 0:
1035
- # Transpose earlier so that the testing will not trigger transposing kernels
1036
- down_input = (
1037
- down_input[0],
1038
- get_col_major_tma_aligned_tensor(down_input[1]),
1039
- )
1040
- """
1041
- down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1042
- down_input, self.w2_weight, out, masked_m, expected_m
1043
- )
1044
- """
1045
1044
 
1046
1045
  return down_output