sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
sglang/srt/lora/layers.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1
|
-
from typing import List, Tuple
|
2
|
-
|
3
1
|
import torch
|
4
2
|
from torch import nn
|
5
3
|
|
@@ -79,18 +77,13 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
79
77
|
self.B_buffer = B_buffer
|
80
78
|
|
81
79
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
82
|
-
backend_kwargs = {"base_output": base_output}
|
83
80
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
84
81
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
85
|
-
lora_a_output,
|
86
|
-
self.B_buffer
|
87
|
-
|
88
|
-
)
|
89
|
-
return (
|
90
|
-
lora_output
|
91
|
-
if self.lora_backend.fuse_output_add
|
92
|
-
else base_output + lora_output
|
82
|
+
x=lora_a_output,
|
83
|
+
weights=self.B_buffer,
|
84
|
+
base_output=base_output,
|
93
85
|
)
|
86
|
+
return lora_output
|
94
87
|
|
95
88
|
def forward(self, input_: torch.Tensor):
|
96
89
|
# duplicate the logic in ColumnParallelLinear
|
@@ -135,37 +128,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
135
128
|
):
|
136
129
|
self.set_lora = True
|
137
130
|
self.A_buffer_gate_up = A_buffer
|
138
|
-
|
139
|
-
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
|
140
|
-
if getattr(self, "B_buffer_gate_up", None) is None:
|
141
|
-
self.B_buffer_gate_up = torch.empty(
|
142
|
-
(
|
143
|
-
B_buffer[0].shape[0],
|
144
|
-
2 * B_buffer[0].shape[1],
|
145
|
-
B_buffer[0].shape[2],
|
146
|
-
),
|
147
|
-
dtype=B_buffer[0].dtype,
|
148
|
-
device=B_buffer[0].device,
|
149
|
-
)
|
150
|
-
self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
|
151
|
-
self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
|
152
|
-
else:
|
153
|
-
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
|
131
|
+
self.B_buffer_gate_up = B_buffer
|
154
132
|
|
155
133
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
156
|
-
backend_kwargs = {"base_output": base_output}
|
157
|
-
|
158
134
|
lora_output = self.lora_backend.run_gate_up_lora(
|
159
|
-
x,
|
160
|
-
self.A_buffer_gate_up,
|
161
|
-
self.B_buffer_gate_up,
|
162
|
-
|
163
|
-
)
|
164
|
-
return (
|
165
|
-
lora_output
|
166
|
-
if self.lora_backend.fuse_output_add
|
167
|
-
else base_output + lora_output
|
135
|
+
x=x,
|
136
|
+
gate_up_lora_a=self.A_buffer_gate_up,
|
137
|
+
gate_up_lora_b=self.B_buffer_gate_up,
|
138
|
+
base_output=base_output,
|
168
139
|
)
|
140
|
+
return lora_output
|
169
141
|
|
170
142
|
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
171
143
|
return A
|
@@ -173,9 +145,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
173
145
|
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
174
146
|
# Since the outputs for both gate and up are identical, we use a random one.
|
175
147
|
shard_size = self.base_layer.output_partition_sizes[0]
|
148
|
+
gate_size = self.base_layer.output_sizes[0]
|
176
149
|
start_idx = tp_rank * shard_size
|
177
150
|
end_idx = (tp_rank + 1) * shard_size
|
178
|
-
return
|
151
|
+
return torch.concat(
|
152
|
+
(
|
153
|
+
B[start_idx:end_idx, :],
|
154
|
+
B[gate_size + start_idx : gate_size + end_idx],
|
155
|
+
),
|
156
|
+
dim=0,
|
157
|
+
)
|
179
158
|
|
180
159
|
|
181
160
|
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
@@ -185,86 +164,46 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
185
164
|
lora_backend: BaseLoRABackend,
|
186
165
|
) -> None:
|
187
166
|
super().__init__(base_layer, lora_backend)
|
167
|
+
q_proj_shard_size = self.base_layer.q_proj_shard_size
|
168
|
+
kv_proj_shard_size = self.base_layer.kv_proj_shard_size
|
169
|
+
self.output_offset = torch.tensor(
|
170
|
+
[
|
171
|
+
0,
|
172
|
+
q_proj_shard_size,
|
173
|
+
q_proj_shard_size + kv_proj_shard_size,
|
174
|
+
q_proj_shard_size + 2 * kv_proj_shard_size,
|
175
|
+
],
|
176
|
+
dtype=torch.int32,
|
177
|
+
device=next(self.base_layer.parameters()).device,
|
178
|
+
)
|
179
|
+
|
180
|
+
# For computing number of launched blocks
|
181
|
+
self.max_qkv_out_dim = max(q_proj_shard_size, kv_proj_shard_size)
|
188
182
|
|
189
183
|
def set_lora_info(
|
190
184
|
self,
|
191
185
|
A_buffer_qkv: torch.Tensor,
|
192
|
-
|
193
|
-
B_buffer_kv: torch.Tensor,
|
186
|
+
B_buffer_qkv: torch.Tensor,
|
194
187
|
):
|
195
188
|
self.set_lora = True
|
196
189
|
self.A_buffer_qkv = A_buffer_qkv
|
197
|
-
|
198
|
-
if self.lora_backend.fuse_stacked_lora_b:
|
199
|
-
assert (
|
200
|
-
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
|
201
|
-
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
|
202
|
-
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
203
|
-
|
204
|
-
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
205
|
-
if getattr(self, "B_buffer_qkv", None) is None:
|
206
|
-
self.B_buffer_qkv = torch.empty(
|
207
|
-
(
|
208
|
-
B_buffer_q[0].shape[0],
|
209
|
-
output_dim_q + 2 * output_dim_kv,
|
210
|
-
B_buffer_q[0].shape[2],
|
211
|
-
),
|
212
|
-
dtype=B_buffer_q[0].dtype,
|
213
|
-
device=B_buffer_q[0].device,
|
214
|
-
)
|
215
|
-
self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
|
216
|
-
self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
|
217
|
-
B_buffer_kv[0]
|
218
|
-
)
|
219
|
-
self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
|
220
|
-
B_buffer_kv[1]
|
221
|
-
)
|
222
|
-
|
223
|
-
# Offsets of q/k/v in output dimension
|
224
|
-
if getattr(self, "output_offset", None) is None:
|
225
|
-
self.output_offset = torch.tensor(
|
226
|
-
[
|
227
|
-
0,
|
228
|
-
output_dim_q,
|
229
|
-
output_dim_q + output_dim_kv,
|
230
|
-
output_dim_q + 2 * output_dim_kv,
|
231
|
-
],
|
232
|
-
dtype=torch.int32,
|
233
|
-
device=B_buffer_q.device,
|
234
|
-
)
|
235
|
-
# For computing number of launched blocks
|
236
|
-
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
237
|
-
else:
|
238
|
-
self.B_buffer_qkv = (
|
239
|
-
B_buffer_q,
|
240
|
-
B_buffer_kv,
|
241
|
-
)
|
190
|
+
self.B_buffer_qkv = B_buffer_qkv
|
242
191
|
|
243
192
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
244
|
-
backend_kwargs = {"base_output": base_output}
|
245
|
-
if self.lora_backend.fuse_stacked_lora_b:
|
246
|
-
backend_kwargs["output_offset"] = self.output_offset
|
247
|
-
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
|
248
|
-
|
249
193
|
lora_output = self.lora_backend.run_qkv_lora(
|
250
|
-
x,
|
251
|
-
self.A_buffer_qkv,
|
252
|
-
self.B_buffer_qkv,
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
lora_output
|
257
|
-
if self.lora_backend.fuse_output_add
|
258
|
-
else base_output + lora_output
|
194
|
+
x=x,
|
195
|
+
qkv_lora_a=self.A_buffer_qkv,
|
196
|
+
qkv_lora_b=self.B_buffer_qkv,
|
197
|
+
base_output=base_output,
|
198
|
+
output_offset=self.output_offset,
|
199
|
+
max_qkv_out_dim=self.max_qkv_out_dim,
|
259
200
|
)
|
201
|
+
return lora_output
|
260
202
|
|
261
203
|
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
262
204
|
return A
|
263
205
|
|
264
|
-
def slice_lora_b_weights(
|
265
|
-
self, B: List[torch.Tensor], tp_rank: int
|
266
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
267
|
-
B_q, B_kv = B
|
206
|
+
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int) -> torch.Tensor:
|
268
207
|
base_layer = self.base_layer
|
269
208
|
q_proj_shard_size = base_layer.q_proj_shard_size
|
270
209
|
kv_proj_shard_size = base_layer.kv_proj_shard_size
|
@@ -277,7 +216,19 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
277
216
|
kv_start_idx = kv_proj_shard_size * kv_shard_id
|
278
217
|
kv_end_idx = kv_start_idx + kv_proj_shard_size
|
279
218
|
|
280
|
-
|
219
|
+
q_size, k_size, _ = base_layer.output_sizes
|
220
|
+
B_q_shard = B[q_start_idx:q_end_idx, :]
|
221
|
+
B_k_shard = B[q_size + kv_start_idx : q_size + kv_end_idx, :]
|
222
|
+
B_v_shard = B[q_size + k_size + kv_start_idx : q_size + k_size + kv_end_idx, :]
|
223
|
+
|
224
|
+
return torch.concat(
|
225
|
+
(
|
226
|
+
B_q_shard,
|
227
|
+
B_k_shard,
|
228
|
+
B_v_shard,
|
229
|
+
),
|
230
|
+
dim=0,
|
231
|
+
)
|
281
232
|
|
282
233
|
|
283
234
|
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
@@ -294,18 +245,13 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
294
245
|
self.B_buffer = B_buffer
|
295
246
|
|
296
247
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
297
|
-
backend_kwargs = {"base_output": base_output}
|
298
248
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
299
249
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
300
|
-
lora_a_output,
|
301
|
-
self.B_buffer
|
302
|
-
|
303
|
-
)
|
304
|
-
return (
|
305
|
-
lora_output
|
306
|
-
if self.lora_backend.fuse_output_add
|
307
|
-
else base_output + lora_output
|
250
|
+
x=lora_a_output,
|
251
|
+
weights=self.B_buffer,
|
252
|
+
base_output=base_output,
|
308
253
|
)
|
254
|
+
return lora_output
|
309
255
|
|
310
256
|
def forward(self, input_: torch.Tensor):
|
311
257
|
# duplicate the logic in RowParallelLinear
|
sglang/srt/lora/lora.py
CHANGED
@@ -117,7 +117,6 @@ class LoRAAdapter(nn.Module):
|
|
117
117
|
q_name = weight_name
|
118
118
|
k_name = weight_name.replace("q_proj", "k_proj")
|
119
119
|
v_name = weight_name.replace("q_proj", "v_proj")
|
120
|
-
kv_name = weight_name.replace("q_proj", "kv_proj")
|
121
120
|
qkv_name = weight_name.replace("q_proj", "qkv_proj")
|
122
121
|
|
123
122
|
# If k_proj doesn't have lora, initialize it to zero
|
@@ -126,57 +125,27 @@ class LoRAAdapter(nn.Module):
|
|
126
125
|
if "k_proj" in target_module
|
127
126
|
else torch.zeros_like(weights[v_name])
|
128
127
|
)
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
weights.pop(v_name)
|
142
|
-
else:
|
143
|
-
weights[kv_name] = torch.stack(
|
144
|
-
[
|
145
|
-
k_proj_weight,
|
146
|
-
weights[v_name],
|
147
|
-
],
|
148
|
-
dim=0,
|
149
|
-
)
|
150
|
-
if "k_proj" in target_module:
|
151
|
-
weights.pop(k_name)
|
152
|
-
weights.pop(v_name)
|
128
|
+
weights[qkv_name] = torch.cat(
|
129
|
+
(
|
130
|
+
weights[q_name],
|
131
|
+
k_proj_weight,
|
132
|
+
weights[v_name],
|
133
|
+
),
|
134
|
+
0,
|
135
|
+
)
|
136
|
+
weights.pop(q_name)
|
137
|
+
if "k_proj" in target_module:
|
138
|
+
weights.pop(k_name)
|
139
|
+
weights.pop(v_name)
|
153
140
|
elif "qkv_proj" in weight_name:
|
154
141
|
# If qkv_proj is already stacked, we normalize it following the SGL convention.
|
155
142
|
qkv_name = weight_name
|
156
143
|
q_name = weight_name.replace("qkv_proj", "q_proj")
|
157
144
|
k_name = weight_name.replace("qkv_proj", "k_proj")
|
158
145
|
v_name = weight_name.replace("qkv_proj", "v_proj")
|
159
|
-
kv_name = weight_name.replace("qkv_proj", "kv_proj")
|
160
146
|
if "lora_A" in weight_name:
|
161
147
|
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
|
162
|
-
else:
|
163
|
-
head_size = (
|
164
|
-
self.base_hf_config.hidden_size
|
165
|
-
// self.base_hf_config.num_attention_heads
|
166
|
-
)
|
167
|
-
weights[q_name], k_proj_weight, v_proj_weight = torch.split(
|
168
|
-
weights[qkv_name],
|
169
|
-
[
|
170
|
-
head_size * self.base_hf_config.num_attention_heads,
|
171
|
-
head_size * self.base_hf_config.num_key_value_heads,
|
172
|
-
head_size * self.base_hf_config.num_key_value_heads,
|
173
|
-
],
|
174
|
-
dim=0,
|
175
|
-
)
|
176
|
-
weights[kv_name] = torch.stack(
|
177
|
-
[k_proj_weight, v_proj_weight],
|
178
|
-
dim=0,
|
179
|
-
)
|
148
|
+
# else: no-op as LoRA B weight is already stacked.
|
180
149
|
|
181
150
|
def normalize_gate_up_proj(
|
182
151
|
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
@@ -187,20 +156,14 @@ class LoRAAdapter(nn.Module):
|
|
187
156
|
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
188
157
|
if up_name not in weights:
|
189
158
|
weights[up_name] = torch.zeros_like(weights[weight_name])
|
190
|
-
# FIXME: Add gate-only support for flashinfer in future implementations
|
191
159
|
assert self.lora_backend.name == "triton", (
|
192
160
|
f"LoRA weight initialization currently only supported for 'triton' backend. "
|
193
161
|
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
194
162
|
f"or consider implementing custom initialization logic for other backends."
|
195
163
|
)
|
196
|
-
|
197
|
-
weights[
|
198
|
-
|
199
|
-
)
|
200
|
-
else:
|
201
|
-
weights[gate_up_name] = torch.stack(
|
202
|
-
[weights[weight_name], weights[up_name]], dim=0
|
203
|
-
)
|
164
|
+
weights[gate_up_name] = torch.cat(
|
165
|
+
(weights[weight_name], weights[up_name]), 0
|
166
|
+
)
|
204
167
|
weights.pop(weight_name)
|
205
168
|
if up_name in weights:
|
206
169
|
weights.pop(up_name)
|
@@ -209,12 +172,4 @@ class LoRAAdapter(nn.Module):
|
|
209
172
|
gate_up_name = weight_name
|
210
173
|
if "lora_A" in weight_name:
|
211
174
|
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
|
212
|
-
else:
|
213
|
-
output_dim = weights[gate_up_name].shape[0] // 2
|
214
|
-
weights[gate_up_name] = torch.stack(
|
215
|
-
[
|
216
|
-
weights[gate_up_name][:output_dim, :],
|
217
|
-
weights[gate_up_name][output_dim:, :],
|
218
|
-
],
|
219
|
-
dim=0,
|
220
|
-
)
|
175
|
+
# else: no-op as LoRA B weight is already stacked.
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -31,7 +31,6 @@ from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
|
31
31
|
from sglang.srt.lora.utils import (
|
32
32
|
LoRABatchInfo,
|
33
33
|
LoRAType,
|
34
|
-
get_customized_names_from_hf_names,
|
35
34
|
get_layer_id,
|
36
35
|
get_normalized_lora_weight_names,
|
37
36
|
get_weight_name,
|
@@ -144,6 +143,7 @@ class LoRAManager:
|
|
144
143
|
|
145
144
|
# keep metadata for displayed messages
|
146
145
|
self.lora_refs[lora_ref.lora_id] = lora_ref
|
146
|
+
self.num_pinned_loras += int(lora_ref.pinned)
|
147
147
|
except Exception as e:
|
148
148
|
return self.create_lora_update_result(
|
149
149
|
success=False,
|
@@ -157,13 +157,22 @@ class LoRAManager:
|
|
157
157
|
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
|
158
158
|
"""
|
159
159
|
|
160
|
+
# Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration.
|
160
161
|
memory_pool = getattr(self, "memory_pool", None)
|
161
162
|
incompatible = memory_pool and not memory_pool.can_support(lora_config)
|
162
163
|
if incompatible:
|
163
164
|
raise ValueError(
|
164
|
-
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current
|
165
|
-
"Please ensure that the LoRA adapter's rank is within the configured
|
166
|
-
"included in `--
|
165
|
+
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current "
|
166
|
+
"LoRA memory pool configuration. Please ensure that the LoRA adapter's rank is within the configured "
|
167
|
+
"`--max-lora-rank` and that the target modules are included in `--lora-target-modules`."
|
168
|
+
)
|
169
|
+
|
170
|
+
# Ensure pinned LoRA adapters does not exceed maximal limit or cause starvation.
|
171
|
+
if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
|
172
|
+
raise ValueError(
|
173
|
+
f"Failed to load LoRA adapter {lora_ref.lora_name} as a pinned adapter. It is not allowed to pin all slots "
|
174
|
+
"in the LoRA memory pool to avoid starvation for unpinned adapters and base models. Please increase your "
|
175
|
+
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
|
167
176
|
)
|
168
177
|
|
169
178
|
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
@@ -172,15 +181,17 @@ class LoRAManager:
|
|
172
181
|
delete the corresponding LoRA modules.
|
173
182
|
"""
|
174
183
|
|
175
|
-
adapter = self.configs.get(lora_ref.lora_id
|
184
|
+
adapter = self.configs.get(lora_ref.lora_id)
|
185
|
+
lora_ref = self.lora_refs.get(lora_ref.lora_id)
|
176
186
|
assert (
|
177
|
-
adapter is not None
|
187
|
+
adapter is not None and lora_ref is not None
|
178
188
|
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
|
179
189
|
|
180
190
|
try:
|
181
191
|
del self.configs[lora_ref.lora_id]
|
182
192
|
del self.loras[lora_ref.lora_id]
|
183
193
|
del self.lora_refs[lora_ref.lora_id]
|
194
|
+
self.num_pinned_loras -= int(lora_ref.pinned)
|
184
195
|
except Exception as e:
|
185
196
|
return self.create_lora_update_result(
|
186
197
|
success=False,
|
@@ -189,15 +200,49 @@ class LoRAManager:
|
|
189
200
|
|
190
201
|
return self.create_lora_update_result(success=True)
|
191
202
|
|
203
|
+
def validate_lora_batch(self, lora_ids: set[str]) -> bool:
|
204
|
+
"""
|
205
|
+
Validate if the LoRA IDs in the batch can be loaded into the current LoRA memory pool.
|
206
|
+
"""
|
207
|
+
if len(lora_ids) > self.max_loras_per_batch:
|
208
|
+
return False
|
209
|
+
|
210
|
+
# skip pinned LoRA check if no pinned LoRA adapters are loaded.
|
211
|
+
if self.num_pinned_loras == 0:
|
212
|
+
return True
|
213
|
+
|
214
|
+
# counting the number of pinned LoRA adapters in the batch.
|
215
|
+
pinned_loras_in_batch = 0
|
216
|
+
for lora_id in lora_ids:
|
217
|
+
if lora_id is not None:
|
218
|
+
lora_ref = self.lora_refs.get(lora_id)
|
219
|
+
assert (
|
220
|
+
lora_ref is not None
|
221
|
+
), f"LoRA ID {lora_id} not found in lora_refs."
|
222
|
+
pinned_loras_in_batch += int(lora_ref.pinned)
|
223
|
+
|
224
|
+
assert pinned_loras_in_batch <= self.num_pinned_loras, (
|
225
|
+
f"Number of pinned LoRA adapters in the batch ({pinned_loras_in_batch}) exceeds the total number of pinned adapters "
|
226
|
+
f"({self.num_pinned_loras}). This indicates a bug in the LoRA loading logic."
|
227
|
+
)
|
228
|
+
|
229
|
+
required_slots = len(lora_ids) - pinned_loras_in_batch
|
230
|
+
mem_pool_vacancy = self.memory_pool.max_loras_per_batch - self.num_pinned_loras
|
231
|
+
|
232
|
+
return required_slots <= mem_pool_vacancy
|
233
|
+
|
192
234
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
235
|
+
|
193
236
|
# Load active loras into lora memory pool
|
194
|
-
|
195
|
-
|
196
|
-
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
|
197
|
-
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
|
198
|
-
cur_uids = set(forward_batch.lora_paths)
|
237
|
+
cur_uids = set(forward_batch.lora_ids)
|
238
|
+
|
199
239
|
assert len(cur_uids) <= self.max_loras_per_batch
|
200
|
-
self.memory_pool.prepare_lora_batch(
|
240
|
+
self.memory_pool.prepare_lora_batch(
|
241
|
+
cur_uids=cur_uids,
|
242
|
+
lora_adapters=self.loras,
|
243
|
+
lora_modules=self.lora_modules,
|
244
|
+
lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation.
|
245
|
+
)
|
201
246
|
|
202
247
|
# set up batch info shared by all lora modules
|
203
248
|
bs = forward_batch.batch_size
|
@@ -211,10 +256,10 @@ class LoRAManager:
|
|
211
256
|
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
|
212
257
|
to device (CUDA) asynchronously.
|
213
258
|
"""
|
214
|
-
weight_indices = [0] * len(forward_batch.
|
259
|
+
weight_indices = [0] * len(forward_batch.lora_ids)
|
215
260
|
lora_ranks = [0] * self.max_loras_per_batch
|
216
261
|
scalings = [0] * self.max_loras_per_batch
|
217
|
-
for i, uid in enumerate(forward_batch.
|
262
|
+
for i, uid in enumerate(forward_batch.lora_ids):
|
218
263
|
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
219
264
|
if uid is not None:
|
220
265
|
lora = self.loras[uid]
|
@@ -299,40 +344,19 @@ class LoRAManager:
|
|
299
344
|
)
|
300
345
|
self.lora_backend.set_batch_info(batch_info)
|
301
346
|
|
302
|
-
# TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
|
303
|
-
# this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
|
304
|
-
self.update_lora_info()
|
305
|
-
|
306
347
|
def update_lora_info(self):
|
307
348
|
"""
|
308
349
|
Update all LoRA modules to associate them with the latest memory buffer.
|
309
350
|
"""
|
310
351
|
for layer_id, layer_modules in enumerate(self.lora_modules):
|
311
352
|
for module_name, module in layer_modules.items():
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
),
|
320
|
-
self.memory_pool.get_tensor(
|
321
|
-
"kv_proj", layer_id, LoRAType.LORA_B
|
322
|
-
),
|
323
|
-
)
|
324
|
-
else:
|
325
|
-
weight_name = get_weight_name(
|
326
|
-
module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
|
327
|
-
)
|
328
|
-
module.set_lora_info(
|
329
|
-
self.memory_pool.get_tensor(
|
330
|
-
weight_name, layer_id, LoRAType.LORA_A
|
331
|
-
),
|
332
|
-
self.memory_pool.get_tensor(
|
333
|
-
weight_name, layer_id, LoRAType.LORA_B
|
334
|
-
),
|
335
|
-
)
|
353
|
+
weight_name = get_weight_name(
|
354
|
+
module_name, self.memory_pool.lora_weight_names
|
355
|
+
)
|
356
|
+
module.set_lora_info(
|
357
|
+
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
|
358
|
+
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
|
359
|
+
)
|
336
360
|
|
337
361
|
def init_state(
|
338
362
|
self,
|
@@ -359,6 +383,7 @@ class LoRAManager:
|
|
359
383
|
self.init_lora_weight_names()
|
360
384
|
self.init_lora_modules()
|
361
385
|
self.init_memory_pool()
|
386
|
+
self.update_lora_info()
|
362
387
|
|
363
388
|
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
364
389
|
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
@@ -370,6 +395,9 @@ class LoRAManager:
|
|
370
395
|
# Mapping from LoRA ID to LoRARef object.
|
371
396
|
self.lora_refs: Dict[str, LoRARef] = {}
|
372
397
|
|
398
|
+
# Count of pinned LoRA adapters.
|
399
|
+
self.num_pinned_loras: int = 0
|
400
|
+
|
373
401
|
if lora_paths:
|
374
402
|
for lora_ref in lora_paths.values():
|
375
403
|
result = self.load_lora_adapter(lora_ref)
|
@@ -390,13 +418,20 @@ class LoRAManager:
|
|
390
418
|
else:
|
391
419
|
self.target_modules = set()
|
392
420
|
for config in self.configs.values():
|
421
|
+
if not isinstance(config.target_modules, list):
|
422
|
+
raise ValueError(
|
423
|
+
f"SGLang currently only supports inferring LoRA target modules when a list of "
|
424
|
+
"suffixes is provided in `target_modules` field of PEFT config. Please explicitly "
|
425
|
+
"specify `--lora-target-modules` during server startup. You can specify `all` to "
|
426
|
+
"enable all support modules types. "
|
427
|
+
)
|
393
428
|
self.target_modules.update(config.target_modules)
|
394
429
|
|
395
430
|
if max_lora_rank is not None:
|
396
431
|
self.max_lora_rank = max_lora_rank
|
397
432
|
else:
|
398
433
|
self.max_lora_rank = max(
|
399
|
-
[x.
|
434
|
+
[x.r for x in self.configs.values()],
|
400
435
|
default=0,
|
401
436
|
)
|
402
437
|
|
@@ -405,9 +440,9 @@ class LoRAManager:
|
|
405
440
|
Add new LoRA weight names if needed based on the current `self.configs`.
|
406
441
|
"""
|
407
442
|
|
408
|
-
|
409
|
-
|
410
|
-
|
443
|
+
self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
|
444
|
+
self.target_modules
|
445
|
+
)
|
411
446
|
|
412
447
|
def load_lora_weights(self, lora_ref: LoRARef):
|
413
448
|
"""
|
@@ -423,15 +458,6 @@ class LoRAManager:
|
|
423
458
|
lora_adapter.initialize_weights()
|
424
459
|
self.loras[lora_ref.lora_id] = lora_adapter
|
425
460
|
|
426
|
-
# Additional checks for flashinfer backend
|
427
|
-
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
428
|
-
if self.lora_backend == "flashinfer":
|
429
|
-
lora_dims = set(x.r for x in self.configs.values())
|
430
|
-
scalings = set(x.scaling for x in self.loras.values())
|
431
|
-
assert (
|
432
|
-
len(lora_dims) == 1 and len(scalings) == 1
|
433
|
-
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
434
|
-
|
435
461
|
def init_memory_pool(self):
|
436
462
|
"""(Re)initialize the LoRA memory pool based on the current configurations."""
|
437
463
|
self.memory_pool = LoRAMemoryPool(
|
@@ -456,12 +482,6 @@ class LoRAManager:
|
|
456
482
|
{} for _ in range(self.base_hf_config.num_hidden_layers)
|
457
483
|
]
|
458
484
|
|
459
|
-
# Target module names of customized layers defined in python/sglang/srt/layers
|
460
|
-
# e.g., {"qkv_proj", "o_proj"}
|
461
|
-
customized_target_names = get_customized_names_from_hf_names(
|
462
|
-
self.target_modules, self.base_model
|
463
|
-
)
|
464
|
-
|
465
485
|
for module_name, module in self.base_model.named_modules():
|
466
486
|
# TODO (lifuhuang): in the future, we should consider generalizing the
|
467
487
|
# should_apply_lora function to support mapping by full module name instead
|
@@ -474,7 +494,7 @@ class LoRAManager:
|
|
474
494
|
continue
|
475
495
|
|
476
496
|
# The module should be converted if it is included in target_names
|
477
|
-
if module_name.split(".")[-1] in
|
497
|
+
if module_name.split(".")[-1] in self.lora_weight_names:
|
478
498
|
layer_id = get_layer_id(module_name)
|
479
499
|
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
480
500
|
module_name, module
|