sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +18 -3
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +25 -2
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -5
 - sglang/srt/entrypoints/engine.py +13 -5
 - sglang/srt/entrypoints/http_server.py +22 -3
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +7 -0
 - sglang/srt/eplb/expert_distribution.py +34 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +7 -6
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
 - sglang/srt/layers/communicator.py +23 -1
 - sglang/srt/layers/layernorm.py +16 -2
 - sglang/srt/layers/logits_processor.py +4 -20
 - sglang/srt/layers/moe/ep_moe/layer.py +0 -18
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
 - sglang/srt/layers/moe/topk.py +31 -6
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +9 -78
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/rotary_embedding.py +117 -45
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +26 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +164 -129
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +154 -59
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/memory_pool.py +171 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +11 -11
 - sglang/srt/model_executor/model_runner.py +76 -21
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +149 -34
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +0 -1
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +1 -1
 - sglang/srt/models/qwen2.py +1 -1
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +1 -1
 - sglang/srt/models/qwen3_moe.py +16 -8
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +103 -22
 - sglang/srt/single_batch_overlap.py +4 -1
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +55 -32
 - sglang/srt/utils/hf_transformers_utils.py +38 -16
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -25,6 +25,13 @@ from sglang.srt.utils import ( 
     | 
|
| 
       25 
25 
     | 
    
         
             
                is_hip,
         
     | 
| 
       26 
26 
     | 
    
         
             
            )
         
     | 
| 
       27 
27 
     | 
    
         | 
| 
      
 28 
     | 
    
         
            +
            try:
         
     | 
| 
      
 29 
     | 
    
         
            +
                from triton.tools.tensor_descriptor import TensorDescriptor
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
                _support_tensor_descriptor = True
         
     | 
| 
      
 32 
     | 
    
         
            +
            except:
         
     | 
| 
      
 33 
     | 
    
         
            +
                _support_tensor_descriptor = False
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
       28 
35 
     | 
    
         
             
            _is_hip = is_hip()
         
     | 
| 
       29 
36 
     | 
    
         
             
            _is_cuda = is_cuda()
         
     | 
| 
       30 
37 
     | 
    
         
             
            _is_cpu_amx_available = cpu_has_amx_support()
         
     | 
| 
         @@ -41,6 +48,10 @@ elif _is_hip: 
     | 
|
| 
       41 
48 
     | 
    
         
             
            padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
         
     | 
| 
       42 
49 
     | 
    
         | 
| 
       43 
50 
     | 
    
         | 
| 
      
 51 
     | 
    
         
            +
            def support_tensor_descriptor():
         
     | 
| 
      
 52 
     | 
    
         
            +
                return _support_tensor_descriptor
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
       44 
55 
     | 
    
         
             
            @triton.jit
         
     | 
| 
       45 
