sglang 0.4.0.post2__py3-none-any.whl → 0.4.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +0 -12
- sglang/bench_one_batch.py +0 -12
- sglang/bench_serving.py +11 -2
- sglang/lang/backend/openai.py +10 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -0
- sglang/srt/layers/attention/flashinfer_backend.py +49 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +124 -99
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
- sglang/srt/layers/moe/topk.py +205 -0
- sglang/srt/layers/quantization/__init__.py +3 -3
- sglang/srt/layers/quantization/fp8.py +169 -32
- sglang/srt/layers/quantization/fp8_kernel.py +292 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/torchao_utils.py +11 -15
- sglang/srt/managers/schedule_batch.py +16 -10
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +13 -16
- sglang/srt/managers/tokenizer_manager.py +130 -111
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_loader/loader.py +22 -11
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +19 -0
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +23 -0
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_params.py +9 -2
- sglang/srt/server.py +21 -37
- sglang/srt/utils.py +33 -44
- sglang/test/test_block_fp8.py +341 -0
- sglang/version.py +1 -1
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/METADATA +4 -4
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/RECORD +52 -48
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/top_level.txt +0 -0
| @@ -0,0 +1,205 @@ | |
| 1 | 
            +
            # Copyright 2024 SGLang Team
         | 
| 2 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 3 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 4 | 
            +
            # You may obtain a copy of the License at
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 9 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 10 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 11 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 12 | 
            +
            # limitations under the License.
         | 
| 13 | 
            +
            # ==============================================================================
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from typing import Callable, Optional
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
            import torch.nn.functional as F
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def fused_topk_native(
         | 
| 22 | 
            +
                hidden_states: torch.Tensor,
         | 
| 23 | 
            +
                gating_output: torch.Tensor,
         | 
| 24 | 
            +
                topk: int,
         | 
| 25 | 
            +
                renormalize: bool,
         | 
| 26 | 
            +
            ):
         | 
| 27 | 
            +
                assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
         | 
| 28 | 
            +
                M, _ = hidden_states.shape
         | 
| 29 | 
            +
                topk_weights = torch.empty(
         | 
| 30 | 
            +
                    M, topk, dtype=torch.float32, device=hidden_states.device
         | 
| 31 | 
            +
                )
         | 
| 32 | 
            +
                topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
         | 
| 33 | 
            +
                topk_weights = F.softmax(gating_output.float(), dim=-1)
         | 
| 34 | 
            +
                topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
         | 
| 35 | 
            +
                if renormalize:
         | 
| 36 | 
            +
                    topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
         | 
| 37 | 
            +
                return topk_weights, topk_ids
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def fused_topk(
         | 
| 41 | 
            +
                hidden_states: torch.Tensor,
         | 
| 42 | 
            +
                gating_output: torch.Tensor,
         | 
| 43 | 
            +
                topk: int,
         | 
| 44 | 
            +
                renormalize: bool,
         | 
| 45 | 
            +
            ):
         | 
| 46 | 
            +
                from vllm import _custom_ops as ops
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                M, _ = hidden_states.shape
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                topk_weights = torch.empty(
         | 
| 53 | 
            +
                    M, topk, dtype=torch.float32, device=hidden_states.device
         | 
| 54 | 
            +
                )
         | 
| 55 | 
            +
                topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
         | 
| 56 | 
            +
                token_expert_indicies = torch.empty(
         | 
| 57 | 
            +
                    M, topk, dtype=torch.int32, device=hidden_states.device
         | 
| 58 | 
            +
                )
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                ops.topk_softmax(
         | 
| 61 | 
            +
                    topk_weights,
         | 
| 62 | 
            +
                    topk_ids,
         | 
| 63 | 
            +
                    token_expert_indicies,
         | 
| 64 | 
            +
                    gating_output.float(),
         | 
| 65 | 
            +
                )
         | 
| 66 | 
            +
                del token_expert_indicies
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                if renormalize:
         | 
| 69 | 
            +
                    topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                return topk_weights, topk_ids
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            # This is used by the Deepseek-V2 model
         | 
| 75 | 
            +
            def grouped_topk(
         | 
| 76 | 
            +
                hidden_states: torch.Tensor,
         | 
| 77 | 
            +
                gating_output: torch.Tensor,
         | 
| 78 | 
            +
                topk: int,
         | 
| 79 | 
            +
                renormalize: bool,
         | 
| 80 | 
            +
                num_expert_group: int = 0,
         | 
| 81 | 
            +
                topk_group: int = 0,
         | 
| 82 | 
            +
            ):
         | 
| 83 | 
            +
                assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                scores = torch.softmax(gating_output, dim=-1)
         | 
| 86 | 
            +
                num_token = scores.shape[0]
         | 
