sglang 0.4.2.post3__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/check_env.py +1 -0
- sglang/srt/constrained/outlines_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +34 -41
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -3
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/lora/backend/__init__.py +25 -5
- sglang/srt/lora/backend/base_backend.py +31 -9
- sglang/srt/lora/backend/flashinfer_backend.py +41 -4
- sglang/srt/lora/backend/triton_backend.py +34 -4
- sglang/srt/lora/layers.py +293 -0
- sglang/srt/lora/lora.py +101 -326
- sglang/srt/lora/lora_manager.py +101 -269
- sglang/srt/lora/mem_pool.py +174 -0
- sglang/srt/lora/triton_ops/__init__.py +7 -1
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
- sglang/srt/lora/utils.py +141 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +64 -21
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +41 -24
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
    
        sglang/srt/lora/lora.py
    CHANGED
    
    | @@ -19,282 +19,25 @@ | |
| 19 19 | 
             
            # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
         | 
| 20 20 |  | 
| 21 21 | 
             
            import re
         | 
| 22 | 
            -
            from  | 
| 22 | 
            +
            from typing import Dict, List
         | 
| 23 23 |  | 
| 24 24 | 
             
            import torch
         | 
| 25 25 | 
             
            from torch import nn
         | 
| 26 26 |  | 
| 27 | 
            -
            from sglang.srt. | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
                RowParallelLinear,
         | 
| 32 | 
            -
            )
         | 
| 33 | 
            -
            from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
         | 
| 27 | 
            +
            from sglang.srt.configs.load_config import LoadConfig
         | 
| 28 | 
            +
            from sglang.srt.hf_transformers_utils import AutoConfig
         | 
| 29 | 
            +
            from sglang.srt.lora.backend import BaseLoRABackend
         | 
| 30 | 
            +
            from sglang.srt.lora.lora_config import LoRAConfig
         | 
| 34 31 | 
             
            from sglang.srt.model_loader.loader import DefaultModelLoader
         | 
| 35 32 |  | 
| 36 33 |  | 
| 37 | 
            -
            @dataclass
         | 
| 38 | 
            -
            class LoraBatchInfo:
         | 
| 39 | 
            -
                # Batch size
         | 
| 40 | 
            -
                bs: int
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                # Lengths of each sequence in shape (bs,)
         | 
| 43 | 
            -
                seg_lens: torch.Tensor
         | 
| 44 | 
            -
             | 
| 45 | 
            -
                # Indice pointers of each sequence in shape (bs + 1, )
         | 
| 46 | 
            -
                seg_indptr: torch.Tensor
         | 
| 47 | 
            -
             | 
| 48 | 
            -
                # Maximum sequence length of current batch
         | 
| 49 | 
            -
                max_len: int
         | 
| 50 | 
            -
             | 
| 51 | 
            -
                # The index of lora adapter used by each sequence, in shape (bs,)
         | 
| 52 | 
            -
                weight_indices: torch.Tensor
         | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
            class BaseLayerWithLoRA(nn.Module):
         | 
| 56 | 
            -
                def __init__(self, base_layer, lora_rank, scaling, lora_backend):
         | 
| 57 | 
            -
                    super().__init__()
         | 
| 58 | 
            -
                    self.base_layer = base_layer
         | 
| 59 | 
            -
                    self.lora_rank = lora_rank
         | 
| 60 | 
            -
                    self.scaling = scaling
         | 
| 61 | 
            -
                    self.set_lora = False
         | 
| 62 | 
            -
                    self.lora_backend = lora_backend
         | 
| 63 | 
            -
             | 
| 64 | 
            -
                def forward(self, x: torch.Tensor):
         | 
| 65 | 
            -
                    return self.base_layer.forward(x)
         | 
| 66 | 
            -
             | 
| 67 | 
            -
                def set_lora_info(self, *args):
         | 
| 68 | 
            -
                    pass
         | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
            class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
         | 
| 72 | 
            -
                def __init__(
         | 
| 73 | 
            -
                    self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
         | 
| 74 | 
            -
                ) -> None:
         | 
