sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +73 -14
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/launch_server.py +2 -0
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/compilation/backend.py +1 -1
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +30 -7
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -12
 - sglang/srt/entrypoints/engine.py +31 -20
 - sglang/srt/entrypoints/grpc_server.py +0 -1
 - sglang/srt/entrypoints/http_server.py +94 -94
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +23 -2
 - sglang/srt/eplb/expert_distribution.py +64 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/function_call/function_call_parser.py +2 -0
 - sglang/srt/function_call/minimax_m2.py +367 -0
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/activation.py +6 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +19 -8
 - sglang/srt/layers/attention/flashinfer_backend.py +10 -1
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
 - sglang/srt/layers/attention/utils.py +78 -0
 - sglang/srt/layers/communicator.py +24 -1
 - sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
 - sglang/srt/layers/layernorm.py +35 -6
 - sglang/srt/layers/logits_processor.py +9 -20
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
 - sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
 - sglang/srt/layers/moe/ep_moe/layer.py +78 -289
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
 - sglang/srt/layers/moe/moe_runner/runner.py +3 -0
 - sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
 - sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
 - sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
 - sglang/srt/layers/moe/topk.py +35 -10
 - sglang/srt/layers/moe/utils.py +3 -4
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +13 -84
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/awq.py +0 -3
 - sglang/srt/layers/quantization/base_config.py +7 -0
 - sglang/srt/layers/quantization/fp8.py +68 -63
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/gguf.py +566 -0
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/quantization/mxfp4.py +30 -38
 - sglang/srt/layers/quantization/unquant.py +23 -45
 - sglang/srt/layers/quantization/w4afp8.py +38 -2
 - sglang/srt/layers/radix_attention.py +5 -2
 - sglang/srt/layers/rotary_embedding.py +130 -46
 - sglang/srt/layers/sampler.py +12 -1
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +29 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +185 -144
 - sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +165 -78
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/hicache_storage.py +7 -1
 - sglang/srt/mem_cache/memory_pool.py +253 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +55 -14
 - sglang/srt/model_executor/model_runner.py +77 -170
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +296 -78
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4_moe.py +322 -354
 - sglang/srt/models/glm4_moe_nextn.py +4 -14
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +29 -197
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +922 -0
 - sglang/srt/models/nvila.py +355 -0
 - sglang/srt/models/nvila_lite.py +184 -0
 - sglang/srt/models/qwen2.py +23 -2
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +35 -5
 - sglang/srt/models/qwen3_moe.py +18 -12
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multimodal/processors/base_processor.py +1 -0
 - sglang/srt/multimodal/processors/glm4v.py +1 -1
 - sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
 - sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/parser/reasoning_parser.py +28 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +459 -199
 - sglang/srt/single_batch_overlap.py +2 -4
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +142 -74
 - sglang/srt/utils/hf_transformers_utils.py +38 -12
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_deterministic.py +235 -12
 - sglang/test/test_deterministic_utils.py +2 -1
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
 - sglang/srt/models/vila.py +0 -306
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -1,3 +1,4 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import torch
         
     | 
| 
       1 
2 
     | 
    
         
             
            import triton
         
     | 
| 
       2 
3 
     | 
    
         
             
            import triton.language as tl
         
     | 
| 
       3 
4 
     | 
    
         | 
| 
         @@ -101,3 +102,80 @@ def create_flashmla_kv_indices_triton( 
     | 
|
| 
       101 
102 
     | 
    
         
             
                        data // PAGED_SIZE,
         
     | 
| 
       102 
103 
     | 
    
         
             
                        mask=mask_out,
         
     | 
| 
       103 
104 
     | 
    
         
             
                    )
         
     | 
| 
      
 105 
     | 
    
         
            +
             
     | 
| 
      
 106 
     | 
    
         
            +
             
     | 
| 
      
 107 
     | 
    
         
            +
            @triton.jit
         
     | 
| 
      
 108 
     | 
    
         
            +
            def concat_and_cast_mha_k_kernel(
         
     | 
| 
      
 109 
     | 
    
         
            +
                k_ptr,
         
     | 
| 
      
 110 
     | 
    
         
            +
                k_nope_ptr,
         
     | 
| 
      
 111 
     | 
    
         
            +
                k_rope_ptr,
         
     | 
| 
      
 112 
     | 
    
         
            +
                head_cnt: tl.constexpr,
         
     | 
| 
      
 113 
     | 
    
         
            +
                k_stride0: tl.constexpr,
         
     | 
| 
      
 114 
     | 
    
         
            +
                k_stride1: tl.constexpr,
         
     | 
| 
      
 115 
     | 
    
         
            +
                nope_stride0: tl.constexpr,
         
     | 
| 
      
 116 
     | 
    
         
            +
                nope_stride1: tl.constexpr,
         
     | 
| 
      
 117 
     | 
    
         
            +
                rope_stride0: tl.constexpr,
         
     | 
| 
      
 118 
     | 
    
         
            +
                nope_dim: tl.constexpr,
         
     | 
| 
      
 119 
     | 
    
         
            +
                rope_dim: tl.constexpr,
         
     | 
| 
      
 120 
     | 
    
         
            +
            ):
         
     | 
| 
      
 121 
     | 
    
         
            +
                pid_loc = tl.program_id(0)
         
     | 
| 
      
 122 
     | 
    
         
            +
                head_range = tl.arange(0, head_cnt)
         
     | 
| 
      
 123 
     | 
    
         
            +
             
     | 
| 
      
 124 
     | 
    
         
            +
                k_head_ptr = k_ptr + pid_loc * k_stride0 + head_range[:, None] * k_stride1
         
     | 
| 
      
 125 
     | 
    
         
            +
             
     | 
| 
      
 126 
     | 
    
         
            +
                nope_offs = tl.arange(0, nope_dim)
         
     | 
| 
      
 127 
     | 
    
         
            +
             
     | 
| 
      
 128 
     | 
    
         
            +
                src_nope_ptr = (
         
     | 
| 
      
 129 
     | 
    
         
            +
                    k_nope_ptr
         
     | 
| 
      
 130 
     | 
    
         
            +
                    + pid_loc * nope_stride0
         
     | 
| 
      
 131 
     | 
    
         
            +
                    + head_range[:, None] * nope_stride1
         
     | 
| 
      
 132 
     | 
    
         
            +
                    + nope_offs[None, :]
         
     | 
| 
      
 133 
     | 
    
         
            +
                )
         
     | 
| 
      
 134 
     | 
    
         
            +
                dst_nope_ptr = k_head_ptr + nope_offs[None, :]
         
     | 
| 
      
 135 
     | 
    
         
            +
             
     | 
| 
      
 136 
     | 
    
         
            +
                src_nope = tl.load(src_nope_ptr)
         
     | 
| 
      
 137 
     | 
    
         
            +
                tl.store(dst_nope_ptr, src_nope)
         
     | 
| 
      
 138 
     | 
    
         
            +
             
     | 
| 
      
 139 
     | 
    
         
            +
                rope_offs = tl.arange(0, rope_dim)
         
     | 
| 
      
 140 
     | 
    
         
            +
                src_rope_ptr = k_rope_ptr + pid_loc * rope_stride0 + rope_offs[None, :]
         
     | 
| 
      
 141 
     | 
    
         
            +
                dst_rope_ptr = k_head_ptr + nope_dim + rope_offs[None, :]
         
     | 
| 
      
 142 
     | 
    
         
            +
                src_rope = tl.load(src_rope_ptr)
         
     | 
| 
      
 143 
     | 
    
         
            +
                tl.store(dst_rope_ptr, src_rope)
         
     | 
| 
      
 144 
     | 
    
         
            +
             
     | 
| 
      
 145 
     | 
    
         
            +
             
     | 
| 
      
 146 
     | 
    
         
            +
            def concat_and_cast_mha_k_triton(
         
     | 
| 
      
 147 
     | 
    
         
            +
                k: torch.Tensor,
         
     | 
| 
      
 148 
     | 
    
         
            +
                k_nope: torch.Tensor,
         
     | 
| 
      
 149 
     | 
    
         
            +
                k_rope: torch.Tensor,
         
     | 
| 
      
 150 
     | 
    
         
            +
            ):
         
     | 
| 
      
 151 
     | 
    
         
            +
                # The source data type will be implicitly converted to the target data type.
         
     | 