| 87 | 
            +
                group_scores = (
         | 
| 88 | 
            +
                    scores.view(num_token, num_expert_group, -1).max(dim=-1).values
         | 
| 89 | 
            +
                )  # [n, n_group]
         | 
| 90 | 
            +
                group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
         | 
| 91 | 
            +
                    1
         | 
| 92 | 
            +
                ]  # [n, top_k_group]
         | 
| 93 | 
            +
                group_mask = torch.zeros_like(group_scores)  # [n, n_group]
         | 
| 94 | 
            +
                group_mask.scatter_(1, group_idx, 1)  # [n, n_group]
         | 
| 95 | 
            +
                score_mask = (
         | 
| 96 | 
            +
                    group_mask.unsqueeze(-1)
         | 
| 97 | 
            +
                    .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
         | 
| 98 | 
            +
                    .reshape(num_token, -1)
         | 
| 99 | 
            +
                )  # [n, e]
         | 
| 100 | 
            +
                tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)  # [n, e]
         | 
| 101 | 
            +
                topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                if renormalize:
         | 
| 104 | 
            +
                    topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
             | 
| 109 | 
            +
            def biased_grouped_topk(
         | 
| 110 | 
            +
                hidden_states: torch.Tensor,
         | 
| 111 | 
            +
                gating_output: torch.Tensor,
         | 
| 112 | 
            +
                correction_bias: torch.Tensor,
         | 
| 113 | 
            +
                topk: int,
         | 
| 114 | 
            +
                renormalize: bool,
         | 
| 115 | 
            +
                num_expert_group: int = 0,
         | 
| 116 | 
            +
                topk_group: int = 0,
         | 
| 117 | 
            +
            ):
         | 
| 118 | 
            +
                assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                scores = gating_output.sigmoid()
         | 
| 121 | 
            +
                num_token = scores.shape[0]
         | 
| 122 | 
            +
                scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
         | 
| 123 | 
            +
                group_scores = (
         | 
| 124 | 
            +
                    scores_for_choice.view(num_token, num_expert_group, -1)
         | 
| 125 | 
            +
                    .topk(2, dim=-1)[0]
         | 
| 126 | 
            +
                    .sum(dim=-1)
         | 
| 127 | 
            +
                )  # [n, n_group]
         | 
| 128 | 
            +
                group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
         | 
| 129 | 
            +
                    1
         | 
| 130 | 
            +
                ]  # [n, top_k_group]
         | 
| 131 | 
            +
                group_mask = torch.zeros_like(group_scores)  # [n, n_group]
         | 
| 132 | 
            +
                group_mask.scatter_(1, group_idx, 1)  # [n, n_group]
         | 
| 133 | 
            +
                score_mask = (
         | 
| 134 | 
            +
                    group_mask.unsqueeze(-1)
         | 
| 135 | 
            +
                    .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
         | 
| 136 | 
            +
                    .reshape(num_token, -1)
         | 
| 137 | 
            +
                )  # [n, e]
         | 
| 138 | 
            +
                tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)  # [n, e]
         | 
| 139 | 
            +
                _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
         | 
| 140 | 
            +
                topk_weights = scores.gather(1, topk_ids)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                if renormalize:
         | 
| 143 | 
            +
                    topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def select_experts(
         | 
| 149 | 
            +
                hidden_states: torch.Tensor,
         | 
| 150 | 
            +
                router_logits: torch.Tensor,
         | 
| 151 | 
            +
                top_k: int,
         | 
| 152 | 
            +
                use_grouped_topk: bool,
         | 
| 153 | 
            +
                renormalize: bool,
         | 
| 154 | 
            +
                topk_group: Optional[int] = None,
         | 
| 155 | 
            +
                num_expert_group: Optional[int] = None,
         | 
| 156 | 
            +
                custom_routing_function: Optional[Callable] = None,
         | 
| 157 | 
            +
                correction_bias: Optional[torch.Tensor] = None,
         | 
| 158 | 
            +
                torch_native: bool = False,
         | 
| 159 | 
            +
            ):
         | 
| 160 | 
            +
                # DeekSeekv2 uses grouped_top_k
         | 
| 161 | 
            +
                if use_grouped_topk:
         | 
| 162 | 
            +
                    assert topk_group is not None
         | 
| 163 | 
            +
                    assert num_expert_group is not None
         | 
| 164 | 
            +
                    if correction_bias is None:
         | 
| 165 | 
            +
                        topk_weights, topk_ids = grouped_topk(
         | 
| 166 | 
            +
                            hidden_states=hidden_states,
         | 
| 167 | 
            +
                            gating_output=router_logits,
         | 
| 168 | 
            +
                            topk=top_k,
         | 
| 169 | 
            +
                            renormalize=renormalize,
         | 
| 170 | 
            +
                            num_expert_group=num_expert_group,
         | 
| 171 | 
            +
                            topk_group=topk_group,
         | 
| 172 | 
            +
                        )
         | 