| 75 | 
            -
                    super().__init__(base_layer, lora_rank, scaling, lora_backend)
         | 
| 76 | 
            -
                    self.weight = base_layer.weight
         | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
            class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
         | 
| 80 | 
            -
                def __init__(
         | 
| 81 | 
            -
                    self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
         | 
| 82 | 
            -
                ) -> None:
         | 
| 83 | 
            -
                    super().__init__(base_layer, lora_rank, scaling, lora_backend)
         | 
| 84 | 
            -
             | 
| 85 | 
            -
                def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
         | 
| 86 | 
            -
                    # TODO
         | 
| 87 | 
            -
                    return output
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                def forward(self, input_: torch.Tensor):
         | 
| 90 | 
            -
                    # duplicate the logic in ColumnParallelLinear
         | 
| 91 | 
            -
                    bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
         | 
| 92 | 
            -
                    output_parallel = self.base_layer.quant_method.apply(
         | 
| 93 | 
            -
                        self.base_layer, input_, bias
         | 
| 94 | 
            -
                    )
         | 
| 95 | 
            -
             | 
| 96 | 
            -
                    if self.set_lora:
         | 
| 97 | 
            -
                        output_parallel = self.apply_lora(output_parallel, input_)
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                    if self.base_layer.gather_output:
         | 
| 100 | 
            -
                        output = tensor_model_parallel_all_gather(output_parallel)
         | 
| 101 | 
            -
                    else:
         | 
| 102 | 
            -
                        output = output_parallel
         | 
| 103 | 
            -
                    output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
         | 
| 104 | 
            -
                    return output, output_bias
         | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
| 107 | 
            -
            class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
         | 
| 108 | 
            -
                def __init__(
         | 
| 109 | 
            -
                    self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
         | 
| 110 | 
            -
                ) -> None:
         | 
| 111 | 
            -
                    super().__init__(base_layer, lora_rank, scaling, lora_backend)
         | 
| 112 | 
            -
             | 
| 113 | 
            -
                def set_lora_info(
         | 
| 114 | 
            -
                    self,
         | 
| 115 | 
            -
                    A_buffer,
         | 
| 116 | 
            -
                    B_buffer,
         | 
| 117 | 
            -
                ):
         | 
| 118 | 
            -
                    self.set_lora = True
         | 
| 119 | 
            -
                    self.A_buffer = A_buffer
         | 
| 120 | 
            -
                    self.B_buffer = B_buffer
         | 
| 121 | 
            -
             | 
| 122 | 
            -
                def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
         | 
| 123 | 
            -
                    lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
         | 
| 124 | 
            -
             | 
| 125 | 
            -
                    output_dim = base_output.shape[-1]
         | 
| 126 | 
            -
                    lora_output = torch.empty_like(base_output)
         | 
| 127 | 
            -
                    lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
         | 
| 128 | 
            -
                        x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
         | 
| 129 | 
            -
                        weights=self.B_buffer[0],
         | 
| 130 | 
            -
                    )
         | 
| 131 | 
            -
             | 
| 132 | 
            -
                    lora_output[:, output_dim : 2 * output_dim] = (
         | 
| 133 | 
            -
                        self.lora_backend.run_lora_b_sgemm(
         | 
| 134 | 
            -
                            x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
         | 
| 135 | 
            -
                            weights=self.B_buffer[1],
         | 
| 136 | 
            -
                        )
         | 
| 137 | 
            -
                    )
         | 
| 138 | 
            -
             | 
| 139 | 
            -
                    return base_output + lora_output * self.scaling
         | 
| 140 | 
            -
             | 
| 141 | 
            -
             | 
| 142 | 
            -
            class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
         | 
| 143 | 
            -
                def init__(
         | 
| 144 | 
            -
                    self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
         | 
| 145 | 
            -
                ) -> None:
         | 
| 146 | 
            -
                    super().__init__(base_layer, lora_rank, scaling, lora_backend)
         | 
| 147 | 
            -
             | 
