sglang 0.4.2.post3__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/check_env.py +1 -0
- sglang/srt/constrained/outlines_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +34 -41
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -3
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/lora/backend/__init__.py +25 -5
- sglang/srt/lora/backend/base_backend.py +31 -9
- sglang/srt/lora/backend/flashinfer_backend.py +41 -4
- sglang/srt/lora/backend/triton_backend.py +34 -4
- sglang/srt/lora/layers.py +293 -0
- sglang/srt/lora/lora.py +101 -326
- sglang/srt/lora/lora_manager.py +101 -269
- sglang/srt/lora/mem_pool.py +174 -0
- sglang/srt/lora/triton_ops/__init__.py +7 -1
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
- sglang/srt/lora/utils.py +141 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +64 -21
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +41 -24
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
| @@ -0,0 +1,170 @@ | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import triton
         | 
| 3 | 
            +
            import triton.language as tl
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from sglang.srt.lora.utils import LoRABatchInfo
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            @triton.jit
         | 
| 9 | 
            +
            def _gate_up_lora_b_kernel(
         | 
| 10 | 
            +
                # Pointers to matrices
         | 
| 11 | 
            +
                x,
         | 
| 12 | 
            +
                weights,
         | 
| 13 | 
            +
                output,
         | 
| 14 | 
            +
                # Parameters of size
         | 
| 15 | 
            +
                K,  # K = R
         | 
| 16 | 
            +
                output_dim,
         | 
| 17 | 
            +
                # Strides
         | 
| 18 | 
            +
                x_stride_0,
         | 
| 19 | 
            +
                x_stride_1,
         | 
| 20 | 
            +
                w_stride_0,
         | 
| 21 | 
            +
                w_stride_1,
         | 
| 22 | 
            +
                w_stride_2,
         | 
| 23 | 
            +
                output_stride_0,
         | 
| 24 | 
            +
                output_stride_1,
         | 
| 25 | 
            +
                # Information on sequence lengths and weight id
         | 
| 26 | 
            +
                seg_lens,
         | 
| 27 | 
            +
                seg_indptr,
         | 
| 28 | 
            +
                weight_indices,
         | 
| 29 | 
            +
                # Meta parameters
         | 
| 30 | 
            +
                BLOCK_S: tl.constexpr,
         | 
| 31 | 
            +
                BLOCK_N: tl.constexpr,
         | 
| 32 | 
            +
                BLOCK_K: tl.constexpr,
         | 
| 33 | 
            +
                # For fused output scaling and adding
         | 
| 34 | 
            +
                fuse_scaling_add,
         | 
| 35 | 
            +
                scaling,
         | 
| 36 | 
            +
            ):
         | 
| 37 | 
            +
                # This kernel packs 2 sgemms (gate/up) into a single kernel.
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                # x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
         | 
| 40 | 
            +
                # weights: (num_lora, 2 * output_dim, K)
         | 
| 41 | 
            +
                # output: (s, 2 * output_dim)
         | 
| 42 | 
            +
                # output_dim >> K
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # Current block computes sequence with batch_id,
         | 
| 45 | 
            +
                # which starts from row seg_start of x with length seg_len.
         | 
| 46 | 
            +
                # gate_up_id decides which of gate or up (0: gate, 1: up)
         | 
| 47 | 
            +
                batch_id = tl.program_id(axis=2)
         | 
| 48 | 
            +
                gate_up_id = tl.program_id(axis=1)
         | 
| 49 | 
            +
                pid = tl.program_id(axis=0)
         | 
| 50 | 
            +
                seg_len = tl.load(seg_lens + batch_id)
         | 
| 51 | 
            +
                w_index = tl.load(weight_indices + batch_id)
         | 
| 52 | 
            +
                seg_start = tl.load(seg_indptr + batch_id)
         | 
| 53 | 
            +
                n_start = gate_up_id * output_dim  # offset on output dim
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                # The tile in output matrix will have (pid_s, pid_n) as id
         | 
| 56 | 
            +
                num_pid_n = tl.cdiv(output_dim, BLOCK_N)
         | 
| 57 | 
            +
                pid_s = pid // num_pid_n
         | 
| 58 | 
            +
                pid_n = pid % num_pid_n
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                # Create pointers for the first block of x and weights
         | 
| 61 | 
            +
                # The pointers will be advanced as we move in the K direction
         | 
| 62 | 
            +
                # and accumulate
         | 
