sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +56 -12
 - sglang/launch_server.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
 - sglang/srt/compilation/backend.py +1 -1
 - sglang/srt/configs/model_config.py +5 -5
 - sglang/srt/distributed/parallel_state.py +0 -7
 - sglang/srt/entrypoints/engine.py +18 -15
 - sglang/srt/entrypoints/grpc_server.py +0 -1
 - sglang/srt/entrypoints/http_server.py +75 -94
 - sglang/srt/environ.py +16 -2
 - sglang/srt/eplb/expert_distribution.py +30 -0
 - sglang/srt/function_call/function_call_parser.py +2 -0
 - sglang/srt/function_call/minimax_m2.py +367 -0
 - sglang/srt/layers/activation.py +6 -0
 - sglang/srt/layers/attention/flashattention_backend.py +12 -2
 - sglang/srt/layers/attention/flashinfer_backend.py +10 -1
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
 - sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
 - sglang/srt/layers/attention/utils.py +78 -0
 - sglang/srt/layers/communicator.py +1 -0
 - sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
 - sglang/srt/layers/layernorm.py +19 -4
 - sglang/srt/layers/logits_processor.py +5 -0
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
 - sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
 - sglang/srt/layers/moe/ep_moe/layer.py +79 -272
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
 - sglang/srt/layers/moe/moe_runner/runner.py +3 -0
 - sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
 - sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
 - sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
 - sglang/srt/layers/moe/topk.py +4 -4
 - sglang/srt/layers/moe/utils.py +3 -4
 - sglang/srt/layers/quantization/__init__.py +3 -5
 - sglang/srt/layers/quantization/awq.py +0 -3
 - sglang/srt/layers/quantization/base_config.py +7 -0
 - sglang/srt/layers/quantization/fp8.py +68 -63
 - sglang/srt/layers/quantization/gguf.py +566 -0
 - sglang/srt/layers/quantization/mxfp4.py +30 -38
 - sglang/srt/layers/quantization/unquant.py +23 -45
 - sglang/srt/layers/quantization/w4afp8.py +38 -2
 - sglang/srt/layers/radix_attention.py +5 -2
 - sglang/srt/layers/rotary_embedding.py +13 -1
 - sglang/srt/layers/sampler.py +12 -1
 - sglang/srt/managers/io_struct.py +3 -0
 - sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
 - sglang/srt/managers/scheduler.py +21 -15
 - sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
 - sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
 - sglang/srt/managers/tokenizer_manager.py +11 -19
 - sglang/srt/mem_cache/hicache_storage.py +7 -1
 - sglang/srt/mem_cache/memory_pool.py +82 -0
 - sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
 - sglang/srt/model_executor/forward_batch_info.py +44 -3
 - sglang/srt/model_executor/model_runner.py +1 -149
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
 - sglang/srt/models/deepseek_v2.py +147 -44
 - sglang/srt/models/glm4_moe.py +322 -354
 - sglang/srt/models/glm4_moe_nextn.py +4 -14
 - sglang/srt/models/glm4v_moe.py +29 -196
 - sglang/srt/models/minimax_m2.py +922 -0
 - sglang/srt/models/nvila.py +355 -0
 - sglang/srt/models/nvila_lite.py +184 -0
 - sglang/srt/models/qwen2.py +22 -1
 - sglang/srt/models/qwen3.py +34 -4
 - sglang/srt/models/qwen3_moe.py +2 -4
 - sglang/srt/multimodal/processors/base_processor.py +1 -0
 - sglang/srt/multimodal/processors/glm4v.py +1 -1
 - sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
 - sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
 - sglang/srt/parser/reasoning_parser.py +28 -1
 - sglang/srt/server_args.py +365 -186
 - sglang/srt/single_batch_overlap.py +2 -7
 - sglang/srt/utils/common.py +87 -42
 - sglang/srt/utils/hf_transformers_utils.py +7 -3
 - sglang/test/test_deterministic.py +235 -12
 - sglang/test/test_deterministic_utils.py +2 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
 - sglang/srt/models/vila.py +0 -306
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
 
| 
         @@ -58,7 +58,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip() 
     | 
|
| 
       58 
58 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
       59 
59 
     | 
    
         | 
| 
       60 
60 
     | 
    
         | 
| 
       61 
     | 
    
         
            -
            class  
     | 
| 
      
 61 
     | 
    
         
            +
            class DeepEPNormalDispatchOutput(NamedTuple):
         
     | 
| 
       62 
62 
     | 
    
         
             
                """DeepEP normal dispatch output."""
         
     | 
| 
       63 
63 
     | 
    
         | 
| 
       64 
64 
     | 
    
         
             
                hidden_states: torch.Tensor
         
     | 
| 
         @@ -72,7 +72,7 @@ class DeepEPNormalOutput(NamedTuple): 
     | 
|
| 
       72 
72 
     | 
    
         
             
                    return DispatchOutputFormat.DEEPEP_NORMAL
         
     | 
| 
       73 
73 
     | 
    
         | 
| 
       74 
74 
     | 
    
         | 
| 
       75 
     | 
    
         
            -
            class  
     | 
| 
      
 75 
     | 
    
         
            +
            class DeepEPLLDispatchOutput(NamedTuple):
         
     | 
| 
       76 
76 
     | 
    
         
             
                """DeepEP low latency dispatch output."""
         
     | 
| 
       77 
77 
     | 
    
         | 
| 
       78 