| 148 | 
            -
                def set_lora_info(
         | 
| 149 | 
            -
                    self,
         | 
| 150 | 
            -
                    A_buffer_qkv,
         | 
| 151 | 
            -
                    B_buffer_q,
         | 
| 152 | 
            -
                    B_buffer_kv,
         | 
| 153 | 
            -
                ):
         | 
| 154 | 
            -
                    self.set_lora = True
         | 
| 155 | 
            -
                    self.A_buffer_qkv = A_buffer_qkv
         | 
| 156 | 
            -
             | 
| 157 | 
            -
                    if self.lora_backend.fuse_qkv_lora_b:
         | 
| 158 | 
            -
                        assert (
         | 
| 159 | 
            -
                            B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
         | 
| 160 | 
            -
                        ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
         | 
| 161 | 
            -
                        output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
         | 
| 162 | 
            -
             | 
| 163 | 
            -
                        # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
         | 
| 164 | 
            -
                        self.B_buffer_qkv = torch.cat(
         | 
| 165 | 
            -
                            (B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
         | 
| 166 | 
            -
                        ).contiguous()
         | 
| 167 | 
            -
             | 
| 168 | 
            -
                        # Offsets of q/k/v in output dimension
         | 
| 169 | 
            -
                        self.output_offset = torch.tensor(
         | 
| 170 | 
            -
                            [
         | 
| 171 | 
            -
                                0,
         | 
| 172 | 
            -
                                output_dim_q,
         | 
| 173 | 
            -
                                output_dim_q + output_dim_kv,
         | 
| 174 | 
            -
                                output_dim_q + 2 * output_dim_kv,
         | 
| 175 | 
            -
                            ],
         | 
| 176 | 
            -
                            dtype=torch.int32,
         | 
| 177 | 
            -
                            device=B_buffer_q.device,
         | 
| 178 | 
            -
                        )
         | 
| 179 | 
            -
                        # For computing number of launched blocks
         | 
| 180 | 
            -
                        self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
         | 
| 181 | 
            -
                    else:
         | 
| 182 | 
            -
                        self.B_buffer_qkv = (
         | 
| 183 | 
            -
                            B_buffer_q,
         | 
| 184 | 
            -
                            B_buffer_kv,
         | 
| 185 | 
            -
                        )
         | 
| 186 | 
            -
                        self.output_offset = None
         | 
| 187 | 
            -
                        self.max_qkv_out_dim = None
         | 
| 188 | 
            -
             | 
| 189 | 
            -
                def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
         | 
| 190 | 
            -
                    lora_output = self.lora_backend.run_qkv_lora(
         | 
| 191 | 
            -
                        x,
         | 
| 192 | 
            -
                        self.A_buffer_qkv,
         | 
| 193 | 
            -
                        self.B_buffer_qkv,
         | 
| 194 | 
            -
                        output_offset=self.output_offset,
         | 
| 195 | 
            -
                        max_qkv_out_dim=self.max_qkv_out_dim,
         | 
| 196 | 
            -
                        base_output=base_output,
         | 
| 197 | 
            -
                        scaling=self.scaling,
         | 
| 198 | 
            -
                    )
         | 
| 199 | 
            -
                    return (
         | 
| 200 | 
            -
                        lora_output
         | 
| 201 | 
            -
                        if self.lora_backend.fuse_output_scaling_add
         | 
| 202 | 
            -
                        else base_output + lora_output * self.scaling
         | 
| 203 | 
            -
                    )
         | 
| 204 | 
            -
             | 
| 205 | 
            -
             | 
| 206 | 
            -
            class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
         | 
| 207 | 
            -
                def __init__(
         | 
| 208 | 
            -
                    self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
         | 
| 209 | 
            -
                ) -> None:
         | 
| 210 | 
            -
                    super().__init__(base_layer, lora_rank, scaling, lora_backend)
         | 
| 211 | 
            -
             | 
| 212 | 
            -
                def set_lora_info(self, A_buffer, B_buffer):
         | 
| 213 | 
            -
                    self.set_lora = True
         | 
| 214 | 
            -
                    self.A_buffer = A_buffer
         | 
| 215 | 
            -
                    self.B_buffer = B_buffer
         | 
| 216 | 
            -
             | 
| 217 | 
            -
                def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
         | 
| 218 | 
            -
                    lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
         | 
| 219 | 
            -
                    lora_output = self.lora_backend.run_lora_b_sgemm(
         | 
| 220 | 
            -
                        lora_a_output,
         | 
| 221 | 
            -
                        self.B_buffer[0],
         | 
| 222 | 
            -
                        base_output=base_output,
         | 
| 223 | 
            -
                        scaling=self.scaling,
         | 
| 224 | 
            -
                    )
         | 
| 225 | 
            -
                    return (
         | 
| 226 | 
            -
                        lora_output
         | 
| 227 | 
            -
                        if self.lora_backend.fuse_output_scaling_add
         | 
| 228 | 
            -
                        else base_output + lora_output * self.scaling
         | 
| 229 | 
            -
                    )
         | 
| 230 | 
            -
             | 
| 231 | 
            -
                def forward(self, input_):
         | 
| 232 | 
            -
                    # duplicate the logic in RowParallelLinear
         | 
| 233 | 
            -
                    if self.base_layer.input_is_parallel:
         | 
| 234 | 
            -
                        input_parallel = input_
         | 
| 235 | 
            -
                    else:
         | 
| 236 | 
            -
                        tp_rank = get_tensor_model_parallel_rank()
         | 
| 237 | 
            -
                        splitted_input = split_tensor_along_last_dim(
         | 
| 238 | 
            -
                            input_, num_partitions=self.base_layer.tp_size
         | 
| 239 | 
            -
                        )
         | 
| 240 | 
            -
                        input_parallel = splitted_input[tp_rank].contiguous()
         | 
| 241 | 
            -
                    output_parallel = self.base_layer.quant_method.apply(
         | 
| 242 | 
            -
                        self.base_layer, input_parallel
         | 
| 243 | 
            -
                    )
         | 
| 244 | 
            -
             | 
| 245 | 
            -
                    if self.set_lora:
         | 
| 246 | 
            -
                        output_parallel = self.apply_lora(output_parallel, input_parallel)
         | 
| 247 | 
            -
             | 
| 248 | 
            -
                    if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
         | 
| 249 | 
            -
                        output_ = tensor_model_parallel_all_reduce(output_parallel)
         | 
| 250 | 
            -
                    else:
         | 
| 251 | 
            -
                        output_ = output_parallel
         | 
| 252 | 
            -
             | 
| 253 | 
            -
                    if not self.base_layer.skip_bias_add:
         | 
| 254 | 
            -
                        output = (
         | 
| 255 | 
            -
                            output_ + self.base_layer.bias
         | 
| 256 | 
            -
                            if self.base_layer.bias is not None
         | 
| 257 | 
            -
                            else output_
         | 
| 258 | 
            -
                        )
         | 
| 259 | 
            -
                        output_bias = None
         | 
| 260 | 
            -
                    else:
         | 
| 261 | 
            -
                        output = output_
         | 
| 262 | 
            -
                        output_bias = self.base_layer.bias
         | 
| 263 | 
            -
                    return output, output_bias
         | 
| 264 | 
            -
             | 
| 265 | 
            -
             | 
| 266 | 
            -
            def get_lora_layer(
         | 
| 267 | 
            -
                layer: nn.Module, lora_rank, scaling, lora_backend
         | 
| 268 | 
            -
            ) -> BaseLayerWithLoRA:
         | 
| 269 | 
            -
                supported_layer_types = {
         | 
| 270 | 
            -
                    # the order matters
         | 
| 271 | 
            -
                    VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
         | 
| 272 | 
            -
                    QKVParallelLinear: QKVParallelLinearWithLoRA,
         | 
| 273 | 
            -
                    MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
         | 
| 274 | 
            -
                    ColumnParallelLinear: ColumnParallelLinearWithLoRA,
         | 
| 275 | 
            -
                    RowParallelLinear: RowParallelLinearWithLoRA,
         | 
| 276 | 
            -
                }
         | 
| 277 | 
            -
                for src_layer_type, lora_layer_type in supported_layer_types.items():
         | 
| 278 | 
            -
                    if isinstance(layer, src_layer_type):  # pylint: disable=unidiomatic-typecheck
         | 
| 279 | 
            -
                        ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
         | 
| 280 | 
            -
                        return ret
         | 
| 281 | 
            -
                raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
         | 
| 282 | 
            -
             | 
| 283 | 
            -
             | 
| 284 | 
            -
            def get_mapped_params(module_names):
         | 
| 285 | 
            -
                ret = set()
         | 
| 286 | 
            -
                for module_name in module_names:
         | 
| 287 | 
            -
                    ret.add(params_mapping(module_name))
         | 
| 288 | 
            -
                return list(ret)
         | 
| 289 | 
            -
             | 
| 290 | 
            -
             | 
| 291 34 | 
             
            class LoRALayer(nn.Module):
         | 
| 292 | 
            -
                def __init__(self, config, base_hf_config):
         | 
| 35 | 
            +
                def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
         | 
| 293 36 | 
             
                    super().__init__()
         | 
| 294 | 
            -
                    self.config = config
         | 
| 295 | 
            -
                    self.base_hf_config = base_hf_config
         | 
| 296 | 
            -
                    self.weights = {}
         | 
| 297 | 
            -
                    self.weight_gpu = {}
         | 
| 37 | 
            +
                    self.config: LoRAConfig = config
         | 
| 38 | 
            +
                    self.base_hf_config: AutoConfig = base_hf_config
         | 
| 39 | 
            +
                    self.weights: Dict[str, torch.Tensor] = {}
         | 
| 40 | 
            +
                    self.weight_gpu: Dict[str, torch.Tensor] = {}
         | 
| 298 41 |  | 
| 299 42 | 
             
                def load_to_gpu(self):
         | 
| 300 43 | 
             
                    for name, weight in self.weights.items():
         | 
| @@ -306,33 +49,32 @@ class LoRALayer(nn.Module): | |
| 306 49 |  | 
| 307 50 |  | 
| 308 51 | 
             
            class LoRAAdapter(nn.Module):
         | 
| 309 | 
            -
                def __init__( | 
| 52 | 
            +
                def __init__(
         | 
| 53 | 
            +
                    self,
         | 
| 54 | 
            +
                    uid: str,
         | 
| 55 | 
            +
                    config: LoRAConfig,
         | 
| 56 | 
            +
                    base_hf_config: AutoConfig,
         | 
| 57 | 
            +
                    load_config: LoadConfig,
         | 
| 58 | 
            +
                    lora_backend: BaseLoRABackend,
         | 
| 59 | 
            +
                ):
         | 
| 310 60 | 
             
                    super().__init__()
         | 
| 311 | 
            -
                    self.uid = uid
         | 
| 312 | 
            -
                    self.config = config
         | 
| 61 | 
            +
                    self.uid: str = uid
         | 
| 62 | 
            +
                    self.config: LoRAConfig = config
         | 
| 313 63 | 
             
                    assert self.config.hf_config["peft_type"].lower() == "lora"
         | 
| 314 | 
            -
                    self.base_hf_config = base_hf_config
         | 
| 315 | 
            -
                    self.load_config = load_config
         | 
| 316 | 
            -
                    self.lora_backend = lora_backend
         | 
| 317 | 
            -
                    self.scaling = self.config.lora_alpha / self.config.r
         | 
| 64 | 
            +
                    self.base_hf_config: AutoConfig = base_hf_config
         | 
| 65 | 
            +
                    self.load_config: LoadConfig = load_config
         | 
| 66 | 
            +
                    self.lora_backend: BaseLoRABackend = lora_backend
         | 
| 67 | 
            +
                    self.scaling: float = self.config.lora_alpha / self.config.r
         | 
| 318 68 |  | 
| 319 | 
            -
                    self.layers = nn.ModuleList(
         | 
| 69 | 
            +
                    self.layers: List[LoRALayer] = nn.ModuleList(
         | 
| 320 70 | 
             
                        [
         | 
| 321 71 | 
             
                            LoRALayer(config, base_hf_config)
         | 
| 322 72 | 
             
                            for i in range(base_hf_config.num_hidden_layers)
         | 
| 323 73 | 
             
                        ]
         | 
| 324 74 | 
             
                    )
         | 
| 325 75 |  | 
| 326 | 
            -
                    self.weights = {}
         | 
| 327 | 
            -
                    self.weights_gpu = {}
         | 
| 328 | 
            -
             | 
| 329 | 
            -
                def get_stacked_multiply(self, module_name):
         | 
| 330 | 
            -
                    stacked_rank = {
         | 
| 331 | 
            -
                        "qkv_proj": 3,
         | 
| 332 | 
            -
                        "kv_proj": 2,
         | 
| 333 | 
            -
                        "gate_up_proj": 2,
         | 
| 334 | 
            -
                    }
         | 
| 335 | 
            -
                    return stacked_rank[module_name] if module_name in stacked_rank else 1
         | 
| 76 | 
            +
                    self.weights: Dict[str, torch.Tensor] = {}
         | 
| 77 | 
            +
                    self.weights_gpu: Dict[str, torch.Tensor] = {}
         | 
| 336 78 |  | 
| 337 79 | 
             
                def load_to_gpu(self):
         | 
| 338 80 | 
             
                    for name, weight in self.weights.items():
         | 
| @@ -367,44 +109,77 @@ class LoRAAdapter(nn.Module): | |
| 367 109 | 
             
                    for i in range(self.base_hf_config.num_hidden_layers):
         | 
| 368 110 | 
             
                        layer = self.layers[i]
         | 
| 369 111 | 
             
                        weight_names = [name for name, _ in layer.weights.items()]
         | 
| 370 | 
            -
                         | 
| 371 | 
            -
             | 
| 372 | 
            -
             | 
| 373 | 
            -
             | 
| 374 | 
            -
             | 
| 375 | 
            -
             | 
| 376 | 
            -
             | 
| 377 | 
            -
             | 
| 378 | 
            -
             | 
| 379 | 
            -
             | 
| 380 | 
            -
             | 
| 381 | 
            -
             | 
| 382 | 
            -
             | 
| 383 | 
            -
             | 
| 384 | 
            -
             | 
| 385 | 
            -
             | 
| 386 | 
            -
             | 
| 387 | 
            -
             | 
| 388 | 
            -
             | 
| 389 | 
            -
             | 
| 390 | 
            -
             | 
| 391 | 
            -
             | 
| 392 | 
            -
             | 
| 393 | 
            -
             | 
| 394 | 
            -
             | 
| 395 | 
            -
             | 
| 396 | 
            -
             | 
| 397 | 
            -
             | 
| 398 | 
            -
             | 
| 399 | 
            -
                                 | 
| 400 | 
            -
                                 | 
| 401 | 
            -
             | 
| 402 | 
            -
             | 
| 403 | 
            -
             | 
| 404 | 
            -
                                     | 
| 405 | 
            -
             | 
| 406 | 
            -
             | 
| 407 | 
            -
                                         | 
| 408 | 
            -
                                    )
         | 
| 409 | 
            -
             | 
| 410 | 
            -
                                 | 
| 112 | 
            +
                        self.stack_qkv_proj(weight_names, layer.weights)
         | 
| 113 | 
            +
                        self.stack_gate_up_proj(weight_names, layer.weights)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    # Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
         | 
| 118 | 
            +
                    target_module = set()
         | 
| 119 | 
            +
                    for weight_name in weight_names:
         | 
| 120 | 
            +
                        if "k_proj" in weight_name:
         | 
| 121 | 
            +
                            target_module.add("k_proj")
         | 
| 122 | 
            +
                        if "q_proj" in weight_name:
         | 
| 123 | 
            +
                            target_module.add("q_proj")
         | 
| 124 | 
            +
                        if "v_proj" in weight_name:
         | 
| 125 | 
            +
                            target_module.add("v_proj")
         | 
| 126 | 
            +
                    if len(target_module) == 0:
         | 
| 127 | 
            +
                        return
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    for weight_name in weight_names:
         | 
| 130 | 
            +
                        # We assume every lora adaptor should contain lora modules for q_proj
         | 
| 131 | 
            +
                        if "q_proj" in weight_name:
         | 
| 132 | 
            +
                            q_name = weight_name
         | 
| 133 | 
            +
                            k_name = weight_name.replace("q_proj", "k_proj")
         | 
| 134 | 
            +
                            v_name = weight_name.replace("q_proj", "v_proj")
         | 
| 135 | 
            +
                            kv_name = weight_name.replace("q_proj", "kv_proj")
         | 
| 136 | 
            +
                            qkv_name = weight_name.replace("q_proj", "qkv_proj")
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                            # If k_proj doesn't have lora, initialize it to zero
         | 
| 139 | 
            +
                            k_proj_weight = (
         | 
| 140 | 
            +
                                weights[k_name]
         | 
| 141 | 
            +
                                if "k_proj" in target_module
         | 
| 142 | 
            +
                                else torch.zeros_like(weights[v_name])
         | 
| 143 | 
            +
                            )
         | 
| 144 | 
            +
                            if "lora_A" in weight_name:
         | 
| 145 | 
            +
                                weights[qkv_name] = torch.cat(
         | 
| 146 | 
            +
                                    (
         | 
| 147 | 
            +
                                        weights[q_name],
         | 
| 148 | 
            +
                                        k_proj_weight,
         | 
| 149 | 
            +
                                        weights[v_name],
         | 
| 150 | 
            +
                                    ),
         | 
| 151 | 
            +
                                    0,
         | 
| 152 | 
            +
                                )
         | 
| 153 | 
            +
                                weights.pop(q_name)
         | 
| 154 | 
            +
                                if "k_proj" in target_module:
         | 
| 155 | 
            +
                                    weights.pop(k_name)
         | 
| 156 | 
            +
                                weights.pop(v_name)
         | 
| 157 | 
            +
                            else:
         | 
| 158 | 
            +
                                weights[kv_name] = torch.stack(
         | 
| 159 | 
            +
                                    [
         | 
| 160 | 
            +
                                        k_proj_weight,
         | 
| 161 | 
            +
                                        weights[v_name],
         | 
| 162 | 
            +
                                    ],
         | 
| 163 | 
            +
                                    dim=0,
         | 
| 164 | 
            +
                                )
         | 
| 165 | 
            +
                                if "k_proj" in target_module:
         | 
| 166 | 
            +
                                    weights.pop(k_name)
         | 
| 167 | 
            +
                                weights.pop(v_name)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                def stack_gate_up_proj(
         | 
| 170 | 
            +
                    self, weight_names: List[str], weights: Dict[str, torch.Tensor]
         | 
| 171 | 
            +
                ):
         | 
| 172 | 
            +
                    for weight_name in weight_names:
         | 
| 173 | 
            +
                        if "gate_proj" in weight_name:
         | 
| 174 | 
            +
                            up_name = weight_name.replace("gate_proj", "up_proj")
         | 
| 175 | 
            +
                            gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
         | 
| 176 | 
            +
                            if "lora_A" in weight_name:
         | 
| 177 | 
            +
                                weights[gate_up_name] = torch.cat(
         | 
| 178 | 
            +
                                    (weights[weight_name], weights[up_name]), 0
         | 
| 179 | 
            +
                                )
         | 
| 180 | 
            +
                            else:
         | 
| 181 | 
            +
                                weights[gate_up_name] = torch.stack(
         | 
| 182 | 
            +
                                    [weights[weight_name], weights[up_name]], dim=0
         | 
| 183 | 
            +
                                )
         | 
| 184 | 
            +
                            weights.pop(weight_name)
         | 
| 185 | 
            +
                            weights.pop(up_name)
         |