| 
      
 152 
     | 
    
         
            +
                assert (
         
     | 
| 
      
 153 
     | 
    
         
            +
                    len(k.shape) == 3 and len(k_nope.shape) == 3 and len(k_rope.shape) == 3
         
     | 
| 
      
 154 
     | 
    
         
            +
                ), f"shape should be 3d, but got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
         
     | 
| 
      
 155 
     | 
    
         
            +
                assert (
         
     | 
| 
      
 156 
     | 
    
         
            +
                    k.shape[0] == k_nope.shape[0] and k.shape[0] == k_rope.shape[0]
         
     | 
| 
      
 157 
     | 
    
         
            +
                ), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
         
     | 
| 
      
 158 
     | 
    
         
            +
                assert (
         
     | 
| 
      
 159 
     | 
    
         
            +
                    k.shape[1] == k_nope.shape[1] and 1 == k_rope.shape[1]
         
     | 
| 
      
 160 
     | 
    
         
            +
                ), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
         
     | 
| 
      
 161 
     | 
    
         
            +
                assert (
         
     | 
| 
      
 162 
     | 
    
         
            +
                    k.shape[-1] == k_nope.shape[-1] + k_rope.shape[-1]
         
     | 
| 
      
 163 
     | 
    
         
            +
                ), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
         
     | 
| 
      
 164 
     | 
    
         
            +
             
     | 
| 
      
 165 
     | 
    
         
            +
                nope_dim = k_nope.shape[-1]
         
     | 
| 
      
 166 
     | 
    
         
            +
                rope_dim = k_rope.shape[-1]
         
     | 
| 
      
 167 
     | 
    
         
            +
                grid = (k.shape[0],)
         
     | 
| 
      
 168 
     | 
    
         
            +
             
     | 
| 
      
 169 
     | 
    
         
            +
                concat_and_cast_mha_k_kernel[grid](
         
     | 
| 
      
 170 
     | 
    
         
            +
                    k,
         
     | 
| 
      
 171 
     | 
    
         
            +
                    k_nope,
         
     | 
| 
      
 172 
     | 
    
         
            +
                    k_rope,
         
     | 
| 
      
 173 
     | 
    
         
            +
                    k.shape[1],
         
     | 
| 
      
 174 
     | 
    
         
            +
                    k.stride(0),
         
     | 
| 
      
 175 
     | 
    
         
            +
                    k.stride(1),
         
     | 
| 
      
 176 
     | 
    
         
            +
                    k_nope.stride(0),
         
     | 
| 
      
 177 
     | 
    
         
            +
                    k_nope.stride(1),
         
     | 
| 
      
 178 
     | 
    
         
            +
                    k_rope.stride(0),
         
     | 
| 
      
 179 
     | 
    
         
            +
                    nope_dim,
         
     | 
| 
      
 180 
     | 
    
         
            +
                    rope_dim,
         
     | 
| 
      
 181 
     | 
    
         
            +
                )
         
     | 
| 
         @@ -15,7 +15,7 @@ 
     | 
|
| 
       15 
15 
     | 
    
         
             
            from dataclasses import dataclass
         
     | 
| 
       16 
16 
     | 
    
         
             
            from enum import Enum, auto
         
     | 
| 
       17 
17 
     | 
    
         
             
            from functools import partial
         
     | 
| 
       18 
     | 
    
         
            -
            from typing import Dict, Optional
         
     | 
| 
      
 18 
     | 
    
         
            +
            from typing import Dict, List, Optional
         
     | 
| 
       19 
19 
     | 
    
         | 
| 
       20 
20 
     | 
    
         
             
            import torch
         
     | 
| 
       21 
21 
     | 
    
         | 
| 
         @@ -216,6 +216,28 @@ class LayerCommunicator: 
     | 
|
| 
       216 
216 
     | 
    
         
             
                        get_global_server_args().speculative_algorithm
         
     | 
| 
       217 
217 
     | 
    
         
             
                    )
         
     | 
| 
       218 
218 
     | 
    
         | 
| 
      
 219 
     | 
    
         
            +
                def prepare_attn_and_capture_last_layer_outputs(
         
     | 
| 
      
 220 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 221 
     | 
    
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 
      
 222 
     | 
    
         
            +
                    residual: torch.Tensor,
         
     | 
| 
      
 223 
     | 
    
         
            +
                    forward_batch: ForwardBatch,
         
     | 
| 
      
 224 
     | 
    
         
            +
                    captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
         
     | 
| 
      
 225 
     | 
    
         
            +
                ):
         
     | 
| 
      
 226 
     | 
    
         
            +
                    hidden_states, residual = self.prepare_attn(
         
     | 
| 
      
 227 
     | 
    
         
            +
                        hidden_states, residual, forward_batch
         
     | 
| 
      
 228 
     | 
    
         
            +
                    )
         
     | 
| 
      
 229 
     | 
    
         
            +
                    if captured_last_layer_outputs is not None:
         
     | 
| 
      
 230 
     | 
    
         
            +
                        gathered_last_layer_output = self._communicate_simple_fn(
         
     | 
| 
      
 231 
     | 
    
         
            +
                            hidden_states=residual,
         
     | 
| 
      
 232 
     | 
    
         
            +
                            forward_batch=forward_batch,
         
     | 
| 
      
 233 
     | 
    
         
            +
                            context=self._context,
         
     | 
| 
      
 234 
     | 
    
         
            +
                        )
         
     | 
| 
      
 235 
     | 
    
         
            +
                        if gathered_last_layer_output is residual:
         
     | 
| 
      
 236 
     | 
    
         
            +
                            # Clone to avoid modifying the original residual by Custom RMSNorm inplace operation
         
     | 
| 
      
 237 
     | 
    
         
            +
                            gathered_last_layer_output = residual.clone()
         
     | 
| 
      
 238 
     | 
    
         
            +
                        captured_last_layer_outputs.append(gathered_last_layer_output)
         
     | 
| 
      
 239 
     | 
    
         
            +
                    return hidden_states, residual
         
     | 
| 
      
 240 
     | 
    
         
            +
             
     | 
| 
       219 
241 
     | 
    
         
             
                def prepare_attn(
         
     | 
| 
       220 
242 
     | 
    
         
             
                    self,
         
     | 
| 
       221 
243 
     | 
    
         
             
                    hidden_states: torch.Tensor,
         
     | 
| 
         @@ -337,6 +359,7 @@ class LayerCommunicator: 
     | 
|
| 
       337 
359 
     | 
    
         
             
                    static_conditions_met = (
         
     | 
| 
       338 
360 
     | 
    
         
             
                        (not self.is_last_layer)
         
     | 
| 
       339 
361 
     | 
    
         
             
                        and (self._context.tp_size > 1)
         
     | 
| 
      
 362 
     | 
    
         
            +
                        and not is_dp_attention_enabled()
         
     | 
| 
       340 
363 
     | 
    
         
             
                        and get_global_server_args().enable_flashinfer_allreduce_fusion
         
     | 
| 
       341 
364 
     | 
    
         
             
                        and _is_flashinfer_available
         
     | 
| 
       342 
365 
     | 
    
         
             
                    )
         
     | 
| 
         @@ -26,7 +26,7 @@ _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "fal 
     | 
|
| 
       26 
26 
     | 
    
         | 
| 
       27 
27 
     | 
    
         
             
            # Force redirect deep_gemm cache_dir
         
     | 
| 
       28 
28 
     | 
    
         
             
            os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
         
     | 
| 
       29 
     | 
    
         
            -
                " 
     | 
| 
      
 29 
     | 
    
         
            +
                "SGLANG_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
         
     | 
| 
       30 
30 
     | 
    
         
             
            )
         
     | 
| 
       31 
31 
     | 
    
         | 
| 
       32 
32 
     | 
    
         
             
            # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
         
     | 
    
        sglang/srt/layers/layernorm.py
    CHANGED
    
    | 
         @@ -20,7 +20,12 @@ import torch 
     | 
|
| 
       20 
20 
     | 
    
         
             
            import torch.nn as nn
         
     | 
| 
       21 
21 
     | 
    
         
             
            from packaging.version import Version
         
     | 
| 
       22 
22 
     | 
    
         | 