78 
     | 
    
         
             
                hidden_states: torch.Tensor
         
     | 
| 
         @@ -87,14 +87,17 @@ class DeepEPLLOutput(NamedTuple): 
     | 
|
| 
       87 
87 
     | 
    
         
             
                    return DispatchOutputFormat.DEEPEP_LL
         
     | 
| 
       88 
88 
     | 
    
         | 
| 
       89 
89 
     | 
    
         | 
| 
       90 
     | 
    
         
            -
            assert isinstance( 
     | 
| 
       91 
     | 
    
         
            -
            assert isinstance( 
     | 
| 
      
 90 
     | 
    
         
            +
            assert isinstance(DeepEPNormalDispatchOutput, DispatchOutput)
         
     | 
| 
      
 91 
     | 
    
         
            +
            assert isinstance(DeepEPLLDispatchOutput, DispatchOutput)
         
     | 
| 
       92 
92 
     | 
    
         | 
| 
       93 
93 
     | 
    
         | 
| 
       94 
94 
     | 
    
         
             
            class DeepEPNormalCombineInput(NamedTuple):
         
     | 
| 
       95 
95 
     | 
    
         
             
                """DeepEP normal combine input."""
         
     | 
| 
       96 
96 
     | 
    
         | 
| 
       97 
     | 
    
         
            -
                 
     | 
| 
      
 97 
     | 
    
         
            +
                hidden_states: torch.Tensor
         
     | 
| 
      
 98 
     | 
    
         
            +
                topk_ids: torch.Tensor
         
     | 
| 
      
 99 
     | 
    
         
            +
                topk_weights: torch.Tensor
         
     | 
| 
      
 100 
     | 
    
         
            +
                overlap_args: Optional[CombineOverlapArgs] = None
         
     | 
| 
       98 
101 
     | 
    
         | 
| 
       99 
102 
     | 
    
         
             
                @property
         
     | 
| 
       100 
103 
     | 
    
         
             
                def format(self) -> CombineInputFormat:
         
     | 
| 
         @@ -104,7 +107,10 @@ class DeepEPNormalCombineInput(NamedTuple): 
     | 
|
| 
       104 
107 
     | 
    
         
             
            class DeepEPLLCombineInput(NamedTuple):
         
     | 
| 
       105 
108 
     | 
    
         
             
                """DeepEP low latency combine input."""
         
     | 
| 
       106 
109 
     | 
    
         | 
| 
       107 
     | 
    
         
            -
                 
     | 
| 
      
 110 
     | 
    
         
            +
                hidden_states: torch.Tensor
         
     | 
| 
      
 111 
     | 
    
         
            +
                topk_ids: torch.Tensor
         
     | 
| 
      
 112 
     | 
    
         
            +
                topk_weights: torch.Tensor
         
     | 
| 
      
 113 
     | 
    
         
            +
                overlap_args: Optional[CombineOverlapArgs] = None
         
     | 
| 
       108 
114 
     | 
    
         | 
| 
       109 
115 
     | 
    
         
             
                @property
         
     | 
| 
       110 
116 
     | 
    
         
             
                def format(self) -> CombineInputFormat:
         
     | 
| 
         @@ -383,7 +389,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): 
     | 
|
| 
       383 
389 
     | 
    
         
             
                    else:
         
     | 
| 
       384 
390 
     | 
    
         
             
                        hidden_states_scale = None
         
     | 
| 
       385 
391 
     | 
    
         | 
| 
       386 
     | 
    
         
            -
                    return  
     | 
| 
      
 392 
     | 
    
         
            +
                    return DeepEPNormalDispatchOutput(
         
     | 
| 
       387 
393 
     | 
    
         
             
                        hidden_states,
         
     | 
| 
       388 
394 
     | 
    
         
             
                        hidden_states_scale,
         
     | 
| 
       389 
395 
     | 
    
         
             
                        topk_ids,
         
     | 
| 
         @@ -562,7 +568,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): 
     | 
|
| 
       562 
568 
     | 
    
         
             
                    else:
         
     | 
| 
       563 
569 
     | 
    
         
             
                        hidden_states_scale = None
         
     | 
| 
       564 
570 
     | 
    
         | 
| 
       565 
     | 
    
         
            -
                    deepep_output =  
     | 
| 
      
 571 
     | 
    
         
            +
                    deepep_output = DeepEPLLDispatchOutput(
         
     | 
| 
       566 
572 
     | 
    
         
             
                        hidden_states,
         
     | 
| 
       567 
573 
     | 
    
         
             
                        hidden_states_scale,
         
     | 
| 
       568 
574 
     | 
    
         
             
                        topk_ids,
         
     | 
| 
         @@ -756,18 +762,16 @@ class DeepEPDispatcher(BaseDispatcher): 
     | 
|
| 
       756 
762 
     | 
    
         
             
                    del self._dispatch_intermediate_state
         
     | 
| 
       757 
763 
     | 
    
         
             
                    return self._get_impl().dispatch_b(*inner_state)
         
     | 
| 
       758 
764 
     | 
    
         | 
| 
       759 
     | 
    
         
            -
                def combine(self,  
     | 
| 
       760 
     | 
    
         
            -
                    self.combine_a( 
     | 
| 
      
 765 
     | 
    
         
            +
                def combine(self, combine_input: CombineInput) -> Tuple:
         
     | 
| 
      
 766 
     | 
    
         
            +
                    self.combine_a(combine_input)
         
     | 
| 
       761 
767 
     | 
    
         
             
                    ret = self.combine_b()
         
     | 
| 
       762 
768 
     | 
    
         
             
                    return ret
         
     | 
| 
       763 
769 
     | 
    
         | 
| 
       764 
770 
     | 
    
         
             
                def combine_a(
         
     | 
| 
       765 
771 
     | 
    
         
             
                    self,
         
     | 
| 
       766 
     | 
    
         
            -
                     
     | 
| 
       767 
     | 
    
         
            -
                    topk_ids: torch.Tensor,
         
     | 
| 
       768 
     | 
    
         
            -
                    topk_weights: torch.Tensor,
         
     | 
| 
       769 
     | 
    
         
            -
                    overlap_args: Optional["CombineOverlapArgs"] = None,
         
     | 
| 
      
 772 
     | 
    
         
            +
                    combine_input: CombineInput,
         
     | 
| 
       770 
773 
     | 
    
         
             
                ):
         
     | 
| 
      
 774 
     | 
    
         
            +
                    hidden_states, topk_ids, topk_weights, overlap_args = combine_input
         
     | 
| 
       771 
775 
     | 
    
         
             
                    self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
         
     | 
| 
       772 
776 
     | 
    
         
             
                    inner_state = self._get_impl().combine_a(
         
     | 
| 
       773 
777 
     | 
    
         
             
                        hidden_states=hidden_states,
         
     | 
| 
         @@ -88,7 +88,7 @@ class StandardDispatcher(BaseDispatcher): 
     | 
|
| 
       88 
88 
     | 
    
         
             
                            topk_output = topk_output._replace(
         
     | 
| 
       89 
89 
     | 
    
         
             
                                topk_ids=self.local_expert_mapping[topk_output.topk_ids]
         
     | 
| 
       90 
90 
     | 
    
         
             
                            )
         
     | 
| 
       91 
     | 
    
         
            -
                        elif TopKOutputChecker. 
     | 
| 
      
 91 
     | 
    
         
            +
                        elif TopKOutputChecker.format_is_triton_kernels(topk_output):
         
     | 
| 
       92 
92 
     | 
    
         
             
                            raise NotImplementedError()
         
     | 
| 
       93 
93 
     | 
    
         | 
| 
       94 
94 
     | 
    
         
             
                    return StandardDispatchOutput(
         
     | 
    
        sglang/srt/layers/moe/topk.py
    CHANGED
    
    | 
         @@ -111,10 +111,10 @@ class TopKOutputChecker: 
     | 
|
| 
       111 
111 
     | 
    
         
             
                    return topk_output.format.is_standard()
         
     | 
| 
       112 
112 
     | 
    
         | 
| 
       113 
113 
     | 
    
         
             
                @staticmethod
         
     | 
| 
       114 
     | 
    
         
            -
                def  
     | 
| 
      
 114 
     | 
    
         
            +
                def format_is_triton_kernels(
         
     | 
| 
       115 
115 
     | 
    
         
             
                    topk_output: TopKOutput,
         
     | 
| 
       116 
116 
     | 
    
         
             
                ) -> TypeGuard[TritonKernelTopKOutput]:
         
     | 
| 
       117 
     | 
    
         
            -
                    return topk_output.format. 
     | 
| 
      
 117 
     | 
    
         
            +
                    return topk_output.format.is_triton_kernels()
         
     | 
| 
       118 
118 
     | 
    
         | 
| 
       119 
119 
     | 
    
         
             
                @staticmethod
         
     | 
| 
       120 
120 
     | 
    
         
             
                def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
         
     | 
| 
         @@ -129,7 +129,7 @@ class TopKOutputFormat(Enum): 
     | 
|
| 
       129 
129 
     | 
    
         
             
                def is_standard(self) -> bool:
         
     | 
| 
       130 
130 
     | 
    
         
             
                    return self == TopKOutputFormat.STANDARD
         
     | 
| 
       131 
131 
     | 
    
         | 
| 
       132 
     | 
    
         
            -
                def  
     | 
| 
      
 132 
     | 
    
         
            +
                def is_triton_kernels(self) -> bool:
         
     | 
| 
       133 
133 
     | 
    
         
             
                    return self == TopKOutputFormat.TRITON_KERNEL
         
     | 
| 
       134 
134 
     | 
    
         | 
| 
       135 
135 
     | 
    
         
             
                def is_bypassed(self) -> bool:
         
     | 
| 
         @@ -254,7 +254,7 @@ class TopK(CustomOp): 
     | 
|
| 
       254 
254 
     | 
    
         
             
                ) -> TopKOutput:
         
     | 
| 
       255 
255 
     | 
    
         
             
                    if self.topk_config.output_format is not None:
         
     | 
| 
       256 
256 
     | 
    
         
             
                        output_format = self.topk_config.output_format
         
     | 
| 
       257 
     | 
    
         
            -
                    elif get_moe_runner_backend(). 
     | 
| 
      
 257 
     | 
    
         
            +
                    elif get_moe_runner_backend().is_triton_kernels():
         
     | 
| 
       258 
258 
     | 
    
         
             
                        output_format = TopKOutputFormat.TRITON_KERNEL
         
     | 
| 
       259 
259 
     | 
    
         
             
                    elif (
         
     | 
| 
       260 
260 
     | 
    
         
             
                        should_use_flashinfer_trtllm_moe()
         
     | 
    
        sglang/srt/layers/moe/utils.py
    CHANGED
    
    | 
         @@ -51,7 +51,7 @@ class MoeRunnerBackend(Enum): 
     | 
|
| 
       51 
51 
     | 
    
         
             
                AUTO = "auto"
         
     | 
| 
       52 
52 
     | 
    
         
             
                DEEP_GEMM = "deep_gemm"
         
     | 
| 
       53 
53 
     | 
    
         
             
                TRITON = "triton"
         
     | 
| 
       54 
     | 
    
         
            -
                 
     | 
| 
      
 54 
     | 
    
         
            +
                TRITON_KERNELS = "triton_kernel"
         
     | 
| 
       55 
55 
     | 
    
         
             
                FLASHINFER_TRTLLM = "flashinfer_trtllm"
         
     | 
| 
       56 
56 
     | 
    
         
             
                FLASHINFER_CUTLASS = "flashinfer_cutlass"
         
     | 
| 
       57 
57 
     | 
    
         
             
                FLASHINFER_MXFP4 = "flashinfer_mxfp4"
         
     | 
| 
         @@ -67,8 +67,8 @@ class MoeRunnerBackend(Enum): 
     | 
|
| 
       67 
67 
     | 
    
         
             
                def is_triton(self):
         
     | 
| 
       68 
68 
     | 
    
         
             
                    return self == MoeRunnerBackend.TRITON
         
     | 
| 
       69 
69 
     | 
    
         | 
| 
       70 
     | 
    
         
            -
                def  
     | 
| 
       71 
     | 
    
         
            -
                    return self == MoeRunnerBackend. 
     | 
| 
      
 70 
     | 
    
         
            +
                def is_triton_kernels(self):
         
     | 
| 
      
 71 
     | 
    
         
            +
                    return self == MoeRunnerBackend.TRITON_KERNELS
         
     | 
| 
       72 
72 
     | 
    
         | 
| 
       73 
73 
     | 
    
         
             
                def is_flashinfer_trtllm(self):
         
     | 
| 
       74 
74 
     | 
    
         
             
                    return self == MoeRunnerBackend.FLASHINFER_TRTLLM
         
     | 
| 
         @@ -152,7 +152,6 @@ def initialize_moe_config(server_args: ServerArgs): 
     | 
|
| 
       152 
152 
     | 
    
         
             
            def get_moe_a2a_backend() -> MoeA2ABackend:
         
     | 
| 
       153 
153 
     | 
    
         
             
                global MOE_A2A_BACKEND
         
     | 
| 
       154 
154 
     | 
    
         
             
                if MOE_A2A_BACKEND is None:
         
     | 
| 
       155 
     | 
    
         
            -
                    logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
         
     | 
| 
       156 
155 
     | 
    
         
             
                    MOE_A2A_BACKEND = MoeA2ABackend.NONE
         
     | 
| 
       157 
156 
     | 
    
         
             
                return MOE_A2A_BACKEND
         
     | 
| 
       158 
157 
     | 
    
         | 
| 
         @@ -12,7 +12,6 @@ try: 
     | 
|
| 
       12 
12 
     | 
    
         
             
                from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
         
     | 
| 
       13 
13 
     | 
    
         
             
                from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
         
     | 
| 
       14 
14 
     | 
    
         
             
                from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
         
     | 
| 
       15 
     | 
    
         
            -
                from vllm.model_executor.layers.quantization.gguf import GGUFConfig
         
     | 
| 
       16 
15 
     | 
    
         
             
                from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
         
     | 
| 
       17 
16 
     | 
    
         
             
                    GPTQMarlin24Config,
         
     | 
| 
       18 
17 
     | 
    
         
             
                )
         
     | 
| 
         @@ -32,9 +31,7 @@ except ImportError as e: 
     | 
|
| 
       32 
31 
     | 
    
         | 
| 
       33 
32 
     | 
    
         
             
                AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
         
     | 
| 
       34 
33 
     | 
    
         
             
                    ExpertsInt8Config
         
     | 
| 
       35 
     | 
    
         
            -
                ) =  
     | 
| 
       36 
     | 
    
         
            -
                    DummyConfig
         
     | 
| 
       37 
     | 
    
         
            -
                )
         
     | 
| 
      
 34 
     | 
    
         
            +
                ) = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
         
     | 
| 
       38 
35 
     | 
    
         | 
| 
       39 
36 
     | 
    
         | 
| 
       40 
37 
     | 
    
         
             
            from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
         
     | 
| 
         @@ -45,6 +42,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import 
     | 
|
| 
       45 
42 
     | 
    
         
             
            )
         
     | 
| 
       46 
43 
     | 
    
         
             
            from sglang.srt.layers.quantization.fp8 import Fp8Config
         
     | 
| 
       47 
44 
     | 
    
         
             
            from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config
         
     | 
| 
      
 45 
     | 
    
         
            +
            from sglang.srt.layers.quantization.gguf import GGUFConfig
         
     | 
| 
       48 
46 
     | 
    
         
             
            from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
         
     | 
| 
       49 
47 
     | 
    
         
             
            from sglang.srt.layers.quantization.modelopt_quant import (
         
     | 
| 
       50 
48 
     | 
    
         
             
                ModelOptFp4Config,
         
     | 
| 
         @@ -75,6 +73,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { 
     | 
|
| 
       75 
73 
     | 
    
         
             
                "w8a8_fp8": W8A8Fp8Config,
         
     | 
| 
       76 
74 
     | 
    
         
             
                "awq": AWQConfig,
         
     | 
| 
       77 
75 
     | 
    
         
             
                "awq_marlin": AWQMarlinConfig,
         
     | 
| 
      
 76 
     | 
    
         
            +
                "gguf": GGUFConfig,
         
     | 
| 
       78 
77 
     | 
    
         
             
                "gptq": GPTQConfig,
         
     | 
| 
       79 
78 
     | 
    
         
             
                "gptq_marlin": GPTQMarlinConfig,
         
     | 
| 
       80 
79 
     | 
    
         
             
                "moe_wna16": MoeWNA16Config,
         
     | 
| 
         @@ -108,7 +107,6 @@ VLLM_QUANTIZATION_METHODS = { 
     | 
|
| 
       108 
107 
     | 
    
         
             
                "deepspeedfp": DeepSpeedFPConfig,
         
     | 
| 
       109 
108 
     | 
    
         
             
                "tpu_int8": Int8TpuConfig,
         
     | 
| 
       110 
109 
     | 
    
         
             
                "marlin": MarlinConfig,
         
     | 
| 
       111 
     | 
    
         
            -
                "gguf": GGUFConfig,
         
     | 
| 
       112 
110 
     | 
    
         
             
                "gptq_marlin_24": GPTQMarlin24Config,
         
     | 
| 
       113 
111 
     | 
    
         
             
                "bitsandbytes": BitsAndBytesConfig,
         
     | 
| 
       114 
112 
     | 
    
         
             
                "qqq": QQQConfig,
         
     | 
| 
         @@ -840,12 +840,9 @@ class AWQMoEMethod(FusedMoEMethodBase): 
     | 
|
| 
       840 
840 
     | 
    
         
             
                        self.moe_runner_config.activation == "silu"
         
     | 
| 
       841 
841 
     | 
    
         
             
                    ), "Only SiLU activation is supported."
         
     | 
| 
       842 
842 
     | 
    
         | 
| 
       843 
     | 
    
         
            -
                    # The input must currently be float16
         
     | 
| 
       844 
843 
     | 
    
         
             
                    x = dispatch_output.hidden_states
         
     | 
| 
       845 
844 
     | 
    
         
             
                    topk_output = dispatch_output.topk_output
         
     | 
| 
       846 
     | 
    
         
            -
             
     | 
| 
       847 
845 
     | 
    
         
             
                    orig_dtype = x.dtype
         
     | 
| 
       848 
     | 
    
         
            -
                    x = x.half()
         
     | 
| 
       849 
846 
     | 
    
         | 
| 
       850 
847 
     | 
    
         
             
                    topk_weights, topk_ids, router_logits = topk_output
         
     | 
| 
       851 
848 
     | 
    
         | 
| 
         @@ -179,6 +179,13 @@ class QuantizationConfig(ABC): 
     | 
|
| 
       179 
179 
     | 
    
         
             
                        elif "NVFP4" in quant_algo or "FP4" in quant_algo:
         
     | 
| 
       180 
180 
     | 
    
         
             
                            return "modelopt_fp4"
         
     | 
| 
       181 
181 
     | 
    
         | 
| 
      
 182 
     | 
    
         
            +
                    # The hf_quant_config may be a parsed quant config, so we need to check the
         
     | 
| 
      
 183 
     | 
    
         
            +
                    # quant_method.
         
     | 
| 
      
 184 
     | 
    
         
            +
                    if hf_quant_config.get("quant_method", "") == "modelopt_fp8":
         
     | 
| 
      
 185 
     | 
    
         
            +
                        return "modelopt_fp8"
         
     | 
| 
      
 186 
     | 
    
         
            +
                    elif hf_quant_config.get("quant_method", "") == "modelopt_fp4":
         
     | 
| 
      
 187 
     | 
    
         
            +
                        return "modelopt_fp4"
         
     | 
| 
      
 188 
     | 
    
         
            +
             
     | 
| 
       182 
189 
     | 
    
         
             
                    return None
         
     | 
| 
       183 
190 
     | 
    
         | 
| 
       184 
191 
     | 
    
         
             
                @staticmethod
         
     | 
| 
         @@ -33,6 +33,7 @@ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading 
     | 
|
| 
       33 
33 
     | 
    
         
             
            from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
         
     | 
| 
       34 
34 
     | 
    
         
             
            from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
         
     | 
| 
       35 
35 
     | 
    
         
             
            from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
         
     | 
| 
      
 36 
     | 
    
         
            +
            from sglang.srt.layers.moe.utils import get_moe_runner_backend
         
     | 
| 
       36 
37 
     | 
    
         
             
            from sglang.srt.layers.parameter import (
         
     | 
| 
       37 
38 
     | 
    
         
             
                BlockQuantScaleParameter,
         
     | 
| 
       38 
39 
     | 
    
         
             
                ModelWeightParameter,
         
     | 
| 
         @@ -525,12 +526,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): 
     | 
|
| 
       525 
526 
     | 
    
         
             
                    self.quant_config = quant_config
         
     | 
| 
       526 
527 
     | 
    
         
             
                    self.block_quant = self.quant_config.weight_block_size is not None
         
     | 
| 
       527 
528 
     | 
    
         
             
                    self.cutlass_fp8_supported = cutlass_fp8_supported()
         
     | 
| 
       528 
     | 
    
         
            -
                    self.use_cutlass_fused_experts_fp8 = (
         
     | 
| 
       529 
     | 
    
         
            -
                        get_bool_env_var("SGLANG_CUTLASS_MOE")
         
     | 
| 
       530 
     | 
    
         
            -
                        and self.cutlass_fp8_supported
         
     | 
| 
       531 
     | 
    
         
            -
                        and self.block_quant
         
     | 
| 
       532 
     | 
    
         
            -
                        and (is_sm100_supported() or is_sm90_supported())
         
     | 
| 
       533 
     | 
    
         
            -
                    )
         
     | 
| 
       534 
529 
     | 
    
         | 
| 
       535 
530 
     | 
    
         
             
                def create_weights(
         
     | 
| 
       536 
531 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -638,58 +633,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): 
     | 
|
| 
       638 
633 
     | 
    
         
             
                        layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
         
     | 
| 
       639 
634 
     | 
    
         
             
                        layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
         
     | 
| 
       640 
635 
     | 
    
         
             
                        assert self.quant_config.activation_scheme == "dynamic"
         
     | 
| 
       641 
     | 
    
         
            -
                        if self. 
     | 
| 
       642 
     | 
    
         
            -
                            self. 
     | 
| 
       643 
     | 
    
         
            -
                                (num_experts,),
         
     | 
| 
       644 
     | 
    
         
            -
                                hidden_size,
         
     | 
| 
       645 
     | 
    
         
            -
                                device=w13_weight.device,
         
     | 
| 
       646 
     | 
    
         
            -
                                dtype=torch.int64,
         
     | 
| 
       647 
     | 
    
         
            -
                            )
         
     | 
| 
       648 
     | 
    
         
            -
                            self.c_strides1 = torch.full(
         
     | 
| 
       649 
     | 
    
         
            -
                                (num_experts,),
         
     | 
| 
       650 
     | 
    
         
            -
                                2 * intermediate_size_per_partition,
         
     | 
| 
       651 
     | 
    
         
            -
                                device=w13_weight.device,
         
     | 
| 
       652 
     | 
    
         
            -
                                dtype=torch.int64,
         
     | 
| 
       653 
     | 
    
         
            -
                            )
         
     | 
| 
       654 
     | 
    
         
            -
                            self.ab_strides2 = torch.full(
         
     | 
| 
       655 
     | 
    
         
            -
                                (num_experts,),
         
     | 
| 
       656 
     | 
    
         
            -
                                intermediate_size_per_partition,
         
     | 
| 
       657 
     | 
    
         
            -
                                device=w2_weight.device,
         
     | 
| 
       658 
     | 
    
         
            -
                                dtype=torch.int64,
         
     | 
| 
       659 
     | 
    
         
            -
                            )
         
     | 
| 
       660 
     | 
    
         
            -
                            self.c_strides2 = torch.full(
         
     | 
| 
       661 
     | 
    
         
            -
                                (num_experts,),
         
     | 
| 
       662 
     | 
    
         
            -
                                hidden_size,
         
     | 
| 
       663 
     | 
    
         
            -
                                device=w2_weight.device,
         
     | 
| 
       664 
     | 
    
         
            -
                                dtype=torch.int64,
         
     | 
| 
       665 
     | 
    
         
            -
                            )
         
     | 
| 
       666 
     | 
    
         
            -
                            self.workspace = torch.empty(
         
     | 
| 
       667 
     | 
    
         
            -
                                90000, device=w13_weight.device, dtype=torch.uint8
         
     | 
| 
       668 
     | 
    
         
            -
                            )
         
     | 
| 
       669 
     | 
    
         
            -
                            self.a_ptr = torch.empty(
         
     | 
| 
       670 
     | 
    
         
            -
                                num_experts, device=w13_weight.device, dtype=torch.int64
         
     | 
| 
       671 
     | 
    
         
            -
                            )
         
     | 
| 
       672 
     | 
    
         
            -
                            self.b_ptr = torch.empty(
         
     | 
| 
       673 
     | 
    
         
            -
                                num_experts, device=w13_weight.device, dtype=torch.int64
         
     | 
| 
       674 
     | 
    
         
            -
                            )
         
     | 
| 
       675 
     | 
    
         
            -
                            self.out_ptr = torch.empty(
         
     | 
| 
       676 
     | 
    
         
            -
                                num_experts, device=w13_weight.device, dtype=torch.int64
         
     | 
| 
       677 
     | 
    
         
            -
                            )
         
     | 
| 
       678 
     | 
    
         
            -
                            self.a_scales_ptr = torch.empty(
         
     | 
| 
       679 
     | 
    
         
            -
                                num_experts, device=w13_weight.device, dtype=torch.int64
         
     | 
| 
       680 
     | 
    
         
            -
                            )
         
     | 
| 
       681 
     | 
    
         
            -
                            self.b_scales_ptr = torch.empty(
         
     | 
| 
       682 
     | 
    
         
            -
                                num_experts, device=w13_weight.device, dtype=torch.int64
         
     | 
| 
       683 
     | 
    
         
            -
                            )
         
     | 
| 
       684 
     | 
    
         
            -
                            self.expert_offsets = torch.empty(
         
     | 
| 
       685 
     | 
    
         
            -
                                num_experts + 1, device=w13_weight.device, dtype=torch.int32
         
     | 
| 
       686 
     | 
    
         
            -
                            )
         
     | 
| 
       687 
     | 
    
         
            -
                            self.problem_sizes1 = torch.empty(
         
     | 
| 
       688 
     | 
    
         
            -
                                num_experts, 3, device=w13_weight.device, dtype=torch.int32
         
     | 
| 
       689 
     | 
    
         
            -
                            )
         
     | 
| 
       690 
     | 
    
         
            -
                            self.problem_sizes2 = torch.empty(
         
     | 
| 
       691 
     | 
    
         
            -
                                num_experts, 3, device=w13_weight.device, dtype=torch.int32
         
     | 
| 
       692 
     | 
    
         
            -
                            )
         
     | 
| 
      
 636 
     | 
    
         
            +
                        if self._should_use_cutlass_fused_experts():
         
     | 
| 
      
 637 
     | 
    
         
            +
                            self._ensure_cutlass_buffers_initialized(layer)
         
     | 
| 
       693 
638 
     | 
    
         | 
| 
       694 
639 
     | 
    
         
             
                    else:
         
     | 
| 
       695 
640 
     | 
    
         
             
                        # Allocate 2 scales for w1 and w3 respectively.
         
     | 
| 
         @@ -1039,13 +984,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): 
     | 
|
| 
       1039 
984 
     | 
    
         
             
                    from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
         
     | 
| 
       1040 
985 
     | 
    
         | 
| 
       1041 
986 
     | 
    
         
             
                    x = dispatch_output.hidden_states
         
     | 
| 
       1042 
     | 
    
         
            -
                    topk_output = dispatch_output.topk_output
         
     | 
| 
       1043 
987 
     | 
    
         
             
                    moe_runner_config = self.moe_runner_config
         
     | 
| 
       1044 
988 
     | 
    
         | 
| 
       1045 
989 
     | 
    
         
             
                    if use_intel_amx_backend(layer):
         
     | 
| 
       1046 
990 
     | 
    
         
             
                        from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
         
     | 
| 
       1047 
991 
     | 
    
         | 
| 
       1048 
     | 
    
         
            -
                        topk_weights, topk_ids, _ = topk_output
         
     | 
| 
      
 992 
     | 
    
         
            +
                        topk_weights, topk_ids, _ = dispatch_output.topk_output
         
     | 
| 
       1049 
993 
     | 
    
         
             
                        x, topk_weights = apply_topk_weights_cpu(
         
     | 
| 
       1050 
994 
     | 
    
         
             
                            moe_runner_config.apply_router_weight_on_input, topk_weights, x
         
     | 
| 
       1051 
995 
     | 
    
         
             
                        )
         
     | 
| 
         @@ -1072,17 +1016,17 @@ class Fp8MoEMethod(FusedMoEMethodBase): 
     | 
|
| 
       1072 
1016 
     | 
    
         
             
                        ret = self.maybe_apply_hip_fused_experts(
         
     | 
| 
       1073 
1017 
     | 
    
         
             
                            layer,
         
     | 
| 
       1074 
1018 
     | 
    
         
             
                            x,
         
     | 
| 
       1075 
     | 
    
         
            -
                            topk_output,
         
     | 
| 
      
 1019 
     | 
    
         
            +
                            dispatch_output.topk_output,
         
     | 
| 
       1076 
1020 
     | 
    
         
             
                            moe_runner_config.activation,
         
     | 
| 
       1077 
1021 
     | 
    
         
             
                            moe_runner_config.no_combine,
         
     | 
| 
       1078 
1022 
     | 
    
         
             
                        )
         
     | 
| 
       1079 
1023 
     | 
    
         
             
                        if ret is not None:
         
     | 
| 
       1080 
1024 
     | 
    
         
             
                            return StandardCombineInput(hidden_states=ret)
         
     | 
| 
       1081 
1025 
     | 
    
         | 
| 
       1082 
     | 
    
         
            -
                    if self. 
     | 
| 
      
 1026 
     | 
    
         
            +
                    if self._should_use_cutlass_fused_experts():
         
     | 
| 
       1083 
1027 
     | 
    
         
             
                        from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
         
     | 
| 
       1084 
1028 
     | 
    
         | 
| 
       1085 
     | 
    
         
            -
                        topk_weights, topk_ids, _ = topk_output
         
     | 
| 
      
 1029 
     | 
    
         
            +
                        topk_weights, topk_ids, _ = dispatch_output.topk_output
         
     | 
| 
       1086 
1030 
     | 
    
         
             
                        output = cutlass_fused_experts_fp8(
         
     | 
| 
       1087 
1031 
     | 
    
         
             
                            x,
         
     | 
| 
       1088 
1032 
     | 
    
         
             
                            layer.w13_weight.transpose(1, 2),
         
     | 
| 
         @@ -1171,6 +1115,67 @@ class Fp8MoEMethod(FusedMoEMethodBase): 
     | 
|
| 
       1171 
1115 
     | 
    
         | 
| 
       1172 
1116 
     | 
    
         
             
                    return self.runner.run(dispatch_output, quant_info)
         
     | 
| 
       1173 
1117 
     | 
    
         | 
| 
      
 1118 
     | 
    
         
            +
                def _should_use_cutlass_fused_experts(self) -> bool:
         
     | 
| 
      
 1119 
     | 
    
         
            +
                    """Decide whether to use Cutlass FP8 fused-experts path based on moe runner backend,
         
     | 
| 
      
 1120 
     | 
    
         
            +
                    with env var override via `SGLANG_CUTLASS_MOE`.
         
     | 
| 
      
 1121 
     | 
    
         
            +
                    """
         
     | 
| 
      
 1122 
     | 
    
         
            +
                    backend = get_moe_runner_backend()
         
     | 
| 
      
 1123 
     | 
    
         
            +
                    env_force = get_bool_env_var("SGLANG_CUTLASS_MOE")
         
     | 
| 
      
 1124 
     | 
    
         
            +
                    # TODO: remove env var in the future, it should be handled by moe runner backend
         
     | 
| 
      
 1125 
     | 
    
         
            +
                    if env_force:
         
     | 
| 
      
 1126 
     | 
    
         
            +
                        return True
         
     | 
| 
      
 1127 
     | 
    
         
            +
                    return (
         
     | 
| 
      
 1128 
     | 
    
         
            +
                        backend.is_flashinfer_cutlass()
         
     | 
| 
      
 1129 
     | 
    
         
            +
                        and self.cutlass_fp8_supported
         
     | 
| 
      
 1130 
     | 
    
         
            +
                        and self.block_quant
         
     | 
| 
      
 1131 
     | 
    
         
            +
                        and (is_sm100_supported() or is_sm90_supported())
         
     | 
| 
      
 1132 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1133 
     | 
    
         
            +
             
     | 
| 
      
 1134 
     | 
    
         
            +
                def _ensure_cutlass_buffers_initialized(self, layer: Module) -> None:
         
     | 
| 
      
 1135 
     | 
    
         
            +
                    if getattr(self, "_cutlass_buffers_ready", False):
         
     | 
| 
      
 1136 
     | 
    
         
            +
                        return
         
     | 
| 
      
 1137 
     | 
    
         
            +
             
     | 
| 
      
 1138 
     | 
    
         
            +
                    device = layer.w13_weight.device
         
     | 
| 
      
 1139 
     | 
    
         
            +
                    num_experts = layer.w13_weight.shape[0]
         
     | 
| 
      
 1140 
     | 
    
         
            +
                    hidden_size = layer.w2_weight.shape[1]
         
     | 
| 
      
 1141 
     | 
    
         
            +
                    intermediate_size_per_partition = layer.intermediate_size_per_partition
         
     | 
| 
      
 1142 
     | 
    
         
            +
             
     | 
| 
      
 1143 
     | 
    
         
            +
                    self.ab_strides1 = torch.full(
         
     | 
| 
      
 1144 
     | 
    
         
            +
                        (num_experts,), hidden_size, device=device, dtype=torch.int64
         
     | 
| 
      
 1145 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1146 
     | 
    
         
            +
                    self.c_strides1 = torch.full(
         
     | 
| 
      
 1147 
     | 
    
         
            +
                        (num_experts,),
         
     | 
| 
      
 1148 
     | 
    
         
            +
                        2 * intermediate_size_per_partition,
         
     | 
| 
      
 1149 
     | 
    
         
            +
                        device=device,
         
     | 
| 
      
 1150 
     | 
    
         
            +
                        dtype=torch.int64,
         
     | 
| 
      
 1151 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1152 
     | 
    
         
            +
                    self.ab_strides2 = torch.full(
         
     | 
| 
      
 1153 
     | 
    
         
            +
                        (num_experts,),
         
     | 
| 
      
 1154 
     | 
    
         
            +
                        intermediate_size_per_partition,
         
     | 
| 
      
 1155 
     | 
    
         
            +
                        device=device,
         
     | 
| 
      
 1156 
     | 
    
         
            +
                        dtype=torch.int64,
         
     | 
| 
      
 1157 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1158 
     | 
    
         
            +
                    self.c_strides2 = torch.full(
         
     | 
| 
      
 1159 
     | 
    
         
            +
                        (num_experts,), hidden_size, device=device, dtype=torch.int64
         
     | 
| 
      
 1160 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1161 
     | 
    
         
            +
                    self.workspace = torch.empty(90000, device=device, dtype=torch.uint8)
         
     | 
| 
      
 1162 
     | 
    
         
            +
                    self.a_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
         
     | 
| 
      
 1163 
     | 
    
         
            +
                    self.b_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
         
     | 
| 
      
 1164 
     | 
    
         
            +
                    self.out_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
         
     | 
| 
      
 1165 
     | 
    
         
            +
                    self.a_scales_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
         
     | 
| 
      
 1166 
     | 
    
         
            +
                    self.b_scales_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
         
     | 
| 
      
 1167 
     | 
    
         
            +
                    self.expert_offsets = torch.empty(
         
     | 
| 
      
 1168 
     | 
    
         
            +
                        num_experts + 1, device=device, dtype=torch.int32
         
     | 
| 
      
 1169 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1170 
     | 
    
         
            +
                    self.problem_sizes1 = torch.empty(
         
     | 
| 
      
 1171 
     | 
    
         
            +
                        num_experts, 3, device=device, dtype=torch.int32
         
     | 
| 
      
 1172 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1173 
     | 
    
         
            +
                    self.problem_sizes2 = torch.empty(
         
     | 
| 
      
 1174 
     | 
    
         
            +
                        num_experts, 3, device=device, dtype=torch.int32
         
     | 
| 
      
 1175 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1176 
     | 
    
         
            +
             
     | 
| 
      
 1177 
     | 
    
         
            +
                    self._cutlass_buffers_ready = True
         
     | 
| 
      
 1178 
     | 
    
         
            +
             
     | 
| 
       1174 
1179 
     | 
    
         
             
                def apply_with_router_logits(
         
     | 
| 
       1175 
1180 
     | 
    
         
             
                    self,
         
     | 
| 
       1176 
1181 
     | 
    
         
             
                    layer: torch.nn.Module,
         
     |