| 173 | 
            +
                    else:
         | 
| 174 | 
            +
                        topk_weights, topk_ids = biased_grouped_topk(
         | 
| 175 | 
            +
                            hidden_states=hidden_states,
         | 
| 176 | 
            +
                            gating_output=router_logits,
         | 
| 177 | 
            +
                            correction_bias=correction_bias,
         | 
| 178 | 
            +
                            topk=top_k,
         | 
| 179 | 
            +
                            renormalize=renormalize,
         | 
| 180 | 
            +
                            num_expert_group=num_expert_group,
         | 
| 181 | 
            +
                            topk_group=topk_group,
         | 
| 182 | 
            +
                        )
         | 
| 183 | 
            +
                elif torch_native:
         | 
| 184 | 
            +
                    topk_weights, topk_ids = fused_topk_native(
         | 
| 185 | 
            +
                        hidden_states=hidden_states,
         | 
| 186 | 
            +
                        gating_output=router_logits,
         | 
| 187 | 
            +
                        topk=top_k,
         | 
| 188 | 
            +
                        renormalize=renormalize,
         | 
| 189 | 
            +
                    )
         | 
| 190 | 
            +
                elif custom_routing_function is None:
         | 
| 191 | 
            +
                    topk_weights, topk_ids = fused_topk(
         | 
| 192 | 
            +
                        hidden_states=hidden_states,
         | 
| 193 | 
            +
                        gating_output=router_logits,
         | 
| 194 | 
            +
                        topk=top_k,
         | 
| 195 | 
            +
                        renormalize=renormalize,
         | 
| 196 | 
            +
                    )
         | 
| 197 | 
            +
                else:
         | 
| 198 | 
            +
                    topk_weights, topk_ids = custom_routing_function(
         | 
| 199 | 
            +
                        hidden_states=hidden_states,
         | 
| 200 | 
            +
                        gating_output=router_logits,
         | 
| 201 | 
            +
                        topk=top_k,
         | 
| 202 | 
            +
                        renormalize=renormalize,
         | 
| 203 | 
            +
                    )
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                return topk_weights, topk_ids
         | 
| @@ -60,8 +60,8 @@ def fp8_get_quant_method(self, layer, prefix): | |
| 60 60 | 
             
                    is_layer_skipped,
         | 
| 61 61 | 
             
                )
         | 
| 62 62 |  | 
| 63 | 
            -
                from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
         | 
| 64 63 | 
             
                from sglang.srt.layers.linear import UnquantizedLinearMethod
         | 
| 64 | 
            +
                from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
         | 
| 65 65 | 
             
                from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
         | 
| 66 66 |  | 
| 67 67 | 
             
                if isinstance(layer, LinearBase):
         | 
| @@ -80,7 +80,7 @@ def gptq_get_quant_method(self, layer, prefix): | |
| 80 80 | 
             
                    GPTQMarlinMoEMethod,
         | 
| 81 81 | 
             
                )
         | 
| 82 82 |  | 
| 83 | 
            -
                from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
         | 
| 83 | 
            +
                from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
         | 
| 84 84 |  | 
| 85 85 | 
             
                if isinstance(layer, LinearBase):
         | 
| 86 86 | 
             
                    return GPTQMarlinLinearMethod(self)
         | 
| @@ -96,7 +96,7 @@ def awq_get_quant_method(self, layer, prefix): | |
| 96 96 | 
             
                    AWQMoEMethod,
         | 
| 97 97 | 
             
                )
         | 
| 98 98 |  | 
| 99 | 
            -
                from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
         | 
| 99 | 
            +
                from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
         | 
| 100 100 |  | 
| 101 101 | 
             
                if isinstance(layer, LinearBase):
         | 
| 102 102 | 
             
                    return AWQMarlinLinearMethod(self)
         | 
| @@ -9,6 +9,7 @@ import torch.nn.functional as F | |
| 9 9 | 
             
            from torch.nn import Module
         | 
| 10 10 | 
             
            from torch.nn.parameter import Parameter
         | 
| 11 11 | 
             
            from vllm import _custom_ops as ops
         | 
| 12 | 
            +
            from vllm.distributed import get_tensor_model_parallel_world_size
         | 
| 12 13 | 
             
            from vllm.model_executor.layers.linear import LinearBase
         | 
| 13 14 | 
             
            from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
         | 
| 14 15 | 
             
            from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
         | 
| @@ -26,13 +27,17 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( | |
| 26 27 | 
             
            )
         | 
| 27 28 | 
             
            from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
         | 
| 28 29 |  | 
| 29 | 
            -
            from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
         | 
| 30 30 | 
             
            from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
         | 
| 31 | 
            +
            from sglang.srt.layers.moe.fused_moe_triton.fused_moe import padding_size
         | 
