sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +1 -0
- sglang/bench_serving.py +9 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +51 -5
- sglang/srt/layers/attention/triton_backend.py +16 -25
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
- sglang/srt/layers/moe/topk.py +191 -0
- sglang/srt/layers/quantization/__init__.py +5 -50
- sglang/srt/layers/quantization/fp8.py +221 -36
- sglang/srt/layers/quantization/fp8_kernel.py +278 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/radix_attention.py +8 -1
- sglang/srt/layers/sampler.py +27 -5
- sglang/srt/layers/torchao_utils.py +31 -0
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +54 -34
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +171 -136
- sglang/srt/managers/tokenizer_manager.py +184 -133
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +25 -11
- sglang/srt/model_executor/model_runner.py +28 -14
- sglang/srt/model_parallel.py +66 -5
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +73 -9
- sglang/srt/models/llama.py +22 -0
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +8 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/server.py +2 -1
- sglang/srt/server_args.py +19 -9
- sglang/srt/utils.py +40 -54
- sglang/test/test_block_fp8.py +341 -0
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
| @@ -1,12 +1,15 @@ | |
| 1 1 | 
             
            # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
         | 
| 2 2 |  | 
| 3 3 | 
             
            import logging
         | 
| 4 | 
            +
            import os
         | 
| 4 5 | 
             
            from typing import Any, Callable, Dict, List, Optional
         | 
| 5 6 |  | 
| 6 7 | 
             
            import torch
         | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 7 9 | 
             
            from torch.nn import Module
         | 
| 8 10 | 
             
            from torch.nn.parameter import Parameter
         | 
| 9 11 | 
             
            from vllm import _custom_ops as ops
         | 
| 12 | 
            +
            from vllm.distributed import get_tensor_model_parallel_world_size
         | 
| 10 13 | 
             
            from vllm.model_executor.layers.linear import LinearBase
         | 
| 11 14 | 
             
            from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
         | 
| 12 15 | 
             
            from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
         | 
| @@ -24,17 +27,17 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( | |
| 24 27 | 
             
            )
         | 
| 25 28 | 
             
            from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
         | 
| 26 29 |  | 
| 27 | 
            -
            from sglang.srt.layers.fused_moe_triton import (
         | 
| 28 | 
            -
                FusedMoE,
         | 
| 29 | 
            -
                FusedMoEMethodBase,
         | 
| 30 | 
            -
                FusedMoeWeightScaleSupported,
         | 
| 31 | 
            -
            )
         | 
| 32 30 | 
             
            from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
         | 
| 31 | 
            +
            from sglang.srt.layers.moe.fused_moe_triton.fused_moe import padding_size
         | 
| 33 32 | 
             
            from sglang.srt.layers.quantization.base_config import (
         | 
| 34 33 | 
             
                QuantizationConfig,
         | 
| 35 34 | 
             
                QuantizeMethodBase,
         | 
| 36 35 | 
             
            )
         | 
| 37 | 
            -
            from sglang.srt.layers.quantization.fp8_utils import  | 
| 36 | 
            +
            from sglang.srt.layers.quantization.fp8_utils import (
         | 
| 37 | 
            +
                BlockQuantScaleParameter,
         | 
| 38 | 
            +
                apply_w8a8_block_fp8_linear,
         | 
| 39 | 
            +
                normalize_e4m3fn_to_e4m3fnuz,
         | 
| 40 | 
            +
            )
         | 
| 38 41 | 
             
            from sglang.srt.utils import (
         | 
| 39 42 | 
             
                get_bool_env_var,
         | 
| 40 43 | 
             
                is_hip,
         | 
| @@ -55,6 +58,7 @@ class Fp8Config(QuantizationConfig): | |
| 55 58 | 
             
                    is_checkpoint_fp8_serialized: bool = False,
         | 
| 56 59 | 
             
                    activation_scheme: str = "dynamic",
         | 
| 57 60 | 
             
                    ignored_layers: Optional[List[str]] = None,
         | 
| 61 | 
            +
                    weight_block_size: List[int] = None,
         | 
| 58 62 | 
             
                ) -> None:
         | 
| 59 63 | 
             
                    self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
         | 
| 60 64 | 
             
                    if is_checkpoint_fp8_serialized:
         | 
| @@ -66,6 +70,20 @@ class Fp8Config(QuantizationConfig): | |
| 66 70 | 
             
                        raise ValueError(f"Unsupported activation scheme {activation_scheme}")
         | 