| 
      
 23 
     | 
    
         
            +
            from sglang.srt.batch_invariant_ops import (
         
     | 
| 
      
 24 
     | 
    
         
            +
                is_batch_invariant_mode_enabled,
         
     | 
| 
      
 25 
     | 
    
         
            +
                rms_norm_batch_invariant,
         
     | 
| 
      
 26 
     | 
    
         
            +
            )
         
     | 
| 
       23 
27 
     | 
    
         
             
            from sglang.srt.custom_op import CustomOp
         
     | 
| 
      
 28 
     | 
    
         
            +
            from sglang.srt.server_args import get_global_server_args
         
     | 
| 
       24 
29 
     | 
    
         
             
            from sglang.srt.utils import (
         
     | 
| 
       25 
30 
     | 
    
         
             
                cpu_has_amx_support,
         
     | 
| 
       26 
31 
     | 
    
         
             
                get_bool_env_var,
         
     | 
| 
         @@ -73,9 +78,16 @@ class RMSNorm(CustomOp): 
     | 
|
| 
       73 
78 
     | 
    
         
             
                    hidden_size: int,
         
     | 
| 
       74 
79 
     | 
    
         
             
                    eps: float = 1e-6,
         
     | 
| 
       75 
80 
     | 
    
         
             
                    var_hidden_size: Optional[int] = None,
         
     | 
| 
      
 81 
     | 
    
         
            +
                    cast_x_before_out_mul: bool = False,
         
     | 
| 
      
 82 
     | 
    
         
            +
                    fp32_residual: bool = False,
         
     | 
| 
      
 83 
     | 
    
         
            +
                    weight_dtype: Optional = None,
         
     | 
| 
      
 84 
     | 
    
         
            +
                    override_orig_dtype: Optional = None,
         
     | 
| 
       76 
85 
     | 
    
         
             
                ) -> None:
         
     | 
| 
       77 
86 
     | 
    
         
             
                    super().__init__()
         
     | 
| 
       78 
     | 
    
         
            -
                    self. 
     | 
| 
      
 87 
     | 
    
         
            +
                    self.cast_x_before_out_mul = cast_x_before_out_mul
         
     | 
| 
      
 88 
     | 
    
         
            +
                    self.fp32_residual = fp32_residual
         
     | 
| 
      
 89 
     | 
    
         
            +
                    self.override_orig_dtype = override_orig_dtype
         
     | 
| 
      
 90 
     | 
    
         
            +
                    self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype))
         
     | 
| 
       79 
91 
     | 
    
         
             
                    self.variance_epsilon = eps
         
     | 
| 
       80 
92 
     | 
    
         
             
                    self.hidden_size = hidden_size
         
     | 
| 
       81 
93 
     | 
    
         
             
                    self.variance_size_override = (
         
     | 
| 
         @@ -83,8 +95,6 @@ class RMSNorm(CustomOp): 
     | 
|
| 
       83 
95 
     | 
    
         
             
                    )
         
     | 
| 
       84 
96 
     | 
    
         
             
                    if _use_aiter:
         
     | 
| 
       85 
97 
     | 
    
         
             
                        self._forward_method = self.forward_aiter
         
     | 
| 
       86 
     | 
    
         
            -
                    if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
         
     | 
| 
       87 
     | 
    
         
            -
                        self._forward_method = self.forward_native
         
     | 
| 
       88 
98 
     | 
    
         | 
| 
       89 
99 
     | 
    
         
             
                def forward_cuda(
         
     | 
| 
       90 
100 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -93,6 +103,17 @@ class RMSNorm(CustomOp): 
     | 
|
| 
       93 
103 
     | 
    
         
             
                ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         
     | 
| 
       94 
104 
     | 
    
         
             
                    if self.variance_size_override is not None:
         
     | 
| 
       95 
105 
     | 
    
         
             
                        return self.forward_native(x, residual)
         
     | 
| 
      
 106 
     | 
    
         
            +
                    if is_batch_invariant_mode_enabled():
         
     | 
| 
      
 107 
     | 
    
         
            +
                        if (
         
     | 
| 
      
 108 
     | 
    
         
            +
                            residual is not None
         
     | 
| 
      
 109 
     | 
    
         
            +
                            or get_global_server_args().rl_on_policy_target == "fsdp"
         
     | 
| 
      
 110 
     | 
    
         
            +
                        ):
         
     | 
| 
      
 111 
     | 
    
         
            +
                            return self.forward_native(x, residual)
         
     | 
| 
      
 112 
     | 
    
         
            +
                        return rms_norm_batch_invariant(
         
     | 
| 
      
 113 
     | 
    
         
            +
                            x,
         
     | 
| 
      
 114 
     | 
    
         
            +
                            self.weight.data,
         
     | 
| 
      
 115 
     | 
    
         
            +
                            self.variance_epsilon,
         
     | 
| 
      
 116 
     | 
    
         
            +
                        )
         
     | 
| 
       96 
117 
     | 
    
         
             
                    if residual is not None:
         
     | 
| 
       97 
118 
     | 
    
         
             
                        fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
         
     | 
| 
       98 
119 
     | 
    
         
             
                        return x, residual
         
     | 
| 
         @@ -165,11 +186,14 @@ class RMSNorm(CustomOp): 
     | 
|
| 
       165 
186 
     | 
    
         
             
                ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         
     | 
| 
       166 
187 
     | 
    
         
             
                    if not x.is_contiguous():
         
     | 
| 
       167 
188 
     | 
    
         
             
                        x = x.contiguous()
         
     | 
| 
       168 
     | 
    
         
            -
                    orig_dtype = x.dtype
         
     | 
| 
      
 189 
     | 
    
         
            +
                    orig_dtype = self.override_orig_dtype or x.dtype
         
     | 
| 
       169 
190 
     | 
    
         
             
                    x = x.to(torch.float32)
         
     | 
| 
       170 
191 
     | 
    
         
             
                    if residual is not None:
         
     | 
| 
       171 
192 
     | 
    
         
             
                        x = x + residual.to(torch.float32)
         
     | 
| 
       172 
     | 
    
         
            -
                         
     | 
| 
      
 193 
     | 
    
         
            +
                        if self.fp32_residual:
         
     | 
| 
      
 194 
     | 
    
         
            +
                            residual = x.clone()
         
     | 
| 
      
 195 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 196 
     | 
    
         
            +
                            residual = x.to(orig_dtype)
         
     | 
| 
       173 
197 
     | 
    
         | 
| 
       174 
198 
     | 
    
         
             
                    hidden_size = x.shape[-1]
         
     | 
| 
       175 
199 
     | 
    
         
             
                    if hidden_size != self.hidden_size:
         
     | 
| 
         @@ -191,7 +215,12 @@ class RMSNorm(CustomOp): 
     | 
|
| 
       191 
215 
     | 
    
         | 
| 
       192 
216 
     | 
    
         
             
                    variance = x_var.pow(2).mean(dim=-1, keepdim=True)
         
     | 
| 
       193 
217 
     | 
    
         
             
                    x = x * torch.rsqrt(variance + self.variance_epsilon)
         
     | 
| 
       194 
     | 
    
         
            -
             
     | 
| 
      
 218 
     | 
    
         
            +
             
     | 
| 
      
 219 
     | 
    
         
            +
                    if self.cast_x_before_out_mul:
         
     | 
| 
      
 220 
     | 
    
         
            +
                        x = self.weight * x.to(orig_dtype)
         
     | 
| 
      
 221 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 222 
     | 
    
         
            +
                        x = (x * self.weight).to(orig_dtype)
         
     | 
| 
      
 223 
     | 
    
         
            +
             
     | 
| 
       195 
224 
     | 
    
         
             
                    if residual is None:
         
     | 
| 
       196 
225 
     | 
    
         
             
                        return x
         
     | 
| 
       197 
226 
     | 
    
         
             
                    else:
         
     | 
| 
         @@ -38,7 +38,6 @@ from sglang.srt.layers.dp_attention import ( 
     | 
|
| 
       38 
38 
     | 
    
         
             
                get_dp_device,
         
     | 
| 
       39 
39 
     | 
    
         
             
                get_dp_dtype,
         
     | 
| 
       40 
40 
     | 
    
         
             
                get_dp_hidden_size,
         
     | 
| 
       41 
     | 
    
         
            -
                get_local_attention_dp_size,
         
     | 
| 
       42 
41 
     | 
    
         
             
            )
         
     | 
| 
       43 
42 
     | 
    
         
             
            from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
         
     | 
| 
       44 
43 
     | 
    
         
             
            from sglang.srt.model_executor.forward_batch_info import (
         
     | 
| 
         @@ -47,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import ( 
     | 
|
| 
       47 
46 
     | 
    
         
             
                ForwardMode,
         
     | 
| 
       48 
47 
     | 
    
         
             
            )
         
     | 
| 
       49 
48 
     | 
    
         
             
            from sglang.srt.server_args import get_global_server_args
         
     | 
| 
       50 
     | 
    
         
            -
            from sglang.srt.utils import  
     | 
| 
      
 49 
     | 
    
         
            +
            from sglang.srt.utils import is_npu, use_intel_amx_backend
         
     | 
| 
       51 
50 
     | 
    
         | 
| 
       52 
51 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
       53 
52 
     | 
    
         | 
| 
         @@ -135,10 +134,7 @@ class LogitsMetadata: 
     | 
|
| 
       135 
134 
     | 
    
         
             
                @classmethod
         
     | 
| 
       136 
135 
     | 
    
         
             
                def from_forward_batch(cls, forward_batch: ForwardBatch):
         
     | 
| 
       137 
136 
     | 
    
         
             
                    if (
         
     | 
| 
       138 
     | 
    
         
            -
                        (
         
     | 
| 
       139 
     | 
    
         
            -
                            forward_batch.forward_mode.is_extend()
         
     | 
| 
       140 
     | 
    
         
            -
                            or forward_batch.forward_mode.is_split_prefill()
         
     | 
| 
       141 
     | 
    
         
            -
                        )
         
     | 
| 
      
 137 
     | 
    
         
            +
                        forward_batch.forward_mode.is_extend()
         
     | 
| 
       142 
138 
     | 
    
         
             
                        and forward_batch.return_logprob
         
     | 
| 
       143 
139 
     | 
    
         
             
                        and not forward_batch.forward_mode.is_target_verify()
         
     | 
| 
       144 
140 
     | 
    
         
             
                    ):
         
     | 
| 
         @@ -252,10 +248,6 @@ class LogitsProcessor(nn.Module): 
     | 
|
| 
       252 
248 
     | 
    
         
             
                    ):
         
     | 
| 
       253 
249 
     | 
    
         
             
                        self.final_logit_softcapping = None
         
     | 
| 
       254 
250 
     | 
    
         | 
| 
       255 
     | 
    
         
            -
                    self.debug_tensor_dump_output_folder = (
         
     | 
| 
       256 
     | 
    
         
            -
                        get_global_server_args().debug_tensor_dump_output_folder
         
     | 
| 
       257 
     | 
    
         
            -
                    )
         
     | 
| 
       258 
     | 
    
         
            -
             
     | 
| 
       259 
251 
     | 
    
         
             
                def compute_logprobs_for_multi_item_scoring(
         
     | 
| 
       260 
252 
     | 
    
         
             
                    self,
         
     | 
| 
       261 
253 
     | 
    
         
             
                    input_ids,
         
     | 
| 
         @@ -389,8 +381,8 @@ class LogitsProcessor(nn.Module): 
     | 
|
| 
       389 
381 
     | 
    
         
             
                        input_logprob_indices = None
         
     | 
| 
       390 
382 
     | 
    
         
             
                    elif (
         
     | 
| 
       391 
383 
     | 
    
         
             
                        logits_metadata.forward_mode.is_extend()
         
     | 
| 
       392 
     | 
    
         
            -
                         
     | 
| 
       393 
     | 
    
         
            -
                    ) 
     | 
| 
      
 384 
     | 
    
         
            +
                        and not logits_metadata.extend_return_logprob
         
     | 
| 
      
 385 
     | 
    
         
            +
                    ):
         
     | 
| 
       394 
386 
     | 
    
         
             
                        # Prefill without input logprobs.
         
     | 
| 
       395 
387 
     | 
    
         
             
                        if logits_metadata.padded_static_len < 0:
         
     | 
| 
       396 
388 
     | 
    
         
             
                            last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
         
     | 
| 
         @@ -463,14 +455,6 @@ class LogitsProcessor(nn.Module): 
     | 
|
| 
       463 
455 
     | 
    
         
             
                        logits[sample_indices] if sample_indices is not None else logits
         
     | 
| 
       464 
456 
     | 
    
         
             
                    )
         
     | 
