sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +16 -7
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +8 -1
- sglang/srt/layers/attention/flashinfer_backend.py +21 -5
- sglang/srt/layers/linear.py +89 -47
- sglang/srt/layers/logits_processor.py +6 -6
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +439 -0
- sglang/srt/layers/quantization/__init__.py +5 -2
- sglang/srt/layers/quantization/fp8.py +107 -53
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/vocab_parallel_embedding.py +16 -3
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/configure_logging.py +43 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/detokenizer_manager.py +0 -2
- sglang/srt/managers/io_struct.py +29 -13
- sglang/srt/managers/schedule_batch.py +7 -1
- sglang/srt/managers/scheduler.py +58 -15
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +109 -45
- sglang/srt/mem_cache/memory_pool.py +313 -53
- sglang/srt/metrics/collector.py +32 -35
- sglang/srt/model_executor/cuda_graph_runner.py +14 -7
- sglang/srt/model_executor/forward_batch_info.py +20 -15
- sglang/srt/model_executor/model_runner.py +53 -10
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/grok.py +25 -16
- sglang/srt/models/llama.py +46 -4
- sglang/srt/models/qwen2.py +11 -0
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +15 -5
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +125 -69
- sglang/srt/server_args.py +39 -19
- sglang/srt/speculative/eagle_utils.py +93 -85
- sglang/srt/speculative/eagle_worker.py +48 -33
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +61 -5
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
| @@ -1,7 +1,6 @@ | |
| 1 1 | 
             
            # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
         | 
| 2 2 |  | 
| 3 3 | 
             
            import logging
         | 
| 4 | 
            -
            import os
         | 
| 5 4 | 
             
            from typing import Any, Callable, Dict, List, Optional
         | 
| 6 5 |  | 
| 7 6 | 
             
            import torch
         | 
| @@ -25,9 +24,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( | |
| 25 24 | 
             
                per_tensor_dequantize,
         | 
| 26 25 | 
             
                requantize_with_max_scale,
         | 
| 27 26 | 
             
            )
         | 
| 28 | 
            -
            from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
         | 
| 29 27 |  | 
| 30 28 | 
             
            from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
         | 
| 29 | 
            +
            from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
         | 