| 67 71 | 
             
                    self.activation_scheme = activation_scheme
         | 
| 68 72 | 
             
                    self.ignored_layers = ignored_layers or []
         | 
| 73 | 
            +
                    if weight_block_size is not None:
         | 
| 74 | 
            +
                        if not is_checkpoint_fp8_serialized:
         | 
| 75 | 
            +
                            raise ValueError(
         | 
| 76 | 
            +
                                f"The block-wise quantization only supports fp8-serialized checkpoint for now."
         | 
| 77 | 
            +
                            )
         | 
| 78 | 
            +
                        if len(weight_block_size) != 2:
         | 
| 79 | 
            +
                            raise ValueError(
         | 
| 80 | 
            +
                                f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions."
         | 
| 81 | 
            +
                            )
         | 
| 82 | 
            +
                        if activation_scheme != "dynamic":
         | 
| 83 | 
            +
                            raise ValueError(
         | 
| 84 | 
            +
                                f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme."
         | 
| 85 | 
            +
                            )
         | 
| 86 | 
            +
                    self.weight_block_size = weight_block_size
         | 
| 69 87 |  | 
| 70 88 | 
             
                @classmethod
         | 
| 71 89 | 
             
                def get_name(cls) -> str:
         | 
| @@ -89,10 +107,12 @@ class Fp8Config(QuantizationConfig): | |
| 89 107 | 
             
                    is_checkpoint_fp8_serialized = "fp8" in quant_method
         | 
| 90 108 | 
             
                    activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
         | 
| 91 109 | 
             
                    ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
         | 
| 110 | 
            +
                    weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
         | 
| 92 111 | 
             
                    return cls(
         | 
| 93 112 | 
             
                        is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
         | 
| 94 113 | 
             
                        activation_scheme=activation_scheme,
         | 
| 95 114 | 
             
                        ignored_layers=ignored_layers,
         | 
| 115 | 
            +
                        weight_block_size=weight_block_size,
         | 
| 96 116 | 
             
                    )
         | 
| 97 117 |  | 
| 98 118 | 
             
                def get_quant_method(
         | 
| @@ -100,6 +120,8 @@ class Fp8Config(QuantizationConfig): | |
| 100 120 | 
             
                ) -> Optional["QuantizeMethodBase"]:
         | 
| 101 121 | 
             
                    from vllm.attention.layer import Attention  # Avoid circular import
         | 
| 102 122 |  | 
| 123 | 
            +
                    from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
         | 
| 124 | 
            +
             | 
| 103 125 | 
             
                    if isinstance(layer, LinearBase):
         | 
| 104 126 | 
             
                        if is_layer_skipped(prefix, self.ignored_layers):
         | 
| 105 127 | 
             
                            return UnquantizedLinearMethod()
         | 
| @@ -143,6 +165,11 @@ class Fp8LinearMethod(LinearMethodBase): | |
| 143 165 | 
             
                    if is_hip():
         | 
| 144 166 | 
             
                        self.use_marlin = False
         | 
| 145 167 |  | 
| 168 | 
            +
                    self.block_quant = self.quant_config.weight_block_size is not None
         | 
| 169 | 
            +
                    if self.block_quant:
         | 
| 170 | 
            +
                        # Marlin doesn't support block-wise fp8
         | 
| 171 | 
            +
                        self.use_marlin = False
         | 
| 172 | 
            +
             | 
| 146 173 | 
             
                def create_weights(
         | 
| 147 174 | 
             
                    self,
         | 
| 148 175 | 
             
                    layer: torch.nn.Module,
         | 
| @@ -153,10 +180,35 @@ class Fp8LinearMethod(LinearMethodBase): | |
| 153 180 | 
             
                    params_dtype: torch.dtype,
         | 
| 154 181 | 
             
                    **extra_weight_attrs,
         | 
| 155 182 | 
             
                ):
         | 
| 156 | 
            -
                    del input_size, output_size
         | 
| 157 183 | 
             
                    output_size_per_partition = sum(output_partition_sizes)
         | 
| 158 184 | 
             
                    weight_loader = extra_weight_attrs.get("weight_loader")
         | 
| 159 185 |  | 
| 186 | 
            +
                    tp_size = get_tensor_model_parallel_world_size()
         | 
| 187 | 
            +
                    if self.block_quant:
         | 
