sglang 0.4.2__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/layers/activation.py +10 -5
  5. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  6. sglang/srt/layers/attention/triton_backend.py +71 -7
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  8. sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
  9. sglang/srt/layers/attention/vision.py +243 -40
  10. sglang/srt/layers/layernorm.py +1 -5
  11. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  22. sglang/srt/layers/moe/topk.py +4 -0
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/fp8.py +7 -0
  46. sglang/srt/layers/quantization/fp8_kernel.py +140 -2
  47. sglang/srt/layers/rotary_embedding.py +29 -15
  48. sglang/srt/layers/sampler.py +9 -6
  49. sglang/srt/lora/backend/__init__.py +8 -0
  50. sglang/srt/lora/backend/base_backend.py +95 -0
  51. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  52. sglang/srt/lora/backend/triton_backend.py +61 -0
  53. sglang/srt/lora/lora.py +127 -112
  54. sglang/srt/lora/lora_manager.py +50 -18
  55. sglang/srt/lora/triton_ops/__init__.py +5 -0
  56. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  57. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  59. sglang/srt/managers/image_processor.py +77 -38
  60. sglang/srt/managers/scheduler.py +17 -3
  61. sglang/srt/mem_cache/base_prefix_cache.py +4 -0
  62. sglang/srt/mem_cache/chunk_cache.py +3 -0
  63. sglang/srt/mem_cache/radix_cache.py +30 -1
  64. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  65. sglang/srt/model_executor/forward_batch_info.py +58 -59
  66. sglang/srt/model_executor/model_runner.py +2 -2
  67. sglang/srt/models/minicpmv.py +129 -76
  68. sglang/srt/models/mllama.py +16 -56
  69. sglang/srt/models/qwen2.py +4 -1
  70. sglang/srt/models/qwen2_vl.py +19 -9
  71. sglang/srt/server_args.py +19 -2
  72. sglang/srt/speculative/build_eagle_tree.py +4 -2
  73. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  74. sglang/srt/speculative/eagle_utils.py +361 -372
  75. sglang/srt/speculative/eagle_worker.py +177 -45
  76. sglang/srt/utils.py +7 -2
  77. sglang/test/runners.py +2 -0
  78. sglang/utils.py +42 -0
  79. sglang/version.py +1 -1
  80. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +16 -7
  81. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +84 -45
  82. sglang/srt/layers/custom_op_util.py +0 -25
  83. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
  84. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
  85. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,61 @@
1
+ import torch
2
+
3
+ from sglang.srt.lora.backend import BaseLoraBackend
4
+ from sglang.srt.lora.lora import LoraBatchInfo
5
+ from sglang.srt.lora.triton_ops import (
6
+ qkv_lora_b_fwd,
7
+ sgemm_lora_a_fwd,
8
+ sgemm_lora_b_fwd,
9
+ )
10
+
11
+
12
+ class TritonLoraBackend(BaseLoraBackend):
13
+
14
+ def __init__(self, name: str, batch_info: LoraBatchInfo = None):
15
+ super().__init__(name, batch_info)
16
+
17
+ def run_lora_a_sgemm(
18
+ self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
19
+ ) -> torch.Tensor:
20
+ return sgemm_lora_a_fwd(x, weights, self.batch_info)
21
+
22
+ def run_lora_b_sgemm(
23
+ self,
24
+ x: torch.Tensor,
25
+ weights: torch.Tensor,
26
+ base_output: torch.Tensor = None,
27
+ scaling: float = 1.0,
28
+ *args,
29
+ **kwargs
30
+ ) -> torch.Tensor:
31
+ return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling)
32
+
33
+ def run_qkv_lora(
34
+ self,
35
+ x: torch.Tensor,
36
+ qkv_lora_a: torch.Tensor,
37
+ qkv_lora_b: torch.Tensor,
38
+ output_offset: torch.Tensor,
39
+ max_qkv_out_dim: int,
40
+ base_output: torch.Tensor = None,
41
+ scaling: float = 1.0,
42
+ *args,
43
+ **kwargs
44
+ ) -> torch.Tensor:
45
+
46
+ # x: (s, input_dim)
47
+ # qkv_lora_a: (num_lora, 3 * r, input_dim)
48
+ # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
49
+ assert isinstance(qkv_lora_b, torch.Tensor)
50
+
51
+ lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
52
+ lora_output = qkv_lora_b_fwd(
53
+ lora_a_output,
54
+ qkv_lora_b,
55
+ self.batch_info,
56
+ output_offset,
57
+ max_qkv_out_dim,
58
+ base_output,
59
+ scaling,
60
+ )
61
+ return lora_output
sglang/srt/lora/lora.py CHANGED
@@ -18,12 +18,11 @@
18
18
  # LoRA layers class inheritance adapted from:
