sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/lang/chat_template.py +24 -0
  3. sglang/srt/_custom_ops.py +59 -92
  4. sglang/srt/configs/model_config.py +5 -0
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/conversation.py +29 -4
  7. sglang/srt/custom_op.py +5 -0
  8. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  9. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/layers/attention/flashattention_backend.py +678 -83
  12. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  14. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  16. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  17. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  18. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
  30. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  31. sglang/srt/layers/moe/topk.py +49 -3
  32. sglang/srt/layers/quantization/__init__.py +5 -1
  33. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  35. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  36. sglang/srt/layers/quantization/fp8.py +3 -1
  37. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  38. sglang/srt/layers/quantization/moe_wna16.py +503 -0
  39. sglang/srt/layers/quantization/utils.py +1 -1
  40. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  41. sglang/srt/layers/radix_attention.py +2 -0
  42. sglang/srt/layers/rotary_embedding.py +63 -12
  43. sglang/srt/managers/cache_controller.py +34 -11
  44. sglang/srt/managers/mm_utils.py +202 -156
  45. sglang/srt/managers/multimodal_processor.py +0 -2
  46. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  47. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  48. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  49. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  50. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  51. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  52. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  53. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  54. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  55. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  56. sglang/srt/managers/schedule_batch.py +185 -128
  57. sglang/srt/managers/scheduler.py +4 -4
  58. sglang/srt/managers/tokenizer_manager.py +1 -1
  59. sglang/srt/managers/utils.py +1 -6
  60. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  61. sglang/srt/mem_cache/memory_pool.py +72 -6
  62. sglang/srt/mem_cache/paged_allocator.py +39 -0
  63. sglang/srt/metrics/collector.py +23 -53
  64. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  65. sglang/srt/model_executor/forward_batch_info.py +10 -10
  66. sglang/srt/model_executor/model_runner.py +60 -57
  67. sglang/srt/model_loader/loader.py +8 -0
  68. sglang/srt/models/clip.py +12 -7
  69. sglang/srt/models/deepseek_janus_pro.py +10 -15
  70. sglang/srt/models/deepseek_v2.py +212 -121
  71. sglang/srt/models/deepseek_vl2.py +105 -104
  72. sglang/srt/models/gemma3_mm.py +14 -80
  73. sglang/srt/models/llama.py +16 -5
  74. sglang/srt/models/llama4.py +420 -0
  75. sglang/srt/models/llava.py +31 -19
  76. sglang/srt/models/llavavid.py +16 -7
  77. sglang/srt/models/minicpmo.py +63 -147
  78. sglang/srt/models/minicpmv.py +17 -27
  79. sglang/srt/models/mllama.py +29 -14
  80. sglang/srt/models/mllama4.py +154 -0
  81. sglang/srt/models/qwen2.py +9 -6
  82. sglang/srt/models/qwen2_5_vl.py +21 -31
  83. sglang/srt/models/qwen2_vl.py +20 -21
  84. sglang/srt/openai_api/adapter.py +18 -6
  85. sglang/srt/platforms/interface.py +371 -0
  86. sglang/srt/server_args.py +99 -14
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  88. sglang/srt/speculative/eagle_utils.py +140 -28
  89. sglang/srt/speculative/eagle_worker.py +93 -24
  90. sglang/srt/utils.py +104 -51
  91. sglang/test/test_custom_ops.py +55 -0
  92. sglang/test/test_utils.py +13 -26
  93. sglang/utils.py +2 -2
  94. sglang/version.py +1 -1
  95. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
  96. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
  97. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,6 @@ from functools import partial
14
14
  from typing import TYPE_CHECKING, Callable, List, Optional, Union
15
15
 
16
16
  import torch
17
- import triton
18
17
 
19
18
  from sglang.global_config import global_config
20
19
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -22,7 +21,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
22
21
  from sglang.srt.layers.dp_attention import get_attention_tp_size
23
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
24
23
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
25
- from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
24
+ from sglang.srt.utils import is_flashinfer_available, next_power_of_2
26
25
 