| 188 | 
            +
                        block_n, block_k = (
         | 
| 189 | 
            +
                            self.quant_config.weight_block_size[0],
         | 
| 190 | 
            +
                            self.quant_config.weight_block_size[1],
         | 
| 191 | 
            +
                        )
         | 
| 192 | 
            +
                        # Required by row parallel
         | 
| 193 | 
            +
                        if tp_size > 1 and input_size // input_size_per_partition == tp_size:
         | 
| 194 | 
            +
                            if input_size_per_partition % block_k != 0:
         | 
| 195 | 
            +
                                raise ValueError(
         | 
| 196 | 
            +
                                    f"Weight input_size_per_partition = "
         | 
| 197 | 
            +
                                    f"{input_size_per_partition} is not divisible by "
         | 
| 198 | 
            +
                                    f"weight quantization block_k = {block_k}."
         | 
| 199 | 
            +
                                )
         | 
| 200 | 
            +
                        # Required by collum parallel or enabling merged weights
         | 
| 201 | 
            +
                        if (
         | 
| 202 | 
            +
                            tp_size > 1 and output_size // output_size_per_partition == tp_size
         | 
| 203 | 
            +
                        ) or len(output_partition_sizes) > 1:
         | 
| 204 | 
            +
                            for output_partition_size in output_partition_sizes:
         | 
| 205 | 
            +
                                if output_partition_size % block_n != 0:
         | 
| 206 | 
            +
                                    raise ValueError(
         | 
| 207 | 
            +
                                        f"Weight output_partition_size = "
         | 
| 208 | 
            +
                                        f"{output_partition_size} is not divisible by "
         | 
| 209 | 
            +
                                        f"weight quantization block_n = {block_n}."
         | 
| 210 | 
            +
                                    )
         | 
| 211 | 
            +
             | 
| 160 212 | 
             
                    layer.logical_widths = output_partition_sizes
         | 
| 161 213 |  | 
| 162 214 | 
             
                    layer.input_size_per_partition = input_size_per_partition
         | 
| @@ -184,13 +236,27 @@ class Fp8LinearMethod(LinearMethodBase): | |
| 184 236 | 
             
                    # Otherwise, wait until process_weights_after_loading.
         | 
| 185 237 | 
             
                    if self.quant_config.is_checkpoint_fp8_serialized:
         | 
| 186 238 | 
             
                        # WEIGHT SCALE
         | 
| 187 | 
            -
                         | 
| 188 | 
            -
                             | 
| 189 | 
            -
                             | 
| 190 | 
            -
             | 
| 191 | 
            -
             | 
| 192 | 
            -
             | 
| 193 | 
            -
             | 
| 239 | 
            +
                        if self.block_quant:
         | 
| 240 | 
            +
                            assert self.quant_config.activation_scheme == "dynamic"
         | 
| 241 | 
            +
                            scale = BlockQuantScaleParameter(
         | 
| 242 | 
            +
                                data=torch.empty(
         | 
| 243 | 
            +
                                    (output_size_per_partition + block_n - 1) // block_n,
         | 
| 244 | 
            +
                                    (input_size_per_partition + block_k - 1) // block_k,
         | 
| 245 | 
            +
                                    dtype=torch.float32,
         | 
| 246 | 
            +
                                ),
         | 
| 247 | 
            +
                                input_dim=1,
         | 
| 248 | 
            +
                                output_dim=0,
         | 
| 249 | 
            +
                                weight_loader=weight_loader,
         | 
| 250 | 
            +
                            )
         | 
| 251 | 
            +
                            scale[:] = torch.finfo(torch.float32).min
         | 
| 252 | 
            +
                            layer.register_parameter("weight_scale_inv", scale)
         | 
| 253 | 
            +
                        else:
         | 
| 254 | 
            +
                            scale = PerTensorScaleParameter(
         | 
| 255 | 
            +
                                data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
         | 
| 256 | 
            +
                                weight_loader=weight_loader,
         | 
| 257 | 
            +
                            )
         | 
| 258 | 
            +
                            scale[:] = torch.finfo(torch.float32).min
         | 
| 259 | 
            +
                            layer.register_parameter("weight_scale", scale)
         | 
| 194 260 |  | 
| 195 261 | 
             
                        # INPUT ACTIVATION SCALE
         | 
| 196 262 | 
             
                        if self.quant_config.activation_scheme == "static":
         | 
| @@ -205,6 +271,9 @@ class Fp8LinearMethod(LinearMethodBase): | |
| 205 271 | 
             
                            layer.register_parameter("input_scale", None)
         | 
