sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +23 -3
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +5 -16
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +218 -79
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/topk.py +30 -3
- sglang/srt/layers/quantization/__init__.py +134 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +12 -0
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/schedule_batch.py +1 -0
- sglang/srt/managers/scheduler.py +25 -19
- sglang/srt/managers/tokenizer_manager.py +0 -1
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -8
- sglang/srt/model_executor/model_runner.py +9 -6
- sglang/srt/model_loader/loader.py +11 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +151 -26
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +6 -0
- sglang/srt/openai_api/adapter.py +88 -87
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/server_args.py +21 -11
- sglang/srt/speculative/eagle_worker.py +1 -1
- sglang/srt/utils.py +33 -0
- sglang/test/runners.py +27 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +8 -4
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +57 -53
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,7 @@ import triton.language as tl
|
|
24
24
|
|
25
25
|
from sglang.srt.utils import (
|
26
26
|
direct_register_custom_op,
|
27
|
+
get_bool_env_var,
|
27
28
|
get_device_core_count,
|
28
29
|
get_device_name,
|
29
30
|
get_device_sm,
|
@@ -43,7 +44,7 @@ if _is_cuda:
|
|
43
44
|
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
44
45
|
|
45
46
|
sm_version = get_device_sm()
|
46
|
-
if sm_version >= 90 and
|
47
|
+
if sm_version >= 90 and get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
47
48
|
_enable_jit_deepgemm = True
|
48
49
|
|
49
50
|
|
@@ -11,12 +11,29 @@ from sglang.srt.utils import is_cuda
|
|
11
11
|
_is_cuda = is_cuda()
|
12
12
|
|
13
13
|
try:
|
14
|
-
import
|
14
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
15
|
+
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
16
|
+
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
17
|
+
GPTQMarlinLinearMethod,
|
18
|
+
GPTQMarlinMoEMethod,
|
19
|
+
)
|
20
|
+
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
21
|
+
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
22
|
+
check_marlin_supported,
|
23
|
+
)
|
24
|
+
from vllm.scalar_type import scalar_types
|
15
25
|
|
16
26
|
VLLM_AVAILABLE = True
|
17
27
|
except ImportError:
|
18
28
|
VLLM_AVAILABLE = False
|
19
29
|
|
30
|
+
GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
|
31
|
+
|
32
|
+
class scalar_types:
|
33
|
+
uint4b8 = "uint4b8"
|
34
|
+
uint8b128 = "uint8b128"
|
35
|
+
|
36
|
+
|
20
37
|
logger = logging.getLogger(__name__)
|
21
38
|
|
22
39
|
|
@@ -117,12 +134,8 @@ class GPTQConfig(QuantizationConfig):
|
|
117
134
|
|
118
135
|
def get_quant_method(
|
119
136
|
self, layer: torch.nn.Module, prefix: str
|
120
|
-
) -> Optional[
|
121
|
-
|
122
|
-
raise ImportError("vllm is not installed")
|
123
|
-
|
124
|
-
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
125
|
-
|
137
|
+
) -> Optional[GPTQLinearMethod]:
|
138
|
+
# Delay the import to avoid circular dependency
|
126
139
|
from sglang.srt.layers.quantization import get_linear_quant_method
|
127
140
|
|
128
141
|
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
@@ -131,16 +144,11 @@ class GPTQConfig(QuantizationConfig):
|
|
131
144
|
class GPTQMarlinConfig(QuantizationConfig):
|
132
145
|
"""Config class for GPTQ Marlin"""
|
133
146
|
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
(4, True): scalar_types.uint4b8,
|
140
|
-
(8, True): scalar_types.uint8b128,
|
141
|
-
}
|
142
|
-
else:
|
143
|
-
raise ImportError("vllm is not installed")
|
147
|
+
# (num_bits, is_sym) -> quant_type
|
148
|
+
TYPE_MAP = {
|
149
|
+
(4, True): scalar_types.uint4b8,
|
150
|
+
(8, True): scalar_types.uint8b128,
|
151
|
+
}
|
144
152
|
|
145
153
|
def __init__(
|
146
154
|
self,
|
@@ -197,6 +205,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
197
205
|
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
|
198
206
|
)
|
199
207
|
|
208
|
+
# (num_bits, is_sym) -> quant_type
|
200
209
|
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
201
210
|
|
202
211
|
def __repr__(self) -> str:
|
@@ -278,15 +287,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
278
287
|
|
279
288
|
def get_quant_method(
|
280
289
|
self, layer: torch.nn.Module, prefix: str
|
281
|
-
) -> Optional[
|
282
|
-
|
283
|
-
raise ImportError("vllm is not installed")
|
284
|
-
|
285
|
-
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
286
|
-
GPTQMarlinLinearMethod,
|
287
|
-
GPTQMarlinMoEMethod,
|
288
|
-
)
|
289
|
-
|
290
|
+
) -> Optional[QuantizeMethodBase]:
|
291
|
+
# Delay the import to avoid circular dependency
|
290
292
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
291
293
|
from sglang.srt.layers.quantization import get_linear_quant_method
|
292
294
|
|
@@ -304,19 +306,12 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
304
306
|
|
305
307
|
@classmethod
|
306
308
|
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
307
|
-
if not VLLM_AVAILABLE:
|
308
|
-
return False
|
309
|
-
|
310
309
|
quant_method = quant_config.get("quant_method", "").lower()
|
311
310
|
num_bits = quant_config.get("bits")
|
312
311
|
group_size = quant_config.get("group_size")
|
313
312
|
sym = quant_config.get("sym")
|
314
313
|
desc_act = quant_config.get("desc_act")
|
315
314
|
|
316
|
-
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
317
|
-
check_marlin_supported,
|
318
|
-
)
|
319
|
-
|
320
315
|
if not _is_cuda:
|
321
316
|
return False
|
322
317
|
|
@@ -427,13 +422,8 @@ class MarlinConfig(QuantizationConfig):
|
|
427
422
|
|
428
423
|
def get_quant_method(
|
429
424
|
self, layer: torch.nn.Module, prefix: str
|
430
|
-
) -> Optional[
|
431
|
-
|
432
|
-
raise ImportError("vllm is not installed")
|
433
|
-
|
434
|
-
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
435
|
-
|
436
|
-
# Delay import to avoid circular dependency
|
425
|
+
) -> Optional[MarlinLinearMethod]:
|
426
|
+
# Delay the import to avoid circular dependency
|
437
427
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
438
428
|
|
439
429
|
if isinstance(layer, LinearBase) or (
|
@@ -37,7 +37,7 @@ class W8A8Fp8Config(QuantizationConfig):
|
|
37
37
|
Note:
|
38
38
|
- For models without offline quantization, weights will be quantized during model loading
|
39
39
|
- If CUTLASS is supported: Per-channel weight quantization is used
|
40
|
-
- If CUTLASS is not supported: Falls back to per-
|
40
|
+
- If CUTLASS is not supported: Falls back to per-tensor weight quantization
|
41
41
|
"""
|
42
42
|
|
43
43
|
def __init__(self, is_checkpoint_fp8_serialized: bool = False):
|
@@ -651,6 +651,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
651
651
|
query: torch.Tensor,
|
652
652
|
key: torch.Tensor,
|
653
653
|
offsets: Optional[torch.Tensor] = None,
|
654
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
655
|
+
if _is_cuda_available:
|
656
|
+
return self.forward_cuda(positions, query, key, offsets)
|
657
|
+
else:
|
658
|
+
return self.forward_native(positions, query, key, offsets)
|
659
|
+
|
660
|
+
def forward_native(
|
661
|
+
self,
|
662
|
+
positions: torch.Tensor,
|
663
|
+
query: torch.Tensor,
|
664
|
+
key: torch.Tensor,
|
665
|
+
offsets: Optional[torch.Tensor] = None,
|
654
666
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
655
667
|
"""PyTorch-native implementation equivalent to forward()."""
|
656
668
|
query_rot = query[..., : self.rotary_dim]
|
@@ -5,7 +5,7 @@ import torch
|
|
5
5
|
from sglang.srt.lora.utils import LoRABatchInfo
|
6
6
|
|
7
7
|
|
8
|
-
def
|
8
|
+
def get_fuse_output_add_from_name(name: str) -> bool:
|
9
9
|
mapping = {
|
10
10
|
"triton": True,
|
11
11
|
"flashinfer": False,
|
@@ -28,14 +28,14 @@ class BaseLoRABackend:
|
|
28
28
|
Args:
|
29
29
|
name: name of backend
|
30
30
|
batch_info: information of current batch for use
|
31
|
-
|
32
|
-
and the operation of
|
31
|
+
fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
|
32
|
+
and the operation of adding will be fused into kernel
|
33
33
|
"""
|
34
34
|
|
35
35
|
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
36
36
|
self.name = name
|
37
37
|
self.batch_info = batch_info
|
38
|
-
self.
|
38
|
+
self.fuse_output_add = get_fuse_output_add_from_name(name)
|
39
39
|
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
|
40
40
|
|
41
41
|
def run_lora_a_sgemm(
|
@@ -37,13 +37,16 @@ class FlashInferLoRABackend(BaseLoRABackend):
|
|
37
37
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
38
38
|
) -> torch.Tensor:
|
39
39
|
|
40
|
-
return
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
40
|
+
return (
|
41
|
+
self.segment_gemm.run(
|
42
|
+
x=x,
|
43
|
+
weights=weights,
|
44
|
+
batch_size=self.batch_info.bs,
|
45
|
+
weight_column_major=True,
|
46
|
+
seg_indptr=self.batch_info.seg_indptr,
|
47
|
+
weight_indices=self.batch_info.weight_indices,
|
48
|
+
)
|
49
|
+
* self.batch_info.scalings[0]
|
47
50
|
)
|
48
51
|
|
49
52
|
def run_qkv_lora(
|
@@ -90,7 +93,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
|
|
90
93
|
weights=kv_lora_b[1],
|
91
94
|
)
|
92
95
|
|
93
|
-
return lora_output
|
96
|
+
return lora_output * self.batch_info.scalings[0]
|
94
97
|
|
95
98
|
def run_gate_up_lora(
|
96
99
|
self,
|
@@ -125,4 +128,4 @@ class FlashInferLoRABackend(BaseLoRABackend):
|
|
125
128
|
weights=gate_up_lora_b[1],
|
126
129
|
)
|
127
130
|
|
128
|
-
return lora_output
|
131
|
+
return lora_output * self.batch_info.scalings[0]
|
@@ -25,11 +25,10 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
25
25
|
x: torch.Tensor,
|
26
26
|
weights: torch.Tensor,
|
27
27
|
base_output: torch.Tensor = None,
|
28
|
-
scaling: float = 1.0,
|
29
28
|
*args,
|
30
29
|
**kwargs
|
31
30
|
) -> torch.Tensor:
|
32
|
-
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output
|
31
|
+
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output)
|
33
32
|
|
34
33
|
def run_qkv_lora(
|
35
34
|
self,
|
@@ -39,7 +38,6 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
39
38
|
output_offset: torch.Tensor,
|
40
39
|
max_qkv_out_dim: int,
|
41
40
|
base_output: torch.Tensor = None,
|
42
|
-
scaling: float = 1.0,
|
43
41
|
*args,
|
44
42
|
**kwargs
|
45
43
|
) -> torch.Tensor:
|
@@ -49,7 +47,7 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
49
47
|
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
50
48
|
assert isinstance(qkv_lora_b, torch.Tensor)
|
51
49
|
|
52
|
-
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
|
50
|
+
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info, stack_num=3)
|
53
51
|
lora_output = qkv_lora_b_fwd(
|
54
52
|
lora_a_output,
|
55
53
|
qkv_lora_b,
|
@@ -57,7 +55,6 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
57
55
|
output_offset,
|
58
56
|
max_qkv_out_dim,
|
59
57
|
base_output,
|
60
|
-
scaling,
|
61
58
|
)
|
62
59
|
return lora_output
|
63
60
|
|
@@ -67,7 +64,6 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
67
64
|
gate_up_lora_a: torch.Tensor,
|
68
65
|
gate_up_lora_b: torch.Tensor,
|
69
66
|
base_output: torch.Tensor = None,
|
70
|
-
scaling: float = 1.0,
|
71
67
|
*args,
|
72
68
|
**kwargs
|
73
69
|
) -> torch.Tensor:
|
@@ -79,13 +75,14 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
79
75
|
output_dim = gate_up_lora_b.shape[-2] // 2
|
80
76
|
|
81
77
|
# lora_a_output: (s, 2 * r)
|
82
|
-
lora_a_output = sgemm_lora_a_fwd(
|
78
|
+
lora_a_output = sgemm_lora_a_fwd(
|
79
|
+
x, gate_up_lora_a, self.batch_info, stack_num=2
|
80
|
+
)
|
83
81
|
lora_output = gate_up_lora_b_fwd(
|
84
82
|
lora_a_output,
|
85
83
|
gate_up_lora_b,
|
86
84
|
self.batch_info,
|
87
85
|
output_dim,
|
88
86
|
base_output,
|
89
|
-
scaling,
|
90
87
|
)
|
91
88
|
return lora_output
|
sglang/srt/lora/layers.py
CHANGED
@@ -23,14 +23,10 @@ class BaseLayerWithLoRA(nn.Module):
|
|
23
23
|
def __init__(
|
24
24
|
self,
|
25
25
|
base_layer: nn.Module,
|
26
|
-
lora_rank: int,
|
27
|
-
scaling: float,
|
28
26
|
lora_backend: BaseLoRABackend,
|
29
27
|
):
|
30
28
|
super().__init__()
|
31
29
|
self.base_layer: nn.Module = base_layer
|
32
|
-
self.lora_rank: int = lora_rank
|
33
|
-
self.scaling: float = scaling
|
34
30
|
self.set_lora: bool = False
|
35
31
|
self.lora_backend: BaseLoRABackend = lora_backend
|
36
32
|
|
@@ -59,11 +55,9 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|
59
55
|
def __init__(
|
60
56
|
self,
|
61
57
|
base_layer: VocabParallelEmbedding,
|
62
|
-
lora_rank: int,
|
63
|
-
scaling: float,
|
64
58
|
lora_backend: BaseLoRABackend,
|
65
59
|
) -> None:
|
66
|
-
super().__init__(base_layer,
|
60
|
+
super().__init__(base_layer, lora_backend)
|
67
61
|
self.weight = base_layer.weight
|
68
62
|
|
69
63
|
|
@@ -71,11 +65,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
71
65
|
def __init__(
|
72
66
|
self,
|
73
67
|
base_layer: ColumnParallelLinear,
|
74
|
-
lora_rank: int,
|
75
|
-
scaling: float,
|
76
68
|
lora_backend: BaseLoRABackend,
|
77
69
|
) -> None:
|
78
|
-
super().__init__(base_layer,
|
70
|
+
super().__init__(base_layer, lora_backend)
|
79
71
|
|
80
72
|
def set_lora_info(
|
81
73
|
self,
|
@@ -87,7 +79,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
87
79
|
self.B_buffer = B_buffer
|
88
80
|
|
89
81
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
90
|
-
backend_kwargs = {"base_output": base_output
|
82
|
+
backend_kwargs = {"base_output": base_output}
|
91
83
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
92
84
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
93
85
|
lora_a_output,
|
@@ -96,8 +88,8 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
96
88
|
)
|
97
89
|
return (
|
98
90
|
lora_output
|
99
|
-
if self.lora_backend.
|
100
|
-
else base_output + lora_output
|
91
|
+
if self.lora_backend.fuse_output_add
|
92
|
+
else base_output + lora_output
|
101
93
|
)
|
102
94
|
|
103
95
|
def forward(self, input_: torch.Tensor):
|
@@ -132,11 +124,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
132
124
|
def __init__(
|
133
125
|
self,
|
134
126
|
base_layer: MergedColumnParallelLinear,
|
135
|
-
lora_rank: int,
|
136
|
-
scaling: float,
|
137
127
|
lora_backend: BaseLoRABackend,
|
138
128
|
) -> None:
|
139
|
-
super().__init__(base_layer,
|
129
|
+
super().__init__(base_layer, lora_backend)
|
140
130
|
|
141
131
|
def set_lora_info(
|
142
132
|
self,
|
@@ -155,7 +145,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
155
145
|
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
|
156
146
|
|
157
147
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
158
|
-
backend_kwargs = {"base_output": base_output
|
148
|
+
backend_kwargs = {"base_output": base_output}
|
159
149
|
|
160
150
|
lora_output = self.lora_backend.run_gate_up_lora(
|
161
151
|
x,
|
@@ -165,8 +155,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
165
155
|
)
|
166
156
|
return (
|
167
157
|
lora_output
|
168
|
-
if self.lora_backend.
|
169
|
-
else base_output + lora_output
|
158
|
+
if self.lora_backend.fuse_output_add
|
159
|
+
else base_output + lora_output
|
170
160
|
)
|
171
161
|
|
172
162
|
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
@@ -184,11 +174,9 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
184
174
|
def init__(
|
185
175
|
self,
|
186
176
|
base_layer: QKVParallelLinear,
|
187
|
-
lora_rank: int,
|
188
|
-
scaling: float,
|
189
177
|
lora_backend: BaseLoRABackend,
|
190
178
|
) -> None:
|
191
|
-
super().__init__(base_layer,
|
179
|
+
super().__init__(base_layer, lora_backend)
|
192
180
|
|
193
181
|
def set_lora_info(
|
194
182
|
self,
|
@@ -230,7 +218,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
230
218
|
)
|
231
219
|
|
232
220
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
233
|
-
backend_kwargs = {"base_output": base_output
|
221
|
+
backend_kwargs = {"base_output": base_output}
|
234
222
|
if self.lora_backend.fuse_stacked_lora_b:
|
235
223
|
backend_kwargs["output_offset"] = self.output_offset
|
236
224
|
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
|
@@ -243,8 +231,8 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
243
231
|
)
|
244
232
|
return (
|
245
233
|
lora_output
|
246
|
-
if self.lora_backend.
|
247
|
-
else base_output + lora_output
|
234
|
+
if self.lora_backend.fuse_output_add
|
235
|
+
else base_output + lora_output
|
248
236
|
)
|
249
237
|
|
250
238
|
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
@@ -273,11 +261,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
273
261
|
def __init__(
|
274
262
|
self,
|
275
263
|
base_layer: RowParallelLinear,
|
276
|
-
lora_rank: int,
|
277
|
-
scaling: float,
|
278
264
|
lora_backend: BaseLoRABackend,
|
279
265
|
) -> None:
|
280
|
-
super().__init__(base_layer,
|
266
|
+
super().__init__(base_layer, lora_backend)
|
281
267
|
|
282
268
|
def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor):
|
283
269
|
self.set_lora = True
|
@@ -285,7 +271,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
285
271
|
self.B_buffer = B_buffer
|
286
272
|
|
287
273
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
288
|
-
backend_kwargs = {"base_output": base_output
|
274
|
+
backend_kwargs = {"base_output": base_output}
|
289
275
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
290
276
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
291
277
|
lora_a_output,
|
@@ -294,8 +280,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
294
280
|
)
|
295
281
|
return (
|
296
282
|
lora_output
|
297
|
-
if self.lora_backend.
|
298
|
-
else base_output + lora_output
|
283
|
+
if self.lora_backend.fuse_output_add
|
284
|
+
else base_output + lora_output
|
299
285
|
)
|
300
286
|
|
301
287
|
def forward(self, input_: torch.Tensor):
|
@@ -344,7 +330,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
344
330
|
|
345
331
|
|
346
332
|
def get_lora_layer(
|
347
|
-
layer: nn.Module,
|
333
|
+
layer: nn.Module, lora_backend: BaseLoRABackend
|
348
334
|
) -> BaseLayerWithLoRA:
|
349
335
|
supported_layer_types = {
|
350
336
|
# the order matters
|
@@ -356,6 +342,6 @@ def get_lora_layer(
|
|
356
342
|
}
|
357
343
|
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
358
344
|
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
359
|
-
ret = lora_layer_type(layer,
|
345
|
+
ret = lora_layer_type(layer, lora_backend)
|
360
346
|
return ret
|
361
347
|
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -103,11 +103,14 @@ class LoRAManager:
|
|
103
103
|
self.loras[name] = lora_adapter
|
104
104
|
|
105
105
|
# misc lora configs
|
106
|
-
# FIXME remove the restrictions after implementing unified paging
|
107
106
|
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
108
|
-
|
109
|
-
|
110
|
-
|
107
|
+
|
108
|
+
if self.lora_backend == "flashinfer":
|
109
|
+
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
110
|
+
max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
|
111
|
+
scaling = list(self.loras.values())[0].scaling
|
112
|
+
assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
|
113
|
+
assert all(x.scaling == scaling for x in self.loras.values())
|
111
114
|
|
112
115
|
# Convert original model layers to layers with LoRA
|
113
116
|
self.convert_to_lora_layers()
|
@@ -148,8 +151,18 @@ class LoRAManager:
|
|
148
151
|
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
149
152
|
max_len = int(torch.max(seg_lens))
|
150
153
|
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
154
|
+
|
155
|
+
lora_ranks = torch.empty(
|
156
|
+
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
157
|
+
)
|
158
|
+
scalings = torch.empty(
|
159
|
+
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
160
|
+
)
|
151
161
|
for i, lora_path in enumerate(forward_batch.lora_paths):
|
152
162
|
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
163
|
+
lora = self.loras[lora_path]
|
164
|
+
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
165
|
+
scalings[weight_indices[i]] = lora.scaling
|
153
166
|
|
154
167
|
batch_info = LoRABatchInfo(
|
155
168
|
bs=bs,
|
@@ -157,6 +170,8 @@ class LoRAManager:
|
|
157
170
|
seg_indptr=seg_indptr,
|
158
171
|
max_len=max_len,
|
159
172
|
weight_indices=weight_indices,
|
173
|
+
lora_ranks=lora_ranks,
|
174
|
+
scalings=scalings,
|
160
175
|
)
|
161
176
|
self.lora_backend.set_batch_info(batch_info)
|
162
177
|
|
@@ -189,9 +204,7 @@ class LoRAManager:
|
|
189
204
|
)
|
190
205
|
|
191
206
|
def set_lora_module(self, module_name, module):
|
192
|
-
lora_module = get_lora_layer(
|
193
|
-
module, self.max_lora_dim, self.scaling, self.lora_backend
|
194
|
-
)
|
207
|
+
lora_module = get_lora_layer(module, self.lora_backend)
|
195
208
|
replace_submodule(self.base_model, module_name, lora_module)
|
196
209
|
return lora_module
|
197
210
|
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -163,10 +163,11 @@ class LoRAMemoryPool:
|
|
163
163
|
if uid is None:
|
164
164
|
for i in range(self.num_layer):
|
165
165
|
for k in self.A_buffer.keys():
|
166
|
-
self.A_buffer[k][i][buffer_id]
|
166
|
+
self.A_buffer[k][i][buffer_id] = 0
|
167
167
|
return
|
168
168
|
|
169
169
|
assert lora_adapter is not None
|
170
|
+
lora_rank = lora_adapter.config.hf_config["r"]
|
170
171
|
for layer_id in range(self.num_layer):
|
171
172
|
layer_weights = lora_adapter.layers[layer_id].weights
|
172
173
|
temp_A_buffer: Dict[str, torch.Tensor] = {}
|
@@ -208,17 +209,22 @@ class LoRAMemoryPool:
|
|
208
209
|
)
|
209
210
|
|
210
211
|
for name, weights in temp_A_buffer.items():
|
211
|
-
|
212
|
+
c = get_stacked_multiply(name)
|
213
|
+
self.A_buffer[name][layer_id][buffer_id][: lora_rank * c, :].copy_(
|
214
|
+
weights
|
215
|
+
)
|
212
216
|
|
213
217
|
for name, weights in temp_B_buffer.items():
|
214
218
|
c = get_stacked_multiply(name)
|
215
219
|
if c > 1:
|
216
220
|
for stacked_id in range(c):
|
217
|
-
self.B_buffer[name][layer_id][stacked_id][buffer_id]
|
218
|
-
|
219
|
-
)
|
221
|
+
self.B_buffer[name][layer_id][stacked_id][buffer_id][
|
222
|
+
:, :lora_rank
|
223
|
+
].copy_(weights[stacked_id])
|
220
224
|
else:
|
221
|
-
self.B_buffer[name][layer_id][0][buffer_id].copy_(
|
225
|
+
self.B_buffer[name][layer_id][0][buffer_id][:, :lora_rank].copy_(
|
226
|
+
weights
|
227
|
+
)
|
222
228
|
|
223
229
|
def get_tensor(
|
224
230
|
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
@@ -22,17 +22,18 @@ def _gate_up_lora_b_kernel(
|
|
22
22
|
w_stride_2,
|
23
23
|
output_stride_0,
|
24
24
|
output_stride_1,
|
25
|
-
# Information on sequence lengths and weight id
|
25
|
+
# Information on sequence lengths,ranks and weight id
|
26
26
|
seg_lens,
|
27
27
|
seg_indptr,
|
28
28
|
weight_indices,
|
29
|
+
lora_ranks,
|
29
30
|
# Meta parameters
|
30
31
|
BLOCK_S: tl.constexpr,
|
31
32
|
BLOCK_N: tl.constexpr,
|
32
33
|
BLOCK_K: tl.constexpr,
|
33
34
|
# For fused output scaling and adding
|
34
35
|
fuse_scaling_add,
|
35
|
-
|
36
|
+
scalings,
|
36
37
|
):
|
37
38
|
# This kernel packs 2 sgemms (gate/up) into a single kernel.
|
38
39
|
|
@@ -51,6 +52,11 @@ def _gate_up_lora_b_kernel(
|
|
51
52
|
w_index = tl.load(weight_indices + batch_id)
|
52
53
|
seg_start = tl.load(seg_indptr + batch_id)
|
53
54
|
n_start = gate_up_id * output_dim # offset on output dim
|
55
|
+
rank = tl.load(lora_ranks + w_index)
|
56
|
+
scaling = tl.load(scalings + w_index)
|
57
|
+
|
58
|
+
# Adjust K (rank) according to the specific LoRA adapter
|
59
|
+
K = tl.minimum(K, rank)
|
54
60
|
|
55
61
|
# The tile in output matrix will have (pid_s, pid_n) as id
|
56
62
|
num_pid_n = tl.cdiv(output_dim, BLOCK_N)
|
@@ -109,7 +115,6 @@ def gate_up_lora_b_fwd(
|
|
109
115
|
batch_info: LoRABatchInfo,
|
110
116
|
output_dim: int,
|
111
117
|
base_output: torch.Tensor = None,
|
112
|
-
scaling: float = 1.0,
|
113
118
|
) -> torch.Tensor:
|
114
119
|
|
115
120
|
# x: (s, 2 * r)
|
@@ -160,11 +165,12 @@ def gate_up_lora_b_fwd(
|
|
160
165
|
batch_info.seg_lens,
|
161
166
|
batch_info.seg_indptr,
|
162
167
|
batch_info.weight_indices,
|
168
|
+
batch_info.lora_ranks,
|
163
169
|
BLOCK_S,
|
164
170
|
BLOCK_OUT,
|
165
171
|
BLOCK_R,
|
166
172
|
fuse_scaling_add,
|
167
|
-
|
173
|
+
batch_info.scalings,
|
168
174
|
)
|
169
175
|
|
170
176
|
return output
|
@@ -26,6 +26,7 @@ def _qkv_lora_b_kernel(
|
|
26
26
|
seg_lens,
|
27
27
|
seg_indptr,
|
28
28
|
weight_indices,
|
29
|
+
lora_ranks,
|
29
30
|
# Offsets of q/k/v slice on output dimension
|
30
31
|
n_offs,
|
31
32
|
# Meta parameters
|
@@ -34,7 +35,7 @@ def _qkv_lora_b_kernel(
|
|
34
35
|
BLOCK_K: tl.constexpr,
|
35
36
|
# For fused output scaling and adding
|
36
37
|
fuse_scaling_add,
|
37
|
-
|
38
|
+
scalings,
|
38
39
|
):
|
39
40
|
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
|
40
41
|
|
@@ -54,6 +55,10 @@ def _qkv_lora_b_kernel(
|
|
54
55
|
seg_start = tl.load(seg_indptr + batch_id)
|
55
56
|
n_start = tl.load(n_offs + qkv_id)
|
56
57
|
n_size = tl.load(n_offs + qkv_id + 1) - n_start
|
58
|
+
rank = tl.load(lora_ranks + w_index)
|
59
|
+
scaling = tl.load(scalings + w_index)
|
60
|
+
# Adjust K (rank) according to the specific LoRA adapter
|
61
|
+
K = tl.minimum(K, rank)
|
57
62
|
|
58
63
|
# The tile in output matrix will have (pid_s, pid_n) as id
|
59
64
|
num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
|
@@ -112,7 +117,6 @@ def qkv_lora_b_fwd(
|
|
112
117
|
output_offset: torch.Tensor,
|
113
118
|
max_qkv_out_dim: int,
|
114
119
|
base_output: torch.Tensor = None,
|
115
|
-
scaling: float = 1.0,
|
116
120
|
) -> torch.Tensor:
|
117
121
|
|
118
122
|
# x: (s, 3 * r)
|
@@ -171,12 +175,13 @@ def qkv_lora_b_fwd(
|
|
171
175
|
batch_info.seg_lens,
|
172
176
|
batch_info.seg_indptr,
|
173
177
|
batch_info.weight_indices,
|
178
|
+
batch_info.lora_ranks,
|
174
179
|
output_offset,
|
175
180
|
BLOCK_S,
|
176
181
|
BLOCK_OUT,
|
177
182
|
BLOCK_R,
|
178
183
|
fuse_scaling_add,
|
179
|
-
|
184
|
+
batch_info.scalings,
|
180
185
|
)
|
181
186
|
|
182
187
|
return output
|