| 31 32 | 
             
            from sglang.srt.layers.quantization.base_config import (
         | 
| 32 33 | 
             
                QuantizationConfig,
         | 
| 33 34 | 
             
                QuantizeMethodBase,
         | 
| 34 35 | 
             
            )
         | 
| 35 | 
            -
            from sglang.srt.layers.quantization.fp8_utils import  | 
| 36 | 
            +
            from sglang.srt.layers.quantization.fp8_utils import (
         | 
| 37 | 
            +
                BlockQuantScaleParameter,
         | 
| 38 | 
            +
                apply_w8a8_block_fp8_linear,
         | 
| 39 | 
            +
                normalize_e4m3fn_to_e4m3fnuz,
         | 
| 40 | 
            +
            )
         | 
| 36 41 | 
             
            from sglang.srt.utils import (
         | 
| 37 42 | 
             
                get_bool_env_var,
         | 
| 38 43 | 
             
                is_hip,
         | 
| @@ -53,6 +58,7 @@ class Fp8Config(QuantizationConfig): | |
| 53 58 | 
             
                    is_checkpoint_fp8_serialized: bool = False,
         | 
| 54 59 | 
             
                    activation_scheme: str = "dynamic",
         | 
| 55 60 | 
             
                    ignored_layers: Optional[List[str]] = None,
         | 
| 61 | 
            +
                    weight_block_size: List[int] = None,
         | 
| 56 62 | 
             
                ) -> None:
         | 
| 57 63 | 
             
                    self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
         | 
| 58 64 | 
             
                    if is_checkpoint_fp8_serialized:
         | 
| @@ -64,6 +70,20 @@ class Fp8Config(QuantizationConfig): | |
| 64 70 | 
             
                        raise ValueError(f"Unsupported activation scheme {activation_scheme}")
         | 
| 65 71 | 
             
                    self.activation_scheme = activation_scheme
         | 
| 66 72 | 
             
                    self.ignored_layers = ignored_layers or []
         | 
| 73 | 
            +
                    if weight_block_size is not None:
         | 
| 74 | 
            +
                        if not is_checkpoint_fp8_serialized:
         | 
| 75 | 
            +
                            raise ValueError(
         | 
| 76 | 
            +
                                f"The block-wise quantization only supports fp8-serialized checkpoint for now."
         | 
| 77 | 
            +
                            )
         | 
| 78 | 
            +
                        if len(weight_block_size) != 2:
         | 
| 79 | 
            +
                            raise ValueError(
         | 
| 80 | 
            +
                                f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions."
         | 
| 81 | 
            +
                            )
         | 
| 82 | 
            +
                        if activation_scheme != "dynamic":
         | 
| 83 | 
            +
                            raise ValueError(
         | 
| 84 | 
            +
                                f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme."
         | 
| 85 | 
            +
                            )
         | 
| 86 | 
            +
                    self.weight_block_size = weight_block_size
         | 
| 67 87 |  | 
| 68 88 | 
             
                @classmethod
         | 
| 69 89 | 
             
                def get_name(cls) -> str:
         | 
| @@ -87,10 +107,12 @@ class Fp8Config(QuantizationConfig): | |
| 87 107 | 
             
                    is_checkpoint_fp8_serialized = "fp8" in quant_method
         | 
| 88 108 | 
             
                    activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
         | 
| 89 109 | 
             
                    ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
         | 
| 110 | 
            +
                    weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
         | 
| 90 111 | 
             
                    return cls(
         | 
| 91 112 | 
             
                        is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
         | 
| 92 113 | 
             
                        activation_scheme=activation_scheme,
         | 
| 93 114 | 
             
                        ignored_layers=ignored_layers,
         | 
| 115 | 
            +
                        weight_block_size=weight_block_size,
         | 
| 94 116 | 
             
                    )
         | 
| 95 117 |  | 
| 96 118 | 
             
                def get_quant_method(
         | 
| @@ -98,7 +120,7 @@ class Fp8Config(QuantizationConfig): | |
| 98 120 | 
             
                ) -> Optional["QuantizeMethodBase"]:
         | 
| 99 121 | 
             
                    from vllm.attention.layer import Attention  # Avoid circular import
         | 
| 100 122 |  | 
| 101 | 
            -
                    from sglang.srt.layers.fused_moe_triton import FusedMoE
         | 
| 123 | 
            +
                    from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
         | 
| 102 124 |  | 
| 103 125 | 
             
                    if isinstance(layer, LinearBase):
         | 
| 104 126 | 
             
                        if is_layer_skipped(prefix, self.ignored_layers):
         | 
| @@ -143,6 +165,11 @@ class Fp8LinearMethod(LinearMethodBase): | |
| 143 165 | 
             
                    if is_hip():
         | 
| 144 166 | 
             
                        self.use_marlin = False
         | 