| 206 272 |  | 
| 207 273 | 
             
                def process_weights_after_loading(self, layer: Module) -> None:
         | 
| 274 | 
            +
                    # Block quant doesn't need to process weights after loading
         | 
| 275 | 
            +
                    if self.block_quant:
         | 
| 276 | 
            +
                        return
         | 
| 208 277 | 
             
                    layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
         | 
| 209 278 | 
             
                    # If checkpoint not serialized fp8, quantize the weights.
         | 
| 210 279 | 
             
                    if not self.quant_config.is_checkpoint_fp8_serialized:
         | 
| @@ -295,6 +364,16 @@ class Fp8LinearMethod(LinearMethodBase): | |
| 295 364 | 
             
                            bias=bias,
         | 
| 296 365 | 
             
                        )
         | 
| 297 366 |  | 
| 367 | 
            +
                    if self.block_quant:
         | 
| 368 | 
            +
                        return apply_w8a8_block_fp8_linear(
         | 
| 369 | 
            +
                            input=x,
         | 
| 370 | 
            +
                            weight=layer.weight,
         | 
| 371 | 
            +
                            block_size=self.quant_config.weight_block_size,
         | 
| 372 | 
            +
                            weight_scale=layer.weight_scale_inv,
         | 
| 373 | 
            +
                            input_scale=layer.input_scale,
         | 
| 374 | 
            +
                            bias=bias,
         | 
| 375 | 
            +
                        )
         | 
| 376 | 
            +
             | 
| 298 377 | 
             
                    return apply_fp8_linear(
         | 
| 299 378 | 
             
                        input=x,
         | 
| 300 379 | 
             
                        weight=layer.weight,
         | 
| @@ -306,7 +385,7 @@ class Fp8LinearMethod(LinearMethodBase): | |
| 306 385 | 
             
                    )
         | 
| 307 386 |  | 
| 308 387 |  | 
| 309 | 
            -
            class Fp8MoEMethod | 
| 388 | 
            +
            class Fp8MoEMethod:
         | 
| 310 389 | 
             
                """MoE method for FP8.
         | 
| 311 390 | 
             
                Supports loading FP8 checkpoints with static weight scale and
         | 
| 312 391 | 
             
                dynamic/static activation scale.
         | 
| @@ -319,8 +398,27 @@ class Fp8MoEMethod(FusedMoEMethodBase): | |
| 319 398 | 
             
                    quant_config: The quantization config.
         | 
| 320 399 | 
             
                """
         | 
| 321 400 |  | 
| 322 | 
            -
                def  | 
| 401 | 
            +
                def __new__(cls, *args, **kwargs):
         | 
| 402 | 
            +
                    from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                    if not hasattr(cls, "_initialized"):
         | 
| 405 | 
            +
                        original_init = cls.__init__
         | 
| 406 | 
            +
                        new_cls = type(
         | 
| 407 | 
            +
                            cls.__name__,
         | 
| 408 | 
            +
                            (FusedMoEMethodBase,),
         | 
| 409 | 
            +
                            {
         | 
| 410 | 
            +
                                "__init__": original_init,
         | 
| 411 | 
            +
                                **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
         | 
| 412 | 
            +
                            },
         | 
| 413 | 
            +
                        )
         | 
| 414 | 
            +
                        obj = super(new_cls, new_cls).__new__(new_cls)
         | 
| 415 | 
            +
                        obj.__init__(*args, **kwargs)
         | 
| 416 | 
            +
                        return obj
         | 
| 417 | 
            +
                    return super().__new__(cls)
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                def __init__(self, quant_config):
         | 
| 323 420 | 
             
                    self.quant_config = quant_config
         | 
| 421 | 
            +
                    self.block_quant = self.quant_config.weight_block_size is not None
         | 
| 324 422 |  | 
| 325 423 | 
             
                def create_weights(
         | 
| 326 424 | 
             
                    self,
         | 
| @@ -331,9 +429,32 @@ class Fp8MoEMethod(FusedMoEMethodBase): | |
| 331 429 | 
             
                    params_dtype: torch.dtype,
         | 
| 332 430 | 
             
                    **extra_weight_attrs,
         | 
| 333 431 | 
             
                ):
         | 
| 432 | 
            +
                    from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
         | 
| 334 433 |  | 
| 335 434 | 
             
                    if self.quant_config.is_checkpoint_fp8_serialized:
         | 