| 31 30 | 
             
            from sglang.srt.layers.quantization.base_config import (
         | 
| 32 31 | 
             
                QuantizationConfig,
         | 
| 33 32 | 
             
                QuantizeMethodBase,
         | 
| @@ -40,12 +39,15 @@ from sglang.srt.layers.quantization.fp8_utils import ( | |
| 40 39 | 
             
            from sglang.srt.utils import (
         | 
| 41 40 | 
             
                get_bool_env_var,
         | 
| 42 41 | 
             
                is_hip,
         | 
| 42 | 
            +
                permute_weight,
         | 
| 43 43 | 
             
                print_warning_once,
         | 
| 44 44 | 
             
                set_weight_attrs,
         | 
| 45 45 | 
             
            )
         | 
| 46 46 |  | 
| 47 47 | 
             
            ACTIVATION_SCHEMES = ["static", "dynamic"]
         | 
| 48 48 |  | 
| 49 | 
            +
            is_hip_ = is_hip()
         | 
| 50 | 
            +
             | 
| 49 51 | 
             
            logger = logging.getLogger(__name__)
         | 
| 50 52 |  | 
| 51 53 |  | 
| @@ -161,7 +163,7 @@ class Fp8LinearMethod(LinearMethodBase): | |
| 161 163 | 
             
                    # kernel for fast weight-only FP8 quantization
         | 
| 162 164 | 
             
                    self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
         | 
| 163 165 | 
             
                    # Disable marlin for ROCm
         | 
| 164 | 
            -
                    if  | 
| 166 | 
            +
                    if is_hip_:
         | 
| 165 167 | 
             
                        self.use_marlin = False
         | 
| 166 168 |  | 
| 167 169 | 
             
                    self.block_quant = self.quant_config.weight_block_size is not None
         | 
| @@ -273,7 +275,7 @@ class Fp8LinearMethod(LinearMethodBase): | |
| 273 275 | 
             
                    # Block quant doesn't need to process weights after loading
         | 
| 274 276 | 
             
                    if self.block_quant:
         | 
| 275 277 | 
             
                        # If ROCm, normalize the weights and scales to e4m3fnuz
         | 
| 276 | 
            -
                        if  | 
| 278 | 
            +
                        if is_hip_:
         | 
| 277 279 | 
             
                            # activation_scheme: dynamic
         | 
| 278 280 | 
             
                            weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
         | 
| 279 281 | 
             
                                weight=layer.weight,
         | 
| @@ -330,7 +332,7 @@ class Fp8LinearMethod(LinearMethodBase): | |
| 330 332 | 
             
                            weight_scale = layer.weight_scale
         | 
| 331 333 |  | 
| 332 334 | 
             
                            # If ROCm, normalize the weights and scales to e4m3fnuz
         | 
| 333 | 
            -
                            if  | 
| 335 | 
            +
                            if is_hip_:
         | 
| 334 336 | 
             
                                weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
         | 
| 335 337 | 
             
                                    weight=weight,
         | 
| 336 338 | 
             
                                    weight_scale=weight_scale,
         | 
| @@ -567,7 +569,7 @@ class Fp8MoEMethod: | |
| 567 569 | 
             
                    # Block quant doesn't need to process weights after loading
         | 
| 568 570 | 
             
                    if self.block_quant:
         | 
| 569 571 | 
             
                        # If ROCm, normalize the weights and scales to e4m3fnuz
         | 
| 570 | 
            -
                        if  | 
| 572 | 
            +
                        if is_hip_:
         | 
| 571 573 | 
             
                            # activation_scheme: dynamic
         | 
| 572 574 | 
             
                            w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
         | 
| 573 575 | 
             
                                weight=layer.w13_weight,
         | 
| @@ -594,7 +596,7 @@ class Fp8MoEMethod: | |
| 594 596 | 
             
                    # If checkpoint is fp16 or bfloat16, quantize in place.
         | 
| 595 597 | 
             
                    if not self.quant_config.is_checkpoint_fp8_serialized:
         | 
| 596 598 | 
             
                        # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
         | 
| 597 | 
            -
                        fp8_dtype = torch.float8_e4m3fnuz if  | 
| 599 | 
            +
                        fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
         | 
| 598 600 | 
             
                        w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
         | 
| 599 601 | 
             
                        w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
         | 
| 600 602 |  | 
| @@ -616,18 +618,30 @@ class Fp8MoEMethod: | |
| 616 618 | 
             
                        layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
         | 
| 617 619 | 
             
                        layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
         | 
| 618 620 |  | 
| 619 | 
            -
                         | 
| 620 | 
            -
             | 
| 621 | 
            -
             | 
| 622 | 
            -
             | 
| 623 | 
            -
             | 
| 624 | 
            -
             | 
| 625 | 
            -
             | 
| 626 | 
            -
             | 
| 627 | 
            -
             | 
| 628 | 
            -
             | 
| 629 | 
            -
             | 
| 630 | 
            -
             | 
| 621 | 
            +
                        if is_hip_:
         | 
| 622 | 
            +
                            if get_bool_env_var("CK_MOE"):
         | 
| 623 | 
            +
                                layer.w13_weight = torch.nn.Parameter(
         | 
| 624 | 
            +
                                    permute_weight(layer.w13_weight.data),
         | 
| 625 | 
            +
                                    requires_grad=False,
         | 
| 626 | 
            +
                                )
         | 
| 627 | 
            +
                                torch.cuda.empty_cache()
         | 
| 628 | 
            +
                                layer.w2_weight = torch.nn.Parameter(
         | 
| 629 | 
            +
                                    permute_weight(layer.w2_weight.data),
         | 
| 630 | 
            +
                                    requires_grad=False,
         | 
| 631 | 
            +
                                )
         | 
| 632 | 
            +
                                torch.cuda.empty_cache()
         | 
| 633 | 
            +
                            elif get_bool_env_var("MOE_PADDING"):
         | 
| 634 | 
            +
                                # If ROCm, apply weight padding (min. Mem channel contention) only if set
         | 
| 635 | 
            +
                                layer.w13_weight = torch.nn.Parameter(
         | 
| 636 | 
            +
                                    F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
         | 
| 637 | 
            +
                                    requires_grad=False,
         | 
| 638 | 
            +
                                )
         | 
| 639 | 
            +
                                torch.cuda.empty_cache()
         | 
| 640 | 
            +
                                layer.w2_weight = torch.nn.Parameter(
         | 
| 641 | 
            +
                                    F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
         | 
| 642 | 
            +
                                    requires_grad=False,
         | 
| 643 | 
            +
                                )
         | 
| 644 | 
            +
                                torch.cuda.empty_cache()
         | 
| 631 645 | 
             
                        return
         | 
| 632 646 |  | 
| 633 647 | 
             
                    # If checkpoint is fp8, we need to handle that the
         | 
| @@ -658,7 +672,7 @@ class Fp8MoEMethod: | |
| 658 672 | 
             
                            )
         | 
| 659 673 |  | 
| 660 674 | 
             
                        # If ROCm, normalize the weights and scales to e4m3fnuz
         | 
| 661 | 
            -
                        if  | 
| 675 | 
            +
                        if is_hip_:
         | 
| 662 676 | 
             
                            # Normalize the weights and scales
         | 
| 663 677 | 
             
                            w13_weight, w13_weight_scale, w13_input_scale = (
         | 
| 664 678 | 
             
                                normalize_e4m3fn_to_e4m3fnuz(
         | 
| @@ -708,18 +722,30 @@ class Fp8MoEMethod: | |
| 708 722 | 
             
                            max_w13_scales, requires_grad=False
         | 
| 709 723 | 
             
                        )
         | 
| 710 724 |  | 
| 711 | 
            -
                         | 
| 712 | 
            -
             | 
| 713 | 
            -
             | 
| 714 | 
            -
             | 
| 715 | 
            -
             | 
| 716 | 
            -
             | 
| 717 | 
            -
             | 
| 718 | 
            -
             | 
| 719 | 
            -
             | 
| 720 | 
            -
             | 
| 721 | 
            -
             | 
| 722 | 
            -
             | 
| 725 | 
            +
                        if is_hip_:
         | 
| 726 | 
            +
                            if get_bool_env_var("CK_MOE"):
         | 
| 727 | 
            +
                                layer.w13_weight = torch.nn.Parameter(
         | 
| 728 | 
            +
                                    permute_weight(layer.w13_weight.data),
         | 
| 729 | 
            +
                                    requires_grad=False,
         | 
| 730 | 
            +
                                )
         | 
| 731 | 
            +
                                torch.cuda.empty_cache()
         | 
| 732 | 
            +
                                layer.w2_weight = torch.nn.Parameter(
         | 
| 733 | 
            +
                                    permute_weight(layer.w2_weight.data),
         | 
| 734 | 
            +
                                    requires_grad=False,
         | 
| 735 | 
            +
                                )
         | 
| 736 | 
            +
                                torch.cuda.empty_cache()
         | 
| 737 | 
            +
                            elif get_bool_env_var("MOE_PADDING"):
         | 
| 738 | 
            +
                                # If ROCm, apply weight padding (min. Mem channel contention) only if set
         | 
| 739 | 
            +
                                layer.w13_weight = torch.nn.Parameter(
         | 
| 740 | 
            +
                                    F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
         | 
| 741 | 
            +
                                    requires_grad=False,
         | 
| 742 | 
            +
                                )
         | 
| 743 | 
            +
                                torch.cuda.empty_cache()
         | 
| 744 | 
            +
                                layer.w2_weight = torch.nn.Parameter(
         | 
| 745 | 
            +
                                    F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
         | 
| 746 | 
            +
                                    requires_grad=False,
         | 
| 747 | 
            +
                                )
         | 
| 748 | 
            +
                                torch.cuda.empty_cache()
         | 
| 723 749 | 
             
                        return
         | 
| 724 750 |  | 
| 725 751 | 
             
                def apply(
         | 
| @@ -752,27 +778,55 @@ class Fp8MoEMethod: | |
| 752 778 | 
             
                        correction_bias=correction_bias,
         | 
| 753 779 | 
             
                    )
         | 
| 754 780 |  | 
| 755 | 
            -
                     | 
| 756 | 
            -
             | 
| 757 | 
            -
                         | 
| 758 | 
            -
             | 
| 759 | 
            -
                         | 
| 760 | 
            -
             | 
| 761 | 
            -
             | 
| 762 | 
            -
             | 
| 763 | 
            -
             | 
| 764 | 
            -
             | 
| 765 | 
            -
                             | 
| 766 | 
            -
                             | 
| 767 | 
            -
             | 
| 768 | 
            -
             | 
| 769 | 
            -
             | 
| 770 | 
            -
                             | 
| 771 | 
            -
             | 
| 772 | 
            -
             | 
| 773 | 
            -
             | 
| 774 | 
            -
             | 
| 775 | 
            -
             | 
| 781 | 
            +
                    if is_hip_ and get_bool_env_var("CK_MOE"):
         | 
| 782 | 
            +
                        import ater
         | 
| 783 | 
            +
                        from ater.fused_moe import fused_experts_ck
         | 
| 784 | 
            +
             | 
| 785 | 
            +
                        return fused_experts_ck(
         | 
| 786 | 
            +
                            x,
         | 
| 787 | 
            +
                            layer.w13_weight,
         | 
| 788 | 
            +
                            layer.w2_weight,
         | 
| 789 | 
            +
                            topk_weights=topk_weights,
         | 
| 790 | 
            +
                            topk_ids=topk_ids,
         | 
| 791 | 
            +
                            use_fp8_w8a8=True,
         | 
| 792 | 
            +
                            w1_scale=(
         | 
| 793 | 
            +
                                layer.w13_weight_scale_inv
         | 
| 794 | 
            +
                                if self.block_quant
         | 
| 795 | 
            +
                                else layer.w13_weight_scale
         | 
| 796 | 
            +
                            ),
         | 
| 797 | 
            +
                            w2_scale=(
         | 
| 798 | 
            +
                                layer.w2_weight_scale_inv
         | 
| 799 | 
            +
                                if self.block_quant
         | 
| 800 | 
            +
                                else layer.w2_weight_scale
         | 
| 801 | 
            +
                            ),
         | 
| 802 | 
            +
                            a1_scale=layer.w13_input_scale,
         | 
| 803 | 
            +
                            a2_scale=layer.w2_input_scale,
         | 
| 804 | 
            +
                        )
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                    else:
         | 
| 807 | 
            +
                        # Expert fusion with FP8 quantization
         | 
| 808 | 
            +
                        return fused_experts(
         | 
| 809 | 
            +
                            x,
         | 
| 810 | 
            +
                            layer.w13_weight,
         | 
| 811 | 
            +
                            layer.w2_weight,
         | 
| 812 | 
            +
                            topk_weights=topk_weights,
         | 
| 813 | 
            +
                            topk_ids=topk_ids,
         | 
| 814 | 
            +
                            inplace=True,
         | 
| 815 | 
            +
                            use_fp8_w8a8=True,
         | 
| 816 | 
            +
                            w1_scale=(
         | 
| 817 | 
            +
                                layer.w13_weight_scale_inv
         | 
| 818 | 
            +
                                if self.block_quant
         | 
| 819 | 
            +
                                else layer.w13_weight_scale
         | 
| 820 | 
            +
                            ),
         | 
| 821 | 
            +
                            w2_scale=(
         | 
| 822 | 
            +
                                layer.w2_weight_scale_inv
         | 
| 823 | 
            +
                                if self.block_quant
         | 
| 824 | 
            +
                                else layer.w2_weight_scale
         | 
| 825 | 
            +
                            ),
         | 
| 826 | 
            +
                            a1_scale=layer.w13_input_scale,
         | 
| 827 | 
            +
                            a2_scale=layer.w2_input_scale,
         | 
| 828 | 
            +
                            block_shape=self.quant_config.weight_block_size,
         | 
| 829 | 
            +
                        )
         | 
| 776 830 |  | 
| 777 831 |  | 
| 778 832 | 
             
            class Fp8KVCacheMethod(BaseKVCacheMethod):
         | 
| @@ -1,8 +1,8 @@ | |
| 1 1 | 
             
            from typing import List, Optional, Tuple
         | 
| 2 2 |  | 
| 3 3 | 
             
            import torch
         | 
| 4 | 
            -
            from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
         | 
| 5 4 |  | 
| 5 | 
            +
            from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter
         | 
| 6 6 | 
             
            from sglang.srt.layers.quantization.fp8_kernel import (
         | 
| 7 7 | 
             
                per_token_group_quant_fp8,
         | 
| 8 8 | 
             
                w8a8_block_fp8_matmul,
         | 
| @@ -0,0 +1,54 @@ | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import triton
         | 
| 3 | 
            +
            import triton.language as tl
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            @triton.jit
         | 
| 7 | 
            +
            def _per_token_quant_int8(
         | 
| 8 | 
            +
                x_ptr,
         | 
| 9 | 
            +
                xq_ptr,
         | 
| 10 | 
            +
                scale_ptr,
         | 
| 11 | 
            +
                stride_x,
         | 
| 12 | 
            +
                stride_xq,
         | 
| 13 | 
            +
                N,
         | 
| 14 | 
            +
                BLOCK: tl.constexpr,
         | 
| 15 | 
            +
            ):
         | 
| 16 | 
            +
                # Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
         | 
| 17 | 
            +
                row_id = tl.program_id(0)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                cols = tl.arange(0, BLOCK)
         | 
| 20 | 
            +
                mask = cols < N
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
         | 
| 23 | 
            +
                absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
         | 
| 24 | 
            +
                scale_x = absmax / 127
         | 
| 25 | 
            +
                x_q = x * (127 / absmax)
         | 
| 26 | 
            +
                x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
         | 
| 29 | 
            +
                tl.store(scale_ptr + row_id, scale_x)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def per_token_quant_int8(x):
         | 
| 33 | 
            +
                M = x.numel() // x.shape[-1]
         | 
| 34 | 
            +
                N = x.shape[-1]
         | 
| 35 | 
            +
                x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
         | 
| 36 | 
            +
                scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
         | 
| 37 | 
            +
                BLOCK = triton.next_power_of_2(N)
         | 
| 38 | 
            +
                # heuristics for number of warps
         | 
| 39 | 
            +
                num_warps = min(max(BLOCK // 256, 1), 8)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                assert x.is_contiguous()
         | 
| 42 | 
            +
                _per_token_quant_int8[(M,)](
         | 
| 43 | 
            +
                    x,
         | 
| 44 | 
            +
                    x_q,
         | 
| 45 | 
            +
                    scales,
         | 
| 46 | 
            +
                    stride_x=x.stride(-2),
         | 
| 47 | 
            +
                    stride_xq=x_q.stride(-2),
         | 
| 48 | 
            +
                    N=N,
         | 
| 49 | 
            +
                    BLOCK=BLOCK,
         | 
| 50 | 
            +
                    num_warps=num_warps,
         | 
| 51 | 
            +
                    num_stages=1,
         | 
| 52 | 
            +
                )
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                return x_q, scales
         | 
| @@ -0,0 +1,174 @@ | |
| 1 | 
            +
            # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import logging
         | 
| 4 | 
            +
            from typing import Any, Dict, List, Optional
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from torch.nn.parameter import Parameter
         | 
| 8 | 
            +
            from vllm.model_executor.layers.linear import LinearBase
         | 
| 9 | 
            +
            from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
         | 
| 10 | 
            +
                apply_fp8_linear,
         | 
| 11 | 
            +
                cutlass_fp8_supported,
         | 
| 12 | 
            +
                requantize_with_max_scale,
         | 
| 13 | 
            +
            )
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from sglang.srt.layers.linear import LinearMethodBase
         | 
| 16 | 
            +
            from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
         | 
| 17 | 
            +
            from sglang.srt.layers.quantization.base_config import (
         | 
| 18 | 
            +
                QuantizationConfig,
         | 
| 19 | 
            +
                QuantizeMethodBase,
         | 
| 20 | 
            +
            )
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            # Initialize logger for the module
         | 
| 23 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            # Supported activation schemes for the current configuration
         | 
| 26 | 
            +
            ACTIVATION_SCHEMES = ["static"]
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class ModelOptFp8Config(QuantizationConfig):
         | 
| 30 | 
            +
                """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None:
         | 
| 33 | 
            +
                    """
         | 
| 34 | 
            +
                    Args:
         | 
| 35 | 
            +
                        is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
         | 
| 36 | 
            +
                    """
         | 
| 37 | 
            +
                    self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
         | 
| 38 | 
            +
                    if is_checkpoint_fp8_serialized:
         | 
| 39 | 
            +
                        logger.warning(
         | 
| 40 | 
            +
                            "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
         | 
| 41 | 
            +
                        )
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                @classmethod
         | 
| 44 | 
            +
                def get_name(cls) -> str:
         | 
| 45 | 
            +
                    return "modelopt"
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                @classmethod
         | 
| 48 | 
            +
                def get_supported_act_dtypes(cls) -> List[torch.dtype]:
         | 
| 49 | 
            +
                    return [torch.bfloat16, torch.half]
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                @classmethod
         | 
| 52 | 
            +
                def get_min_capability(cls) -> int:
         | 
| 53 | 
            +
                    return 89  # Minimum hardware capability (e.g., Hopper GPUs).
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                @classmethod
         | 
| 56 | 
            +
                def get_config_filenames(cls) -> List[str]:
         | 
| 57 | 
            +
                    return ["hf_quant_config.json"]
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                @classmethod
         | 
| 60 | 
            +
                def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
         | 
| 61 | 
            +
                    quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    if "FP8" not in quant_method:
         | 
| 64 | 
            +
                        raise ValueError(
         | 
| 65 | 
            +
                            "ModelOpt only supports static FP8 quantization in SGLang. "
         | 
| 66 | 
            +
                            "Check the `hf_quant_config.json` file for your model's configuration."
         | 
| 67 | 
            +
                        )
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    return cls(is_checkpoint_fp8_serialized=True)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def get_quant_method(
         | 
| 72 | 
            +
                    self, layer: torch.nn.Module, prefix: str
         | 
| 73 | 
            +
                ) -> Optional["QuantizeMethodBase"]:
         | 
| 74 | 
            +
                    return ModelOptFp8LinearMethod(self) if isinstance(layer, LinearBase) else None
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def get_scaled_act_names(self) -> List[str]:
         | 
| 77 | 
            +
                    return []
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            class ModelOptFp8LinearMethod(LinearMethodBase):
         | 
| 81 | 
            +
                """Linear method for ModelOpt static FP8 quantization.
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                Supports loading FP8 checkpoints with static weight and activation scales.
         | 
| 84 | 
            +
                Future support may include dynamic scales.
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                **Limitations**:
         | 
| 87 | 
            +
                1. Only supports per-tensor quantization due to `torch._scaled_mm` limitations.
         | 
| 88 | 
            +
                2. Only supports the `float8_e4m3fn` data type.
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                Args:
         | 
| 91 | 
            +
                    quant_config (ModelOptFp8Config): The ModelOpt quantization configuration.
         | 
| 92 | 
            +
                """
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def __init__(self, quant_config: ModelOptFp8Config):
         | 
| 95 | 
            +
                    super().__init__()
         | 
| 96 | 
            +
                    self.quant_config = quant_config
         | 
| 97 | 
            +
                    self.cutlass_fp8_supported = cutlass_fp8_supported()
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def create_weights(
         | 
| 100 | 
            +
                    self,
         | 
| 101 | 
            +
                    layer: torch.nn.Module,
         | 
| 102 | 
            +
                    input_size_per_partition: int,
         | 
| 103 | 
            +
                    output_partition_sizes: List[int],
         | 
| 104 | 
            +
                    params_dtype: torch.dtype,
         | 
| 105 | 
            +
                    **extra_weight_attrs,
         | 
| 106 | 
            +
                ) -> None:
         | 
| 107 | 
            +
                    """Creates and registers weights, weight scales, and input scales for FP8 quantization."""
         | 
| 108 | 
            +
                    output_size_per_partition = sum(output_partition_sizes)
         | 
| 109 | 
            +
                    weight_loader = extra_weight_attrs.get("weight_loader")
         | 
| 110 | 
            +
                    weight_dtype = (
         | 
| 111 | 
            +
                        torch.float8_e4m3fn
         | 
| 112 | 
            +
                        if self.quant_config.is_checkpoint_fp8_serialized
         | 
| 113 | 
            +
                        else params_dtype
         | 
| 114 | 
            +
                    )
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    # Set layer attributes
         | 
| 117 | 
            +
                    layer.logical_widths = output_partition_sizes
         | 
| 118 | 
            +
                    layer.input_size_per_partition = input_size_per_partition
         | 
| 119 | 
            +
                    layer.output_size_per_partition = output_size_per_partition
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    # Register weight
         | 
| 122 | 
            +
                    layer.register_parameter(
         | 
| 123 | 
            +
                        "weight",
         | 
| 124 | 
            +
                        ModelWeightParameter(
         | 
| 125 | 
            +
                            data=torch.empty(
         | 
| 126 | 
            +
                                output_size_per_partition,
         | 
| 127 | 
            +
                                input_size_per_partition,
         | 
| 128 | 
            +
                                dtype=weight_dtype,
         | 
| 129 | 
            +
                            ),
         | 
| 130 | 
            +
                            input_dim=1,
         | 
| 131 | 
            +
                            output_dim=0,
         | 
| 132 | 
            +
                            weight_loader=weight_loader,
         | 
| 133 | 
            +
                        ),
         | 
| 134 | 
            +
                    )
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    if self.quant_config.is_checkpoint_fp8_serialized:
         | 
| 137 | 
            +
                        # Register weight and input scales
         | 
| 138 | 
            +
                        for scale_name in ["weight_scale", "input_scale"]:
         | 
| 139 | 
            +
                            layer.register_parameter(
         | 
| 140 | 
            +
                                scale_name,
         | 
| 141 | 
            +
                                PerTensorScaleParameter(
         | 
| 142 | 
            +
                                    data=torch.full(
         | 
| 143 | 
            +
                                        (len(output_partition_sizes),),
         | 
| 144 | 
            +
                                        torch.finfo(torch.float32).min,
         | 
| 145 | 
            +
                                        dtype=torch.float32,
         | 
| 146 | 
            +
                                    ),
         | 
| 147 | 
            +
                                    weight_loader=weight_loader,
         | 
| 148 | 
            +
                                ),
         | 
| 149 | 
            +
                            )
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
         | 
| 152 | 
            +
                    """Requantizes weights after loading using the maximum scale."""
         | 
| 153 | 
            +
                    max_w_scale, quantized_weight = requantize_with_max_scale(
         | 
| 154 | 
            +
                        layer.weight, layer.weight_scale, layer.logical_widths
         | 
| 155 | 
            +
                    )
         | 
| 156 | 
            +
                    layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
         | 
| 157 | 
            +
                    layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
         | 
| 158 | 
            +
                    layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                def apply(
         | 
| 161 | 
            +
                    self,
         | 
| 162 | 
            +
                    layer: torch.nn.Module,
         | 
| 163 | 
            +
                    x: torch.Tensor,
         | 
| 164 | 
            +
                    bias: Optional[torch.Tensor] = None,
         | 
| 165 | 
            +
                ) -> torch.Tensor:
         | 
| 166 | 
            +
                    """Applies FP8 linear transformation."""
         | 
| 167 | 
            +
                    return apply_fp8_linear(
         | 
| 168 | 
            +
                        input=x,
         | 
| 169 | 
            +
                        weight=layer.weight,
         | 
| 170 | 
            +
                        weight_scale=layer.weight_scale,
         | 
| 171 | 
            +
                        input_scale=layer.input_scale,
         | 
| 172 | 
            +
                        bias=bias,
         | 
| 173 | 
            +
                        cutlass_fp8_supported=self.cutlass_fp8_supported,
         | 
| 174 | 
            +
                    )
         | 
| @@ -0,0 +1,117 @@ | |
| 1 | 
            +
            from typing import Any, Dict, List, Optional
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from sglang.srt.utils import is_cuda_available
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            is_cuda = is_cuda_available()
         | 
| 8 | 
            +
            if is_cuda:
         | 
| 9 | 
            +
                from sgl_kernel import int8_scaled_mm
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from torch.nn.parameter import Parameter
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from sglang.srt.layers.linear import LinearMethodBase
         | 
| 14 | 
            +
            from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
         | 
| 15 | 
            +
            from sglang.srt.layers.quantization.base_config import (
         | 
| 16 | 
            +
                QuantizationConfig,
         | 
| 17 | 
            +
                QuantizeMethodBase,
         | 
| 18 | 
            +
            )
         | 
| 19 | 
            +
            from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class W8A8Int8Config(QuantizationConfig):
         | 
| 23 | 
            +
                """Config class for W8A8 Int8 Quantization.
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                - Weight: static, per-channel, symmetric
         | 
| 26 | 
            +
                - Activation: dynamic, per-token, symmetric
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def __init__(self):
         | 
| 30 | 
            +
                    pass
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                @classmethod
         | 
| 33 | 
            +
                def get_supported_act_dtypes(cls) -> List[torch.dtype]:
         | 
| 34 | 
            +
                    return [torch.float16, torch.bfloat16]
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                @classmethod
         | 
| 37 | 
            +
                def get_min_capability(cls) -> int:
         | 
| 38 | 
            +
                    return 75
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                @classmethod
         | 
| 41 | 
            +
                def get_name(self) -> str:
         | 
| 42 | 
            +
                    return "w8a8_int8"
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                @classmethod
         | 
| 45 | 
            +
                def get_config_filenames(cls) -> List[str]:
         | 
| 46 | 
            +
                    return []
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                @classmethod
         | 
| 49 | 
            +
                def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
         | 
| 50 | 
            +
                    return cls()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def get_quant_method(
         | 
| 53 | 
            +
                    self,
         | 
| 54 | 
            +
                    layer: torch.nn.Module,
         | 
| 55 | 
            +
                    prefix: str,
         | 
| 56 | 
            +
                ) -> Optional["QuantizeMethodBase"]:
         | 
| 57 | 
            +
                    from vllm.model_executor.layers.linear import LinearBase
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    if isinstance(layer, LinearBase):
         | 
| 60 | 
            +
                        return W8A8Int8LinearMethod(self)
         | 
| 61 | 
            +
                    return None
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def get_scaled_act_names(self) -> List[str]:
         | 
| 64 | 
            +
                    return []
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            class W8A8Int8LinearMethod(LinearMethodBase):
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def __init__(self, quantization_config: W8A8Int8Config):
         | 
| 70 | 
            +
                    self.quantization_config = quantization_config
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
         | 
| 73 | 
            +
                    layer.weight = Parameter(layer.weight.t(), requires_grad=False)
         | 
| 74 | 
            +
                    layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def create_weights(
         | 
| 77 | 
            +
                    self,
         | 
| 78 | 
            +
                    layer: torch.nn.Module,
         | 
| 79 | 
            +
                    input_size_per_partition: int,
         | 
| 80 | 
            +
                    output_partition_sizes: List[int],
         | 
| 81 | 
            +
                    input_size: int,
         | 
| 82 | 
            +
                    output_size: int,
         | 
| 83 | 
            +
                    params_dtype: torch.dtype,
         | 
| 84 | 
            +
                    **extra_weight_attrs
         | 
| 85 | 
            +
                ):
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    weight_loader = extra_weight_attrs.get("weight_loader")
         | 
| 88 | 
            +
                    self.logical_widths = output_partition_sizes
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    weight = ModelWeightParameter(
         | 
| 91 | 
            +
                        data=torch.empty(
         | 
| 92 | 
            +
                            sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
         | 
| 93 | 
            +
                        ),
         | 
| 94 | 
            +
                        input_dim=1,
         | 
| 95 | 
            +
                        output_dim=0,
         | 
| 96 | 
            +
                        weight_loader=weight_loader,
         | 
| 97 | 
            +
                    )
         | 
| 98 | 
            +
                    layer.register_parameter("weight", weight)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    weight_scale = ChannelQuantScaleParameter(
         | 
| 101 | 
            +
                        data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
         | 
| 102 | 
            +
                        output_dim=0,
         | 
| 103 | 
            +
                        weight_loader=weight_loader,
         | 
| 104 | 
            +
                    )
         | 
| 105 | 
            +
                    layer.register_parameter("weight_scale", weight_scale)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def apply(
         | 
| 108 | 
            +
                    self,
         | 
| 109 | 
            +
                    layer: torch.nn.Module,
         | 
| 110 | 
            +
                    x: torch.Tensor,
         | 
| 111 | 
            +
                    bias: Optional[torch.Tensor] = None,
         | 
| 112 | 
            +
                ):
         | 
| 113 | 
            +
                    x_q, x_scale = per_token_quant_int8(x)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    return int8_scaled_mm(
         | 
| 116 | 
            +
                        x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
         | 
| 117 | 
            +
                    )
         | 
| @@ -12,8 +12,8 @@ from vllm.distributed import ( | |
| 12 12 | 
             
                get_tensor_model_parallel_world_size,
         | 
| 13 13 | 
             
                tensor_model_parallel_all_reduce,
         | 
| 14 14 | 
             
            )
         | 
| 15 | 
            -
            from vllm.model_executor.parameter import BasevLLMParameter
         | 
| 16 15 |  | 
| 16 | 
            +
            from sglang.srt.layers.parameter import BasevLLMParameter
         | 
| 17 17 | 
             
            from sglang.srt.layers.quantization.base_config import (
         | 
| 18 18 | 
             
                QuantizationConfig,
         | 
| 19 19 | 
             
                QuantizeMethodBase,
         | 
| @@ -220,6 +220,7 @@ class VocabParallelEmbedding(torch.nn.Module): | |
| 220 220 | 
             
                    quant_config: Optional[QuantizationConfig] = None,
         | 
| 221 221 | 
             
                    prefix: str = "",
         | 
| 222 222 | 
             
                    enable_tp: bool = True,
         | 
| 223 | 
            +
                    use_presharded_weights: bool = False,
         | 
| 223 224 | 
             
                ):
         | 
| 224 225 | 
             
                    super().__init__()
         | 
| 225 226 | 
             
                    self.quant_config = quant_config
         | 
| @@ -236,6 +237,12 @@ class VocabParallelEmbedding(torch.nn.Module): | |
| 236 237 | 
             
                    self.padding_size = padding_size
         | 
| 237 238 | 
             
                    self.org_vocab_size = org_num_embeddings or num_embeddings
         | 
| 238 239 | 
             
                    num_added_embeddings = num_embeddings - self.org_vocab_size
         | 
| 240 | 
            +
                    self.use_presharded_weights = use_presharded_weights
         | 
| 241 | 
            +
                    if use_presharded_weights:
         | 
| 242 | 
            +
                        assert (
         | 
| 243 | 
            +
                            num_added_embeddings == 0
         | 
| 244 | 
            +
                        ), "Lora is not supported with presharded weights."
         | 
| 245 | 
            +
             | 
| 239 246 | 
             
                    self.org_vocab_size_padded = pad_vocab_size(
         | 
| 240 247 | 
             
                        self.org_vocab_size, self.padding_size
         | 
| 241 248 | 
             
                    )
         | 
| @@ -447,10 +454,14 @@ class VocabParallelEmbedding(torch.nn.Module): | |
| 447 454 | 
             
                        start_idx = start_idx // packed_factor
         | 
| 448 455 | 
             
                        shard_size = shard_size // packed_factor
         | 
| 449 456 | 
             
                    else:
         | 
| 450 | 
            -
                        assert loaded_weight.shape[output_dim] ==  | 
| 457 | 
            +
                        assert loaded_weight.shape[output_dim] == (
         | 
| 458 | 
            +
                            self.org_vocab_size
         | 
| 459 | 
            +
                            // (self.tp_size if self.use_presharded_weights else 1)
         | 
| 460 | 
            +
                        )
         | 
| 451 461 |  | 
| 452 462 | 
             
                    # Copy the data.
         | 
| 453 | 
            -
                     | 
| 463 | 
            +
                    if not self.use_presharded_weights:
         | 
| 464 | 
            +
                        loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
         | 
| 454 465 | 
             
                    param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
         | 
| 455 466 | 
             
                    param[loaded_weight.shape[0] :].data.fill_(0)
         | 
| 456 467 |  | 
| @@ -514,6 +525,7 @@ class ParallelLMHead(VocabParallelEmbedding): | |
| 514 525 | 
             
                    padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
         | 
| 515 526 | 
             
                    quant_config: Optional[QuantizationConfig] = None,
         | 
| 516 527 | 
             
                    prefix: str = "",
         | 
| 528 | 
            +
                    use_presharded_weights: bool = False,
         | 
| 517 529 | 
             
                ):
         | 
| 518 530 | 
             
                    super().__init__(
         | 
| 519 531 | 
             
                        num_embeddings,
         | 
| @@ -523,6 +535,7 @@ class ParallelLMHead(VocabParallelEmbedding): | |
| 523 535 | 
             
                        padding_size,
         | 
| 524 536 | 
             
                        quant_config,
         | 
| 525 537 | 
             
                        prefix,
         | 
| 538 | 
            +
                        use_presharded_weights=use_presharded_weights,
         | 
| 526 539 | 
             
                    )
         | 
| 527 540 | 
             
                    self.quant_config = quant_config
         | 
| 528 541 | 
             
                    if bias:
         |