| 
       465 
457 
     | 
    
         | 
| 
       466 
     | 
    
         
            -
                    if self.debug_tensor_dump_output_folder:
         
     | 
| 
       467 
     | 
    
         
            -
                        assert (
         
     | 
| 
       468 
     | 
    
         
            -
                            not self.do_tensor_parallel_all_gather
         
     | 
| 
       469 
     | 
    
         
            -
                            or get_local_attention_dp_size() == 1
         
     | 
| 
       470 
     | 
    
         
            -
                        ), "dp attention + sharded lm_head doesn't support full logits"
         
     | 
| 
       471 
     | 
    
         
            -
                        full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
         
     | 
| 
       472 
     | 
    
         
            -
                        dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
         
     | 
| 
       473 
     | 
    
         
            -
             
     | 
| 
       474 
458 
     | 
    
         
             
                    hidden_states_to_store: Optional[torch.Tensor] = None
         
     | 
| 
       475 
459 
     | 
    
         
             
                    if logits_metadata.capture_hidden_mode.need_capture():
         
     | 
| 
       476 
460 
     | 
    
         
             
                        if logits_metadata.capture_hidden_mode.is_full():
         
     | 
| 
         @@ -593,6 +577,11 @@ class LogitsProcessor(nn.Module): 
     | 
|
| 
       593 
577 
     | 
    
         
             
                                None,  # bias
         
     | 
| 
       594 
578 
     | 
    
         
             
                                True,  # is_vnni
         
     | 
| 
       595 
579 
     | 
    
         
             
                            )
         
     | 
| 
      
 580 
     | 
    
         
            +
                        elif get_global_server_args().rl_on_policy_target == "fsdp":
         
     | 
| 
      
 581 
     | 
    
         
            +
                            # Due to tie-weight, we may not be able to change lm_head's weight dtype
         
     | 
| 
      
 582 
     | 
    
         
            +
                            logits = torch.matmul(
         
     | 
| 
      
 583 
     | 
    
         
            +
                                hidden_states.bfloat16(), lm_head.weight.T.bfloat16()
         
     | 
| 
      
 584 
     | 
    
         
            +
                            )
         
     | 
| 
       596 
585 
     | 
    
         
             
                        else:
         
     | 
| 
       597 