27
26
  if TYPE_CHECKING:
28
27
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -932,6 +931,7 @@ class FlashInferMultiStepDraftBackend:
932
931
  self.topk = topk
933
932
  self.speculative_num_steps = speculative_num_steps
934
933
  self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
934
+ self.page_size = model_runner.page_size
935
935
 
936
936
  max_bs = model_runner.req_to_token_pool.size * self.topk
937
937
  self.kv_indptr = torch.zeros(
@@ -985,9 +985,9 @@ class FlashInferMultiStepDraftBackend:
985
985
  self.pool_len,
986
986
  kv_indices_buffer.shape[1],
987
987
  self.kv_indptr.shape[1],
988
- triton.next_power_of_2(num_seqs),
989
- triton.next_power_of_2(self.speculative_num_steps),
990
- triton.next_power_of_2(bs),
988
+ next_power_of_2(num_seqs),
989
+ next_power_of_2(self.speculative_num_steps),
990
+ next_power_of_2(bs),
991
991
  )
992
992
 
993
993
  assert forward_batch.spec_info is not None
@@ -1018,8 +1018,6 @@ class FlashInferMultiStepDraftBackend:
1018
1018
  )
1019
1019
 
1020
1020
  def call_fn(i, forward_batch):
1021
- assert forward_batch.spec_info is not None
1022
- assert isinstance(forward_batch.spec_info, EagleDraftInput)
1023
1021
  forward_batch.spec_info.kv_indptr = (
1024
1022
  forward_batch.spec_info.kv_indptr.clone()
1025
1023
  )
@@ -71,8 +71,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
71
71
  self.device = model_runner.device
72
72
  self.skip_prefill = skip_prefill
73
73
 
74
- global_config.enable_flashinfer_mla = True
75
-
76
74
  # Allocate buffers
77
75
  global global_workspace_buffer
78
76
  if global_workspace_buffer is None:
@@ -797,7 +795,7 @@ class FlashInferMLAMultiStepDraftBackend:
797
795
  encoder_lens=None,
798
796
  forward_mode=ForwardMode.DECODE,
799
797
  spec_info=forward_batch.spec_info,
800
- seq_lens_cpu=forward_batch.decode_seq_lens_cpu,
798
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
801
799
  )
802
800
 
803
801
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
@@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
92
92
  if forward_batch.forward_mode.is_decode_or_idle():
93
93
  if spec_info is None:
94
94
  max_seqlen_pad = triton.cdiv(
95
- forward_batch.decode_seq_lens_cpu.max().item(), PAGE_SIZE
95
+ forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
96
96
  )