| 145 167 |  | 
| 168 | 
            +
                    self.block_quant = self.quant_config.weight_block_size is not None
         | 
| 169 | 
            +
                    if self.block_quant:
         | 
| 170 | 
            +
                        # Marlin doesn't support block-wise fp8
         | 
| 171 | 
            +
                        self.use_marlin = False
         | 
| 172 | 
            +
             | 
| 146 173 | 
             
                def create_weights(
         | 
| 147 174 | 
             
                    self,
         | 
| 148 175 | 
             
                    layer: torch.nn.Module,
         | 
| @@ -153,10 +180,35 @@ class Fp8LinearMethod(LinearMethodBase): | |
| 153 180 | 
             
                    params_dtype: torch.dtype,
         | 
| 154 181 | 
             
                    **extra_weight_attrs,
         | 
| 155 182 | 
             
                ):
         | 
| 156 | 
            -
                    del input_size, output_size
         | 
| 157 183 | 
             
                    output_size_per_partition = sum(output_partition_sizes)
         | 
| 158 184 | 
             
                    weight_loader = extra_weight_attrs.get("weight_loader")
         | 
| 159 185 |  | 
| 186 | 
            +
                    tp_size = get_tensor_model_parallel_world_size()
         | 
| 187 | 
            +
                    if self.block_quant:
         | 
| 188 | 
            +
                        block_n, block_k = (
         | 
| 189 | 
            +
                            self.quant_config.weight_block_size[0],
         | 
| 190 | 
            +
                            self.quant_config.weight_block_size[1],
         | 
| 191 | 
            +
                        )
         | 
| 192 | 
            +
                        # Required by row parallel
         | 
| 193 | 
            +
                        if tp_size > 1 and input_size // input_size_per_partition == tp_size:
         | 
| 194 | 
            +
                            if input_size_per_partition % block_k != 0:
         | 
| 195 | 
            +
                                raise ValueError(
         | 
| 196 | 
            +
                                    f"Weight input_size_per_partition = "
         | 
| 197 | 
            +
                                    f"{input_size_per_partition} is not divisible by "
         | 
| 198 | 
            +
                                    f"weight quantization block_k = {block_k}."
         | 
| 199 | 
            +
                                )
         | 
| 200 | 
            +
                        # Required by collum parallel or enabling merged weights
         | 
| 201 | 
            +
                        if (
         | 
| 202 | 
            +
                            tp_size > 1 and output_size // output_size_per_partition == tp_size
         | 
| 203 | 
            +
                        ) or len(output_partition_sizes) > 1:
         | 
| 204 | 
            +
                            for output_partition_size in output_partition_sizes:
         | 
| 205 | 
            +
                                if output_partition_size % block_n != 0:
         | 
| 206 | 
            +
                                    raise ValueError(
         | 
| 207 | 
            +
                                        f"Weight output_partition_size = "
         | 
| 208 | 
            +
                                        f"{output_partition_size} is not divisible by "
         | 
| 209 | 
            +
                                        f"weight quantization block_n = {block_n}."
         | 
| 210 | 
            +
                                    )
         | 
| 211 | 
            +
             | 
| 160 212 | 
             
                    layer.logical_widths = output_partition_sizes
         | 
| 161 213 |  | 
| 162 214 | 
             
                    layer.input_size_per_partition = input_size_per_partition
         | 
| @@ -184,13 +236,27 @@ class Fp8LinearMethod(LinearMethodBase): | |
| 184 236 | 
             
                    # Otherwise, wait until process_weights_after_loading.
         | 
| 185 237 | 
             
                    if self.quant_config.is_checkpoint_fp8_serialized:
         | 
| 186 238 | 
             
                        # WEIGHT SCALE
         | 
| 187 | 
            -
                         | 
| 188 | 
            -
                             | 
| 189 | 
            -
                             | 
| 190 | 
            -
             | 
| 191 | 
            -
             | 
| 192 | 
            -
             | 
| 193 | 
            -
             | 
| 239 | 
            +
                        if self.block_quant:
         | 
| 240 | 
            +
                            assert self.quant_config.activation_scheme == "dynamic"
         | 
| 241 | 
            +
                            scale = BlockQuantScaleParameter(
         | 
| 242 | 
            +
                                data=torch.empty(
         | 
| 243 | 
            +
                                    (output_size_per_partition + block_n - 1) // block_n,
         | 
| 244 | 
            +
                                    (input_size_per_partition + block_k - 1) // block_k,
         | 
| 245 | 
            +
                                    dtype=torch.float32,
         | 
| 246 | 
            +
                                ),
         | 
| 247 | 
            +
                                input_dim=1,
         | 
| 248 | 
            +
                                output_dim=0,
         | 
| 249 | 
            +
                                weight_loader=weight_loader,
         | 
| 250 | 
            +
                            )
         | 
