sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +56 -12
 - sglang/launch_server.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
 - sglang/srt/compilation/backend.py +1 -1
 - sglang/srt/configs/model_config.py +5 -5
 - sglang/srt/distributed/parallel_state.py +0 -7
 - sglang/srt/entrypoints/engine.py +18 -15
 - sglang/srt/entrypoints/grpc_server.py +0 -1
 - sglang/srt/entrypoints/http_server.py +75 -94
 - sglang/srt/environ.py +16 -2
 - sglang/srt/eplb/expert_distribution.py +30 -0
 - sglang/srt/function_call/function_call_parser.py +2 -0
 - sglang/srt/function_call/minimax_m2.py +367 -0
 - sglang/srt/layers/activation.py +6 -0
 - sglang/srt/layers/attention/flashattention_backend.py +12 -2
 - sglang/srt/layers/attention/flashinfer_backend.py +10 -1
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
 - sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
 - sglang/srt/layers/attention/utils.py +78 -0
 - sglang/srt/layers/communicator.py +1 -0
 - sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
 - sglang/srt/layers/layernorm.py +19 -4
 - sglang/srt/layers/logits_processor.py +5 -0
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
 - sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
 - sglang/srt/layers/moe/ep_moe/layer.py +79 -272
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
 - sglang/srt/layers/moe/moe_runner/runner.py +3 -0
 - sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
 - sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
 - sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
 - sglang/srt/layers/moe/topk.py +4 -4
 - sglang/srt/layers/moe/utils.py +3 -4
 - sglang/srt/layers/quantization/__init__.py +3 -5
 - sglang/srt/layers/quantization/awq.py +0 -3
 - sglang/srt/layers/quantization/base_config.py +7 -0
 - sglang/srt/layers/quantization/fp8.py +68 -63
 - sglang/srt/layers/quantization/gguf.py +566 -0
 - sglang/srt/layers/quantization/mxfp4.py +30 -38
 - sglang/srt/layers/quantization/unquant.py +23 -45
 - sglang/srt/layers/quantization/w4afp8.py +38 -2
 - sglang/srt/layers/radix_attention.py +5 -2
 - sglang/srt/layers/rotary_embedding.py +13 -1
 - sglang/srt/layers/sampler.py +12 -1
 - sglang/srt/managers/io_struct.py +3 -0
 - sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
 - sglang/srt/managers/scheduler.py +21 -15
 - sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
 - sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
 - sglang/srt/managers/tokenizer_manager.py +11 -19
 - sglang/srt/mem_cache/hicache_storage.py +7 -1
 - sglang/srt/mem_cache/memory_pool.py +82 -0
 - sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
 - sglang/srt/model_executor/forward_batch_info.py +44 -3
 - sglang/srt/model_executor/model_runner.py +1 -149
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
 - sglang/srt/models/deepseek_v2.py +147 -44
 - sglang/srt/models/glm4_moe.py +322 -354
 - sglang/srt/models/glm4_moe_nextn.py +4 -14
 - sglang/srt/models/glm4v_moe.py +29 -196
 - sglang/srt/models/minimax_m2.py +922 -0
 - sglang/srt/models/nvila.py +355 -0
 - sglang/srt/models/nvila_lite.py +184 -0
 - sglang/srt/models/qwen2.py +22 -1
 - sglang/srt/models/qwen3.py +34 -4
 - sglang/srt/models/qwen3_moe.py +2 -4
 - sglang/srt/multimodal/processors/base_processor.py +1 -0
 - sglang/srt/multimodal/processors/glm4v.py +1 -1
 - sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
 - sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
 - sglang/srt/parser/reasoning_parser.py +28 -1
 - sglang/srt/server_args.py +365 -186
 - sglang/srt/single_batch_overlap.py +2 -7
 - sglang/srt/utils/common.py +87 -42
 - sglang/srt/utils/hf_transformers_utils.py +7 -3
 - sglang/test/test_deterministic.py +235 -12
 - sglang/test/test_deterministic_utils.py +2 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
 - sglang/srt/models/vila.py +0 -306
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
 
    
        sglang/srt/layers/layernorm.py
    CHANGED
    
    | 
         @@ -73,9 +73,16 @@ class RMSNorm(CustomOp): 
     | 
|
| 
       73 
73 
     | 
    
         
             
                    hidden_size: int,
         
     | 
| 
       74 
74 
     | 
    
         
             
                    eps: float = 1e-6,
         
     | 
| 
       75 
75 
     | 
    
         
             
                    var_hidden_size: Optional[int] = None,
         
     | 
| 
      
 76 
     | 
    
         
            +
                    cast_x_before_out_mul: bool = False,
         
     | 
| 
      
 77 
     | 
    
         
            +
                    fp32_residual: bool = False,
         
     | 
| 
      
 78 
     | 
    
         
            +
                    weight_dtype: Optional = None,
         
     | 
| 
      
 79 
     | 
    
         
            +
                    override_orig_dtype: Optional = None,
         
     | 
| 
       76 
80 
     | 
    
         
             
                ) -> None:
         
     | 
| 
       77 
81 
     | 
    
         
             
                    super().__init__()
         
     | 
| 
       78 
     | 
    
         
            -
                    self. 
     | 
| 
      
 82 
     | 
    
         
            +
                    self.cast_x_before_out_mul = cast_x_before_out_mul
         
     | 
| 
      
 83 
     | 
    
         
            +
                    self.fp32_residual = fp32_residual
         
     | 
| 
      
 84 
     | 
    
         
            +
                    self.override_orig_dtype = override_orig_dtype
         
     | 
| 
      
 85 
     | 
    
         
            +
                    self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype))
         
     | 