| 63 | 
            +
                s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
         | 
| 64 | 
            +
                n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
         | 
| 65 | 
            +
                k_offset = tl.arange(0, BLOCK_K)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                x_ptrs = (x + seg_start * x_stride_0 + (gate_up_id * K) * x_stride_1) + (
         | 
| 68 | 
            +
                    s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
         | 
| 69 | 
            +
                )
         | 
| 70 | 
            +
                w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
         | 
| 71 | 
            +
                    k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
         | 
| 72 | 
            +
                )
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                # Iteate to compute the block in output matrix
         | 
| 75 | 
            +
                partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
         | 
| 76 | 
            +
                for k in range(0, tl.cdiv(K, BLOCK_K)):
         | 
| 77 | 
            +
                    x_tile = tl.load(
         | 
| 78 | 
            +
                        x_ptrs,
         | 
| 79 | 
            +
                        mask=(s_offset[:, None] < seg_len)
         | 
| 80 | 
            +
                        and (k_offset[None, :] < K - k * BLOCK_K),
         | 
| 81 | 
            +
                        other=0.0,
         | 
| 82 | 
            +
                    )
         | 
| 83 | 
            +
                    w_tile = tl.load(
         | 
| 84 | 
            +
                        w_ptrs,
         | 
| 85 | 
            +
                        mask=(k_offset[:, None] < K - k * BLOCK_K)
         | 
| 86 | 
            +
                        and (n_offset[None, :] < output_dim),
         | 
| 87 | 
            +
                        other=0.0,
         | 
| 88 | 
            +
                    )
         | 
| 89 | 
            +
                    partial_sum += tl.dot(x_tile, w_tile)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    x_ptrs += BLOCK_K * x_stride_1
         | 
| 92 | 
            +
                    w_ptrs += BLOCK_K * w_stride_2
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                # Store result to output matrix
         | 
| 95 | 
            +
                partial_sum *= scaling
         | 
| 96 | 
            +
                partial_sum = partial_sum.to(x.dtype.element_ty)
         | 
| 97 | 
            +
                output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
         | 
| 98 | 
            +
                    s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
         | 
| 99 | 
            +
                )
         | 
| 100 | 
            +
                output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim)
         | 
| 101 | 
            +
                if fuse_scaling_add:
         | 
| 102 | 
            +
                    partial_sum += tl.load(output_ptr, mask=output_mask)
         | 
| 103 | 
            +
                tl.store(output_ptr, partial_sum, mask=output_mask)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            def gate_up_lora_b_fwd(
         | 
| 107 | 
            +
                x: torch.Tensor,
         | 
| 108 | 
            +
                gate_up_lora_b: torch.Tensor,
         | 
| 109 | 
            +
                batch_info: LoRABatchInfo,
         | 
| 110 | 
            +
                output_dim: int,
         | 
| 111 | 
            +
                base_output: torch.Tensor = None,
         | 
| 112 | 
            +
                scaling: float = 1.0,
         | 
| 113 | 
            +
            ) -> torch.Tensor:
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                # x: (s, 2 * r)
         | 
| 116 | 
            +
                # gate_up_lora_b: (num_lora, 2 * output_dim, r)
         | 
| 117 | 
            +
                # output: (s, 2 * output_dim)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                # Compute lora_output with shape (s, output_dim) as follows:
         | 
| 120 | 
            +
                # lora_output[:, :output_dim] = sgemm(x[:, :r], gate_up_lora_b[:, :output_dim, :])
         | 
| 121 | 
            +
                # lora_output[:, output_dim:]
         | 
| 122 | 
            +
                #      = sgemm(x[:, r:], gate_up_lora_b[:, output_dim:, :])
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                # Get dims
         | 
| 125 | 
            +
                s = x.shape[0]
         | 
| 126 | 
            +
                input_dim = x.shape[1]
         | 
| 127 | 
            +
                r = gate_up_lora_b.shape[-1]
         | 
| 128 | 
            +
                assert input_dim == 2 * r
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                BLOCK_S = 16
         | 
| 131 | 
            +
                BLOCK_R = 16
         | 
| 132 | 
            +
                BLOCK_OUT = 64
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                grid_b = (
         | 
| 135 | 
            +
                    triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(output_dim, BLOCK_OUT),
         | 
| 136 | 
            +
                    2,  # this dimension decides current block computes on gate or up proj
         | 
| 137 | 
            +
                    batch_info.bs,
         | 
| 138 | 
            +
                )
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                if base_output is None:
         | 