| 251 | 
            +
                            scale[:] = torch.finfo(torch.float32).min
         | 
| 252 | 
            +
                            layer.register_parameter("weight_scale_inv", scale)
         | 
| 253 | 
            +
                        else:
         | 
| 254 | 
            +
                            scale = PerTensorScaleParameter(
         | 
| 255 | 
            +
                                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
         | 
| 256 | 
            +
                                weight_loader=weight_loader,
         | 
| 257 | 
            +
                            )
         | 
| 258 | 
            +
                            scale[:] = torch.finfo(torch.float32).min
         | 
| 259 | 
            +
                            layer.register_parameter("weight_scale", scale)
         | 
| 194 260 |  | 
| 195 261 | 
             
                        # INPUT ACTIVATION SCALE
         | 
| 196 262 | 
             
                        if self.quant_config.activation_scheme == "static":
         | 
| @@ -205,6 +271,9 @@ class Fp8LinearMethod(LinearMethodBase): | |
| 205 271 | 
             
                            layer.register_parameter("input_scale", None)
         | 
| 206 272 |  | 
| 207 273 | 
             
                def process_weights_after_loading(self, layer: Module) -> None:
         | 
| 274 | 
            +
                    # Block quant doesn't need to process weights after loading
         | 
| 275 | 
            +
                    if self.block_quant:
         | 
| 276 | 
            +
                        return
         | 
| 208 277 | 
             
                    layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
         | 
| 209 278 | 
             
                    # If checkpoint not serialized fp8, quantize the weights.
         | 
| 210 279 | 
             
                    if not self.quant_config.is_checkpoint_fp8_serialized:
         | 
| @@ -295,6 +364,16 @@ class Fp8LinearMethod(LinearMethodBase): | |
| 295 364 | 
             
                            bias=bias,
         | 
| 296 365 | 
             
                        )
         | 
| 297 366 |  | 
| 367 | 
            +
                    if self.block_quant:
         | 
| 368 | 
            +
                        return apply_w8a8_block_fp8_linear(
         | 
| 369 | 
            +
                            input=x,
         | 
| 370 | 
            +
                            weight=layer.weight,
         | 
| 371 | 
            +
                            block_size=self.quant_config.weight_block_size,
         | 
| 372 | 
            +
                            weight_scale=layer.weight_scale_inv,
         | 
| 373 | 
            +
                            input_scale=layer.input_scale,
         | 
| 374 | 
            +
                            bias=bias,
         | 
| 375 | 
            +
                        )
         | 
| 376 | 
            +
             | 