97
97
  block_kv_indices = torch.full(
98
98
  (bs, max_seqlen_pad),
@@ -244,6 +244,148 @@ def silu_and_mul_triton_kernel(
244
244
  tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
245
245
 
246
246
 
247
+ # copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
248
+ @triton.jit
249
+ def _silu_and_mul_post_quant_kernel(
250
+ input_ptr,
251
+ stride_input_0,
252
+ stride_input_1,
253
+ stride_input_2,
254
+ output_ptr,
255
+ stride_output_0,
256
+ stride_output_1,
257
+ stride_output_2,
258
+ output_scale_ptr,
259
+ stride_output_scale_0,
260
+ stride_output_scale_1,
261
+ stride_output_scale_2,
262
+ masked_m_ptr,
263
+ size_n,
264
+ fp8_max,
265
+ fp8_min,
266
+ BLOCK_N: tl.constexpr,
267
+ NUM_STAGE: tl.constexpr,
268
+ ):
269
+ expert_id = tl.program_id(2)
270
+ token_id = tl.program_id(1)
271
+ hidden_dim_block_index = tl.program_id(0)
272
+
273
+ block_num_per_expert = tl.num_programs(1)
274
+
275
+ token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
276
+
277
+ stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
278
+ stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
279
+ stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
280
+ stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
281
+
282
+ offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
283
+ input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
284
+ output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
285
+ output_scale_offs = (
286
+ output_scale_ptr
287
+ + expert_id * stride_output_scale_0
288
+ + hidden_dim_block_index * stride_output_scale_2
289
+ )
290
+
291
+ for token_index in tl.range(
292
+ token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE
293
+ ):
294
+ gate = tl.load(
295
+ input_ptr_offs + token_index * stride_input_1,
296
+ mask=offs_in_d < size_n,
297
+ other=0.0,
298
+ ).to(tl.float32)
299
+ up = tl.load(
300
+ input_ptr_offs + token_index * stride_input_1 + size_n,
301
+ mask=offs_in_d < size_n,
302
+ other=0.0,
303
+ )
304
+ gate = gate / (1 + tl.exp(-gate))
305
+ gate = gate.to(input_ptr.dtype.element_ty)
306
+ gate_up = up * gate
307
+ _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
308
+ output_s = _absmax / fp8_max
309
+ output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
310
+ output_ptr.dtype.element_ty
311
+ )
312
+ tl.store(
313
+ output_ptr_offs + token_index * stride_output_1,
314
+ output_q,
315
+ mask=offs_in_d < size_n,
316
+ )
317
+ tl.store(
318
+ output_scale_offs + token_index * stride_output_scale_1,
319
+ output_s,
320
+ )
321
+
322
+
323
+ def silu_and_mul_masked_post_quant_fwd(
324
+ input: torch.Tensor,
325
+ output: torch.Tensor,
326
+ output_scale: torch.Tensor,
327
+ quant_group_size: int,
328
+ masked_m: torch.Tensor,
329
+ ):
330
+ """
331
+ input shape [expert_num, token_num_padded, hidden_dim]
332
+ output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
333
+ output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
334
+ quant_group_size int,
335
+ masked_m shape [expert_num],
336
+ """
337
+
338
+ assert input.is_contiguous()
339
+ assert output.dtype == torch.float8_e4m3fn
340
+ assert output.is_contiguous()
341
+ assert len(input.shape) == 3
342
+ assert input.shape[0] == masked_m.shape[0]
343
+ assert input.shape[-1] % 2 == 0
344
+
345
+ size_n = input.shape[-1] // 2
346
+ assert size_n % quant_group_size == 0
347
+
348
+ expert_num = len(masked_m)
349
+
350
+ if expert_num < 4:
351
+ BLOCK_NUM_PER_EXPERT = 64
352
+ else:
353
+ BLOCK_NUM_PER_EXPERT = 32
354
+
355
+ BLOCK_N = quant_group_size
356
+ num_warps = 1
357
+ NUM_STAGES = 6
358
+ hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N)
359
+ assert BLOCK_N % quant_group_size == 0
360
+
361
+ grid = (
362
+ hidden_dim_split_block_num,
363
+ BLOCK_NUM_PER_EXPERT,
364
+ expert_num,
365
+ )
366
+
367
+ finfo = torch.finfo(torch.float8_e4m3fn)
368
+ fp8_max = finfo.max
369
+ fp8_min = -fp8_max
370
+
371
+ _silu_and_mul_post_quant_kernel[grid](
372
+ input,
373
+ *input.stride(),
374
+ output,
375
+ *output.stride(),
376
+ output_scale,
377
+ *output_scale.stride(),
378
+ masked_m,
379
+ size_n,
380
+ fp8_max,
381
+ fp8_min,
382
+ BLOCK_N=BLOCK_N,
383
+ NUM_STAGE=NUM_STAGES,
384
+ num_warps=num_warps,
385
+ )
386
+ return
387
+
388
+
247
389
  @triton.jit
248
390
  def tanh(x):
249
391
  return 2 * tl.sigmoid(2 * x) - 1
@@ -3,12 +3,16 @@ from typing import Callable, List, Optional, Tuple
3
3
 
4
4
  import torch
5
5
 