| 
       79 
86 
     | 
    
         
             
                    self.variance_epsilon = eps
         
     | 
| 
       80 
87 
     | 
    
         
             
                    self.hidden_size = hidden_size
         
     | 
| 
       81 
88 
     | 
    
         
             
                    self.variance_size_override = (
         
     | 
| 
         @@ -165,11 +172,14 @@ class RMSNorm(CustomOp): 
     | 
|
| 
       165 
172 
     | 
    
         
             
                ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         
     | 
| 
       166 
173 
     | 
    
         
             
                    if not x.is_contiguous():
         
     | 
| 
       167 
174 
     | 
    
         
             
                        x = x.contiguous()
         
     | 
| 
       168 
     | 
    
         
            -
                    orig_dtype = x.dtype
         
     | 
| 
      
 175 
     | 
    
         
            +
                    orig_dtype = self.override_orig_dtype or x.dtype
         
     | 
| 
       169 
176 
     | 
    
         
             
                    x = x.to(torch.float32)
         
     | 
| 
       170 
177 
     | 
    
         
             
                    if residual is not None:
         
     | 
| 
       171 
178 
     | 
    
         
             
                        x = x + residual.to(torch.float32)
         
     | 
| 
       172 
     | 
    
         
            -
                         
     | 
| 
      
 179 
     | 
    
         
            +
                        if self.fp32_residual:
         
     | 
| 
      
 180 
     | 
    
         
            +
                            residual = x.clone()
         
     | 
| 
      
 181 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 182 
     | 
    
         
            +
                            residual = x.to(orig_dtype)
         
     | 
| 
       173 
183 
     | 
    
         | 
| 
       174 
184 
     | 
    
         
             
                    hidden_size = x.shape[-1]
         
     | 
| 
       175 
185 
     | 
    
         
             
                    if hidden_size != self.hidden_size:
         
     | 
| 
         @@ -191,7 +201,12 @@ class RMSNorm(CustomOp): 
     | 
|
| 
       191 
201 
     | 
    
         | 
| 
       192 
202 
     | 
    
         
             
                    variance = x_var.pow(2).mean(dim=-1, keepdim=True)
         
     | 
| 
       193 
203 
     | 
    
         
             
                    x = x * torch.rsqrt(variance + self.variance_epsilon)
         
     | 
| 
       194 
     | 
    
         
            -
             
     | 
| 
      
 204 
     | 
    
         
            +
             
     | 
| 
      
 205 
     | 
    
         
            +
                    if self.cast_x_before_out_mul:
         
     | 
| 
      
 206 
     | 
    
         
            +
                        x = self.weight * x.to(orig_dtype)
         
     | 
| 
      
 207 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 208 
     | 
    
         
            +
                        x = (x * self.weight).to(orig_dtype)
         
     | 
| 
      
 209 
     | 
    
         
            +
             
     | 
| 
       195 
210 
     | 
    
         
             
                    if residual is None:
         
     | 
| 
       196 
211 
     | 
    
         
             
                        return x
         
     | 
| 
       197 
212 
     | 
    
         
             
                    else:
         
     | 
| 
         @@ -593,6 +593,11 @@ class LogitsProcessor(nn.Module): 
     | 
|
| 
       593 
593 
     | 
    
         
             
                                None,  # bias
         
     | 
| 
       594 
594 
     | 
    
         
             
                                True,  # is_vnni
         
     | 
| 
       595 
595 
     | 
    
         
             
                            )
         
     | 
| 
      
 596 
     | 
    
         
            +
                        elif get_global_server_args().rl_on_policy_target == "fsdp":
         
     | 
| 
      
 597 
     | 
    
         
            +
                            # Due to tie-weight, we may not be able to change lm_head's weight dtype
         
     | 
| 
      
 598 
     | 
    
         
            +
                            logits = torch.matmul(
         
     | 
| 
      
 599 
     | 
    
         
            +
                                hidden_states.bfloat16(), lm_head.weight.T.bfloat16()
         
     | 
| 
      
 600 
     | 
    
         
            +
                            )
         
     | 
| 
       596 
601 
     | 
    
         
             
                        else:
         
     | 
| 
       597 