| 298 377 | 
             
                    return apply_fp8_linear(
         | 
| 299 378 | 
             
                        input=x,
         | 
| 300 379 | 
             
                        weight=layer.weight,
         | 
| @@ -320,7 +399,7 @@ class Fp8MoEMethod: | |
| 320 399 | 
             
                """
         | 
| 321 400 |  | 
| 322 401 | 
             
                def __new__(cls, *args, **kwargs):
         | 
| 323 | 
            -
                    from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase
         | 
| 402 | 
            +
                    from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
         | 
| 324 403 |  | 
| 325 404 | 
             
                    if not hasattr(cls, "_initialized"):
         | 
| 326 405 | 
             
                        original_init = cls.__init__
         | 
| @@ -339,6 +418,7 @@ class Fp8MoEMethod: | |
| 339 418 |  | 
| 340 419 | 
             
                def __init__(self, quant_config):
         | 
| 341 420 | 
             
                    self.quant_config = quant_config
         | 
| 421 | 
            +
                    self.block_quant = self.quant_config.weight_block_size is not None
         | 
| 342 422 |  | 
| 343 423 | 
             
                def create_weights(
         | 
| 344 424 | 
             
                    self,
         | 
| @@ -349,10 +429,32 @@ class Fp8MoEMethod: | |
| 349 429 | 
             
                    params_dtype: torch.dtype,
         | 
| 350 430 | 
             
                    **extra_weight_attrs,
         | 
| 351 431 | 
             
                ):
         | 
| 352 | 
            -
                    from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported
         | 
| 432 | 
            +
                    from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
         | 
| 353 433 |  | 
| 354 434 | 
             
                    if self.quant_config.is_checkpoint_fp8_serialized:
         | 
| 355 435 | 
             
                        params_dtype = torch.float8_e4m3fn
         | 
| 436 | 
            +
                    tp_size = get_tensor_model_parallel_world_size()
         | 
| 437 | 
            +
                    if self.block_quant:
         | 
| 438 | 
            +
                        block_n, block_k = (
         | 
| 439 | 
            +
                            self.quant_config.weight_block_size[0],
         | 
| 440 | 
            +
                            self.quant_config.weight_block_size[1],
         | 
| 441 | 
            +
                        )
         | 
| 442 | 
            +
                        # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
         | 
| 443 | 
            +
                        # Required by collum parallel or enabling merged weights
         | 
| 444 | 
            +
                        if intermediate_size % block_n != 0:
         | 
| 445 | 
            +
                            raise ValueError(
         | 
| 446 | 
            +
                                f"The output_size of gate's and up's weight = "
         | 
| 447 | 
            +
                                f"{intermediate_size} is not divisible by "
         | 
| 448 | 
            +
                                f"weight quantization block_n = {block_n}."
         | 
| 449 | 
            +
                            )
         | 
| 450 | 
            +
                        if tp_size > 1:
         | 
| 451 | 
            +
                            # Required by row parallel
         | 
| 452 | 
            +
                            if intermediate_size % block_k != 0:
         | 
| 453 | 
            +
                                raise ValueError(
         | 
| 454 | 
            +
                                    f"The input_size of down's weight = "
         | 
| 455 | 
            +
                                    f"{intermediate_size} is not divisible by "
         | 
| 456 | 
            +
                                    f"weight quantization block_k = {block_k}."
         | 
| 457 | 
            +
                                )
         | 
| 356 458 |  | 
| 357 459 | 
             
                    # WEIGHTS
         | 
| 358 460 | 
             
                    w13_weight = torch.nn.Parameter(
         | 
| @@ -374,21 +476,45 @@ class Fp8MoEMethod: | |
| 374 476 | 
             
                    set_weight_attrs(w2_weight, extra_weight_attrs)
         | 
| 375 477 |  | 
| 376 478 | 
             
                    # WEIGHT_SCALES
         | 
| 377 | 
            -
                     | 
| 378 | 
            -
             | 
| 379 | 
            -
             | 
| 380 | 
            -
             | 
| 381 | 
            -
             | 
| 382 | 
            -
             | 
| 383 | 
            -
             | 
| 384 | 
            -
             | 
| 385 | 
            -
             | 
| 386 | 
            -
             | 
| 387 | 
            -
             | 
| 479 | 
            +
                    if self.block_quant:
         | 
| 480 | 
            +
                        w13_weight_scale = torch.nn.Parameter(
         | 
| 481 | 
            +
                            torch.ones(
         | 
| 482 | 
            +
                                num_experts,
         | 
| 483 | 
            +
                                2 * ((intermediate_size + block_n - 1) // block_n),
         | 
| 484 | 
            +
                                (hidden_size + block_k - 1) // block_k,
         | 
| 485 | 
            +
                                dtype=torch.float32,
         | 
| 486 | 
            +
                            ),
         | 
| 487 | 
            +
                            requires_grad=False,
         | 
| 488 | 
            +
                        )
         | 
| 489 | 
            +
                        w2_weight_scale = torch.nn.Parameter(
         | 
| 490 | 
            +
                            torch.ones(
         | 
| 491 | 
            +
                                num_experts,
         | 
| 492 | 
            +
                                (hidden_size + block_n - 1) // block_n,
         | 
| 493 | 
            +
                                (intermediate_size + block_k - 1) // block_k,
         | 
| 494 | 
            +
                                dtype=torch.float32,
         | 
| 495 | 
            +
                            ),
         | 
| 496 | 
            +
                            requires_grad=False,
         | 
| 497 | 
            +
                        )
         | 
| 498 | 
            +
                        layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
         | 
| 499 | 
            +
                        layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
         | 
| 500 | 
            +
                        assert self.quant_config.activation_scheme == "dynamic"
         | 
| 501 | 
            +
                    else:
         | 
| 502 | 
            +
                        # Allocate 2 scales for w1 and w3 respectively.
         | 
| 503 | 
            +
                        # They will be combined to a single scale after weight loading.
         | 
| 504 | 
            +
                        w13_weight_scale = torch.nn.Parameter(
         | 
| 505 | 
            +
                            torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
         | 
| 506 | 
            +
                        )
         | 
| 507 | 
            +
                        w2_weight_scale = torch.nn.Parameter(
         | 
| 508 | 
            +
                            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
         | 
| 509 | 
            +
                        )
         | 
| 510 | 
            +
                        layer.register_parameter("w13_weight_scale", w13_weight_scale)
         | 
| 511 | 
            +
                        layer.register_parameter("w2_weight_scale", w2_weight_scale)
         | 
| 388 512 | 
             
                    # Add the quantization method used (per tensor/grouped/channel)
         | 
| 389 513 | 
             
                    # to ensure the weight scales are loaded in properly
         | 
| 390 514 | 
             
                    extra_weight_attrs.update(
         | 
| 391 | 
            -
                        {"quant_method": FusedMoeWeightScaleSupported. | 
| 515 | 
            +
                        {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
         | 
| 516 | 
            +
                        if self.block_quant
         | 
| 517 | 
            +
                        else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
         | 
| 392 518 | 
             
                    )
         | 
| 393 519 | 
             
                    # If loading fp8 checkpoint, pass the weight loaders.
         | 
| 394 520 | 
             
                    # If loading an fp16 checkpoint, do not (we will quantize in
         | 
| @@ -422,7 +548,9 @@ class Fp8MoEMethod: | |
| 422 548 | 
             
                        layer.w2_input_scale = None
         | 
| 423 549 |  | 
| 424 550 | 
             
                def process_weights_after_loading(self, layer: Module) -> None:
         | 
| 425 | 
            -
             | 
| 551 | 
            +
                    # Block quant doesn't need to process weights after loading
         | 
| 552 | 
            +
                    if self.block_quant:
         | 
| 553 | 
            +
                        return
         | 
| 426 554 | 
             
                    # If checkpoint is fp16 or bfloat16, quantize in place.
         | 
| 427 555 | 
             
                    if not self.quant_config.is_checkpoint_fp8_serialized:
         | 
| 428 556 | 
             
                        # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
         | 
| @@ -519,7 +647,6 @@ class Fp8MoEMethod: | |
| 519 647 | 
             
                                layer.w2_input_scale = torch.nn.Parameter(
         | 
| 520 648 | 
             
                                    w2_input_scale, requires_grad=False
         | 
| 521 649 | 
             
                                )
         | 
| 522 | 
            -
             | 
| 523 650 | 
             
                        # Fp8 moe kernel needs single weight scale for w13 per expert.
         | 
| 524 651 | 
             
                        # We take the max then dequant and requant each expert.
         | 
| 525 652 | 
             
                        assert layer.w13_weight_scale is not None
         | 
| @@ -566,12 +693,14 @@ class Fp8MoEMethod: | |
| 566 693 | 
             
                    topk_group: Optional[int] = None,
         | 
| 567 694 | 
             
                    num_expert_group: Optional[int] = None,
         | 
| 568 695 | 
             
                    custom_routing_function: Optional[Callable] = None,
         | 
| 696 | 
            +
                    correction_bias: Optional[torch.Tensor] = None,
         | 
| 569 697 | 
             
                ) -> torch.Tensor:
         | 
| 570 | 
            -
                    from sglang.srt.layers.fused_moe_triton import FusedMoE
         | 
| 571 | 
            -
                    from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
         | 
| 698 | 
            +
                    from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
         | 
| 699 | 
            +
                    from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
         | 
| 700 | 
            +
                    from sglang.srt.layers.moe.topk import select_experts
         | 
| 572 701 |  | 
| 573 702 | 
             
                    # Expert selection
         | 
| 574 | 
            -
                    topk_weights, topk_ids =  | 
| 703 | 
            +
                    topk_weights, topk_ids = select_experts(
         | 
| 575 704 | 
             
                        hidden_states=x,
         | 
| 576 705 | 
             
                        router_logits=router_logits,
         | 
| 577 706 | 
             
                        use_grouped_topk=use_grouped_topk,
         | 
| @@ -580,6 +709,7 @@ class Fp8MoEMethod: | |
| 580 709 | 
             
                        topk_group=topk_group,
         | 
| 581 710 | 
             
                        num_expert_group=num_expert_group,
         | 
| 582 711 | 
             
                        custom_routing_function=custom_routing_function,
         | 
| 712 | 
            +
                        correction_bias=correction_bias,
         | 
| 583 713 | 
             
                    )
         | 
| 584 714 |  | 
| 585 715 | 
             
                    # Expert fusion with FP8 quantization
         | 
| @@ -591,10 +721,17 @@ class Fp8MoEMethod: | |
| 591 721 | 
             
                        topk_ids=topk_ids,
         | 
| 592 722 | 
             
                        inplace=True,
         | 
| 593 723 | 
             
                        use_fp8_w8a8=True,
         | 
| 594 | 
            -
                        w1_scale= | 
| 595 | 
            -
             | 
| 724 | 
            +
                        w1_scale=(
         | 
| 725 | 
            +
                            layer.w13_weight_scale_inv
         | 
| 726 | 
            +
                            if self.block_quant
         | 
| 727 | 
            +
                            else layer.w13_weight_scale
         | 
| 728 | 
            +
                        ),
         | 
| 729 | 
            +
                        w2_scale=(
         | 
| 730 | 
            +
                            layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
         | 
| 731 | 
            +
                        ),
         | 
| 596 732 | 
             
                        a1_scale=layer.w13_input_scale,
         | 
| 597 733 | 
             
                        a2_scale=layer.w2_input_scale,
         | 
| 734 | 
            +
                        block_shape=self.quant_config.weight_block_size,
         | 
| 598 735 | 
             
                    )
         | 
| 599 736 |  | 
| 600 737 |  |