| 336 435 | 
             
                        params_dtype = torch.float8_e4m3fn
         | 
| 436 | 
            +
                    tp_size = get_tensor_model_parallel_world_size()
         | 
| 437 | 
            +
                    if self.block_quant:
         | 
| 438 | 
            +
                        block_n, block_k = (
         | 
| 439 | 
            +
                            self.quant_config.weight_block_size[0],
         | 
| 440 | 
            +
                            self.quant_config.weight_block_size[1],
         | 
| 441 | 
            +
                        )
         | 
| 442 | 
            +
                        # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
         | 
| 443 | 
            +
                        # Required by collum parallel or enabling merged weights
         | 
| 444 | 
            +
                        if intermediate_size % block_n != 0:
         | 
| 445 | 
            +
                            raise ValueError(
         | 
| 446 | 
            +
                                f"The output_size of gate's and up's weight = "
         | 
| 447 | 
            +
                                f"{intermediate_size} is not divisible by "
         | 
| 448 | 
            +
                                f"weight quantization block_n = {block_n}."
         | 
| 449 | 
            +
                            )
         | 
| 450 | 
            +
                        if tp_size > 1:
         | 
| 451 | 
            +
                            # Required by row parallel
         | 
| 452 | 
            +
                            if intermediate_size % block_k != 0:
         | 
| 453 | 
            +
                                raise ValueError(
         | 
| 454 | 
            +
                                    f"The input_size of down's weight = "
         | 
| 455 | 
            +
                                    f"{intermediate_size} is not divisible by "
         | 
| 456 | 
            +
                                    f"weight quantization block_k = {block_k}."
         | 
| 457 | 
            +
                                )
         | 
| 337 458 |  | 
| 338 459 | 
             
                    # WEIGHTS
         | 