586 
     | 
    
         
             
                            logits = torch.matmul(
         
     | 
| 
       598 
587 
     | 
    
         
             
                                hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
         
     | 
| 
         @@ -11,12 +11,14 @@ from sgl_kernel import ( 
     | 
|
| 
       11 
11 
     | 
    
         
             
            )
         
     | 
| 
       12 
12 
     | 
    
         | 
| 
       13 
13 
     | 
    
         
             
            from sglang.srt.layers.moe.ep_moe.kernels import (
         
     | 
| 
      
 14 
     | 
    
         
            +
                deepep_ll_get_cutlass_w4a8_moe_mm_data,
         
     | 
| 
       14 
15 
     | 
    
         
             
                deepep_permute_triton_kernel,
         
     | 
| 
       15 
16 
     | 
    
         
             
                deepep_post_reorder_triton_kernel,
         
     | 
| 
       16 
17 
     | 
    
         
             
                deepep_run_moe_deep_preprocess,
         
     | 
| 
       17 
18 
     | 
    
         
             
                post_reorder_triton_kernel_for_cutlass_moe,
         
     | 
| 
       18 
19 
     | 
    
         
             
                pre_reorder_triton_kernel_for_cutlass_moe,
         
     | 
| 
       19 
20 
     | 
    
         
             
                run_moe_ep_preproess,
         
     | 
| 
      
 21 
     | 
    
         
            +
                silu_and_mul_masked_post_per_tensor_quant_fwd,
         
     | 
| 
       20 
22 
     | 
    
         
             
            )
         
     | 
| 
       21 
23 
     | 
    
         | 
| 
       22 
24 
     | 
    
         | 
| 
         @@ -396,3 +398,139 @@ def cutlass_w4a8_moe_deepep_normal( 
     | 
|
| 
       396 
398 
     | 
    
         
             
                )
         
     | 
| 
       397 
399 
     | 
    
         | 
| 
       398 
400 
     | 
    
         
             
                return output
         
     | 
| 
      
 401 
     | 
    
         
            +
             
     | 
| 
      
 402 
     | 
    
         
            +
             
     | 
| 
      
 403 
     | 
    
         
            +
            def cutlass_w4a8_moe_deepep_ll(
         
     | 
| 
      
 404 
     | 
    
         
            +
                a: torch.Tensor,
         
     | 
| 
      
 405 
     | 
    
         
            +
                w1_q: torch.Tensor,
         
     | 
| 
      
 406 
     | 
    
         
            +
                w2_q: torch.Tensor,
         
     | 
| 
      
 407 
     | 
    
         
            +
                w1_scale: torch.Tensor,
         
     | 
| 
      
 408 
     | 
    
         
            +
                w2_scale: torch.Tensor,
         
     | 
| 
      
 409 
     | 
    
         
            +
                topk_ids_: torch.Tensor,
         
     | 
| 
      
 410 
     | 
    
         
            +
                masked_m: torch.Tensor,
         
     | 
| 
      
 411 
     | 
    
         
            +
                a_strides1: torch.Tensor,
         
     | 
| 
      
 412 
     | 
    
         
            +
                b_strides1: torch.Tensor,
         
     | 
| 
      
 413 
     | 
    
         
            +
                c_strides1: torch.Tensor,
         
     | 
| 
      
 414 
     | 
    
         
            +
                a_strides2: torch.Tensor,
         
     | 
| 
      
 415 
     | 
    
         
            +
                b_strides2: torch.Tensor,
         
     | 
| 
      
 416 
     | 
    
         
            +
                c_strides2: torch.Tensor,
         
     | 
| 
      
 417 
     | 
    
         
            +
                s_strides13: torch.Tensor,
         
     | 
| 
      
 418 
     | 
    
         
            +
                s_strides2: torch.Tensor,
         
     | 
| 
      
 419 
     | 
    
         
            +
                expert_offsets: torch.Tensor,
         
     | 
| 
      
 420 
     | 
    
         
            +
                problem_sizes1: torch.Tensor,
         
     | 
| 
      
 421 
     | 
    
         
            +
                problem_sizes2: torch.Tensor,
         
     | 
| 
      
 422 
     | 
    
         
            +
                a1_scale: Optional[torch.Tensor] = None,
         
     | 
| 
      
 423 
     | 
    
         
            +
                a2_scale: Optional[torch.Tensor] = None,
         
     | 
| 
      
 424 
     | 
    
         
            +
            ) -> torch.Tensor:
         
     | 
| 
      
 425 
     | 
    
         
            +
                """
         
     | 
| 
      
 426 
     | 
    
         
            +
                This function computes a w4a8-quantized Mixture of Experts (MoE) layer
         
     | 
| 
      
 427 
     | 
    
         
            +
                using two sets of quantized weights, w1_q and w2_q, and top-k gating
         
     | 
| 
      
 428 
     | 
    
         
            +
                mechanism. The matrix multiplications are implemented with CUTLASS
         
     | 
| 
      
 429 
     | 
    
         
            +
                grouped gemm.
         
     | 
| 
      
 430 
     | 
    
         
            +
             
     | 
| 
      
 431 
     | 
    
         
            +
                Parameters:
         
     | 
| 
      
 432 
     | 
    
         
            +
                - a (torch.Tensor): The input tensor to the MoE layer.
         
     | 
| 
      
 433 
     | 
    
         
            +
                    Shape: [num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, K]
         
     | 
| 
      
 434 
     | 
    
         
            +
                - w1_q (torch.Tensor): The first set of int4-quantized expert weights.
         
     | 
| 
      
 435 
     | 
    
         
            +
                    Shape: [num_experts, N * 2,  K // 2]
         
     | 
| 
      
 436 
     | 
    
         
            +
                    (the weights are passed transposed and int4-packed)
         
     | 
| 
      
 437 
     | 
    
         
            +
                - w2_q (torch.Tensor): The second set of int4-quantized expert weights.
         
     | 
| 
      
 438 
     | 
    
         
            +
                    Shape: [num_experts, K, N // 2]
         
     | 
| 
      
 439 
     | 
    
         
            +
                    (the weights are passed transposed and int4-packed)
         
     | 
| 
      
 440 
     | 
    
         
            +
                - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
         
     | 
| 
      
 441 
     | 
    
         
            +
                    Shape: [num_experts, K // 512, N * 8]
         
     | 
| 
      
 442 
     | 
    
         
            +
                - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
         
     | 
| 
      
 443 
     | 
    
         
            +
                    Shape: [num_experts, N // 512, K * 4]
         
     | 
| 
      
 444 
     | 
    
         
            +
                - topk_weights (torch.Tensor): The weights of each token->expert mapping.
         
     | 
| 
      
 445 
     | 
    
         
            +
                - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
         
     | 
| 
      
 446 
     | 
    
         
            +
                - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
         
     | 
| 
      
 447 
     | 
    
         
            +
                - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
         
     | 
| 
      
 448 
     | 
    
         
            +
                - a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
         
     | 
| 
      
 449 
     | 
    
         
            +
                - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
         
     | 
| 
      
 450 
     | 
    
         
            +
                - c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
         
     | 
| 
      
 451 
     | 
    
         
            +
                - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
         
     | 
| 
      
 452 
     | 
    
         
            +
                - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
         
     | 
| 
      
 453 
     | 
    
         
            +
                - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
         
     | 
| 
      
 454 
     | 
    
         
            +
                    Shape: scalar or [1, K]
         
     | 
| 
      
 455 
     | 
    
         
            +
                - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
         
     | 
| 
      
 456 
     | 
    
         
            +
                    quantize the intermediate result between the gemms.
         
     | 
| 
      
 457 
     | 
    
         
            +
                    Shape: scalar or [1, N]
         
     | 
| 
      
 458 
     | 
    
         
            +
                - apply_router_weight_on_input (bool): When true, the topk weights are
         
     | 
| 
      
 459 
     | 
    
         
            +
                    applied directly on the inputs. This is only applicable when topk is 1.
         
     | 
| 
      
 460 
     | 
    
         
            +
             
     | 
| 
      
 461 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 462 
     | 
    
         
            +
                - torch.Tensor: The fp8 output tensor after applying the MoE layer.
         
     | 
| 
      
 463 
     | 
    
         
            +
                """
         
     | 
| 
      
 464 
     | 
    
         
            +
                assert w1_q.dtype == torch.int8
         
     | 
| 
      
 465 
     | 
    
         
            +
                assert w2_q.dtype == torch.int8
         
     | 
| 
      
 466 
     | 
    
         
            +
                assert a.shape[2] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
         
     | 
| 
      
 467 
     | 
    
         
            +
                assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
         
     | 
| 
      
 468 
     | 
    
         
            +
                assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
         
     | 
| 
      
 469 
     | 
    
         
            +
                assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
         
     | 
| 
      
 470 
     | 
    
         
            +
                assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
         
     | 
| 
      
 471 
     | 
    
         
            +
             
     | 
| 
      
 472 
     | 
    
         
            +
                assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
         
     | 
| 
      
 473 
     | 
    
         
            +
                assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
         
     | 
| 
      
 474 
     | 
    
         
            +
                assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
         
     | 
| 
      
 475 
     | 
    
         
            +
                assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
         
     | 
| 
      
 476 
     | 
    
         
            +
                num_experts = w1_q.size(0)
         
     | 
| 
      
 477 
     | 
    
         
            +
                m = a.size(1)
         
     | 
| 
      
 478 
     | 
    
         
            +
                k = w1_q.size(2) * 2  # w1_q is transposed and packed
         
     | 
| 
      
 479 
     | 
    
         
            +
                n = w2_q.size(2) * 2  # w2_q is transposed and packed
         
     | 
| 
      
 480 
     | 
    
         
            +
                topk = topk_ids_.size(1)
         
     | 
| 
      
 481 
     | 
    
         
            +
             
     | 
| 
      
 482 
     | 
    
         
            +
                device = a.device
         
     | 
| 
      
 483 
     | 
    
         
            +
             
     | 
| 
      
 484 
     | 
    
         
            +
                problem_sizes1, problem_sizes2 = deepep_ll_get_cutlass_w4a8_moe_mm_data(
         
     | 
| 
      
 485 
     | 
    
         
            +
                    masked_m,
         
     | 
| 
      
 486 
     | 
    
         
            +
                    problem_sizes1,
         
     | 
| 
      
 487 
     | 
    
         
            +
                    problem_sizes2,
         
     | 
| 
      
 488 
     | 
    
         
            +
                    num_experts,
         
     | 
| 
      
 489 
     | 
    
         
            +
                    n,
         
     | 
| 
      
 490 
     | 
    
         
            +
                    k,
         
     | 
| 
      
 491 
     | 
    
         
            +
                )
         
     | 
| 
      
 492 
     | 
    
         
            +
             
     | 
| 
      
 493 
     | 
    
         
            +
                gateup_input = torch.empty(a.shape, dtype=torch.float8_e4m3fn, device=device)
         
     | 
| 
      
 494 
     | 
    
         
            +
                sgl_per_tensor_quant_fp8(a, gateup_input, a1_scale.float(), True)
         
     | 
| 
      
 495 
     | 
    
         
            +
                c1 = torch.empty((num_experts, m, n * 2), device=device, dtype=torch.bfloat16)
         
     | 
| 
      
 496 
     | 
    
         
            +
                c2 = torch.empty((num_experts, m, k), device=device, dtype=torch.bfloat16)
         
     | 
| 
      
 497 
     | 
    
         
            +
             
     | 
| 
      
 498 
     | 
    
         
            +
                cutlass_w4a8_moe_mm(
         
     | 
| 
      
 499 
     | 
    
         
            +
                    c1,
         
     | 
| 
      
 500 
     | 
    
         
            +
                    gateup_input,
         
     | 
| 
      
 501 
     | 
    
         
            +
                    w1_q,
         
     | 
| 
      
 502 
     | 
    
         
            +
                    a1_scale.float(),
         
     | 
| 
      
 503 
     | 
    
         
            +
                    w1_scale,
         
     | 
| 
      
 504 
     | 
    
         
            +
                    expert_offsets[:-1],
         
     | 
| 
      
 505 
     | 
    
         
            +
                    problem_sizes1,
         
     | 
| 
      
 506 
     | 
    
         
            +
                    a_strides1,
         
     | 
| 
      
 507 
     | 
    
         
            +
                    b_strides1,
         
     | 
| 
      
 508 
     | 
    
         
            +
                    c_strides1,
         
     | 
| 
      
 509 
     | 
    
         
            +
                    s_strides13,
         
     | 
| 
      
 510 
     | 
    
         
            +
                    128,
         
     | 
| 
      
 511 
     | 
    
         
            +
                    topk,
         
     | 
| 
      
 512 
     | 
    
         
            +
                )
         
     | 
| 
      
 513 
     | 
    
         
            +
             
     | 
| 
      
 514 
     | 
    
         
            +
                intermediate_q = torch.empty(
         
     | 
| 
      
 515 
     | 
    
         
            +
                    (num_experts, m, n), device=a.device, dtype=torch.float8_e4m3fn
         
     | 
| 
      
 516 
     | 
    
         
            +
                )
         
     | 
| 
      
 517 
     | 
    
         
            +
                silu_and_mul_masked_post_per_tensor_quant_fwd(
         
     | 
| 
      
 518 
     | 
    
         
            +
                    c1, intermediate_q, masked_m, a2_scale
         
     | 
| 
      
 519 
     | 
    
         
            +
                )
         
     | 
| 
      
 520 
     | 
    
         
            +
                cutlass_w4a8_moe_mm(
         
     | 
| 
      
 521 
     | 
    
         
            +
                    c2,
         
     | 
| 
      
 522 
     | 
    
         
            +
                    intermediate_q,
         
     | 
| 
      
 523 
     | 
    
         
            +
                    w2_q,
         
     | 
| 
      
 524 
     | 
    
         
            +
                    a2_scale.float(),
         
     | 
| 
      
 525 
     | 
    
         
            +
                    w2_scale,
         
     | 
| 
      
 526 
     | 
    
         
            +
                    expert_offsets[:-1],
         
     | 
| 
      
 527 
     | 
    
         
            +
                    problem_sizes2,
         
     | 
| 
      
 528 
     | 
    
         
            +
                    a_strides2,
         
     | 
| 
      
 529 
     | 
    
         
            +
                    b_strides2,
         
     | 
| 
      
 530 
     | 
    
         
            +
                    c_strides2,
         
     | 
| 
      
 531 
     | 
    
         
            +
                    s_strides2,
         
     | 
| 
      
 532 
     | 
    
         
            +
                    128,
         
     | 
| 
      
 533 
     | 
    
         
            +
                    topk,
         
     | 
| 
      
 534 
     | 
    
         
            +
                )
         
     | 
| 
      
 535 
     | 
    
         
            +
             
     | 
| 
      
 536 
     | 
    
         
            +
                return c2
         
     | 
| 
         @@ -1014,3 +1014,197 @@ def zero_experts_compute_triton( 
     | 
|
| 
       1014 
1014 
     | 
    
         
             
                )
         
     | 
| 
       1015 
1015 
     | 
    
         | 
| 
       1016 
1016 
     | 
    
         
             
                return output
         
     | 
| 
      
 1017 
     | 
    
         
            +
             
     | 
| 
      
 1018 
     | 
    
         
            +
             
     | 
| 
      
 1019 
     | 
    
         
            +
            @triton.jit
         
     | 
| 
      
 1020 
     | 
    
         
            +
            def compute_problem_sizes_w4a8_kernel(
         
     | 
| 
      
 1021 
     | 
    
         
            +
                masked_m_ptr,
         
     | 
| 
      
 1022 
     | 
    
         
            +
                problem_sizes1_ptr,
         
     | 
| 
      
 1023 
     | 
    
         
            +
                problem_sizes2_ptr,
         
     | 
| 
      
 1024 
     | 
    
         
            +
                n,
         
     | 
| 
      
 1025 
     | 
    
         
            +
                k,
         
     | 
| 
      
 1026 
     | 
    
         
            +
                num_experts,
         
     | 
| 
      
 1027 
     | 
    
         
            +
                BLOCK_SIZE: tl.constexpr,
         
     | 
| 
      
 1028 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1029 
     | 
    
         
            +
                pid = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
         
     | 
| 
      
 1030 
     | 
    
         
            +
                mask = pid < num_experts
         
     | 
| 
      
 1031 
     | 
    
         
            +
                final_occurrences = tl.load(masked_m_ptr + pid, mask=mask, other=0)
         
     | 
| 
      
 1032 
     | 
    
         
            +
             
     | 
| 
      
 1033 
     | 
    
         
            +
                ps1_idx_0 = pid * 3
         
     | 
| 
      
 1034 
     | 
    
         
            +
                ps1_idx_1 = ps1_idx_0 + 1
         
     | 
| 
      
 1035 
     | 
    
         
            +
                ps1_idx_2 = ps1_idx_0 + 2
         
     | 
| 
      
 1036 
     | 
    
         
            +
             
     | 
| 
      
 1037 
     | 
    
         
            +
                ps2_idx_0 = pid * 3
         
     | 
| 
      
 1038 
     | 
    
         
            +
                ps2_idx_1 = ps2_idx_0 + 1
         
     | 
| 
      
 1039 
     | 
    
         
            +
                ps2_idx_2 = ps2_idx_0 + 2
         
     | 
| 
      
 1040 
     | 
    
         
            +
             
     | 
| 
      
 1041 
     | 
    
         
            +
                ps1_mask_0 = ps1_idx_0 < num_experts * 3
         
     | 
| 
      
 1042 
     | 
    
         
            +
                ps1_mask_1 = ps1_idx_1 < num_experts * 3
         
     | 
| 
      
 1043 
     | 
    
         
            +
                ps1_mask_2 = ps1_idx_2 < num_experts * 3
         
     | 
| 
      
 1044 
     | 
    
         
            +
                ps2_mask_0 = ps2_idx_0 < num_experts * 3
         
     | 
| 
      
 1045 
     | 
    
         
            +
                ps2_mask_1 = ps2_idx_1 < num_experts * 3
         
     | 
| 
      
 1046 
     | 
    
         
            +
                ps2_mask_2 = ps2_idx_2 < num_experts * 3
         
     | 
| 
      
 1047 
     | 
    
         
            +
             
     | 
| 
      
 1048 
     | 
    
         
            +
                tl.store(problem_sizes1_ptr + ps1_idx_0, 2 * n, mask=ps1_mask_0)
         
     | 
| 
      
 1049 
     | 
    
         
            +
                tl.store(problem_sizes1_ptr + ps1_idx_1, final_occurrences, mask=ps1_mask_1)
         
     | 
| 
      
 1050 
     | 
    
         
            +
                tl.store(problem_sizes1_ptr + ps1_idx_2, k, mask=ps1_mask_2)
         
     | 
| 
      
 1051 
     | 
    
         
            +
             
     | 
| 
      
 1052 
     | 
    
         
            +
                tl.store(problem_sizes2_ptr + ps2_idx_0, k, mask=ps2_mask_0)
         
     | 
| 
      
 1053 
     | 
    
         
            +
                tl.store(problem_sizes2_ptr + ps2_idx_1, final_occurrences, mask=ps2_mask_1)
         
     | 
| 
      
 1054 
     | 
    
         
            +
                tl.store(problem_sizes2_ptr + ps2_idx_2, n, mask=ps2_mask_2)
         
     | 
| 
      
 1055 
     | 
    
         
            +
             
     | 
| 
      
 1056 
     | 
    
         
            +
             
     | 
| 
      
 1057 
     | 
    
         
            +
            def compute_problem_sizes_w4a8(
         
     | 
| 
      
 1058 
     | 
    
         
            +
                masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
         
     | 
| 
      
 1059 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1060 
     | 
    
         
            +
                BLOCK_SIZE = 256
         
     | 
| 
      
 1061 
     | 
    
         
            +
                grid = lambda meta: (triton.cdiv(num_experts, meta["BLOCK_SIZE"]),)
         
     | 
| 
      
 1062 
     | 
    
         
            +
                compute_problem_sizes_w4a8_kernel[grid](
         
     | 
| 
      
 1063 
     | 
    
         
            +
                    masked_m,
         
     | 
| 
      
 1064 
     | 
    
         
            +
                    problem_sizes1,
         
     | 
| 
      
 1065 
     | 
    
         
            +
                    problem_sizes2,
         
     | 
| 
      
 1066 
     | 
    
         
            +
                    n,
         
     | 
| 
      
 1067 
     | 
    
         
            +
                    k,
         
     | 
| 
      
 1068 
     | 
    
         
            +
                    num_experts,
         
     | 
| 
      
 1069 
     | 
    
         
            +
                    BLOCK_SIZE=BLOCK_SIZE,
         
     | 
| 
      
 1070 
     | 
    
         
            +
                )
         
     | 
| 
      
 1071 
     | 
    
         
            +
                return problem_sizes1, problem_sizes2
         
     | 
| 
      
 1072 
     | 
    
         
            +
             
     | 
| 
      
 1073 
     | 
    
         
            +
             
     | 
| 
      
 1074 
     | 
    
         
            +
            def deepep_ll_get_cutlass_w4a8_moe_mm_data(
         
     | 
| 
      
 1075 
     | 
    
         
            +
                masked_m,
         
     | 
| 
      
 1076 
     | 
    
         
            +
                problem_sizes1,
         
     | 
| 
      
 1077 
     | 
    
         
            +
                problem_sizes2,
         
     | 
| 
      
 1078 
     | 
    
         
            +
                num_experts,
         
     | 
| 
      
 1079 
     | 
    
         
            +
                n,
         
     | 
| 
      
 1080 
     | 
    
         
            +
                k,
         
     | 
| 
      
 1081 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1082 
     | 
    
         
            +
                problem_sizes1, problem_sizes2 = compute_problem_sizes_w4a8(
         
     | 
| 
      
 1083 
     | 
    
         
            +
                    masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
         
     | 
| 
      
 1084 
     | 
    
         
            +
                )
         
     | 
| 
      
 1085 
     | 
    
         
            +
                return (
         
     | 
| 
      
 1086 
     | 
    
         
            +
                    problem_sizes1.to(torch.int32),
         
     | 
| 
      
 1087 
     | 
    
         
            +
                    problem_sizes2.to(torch.int32),
         
     | 
| 
      
 1088 
     | 
    
         
            +
                )
         
     | 
| 
      
 1089 
     | 
    
         
            +
             
     | 
| 
      
 1090 
     | 
    
         
            +
             
     | 
| 
      
 1091 
     | 
    
         
            +
            @triton.jit
         
     | 
| 
      
 1092 
     | 
    
         
            +
            def _silu_and_mul_post_per_tensor_quant_kernel(
         
     | 
| 
      
 1093 
     | 
    
         
            +
                input_ptr,
         
     | 
| 
      
 1094 
     | 
    
         
            +
                stride_input_expert,
         
     | 
| 
      
 1095 
     | 
    
         
            +
                stride_input_token,
         
     | 
| 
      
 1096 
     | 
    
         
            +
                stride_input_dim,
         
     | 
| 
      
 1097 
     | 
    
         
            +
                output_ptr,
         
     | 
| 
      
 1098 
     | 
    
         
            +
                stride_output_expert,
         
     | 
| 
      
 1099 
     | 
    
         
            +
                stride_output_token,
         
     | 
| 
      
 1100 
     | 
    
         
            +
                stride_output_dim,
         
     | 
| 
      
 1101 
     | 
    
         
            +
                scale_ptr,
         
     | 
| 
      
 1102 
     | 
    
         
            +
                masked_m_ptr,
         
     | 
| 
      
 1103 
     | 
    
         
            +
                inner_dim,
         
     | 
| 
      
 1104 
     | 
    
         
            +
                fp8_max,
         
     | 
| 
      
 1105 
     | 
    
         
            +
                fp8_min,
         
     | 
| 
      
 1106 
     | 
    
         
            +
                BLOCK_N: tl.constexpr,
         
     | 
| 
      
 1107 
     | 
    
         
            +
                NUM_STAGE: tl.constexpr,
         
     | 
| 
      
 1108 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1109 
     | 
    
         
            +
                """
         
     | 
| 
      
 1110 
     | 
    
         
            +
                Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization.
         
     | 
| 
      
 1111 
     | 
    
         
            +
             
     | 
| 
      
 1112 
     | 
    
         
            +
                Shape:
         
     | 
| 
      
 1113 
     | 
    
         
            +
                    input:  [E, T_padded, 2*D]  -> gate: [:,:,D], up: [:,:,D]
         
     | 
| 
      
 1114 
     | 
    
         
            +
                    output: [E, T_padded, D], dtype=float8_e4m3fn
         
     | 
| 
      
 1115 
     | 
    
         
            +
                """
         
     | 
| 
      
 1116 
     | 
    
         
            +
                expert_id = tl.program_id(2)
         
     | 
| 
      
 1117 
     | 
    
         
            +
                block_id_token = tl.program_id(1)
         
     | 
| 
      
 1118 
     | 
    
         
            +
                block_id_dim = tl.program_id(0)
         
     | 
| 
      
 1119 
     | 
    
         
            +
             
     | 
| 
      
 1120 
     | 
    
         
            +
                num_token_blocks = tl.num_programs(1)
         
     | 
| 
      
 1121 
     | 
    
         
            +
             
     | 
| 
      
 1122 
     | 
    
         
            +
                token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
         
     | 
| 
      
 1123 
     | 
    
         
            +
             
     | 
| 
      
 1124 
     | 
    
         
            +
                scale = 1.0 / tl.load(scale_ptr).to(tl.float32)
         
     | 
| 
      
 1125 
     | 
    
         
            +
             
     | 
| 
      
 1126 
     | 
    
         
            +
                stride_input_expert = tl.cast(stride_input_expert, tl.int32)
         
     | 
| 
      
 1127 
     | 
    
         
            +
                stride_output_expert = tl.cast(stride_output_expert, tl.int32)
         
     | 
| 
      
 1128 
     | 
    
         
            +
                stride_input_token = tl.cast(stride_input_token, tl.int32)
         
     | 
| 
      
 1129 
     | 
    
         
            +
                stride_output_token = tl.cast(stride_output_token, tl.int32)
         
     | 
| 
      
 1130 
     | 
    
         
            +
             
     | 
| 
      
 1131 
     | 
    
         
            +
                offset_d = block_id_dim * BLOCK_N + tl.arange(0, BLOCK_N)
         
     | 
| 
      
 1132 
     | 
    
         
            +
                mask_d = offset_d < inner_dim
         
     | 
| 
      
 1133 
     | 
    
         
            +
             
     | 
| 
      
 1134 
     | 
    
         
            +
                # base pointers for current expert and dim block
         
     | 
| 
      
 1135 
     | 
    
         
            +
                input_base_offs = input_ptr + expert_id * stride_input_expert + offset_d
         
     | 
| 
      
 1136 
     | 
    
         
            +
                output_base_offs = output_ptr + expert_id * stride_output_expert + offset_d
         
     | 
| 
      
 1137 
     | 
    
         
            +
             
     | 
| 
      
 1138 
     | 
    
         
            +
                for token_idx in tl.range(
         
     | 
| 
      
 1139 
     | 
    
         
            +
                    block_id_token, token_num_cur_expert, num_token_blocks, num_stages=NUM_STAGE
         
     | 
| 
      
 1140 
     | 
    
         
            +
                ):
         
     | 
| 
      
 1141 
     | 
    
         
            +
                    gate_ptr = input_base_offs + token_idx * stride_input_token
         
     | 
| 
      
 1142 
     | 
    
         
            +
                    up_ptr = gate_ptr + inner_dim
         
     | 
| 
      
 1143 
     | 
    
         
            +
                    gate = tl.load(gate_ptr, mask=mask_d, other=0.0).to(tl.float32)
         
     | 
| 
      
 1144 
     | 
    
         
            +
                    up = tl.load(up_ptr, mask=mask_d, other=0.0).to(tl.float32)
         
     | 
| 
      
 1145 
     | 
    
         
            +
             
     | 
| 
      
 1146 
     | 
    
         
            +
                    # SiLU: x * sigmoid(x)
         
     | 
| 
      
 1147 
     | 
    
         
            +
                    gate = gate / (1 + tl.exp(-gate))
         
     | 
| 
      
 1148 
     | 
    
         
            +
                    gate = gate.to(input_ptr.dtype.element_ty)
         
     | 
| 
      
 1149 
     | 
    
         
            +
                    gate_up = up * gate
         
     | 
| 
      
 1150 
     | 
    
         
            +
             
     | 
| 
      
 1151 
     | 
    
         
            +
                    scaled = gate_up * scale
         
     | 
| 
      
 1152 
     | 
    
         
            +
                    output_q = tl.clamp(scaled, fp8_min, fp8_max).to(output_ptr.dtype.element_ty)
         
     | 
| 
      
 1153 
     | 
    
         
            +
                    out_ptr = output_base_offs + token_idx * stride_output_token
         
     | 
| 
      
 1154 
     | 
    
         
            +
                    tl.store(out_ptr, output_q, mask=mask_d)
         
     | 
| 
      
 1155 
     | 
    
         
            +
             
     | 
| 
      
 1156 
     | 
    
         
            +
             
     | 
| 
      
 1157 
     | 
    
         
            +
            def silu_and_mul_masked_post_per_tensor_quant_fwd(
         
     | 
| 
      
 1158 
     | 
    
         
            +
                input: torch.Tensor,
         
     | 
| 
      
 1159 
     | 
    
         
            +
                output: torch.Tensor,
         
     | 
| 
      
 1160 
     | 
    
         
            +
                masked_m: torch.Tensor,
         
     | 
| 
      
 1161 
     | 
    
         
            +
                scale: torch.Tensor,
         
     | 
| 
      
 1162 
     | 
    
         
            +
            ) -> torch.Tensor:
         
     | 
| 
      
 1163 
     | 
    
         
            +
                """
         
     | 
| 
      
 1164 
     | 
    
         
            +
                Fused SiLU + Mul + Per-Tensor Quantization to FP8.
         
     | 
| 
      
 1165 
     | 
    
         
            +
             
     | 
| 
      
 1166 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 1167 
     | 
    
         
            +
                    input: [expert_num, token_num_padded, 2 * inner_dim]
         
     | 
| 
      
 1168 
     | 
    
         
            +
                    output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn
         
     | 
| 
      
 1169 
     | 
    
         
            +
                    masked_m: [expert_num], actual token count for each expert
         
     | 
| 
      
 1170 
     | 
    
         
            +
                    scale: [1] or [expert_num], quantization scale (per-tensor or per-expert)
         
     | 
| 
      
 1171 
     | 
    
         
            +
             
     | 
| 
      
 1172 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 1173 
     | 
    
         
            +
                    output tensor
         
     | 
| 
      
 1174 
     | 
    
         
            +
                """
         
     | 
| 
      
 1175 
     | 
    
         
            +
                assert input.is_contiguous()
         
     | 
| 
      
 1176 
     | 
    
         
            +
                assert output.is_contiguous()
         
     | 
| 
      
 1177 
     | 
    
         
            +
                assert output.dtype == torch.float8_e4m3fn
         
     | 
| 
      
 1178 
     | 
    
         
            +
                assert input.ndim == 3
         
     | 
| 
      
 1179 
     | 
    
         
            +
                assert input.shape[0] == masked_m.shape[0]
         
     | 
| 
      
 1180 
     | 
    
         
            +
                assert input.shape[-1] % 2 == 0
         
     | 
| 
      
 1181 
     | 
    
         
            +
                assert scale.numel() == 1 or scale.shape[0] == input.shape[0]
         
     | 
| 
      
 1182 
     | 
    
         
            +
             
     | 
| 
      
 1183 
     | 
    
         
            +
                expert_num = input.shape[0]
         
     | 
| 
      
 1184 
     | 
    
         
            +
                #  3584
         
     | 
| 
      
 1185 
     | 
    
         
            +
                inner_dim = input.shape[-1] // 2
         
     | 
| 
      
 1186 
     | 
    
         
            +
             
     | 
| 
      
 1187 
     | 
    
         
            +
                BLOCK_N = 256
         
     | 
| 
      
 1188 
     | 
    
         
            +
                BLOCK_M = 64 if expert_num < 4 else 32
         
     | 
| 
      
 1189 
     | 
    
         
            +
                NUM_STAGES = 3
         
     | 
| 
      
 1190 
     | 
    
         
            +
                hidden_dim_split_block_num = triton.cdiv(inner_dim, BLOCK_N)
         
     | 
| 
      
 1191 
     | 
    
         
            +
             
     | 
| 
      
 1192 
     | 
    
         
            +
                grid = (hidden_dim_split_block_num, BLOCK_M, expert_num)
         
     | 
| 
      
 1193 
     | 
    
         
            +
                finfo = torch.finfo(torch.float8_e4m3fn)
         
     | 
| 
      
 1194 
     | 
    
         
            +
                fp8_max = finfo.max
         
     | 
| 
      
 1195 
     | 
    
         
            +
                fp8_min = -fp8_max
         
     | 
| 
      
 1196 
     | 
    
         
            +
             
     | 
| 
      
 1197 
     | 
    
         
            +
                _silu_and_mul_post_per_tensor_quant_kernel[grid](
         
     | 
| 
      
 1198 
     | 
    
         
            +
                    input,
         
     | 
| 
      
 1199 
     | 
    
         
            +
                    *input.stride(),
         
     | 
| 
      
 1200 
     | 
    
         
            +
                    output,
         
     | 
| 
      
 1201 
     | 
    
         
            +
                    *output.stride(),
         
     | 
| 
      
 1202 
     | 
    
         
            +
                    scale,
         
     | 
| 
      
 1203 
     | 
    
         
            +
                    masked_m,
         
     | 
| 
      
 1204 
     | 
    
         
            +
                    inner_dim,
         
     | 
| 
      
 1205 
     | 
    
         
            +
                    fp8_max,
         
     | 
| 
      
 1206 
     | 
    
         
            +
                    fp8_min,
         
     | 
| 
      
 1207 
     | 
    
         
            +
                    BLOCK_N=BLOCK_N,
         
     | 
| 
      
 1208 
     | 
    
         
            +
                    NUM_STAGE=NUM_STAGES,
         
     | 
| 
      
 1209 
     | 
    
         
            +
                )
         
     | 
| 
      
 1210 
     | 
    
         
            +
                return output
         
     |