6
- # TODO: use deep_gemm masked kernel after low latency dispatch
7
- # import deep_gemm
8
- # from deep_gemm import (
9
- # get_col_major_tma_aligned_tensor,
10
- # m_grouped_gemm_fp8_fp8_bf16_nt_masked,
11
- # )
6
+ try:
7
+ from deep_gemm import (
8
+ get_col_major_tma_aligned_tensor,
9
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked,
10
+ )
11
+
12
+ use_deep_gemm = True
13
+ except ImportError:
14
+ use_deep_gemm = False
15
+
12
16
  from torch.nn import Module
13
17
 
14
18
  from sglang.srt.custom_op import CustomOp
@@ -22,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
22
26
  post_reorder_triton_kernel,
23
27
  pre_reorder_triton_kernel,
24
28
  run_moe_ep_preproess,
29
+ silu_and_mul_masked_post_quant_fwd,
25
30
  silu_and_mul_triton_kernel,
26
31
  )
27
32
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
@@ -33,7 +38,7 @@ from sglang.srt.layers.quantization.base_config import (
33
38
  )
34
39
  from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
35
40
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
36
- from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
41
+ from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
37
42
 
38
43
  _is_cuda = is_cuda()
39
44
 
@@ -42,7 +47,6 @@ if _is_cuda:
42
47
  else:
43
48
  from vllm import _custom_ops as vllm_ops
44
49
 
45
-
46
50
  logger = logging.getLogger(__name__)
47
51
 
48
52
  _is_hip = is_hip()
@@ -809,6 +813,7 @@ class DeepEPMoE(EPMoE):
809
813
  correction_bias: Optional[torch.Tensor] = None,
810
814
  custom_routing_function: Optional[Callable] = None,
811
815
  activation: str = "silu",
816
+ deepep_mode: DeepEPMode = DeepEPMode.auto,
812
817
  ):
813
818
  super().__init__(
814
819
  num_experts,
@@ -827,21 +832,38 @@ class DeepEPMoE(EPMoE):
827
832
  custom_routing_function,
828
833
  activation,
829
834
  )
835
+ self.deepep_mode = deepep_mode
836
+ if self.deepep_mode.enable_low_latency():
837
+ assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
838
+ self.w13_weight_fp8 = (
839
+ self.w13_weight,
840
+ (
841
+ self.w13_weight_scale_inv
842
+ if self.use_block_quant
843
+ else self.w13_weight_scale
844
+ ),
845
+ )
846
+ self.w2_weight_fp8 = (
847
+ self.w2_weight,
848
+ self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
849
+ )
830
850
 
831
851
  def forward(
832
852
  self,
833
853
  hidden_states: torch.Tensor,
834
854
  reorder_topk_ids: torch.Tensor,
835
855
  seg_indptr: torch.Tensor,
856
+ masked_m: torch.Tensor,
857
+ expected_m: int,
836
858
  forward_mode: ForwardMode,
837
859
  ):
838
- # Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
839
- if True: # not forward_mode.is_decode():
860
+ resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
861
+ if resolved_deepep_mode == DeepEPMode.normal:
840
862
  return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
863
+ elif resolved_deepep_mode == DeepEPMode.low_latency:
864
+ return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
841
865
  else:
842
- return self.forward_deepgemm_masked(
843
- hidden_states, reorder_topk_ids, seg_indptr
844
- )
866
+ raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
845
867
 