602 
     | 
    
         
             
                            logits = torch.matmul(
         
     | 
| 
       598 
603 
     | 
    
         
             
                                hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
         
     | 
| 
         @@ -11,12 +11,14 @@ from sgl_kernel import ( 
     | 
|
| 
       11 
11 
     | 
    
         
             
            )
         
     | 
| 
       12 
12 
     | 
    
         | 
| 
       13 
13 
     | 
    
         
             
            from sglang.srt.layers.moe.ep_moe.kernels import (
         
     | 
| 
      
 14 
     | 
    
         
            +
                deepep_ll_get_cutlass_w4a8_moe_mm_data,
         
     | 
| 
       14 
15 
     | 
    
         
             
                deepep_permute_triton_kernel,
         
     | 
| 
       15 
16 
     | 
    
         
             
                deepep_post_reorder_triton_kernel,
         
     | 
| 
       16 
17 
     | 
    
         
             
                deepep_run_moe_deep_preprocess,
         
     | 
| 
       17 
18 
     | 
    
         
             
                post_reorder_triton_kernel_for_cutlass_moe,
         
     | 
| 
       18 
19 
     | 
    
         
             
                pre_reorder_triton_kernel_for_cutlass_moe,
         
     | 
| 
       19 
20 
     | 
    
         
             
                run_moe_ep_preproess,
         
     | 
| 
      
 21 
     | 
    
         
            +
                silu_and_mul_masked_post_per_tensor_quant_fwd,
         
     | 
| 
       20 
22 
     | 
    
         
             
            )
         
     | 
| 
       21 
23 
     | 
    
         | 
| 
       22 
24 
     | 
    
         | 
| 
         @@ -396,3 +398,139 @@ def cutlass_w4a8_moe_deepep_normal( 
     | 
|
| 
       396 
398 
     | 
    
         
             
                )
         
     | 
| 
       397 
399 
     | 
    
         | 
| 
       398 
400 
     | 
    
         
             
                return output
         
     | 
| 
      
 401 
     | 
    
         
            +
             
     | 
| 
      
 402 
     | 
    
         
            +
             
     | 
| 
      
 403 
     | 
    
         
            +
            def cutlass_w4a8_moe_deepep_ll(
         
     | 
| 
      
 404 
     | 
    
         
            +
                a: torch.Tensor,
         
     | 
| 
      
 405 
     | 
    
         
            +
                w1_q: torch.Tensor,
         
     | 
| 
      
 406 
     | 
    
         
            +
                w2_q: torch.Tensor,
         
     | 
| 
      
 407 
     | 
    
         
            +
                w1_scale: torch.Tensor,
         
     | 
| 
      
 408 
     | 
    
         
            +
                w2_scale: torch.Tensor,
         
     | 
| 
      
 409 
     | 
    
         
            +
                topk_ids_: torch.Tensor,
         
     | 
| 
      
 410 
     | 
    
         
            +
                masked_m: torch.Tensor,
         
     | 
| 
      
 411 
     | 
    
         
            +
                a_strides1: torch.Tensor,
         
     | 
| 
      
 412 
     | 
    
         
            +
                b_strides1: torch.Tensor,
         
     | 
| 
      
 413 
     | 
    
         
            +
                c_strides1: torch.Tensor,
         
     | 
| 
      
 414 
     | 
    
         
            +
                a_strides2: torch.Tensor,
         
     | 
| 
      
 415 
     | 
    
         
            +
                b_strides2: torch.Tensor,
         
     | 
| 
      
 416 
     | 
    
         
            +
                c_strides2: torch.Tensor,
         
     | 
| 
      
 417 
     | 
    
         
            +
                s_strides13: torch.Tensor,
         
     | 
| 
      
 418 
     | 
    
         
            +
                s_strides2: torch.Tensor,
         
     | 
| 
      
 419 
     | 
    
         
            +
                expert_offsets: torch.Tensor,
         
     | 
| 
      
 420 
     | 
    
         
            +
                problem_sizes1: torch.Tensor,
         
     | 
| 
      
 421 
     | 
    
         
            +
                problem_sizes2: torch.Tensor,
         
     | 
| 
      
 422 
     | 
    
         
            +
                a1_scale: Optional[torch.Tensor] = None,
         
     | 
| 
      
 423 
     | 
    
         
            +
                a2_scale: Optional[torch.Tensor] = None,
         
     | 
| 
      
 424 
     | 
    
         
            +
            ) -> torch.Tensor:
         
     | 
| 
      
 425 
     | 
    
         
            +
                """
         
     | 
| 
      
 426 
     | 
    
         
            +
                This function computes a w4a8-quantized Mixture of Experts (MoE) layer
         
     | 
| 
      
 427 
     | 
    
         
            +
                using two sets of quantized weights, w1_q and w2_q, and top-k gating
         
     | 
| 
      
 428 
     | 
    
         
            +
                mechanism. The matrix multiplications are implemented with CUTLASS
         
     | 
| 
      
 429 
     | 
    
         
            +
                grouped gemm.
         
     | 
| 
      
 430 
     | 
    
         
            +
             
     | 
| 
      
 431 
     | 
    
         
            +
                Parameters:
         
     | 
| 
      
 432 
     | 
    
         
            +
                - a (torch.Tensor): The input tensor to the MoE layer.
         
     | 
| 
      
 433 
     | 
    
         
            +
                    Shape: [num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, K]
         
     | 
| 
      
 434 
     | 
    
         
            +
                - w1_q (torch.Tensor): The first set of int4-quantized expert weights.
         
     | 
| 
      
 435 
     | 
    
         
            +
                    Shape: [num_experts, N * 2,  K // 2]
         
     | 
| 
      
 436 
     | 
    
         
            +
                    (the weights are passed transposed and int4-packed)
         
     | 
| 
      
 437 
     | 
    
         
            +
                - w2_q (torch.Tensor): The second set of int4-quantized expert weights.
         
     | 
| 
      
 438 
     | 
    
         
            +
                    Shape: [num_experts, K, N // 2]
         
     | 
| 
      
 439 
     | 
    
         
            +
                    (the weights are passed transposed and int4-packed)
         
     | 
| 
      
 440 
     | 
    
         
            +
                - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
         
     | 
| 
      
 441 
     | 
    
         
            +
                    Shape: [num_experts, K // 512, N * 8]
         
     | 
| 
      
 442 
     | 
    
         
            +
                - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
         
     | 
| 
      
 443 
     | 
    
         
            +
                    Shape: [num_experts, N // 512, K * 4]
         
     | 
| 
      
 444 
     | 
    
         
            +
                - topk_weights (torch.Tensor): The weights of each token->expert mapping.
         
     | 
| 
      
 445 
     | 
    
         
            +
                - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
         
     | 
| 
      
 446 
     | 
    
         
            +
                - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
         
     | 
| 
      
 447 
     | 
    
         
            +
                - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
         
     | 
| 
      
 448 
     | 
    
         
            +
                - a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
         
     | 
| 
      
 449 
     | 
    
         
            +
                - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
         
     | 
| 
      
 450 
     | 
    
         
            +
                - c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
         
     | 
| 
      
 451 
     | 
    
         
            +
                - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
         
     | 
| 
      
 452 
     | 
    
         
            +
                - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
         
     | 
| 
      
 453 
     | 
    
         
            +
                - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
         
     | 
| 
      
 454 
     | 
    
         
            +
                    Shape: scalar or [1, K]
         
     | 
| 
      
 455 
     | 
    
         
            +
                - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
         
     | 
| 
      
 456 
     | 
    
         
            +
                    quantize the intermediate result between the gemms.
         
     | 
| 
      
 457 
     | 
    
         
            +
                    Shape: scalar or [1, N]
         
     | 
| 
      
 458 
     | 
    
         
            +
                - apply_router_weight_on_input (bool): When true, the topk weights are
         
     | 
| 
      
 459 
     | 
    
         
            +
                    applied directly on the inputs. This is only applicable when topk is 1.
         
     | 
| 
      
 460 
     | 
    
         
            +
             
     | 
| 
      
 461 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 462 
     | 
    
         
            +
                - torch.Tensor: The fp8 output tensor after applying the MoE layer.
         
     | 
| 
      
 463 
     | 
    
         
            +
                """
         
     | 
| 
      
 464 
     | 
    
         
            +
                assert w1_q.dtype == torch.int8
         
     | 
| 
      
 465 
     | 
    
         
            +
                assert w2_q.dtype == torch.int8
         
     | 
| 
      
 466 
     | 
    
         
            +
                assert a.shape[2] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
         
     | 
| 
      
 467 
     | 
    
         
            +
                assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
         
     | 
| 
      
 468 
     | 
    
         
            +
                assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
         
     | 
| 
      
 469 
     | 
    
         
            +
                assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
         
     | 
| 
      
 470 
     | 
    
         
            +
                assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
         
     | 
| 
      
 471 
     | 
    
         
            +
             
     | 
| 
      
 472 
     | 
    
         
            +
                assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
         
     | 
| 
      
 473 
     | 
    
         
            +
                assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
         
     | 
| 
      
 474 
     | 
    
         
            +
                assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
         
     | 
| 
      
 475 
     | 
    
         
            +
                assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
         
     | 
| 
      
 476 
     | 
    
         
            +
                num_experts = w1_q.size(0)
         
     | 
| 
      
 477 
     | 
    
         
            +
                m = a.size(1)
         
     | 
| 
      
 478 
     | 
    
         
            +
                k = w1_q.size(2) * 2  # w1_q is transposed and packed
         
     | 
| 
      
 479 
     | 
    
         
            +
                n = w2_q.size(2) * 2  # w2_q is transposed and packed
         
     | 
| 
      
 480 
     | 
    
         
            +
                topk = topk_ids_.size(1)
         
     | 
| 
      
 481 
     | 
    
         
            +
             
     | 
| 
      
 482 
     | 
    
         
            +
                device = a.device
         
     | 
| 
      
 483 
     | 
    
         
            +
             
     | 
| 
      
 484 
     | 
    
         
            +
                problem_sizes1, problem_sizes2 = deepep_ll_get_cutlass_w4a8_moe_mm_data(
         
     | 
| 
      
 485 
     | 
    
         
            +
                    masked_m,
         
     | 
| 
      
 486 
     | 
    
         
            +
                    problem_sizes1,
         
     | 
| 
      
 487 
     | 
    
         
            +
                    problem_sizes2,
         
     | 
| 
      
 488 
     | 
    
         
            +
                    num_experts,
         
     | 
| 
      
 489 
     | 
    
         
            +
                    n,
         
     | 
| 
      
 490 
     | 
    
         
            +
                    k,
         
     | 
| 
      
 491 
     | 
    
         
            +
                )
         
     | 
| 
      
 492 
     | 
    
         
            +
             
     | 
| 
      
 493 
     | 
    
         
            +
                gateup_input = torch.empty(a.shape, dtype=torch.float8_e4m3fn, device=device)
         
     | 
| 
      
 494 
     | 
    
         
            +
                sgl_per_tensor_quant_fp8(a, gateup_input, a1_scale.float(), True)
         
     | 
| 
      
 495 
     | 
    
         
            +
                c1 = torch.empty((num_experts, m, n * 2), device=device, dtype=torch.bfloat16)
         
     | 
| 
      
 496 
     | 
    
         
            +
                c2 = torch.empty((num_experts, m, k), device=device, dtype=torch.bfloat16)
         
     | 
| 
      
 497 
     | 
    
         
            +
             
     | 
| 
      
 498 
     | 
    
         
            +
                cutlass_w4a8_moe_mm(
         
     | 
| 
      
 499 
     | 
    
         
            +
                    c1,
         
     | 
| 
      
 500 
     | 
    
         
            +
                    gateup_input,
         
     | 
| 
      
 501 
     | 
    
         
            +
                    w1_q,
         
     | 
| 
      
 502 
     | 
    
         
            +
                    a1_scale.float(),
         
     | 
| 
      
 503 
     | 
    
         
            +
                    w1_scale,
         
     | 
| 
      
 504 
     | 
    
         
            +
                    expert_offsets[:-1],
         
     | 
| 
      
 505 
     | 
    
         
            +
                    problem_sizes1,
         
     | 
| 
      
 506 
     | 
    
         
            +
                    a_strides1,
         
     | 
| 
      
 507 
     | 
    
         
            +
                    b_strides1,
         
     | 
| 
      
 508 
     | 
    
         
            +
                    c_strides1,
         
     | 
| 
      
 509 
     | 
    
         
            +
                    s_strides13,
         
     | 
| 
      
 510 
     | 
    
         
            +
                    128,
         
     | 
| 
      
 511 
     | 
    
         
            +
                    topk,
         
     | 
| 
      
 512 
     | 
    
         
            +
                )
         
     | 
| 
      
 513 
     | 
    
         
            +
             
     | 
| 
      
 514 
     | 
    
         
            +
                intermediate_q = torch.empty(
         
     | 
| 
      
 515 
     | 
    
         
            +
                    (num_experts, m, n), device=a.device, dtype=torch.float8_e4m3fn
         
     | 
| 
      
 516 
     | 
    
         
            +
                )
         
     | 
| 
      
 517 
     | 
    
         
            +
                silu_and_mul_masked_post_per_tensor_quant_fwd(
         
     | 
| 
      
 518 
     | 
    
         
            +
                    c1, intermediate_q, masked_m, a2_scale
         
     | 
| 
      
 519 
     | 
    
         
            +
                )
         
     | 
| 
      
 520 
     | 
    
         
            +
                cutlass_w4a8_moe_mm(
         
     | 
| 
      
 521 
     | 
    
         
            +
                    c2,
         
     | 
| 
      
 522 
     | 
    
         
            +
                    intermediate_q,
         
     | 
| 
      
 523 
     | 
    
         
            +
                    w2_q,
         
     | 
| 
      
 524 
     | 
    
         
            +
                    a2_scale.float(),
         
     | 
| 
      
 525 
     | 
    
         
            +
                    w2_scale,
         
     | 
| 
      
 526 
     | 
    
         
            +
                    expert_offsets[:-1],
         
     | 
| 
      
 527 
     | 
    
         
            +
                    problem_sizes2,
         
     | 
| 
      
 528 
     | 
    
         
            +
                    a_strides2,
         
     | 
| 
      
 529 
     | 
    
         
            +
                    b_strides2,
         
     | 
| 
      
 530 
     | 
    
         
            +
                    c_strides2,
         
     | 
| 
      
 531 
     | 
    
         
            +
                    s_strides2,
         
     | 
| 
      
 532 
     | 
    
         
            +
                    128,
         
     | 
| 
      
 533 
     | 
    
         
            +
                    topk,
         
     | 
| 
      
 534 
     | 
    
         
            +
                )
         
     | 
| 
      
 535 
     | 
    
         
            +
             
     | 
| 
      
 536 
     | 
    
         
            +
                return c2
         
     | 
| 
         @@ -1014,3 +1014,197 @@ def zero_experts_compute_triton( 
     | 
|
| 
       1014 
1014 
     | 
    
         
             
                )
         
     | 
| 
       1015 
1015 
     | 
    
         | 
| 
       1016 
1016 
     | 
    
         
             
                return output
         
     | 
| 
      
 1017 
     | 
    
         
            +
             
     | 
| 
      
 1018 
     | 
    
         
            +
             
     | 
| 
      
 1019 
     | 
    
         
            +
            @triton.jit
         
     | 
| 
      
 1020 
     | 
    
         
            +
            def compute_problem_sizes_w4a8_kernel(
         
     | 
| 
      
 1021 
     | 
    
         
            +
                masked_m_ptr,
         
     | 
| 
      
 1022 
     | 
    
         
            +
                problem_sizes1_ptr,
         
     | 
| 
      
 1023 
     | 
    
         
            +
                problem_sizes2_ptr,
         
     | 
| 
      
 1024 
     | 
    
         
            +
                n,
         
     | 
| 
      
 1025 
     | 
    
         
            +
                k,
         
     | 
| 
      
 1026 
     | 
    
         
            +
                num_experts,
         
     | 
| 
      
 1027 
     | 
    
         
            +
                BLOCK_SIZE: tl.constexpr,
         
     | 
| 
      
 1028 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1029 
     | 
    
         
            +
                pid = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
         
     | 
| 
      
 1030 
     | 
    
         
            +
                mask = pid < num_experts
         
     | 
| 
      
 1031 
     | 
    
         
            +
                final_occurrences = tl.load(masked_m_ptr + pid, mask=mask, other=0)
         
     | 
| 
      
 1032 
     | 
    
         
            +
             
     | 
| 
      
 1033 
     | 
    
         
            +
                ps1_idx_0 = pid * 3
         
     | 
| 
      
 1034 
     | 
    
         
            +
                ps1_idx_1 = ps1_idx_0 + 1
         
     | 
| 
      
 1035 
     | 
    
         
            +
                ps1_idx_2 = ps1_idx_0 + 2
         
     | 
| 
      
 1036 
     | 
    
         
            +
             
     | 
| 
      
 1037 
     | 
    
         
            +
                ps2_idx_0 = pid * 3
         
     | 
| 
      
 1038 
     | 
    
         
            +
                ps2_idx_1 = ps2_idx_0 + 1
         
     | 
| 
      
 1039 
     | 
    
         
            +
                ps2_idx_2 = ps2_idx_0 + 2
         
     | 
| 
      
 1040 
     | 
    
         
            +
             
     | 
| 
      
 1041 
     | 
    
         
            +
                ps1_mask_0 = ps1_idx_0 < num_experts * 3
         
     | 
| 
      
 1042 
     | 
    
         
            +
                ps1_mask_1 = ps1_idx_1 < num_experts * 3
         
     | 
| 
      
 1043 
     | 
    
         
            +
                ps1_mask_2 = ps1_idx_2 < num_experts * 3
         
     | 
| 
      
 1044 
     | 
    
         
            +
                ps2_mask_0 = ps2_idx_0 < num_experts * 3
         
     | 
| 
      
 1045 
     | 
    
         
            +
                ps2_mask_1 = ps2_idx_1 < num_experts * 3
         
     | 
| 
      
 1046 
     | 
    
         
            +
                ps2_mask_2 = ps2_idx_2 < num_experts * 3
         
     | 
| 
      
 1047 
     | 
    
         
            +
             
     | 
| 
      
 1048 
     | 
    
         
            +
                tl.store(problem_sizes1_ptr + ps1_idx_0, 2 * n, mask=ps1_mask_0)
         
     | 
| 
      
 1049 
     | 
    
         
            +
                tl.store(problem_sizes1_ptr + ps1_idx_1, final_occurrences, mask=ps1_mask_1)
         
     | 
| 
      
 1050 
     | 
    
         
            +
                tl.store(problem_sizes1_ptr + ps1_idx_2, k, mask=ps1_mask_2)
         
     | 
| 
      
 1051 
     | 
    
         
            +
             
     | 
| 
      
 1052 
     | 
    
         
            +
                tl.store(problem_sizes2_ptr + ps2_idx_0, k, mask=ps2_mask_0)
         
     | 
| 
      
 1053 
     | 
    
         
            +
                tl.store(problem_sizes2_ptr + ps2_idx_1, final_occurrences, mask=ps2_mask_1)
         
     | 
| 
      
 1054 
     | 
    
         
            +
                tl.store(problem_sizes2_ptr + ps2_idx_2, n, mask=ps2_mask_2)
         
     | 
| 
      
 1055 
     | 
    
         
            +
             
     | 
| 
      
 1056 
     | 
    
         
            +
             
     | 
| 
      
 1057 
     | 
    
         
            +
            def compute_problem_sizes_w4a8(
         
     | 
| 
      
 1058 
     | 
    
         
            +
                masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
         
     | 
| 
      
 1059 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1060 
     | 
    
         
            +
                BLOCK_SIZE = 256
         
     | 
| 
      
 1061 
     | 
    
         
            +
                grid = lambda meta: (triton.cdiv(num_experts, meta["BLOCK_SIZE"]),)
         
     | 
| 
      
 1062 
     | 
    
         
            +
                compute_problem_sizes_w4a8_kernel[grid](
         
     | 
| 
      
 1063 
     | 
    
         
            +
                    masked_m,
         
     | 
| 
      
 1064 
     | 
    
         
            +
                    problem_sizes1,
         
     | 
| 
      
 1065 
     | 
    
         
            +
                    problem_sizes2,
         
     | 
| 
      
 1066 
     | 
    
         
            +
                    n,
         
     | 
| 
      
 1067 
     | 
    
         
            +
                    k,
         
     | 
| 
      
 1068 
     | 
    
         
            +
                    num_experts,
         
     | 
| 
      
 1069 
     | 
    
         
            +
                    BLOCK_SIZE=BLOCK_SIZE,
         
     | 
| 
      
 1070 
     | 
    
         
            +
                )
         
     | 
| 
      
 1071 
     | 
    
         
            +
                return problem_sizes1, problem_sizes2
         
     | 
| 
      
 1072 
     | 
    
         
            +
             
     | 
| 
      
 1073 
     | 
    
         
            +
             
     | 
| 
      
 1074 
     | 
    
         
            +
            def deepep_ll_get_cutlass_w4a8_moe_mm_data(
         
     | 
| 
      
 1075 
     | 
    
         
            +
                masked_m,
         
     | 
| 
      
 1076 
     | 
    
         
            +
                problem_sizes1,
         
     | 
| 
      
 1077 
     | 
    
         
            +
                problem_sizes2,
         
     | 
| 
      
 1078 
     | 
    
         
            +
                num_experts,
         
     | 
| 
      
 1079 
     | 
    
         
            +
                n,
         
     | 
| 
      
 1080 
     | 
    
         
            +
                k,
         
     | 
| 
      
 1081 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1082 
     | 
    
         
            +
                problem_sizes1, problem_sizes2 = compute_problem_sizes_w4a8(
         
     | 
| 
      
 1083 
     | 
    
         
            +
                    masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
         
     | 
| 
      
 1084 
     | 
    
         
            +
                )
         
     | 
| 
      
 1085 
     | 
    
         
            +
                return (
         
     | 
| 
      
 1086 
     | 
    
         
            +
                    problem_sizes1.to(torch.int32),
         
     | 
| 
      
 1087 
     | 
    
         
            +
                    problem_sizes2.to(torch.int32),
         
     | 
| 
      
 1088 
     | 
    
         
            +
                )
         
     | 
| 
      
 1089 
     | 
    
         
            +
             
     | 
| 
      
 1090 
     | 
    
         
            +
             
     | 
| 
      
 1091 
     | 
    
         
            +
            @triton.jit
         
     | 
| 
      
 1092 
     | 
    
         
            +
            def _silu_and_mul_post_per_tensor_quant_kernel(
         
     | 
| 
      
 1093 
     | 
    
         
            +
                input_ptr,
         
     | 
| 
      
 1094 
     | 
    
         
            +
                stride_input_expert,
         
     | 
| 
      
 1095 
     | 
    
         
            +
                stride_input_token,
         
     | 
| 
      
 1096 
     | 
    
         
            +
                stride_input_dim,
         
     | 
| 
      
 1097 
     | 
    
         
            +
                output_ptr,
         
     | 
| 
      
 1098 
     | 
    
         
            +
                stride_output_expert,
         
     | 
| 
      
 1099 
     | 
    
         
            +
                stride_output_token,
         
     | 
| 
      
 1100 
     | 
    
         
            +
                stride_output_dim,
         
     | 
| 
      
 1101 
     | 
    
         
            +
                scale_ptr,
         
     | 
| 
      
 1102 
     | 
    
         
            +
                masked_m_ptr,
         
     | 
| 
      
 1103 
     | 
    
         
            +
                inner_dim,
         
     | 
| 
      
 1104 
     | 
    
         
            +
                fp8_max,
         
     | 
| 
      
 1105 
     | 
    
         
            +
                fp8_min,
         
     | 
| 
      
 1106 
     | 
    
         
            +
                BLOCK_N: tl.constexpr,
         
     | 
| 
      
 1107 
     | 
    
         
            +
                NUM_STAGE: tl.constexpr,
         
     | 
| 
      
 1108 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1109 
     | 
    
         
            +
                """
         
     | 
| 
      
 1110 
     | 
    
         
            +
                Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization.
         
     | 
| 
      
 1111 
     | 
    
         
            +
             
     | 
| 
      
 1112 
     | 
    
         
            +
                Shape:
         
     | 
| 
      
 1113 
     | 
    
         
            +
                    input:  [E, T_padded, 2*D]  -> gate: [:,:,D], up: [:,:,D]
         
     | 
| 
      
 1114 
     | 
    
         
            +
                    output: [E, T_padded, D], dtype=float8_e4m3fn
         
     | 
| 
      
 1115 
     | 
    
         
            +
                """
         
     | 
| 
      
 1116 
     | 
    
         
            +
                expert_id = tl.program_id(2)
         
     | 
| 
      
 1117 
     | 
    
         
            +
                block_id_token = tl.program_id(1)
         
     | 
| 
      
 1118 
     | 
    
         
            +
                block_id_dim = tl.program_id(0)
         
     | 
| 
      
 1119 
     | 
    
         
            +
             
     | 
| 
      
 1120 
     | 
    
         
            +
                num_token_blocks = tl.num_programs(1)
         
     | 
| 
      
 1121 
     | 
    
         
            +
             
     | 
| 
      
 1122 
     | 
    
         
            +
                token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
         
     | 
| 
      
 1123 
     | 
    
         
            +
             
     | 
| 
      
 1124 
     | 
    
         
            +
                scale = 1.0 / tl.load(scale_ptr).to(tl.float32)
         
     | 
| 
      
 1125 
     | 
    
         
            +
             
     | 
| 
      
 1126 
     | 
    
         
            +
                stride_input_expert = tl.cast(stride_input_expert, tl.int32)
         
     | 
| 
      
 1127 
     | 
    
         
            +
                stride_output_expert = tl.cast(stride_output_expert, tl.int32)
         
     | 
| 
      
 1128 
     | 
    
         
            +
                stride_input_token = tl.cast(stride_input_token, tl.int32)
         
     | 
| 
      
 1129 
     | 
    
         
            +
                stride_output_token = tl.cast(stride_output_token, tl.int32)
         
     | 
| 
      
 1130 
     | 
    
         
            +
             
     | 
| 
      
 1131 
     | 
    
         
            +
                offset_d = block_id_dim * BLOCK_N + tl.arange(0, BLOCK_N)
         
     | 
| 
      
 1132 
     | 
    
         
            +
                mask_d = offset_d < inner_dim
         
     | 
| 
      
 1133 
     | 
    
         
            +
             
     | 
| 
      
 1134 
     | 
    
         
            +
                # base pointers for current expert and dim block
         
     | 
| 
      
 1135 
     | 
    
         
            +
                input_base_offs = input_ptr + expert_id * stride_input_expert + offset_d
         
     | 
| 
      
 1136 
     | 
    
         
            +
                output_base_offs = output_ptr + expert_id * stride_output_expert + offset_d
         
     | 
| 
      
 1137 
     | 
    
         
            +
             
     | 
| 
      
 1138 
     | 
    
         
            +
                for token_idx in tl.range(
         
     | 
| 
      
 1139 
     | 
    
         
            +
                    block_id_token, token_num_cur_expert, num_token_blocks, num_stages=NUM_STAGE
         
     | 
| 
      
 1140 
     | 
    
         
            +
                ):
         
     | 
| 
      
 1141 
     | 
    
         
            +
                    gate_ptr = input_base_offs + token_idx * stride_input_token
         
     | 
| 
      
 1142 
     | 
    
         
            +
                    up_ptr = gate_ptr + inner_dim
         
     | 
| 
      
 1143 
     | 
    
         
            +
                    gate = tl.load(gate_ptr, mask=mask_d, other=0.0).to(tl.float32)
         
     | 
| 
      
 1144 
     | 
    
         
            +
                    up = tl.load(up_ptr, mask=mask_d, other=0.0).to(tl.float32)
         
     | 
| 
      
 1145 
     | 
    
         
            +
             
     | 
| 
      
 1146 
     | 
    
         
            +
                    # SiLU: x * sigmoid(x)
         
     | 
| 
      
 1147 
     | 
    
         
            +
                    gate = gate / (1 + tl.exp(-gate))
         
     | 
| 
      
 1148 
     | 
    
         
            +
                    gate = gate.to(input_ptr.dtype.element_ty)
         
     | 
| 
      
 1149 
     | 
    
         
            +
                    gate_up = up * gate
         
     | 
| 
      
 1150 
     | 
    
         
            +
             
     | 
| 
      
 1151 
     | 
    
         
            +
                    scaled = gate_up * scale
         
     | 
| 
      
 1152 
     | 
    
         
            +
                    output_q = tl.clamp(scaled, fp8_min, fp8_max).to(output_ptr.dtype.element_ty)
         
     | 
| 
      
 1153 
     | 
    
         
            +
                    out_ptr = output_base_offs + token_idx * stride_output_token
         
     | 
| 
      
 1154 
     | 
    
         
            +
                    tl.store(out_ptr, output_q, mask=mask_d)
         
     | 
| 
      
 1155 
     | 
    
         
            +
             
     | 
| 
      
 1156 
     | 
    
         
            +
             
     | 
| 
      
 1157 
     | 
    
         
            +
            def silu_and_mul_masked_post_per_tensor_quant_fwd(
         
     | 
| 
      
 1158 
     | 
    
         
            +
                input: torch.Tensor,
         
     | 
| 
      
 1159 
     | 
    
         
            +
                output: torch.Tensor,
         
     | 
| 
      
 1160 
     | 
    
         
            +
                masked_m: torch.Tensor,
         
     | 
| 
      
 1161 
     | 
    
         
            +
                scale: torch.Tensor,
         
     | 
| 
      
 1162 
     | 
    
         
            +
            ) -> torch.Tensor:
         
     | 
| 
      
 1163 
     | 
    
         
            +
                """
         
     | 
| 
      
 1164 
     | 
    
         
            +
                Fused SiLU + Mul + Per-Tensor Quantization to FP8.
         
     | 
| 
      
 1165 
     | 
    
         
            +
             
     | 
| 
      
 1166 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 1167 
     | 
    
         
            +
                    input: [expert_num, token_num_padded, 2 * inner_dim]
         
     | 
| 
      
 1168 
     | 
    
         
            +
                    output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn
         
     | 
| 
      
 1169 
     | 
    
         
            +
                    masked_m: [expert_num], actual token count for each expert
         
     | 
| 
      
 1170 
     | 
    
         
            +
                    scale: [1] or [expert_num], quantization scale (per-tensor or per-expert)
         
     | 
| 
      
 1171 
     | 
    
         
            +
             
     | 
| 
      
 1172 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 1173 
     | 
    
         
            +
                    output tensor
         
     | 
| 
      
 1174 
     | 
    
         
            +
                """
         
     | 
| 
      
 1175 
     | 
    
         
            +
                assert input.is_contiguous()
         
     | 
| 
      
 1176 
     | 
    
         
            +
                assert output.is_contiguous()
         
     | 
| 
      
 1177 
     | 
    
         
            +
                assert output.dtype == torch.float8_e4m3fn
         
     | 
| 
      
 1178 
     | 
    
         
            +
                assert input.ndim == 3
         
     | 
| 
      
 1179 
     | 
    
         
            +
                assert input.shape[0] == masked_m.shape[0]
         
     | 
| 
      
 1180 
     | 
    
         
            +
                assert input.shape[-1] % 2 == 0
         
     | 
| 
      
 1181 
     | 
    
         
            +
                assert scale.numel() == 1 or scale.shape[0] == input.shape[0]
         
     | 
| 
      
 1182 
     | 
    
         
            +
             
     | 
| 
      
 1183 
     | 
    
         
            +
                expert_num = input.shape[0]
         
     | 
| 
      
 1184 
     | 
    
         
            +
                #  3584
         
     | 
| 
      
 1185 
     | 
    
         
            +
                inner_dim = input.shape[-1] // 2
         
     | 
| 
      
 1186 
     | 
    
         
            +
             
     | 
| 
      
 1187 
     | 
    
         
            +
                BLOCK_N = 256
         
     | 
| 
      
 1188 
     | 
    
         
            +
                BLOCK_M = 64 if expert_num < 4 else 32
         
     | 
| 
      
 1189 
     | 
    
         
            +
                NUM_STAGES = 3
         
     | 
| 
      
 1190 
     | 
    
         
            +
                hidden_dim_split_block_num = triton.cdiv(inner_dim, BLOCK_N)
         
     | 
| 
      
 1191 
     | 
    
         
            +
             
     | 
| 
      
 1192 
     | 
    
         
            +
                grid = (hidden_dim_split_block_num, BLOCK_M, expert_num)
         
     | 
| 
      
 1193 
     | 
    
         
            +
                finfo = torch.finfo(torch.float8_e4m3fn)
         
     | 
| 
      
 1194 
     | 
    
         
            +
                fp8_max = finfo.max
         
     | 
| 
      
 1195 
     | 
    
         
            +
                fp8_min = -fp8_max
         
     | 
| 
      
 1196 
     | 
    
         
            +
             
     | 
| 
      
 1197 
     | 
    
         
            +
                _silu_and_mul_post_per_tensor_quant_kernel[grid](
         
     | 
| 
      
 1198 
     | 
    
         
            +
                    input,
         
     | 
| 
      
 1199 
     | 
    
         
            +
                    *input.stride(),
         
     | 
| 
      
 1200 
     | 
    
         
            +
                    output,
         
     | 
| 
      
 1201 
     | 
    
         
            +
                    *output.stride(),
         
     | 
| 
      
 1202 
     | 
    
         
            +
                    scale,
         
     | 
| 
      
 1203 
     | 
    
         
            +
                    masked_m,
         
     | 
| 
      
 1204 
     | 
    
         
            +
                    inner_dim,
         
     | 
| 
      
 1205 
     | 
    
         
            +
                    fp8_max,
         
     | 
| 
      
 1206 
     | 
    
         
            +
                    fp8_min,
         
     | 
| 
      
 1207 
     | 
    
         
            +
                    BLOCK_N=BLOCK_N,
         
     | 
| 
      
 1208 
     | 
    
         
            +
                    NUM_STAGE=NUM_STAGES,
         
     | 
| 
      
 1209 
     | 
    
         
            +
                )
         
     | 
| 
      
 1210 
     | 
    
         
            +
                return output
         
     |