56 
     | 
    
         
             
            def write_zeros_to_output(
         
     | 
| 
       46 
57 
     | 
    
         
             
                c_ptr,
         
     | 
| 
         @@ -108,6 +119,7 @@ def fused_moe_kernel_gptq_awq( 
     | 
|
| 
       108 
119 
     | 
    
         
             
                use_int4_w4a16: tl.constexpr,
         
     | 
| 
       109 
120 
     | 
    
         
             
                use_int8_w8a16: tl.constexpr,
         
     | 
| 
       110 
121 
     | 
    
         
             
                even_Ks: tl.constexpr,
         
     | 
| 
      
 122 
     | 
    
         
            +
                filter_expert: tl.constexpr,
         
     | 
| 
       111 
123 
     | 
    
         
             
            ):
         
     | 
| 
       112 
124 
     | 
    
         
             
                """
         
     | 
| 
       113 
125 
     | 
    
         
             
                Implements the fused computation for a Mixture of Experts (MOE) using
         
     | 
| 
         @@ -161,7 +173,7 @@ def fused_moe_kernel_gptq_awq( 
     | 
|
| 
       161 
173 
     | 
    
         
             
                token_mask = offs_token < num_valid_tokens
         
     | 
| 
       162 
174 
     | 
    
         | 
| 
       163 
175 
     | 
    
         
             
                off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
         
     | 
| 
       164 
     | 
    
         
            -
                if off_experts == -1:
         
     | 
| 
      
 176 
     | 
    
         
            +
                if filter_expert and off_experts == -1:
         
     | 
| 
       165 
177 
     | 
    
         
             
                    # -----------------------------------------------------------
         
     | 
| 
       166 
178 
     | 
    
         
             
                    # Write back zeros to the output when the expert is not
         
     | 
| 
       167 
179 
     | 
    
         
             
                    # in the current expert parallel rank.
         
     | 
| 
         @@ -296,7 +308,9 @@ def fused_moe_kernel_gptq_awq( 
     | 
|
| 
       296 
308 
     | 
    
         
             
            def fused_moe_kernel(
         
     | 
| 
       297 
309 
     | 
    
         
             
                # Pointers to matrices
         
     | 
| 
       298 
310 
     | 
    
         
             
                a_ptr,
         
     | 
| 
      
 311 
     | 
    
         
            +
                a_desc,
         
     | 
| 
       299 
312 
     | 
    
         
             
                b_ptr,
         
     | 
| 
      
 313 
     | 
    
         
            +
                b_desc,
         
     | 
| 
       300 
314 
     | 
    
         
             
                bias_ptr,
         
     | 
| 
       301 
315 
     | 
    
         
             
                c_ptr,
         
     | 
| 
       302 
316 
     | 
    
         
             
                a_scale_ptr,
         
     | 
| 
         @@ -344,6 +358,8 @@ def fused_moe_kernel( 
     | 
|
| 
       344 
358 
     | 
    
         
             
                use_int8_w8a16: tl.constexpr,
         
     | 
| 
       345 
359 
     | 
    
         
             
                per_channel_quant: tl.constexpr,
         
     | 
| 
       346 
360 
     | 
    
         
             
                even_Ks: tl.constexpr,
         
     | 
| 
      
 361 
     | 
    
         
            +
                c_sorted: tl.constexpr,
         
     | 
| 
      
 362 
     | 
    
         
            +
                filter_expert: tl.constexpr,
         
     | 
| 
       347 
363 
     | 
    
         
             
            ):
         
     | 
| 
       348 
364 
     | 
    
         
             
                """
         
     | 
| 
       349 
365 
     | 
    
         
             
                Implements the fused computation for a Mixture of Experts (MOE) using
         
     | 
| 
         @@ -399,9 +415,10 @@ def fused_moe_kernel( 
     | 
|
| 
       399 
415 
     | 
    
         
             
                offs_token = offs_token.to(tl.int64)
         
     | 
| 
       400 
416 
     | 
    
         
             
                token_mask = offs_token < num_valid_tokens
         
     | 
| 
       401 
417 
     | 
    
         | 
| 
       402 
     | 
    
         
            -
                 
     | 
| 
      
 418 
     | 
    
         
            +
                off_experts_i32 = tl.load(expert_ids_ptr + pid_m)
         
     | 
| 
      
 419 
     | 
    
         
            +
                off_experts = off_experts_i32.to(tl.int64)
         
     | 
| 
       403 
420 
     | 
    
         | 
| 
       404 
     | 
    
         
            -
                if off_experts == -1:
         
     | 
| 
      
 421 
     | 
    
         
            +
                if filter_expert and off_experts == -1:
         
     | 
| 
       405 
422 
     | 
    
         
             
                    # -----------------------------------------------------------
         
     | 
| 
       406 
423 
     | 
    
         
             
                    # Write back zeros to the output when the expert is not
         
     | 
| 
       407 
424 
     | 
    
         
             
                    # in the current expert parallel rank.
         
     | 
| 
         @@ -421,15 +438,23 @@ def fused_moe_kernel( 
     | 
|
| 
       421 
438 
     | 
    
         | 
| 
       422 
439 
     | 
    
         
             
                offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
         
     | 
| 
       423 
440 
     | 
    
         
             
                offs_k = tl.arange(0, BLOCK_SIZE_K)
         
     | 
| 
       424 
     | 
    
         
            -
                 
     | 
| 
       425 
     | 
    
         
            -
                     
     | 
| 
       426 
     | 
    
         
            -
             
     | 
| 
      
 441 
     | 
    
         
            +
                if a_desc is not None:
         
     | 
| 
      
 442 
     | 
    
         
            +
                    assert use_fp8_w8a8 and group_n > 0 and group_k > 0
         
     | 
| 
      
 443 
     | 
    
         
            +
                    start_offs_m = pid_m * BLOCK_SIZE_M
         
     | 
| 
      
 444 
     | 
    
         
            +
                else:
         
     | 
| 
      
 445 
     | 
    
         
            +
                    a_ptrs = a_ptr + (
         
     | 
| 
      
 446 
     | 
    
         
            +
                        offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
         
     | 
| 
      
 447 
     | 
    
         
            +
                    )
         
     | 
| 
      
 448 
     | 
    
         
            +
             
     | 
| 
      
 449 
     | 
    
         
            +
                if b_desc is not None:
         
     | 
| 
      
 450 
     | 
    
         
            +
                    start_offs_n = pid_n * BLOCK_SIZE_N
         
     | 
| 
      
 451 
     | 
    
         
            +
                else:
         
     | 
| 
      
 452 
     | 
    
         
            +
                    b_ptrs = (
         
     | 
| 
      
 453 
     | 
    
         
            +
                        b_ptr
         
     | 
| 
      
 454 
     | 
    
         
            +
                        + off_experts * stride_be
         
     | 
| 
      
 455 
     | 
    
         
            +
                        + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
         
     | 
| 
      
 456 
     | 
    
         
            +
                    )
         
     | 
| 
       427 
457 
     | 
    
         | 
| 
       428 
     | 
    
         
            -
                b_ptrs = (
         
     | 
| 
       429 
     | 
    
         
            -
                    b_ptr
         
     | 
| 
       430 
     | 
    
         
            -
                    + off_experts * stride_be
         
     | 
| 
       431 
     | 
    
         
            -
                    + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
         
     | 
| 
       432 
     | 
    
         
            -
                )
         
     | 
| 
       433 
458 
     | 
    
         
             
                if bias_ptr is not None:
         
     | 
| 
       434 
459 
     | 
    
         
             
                    bias = tl.load(
         
     | 
| 
       435 
460 
     | 
    
         
             
                        bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
         
     | 
| 
         @@ -443,8 +468,14 @@ def fused_moe_kernel( 
     | 
|
| 
       443 
468 
     | 
    
         
             
                if use_fp8_w8a8 or use_int8_w8a8:
         
     | 
| 
       444 
469 
     | 
    
         
             
                    # block-wise
         
     | 
| 
       445 
470 
     | 
    
         
             
                    if group_k > 0 and group_n > 0:
         
     | 
| 
       446 
     | 
    
         
            -
                         
     | 
| 
       447 
     | 
    
         
            -
             
     | 
| 
      
 471 
     | 
    
         
            +
                        if a_desc is not None:
         
     | 
| 
      
 472 
     | 
    
         
            +
                            a_scale_ptrs = a_scale_ptr + offs_token_id * stride_asm
         
     | 
| 
      
 473 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 474 
     | 
    
         
            +
                            a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
         
     | 
| 
      
 475 
     | 
    
         
            +
                        if BLOCK_SIZE_N > group_n:
         
     | 
| 
      
 476 
     | 
    
         
            +
                            offs_bsn = offs_bn // group_n
         
     | 
| 
      
 477 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 478 
     | 
    
         
            +
                            offs_bsn = pid_n * BLOCK_SIZE_N // group_n
         
     | 
| 
       448 
479 
     | 
    
         
             
                        b_scale_ptrs = (
         
     | 
| 
       449 
480 
     | 
    
         
             
                            b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
         
     | 
| 
       450 
481 
     | 
    
         
             
                        )
         
     | 
| 
         @@ -469,37 +500,49 @@ def fused_moe_kernel( 
     | 
|
| 
       469 
500 
     | 
    
         
             
                # `accumulator` will be converted back to fp16 after the loop.
         
     | 
| 
       470 
501 
     | 
    
         
             
                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
         
     | 
| 
       471 
502 
     | 
    
         | 
| 
       472 
     | 
    
         
            -
                for  
     | 
| 
      
 503 
     | 
    
         
            +
                for k_start in range(0, K, BLOCK_SIZE_K):
         
     | 
| 
       473 
504 
     | 
    
         
             
                    # Load the next block of A and B, generate a mask by checking the
         
     | 
| 
       474 
505 
     | 
    
         
             
                    # K dimension.
         
     | 
| 
       475 
     | 
    
         
            -
                    if  
     | 
| 
      
 506 
     | 
    
         
            +
                    if a_desc is not None:
         
     | 
| 
      
 507 
     | 
    
         
            +
                        a = a_desc.load([start_offs_m, k_start])
         
     | 
| 
      
 508 
     | 
    
         
            +
                    elif even_Ks:
         
     | 
| 
       476 
509 
     | 
    
         
             
                        a = tl.load(
         
     | 
| 
       477 
510 
     | 
    
         
             
                            a_ptrs,
         
     | 
| 
       478 
511 
     | 
    
         
             
                            mask=token_mask[:, None],
         
     | 
| 
       479 
512 
     | 
    
         
             
                            other=0.0,
         
     | 
| 
       480 
513 
     | 
    
         
             
                        )
         
     | 
| 
       481 
     | 
    
         
            -
                        b = tl.load(b_ptrs)
         
     | 
| 
       482 
514 
     | 
    
         
             
                    else:
         
     | 
| 
       483 
515 
     | 
    
         
             
                        a = tl.load(
         
     | 
| 
       484 
516 
     | 
    
         
             
                            a_ptrs,
         
     | 
| 
       485 
     | 
    
         
            -
                            mask=token_mask[:, None] & (offs_k[None, :] < K -  
     | 
| 
      
 517 
     | 
    
         
            +
                            mask=token_mask[:, None] & (offs_k[None, :] < K - k_start),
         
     | 
| 
       486 
518 
     | 
    
         
             
                            other=0.0,
         
     | 
| 
       487 
519 
     | 
    
         
             
                        )
         
     | 
| 
       488 
     | 
    
         
            -
             
     | 
| 
      
 520 
     | 
    
         
            +
             
     | 
| 
      
 521 
     | 
    
         
            +
                    if b_desc is not None:
         
     | 
| 
      
 522 
     | 
    
         
            +
                        b = (
         
     | 
| 
      
 523 
     | 
    
         
            +
                            b_desc.load([off_experts_i32, start_offs_n, k_start])
         
     | 
| 
      
 524 
     | 
    
         
            +
                            .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K)
         
     | 
| 
      
 525 
     | 
    
         
            +
                            .T
         
     | 
| 
      
 526 
     | 
    
         
            +
                        )
         
     | 
| 
      
 527 
     | 
    
         
            +
                    elif even_Ks:
         
     | 
| 
      
 528 
     | 
    
         
            +
                        b = tl.load(b_ptrs)
         
     | 
| 
      
 529 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 530 
     | 
    
         
            +
                        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_start, other=0.0)
         
     | 
| 
       489 
531 
     | 
    
         | 
| 
       490 
532 
     | 
    
         
             
                    # We accumulate along the K dimension.
         
     | 
| 
       491 
533 
     | 
    
         
             
                    if use_int8_w8a16:
         
     | 
| 
       492 
534 
     | 
    
         
             
                        accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
         
     | 
| 
       493 
535 
     | 
    
         
             
                    elif use_fp8_w8a8 or use_int8_w8a8:
         
     | 
| 
       494 
536 
     | 
    
         
             
                        if group_k > 0 and group_n > 0:
         
     | 
| 
       495 
     | 
    
         
            -
                            k_start = k * BLOCK_SIZE_K
         
     | 
| 
       496 
537 
     | 
    
         
             
                            offs_ks = k_start // group_k
         
     | 
| 
       497 
538 
     | 
    
         
             
                            a_scale = tl.load(
         
     | 
| 
       498 
539 
     | 
    
         
             
                                a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
         
     | 
| 
       499 
540 
     | 
    
         
             
                            )
         
     | 
| 
       500 
541 
     | 
    
         
             
                            b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
         
     | 
| 
       501 
     | 
    
         
            -
             
     | 
| 
       502 
     | 
    
         
            -
             
     | 
| 
      
 542 
     | 
    
         
            +
                            if BLOCK_SIZE_N > group_n:
         
     | 
| 
      
 543 
     | 
    
         
            +
                                accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
         
     | 
| 
      
 544 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 545 
     | 
    
         
            +
                                accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale)
         
     | 
| 
       503 
546 
     | 
    
         
             
                        else:
         
     | 
| 
       504 
547 
     | 
    
         
             
                            if use_fp8_w8a8:
         
     | 
| 
       505 
548 
     | 
    
         
             
                                accumulator = tl.dot(a, b, acc=accumulator)
         
     | 
| 
         @@ -508,8 +551,10 @@ def fused_moe_kernel( 
     | 
|
| 
       508 
551 
     | 
    
         
             
                    else:
         
     | 
| 
       509 
552 
     | 
    
         
             
                        accumulator += tl.dot(a, b)
         
     | 
| 
       510 
553 
     | 
    
         
             
                    # Advance the ptrs to the next K block.
         
     | 
| 
       511 
     | 
    
         
            -
                     
     | 
| 
       512 
     | 
    
         
            -
             
     | 
| 
      
 554 
     | 
    
         
            +
                    if a_desc is None:
         
     | 
| 
      
 555 
     | 
    
         
            +
                        a_ptrs += BLOCK_SIZE_K * stride_ak
         
     | 
| 
      
 556 
     | 
    
         
            +
                    if b_desc is None:
         
     | 
| 
      
 557 
     | 
    
         
            +
                        b_ptrs += BLOCK_SIZE_K * stride_bk
         
     | 
| 
       513 
558 
     | 
    
         | 
| 
       514 
559 
     | 
    
         
             
                if use_int8_w8a16:
         
     | 
| 
       515 
560 
     | 
    
         
             
                    accumulator *= b_scale
         
     | 
| 
         @@ -528,7 +573,12 @@ def fused_moe_kernel( 
     | 
|
| 
       528 
573 
     | 
    
         
             
                # -----------------------------------------------------------
         
     | 
| 
       529 
574 
     | 
    
         
             
                # Write back the block of the output
         
     | 
| 
       530 
575 
     | 
    
         
             
                offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
         
     | 
| 
       531 
     | 
    
         
            -
                 
     | 
| 
      
 576 
     | 
    
         
            +
                if c_sorted:
         
     | 
| 
      
 577 
     | 
    
         
            +
                    c_ptrs = (
         
     | 
| 
      
 578 
     | 
    
         
            +
                        c_ptr + stride_cm * offs_token_id[:, None] + stride_cn * offs_cn[None, :]
         
     | 
| 
      
 579 
     | 
    
         
            +
                    )
         
     | 
| 
      
 580 
     | 
    
         
            +
                else:
         
     | 
| 
      
 581 
     | 
    
         
            +
                    c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
         
     | 
| 
       532 
582 
     | 
    
         
             
                c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
         
     | 
| 
       533 
583 
     | 
    
         
             
                tl.store(c_ptrs, accumulator, mask=c_mask)
         
     | 
| 
       534 
584 
     | 
    
         | 
| 
         @@ -557,6 +607,10 @@ def invoke_fused_moe_kernel( 
     | 
|
| 
       557 
607 
     | 
    
         
             
                per_channel_quant: bool,
         
     | 
| 
       558 
608 
     | 
    
         
             
                block_shape: Optional[List[int]] = None,
         
     | 
| 
       559 
609 
     | 
    
         
             
                no_combine: bool = False,
         
     | 
| 
      
 610 
     | 
    
         
            +
                a_use_tma: bool = False,
         
     | 
| 
      
 611 
     | 
    
         
            +
                b_use_tma: bool = False,
         
     | 
| 
      
 612 
     | 
    
         
            +
                c_sorted: bool = False,
         
     | 
| 
      
 613 
     | 
    
         
            +
                filter_expert: bool = True,
         
     | 
| 
       560 
614 
     | 
    
         
             
            ) -> None:
         
     | 
| 
       561 
615 
     | 
    
         
             
                assert topk_weights.stride(1) == 1
         
     | 
| 
       562 
616 
     | 
    
         
             
                assert sorted_token_ids.stride(0) == 1
         
     | 
| 
         @@ -662,14 +716,38 @@ def invoke_fused_moe_kernel( 
     | 
|
| 
       662 
716 
     | 
    
         
             
                        use_int4_w4a16=use_int4_w4a16,
         
     | 
| 
       663 
717 
     | 
    
         
             
                        use_int8_w8a16=use_int8_w8a16,
         
     | 
| 
       664 
718 
     | 
    
         
             
                        even_Ks=even_Ks,
         
     | 
| 
      
 719 
     | 
    
         
            +
                        filter_expert=filter_expert,
         
     | 
| 
       665 
720 
     | 
    
         
             
                        **config,
         
     | 
| 
       666 
721 
     | 
    
         
             
                    )
         
     | 
| 
       667 
722 
     | 
    
         | 
| 
       668 
723 
     | 
    
         
             
                else:
         
     | 
| 
      
 724 
     | 
    
         
            +
                    if a_use_tma or b_use_tma:
         
     | 
| 
      
 725 
     | 
    
         
            +
                        # TMA descriptors require a global memory allocation
         
     | 
| 
      
 726 
     | 
    
         
            +
                        def alloc_fn(size: int, alignment: int, stream: Optional[int]):
         
     | 
| 
      
 727 
     | 
    
         
            +
                            return torch.empty(size, device="cuda", dtype=torch.int8)
         
     | 
| 
      
 728 
     | 
    
         
            +
             
     | 
| 
      
 729 
     | 
    
         
            +
                        triton.set_allocator(alloc_fn)
         
     | 
| 
      
 730 
     | 
    
         
            +
                    if a_use_tma:
         
     | 
| 
      
 731 
     | 
    
         
            +
                        a_desc = TensorDescriptor(
         
     | 
| 
      
 732 
     | 
    
         
            +
                            A, A.shape, A.stride(), [config["BLOCK_SIZE_M"], config["BLOCK_SIZE_K"]]
         
     | 
| 
      
 733 
     | 
    
         
            +
                        )
         
     | 
| 
      
 734 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 735 
     | 
    
         
            +
                        a_desc = None
         
     | 
| 
      
 736 
     | 
    
         
            +
                    if b_use_tma:
         
     | 
| 
      
 737 
     | 
    
         
            +
                        b_desc = TensorDescriptor(
         
     | 
| 
      
 738 
     | 
    
         
            +
                            B,
         
     | 
| 
      
 739 
     | 
    
         
            +
                            B.shape,
         
     | 
| 
      
 740 
     | 
    
         
            +
                            B.stride(),
         
     | 
| 
      
 741 
     | 
    
         
            +
                            [1, config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"]],
         
     | 
| 
      
 742 
     | 
    
         
            +
                        )
         
     | 
| 
      
 743 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 744 
     | 
    
         
            +
                        b_desc = None
         
     | 
| 
       669 
745 
     | 
    
         | 
| 
       670 
746 
     | 
    
         
             
                    fused_moe_kernel[grid](
         
     | 
| 
       671 
747 
     | 
    
         
             
                        A,
         
     | 
| 
      
 748 
     | 
    
         
            +
                        a_desc,
         
     | 
| 
       672 
749 
     | 
    
         
             
                        B,
         
     | 
| 
      
 750 
     | 
    
         
            +
                        b_desc,
         
     | 
| 
       673 
751 
     | 
    
         
             
                        bias,
         
     | 
| 
       674 
752 
     | 
    
         
             
                        C,
         
     | 
| 
       675 
753 
     | 
    
         
             
                        A_scale,
         
     | 
| 
         @@ -689,8 +767,8 @@ def invoke_fused_moe_kernel( 
     | 
|
| 
       689 
767 
     | 
    
         
             
                        B.stride(1),
         
     | 
| 
       690 
768 
     | 
    
         
             
                        bias.stride(0) if bias is not None else 0,
         
     | 
| 
       691 
769 
     | 
    
         
             
                        bias.stride(1) if bias is not None else 0,
         
     | 
| 
       692 
     | 
    
         
            -
                        C.stride( 
     | 
| 
       693 
     | 
    
         
            -
                        C.stride( 
     | 
| 
      
 770 
     | 
    
         
            +
                        C.stride(-2),
         
     | 
| 
      
 771 
     | 
    
         
            +
                        C.stride(-1),
         
     | 
| 
       694 
772 
     | 
    
         
             
                        A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
         
     | 
| 
       695 
773 
     | 
    
         
             
                        A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
         
     | 
| 
       696 
774 
     | 
    
         
             
                        B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
         
     | 
| 
         @@ -706,6 +784,8 @@ def invoke_fused_moe_kernel( 
     | 
|
| 
       706 
784 
     | 
    
         
             
                        use_int8_w8a16=use_int8_w8a16,
         
     | 
| 
       707 
785 
     | 
    
         
             
                        per_channel_quant=per_channel_quant,
         
     | 
| 
       708 
786 
     | 
    
         
             
                        even_Ks=even_Ks,
         
     | 
| 
      
 787 
     | 
    
         
            +
                        c_sorted=c_sorted,
         
     | 
| 
      
 788 
     | 
    
         
            +
                        filter_expert=filter_expert,
         
     | 
| 
       709 
789 
     | 
    
         
             
                        **config,
         
     | 
| 
       710 
790 
     | 
    
         
             
                    )
         
     | 
| 
       711 
791 
     | 
    
         | 
| 
         @@ -39,6 +39,9 @@ if not (_is_npu or _is_hip): 
     | 
|
| 
       39 
39 
     | 
    
         
             
                from sgl_kernel import silu_and_mul
         
     | 
| 
       40 
40 
     | 
    
         | 
| 
       41 
41 
     | 
    
         | 
| 
      
 42 
     | 
    
         
            +
            _MASKED_GEMM_FAST_ACT = get_bool_env_var("SGLANG_MASKED_GEMM_FAST_ACT")
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
             
     | 
| 
       42 
45 
     | 
    
         
             
            # TODO(kaixih@nvidia): ideally we should merge this logic into
         
     | 
| 
       43 
46 
     | 
    
         
             
            # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
         
     | 
| 
       44 
47 
     | 
    
         
             
            @torch.compile
         
     | 
| 
         @@ -214,6 +217,9 @@ class DeepGemmRunnerCore(MoeRunnerCore): 
     | 
|
| 
       214 
217 
     | 
    
         
             
                    from sglang.srt.layers.moe.ep_moe.kernels import (
         
     | 
| 
       215 
218 
     | 
    
         
             
                        silu_and_mul_masked_post_quant_fwd,
         
     | 
| 
       216 
219 
     | 
    
         
             
                    )
         
     | 
| 
      
 220 
     | 
    
         
            +
                    from sglang.srt.layers.quantization.fp8_kernel import (
         
     | 
| 
      
 221 
     | 
    
         
            +
                        sglang_per_token_group_quant_8bit,
         
     | 
| 
      
 222 
     | 
    
         
            +
                    )
         
     | 
| 
       217 
223 
     | 
    
         | 
| 
       218 
224 
     | 
    
         
             
                    hidden_states = runner_input.hidden_states
         
     | 
| 
       219 
225 
     | 
    
         
             
                    hidden_states_scale = runner_input.hidden_states_scale
         
     | 
| 
         @@ -227,15 +233,16 @@ class DeepGemmRunnerCore(MoeRunnerCore): 
     | 
|
| 
       227 
233 
     | 
    
         | 
| 
       228 
234 
     | 
    
         
             
                    hidden_states_device = running_state["hidden_states_device"]
         
     | 
| 
       229 
235 
     | 
    
         | 
| 
       230 
     | 
    
         
            -
                    if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
         
     | 
| 
       231 
     | 
    
         
            -
                        b, s_mn, s_k = hidden_states_scale.shape
         
     | 
| 
       232 
     | 
    
         
            -
                        assert (
         
     | 
| 
       233 
     | 
    
         
            -
                            s_mn % 4 == 0 and s_k % 4 == 0
         
     | 
| 
       234 
     | 
    
         
            -
                        ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
         
     | 
| 
       235 
     | 
    
         
            -
             
     | 
| 
       236 
236 
     | 
    
         
             
                    # GroupGemm-0
         
     | 
| 
       237 
237 
     | 
    
         
             
                    if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
         
     | 
| 
       238 
     | 
    
         
            -
                        hidden_states_scale  
     | 
| 
      
 238 
     | 
    
         
            +
                        if hidden_states_scale.dtype != torch.int:
         
     | 
| 
      
 239 
     | 
    
         
            +
                            b, s_mn, s_k = hidden_states_scale.shape
         
     | 
| 
      
 240 
     | 
    
         
            +
                            assert (
         
     | 
| 
      
 241 
     | 
    
         
            +
                                s_mn % 4 == 0 and s_k % 4 == 0
         
     | 
| 
      
 242 
     | 
    
         
            +
                            ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
         
     | 
| 
      
 243 
     | 
    
         
            +
                            hidden_states_scale = _cast_to_e8m0_with_rounding_up(
         
     | 
| 
      
 244 
     | 
    
         
            +
                                hidden_states_scale
         
     | 
| 
      
 245 
     | 
    
         
            +
                            )
         
     | 
| 
       239 
246 
     | 
    
         
             
                    else:
         
     | 
| 
       240 
247 
     | 
    
         
             
                        hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
         
     | 
| 
       241 
248 
     | 
    
         
             
                            hidden_states_scale
         
     | 
| 
         @@ -257,33 +264,46 @@ class DeepGemmRunnerCore(MoeRunnerCore): 
     | 
|
| 
       257 
264 
     | 
    
         
             
                    dispose_tensor(hidden_states_scale)
         
     | 
| 
       258 
265 
     | 
    
         | 
| 
       259 
266 
     | 
    
         
             
                    # Act
         
     | 
| 
       260 
     | 
    
         
            -
                    down_input = torch.empty(
         
     | 
| 
       261 
     | 
    
         
            -
                        (
         
     | 
| 
       262 
     | 
    
         
            -
                            gateup_output.shape[0],
         
     | 
| 
       263 
     | 
    
         
            -
                            gateup_output.shape[1],
         
     | 
| 
       264 
     | 
    
         
            -
                            gateup_output.shape[2] // 2,
         
     | 
| 
       265 
     | 
    
         
            -
                        ),
         
     | 
| 
       266 
     | 
    
         
            -
                        device=hidden_states_device,
         
     | 
| 
       267 
     | 
    
         
            -
                        dtype=torch.float8_e4m3fn,
         
     | 
| 
       268 
     | 
    
         
            -
                    )
         
     | 
| 
       269 
267 
     | 
    
         
             
                    scale_block_size = 128
         
     | 
| 
       270 
     | 
    
         
            -
                     
     | 
| 
       271 
     | 
    
         
            -
                        (
         
     | 
| 
       272 
     | 
    
         
            -
                            gateup_output 
     | 
| 
       273 
     | 
    
         
            -
                             
     | 
| 
       274 
     | 
    
         
            -
                             
     | 
| 
       275 
     | 
    
         
            -
             
     | 
| 
       276 
     | 
    
         
            -
             
     | 
| 
       277 
     | 
    
         
            -
             
     | 
| 
       278 
     | 
    
         
            -
             
     | 
| 
       279 
     | 
    
         
            -
             
     | 
| 
       280 
     | 
    
         
            -
             
     | 
| 
       281 
     | 
    
         
            -
                         
     | 
| 
       282 
     | 
    
         
            -
             
     | 
| 
       283 
     | 
    
         
            -
                         
     | 
| 
       284 
     | 
    
         
            -
             
     | 
| 
       285 
     | 
    
         
            -
             
     | 
| 
       286 
     | 
    
         
            -
             
     | 
| 
      
 268 
     | 
    
         
            +
                    if _MASKED_GEMM_FAST_ACT:
         
     | 
| 
      
 269 
     | 
    
         
            +
                        down_input, down_input_scale = sglang_per_token_group_quant_8bit(
         
     | 
| 
      
 270 
     | 
    
         
            +
                            x=gateup_output,
         
     | 
| 
      
 271 
     | 
    
         
            +
                            dst_dtype=torch.float8_e4m3fn,
         
     | 
| 
      
 272 
     | 
    
         
            +
                            group_size=scale_block_size,
         
     | 
| 
      
 273 
     | 
    
         
            +
                            masked_m=masked_m,
         
     | 
| 
      
 274 
     | 
    
         
            +
                            column_major_scales=True,
         
     | 
| 
      
 275 
     | 
    
         
            +
                            scale_tma_aligned=True,
         
     | 
| 
      
 276 
     | 
    
         
            +
                            scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
         
     | 
| 
      
 277 
     | 
    
         
            +
                            fuse_silu_and_mul=True,
         
     | 
| 
      
 278 
     | 
    
         
            +
                            enable_v2=True,
         
     | 
| 
      
 279 
     | 
    
         
            +
                        )
         
     | 
| 
      
 280 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 281 
     | 
    
         
            +
                        down_input = torch.empty(
         
     | 
| 
      
 282 
     | 
    
         
            +
                            (
         
     | 
| 
      
 283 
     | 
    
         
            +
                                gateup_output.shape[0],
         
     | 
| 
      
 284 
     | 
    
         
            +
                                gateup_output.shape[1],
         
     | 
| 
      
 285 
     | 
    
         
            +
                                gateup_output.shape[2] // 2,
         
     | 
| 
      
 286 
     | 
    
         
            +
                            ),
         
     | 
| 
      
 287 
     | 
    
         
            +
                            device=hidden_states_device,
         
     | 
| 
      
 288 
     | 
    
         
            +
                            dtype=torch.float8_e4m3fn,
         
     | 
| 
      
 289 
     | 
    
         
            +
                        )
         
     | 
| 
      
 290 
     | 
    
         
            +
                        down_input_scale = torch.empty(
         
     | 
| 
      
 291 
     | 
    
         
            +
                            (
         
     | 
| 
      
 292 
     | 
    
         
            +
                                gateup_output.shape[0],
         
     | 
| 
      
 293 
     | 
    
         
            +
                                gateup_output.shape[1],
         
     | 
| 
      
 294 
     | 
    
         
            +
                                gateup_output.shape[2] // 2 // scale_block_size,
         
     | 
| 
      
 295 
     | 
    
         
            +
                            ),
         
     | 
| 
      
 296 
     | 
    
         
            +
                            device=hidden_states_device,
         
     | 
| 
      
 297 
     | 
    
         
            +
                            dtype=torch.float32,
         
     | 
| 
      
 298 
     | 
    
         
            +
                        )
         
     | 
| 
      
 299 
     | 
    
         
            +
                        silu_and_mul_masked_post_quant_fwd(
         
     | 
| 
      
 300 
     | 
    
         
            +
                            gateup_output,
         
     | 
| 
      
 301 
     | 
    
         
            +
                            down_input,
         
     | 
| 
      
 302 
     | 
    
         
            +
                            down_input_scale,
         
     | 
| 
      
 303 
     | 
    
         
            +
                            scale_block_size,
         
     | 
| 
      
 304 
     | 
    
         
            +
                            masked_m,
         
     | 
| 
      
 305 
     | 
    
         
            +
                            scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
         
     | 
| 
      
 306 
     | 
    
         
            +
                        )
         
     | 
| 
       287 
307 
     | 
    
         
             
                    del gateup_output
         
     | 
| 
       288 
308 
     | 
    
         | 
| 
       289 
309 
     | 
    
         
             
                    # GroupGemm-1
         
     | 
| 
         @@ -97,7 +97,6 @@ class DeepEPNormalCombineInput(NamedTuple): 
     | 
|
| 
       97 
97 
     | 
    
         
             
                hidden_states: torch.Tensor
         
     | 
| 
       98 
98 
     | 
    
         
             
                topk_ids: torch.Tensor
         
     | 
| 
       99 
99 
     | 
    
         
             
                topk_weights: torch.Tensor
         
     | 
| 
       100 
     | 
    
         
            -
                overlap_args: Optional[CombineOverlapArgs] = None
         
     | 
| 
       101 
100 
     | 
    
         | 
| 
       102 
101 
     | 
    
         
             
                @property
         
     | 
| 
       103 
102 
     | 
    
         
             
                def format(self) -> CombineInputFormat:
         
     | 
| 
         @@ -110,7 +109,6 @@ class DeepEPLLCombineInput(NamedTuple): 
     | 
|
| 
       110 
109 
     | 
    
         
             
                hidden_states: torch.Tensor
         
     | 
| 
       111 
110 
     | 
    
         
             
                topk_ids: torch.Tensor
         
     | 
| 
       112 
111 
     | 
    
         
             
                topk_weights: torch.Tensor
         
     | 
| 
       113 
     | 
    
         
            -
                overlap_args: Optional[CombineOverlapArgs] = None
         
     | 
| 
       114 
112 
     | 
    
         | 
| 
       115 
113 
     | 
    
         
             
                @property
         
     | 
| 
       116 
114 
     | 
    
         
             
                def format(self) -> CombineInputFormat:
         
     | 
| 
         @@ -333,7 +331,7 @@ class _DeepEPDispatcherImplBase: 
     | 
|
| 
       333 
331 
     | 
    
         
             
                    hidden_states: torch.Tensor,
         
     | 
| 
       334 
332 
     | 
    
         
             
                    topk_ids: torch.Tensor,
         
     | 
| 
       335 
333 
     | 
    
         
             
                    topk_weights: torch.Tensor,
         
     | 
| 
       336 
     | 
    
         
            -
                    overlap_args: Optional[ 
     | 
| 
      
 334 
     | 
    
         
            +
                    overlap_args: Optional[CombineOverlapArgs] = None,
         
     | 
| 
       337 
335 
     | 
    
         
             
                ):
         
     | 
| 
       338 
336 
     | 
    
         
             
                    raise NotImplementedError
         
     | 
| 
       339 
337 
     | 
    
         | 
| 
         @@ -463,7 +461,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): 
     | 
|
| 
       463 
461 
     | 
    
         
             
                    hidden_states: torch.Tensor,
         
     | 
| 
       464 
462 
     | 
    
         
             
                    topk_ids: torch.Tensor,
         
     | 
| 
       465 
463 
     | 
    
         
             
                    topk_weights: torch.Tensor,
         
     | 
| 
       466 
     | 
    
         
            -
                    overlap_args: Optional[ 
     | 
| 
      
 464 
     | 
    
         
            +
                    overlap_args: Optional[CombineOverlapArgs] = None,
         
     | 
| 
       467 
465 
     | 
    
         
             
                ):
         
     | 
| 
       468 
466 
     | 
    
         | 
| 
       469 
467 
     | 
    
         
             
                    if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
         
     | 
| 
         @@ -619,7 +617,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): 
     | 
|
| 
       619 
617 
     | 
    
         
             
                    hidden_states: torch.Tensor,
         
     | 
| 
       620 
618 
     | 
    
         
             
                    topk_ids: torch.Tensor,
         
     | 
| 
       621 
619 
     | 
    
         
             
                    topk_weights: torch.Tensor,
         
     | 
| 
       622 
     | 
    
         
            -
                    overlap_args: Optional[ 
     | 
| 
      
 620 
     | 
    
         
            +
                    overlap_args: Optional[CombineOverlapArgs] = None,
         
     | 
| 
       623 
621 
     | 
    
         
             
                ):
         
     | 
| 
       624 
622 
     | 
    
         
             
                    hidden_states, event, hook = self._combine_core(
         
     | 
| 
       625 
623 
     | 
    
         
             
                        hidden_states,
         
     | 
| 
         @@ -645,7 +643,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): 
     | 
|
| 
       645 
643 
     | 
    
         
             
                    hidden_states: torch.Tensor,
         
     | 
| 
       646 
644 
     | 
    
         
             
                    topk_ids: torch.Tensor,
         
     | 
| 
       647 
645 
     | 
    
         
             
                    topk_weights: torch.Tensor,
         
     | 
| 
       648 
     | 
    
         
            -
                    overlap_args: Optional[ 
     | 
| 
      
 646 
     | 
    
         
            +
                    overlap_args: Optional[CombineOverlapArgs] = None,
         
     | 
| 
       649 
647 
     | 
    
         
             
                ):
         
     | 
| 
       650 
648 
     | 
    
         
             
                    buffer = self._get_buffer()
         
     | 
| 
       651 
649 
     | 
    
         | 
| 
         @@ -762,16 +760,21 @@ class DeepEPDispatcher(BaseDispatcher): 
     | 
|
| 
       762 
760 
     | 
    
         
             
                    del self._dispatch_intermediate_state
         
     | 
| 
       763 
761 
     | 
    
         
             
                    return self._get_impl().dispatch_b(*inner_state)
         
     | 
| 
       764 
762 
     | 
    
         | 
| 
       765 
     | 
    
         
            -
                def combine( 
     | 
| 
       766 
     | 
    
         
            -
                    self 
     | 
| 
      
 763 
     | 
    
         
            +
                def combine(
         
     | 
| 
      
 764 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 765 
     | 
    
         
            +
                    combine_input: CombineInput,
         
     | 
| 
      
 766 
     | 
    
         
            +
                    overlap_args: Optional[CombineOverlapArgs] = None,
         
     | 
| 
      
 767 
     | 
    
         
            +
                ) -> Tuple:
         
     | 
| 
      
 768 
     | 
    
         
            +
                    self.combine_a(combine_input, overlap_args)
         
     | 
| 
       767 
769 
     | 
    
         
             
                    ret = self.combine_b()
         
     | 
| 
       768 
770 
     | 
    
         
             
                    return ret
         
     | 
| 
       769 
771 
     | 
    
         | 
| 
       770 
772 
     | 
    
         
             
                def combine_a(
         
     | 
| 
       771 
773 
     | 
    
         
             
                    self,
         
     | 
| 
       772 
774 
     | 
    
         
             
                    combine_input: CombineInput,
         
     | 
| 
      
 775 
     | 
    
         
            +
                    overlap_args: Optional[CombineOverlapArgs] = None,
         
     | 
| 
       773 
776 
     | 
    
         
             
                ):
         
     | 
| 
       774 
     | 
    
         
            -
                    hidden_states, topk_ids, topk_weights 
     | 
| 
      
 777 
     | 
    
         
            +
                    hidden_states, topk_ids, topk_weights = combine_input
         
     | 
| 
       775 
778 
     | 
    
         
             
                    self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
         
     | 
| 
       776 
779 
     | 
    
         
             
                    inner_state = self._get_impl().combine_a(
         
     | 
| 
       777 
780 
     | 
    
         
             
                        hidden_states=hidden_states,
         
     | 
    
        sglang/srt/layers/moe/topk.py
    CHANGED
    
    | 
         @@ -314,16 +314,41 @@ class TopK(CustomOp): 
     | 
|
| 
       314 
314 
     | 
    
         
             
                    num_token_non_padded: Optional[torch.Tensor] = None,
         
     | 
| 
       315 
315 
     | 
    
         
             
                    expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
         
     | 
| 
       316 
316 
     | 
    
         
             
                ) -> TopKOutput:
         
     | 
| 
       317 
     | 
    
         
            -
                    global_num_experts = router_logits.shape[-1]
         
     | 
| 
       318 
317 
     | 
    
         | 
| 
       319 
     | 
    
         
            -
                     
     | 
| 
       320 
     | 
    
         
            -
                     
     | 
| 
      
 318 
     | 
    
         
            +
                    use_grouped_topk = self.topk_config.use_grouped_topk
         
     | 
| 
      
 319 
     | 
    
         
            +
                    torch_native = self.topk_config.torch_native
         
     | 
| 
      
 320 
     | 
    
         
            +
                    renormalize = self.topk_config.renormalize
         
     | 
| 
       321 
321 
     | 
    
         | 
| 
      
 322 
     | 
    
         
            +
                    if not use_grouped_topk and not torch_native:
         
     | 
| 
      
 323 
     | 
    
         
            +
                        topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
         
     | 
| 
      
 324 
     | 
    
         
            +
                            router_logits,
         
     | 
| 
      
 325 
     | 
    
         
            +
                            k=self.topk_config.top_k,
         
     | 
| 
      
 326 
     | 
    
         
            +
                        )
         
     | 
| 
      
 327 
     | 
    
         
            +
                        topk_weights = topk_weights.to(torch.float32)
         
     | 
| 
      
 328 
     | 
    
         
            +
             
     | 
| 
      
 329 
     | 
    
         
            +
                        if renormalize:
         
     | 
| 
      
 330 
     | 
    
         
            +
                            topk_weights_sum = (
         
     | 
| 
      
 331 
     | 
    
         
            +
                                topk_weights.sum(dim=-1, keepdim=True)
         
     | 
| 
      
 332 
     | 
    
         
            +
                                if self.topk_config.num_fused_shared_experts == 0
         
     | 
| 
      
 333 
     | 
    
         
            +
                                else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
         
     | 
| 
      
 334 
     | 
    
         
            +
                            )
         
     | 
| 
      
 335 
     | 
    
         
            +
                            topk_weights = topk_weights / topk_weights_sum
         
     | 
| 
      
 336 
     | 
    
         
            +
             
     | 
| 
      
 337 
     | 
    
         
            +
                        if expert_location_dispatch_info is not None:
         
     | 
| 
      
 338 
     | 
    
         
            +
                            topk_ids = topk_ids_logical_to_physical(
         
     | 
| 
      
 339 
     | 
    
         
            +
                                topk_ids, expert_location_dispatch_info
         
     | 
| 
      
 340 
     | 
    
         
            +
                            )
         
     | 
| 
      
 341 
     | 
    
         
            +
                        get_global_expert_distribution_recorder().on_select_experts(
         
     | 
| 
      
 342 
     | 
    
         
            +
                            topk_ids=topk_ids
         
     | 
| 
      
 343 
     | 
    
         
            +
                        )
         
     | 
| 
      
 344 
     | 
    
         
            +
             
     | 
| 
      
 345 
     | 
    
         
            +
                        return StandardTopKOutput(topk_weights, topk_ids, _)
         
     | 
| 
      
 346 
     | 
    
         
            +
                    if use_grouped_topk and not torch_native and router_logits.shape[-1] == 256:
         
     | 
| 
      
 347 
     | 
    
         
            +
                        # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
         
     | 
| 
       322 
348 
     | 
    
         
             
                        routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
         
     | 
| 
       323 
     | 
    
         
            -
                        router_logits = router_logits.to(torch.float32)
         
     | 
| 
       324 
349 
     | 
    
         | 
| 
       325 
350 
     | 
    
         
             
                        topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
         
     | 
| 
       326 
     | 
    
         
            -
                            router_logits,
         
     | 
| 
      
 351 
     | 
    
         
            +
                            router_logits.to(torch.float32),
         
     | 
| 
       327 
352 
     | 
    
         
             
                            k=self.topk_config.top_k,
         
     | 
| 
       328 
353 
     | 
    
         
             
                            bias=self.topk_config.correction_bias.to(torch.float32),
         
     | 
| 
       329 
354 
     | 
    
         
             
                            k_group=self.topk_config.topk_group,
         
     | 
| 
         @@ -335,7 +360,7 @@ class TopK(CustomOp): 
     | 
|
| 
       335 
360 
     | 
    
         
             
                            eps=float(1e-20),
         
     | 
| 
       336 
361 
     | 
    
         
             
                        )
         
     | 
| 
       337 
362 
     | 
    
         | 
| 
       338 
     | 
    
         
            -
                        if  
     | 
| 
      
 363 
     | 
    
         
            +
                        if renormalize:
         
     | 
| 
       339 
364 
     | 
    
         
             
                            topk_weights_sum = (
         
     | 
| 
       340 
365 
     | 
    
         
             
                                topk_weights.sum(dim=-1, keepdim=True)
         
     | 
| 
       341 
366 
     | 
    
         
             
                                if self.topk_config.num_fused_shared_experts == 0
         
     | 
    
        sglang/srt/layers/pooler.py
    CHANGED
    
    | 
         @@ -20,7 +20,9 @@ class PoolingType(IntEnum): 
     | 
|
| 
       20 
20 
     | 
    
         | 
| 
       21 
21 
     | 
    
         
             
            @dataclass
         
     | 
| 
       22 
22 
     | 
    
         
             
            class EmbeddingPoolerOutput:
         
     | 
| 
       23 
     | 
    
         
            -
                 
     | 
| 
      
 23 
     | 
    
         
            +
                # Pooler can return list[tensor] instead of tensor if the dimension of each tensor in the batch is different
         
     | 
| 
      
 24 
     | 
    
         
            +
                # due to different per-request matryoshka dim truncation
         
     | 
| 
      
 25 
     | 
    
         
            +
                embeddings: torch.Tensor | list[torch.Tensor]
         
     | 
| 
       24 
26 
     | 
    
         | 
| 
       25 
27 
     | 
    
         | 
| 
       26 
28 
     | 
    
         
             
            class Pooler(nn.Module):
         
     | 
| 
         @@ -42,6 +44,7 @@ class Pooler(nn.Module): 
     | 
|
| 
       42 
44 
     | 
    
         
             
                def forward(
         
     | 
| 
       43 
45 
     | 
    
         
             
                    self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
         
     | 
| 
       44 
46 
     | 
    
         
             
                ) -> EmbeddingPoolerOutput:
         
     | 
| 
      
 47 
     | 
    
         
            +
             
     | 
| 
       45 
48 
     | 
    
         
             
                    if self.pooling_type == PoolingType.LAST:
         
     | 
| 
       46 
49 
     | 
    
         
             
                        last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
         
     | 
| 
       47 
50 
     | 
    
         
             
                        pooled_data = hidden_states[last_token_indices]
         
     | 
| 
         @@ -53,8 +56,24 @@ class Pooler(nn.Module): 
     | 
|
| 
       53 
56 
     | 
    
         
             
                    else:
         
     | 
| 
       54 
57 
     | 
    
         
             
                        raise ValueError(f"Invalid pooling type: {self.pooling_type}")
         
     | 
| 
       55 
58 
     | 
    
         | 
| 
      
 59 
     | 
    
         
            +
                    if forward_batch.dimensions is not None:
         
     | 
| 
      
 60 
     | 
    
         
            +
                        all_same_dimensions = len(set(forward_batch.dimensions)) == 1
         
     | 
| 
      
 61 
     | 
    
         
            +
                        if all_same_dimensions:
         
     | 
| 
      
 62 
     | 
    
         
            +
                            pooled_data = pooled_data[..., : forward_batch.dimensions[0]]
         
     | 
| 
      
 63 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 64 
     | 
    
         
            +
                            pooled_data = [
         
     | 
| 
      
 65 
     | 
    
         
            +
                                tensor[..., :dim]
         
     | 
| 
      
 66 
     | 
    
         
            +
                                for tensor, dim in zip(pooled_data, forward_batch.dimensions)
         
     | 
| 
      
 67 
     | 
    
         
            +
                            ]
         
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
       56 
69 
     | 
    
         
             
                    if self.normalize:
         
     | 
| 
       57 
     | 
    
         
            -
                         
     | 
| 
      
 70 
     | 
    
         
            +
                        if isinstance(pooled_data, list):
         
     | 
| 
      
 71 
     | 
    
         
            +
                            pooled_data = [
         
     | 
| 
      
 72 
     | 
    
         
            +
                                nn.functional.normalize(tensor, p=2, dim=-1)
         
     | 
| 
      
 73 
     | 
    
         
            +
                                for tensor in pooled_data
         
     | 
| 
      
 74 
     | 
    
         
            +
                            ]
         
     | 
| 
      
 75 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 76 
     | 
    
         
            +
                            pooled_data = nn.functional.normalize(pooled_data, p=2, dim=-1)
         
     | 
| 
       58 
77 
     | 
    
         | 
| 
       59 
78 
     | 
    
         
             
                    return EmbeddingPoolerOutput(embeddings=pooled_data)
         
     | 
| 
       60 
79 
     | 
    
         |