sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/constrained/outlines_backend.py +9 -1
- sglang/srt/custom_op.py +40 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/layers/activation.py +10 -5
- sglang/srt/layers/attention/flashinfer_backend.py +284 -39
- sglang/srt/layers/attention/triton_backend.py +71 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
- sglang/srt/layers/layernorm.py +1 -5
- sglang/srt/layers/moe/ep_moe/layer.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
- sglang/srt/layers/moe/topk.py +4 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8_kernel.py +140 -2
- sglang/srt/layers/rotary_embedding.py +1 -3
- sglang/srt/layers/sampler.py +4 -4
- sglang/srt/lora/backend/__init__.py +8 -0
- sglang/srt/lora/backend/base_backend.py +95 -0
- sglang/srt/lora/backend/flashinfer_backend.py +91 -0
- sglang/srt/lora/backend/triton_backend.py +61 -0
- sglang/srt/lora/lora.py +127 -112
- sglang/srt/lora/lora_manager.py +50 -18
- sglang/srt/lora/triton_ops/__init__.py +5 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
- sglang/srt/model_executor/cuda_graph_runner.py +77 -80
- sglang/srt/model_executor/forward_batch_info.py +58 -59
- sglang/srt/model_executor/model_runner.py +2 -2
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/server_args.py +13 -2
- sglang/srt/speculative/build_eagle_tree.py +4 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
- sglang/srt/speculative/eagle_utils.py +361 -372
- sglang/srt/speculative/eagle_worker.py +177 -45
- sglang/srt/utils.py +7 -0
- sglang/test/runners.py +2 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +15 -6
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +72 -33
- sglang/srt/layers/custom_op_util.py +0 -25
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,91 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.lora.backend import BaseLoraBackend
|
6
|
+
from sglang.srt.lora.lora import LoraBatchInfo
|
7
|
+
from sglang.srt.utils import is_flashinfer_available
|
8
|
+
|
9
|
+
if is_flashinfer_available():
|
10
|
+
from flashinfer import SegmentGEMMWrapper
|
11
|
+
|
12
|
+
|
13
|
+
class FlashInferLoraBackend(BaseLoraBackend):
|
14
|
+
|
15
|
+
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
|
16
|
+
super().__init__(name, batch_info)
|
17
|
+
|
18
|
+
# Set up SGemm Wrapper from flashinfer
|
19
|
+
# FIXME wait for flashinfer segment gemm update
|
20
|
+
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
21
|
+
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
|
22
|
+
|
23
|
+
def run_lora_a_sgemm(
|
24
|
+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
25
|
+
) -> torch.Tensor:
|
26
|
+
|
27
|
+
return self.segment_gemm.run(
|
28
|
+
x=x,
|
29
|
+
weights=weights,
|
30
|
+
batch_size=self.batch_info.bs,
|
31
|
+
weight_column_major=True,
|
32
|
+
seg_indptr=self.batch_info.seg_indptr,
|
33
|
+
weight_indices=self.batch_info.weight_indices,
|
34
|
+
)
|
35
|
+
|
36
|
+
def run_lora_b_sgemm(
|
37
|
+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
38
|
+
) -> torch.Tensor:
|
39
|
+
|
40
|
+
return self.segment_gemm.run(
|
41
|
+
x=x,
|
42
|
+
weights=weights,
|
43
|
+
batch_size=self.batch_info.bs,
|
44
|
+
weight_column_major=True,
|
45
|
+
seg_indptr=self.batch_info.seg_indptr,
|
46
|
+
weight_indices=self.batch_info.weight_indices,
|
47
|
+
)
|
48
|
+
|
49
|
+
def run_qkv_lora(
|
50
|
+
self,
|
51
|
+
x: torch.Tensor,
|
52
|
+
qkv_lora_a: torch.Tensor,
|
53
|
+
qkv_lora_b: Tuple[torch.Tensor],
|
54
|
+
*args,
|
55
|
+
**kwargs,
|
56
|
+
) -> torch.Tensor:
|
57
|
+
|
58
|
+
# Shape of lora_a_output: (s, 3 * r)
|
59
|
+
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
|
60
|
+
|
61
|
+
q_lora_b, kv_lora_b = qkv_lora_b
|
62
|
+
lora_rank = kv_lora_b.shape[-1]
|
63
|
+
output_dim_q = q_lora_b.shape[-2]
|
64
|
+
output_dim_kv = kv_lora_b.shape[-2]
|
65
|
+
lora_output = torch.empty(
|
66
|
+
(x.shape[0], output_dim_q + 2 * output_dim_kv),
|
67
|
+
device=x.device,
|
68
|
+
dtype=x.dtype,
|
69
|
+
)
|
70
|
+
|
71
|
+
# q
|
72
|
+
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
|
73
|
+
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
|
74
|
+
)
|
75
|
+
|
76
|
+
# kv
|
77
|
+
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
|
78
|
+
self.run_lora_b_sgemm(
|
79
|
+
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
|
80
|
+
weights=kv_lora_b[0],
|
81
|
+
)
|
82
|
+
)
|
83
|
+
|
84
|
+
lora_output[
|
85
|
+
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
|
86
|
+
] = self.run_lora_b_sgemm(
|
87
|
+
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
|
88
|
+
weights=kv_lora_b[1],
|
89
|
+
)
|
90
|
+
|
91
|
+
return lora_output
|
@@ -0,0 +1,61 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from sglang.srt.lora.backend import BaseLoraBackend
|
4
|
+
from sglang.srt.lora.lora import LoraBatchInfo
|
5
|
+
from sglang.srt.lora.triton_ops import (
|
6
|
+
qkv_lora_b_fwd,
|
7
|
+
sgemm_lora_a_fwd,
|
8
|
+
sgemm_lora_b_fwd,
|
9
|
+
)
|
10
|
+
|
11
|
+
|
12
|
+
class TritonLoraBackend(BaseLoraBackend):
|
13
|
+
|
14
|
+
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
|
15
|
+
super().__init__(name, batch_info)
|
16
|
+
|
17
|
+
def run_lora_a_sgemm(
|
18
|
+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
19
|
+
) -> torch.Tensor:
|
20
|
+
return sgemm_lora_a_fwd(x, weights, self.batch_info)
|
21
|
+
|
22
|
+
def run_lora_b_sgemm(
|
23
|
+
self,
|
24
|
+
x: torch.Tensor,
|
25
|
+
weights: torch.Tensor,
|
26
|
+
base_output: torch.Tensor = None,
|
27
|
+
scaling: float = 1.0,
|
28
|
+
*args,
|
29
|
+
**kwargs
|
30
|
+
) -> torch.Tensor:
|
31
|
+
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling)
|
32
|
+
|
33
|
+
def run_qkv_lora(
|
34
|
+
self,
|
35
|
+
x: torch.Tensor,
|
36
|
+
qkv_lora_a: torch.Tensor,
|
37
|
+
qkv_lora_b: torch.Tensor,
|
38
|
+
output_offset: torch.Tensor,
|
39
|
+
max_qkv_out_dim: int,
|
40
|
+
base_output: torch.Tensor = None,
|
41
|
+
scaling: float = 1.0,
|
42
|
+
*args,
|
43
|
+
**kwargs
|
44
|
+
) -> torch.Tensor:
|
45
|
+
|
46
|
+
# x: (s, input_dim)
|
47
|
+
# qkv_lora_a: (num_lora, 3 * r, input_dim)
|
48
|
+
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
49
|
+
assert isinstance(qkv_lora_b, torch.Tensor)
|
50
|
+
|
51
|
+
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
|
52
|
+
lora_output = qkv_lora_b_fwd(
|
53
|
+
lora_a_output,
|
54
|
+
qkv_lora_b,
|
55
|
+
self.batch_info,
|
56
|
+
output_offset,
|
57
|
+
max_qkv_out_dim,
|
58
|
+
base_output,
|
59
|
+
scaling,
|
60
|
+
)
|
61
|
+
return lora_output
|
sglang/srt/lora/lora.py
CHANGED
@@ -18,12 +18,11 @@
|
|
18
18
|
# LoRA layers class inheritance adapted from:
|
19
19
|
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
20
20
|
|
21
|
-
|
22
21
|
import re
|
22
|
+
from dataclasses import dataclass
|
23
23
|
|
24
24
|
import torch
|
25
25
|
from torch import nn
|
26
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
27
26
|
|
28
27
|
from sglang.srt.layers.linear import (
|
29
28
|
ColumnParallelLinear,
|
@@ -31,17 +30,36 @@ from sglang.srt.layers.linear import (
|
|
31
30
|
QKVParallelLinear,
|
32
31
|
RowParallelLinear,
|
33
32
|
)
|
33
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
34
34
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
35
35
|
|
36
36
|
|
37
|
+
@dataclass
|
38
|
+
class LoraBatchInfo:
|
39
|
+
# Batch size
|
40
|
+
bs: int
|
41
|
+
|
42
|
+
# Lengths of each sequence in shape (bs,)
|
43
|
+
seg_lens: torch.Tensor
|
44
|
+
|
45
|
+
# Indice pointers of each sequence in shape (bs + 1, )
|
46
|
+
seg_indptr: torch.Tensor
|
47
|
+
|
48
|
+
# Maximum sequence length of current batch
|
49
|
+
max_len: int
|
50
|
+
|
51
|
+
# The index of lora adapter used by each sequence, in shape (bs,)
|
52
|
+
weight_indices: torch.Tensor
|
53
|
+
|
54
|
+
|
37
55
|
class BaseLayerWithLoRA(nn.Module):
|
38
|
-
def __init__(self, base_layer,
|
56
|
+
def __init__(self, base_layer, lora_rank, scaling, lora_backend):
|
39
57
|
super().__init__()
|
40
58
|
self.base_layer = base_layer
|
41
|
-
self.segment_gemm = segment_gemm
|
42
59
|
self.lora_rank = lora_rank
|
43
60
|
self.scaling = scaling
|
44
61
|
self.set_lora = False
|
62
|
+
self.lora_backend = lora_backend
|
45
63
|
|
46
64
|
def forward(self, x: torch.Tensor):
|
47
65
|
return self.base_layer.forward(x)
|
@@ -52,17 +70,17 @@ class BaseLayerWithLoRA(nn.Module):
|
|
52
70
|
|
53
71
|
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
54
72
|
def __init__(
|
55
|
-
self, base_layer: VocabParallelEmbedding,
|
73
|
+
self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
|
56
74
|
) -> None:
|
57
|
-
super().__init__(base_layer,
|
75
|
+
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
58
76
|
self.weight = base_layer.weight
|
59
77
|
|
60
78
|
|
61
79
|
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
62
80
|
def __init__(
|
63
|
-
self, base_layer: ColumnParallelLinear,
|
81
|
+
self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
|
64
82
|
) -> None:
|
65
|
-
super().__init__(base_layer,
|
83
|
+
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
66
84
|
|
67
85
|
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
68
86
|
# TODO
|
@@ -88,136 +106,127 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
88
106
|
|
89
107
|
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
90
108
|
def __init__(
|
91
|
-
self, base_layer: MergedColumnParallelLinear,
|
109
|
+
self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
|
92
110
|
) -> None:
|
93
|
-
super().__init__(base_layer,
|
111
|
+
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
94
112
|
|
95
|
-
def set_lora_info(
|
113
|
+
def set_lora_info(
|
114
|
+
self,
|
115
|
+
A_buffer,
|
116
|
+
B_buffer,
|
117
|
+
):
|
96
118
|
self.set_lora = True
|
97
119
|
self.A_buffer = A_buffer
|
98
120
|
self.B_buffer = B_buffer
|
99
|
-
self.bs = bs
|
100
|
-
self.seg_indptr = seg_indptr
|
101
|
-
self.weight_indices = weight_indices
|
102
121
|
|
103
122
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
104
|
-
lora_a_output = self.
|
105
|
-
|
106
|
-
|
107
|
-
batch_size=self.bs,
|
108
|
-
weight_column_major=True,
|
109
|
-
seg_indptr=self.seg_indptr,
|
110
|
-
weight_indices=self.weight_indices,
|
111
|
-
)
|
112
|
-
# FIXME
|
123
|
+
lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
|
124
|
+
|
125
|
+
output_dim = base_output.shape[-1]
|
113
126
|
lora_output = torch.empty_like(base_output)
|
114
|
-
output_dim =
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
].contiguous(),
|
122
|
-
weights=self.B_buffer[
|
123
|
-
batch_size=self.bs,
|
124
|
-
weight_column_major=True,
|
125
|
-
seg_indptr=self.seg_indptr,
|
126
|
-
weight_indices=self.weight_indices,
|
127
|
+
lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
|
128
|
+
x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
|
129
|
+
weights=self.B_buffer[0],
|
130
|
+
)
|
131
|
+
|
132
|
+
lora_output[:, output_dim : 2 * output_dim] = (
|
133
|
+
self.lora_backend.run_lora_b_sgemm(
|
134
|
+
x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
|
135
|
+
weights=self.B_buffer[1],
|
127
136
|
)
|
137
|
+
)
|
138
|
+
|
128
139
|
return base_output + lora_output * self.scaling
|
129
140
|
|
130
141
|
|
131
142
|
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
132
|
-
def
|
133
|
-
self, base_layer: QKVParallelLinear,
|
143
|
+
def init__(
|
144
|
+
self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
|
134
145
|
) -> None:
|
135
|
-
super().__init__(base_layer,
|
146
|
+
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
136
147
|
|
137
148
|
def set_lora_info(
|
138
|
-
self,
|
149
|
+
self,
|
150
|
+
A_buffer_qkv,
|
151
|
+
B_buffer_q,
|
152
|
+
B_buffer_kv,
|
139
153
|
):
|
140
154
|
self.set_lora = True
|
141
155
|
self.A_buffer_qkv = A_buffer_qkv
|
142
|
-
|
143
|
-
self.
|
144
|
-
|
145
|
-
|
146
|
-
|
156
|
+
|
157
|
+
if self.lora_backend.fuse_qkv_lora_b:
|
158
|
+
assert (
|
159
|
+
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
|
160
|
+
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
|
161
|
+
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
162
|
+
|
163
|
+
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
164
|
+
self.B_buffer_qkv = torch.cat(
|
165
|
+
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
|
166
|
+
).contiguous()
|
167
|
+
|
168
|
+
# Offsets of q/k/v in output dimension
|
169
|
+
self.output_offset = torch.tensor(
|
170
|
+
[
|
171
|
+
0,
|
172
|
+
output_dim_q,
|
173
|
+
output_dim_q + output_dim_kv,
|
174
|
+
output_dim_q + 2 * output_dim_kv,
|
175
|
+
],
|
176
|
+
dtype=torch.int32,
|
177
|
+
device=B_buffer_q.device,
|
178
|
+
)
|
179
|
+
# For computing number of launched blocks
|
180
|
+
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
181
|
+
else:
|
182
|
+
self.B_buffer_qkv = (
|
183
|
+
B_buffer_q,
|
184
|
+
B_buffer_kv,
|
185
|
+
)
|
186
|
+
self.output_offset = None
|
187
|
+
self.max_qkv_out_dim = None
|
147
188
|
|
148
189
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
149
|
-
|
150
|
-
x
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
190
|
+
lora_output = self.lora_backend.run_qkv_lora(
|
191
|
+
x,
|
192
|
+
self.A_buffer_qkv,
|
193
|
+
self.B_buffer_qkv,
|
194
|
+
output_offset=self.output_offset,
|
195
|
+
max_qkv_out_dim=self.max_qkv_out_dim,
|
196
|
+
base_output=base_output,
|
197
|
+
scaling=self.scaling,
|
156
198
|
)
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
lora_output[:, :output_dim_q] = self.segment_gemm.run(
|
162
|
-
x=lora_a_output[:, : self.lora_rank].contiguous(),
|
163
|
-
weights=self.B_buffer_q,
|
164
|
-
batch_size=self.bs,
|
165
|
-
weight_column_major=True,
|
166
|
-
seg_indptr=self.seg_indptr,
|
167
|
-
weight_indices=self.weight_indices,
|
199
|
+
return (
|
200
|
+
lora_output
|
201
|
+
if self.lora_backend.fuse_output_scaling_add
|
202
|
+
else base_output + lora_output * self.scaling
|
168
203
|
)
|
169
|
-
# kv
|
170
|
-
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
|
171
|
-
for i in range(2):
|
172
|
-
left = output_dim_kv * i
|
173
|
-
right = left + output_dim_kv
|
174
|
-
lora_output[:, output_dim_q + left : output_dim_q + right] = (
|
175
|
-
self.segment_gemm.run(
|
176
|
-
x=lora_a_output[
|
177
|
-
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
|
178
|
-
].contiguous(),
|
179
|
-
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
|
180
|
-
batch_size=self.bs,
|
181
|
-
weight_column_major=True,
|
182
|
-
seg_indptr=self.seg_indptr,
|
183
|
-
weight_indices=self.weight_indices,
|
184
|
-
)
|
185
|
-
)
|
186
|
-
return base_output + lora_output * self.scaling
|
187
204
|
|
188
205
|
|
189
206
|
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
190
207
|
def __init__(
|
191
|
-
self, base_layer: RowParallelLinear,
|
208
|
+
self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
|
192
209
|
) -> None:
|
193
|
-
super().__init__(base_layer,
|
210
|
+
super().__init__(base_layer, lora_rank, scaling, lora_backend)
|
194
211
|
|
195
|
-
def set_lora_info(self, A_buffer, B_buffer
|
212
|
+
def set_lora_info(self, A_buffer, B_buffer):
|
196
213
|
self.set_lora = True
|
197
214
|
self.A_buffer = A_buffer
|
198
215
|
self.B_buffer = B_buffer
|
199
|
-
self.bs = bs
|
200
|
-
self.seg_indptr = seg_indptr
|
201
|
-
self.weight_indices = weight_indices
|
202
216
|
|
203
217
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
weight_indices=self.weight_indices,
|
218
|
+
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
219
|
+
lora_output = self.lora_backend.run_lora_b_sgemm(
|
220
|
+
lora_a_output,
|
221
|
+
self.B_buffer[0],
|
222
|
+
base_output=base_output,
|
223
|
+
scaling=self.scaling,
|
211
224
|
)
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
weight_column_major=True,
|
217
|
-
seg_indptr=self.seg_indptr,
|
218
|
-
weight_indices=self.weight_indices,
|
225
|
+
return (
|
226
|
+
lora_output
|
227
|
+
if self.lora_backend.fuse_output_scaling_add
|
228
|
+
else base_output + lora_output * self.scaling
|
219
229
|
)
|
220
|
-
return base_output + lora_output * self.scaling
|
221
230
|
|
222
231
|
def forward(self, input_):
|
223
232
|
# duplicate the logic in RowParallelLinear
|
@@ -255,7 +264,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
255
264
|
|
256
265
|
|
257
266
|
def get_lora_layer(
|
258
|
-
layer: nn.Module,
|
267
|
+
layer: nn.Module, lora_rank, scaling, lora_backend
|
259
268
|
) -> BaseLayerWithLoRA:
|
260
269
|
supported_layer_types = {
|
261
270
|
# the order matters
|
@@ -267,7 +276,7 @@ def get_lora_layer(
|
|
267
276
|
}
|
268
277
|
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
269
278
|
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
270
|
-
ret = lora_layer_type(layer,
|
279
|
+
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
|
271
280
|
return ret
|
272
281
|
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
273
282
|
|
@@ -297,13 +306,14 @@ class LoRALayer(nn.Module):
|
|
297
306
|
|
298
307
|
|
299
308
|
class LoRAAdapter(nn.Module):
|
300
|
-
def __init__(self, uid, config, base_hf_config, load_config):
|
309
|
+
def __init__(self, uid, config, base_hf_config, load_config, lora_backend):
|
301
310
|
super().__init__()
|
302
311
|
self.uid = uid
|
303
312
|
self.config = config
|
304
313
|
assert self.config.hf_config["peft_type"].lower() == "lora"
|
305
314
|
self.base_hf_config = base_hf_config
|
306
315
|
self.load_config = load_config
|
316
|
+
self.lora_backend = lora_backend
|
307
317
|
self.scaling = self.config.lora_alpha / self.config.r
|
308
318
|
|
309
319
|
self.layers = nn.ModuleList(
|
@@ -376,20 +386,25 @@ class LoRAAdapter(nn.Module):
|
|
376
386
|
layer.weights.pop(weight_name)
|
377
387
|
layer.weights.pop(v_name)
|
378
388
|
else:
|
379
|
-
layer.weights[kv_name] = torch.
|
380
|
-
|
389
|
+
layer.weights[kv_name] = torch.stack(
|
390
|
+
[
|
381
391
|
layer.weights[weight_name],
|
382
392
|
layer.weights[v_name],
|
383
|
-
|
384
|
-
0,
|
393
|
+
],
|
394
|
+
dim=0,
|
385
395
|
)
|
386
396
|
layer.weights.pop(weight_name)
|
387
397
|
layer.weights.pop(v_name)
|
388
398
|
elif "gate_proj" in weight_name:
|
389
399
|
up_name = weight_name.replace("gate_proj", "up_proj")
|
390
400
|
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
391
|
-
|
392
|
-
|
393
|
-
|
401
|
+
if "lora_A" in weight_name:
|
402
|
+
layer.weights[gate_up_name] = torch.cat(
|
403
|
+
(layer.weights[weight_name], layer.weights[up_name]), 0
|
404
|
+
)
|
405
|
+
else:
|
406
|
+
layer.weights[gate_up_name] = torch.stack(
|
407
|
+
[layer.weights[weight_name], layer.weights[up_name]], dim=0
|
408
|
+
)
|
394
409
|
layer.weights.pop(weight_name)
|
395
410
|
layer.weights.pop(up_name)
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -20,16 +20,14 @@ import re
|
|
20
20
|
|
21
21
|
import torch
|
22
22
|
|
23
|
-
from sglang.srt.lora.
|
23
|
+
from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend
|
24
|
+
from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer
|
24
25
|
from sglang.srt.lora.lora_config import LoRAConfig
|
25
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
26
27
|
from sglang.srt.utils import is_flashinfer_available, replace_submodule
|
27
28
|
|
28
29
|
logger = logging.getLogger(__name__)
|
29
30
|
|
30
|
-
if is_flashinfer_available():
|
31
|
-
from flashinfer import SegmentGEMMWrapper
|
32
|
-
|
33
31
|
|
34
32
|
def get_module_name(name):
|
35
33
|
# Fallback solution of mapping from config module name to module name in model class.
|
@@ -77,6 +75,20 @@ def get_stacked_name(name):
|
|
77
75
|
return params_mapping.get(name, (name, name))
|
78
76
|
|
79
77
|
|
78
|
+
def get_backend_from_name(name):
|
79
|
+
backend_mapping = {
|
80
|
+
"triton": TritonLoraBackend,
|
81
|
+
"flashinfer": FlashInferLoraBackend,
|
82
|
+
}
|
83
|
+
|
84
|
+
if name in backend_mapping:
|
85
|
+
return backend_mapping[name]
|
86
|
+
|
87
|
+
raise Exception(
|
88
|
+
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
|
89
|
+
)
|
90
|
+
|
91
|
+
|
80
92
|
def get_layer_id(name):
|
81
93
|
match = re.search(r"layers\.(\d+)\.", name)
|
82
94
|
if match is None:
|
@@ -93,6 +105,7 @@ class LoRAManager:
|
|
93
105
|
max_loras_per_batch,
|
94
106
|
load_config,
|
95
107
|
dtype,
|
108
|
+
lora_backend,
|
96
109
|
):
|
97
110
|
self.base_model = base_model
|
98
111
|
self.lora_paths = lora_paths
|
@@ -101,8 +114,9 @@ class LoRAManager:
|
|
101
114
|
self.load_config = load_config
|
102
115
|
self.dtype = dtype
|
103
116
|
|
104
|
-
|
105
|
-
|
117
|
+
logger.info(f"Using {lora_backend} as backend of Lora kernels.")
|
118
|
+
backend_type = get_backend_from_name(lora_backend)
|
119
|
+
self.lora_backend = backend_type(lora_backend)
|
106
120
|
|
107
121
|
self.init_loras()
|
108
122
|
self.init_lora_memory_pool()
|
@@ -123,7 +137,7 @@ class LoRAManager:
|
|
123
137
|
|
124
138
|
def set_lora_module(self, module_name, module):
|
125
139
|
lora_module = get_lora_layer(
|
126
|
-
module, self.
|
140
|
+
module, self.max_lora_dim, self.scaling, self.lora_backend
|
127
141
|
)
|
128
142
|
replace_submodule(self.base_model, module_name, lora_module)
|
129
143
|
return lora_module
|
@@ -162,7 +176,11 @@ class LoRAManager:
|
|
162
176
|
self.lora_id[name] = len(self.loras)
|
163
177
|
self.loras.append(
|
164
178
|
LoRAAdapter(
|
165
|
-
name,
|
179
|
+
name,
|
180
|
+
self.configs[name],
|
181
|
+
self.base_hf_config,
|
182
|
+
self.load_config,
|
183
|
+
self.lora_backend,
|
166
184
|
)
|
167
185
|
)
|
168
186
|
self.loras[-1].initialize_weights()
|
@@ -226,8 +244,9 @@ class LoRAManager:
|
|
226
244
|
self.B_buffer[module_B] = [
|
227
245
|
torch.empty(
|
228
246
|
(
|
247
|
+
c,
|
229
248
|
self.max_loras_per_batch,
|
230
|
-
hidden_dim_B
|
249
|
+
hidden_dim_B,
|
231
250
|
self.max_lora_dim,
|
232
251
|
),
|
233
252
|
dtype=self.dtype,
|
@@ -263,7 +282,16 @@ class LoRAManager:
|
|
263
282
|
else:
|
264
283
|
lora_weight_name = self.get_weight_name(name, 1)
|
265
284
|
if lora_weight_name:
|
266
|
-
self.
|
285
|
+
c = self.loras[-1].get_stacked_multiply(lora_weight_name)
|
286
|
+
if c > 1:
|
287
|
+
for j in range(c):
|
288
|
+
self.B_buffer[lora_weight_name][i][j][buffer_id].copy_(
|
289
|
+
weights[j]
|
290
|
+
)
|
291
|
+
else:
|
292
|
+
self.B_buffer[lora_weight_name][i][0][buffer_id].copy_(
|
293
|
+
weights
|
294
|
+
)
|
267
295
|
|
268
296
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
269
297
|
# load active loras into lora memory pool
|
@@ -292,20 +320,30 @@ class LoRAManager:
|
|
292
320
|
if cur_uids == set([None]):
|
293
321
|
return
|
294
322
|
|
295
|
-
#
|
323
|
+
# set up batch info shared by all lora moruldes
|
296
324
|
bs = forward_batch.batch_size
|
297
325
|
seg_lens = (
|
298
326
|
forward_batch.extend_seq_lens
|
299
327
|
if forward_batch.forward_mode.is_extend()
|
300
328
|
else torch.ones(bs, device="cuda")
|
301
329
|
)
|
302
|
-
# FIXME: reuse the data rather than recompute
|
303
330
|
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
304
331
|
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
332
|
+
max_len = int(torch.max(seg_lens))
|
305
333
|
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
|
306
334
|
for i, lora_path in enumerate(forward_batch.lora_paths):
|
307
335
|
weight_indices[i] = self.buffer_id[lora_path]
|
308
336
|
|
337
|
+
batch_info = LoraBatchInfo(
|
338
|
+
bs=bs,
|
339
|
+
seg_lens=seg_lens,
|
340
|
+
seg_indptr=seg_indptr,
|
341
|
+
max_len=max_len,
|
342
|
+
weight_indices=weight_indices,
|
343
|
+
)
|
344
|
+
self.lora_backend.set_batch_info(batch_info)
|
345
|
+
|
346
|
+
# call set_lora_info for each lora modules
|
309
347
|
for module_name, module in self.lora_modules:
|
310
348
|
layer_id = get_layer_id(module_name)
|
311
349
|
|
@@ -314,16 +352,10 @@ class LoRAManager:
|
|
314
352
|
module.set_lora_info(
|
315
353
|
self.A_buffer[weight_name][layer_id],
|
316
354
|
self.B_buffer[weight_name][layer_id],
|
317
|
-
bs,
|
318
|
-
seg_indptr,
|
319
|
-
weight_indices,
|
320
355
|
)
|
321
356
|
else:
|
322
357
|
module.set_lora_info(
|
323
358
|
self.A_buffer["qkv_proj"][layer_id],
|
324
359
|
self.B_buffer["q_proj"][layer_id],
|
325
360
|
self.B_buffer["kv_proj"][layer_id],
|
326
|
-
bs,
|
327
|
-
seg_indptr,
|
328
|
-
weight_indices,
|
329
361
|
)
|