| 339 460 | 
             
                    w13_weight = torch.nn.Parameter(
         | 
| @@ -355,21 +476,45 @@ class Fp8MoEMethod(FusedMoEMethodBase): | |
| 355 476 | 
             
                    set_weight_attrs(w2_weight, extra_weight_attrs)
         | 
| 356 477 |  | 
| 357 478 | 
             
                    # WEIGHT_SCALES
         | 
| 358 | 
            -
                     | 
| 359 | 
            -
             | 
| 360 | 
            -
             | 
| 361 | 
            -
             | 
| 362 | 
            -
             | 
| 363 | 
            -
             | 
| 364 | 
            -
             | 
| 365 | 
            -
             | 
| 366 | 
            -
             | 
| 367 | 
            -
             | 
| 368 | 
            -
             | 
| 479 | 
            +
                    if self.block_quant:
         | 
| 480 | 
            +
                        w13_weight_scale = torch.nn.Parameter(
         | 
| 481 | 
            +
                            torch.ones(
         | 
| 482 | 
            +
                                num_experts,
         | 
| 483 | 
            +
                                2 * ((intermediate_size + block_n - 1) // block_n),
         | 
| 484 | 
            +
                                (hidden_size + block_k - 1) // block_k,
         | 
| 485 | 
            +
                                dtype=torch.float32,
         | 
| 486 | 
            +
                            ),
         | 
| 487 | 
            +
                            requires_grad=False,
         | 
| 488 | 
            +
                        )
         | 
| 489 | 
            +
                        w2_weight_scale = torch.nn.Parameter(
         | 
| 490 | 
            +
                            torch.ones(
         | 
| 491 | 
            +
                                num_experts,
         | 
| 492 | 
            +
                                (hidden_size + block_n - 1) // block_n,
         | 
| 493 | 
            +
                                (intermediate_size + block_k - 1) // block_k,
         | 
| 494 | 
            +
                                dtype=torch.float32,
         | 
| 495 | 
            +
                            ),
         | 
| 496 | 
            +
                            requires_grad=False,
         | 
| 497 | 
            +
                        )
         | 
| 498 | 
            +
                        layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
         | 
| 499 | 
            +
                        layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
         | 
| 500 | 
            +
                        assert self.quant_config.activation_scheme == "dynamic"
         | 
| 501 | 
            +
                    else:
         | 
| 502 | 
            +
                        # Allocate 2 scales for w1 and w3 respectively.
         | 
| 503 | 
            +
                        # They will be combined to a single scale after weight loading.
         | 
| 504 | 
            +
                        w13_weight_scale = torch.nn.Parameter(
         | 
| 505 | 
            +
                            torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
         | 
| 506 | 
            +
                        )
         | 
| 507 | 
            +
                        w2_weight_scale = torch.nn.Parameter(
         | 
| 508 | 
            +
                            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
         | 
| 509 | 
            +
                        )
         | 
| 510 | 
            +
                        layer.register_parameter("w13_weight_scale", w13_weight_scale)
         | 
| 511 | 
            +
                        layer.register_parameter("w2_weight_scale", w2_weight_scale)
         | 
| 369 512 | 
             
                    # Add the quantization method used (per tensor/grouped/channel)
         | 
| 370 513 | 
             
                    # to ensure the weight scales are loaded in properly
         | 
| 371 514 | 
             
                    extra_weight_attrs.update(
         | 
| 372 | 
            -
                        {"quant_method": FusedMoeWeightScaleSupported. | 
| 515 | 
            +
                        {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
         | 
| 516 | 
            +
                        if self.block_quant
         | 
| 517 | 
            +
                        else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
         | 
| 373 518 | 
             
                    )
         | 
| 374 519 | 
             
                    # If loading fp8 checkpoint, pass the weight loaders.
         | 
| 375 520 | 
             
                    # If loading an fp16 checkpoint, do not (we will quantize in
         | 
| @@ -403,8 +548,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): | |
| 403 548 | 
             
                        layer.w2_input_scale = None
         | 
| 404 549 |  | 
| 405 550 | 
             
                def process_weights_after_loading(self, layer: Module) -> None:
         | 
| 406 | 
            -
             | 
| 407 | 
            -
                     | 
| 551 | 
            +
                    # Block quant doesn't need to process weights after loading
         | 
| 552 | 
            +
                    if self.block_quant:
         | 
| 553 | 
            +
                        return
         | 
| 554 | 
            +
                    # If checkpoint is fp16 or bfloat16, quantize in place.
         | 
| 408 555 | 
             
                    if not self.quant_config.is_checkpoint_fp8_serialized:
         | 
| 409 556 | 
             
                        # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
         | 
| 410 557 | 
             
                        fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
         | 
| @@ -428,6 +575,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): | |
| 428 575 | 
             
                            )
         | 
| 429 576 | 
             
                        layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
         | 
| 430 577 | 
             
                        layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                        # If ROCm, apply weight padding (min. Mem channel contention) only if set
         | 
| 580 | 
            +
                        if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
         | 
| 581 | 
            +
                            layer.w13_weight = torch.nn.Parameter(
         | 
| 582 | 
            +
                                F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
         | 
| 583 | 
            +
                                requires_grad=False,
         | 
| 584 | 
            +
                            )
         | 
| 585 | 
            +
                            torch.cuda.empty_cache()
         | 
| 586 | 
            +
                            layer.w2_weight = torch.nn.Parameter(
         | 
| 587 | 
            +
                                F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
         | 
| 588 | 
            +
                                requires_grad=False,
         | 
| 589 | 
            +
                            )
         | 
| 590 | 
            +
                            torch.cuda.empty_cache()
         | 
| 431 591 | 
             
                        return
         | 
| 432 592 |  | 
| 433 593 | 
             
                    # If checkpoint is fp8, we need to handle that the
         | 
| @@ -456,6 +616,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): | |
| 456 616 | 
             
                            layer.w2_input_scale = torch.nn.Parameter(
         | 
| 457 617 | 
             
                                layer.w2_input_scale.max(), requires_grad=False
         | 
| 458 618 | 
             
                            )
         | 
| 619 | 
            +
             | 
| 459 620 | 
             
                        # If ROCm, normalize the weights and scales to e4m3fnuz
         | 
| 460 621 | 
             
                        if is_hip():
         | 
| 461 622 | 
             
                            # Normalize the weights and scales
         | 
| @@ -486,7 +647,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): | |
| 486 647 | 
             
                                layer.w2_input_scale = torch.nn.Parameter(
         | 
| 487 648 | 
             
                                    w2_input_scale, requires_grad=False
         | 
| 488 649 | 
             
                                )
         | 
| 489 | 
            -
             | 
| 490 650 | 
             
                        # Fp8 moe kernel needs single weight scale for w13 per expert.
         | 
| 491 651 | 
             
                        # We take the max then dequant and requant each expert.
         | 
| 492 652 | 
             
                        assert layer.w13_weight_scale is not None
         | 