19
19
  # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
20
20
 
21
-
22
21
  import re
22
+ from dataclasses import dataclass
23
23
 
24
24
  import torch
25
25
  from torch import nn
26
- from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
27
26
 
28
27
  from sglang.srt.layers.linear import (
29
28
  ColumnParallelLinear,
@@ -31,17 +30,36 @@ from sglang.srt.layers.linear import (
31
30
  QKVParallelLinear,
32
31
  RowParallelLinear,
33
32
  )
33
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
34
34
  from sglang.srt.model_loader.loader import DefaultModelLoader
35
35
 
36
36
 
37
+ @dataclass
38
+ class LoraBatchInfo:
39
+ # Batch size
40
+ bs: int
41
+
42
+ # Lengths of each sequence in shape (bs,)
43
+ seg_lens: torch.Tensor
44
+
45
+ # Indice pointers of each sequence in shape (bs + 1, )
46
+ seg_indptr: torch.Tensor
47
+
48
+ # Maximum sequence length of current batch
49
+ max_len: int
50
+
51
+ # The index of lora adapter used by each sequence, in shape (bs,)
52
+ weight_indices: torch.Tensor
53
+
54
+
37
55
  class BaseLayerWithLoRA(nn.Module):
38
- def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
56
+ def __init__(self, base_layer, lora_rank, scaling, lora_backend):
39
57
  super().__init__()
40
58
  self.base_layer = base_layer
41
- self.segment_gemm = segment_gemm
42
59
  self.lora_rank = lora_rank
43
60
  self.scaling = scaling
44
61
  self.set_lora = False
62
+ self.lora_backend = lora_backend
45
63
 
46
64
  def forward(self, x: torch.Tensor):
47
65
  return self.base_layer.forward(x)
@@ -52,17 +70,17 @@ class BaseLayerWithLoRA(nn.Module):
52
70
 
53
71
  class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
54
72
  def __init__(
55
- self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
73
+ self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
56
74
  ) -> None:
57
- super().__init__(base_layer, segment_gemm, lora_rank, scaling)
75
+ super().__init__(base_layer, lora_rank, scaling, lora_backend)
58
76
  self.weight = base_layer.weight
59
77
 
60
78
 
61
79
  class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
62
80
  def __init__(
63
- self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
81
+ self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
64
82
  ) -> None:
65
- super().__init__(base_layer, segment_gemm, lora_rank, scaling)
83
+ super().__init__(base_layer, lora_rank, scaling, lora_backend)
66
84
 
67
85
  def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
68
86
  # TODO
@@ -88,136 +106,127 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
88
106
 
89
107
  class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
90
108
  def __init__(
91
- self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
109
+ self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
92
110
  ) -> None:
93
- super().__init__(base_layer, segment_gemm, lora_rank, scaling)
111
+ super().__init__(base_layer, lora_rank, scaling, lora_backend)
94
112
 
95
- def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
113
+ def set_lora_info(
114
+ self,
115
+ A_buffer,
116
+ B_buffer,
117
+ ):
96
118
  self.set_lora = True
97
119
  self.A_buffer = A_buffer
98
120
  self.B_buffer = B_buffer
99
- self.bs = bs
100
- self.seg_indptr = seg_indptr
101
- self.weight_indices = weight_indices
102
121
 
103
122
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
104
- lora_a_output = self.segment_gemm.run(
105
- x=x,
106
- weights=self.A_buffer,
107
- batch_size=self.bs,
108
- weight_column_major=True,
109
- seg_indptr=self.seg_indptr,
110
- weight_indices=self.weight_indices,
111
- )
112
- # FIXME
123
+ lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
124
+
125
+ output_dim = base_output.shape[-1]
113
126
  lora_output = torch.empty_like(base_output)
114
- output_dim = lora_output.shape[-1] // 2
115
- for i in range(2):
116
- left = output_dim * i
117
- right = left + output_dim
118
- lora_output[:, left:right] = self.segment_gemm.run(
119
- x=lora_a_output[
120
- :, self.lora_rank * i : self.lora_rank * (i + 1)
121
- ].contiguous(),
122
- weights=self.B_buffer[:, left:right, :].contiguous(),
123
- batch_size=self.bs,
124
- weight_column_major=True,
125
- seg_indptr=self.seg_indptr,
126
- weight_indices=self.weight_indices,
127
+ lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
128
+ x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
129
+ weights=self.B_buffer[0],
130
+ )
131
+
132
+ lora_output[:, output_dim : 2 * output_dim] = (
133
+ self.lora_backend.run_lora_b_sgemm(
134
+ x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
135
+ weights=self.B_buffer[1],
127
136
  )
137
+ )
138
+
128
139
  return base_output + lora_output * self.scaling
129
140
 
130
141
 
131
142
  class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
132
- def __init__(
133
- self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
143
+ def init__(
144
+ self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
134
145
  ) -> None:
135
- super().__init__(base_layer, segment_gemm, lora_rank, scaling)
146
+ super().__init__(base_layer, lora_rank, scaling, lora_backend)
136
147
 
137
148
  def set_lora_info(
138
- self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices
149
+ self,
150
+ A_buffer_qkv,
151
+ B_buffer_q,
152
+ B_buffer_kv,
139
153
  ):
140
154
  self.set_lora = True
141
155
  self.A_buffer_qkv = A_buffer_qkv
142
- self.B_buffer_q = B_buffer_q
143
- self.B_buffer_kv = B_buffer_kv
144
- self.bs = bs
145
- self.seg_indptr = seg_indptr
146
- self.weight_indices = weight_indices
156
+
157
+ if self.lora_backend.fuse_qkv_lora_b:
158
+ assert (
159
+ B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
160
+ ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
161
+ output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
162
+
163
+ # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
164
+ self.B_buffer_qkv = torch.cat(
165
+ (B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
166
+ ).contiguous()
167
+
168
+ # Offsets of q/k/v in output dimension
169
+ self.output_offset = torch.tensor(
170
+ [
171
+ 0,
172
+ output_dim_q,
173
+ output_dim_q + output_dim_kv,
174
+ output_dim_q + 2 * output_dim_kv,
175
+ ],
176
+ dtype=torch.int32,
177
+ device=B_buffer_q.device,
178
+ )
179
+ # For computing number of launched blocks
180
+ self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
181
+ else:
182
+ self.B_buffer_qkv = (
183
+ B_buffer_q,
184
+ B_buffer_kv,
185
+ )
186
+ self.output_offset = None
187
+ self.max_qkv_out_dim = None
147
188
 
148
189
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
149
- lora_a_output = self.segment_gemm.run(
150
- x=x,
151
- weights=self.A_buffer_qkv,
152
- batch_size=self.bs,
153
- weight_column_major=True,
154
- seg_indptr=self.seg_indptr,
155
- weight_indices=self.weight_indices,
190
+ lora_output = self.lora_backend.run_qkv_lora(
191
+ x,
192
+ self.A_buffer_qkv,
193
+ self.B_buffer_qkv,
194
+ output_offset=self.output_offset,
195
+ max_qkv_out_dim=self.max_qkv_out_dim,
196
+ base_output=base_output,
197
+ scaling=self.scaling,
156
198
  )
157
- # FIXME parallelize qkv
158
- lora_output = torch.empty_like(base_output)
159
- # q
160
- output_dim_q = self.B_buffer_q.shape[-2]
161
- lora_output[:, :output_dim_q] = self.segment_gemm.run(
162
- x=lora_a_output[:, : self.lora_rank].contiguous(),
163
- weights=self.B_buffer_q,
164
- batch_size=self.bs,
165
- weight_column_major=True,
166
- seg_indptr=self.seg_indptr,
167
- weight_indices=self.weight_indices,
199
+ return (
200
+ lora_output
201
+ if self.lora_backend.fuse_output_scaling_add
202
+ else base_output + lora_output * self.scaling
168
203
  )
169
- # kv
170
- output_dim_kv = self.B_buffer_kv.shape[-2] // 2
171
- for i in range(2):
172
- left = output_dim_kv * i
173
- right = left + output_dim_kv
174
- lora_output[:, output_dim_q + left : output_dim_q + right] = (
175
- self.segment_gemm.run(
176
- x=lora_a_output[
177
- :, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
178
- ].contiguous(),
179
- weights=self.B_buffer_kv[:, left:right, :].contiguous(),
180
- batch_size=self.bs,
181
- weight_column_major=True,
182
- seg_indptr=self.seg_indptr,
183
- weight_indices=self.weight_indices,
184
- )
185
- )
186
- return base_output + lora_output * self.scaling
187
204
 
188
205
 
189
206
  class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
190
207
  def __init__(
191
- self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
208
+ self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
192
209
  ) -> None:
193
- super().__init__(base_layer, segment_gemm, lora_rank, scaling)
210
+ super().__init__(base_layer, lora_rank, scaling, lora_backend)
194
211
 
195
- def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
212
+ def set_lora_info(self, A_buffer, B_buffer):
196
213
  self.set_lora = True
197
214
  self.A_buffer = A_buffer
198
215
  self.B_buffer = B_buffer
199
- self.bs = bs
200
- self.seg_indptr = seg_indptr
201
- self.weight_indices = weight_indices
202
216
 
203
217
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
204
- lora_output = self.segment_gemm.run(
205
- x=x,
206
- weights=self.A_buffer,
207
- batch_size=self.bs,
208
- weight_column_major=True,
209
- seg_indptr=self.seg_indptr,
210
- weight_indices=self.weight_indices,
218
+ lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
219
+ lora_output = self.lora_backend.run_lora_b_sgemm(
220
+ lora_a_output,
221
+ self.B_buffer[0],
222
+ base_output=base_output,
223
+ scaling=self.scaling,
211
224
  )
212
- lora_output = self.segment_gemm.run(
213
- x=lora_output,
214
- weights=self.B_buffer,
215
- batch_size=self.bs,
216
- weight_column_major=True,
217
- seg_indptr=self.seg_indptr,
218
- weight_indices=self.weight_indices,
225
+ return (
226
+ lora_output
227
+ if self.lora_backend.fuse_output_scaling_add
228
+ else base_output + lora_output * self.scaling
219
229
  )
220
- return base_output + lora_output * self.scaling
221
230
 
222
231
  def forward(self, input_):
223
232
  # duplicate the logic in RowParallelLinear
@@ -255,7 +264,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
255
264
 
256
265
 
257
266
  def get_lora_layer(
258
- layer: nn.Module, segment_gemm, lora_rank, scaling
267
+ layer: nn.Module, lora_rank, scaling, lora_backend
259
268
  ) -> BaseLayerWithLoRA:
260
269
  supported_layer_types = {
261
270
  # the order matters
@@ -267,7 +276,7 @@ def get_lora_layer(
267
276
  }
268
277
  for src_layer_type, lora_layer_type in supported_layer_types.items():
269
278
  if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
270
- ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
279
+ ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
271
280
  return ret
272
281
  raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
273
282
 
@@ -297,13 +306,14 @@ class LoRALayer(nn.Module):
297
306
 
298
307
 
299
308
  class LoRAAdapter(nn.Module):
300
- def __init__(self, uid, config, base_hf_config, load_config):
309
+ def __init__(self, uid, config, base_hf_config, load_config, lora_backend):
301
310
  super().__init__()
302
311
  self.uid = uid
303
312
  self.config = config
304
313
  assert self.config.hf_config["peft_type"].lower() == "lora"
305
314
  self.base_hf_config = base_hf_config
306
315
  self.load_config = load_config
316
+ self.lora_backend = lora_backend
307
317
  self.scaling = self.config.lora_alpha / self.config.r
308
318
 
309
319
  self.layers = nn.ModuleList(
@@ -376,20 +386,25 @@ class LoRAAdapter(nn.Module):
376
386
  layer.weights.pop(weight_name)
377
387
  layer.weights.pop(v_name)
378
388
  else:
379
- layer.weights[kv_name] = torch.cat(
380
- (
389
+ layer.weights[kv_name] = torch.stack(
390
+ [
381
391
  layer.weights[weight_name],
382
392
  layer.weights[v_name],
383
- ),
384
- 0,
393
+ ],
394
+ dim=0,
385
395
  )
386
396
  layer.weights.pop(weight_name)
387
397
  layer.weights.pop(v_name)
388
398
  elif "gate_proj" in weight_name:
389
399
  up_name = weight_name.replace("gate_proj", "up_proj")
390
400
  gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
391
- layer.weights[gate_up_name] = torch.cat(
392
- (layer.weights[weight_name], layer.weights[up_name]), 0
393
- )
401
+ if "lora_A" in weight_name:
402
+ layer.weights[gate_up_name] = torch.cat(
403
+ (layer.weights[weight_name], layer.weights[up_name]), 0
404
+ )
405
+ else:
406
+ layer.weights[gate_up_name] = torch.stack(
407
+ [layer.weights[weight_name], layer.weights[up_name]], dim=0
408
+ )
394
409
  layer.weights.pop(weight_name)
395
410
  layer.weights.pop(up_name)
@@ -20,16 +20,14 @@ import re
20
20
 
21
21
  import torch
22
22
 
23
- from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
23
+ from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend
24
+ from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer
24
25
  from sglang.srt.lora.lora_config import LoRAConfig
25
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
26
27
  from sglang.srt.utils import is_flashinfer_available, replace_submodule
27
28
 
28
29
  logger = logging.getLogger(__name__)
29
30
 
30
- if is_flashinfer_available():
31
- from flashinfer import SegmentGEMMWrapper
32
-
33
31
 
34
32
  def get_module_name(name):
35
33
  # Fallback solution of mapping from config module name to module name in model class.
@@ -77,6 +75,20 @@ def get_stacked_name(name):
77
75
  return params_mapping.get(name, (name, name))
78
76
 
79
77
 
78
+ def get_backend_from_name(name):
79
+ backend_mapping = {
80
+ "triton": TritonLoraBackend,
81
+ "flashinfer": FlashInferLoraBackend,
82
+ }
83
+
84
+ if name in backend_mapping:
85
+ return backend_mapping[name]
86
+
87
+ raise Exception(
88
+ f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
89
+ )
90
+
91
+
80
92
  def get_layer_id(name):
81
93
  match = re.search(r"layers\.(\d+)\.", name)
82
94
  if match is None:
@@ -93,6 +105,7 @@ class LoRAManager:
93
105
  max_loras_per_batch,
94
106
  load_config,
95
107
  dtype,
108
+ lora_backend,
96
109
  ):
97
110
  self.base_model = base_model
98
111
  self.lora_paths = lora_paths
@@ -101,8 +114,9 @@ class LoRAManager:
101
114
  self.load_config = load_config
102
115
  self.dtype = dtype
103
116
 
104
- workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
105
- self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
117
+ logger.info(f"Using {lora_backend} as backend of Lora kernels.")
118
+ backend_type = get_backend_from_name(lora_backend)
119
+ self.lora_backend = backend_type(lora_backend)
106
120
 
107
121
  self.init_loras()
108
122
  self.init_lora_memory_pool()
@@ -123,7 +137,7 @@ class LoRAManager:
123
137
 
124
138
  def set_lora_module(self, module_name, module):
125
139
  lora_module = get_lora_layer(
126
- module, self.segment_gemm, self.max_lora_dim, self.scaling
140
+ module, self.max_lora_dim, self.scaling, self.lora_backend
127
141
  )
128
142
  replace_submodule(self.base_model, module_name, lora_module)
129
143
  return lora_module
@@ -162,7 +176,11 @@ class LoRAManager:
162
176
  self.lora_id[name] = len(self.loras)
163
177
  self.loras.append(
164
178
  LoRAAdapter(
165
- name, self.configs[name], self.base_hf_config, self.load_config
179
+ name,
180
+ self.configs[name],
181
+ self.base_hf_config,
182
+ self.load_config,
183
+ self.lora_backend,
166
184
  )
167
185
  )
168
186
  self.loras[-1].initialize_weights()
@@ -226,8 +244,9 @@ class LoRAManager:
226
244
  self.B_buffer[module_B] = [
227
245
  torch.empty(
228
246
  (
247
+ c,
229
248
  self.max_loras_per_batch,
230
- hidden_dim_B * c,
249
+ hidden_dim_B,
231
250
  self.max_lora_dim,
232
251
  ),
233
252
  dtype=self.dtype,
@@ -263,7 +282,16 @@ class LoRAManager:
263
282
  else:
264
283
  lora_weight_name = self.get_weight_name(name, 1)
265
284
  if lora_weight_name:
266
- self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
285
+ c = self.loras[-1].get_stacked_multiply(lora_weight_name)
286
+ if c > 1:
287
+ for j in range(c):
288
+ self.B_buffer[lora_weight_name][i][j][buffer_id].copy_(
289
+ weights[j]
290
+ )
291
+ else:
292
+ self.B_buffer[lora_weight_name][i][0][buffer_id].copy_(
293
+ weights
294
+ )
267
295
 
268
296
  def prepare_lora_batch(self, forward_batch: ForwardBatch):
269
297
  # load active loras into lora memory pool
@@ -292,20 +320,30 @@ class LoRAManager:
292
320
  if cur_uids == set([None]):
293
321
  return
294
322
 
295
- # setup lora in forward modules
323
+ # set up batch info shared by all lora moruldes
296
324
  bs = forward_batch.batch_size
297
325
  seg_lens = (
298
326
  forward_batch.extend_seq_lens
299
327
  if forward_batch.forward_mode.is_extend()
300
328
  else torch.ones(bs, device="cuda")
301
329
  )
302
- # FIXME: reuse the data rather than recompute
303
330
  seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
304
331
  seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
332
+ max_len = int(torch.max(seg_lens))
305
333
  weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
306
334
  for i, lora_path in enumerate(forward_batch.lora_paths):
307
335
  weight_indices[i] = self.buffer_id[lora_path]
308
336
 
337
+ batch_info = LoraBatchInfo(
338
+ bs=bs,
339
+ seg_lens=seg_lens,
340
+ seg_indptr=seg_indptr,
341
+ max_len=max_len,
342
+ weight_indices=weight_indices,
343
+ )
344
+ self.lora_backend.set_batch_info(batch_info)
345
+
346
+ # call set_lora_info for each lora modules
309
347
  for module_name, module in self.lora_modules:
310
348
  layer_id = get_layer_id(module_name)
311
349
 
@@ -314,16 +352,10 @@ class LoRAManager:
314
352
  module.set_lora_info(
315
353
  self.A_buffer[weight_name][layer_id],
316
354
  self.B_buffer[weight_name][layer_id],
317
- bs,
318
- seg_indptr,
319
- weight_indices,
320
355
  )
321
356
  else:
322
357
  module.set_lora_info(
323
358
  self.A_buffer["qkv_proj"][layer_id],
324
359
  self.B_buffer["q_proj"][layer_id],
325
360
  self.B_buffer["kv_proj"][layer_id],
326
- bs,
327
- seg_indptr,
328
- weight_indices,
329
361
  )
@@ -0,0 +1,5 @@
1
+ from .qkv_lora_b import qkv_lora_b_fwd
2
+ from .sgemm_lora_a import sgemm_lora_a_fwd
3
+ from .sgemm_lora_b import sgemm_lora_b_fwd
4
+
5
+ __all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"]