846
868
  def forward_normal(
847
869
  self,
@@ -958,89 +980,66 @@ class DeepEPMoE(EPMoE):
958
980
 
959
981
  def forward_deepgemm_masked(
960
982
  self,
961
- hidden_states: torch.Tensor,
962
- reorder_topk_ids: torch.Tensor,
963
- seg_indptr: torch.Tensor,
983
+ hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
984
+ masked_m: torch.Tensor,
985
+ expected_m: int,
964
986
  ):
965
987
  assert self.quant_method is not None
966
988
  assert self.activation == "silu"
967
-
968
- if self.activation_scheme == "dynamic" and not self.use_block_quant:
969
- max_value = (
970
- torch.max(hidden_states)
971
- .repeat(self.num_experts_per_partition)
972
- .to(torch.float32)
973
- )
974
- self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
989
+ assert (
990
+ hidden_states_fp8[0].size(0) % 4 == 0
991
+ ), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
975
992
 
976
993
  # GroupGemm-0
994
+ num_groups, m, k = hidden_states_fp8[0].size()
995
+ n = self.w13_weight.size(1)
996
+ expected_m = min(expected_m, m)
977
997
  gateup_output = torch.empty(
978
- hidden_states.shape[0],
979
- self.w13_weight.shape[1],
980
- device=hidden_states.device,
981
- dtype=hidden_states.dtype,
998
+ (num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
999
+ )
1000
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1001
+ hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
982
1002
  )
983
- if hidden_states.shape[0] > 0:
984
- # Transpose earlier so that the testing will not trigger transposing kernels
985
- hidden_states = (
986
- hidden_states[0],
987
- get_col_major_tma_aligned_tensor(hidden_states[1]),
988
- )
989
- """
990
- gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
991
- hidden_states, self.w13_weight, out, masked_m, expected_m
992
- )
993
- """
994
1003
 
995
1004
  # Act
996
1005
  down_input = torch.empty(
997
- gateup_output.shape[0],
998
- gateup_output.shape[1] // 2,
999
- device=gateup_output.device,
1000
- dtype=(
1001
- self.fp8_dtype
1002
- if (self.use_fp8_w8a8 and not self.use_block_quant)
1003
- else hidden_states.dtype
1006
+ (
1007
+ gateup_output.shape[0],
1008
+ gateup_output.shape[1],
1009
+ gateup_output.shape[2] // 2,
1004
1010
  ),
1011
+ device=gateup_output.device,
1012
+ dtype=self.fp8_dtype,
1005
1013
  )
1006
- if self.w2_input_scale is None and not self.use_block_quant:
1007
- self.w2_input_scale = torch.ones(
1008
- self.num_experts_per_partition,
1009
- dtype=torch.float32,
1010
- device=hidden_states.device,
1011
- )
1012
-
1013
- if self.activation == "silu":
1014
- silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
1015
- gateup_output,
1016
- down_input,
1014
+ scale_block_size = 128
1015
+ down_input_scale = torch.empty(
1016
+ (
1017
+ gateup_output.shape[0],
1017
1018
  gateup_output.shape[1],
1018
- reorder_topk_ids,
1019
- self.w2_input_scale,
1020
- 0,
1021
- self.num_experts_per_partition - 1,
1022
- BLOCK_SIZE=512,
1023
- )
1024
- else:
1025
- raise ValueError(f"Unsupported activation: {self.activation=}")
1019
+ gateup_output.shape[2] // 2 // scale_block_size,
1020
+ ),
1021
+ device=gateup_output.device,
1022
+ dtype=torch.float32,
1023
+ )
1024
+ silu_and_mul_masked_post_quant_fwd(
1025
+ gateup_output,
1026
+ down_input,
1027
+ down_input_scale,
1028
+ scale_block_size,
1029
+ masked_m,
1030
+ )
1026
1031
 
1027
1032
  # GroupGemm-1
1033
+ n = self.w2_weight.size(1)
1034
+ down_input_fp8 = (
1035
+ down_input,
1036
+ get_col_major_tma_aligned_tensor(down_input_scale),
1037
+ )
1028
1038
  down_output = torch.empty(
1029
- down_input.shape[0],
1030
- self.w2_weight.shape[1],
1031
- device=hidden_states.device,
1032
- dtype=hidden_states.dtype,
1039
+ (num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
1040
+ )
1041
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1042
+ down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
1033
1043
  )
1034
- if down_input.shape[0] > 0:
1035
- # Transpose earlier so that the testing will not trigger transposing kernels
1036
- down_input = (
1037
- down_input[0],
1038
- get_col_major_tma_aligned_tensor(down_input[1]),
1039
- )
1040
- """
1041
- down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1042
- down_input, self.w2_weight, out, masked_m, expected_m
1043
- )
1044
- """
1045
1044
 
1046
1045
  return down_output