| 141 | 
            +
                    output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype)
         | 
| 142 | 
            +
                    fuse_scaling_add = False
         | 
| 143 | 
            +
                else:
         | 
| 144 | 
            +
                    output = base_output
         | 
| 145 | 
            +
                    fuse_scaling_add = True
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                _gate_up_lora_b_kernel[grid_b](
         | 
| 148 | 
            +
                    x,
         | 
| 149 | 
            +
                    gate_up_lora_b,
         | 
| 150 | 
            +
                    output,
         | 
| 151 | 
            +
                    r,
         | 
| 152 | 
            +
                    output_dim,
         | 
| 153 | 
            +
                    x.stride(0),
         | 
| 154 | 
            +
                    x.stride(1),
         | 
| 155 | 
            +
                    gate_up_lora_b.stride(0),
         | 
| 156 | 
            +
                    gate_up_lora_b.stride(1),
         | 
| 157 | 
            +
                    gate_up_lora_b.stride(2),
         | 
| 158 | 
            +
                    output.stride(0),
         | 
| 159 | 
            +
                    output.stride(1),
         | 
| 160 | 
            +
                    batch_info.seg_lens,
         | 
| 161 | 
            +
                    batch_info.seg_indptr,
         | 
| 162 | 
            +
                    batch_info.weight_indices,
         | 
| 163 | 
            +
                    BLOCK_S,
         | 
| 164 | 
            +
                    BLOCK_OUT,
         | 
| 165 | 
            +
                    BLOCK_R,
         | 
| 166 | 
            +
                    fuse_scaling_add,
         | 
| 167 | 
            +
                    scaling,
         | 
| 168 | 
            +
                )
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                return output
         | 
| @@ -2,7 +2,7 @@ import torch | |
| 2 2 | 
             
            import triton
         | 
| 3 3 | 
             
            import triton.language as tl
         | 
| 4 4 |  | 
| 5 | 
            -
            from sglang.srt.lora. | 
| 5 | 
            +
            from sglang.srt.lora.utils import LoRABatchInfo
         | 
| 6 6 |  | 
| 7 7 |  | 
| 8 8 | 
             
            @triton.jit
         | 