| @@ -507,6 +667,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): | |
| 507 667 | 
             
                        layer.w13_weight_scale = torch.nn.Parameter(
         | 
| 508 668 | 
             
                            max_w13_scales, requires_grad=False
         | 
| 509 669 | 
             
                        )
         | 
| 670 | 
            +
             | 
| 671 | 
            +
                        # If ROCm, apply weight padding (min. Mem channel contention) only if set
         | 
| 672 | 
            +
                        if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
         | 
| 673 | 
            +
                            layer.w13_weight = torch.nn.Parameter(
         | 
| 674 | 
            +
                                F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
         | 
| 675 | 
            +
                                requires_grad=False,
         | 
| 676 | 
            +
                            )
         | 
| 677 | 
            +
                            torch.cuda.empty_cache()
         | 
| 678 | 
            +
                            layer.w2_weight = torch.nn.Parameter(
         | 
| 679 | 
            +
                                F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
         | 
| 680 | 
            +
                                requires_grad=False,
         | 
| 681 | 
            +
                            )
         | 
| 682 | 
            +
                            torch.cuda.empty_cache()
         | 
| 510 683 | 
             
                        return
         | 
| 511 684 |  | 
| 512 685 | 
             
                def apply(
         | 
| @@ -520,11 +693,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): | |
| 520 693 | 
             
                    topk_group: Optional[int] = None,
         | 
| 521 694 | 
             
                    num_expert_group: Optional[int] = None,
         | 
| 522 695 | 
             
                    custom_routing_function: Optional[Callable] = None,
         | 
| 696 | 
            +
                    correction_bias: Optional[torch.Tensor] = None,
         | 
| 523 697 | 
             
                ) -> torch.Tensor:
         | 
| 698 | 
            +
                    from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
         | 
| 699 | 
            +
                    from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
         | 
| 700 | 
            +
                    from sglang.srt.layers.moe.topk import select_experts
         | 
| 524 701 |  | 
| 525 | 
            -
                     | 
| 526 | 
            -
             | 
| 527 | 
            -
                    topk_weights, topk_ids = FusedMoE.select_experts(
         | 
| 702 | 
            +
                    # Expert selection
         | 
| 703 | 
            +
                    topk_weights, topk_ids = select_experts(
         | 
| 528 704 | 
             
                        hidden_states=x,
         | 
| 529 705 | 
             
                        router_logits=router_logits,
         | 
| 530 706 | 
             
                        use_grouped_topk=use_grouped_topk,
         | 
| @@ -533,8 +709,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): | |
| 533 709 | 
             
                        topk_group=topk_group,
         | 
| 534 710 | 
             
                        num_expert_group=num_expert_group,
         | 
| 535 711 | 
             
                        custom_routing_function=custom_routing_function,
         | 
| 712 | 
            +
                        correction_bias=correction_bias,
         | 
| 536 713 | 
             
                    )
         | 
| 537 714 |  | 
| 715 | 
            +
                    # Expert fusion with FP8 quantization
         | 
| 538 716 | 
             
                    return fused_experts(
         | 
| 539 717 | 
             
                        x,
         | 
| 540 718 | 
             
                        layer.w13_weight,
         | 
| @@ -543,10 +721,17 @@ class Fp8MoEMethod(FusedMoEMethodBase): | |
| 543 721 | 
             
                        topk_ids=topk_ids,
         | 
| 544 722 | 
             
                        inplace=True,
         | 
| 545 723 | 
             
                        use_fp8_w8a8=True,
         | 
| 546 | 
            -
                        w1_scale= | 
| 547 | 
            -
             | 
| 724 | 
            +
                        w1_scale=(
         | 
| 725 | 
            +
                            layer.w13_weight_scale_inv
         | 
| 726 | 
            +
                            if self.block_quant
         | 
| 727 | 
            +
                            else layer.w13_weight_scale
         | 
| 728 | 
            +
                        ),
         | 
| 729 | 
            +
                        w2_scale=(
         | 
| 730 | 
            +
                            layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
         | 
| 731 | 
            +
                        ),
         | 
| 548 732 | 
             
                        a1_scale=layer.w13_input_scale,
         | 
| 549 733 | 
             
                        a2_scale=layer.w2_input_scale,
         | 
| 734 | 
            +
                        block_shape=self.quant_config.weight_block_size,
         | 
| 550 735 | 
             
                    )
         | 
| 551 736 |  | 
| 552 737 |  |