| @@ -108,7 +108,7 @@ def _qkv_lora_b_kernel( | |
| 108 108 | 
             
            def qkv_lora_b_fwd(
         | 
| 109 109 | 
             
                x: torch.Tensor,
         | 
| 110 110 | 
             
                qkv_lora_b: torch.Tensor,
         | 
| 111 | 
            -
                batch_info:  | 
| 111 | 
            +
                batch_info: LoRABatchInfo,
         | 
| 112 112 | 
             
                output_offset: torch.Tensor,
         | 
| 113 113 | 
             
                max_qkv_out_dim: int,
         | 
| 114 114 | 
             
                base_output: torch.Tensor = None,
         | 
| @@ -123,11 +123,11 @@ def qkv_lora_b_fwd( | |
| 123 123 | 
             
                # output: (s, output_dim_q + 2 * output_dim_kv)
         | 
| 124 124 |  | 
| 125 125 | 
             
                # Compute lora_output with shape (s, output_dim) as follows:
         | 
| 126 | 
            -
                # lora_output[:, :output_dim_q] = sgemm( | 
| 126 | 
            +
                # lora_output[:, :output_dim_q] = sgemm(x[:, :r], qkv_lora_b[:, :outptu_dim_q, :])
         | 
| 127 127 | 
             
                # lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
         | 
| 128 | 
            -
                #      = sgemm( | 
| 128 | 
            +
                #      = sgemm(x[:, r: 2 * r], qkv_lora_b[:, outptu_dim_q: output_dim_q + output_dim_kv, :])
         | 
| 129 129 | 
             
                # lora_output[:, output_dim_q + output_dim_kv: ]
         | 
| 130 | 
            -
                #      = sgemm( | 
| 130 | 
            +
                #      = sgemm(x[:, 2 * r: , qkv_lora_b[:, output_dim_q + output_dim_kv: , :])
         | 
| 131 131 |  | 
| 132 132 | 
             
                # Get dims
         | 
| 133 133 | 
             
                s = x.shape[0]
         | 
| @@ -2,7 +2,7 @@ import torch | |
| 2 2 | 
             
            import triton
         | 
| 3 3 | 
             
            import triton.language as tl
         | 
| 4 4 |  | 
| 5 | 
            -
            from sglang.srt.lora. | 
| 5 | 
            +
            from sglang.srt.lora.utils import LoRABatchInfo
         | 
| 6 6 |  | 
| 7 7 |  | 
| 8 8 | 
             
            @triton.jit
         | 
| @@ -91,7 +91,7 @@ def _sgemm_lora_a_kernel( | |
| 91 91 |  | 
| 92 92 |  | 
| 93 93 | 
             
            def sgemm_lora_a_fwd(
         | 
| 94 | 
            -
                x: torch.Tensor, weights: torch.Tensor, batch_info:  | 
| 94 | 
            +
                x: torch.Tensor, weights: torch.Tensor, batch_info: LoRABatchInfo
         | 
| 95 95 | 
             
            ) -> torch.Tensor:
         | 
| 96 96 | 
             
                # x: (s, input_dim)
         | 
| 97 97 | 
             
                # weights: (num_lora, r, input_dim)
         | 
| @@ -2,7 +2,7 @@ import torch | |
| 2 2 | 
             
            import triton
         | 
| 3 3 | 
             
            import triton.language as tl
         | 
| 4 4 |  | 
| 5 | 
            -
            from sglang.srt.lora. | 
| 5 | 
            +
            from sglang.srt.lora.utils import LoRABatchInfo
         | 
| 6 6 |  | 
| 7 7 |  | 
| 8 8 | 
             
            @triton.jit
         | 
| @@ -98,7 +98,7 @@ def _sgemm_lora_b_kernel( | |
| 98 98 | 
             
            def sgemm_lora_b_fwd(
         | 
| 99 99 | 
             
                x: torch.Tensor,
         | 
| 100 100 | 
             
                weights: torch.Tensor,
         | 
| 101 | 
            -
                batch_info:  | 
| 101 | 
            +
                batch_info: LoRABatchInfo,
         | 
| 102 102 | 
             
                base_output: torch.Tensor = None,
         | 
| 103 103 | 
             
                scaling: float = 1.0,
         | 
| 104 104 | 
             
            ) -> torch.Tensor:
         | 
    
        sglang/srt/lora/utils.py
    ADDED
    
    | @@ -0,0 +1,141 @@ | |
| 1 | 
            +
            import re
         | 
| 2 | 
            +
            from dataclasses import dataclass
         | 
| 3 | 
            +
            from enum import Enum
         | 
| 4 | 
            +
            from typing import Optional, Set, Tuple
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from sglang.srt.hf_transformers_utils import AutoConfig
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            @dataclass
         | 
| 12 | 
            +
            class LoRABatchInfo:
         | 
| 13 | 
            +
                # Batch size
         | 
| 14 | 
            +
                bs: int
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                # Lengths of each sequence in shape (bs,)
         | 
| 17 | 
            +
                seg_lens: torch.Tensor
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                # Indice pointers of each sequence in shape (bs + 1, )
         | 
| 20 | 
            +
                seg_indptr: torch.Tensor
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                # Maximum sequence length of current batch
         | 
| 23 | 
            +
                max_len: int
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                # The index of lora adapter used by each sequence, in shape (bs,)
         | 
| 26 | 
            +
                weight_indices: torch.Tensor
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class LoRAType(Enum):
         | 
| 30 | 
            +
                LORA_A = 0
         | 
| 31 | 
            +
                LORA_B = 1
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def get_layer_id(name: str) -> int:
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                Extract integer id of layer from its name in string.
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                match = re.search(r"layers\.(\d+)\.", name)
         | 
| 39 | 
            +
                if match is None:
         | 
| 40 | 
            +
                    return None
         | 
| 41 | 
            +
                return int(match.group(1))
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def get_customized_names_from_hf_names(
         | 
| 45 | 
            +
                hf_module_names: Set[str], base_model: torch.nn.Module
         | 
| 46 | 
            +
            ) -> Set[str]:
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                This function takes in a set of huggingface style module names:
         | 
| 49 | 
            +
                     e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
         | 
| 50 | 
            +
                and outputs a set of module names of customized sglang layers:
         | 
| 51 | 
            +
                     e.g., {"qkv_proj", "o_proj"}
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                if hasattr(base_model, "get_module_name"):
         | 
| 54 | 
            +
                    return {base_model.get_module_name(name) for name in hf_module_names}
         | 
| 55 | 
            +
                else:
         | 
| 56 | 
            +
                    """
         | 
| 57 | 
            +
                    Fallback solution of mapping from config module name to module name in model class.
         | 
| 58 | 
            +
                    Please check if it aligns with your base model.
         | 
| 59 | 
            +
                    Please implement the function in the model class if it is not.
         | 
| 60 | 
            +
                    You can reference this function in llama.py.
         | 
| 61 | 
            +
                    """
         | 
| 62 | 
            +
                    params_mapping = {
         | 
| 63 | 
            +
                        "q_proj": "qkv_proj",
         | 
| 64 | 
            +
                        "k_proj": "qkv_proj",
         | 
| 65 | 
            +
                        "v_proj": "qkv_proj",
         | 
| 66 | 
            +
                        "gate_proj": "gate_up_proj",
         | 
| 67 | 
            +
                        "up_proj": "gate_up_proj",
         | 
| 68 | 
            +
                    }
         | 
| 69 | 
            +
                    return {params_mapping.get(name, name) for name in hf_module_names}
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            def get_hidden_dim(
         | 
| 73 | 
            +
                module_name: str, config: AutoConfig, base_model: torch.nn.Module
         | 
| 74 | 
            +
            ) -> Tuple[int]:
         | 
| 75 | 
            +
                """
         | 
| 76 | 
            +
                Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
         | 
| 77 | 
            +
                """
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                if hasattr(base_model, "get_hidden_dim"):
         | 
| 80 | 
            +
                    return base_model.get_hidden_dim(module_name)
         | 
| 81 | 
            +
                else:
         | 
| 82 | 
            +
                    """
         | 
| 83 | 
            +
                    WARNING: get_hidden_dim() is not defined,
         | 
| 84 | 
            +
                    which is used to get the hidden dim for different lora modules
         | 
| 85 | 
            +
                    Use the default one, but please check if it is correct for your model.
         | 
| 86 | 
            +
                    Please implement the function in the model class if it is not.
         | 
| 87 | 
            +
                    You can reference this function in llama.py.
         | 
| 88 | 
            +
                    """
         | 
| 89 | 
            +
                    if module_name in ["q_proj", "o_proj", "qkv_proj"]:
         | 
| 90 | 
            +
                        return config.hidden_size, config.hidden_size
         | 
| 91 | 
            +
                    elif module_name in ["kv_proj"]:
         | 
| 92 | 
            +
                        return config.hidden_size, config.hidden_size // (
         | 
| 93 | 
            +
                            config.num_attention_heads // config.num_key_value_heads
         | 
| 94 | 
            +
                        )
         | 
| 95 | 
            +
                    elif module_name == "gate_up_proj":
         | 
| 96 | 
            +
                        return config.hidden_size, config.intermediate_size
         | 
| 97 | 
            +
                    elif module_name == "down_proj":
         | 
| 98 | 
            +
                        return config.intermediate_size, config.hidden_size
         | 
| 99 | 
            +
                    else:
         | 
| 100 | 
            +
                        raise NotImplementedError()
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
            def get_stacked_name(name: str) -> Tuple[str]:
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
                Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
         | 
| 106 | 
            +
                """
         | 
| 107 | 
            +
                params_mapping = {
         | 
| 108 | 
            +
                    "q_proj": ("qkv_proj", "q_proj"),
         | 
| 109 | 
            +
                    "k_proj": ("qkv_proj", "kv_proj"),
         | 
| 110 | 
            +
                    "v_proj": ("qkv_proj", "kv_proj"),
         | 
| 111 | 
            +
                    "gate_proj": ("gate_up_proj", "gate_up_proj"),
         | 
| 112 | 
            +
                    "up_proj": ("gate_up_proj", "gate_up_proj"),
         | 
| 113 | 
            +
                }
         | 
| 114 | 
            +
                return params_mapping.get(name, (name, name))
         | 
| 115 | 
            +
             | 
| 116 | 
            +
             | 
| 117 | 
            +
            def get_stacked_multiply(module_name: str) -> int:
         | 
| 118 | 
            +
                """
         | 
| 119 | 
            +
                Mapping a lora module name to its magnification at output dimension
         | 
| 120 | 
            +
                """
         | 
| 121 | 
            +
                stacked_rank = {
         | 
| 122 | 
            +
                    "qkv_proj": 3,
         | 
| 123 | 
            +
                    "kv_proj": 2,
         | 
| 124 | 
            +
                    "gate_up_proj": 2,
         | 
| 125 | 
            +
                }
         | 
| 126 | 
            +
                return stacked_rank[module_name] if module_name in stacked_rank else 1
         | 
| 127 | 
            +
             | 
| 128 | 
            +
             | 
| 129 | 
            +
            def get_weight_name(
         | 
| 130 | 
            +
                target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType
         | 
| 131 | 
            +
            ) -> Optional[str]:
         | 
| 132 | 
            +
                """
         | 
| 133 | 
            +
                target_name is name of a given module,
         | 
| 134 | 
            +
                lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
         | 
| 135 | 
            +
                If there is a weight name in lora_weight_names that can match target_name, return this name
         | 
| 136 | 
            +
                Else return None
         | 
| 137 | 
            +
                """
         | 
| 138 | 
            +
                idx = 0 if lora_type == LoRAType.LORA_A else 1
         | 
| 139 | 
            +
                for weight_name_pair in lora_weight_names:
         | 
| 140 | 
            +
                    if weight_name_pair[idx] in target_name:
         | 
| 141 | 
            +
                        return weight_name_pair[idx]
         | 
| @@ -237,6 +237,7 @@ class CudaGraphRunner: | |
| 237 237 | 
             
                            "1. disable cuda graph by --disable-cuda-graph\n"
         | 
| 238 238 | 
             
                            "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
         | 
| 239 239 | 
             
                            "3. disable torch compile by not using --enable-torch-compile\n"
         | 
| 240 | 
            +
                            "4. set --cuda-graph-max-bs to a smaller value (e.g., 32)\n"
         | 
| 240 241 | 
             
                            "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
         | 
| 241 242 | 
             
                        )
         | 
| 242 243 |  | 
| @@ -462,8 +463,11 @@ class CudaGraphRunner: | |
| 462 463 | 
             
                                ),
         | 
| 463 464 | 
             
                                positions=None,
         | 
| 464 465 | 
             
                                retrive_index=None,
         | 
| 466 | 
            +
                                retrive_next_token=None,
         | 
| 467 | 
            +
                                retrive_next_sibling=None,
         | 
| 465 468 | 
             
                                retrive_cum_len=None,
         | 
| 466 469 | 
             
                                draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
         | 
| 470 | 
            +
                                spec_steps=self.model_runner.server_args.speculative_num_steps,
         | 
| 467 471 | 
             
                                capture_hidden_mode=CaptureHiddenMode.FULL,
         | 
| 468 472 | 
             
                            )
         | 
| 469 473 |  | 
| @@ -85,6 +85,7 @@ class EAGLEDraftCudaGraphRunner: | |
| 85 85 | 
             
                            "1. disable cuda graph by --disable-cuda-graph\n"
         | 
| 86 86 | 
             
                            "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
         | 
| 87 87 | 
             
                            "3. disable torch compile by not using --enable-torch-compile\n"
         | 
| 88 | 
            +
                            "4. specify --dtype to the same dtype (e.g. bfloat16)\n"
         | 
| 88 89 | 
             
                            "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
         | 
| 89 90 | 
             
                        )
         | 
| 90 91 |  | 
| @@ -4,6 +4,7 @@ import dataclasses | |
| 4 4 | 
             
            from typing import TYPE_CHECKING, List
         | 
| 5 5 |  | 
| 6 6 | 
             
            import torch
         | 
| 7 | 
            +
            import torch.nn.functional as F
         | 
| 7 8 | 
             
            import triton
         | 
| 8 9 | 
             
            import triton.language as tl
         | 
| 9 10 |  | 
| @@ -11,7 +12,14 @@ from sglang.srt.layers.attention.flashinfer_backend import ( | |
| 11 12 | 
             
                create_flashinfer_kv_indices_triton,
         | 
| 12 13 | 
             
            )
         | 
| 13 14 | 
             
            from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
         | 
| 14 | 
            -
            from sglang.srt.speculative.build_eagle_tree import  | 
| 15 | 
            +
            from sglang.srt.speculative.build_eagle_tree import (
         | 
| 16 | 
            +
                build_tree_kernel,
         | 
| 17 | 
            +
                build_tree_kernel_efficient,
         | 
| 18 | 
            +
            )
         | 
| 19 | 
            +
            from sglang.srt.utils import is_cuda_available
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            if is_cuda_available():
         | 
| 22 | 
            +
                from sgl_kernel import tree_speculative_sampling_target_only
         | 
| 15 23 |  | 
| 16 24 | 
             
            if TYPE_CHECKING:
         | 
| 17 25 | 
             
                from sglang.srt.managers.schedule_batch import ScheduleBatch
         | 
| @@ -160,8 +168,11 @@ class EagleVerifyInput: | |
| 160 168 | 
             
                custom_mask: torch.Tensor
         | 
| 161 169 | 
             
                positions: torch.Tensor
         | 
| 162 170 | 
             
                retrive_index: torch.Tensor
         | 
| 171 | 
            +
                retrive_next_token: torch.Tensor
         | 
| 172 | 
            +
                retrive_next_sibling: torch.Tensor
         | 
| 163 173 | 
             
                retrive_cum_len: torch.Tensor
         | 
| 164 174 | 
             
                draft_token_num: int
         | 
| 175 | 
            +
                spec_steps: int
         | 
| 165 176 | 
             
                capture_hidden_mode: CaptureHiddenMode
         | 
| 166 177 |  | 
| 167 178 | 
             
                @classmethod
         | 
| @@ -175,10 +186,45 @@ class EagleVerifyInput: | |
| 175 186 | 
             
                    seq_lens_sum: int,
         | 
| 176 187 | 
             
                    topk: int,
         | 
| 177 188 | 
             
                    spec_steps: int,
         | 
| 178 | 
            -
                     | 
| 189 | 
            +
                    num_verify_tokens: int,
         | 
| 190 | 
            +
                    is_all_greedy: bool,
         | 
| 179 191 | 
             
                ):
         | 
| 180 | 
            -
                     | 
| 181 | 
            -
                         | 
| 192 | 
            +
                    if is_all_greedy:
         | 
| 193 | 
            +
                        tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
         | 
| 194 | 
            +
                            build_tree_kernel(
         | 
| 195 | 
            +
                                verified_id,
         | 
| 196 | 
            +
                                score_list,  # b, n, topk; n= 1 + (num_steps-1) * self.topk
         | 
| 197 | 
            +
                                token_list,
         | 
| 198 | 
            +
                                parents_list,
         | 
| 199 | 
            +
                                seq_lens,
         | 
| 200 | 
            +
                                seq_lens_sum,
         | 
| 201 | 
            +
                                topk,
         | 
| 202 | 
            +
                                spec_steps,
         | 
| 203 | 
            +
                                num_verify_tokens,
         | 
| 204 | 
            +
                            )
         | 
| 205 | 
            +
                        )
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                        return cls(
         | 
| 208 | 
            +
                            draft_tokens,
         | 
| 209 | 
            +
                            tree_mask,
         | 
| 210 | 
            +
                            position,
         | 
| 211 | 
            +
                            retrive_index,
         | 
| 212 | 
            +
                            None,
         | 
| 213 | 
            +
                            None,
         | 
| 214 | 
            +
                            retrive_cum_len,
         | 
| 215 | 
            +
                            num_verify_tokens,
         | 
| 216 | 
            +
                            spec_steps,
         | 
| 217 | 
            +
                            CaptureHiddenMode.FULL,
         | 
| 218 | 
            +
                        )
         | 
| 219 | 
            +
                    else:
         | 
| 220 | 
            +
                        (
         | 
| 221 | 
            +
                            tree_mask,
         | 
| 222 | 
            +
                            position,
         | 
| 223 | 
            +
                            retrive_index,
         | 
| 224 | 
            +
                            retrive_next_token,
         | 
| 225 | 
            +
                            retrive_next_sibling,
         | 
| 226 | 
            +
                            draft_tokens,
         | 
| 227 | 
            +
                        ) = build_tree_kernel_efficient(
         | 
| 182 228 | 
             
                            verified_id,
         | 
| 183 229 | 
             
                            score_list,
         | 
| 184 230 | 
             
                            token_list,
         | 
| @@ -187,18 +233,21 @@ class EagleVerifyInput: | |
| 187 233 | 
             
                            seq_lens_sum,
         | 
| 188 234 | 
             
                            topk,
         | 
| 189 235 | 
             
                            spec_steps,
         | 
| 190 | 
            -
                             | 
| 236 | 
            +
                            num_verify_tokens,
         | 
| 237 | 
            +
                        )
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                        return cls(
         | 
| 240 | 
            +
                            draft_tokens,
         | 
| 241 | 
            +
                            tree_mask,
         | 
| 242 | 
            +
                            position,
         | 
| 243 | 
            +
                            retrive_index,
         | 
| 244 | 
            +
                            retrive_next_token,
         | 
| 245 | 
            +
                            retrive_next_sibling,
         | 
| 246 | 
            +
                            None,
         | 
| 247 | 
            +
                            num_verify_tokens,
         | 
| 248 | 
            +
                            spec_steps,
         | 
| 249 | 
            +
                            CaptureHiddenMode.FULL,
         | 
| 191 250 | 
             
                        )
         | 
| 192 | 
            -
                    )
         | 
| 193 | 
            -
                    return cls(
         | 
| 194 | 
            -
                        draft_tokens,
         | 
| 195 | 
            -
                        tree_mask,
         | 
| 196 | 
            -
                        position,
         | 
| 197 | 
            -
                        retrive_index,
         | 
| 198 | 
            -
                        retrive_cum_len,
         | 
| 199 | 
            -
                        num_verify_token,
         | 
| 200 | 
            -
                        CaptureHiddenMode.FULL,
         | 
| 201 | 
            -
                    )
         | 
| 202 251 |  | 
| 203 252 | 
             
                def prepare_for_verify(self, batch: ScheduleBatch):
         | 
| 204 253 | 
             
                    batch.input_ids = self.draft_token
         | 
| @@ -313,12 +362,6 @@ class EagleVerifyInput: | |
| 313 362 | 
             
                            uniform_samples=coins,
         | 
| 314 363 | 
             
                            target_probs=target_probs,
         | 
| 315 364 | 
             
                            draft_probs=draft_probs,
         | 
| 316 | 
            -
                            threshold_single=global_server_args_dict[
         | 
| 317 | 
            -
                                "speculative_accept_threshold_single"
         | 
| 318 | 
            -
                            ],
         | 
| 319 | 
            -
                            threshold_acc=global_server_args_dict[
         | 
| 320 | 
            -
                                "speculative_accept_threshold_acc"
         | 
| 321 | 
            -
                            ],
         | 
| 322 365 | 
             
                            deterministic=True,
         | 
| 323 366 | 
             
                        )
         | 
| 324 367 |  | 
    
        sglang/version.py
    CHANGED
    
    | @@ -1 +1 @@ | |
| 1 | 
            -
            __version__ = "0.4.2. | 
| 1 | 
            +
            __version__ = "0.4.2.post4"
         | 
| @@ -1,6 +1,6 @@ | |
| 1 1 | 
             
            Metadata-Version: 2.2
         | 
| 2 2 | 
             
            Name: sglang
         | 
| 3 | 
            -
            Version: 0.4.2. | 
| 3 | 
            +
            Version: 0.4.2.post4
         | 
| 4 4 | 
             
            Summary: SGLang is yet another fast serving framework for large language models and vision language models.
         | 
| 5 5 | 
             
            License:                                  Apache License
         | 
| 6 6 | 
             
                                               Version 2.0, January 2004
         | 
| @@ -239,11 +239,11 @@ Requires-Dist: xgrammar>=0.1.10; extra == "runtime-common" | |
| 239 239 | 
             
            Provides-Extra: srt
         | 
| 240 240 | 
             
            Requires-Dist: sglang[runtime_common]; extra == "srt"
         | 
| 241 241 | 
             
            Requires-Dist: cuda-python; extra == "srt"
         | 
| 242 | 
            -
            Requires-Dist: sgl-kernel>=0.0.3. | 
| 242 | 
            +
            Requires-Dist: sgl-kernel>=0.0.3.post3; extra == "srt"
         | 
| 243 243 | 
             
            Requires-Dist: torch; extra == "srt"
         | 
| 244 | 
            -
            Requires-Dist: vllm | 
| 244 | 
            +
            Requires-Dist: vllm<=0.7.2,>=0.6.4.post1; extra == "srt"
         | 
| 245 245 | 
             
            Requires-Dist: flashinfer_python>=0.2.0.post2; extra == "srt"
         | 
| 246 | 
            -
            Requires-Dist: outlines | 
| 246 | 
            +
            Requires-Dist: outlines<=0.1.11,>=0.0.44; extra == "srt"
         | 
| 247 247 | 
             
            Provides-Extra: srt-hip
         | 
| 248 248 | 
             
            Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
         | 
| 249 249 | 
             
            Requires-Dist: torch; extra == "srt-hip"
         |