sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +72 -10
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +6 -16
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +582 -125
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +79 -6
- sglang/srt/layers/quantization/__init__.py +137 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +44 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -127
- sglang/srt/managers/scheduler.py +29 -23
- sglang/srt/managers/tokenizer_manager.py +1 -2
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +16 -13
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +64 -59
- sglang/srt/model_loader/loader.py +19 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +568 -0
- sglang/srt/models/deepseek_janus_pro.py +12 -17
- sglang/srt/models/deepseek_v2.py +339 -123
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +20 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +106 -93
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +120 -25
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +94 -25
- sglang/srt/utils.py +137 -51
- sglang/test/runners.py +27 -2
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +14 -27
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
sglang/srt/lora/layers.py
CHANGED
@@ -23,14 +23,10 @@ class BaseLayerWithLoRA(nn.Module):
|
|
23
23
|
def __init__(
|
24
24
|
self,
|
25
25
|
base_layer: nn.Module,
|
26
|
-
lora_rank: int,
|
27
|
-
scaling: float,
|
28
26
|
lora_backend: BaseLoRABackend,
|
29
27
|
):
|
30
28
|
super().__init__()
|
31
29
|
self.base_layer: nn.Module = base_layer
|
32
|
-
self.lora_rank: int = lora_rank
|
33
|
-
self.scaling: float = scaling
|
34
30
|
self.set_lora: bool = False
|
35
31
|
self.lora_backend: BaseLoRABackend = lora_backend
|
36
32
|
|
@@ -59,11 +55,9 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|
59
55
|
def __init__(
|
60
56
|
self,
|
61
57
|
base_layer: VocabParallelEmbedding,
|
62
|
-
lora_rank: int,
|
63
|
-
scaling: float,
|
64
58
|
lora_backend: BaseLoRABackend,
|
65
59
|
) -> None:
|
66
|
-
super().__init__(base_layer,
|
60
|
+
super().__init__(base_layer, lora_backend)
|
67
61
|
self.weight = base_layer.weight
|
68
62
|
|
69
63
|
|
@@ -71,11 +65,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
71
65
|
def __init__(
|
72
66
|
self,
|
73
67
|
base_layer: ColumnParallelLinear,
|
74
|
-
lora_rank: int,
|
75
|
-
scaling: float,
|
76
68
|
lora_backend: BaseLoRABackend,
|
77
69
|
) -> None:
|
78
|
-
super().__init__(base_layer,
|
70
|
+
super().__init__(base_layer, lora_backend)
|
79
71
|
|
80
72
|
def set_lora_info(
|
81
73
|
self,
|
@@ -87,7 +79,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
87
79
|
self.B_buffer = B_buffer
|
88
80
|
|
89
81
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
90
|
-
backend_kwargs = {"base_output": base_output
|
82
|
+
backend_kwargs = {"base_output": base_output}
|
91
83
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
92
84
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
93
85
|
lora_a_output,
|
@@ -96,8 +88,8 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
96
88
|
)
|
97
89
|
return (
|
98
90
|
lora_output
|
99
|
-
if self.lora_backend.
|
100
|
-
else base_output + lora_output
|
91
|
+
if self.lora_backend.fuse_output_add
|
92
|
+
else base_output + lora_output
|
101
93
|
)
|
102
94
|
|
103
95
|
def forward(self, input_: torch.Tensor):
|
@@ -132,11 +124,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
132
124
|
def __init__(
|
133
125
|
self,
|
134
126
|
base_layer: MergedColumnParallelLinear,
|
135
|
-
lora_rank: int,
|
136
|
-
scaling: float,
|
137
127
|
lora_backend: BaseLoRABackend,
|
138
128
|
) -> None:
|
139
|
-
super().__init__(base_layer,
|
129
|
+
super().__init__(base_layer, lora_backend)
|
140
130
|
|
141
131
|
def set_lora_info(
|
142
132
|
self,
|
@@ -155,7 +145,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
155
145
|
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
|
156
146
|
|
157
147
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
158
|
-
backend_kwargs = {"base_output": base_output
|
148
|
+
backend_kwargs = {"base_output": base_output}
|
159
149
|
|
160
150
|
lora_output = self.lora_backend.run_gate_up_lora(
|
161
151
|
x,
|
@@ -165,8 +155,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
165
155
|
)
|
166
156
|
return (
|
167
157
|
lora_output
|
168
|
-
if self.lora_backend.
|
169
|
-
else base_output + lora_output
|
158
|
+
if self.lora_backend.fuse_output_add
|
159
|
+
else base_output + lora_output
|
170
160
|
)
|
171
161
|
|
172
162
|
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
@@ -184,11 +174,9 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
184
174
|
def init__(
|
185
175
|
self,
|
186
176
|
base_layer: QKVParallelLinear,
|
187
|
-
lora_rank: int,
|
188
|
-
scaling: float,
|
189
177
|
lora_backend: BaseLoRABackend,
|
190
178
|
) -> None:
|
191
|
-
super().__init__(base_layer,
|
179
|
+
super().__init__(base_layer, lora_backend)
|
192
180
|
|
193
181
|
def set_lora_info(
|
194
182
|
self,
|
@@ -230,7 +218,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
230
218
|
)
|
231
219
|
|
232
220
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
233
|
-
backend_kwargs = {"base_output": base_output
|
221
|
+
backend_kwargs = {"base_output": base_output}
|
234
222
|
if self.lora_backend.fuse_stacked_lora_b:
|
235
223
|
backend_kwargs["output_offset"] = self.output_offset
|
236
224
|
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
|
@@ -243,8 +231,8 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
243
231
|
)
|
244
232
|
return (
|
245
233
|
lora_output
|
246
|
-
if self.lora_backend.
|
247
|
-
else base_output + lora_output
|
234
|
+
if self.lora_backend.fuse_output_add
|
235
|
+
else base_output + lora_output
|
248
236
|
)
|
249
237
|
|
250
238
|
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
@@ -273,11 +261,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
273
261
|
def __init__(
|
274
262
|
self,
|
275
263
|
base_layer: RowParallelLinear,
|
276
|
-
lora_rank: int,
|
277
|
-
scaling: float,
|
278
264
|
lora_backend: BaseLoRABackend,
|
279
265
|
) -> None:
|
280
|
-
super().__init__(base_layer,
|
266
|
+
super().__init__(base_layer, lora_backend)
|
281
267
|
|
282
268
|
def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor):
|
283
269
|
self.set_lora = True
|
@@ -285,7 +271,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
285
271
|
self.B_buffer = B_buffer
|
286
272
|
|
287
273
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
288
|
-
backend_kwargs = {"base_output": base_output
|
274
|
+
backend_kwargs = {"base_output": base_output}
|
289
275
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
290
276
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
291
277
|
lora_a_output,
|
@@ -294,8 +280,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
294
280
|
)
|
295
281
|
return (
|
296
282
|
lora_output
|
297
|
-
if self.lora_backend.
|
298
|
-
else base_output + lora_output
|
283
|
+
if self.lora_backend.fuse_output_add
|
284
|
+
else base_output + lora_output
|
299
285
|
)
|
300
286
|
|
301
287
|
def forward(self, input_: torch.Tensor):
|
@@ -344,7 +330,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
344
330
|
|
345
331
|
|
346
332
|
def get_lora_layer(
|
347
|
-
layer: nn.Module,
|
333
|
+
layer: nn.Module, lora_backend: BaseLoRABackend
|
348
334
|
) -> BaseLayerWithLoRA:
|
349
335
|
supported_layer_types = {
|
350
336
|
# the order matters
|
@@ -356,6 +342,6 @@ def get_lora_layer(
|
|
356
342
|
}
|
357
343
|
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
358
344
|
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
359
|
-
ret = lora_layer_type(layer,
|
345
|
+
ret = lora_layer_type(layer, lora_backend)
|
360
346
|
return ret
|
361
347
|
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -103,11 +103,14 @@ class LoRAManager:
|
|
103
103
|
self.loras[name] = lora_adapter
|
104
104
|
|
105
105
|
# misc lora configs
|
106
|
-
# FIXME remove the restrictions after implementing unified paging
|
107
106
|
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
108
|
-
|
109
|
-
|
110
|
-
|
107
|
+
|
108
|
+
if self.lora_backend == "flashinfer":
|
109
|
+
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
110
|
+
max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
|
111
|
+
scaling = list(self.loras.values())[0].scaling
|
112
|
+
assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
|
113
|
+
assert all(x.scaling == scaling for x in self.loras.values())
|
111
114
|
|
112
115
|
# Convert original model layers to layers with LoRA
|
113
116
|
self.convert_to_lora_layers()
|
@@ -148,8 +151,18 @@ class LoRAManager:
|
|
148
151
|
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
149
152
|
max_len = int(torch.max(seg_lens))
|
150
153
|
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
154
|
+
|
155
|
+
lora_ranks = torch.empty(
|
156
|
+
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
157
|
+
)
|
158
|
+
scalings = torch.empty(
|
159
|
+
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
160
|
+
)
|
151
161
|
for i, lora_path in enumerate(forward_batch.lora_paths):
|
152
162
|
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
163
|
+
lora = self.loras[lora_path]
|
164
|
+
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
165
|
+
scalings[weight_indices[i]] = lora.scaling
|
153
166
|
|
154
167
|
batch_info = LoRABatchInfo(
|
155
168
|
bs=bs,
|
@@ -157,6 +170,8 @@ class LoRAManager:
|
|
157
170
|
seg_indptr=seg_indptr,
|
158
171
|
max_len=max_len,
|
159
172
|
weight_indices=weight_indices,
|
173
|
+
lora_ranks=lora_ranks,
|
174
|
+
scalings=scalings,
|
160
175
|
)
|
161
176
|
self.lora_backend.set_batch_info(batch_info)
|
162
177
|
|
@@ -189,9 +204,7 @@ class LoRAManager:
|
|
189
204
|
)
|
190
205
|
|
191
206
|
def set_lora_module(self, module_name, module):
|
192
|
-
lora_module = get_lora_layer(
|
193
|
-
module, self.max_lora_dim, self.scaling, self.lora_backend
|
194
|
-
)
|
207
|
+
lora_module = get_lora_layer(module, self.lora_backend)
|
195
208
|
replace_submodule(self.base_model, module_name, lora_module)
|
196
209
|
return lora_module
|
197
210
|
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -163,10 +163,11 @@ class LoRAMemoryPool:
|
|
163
163
|
if uid is None:
|
164
164
|
for i in range(self.num_layer):
|
165
165
|
for k in self.A_buffer.keys():
|
166
|
-
self.A_buffer[k][i][buffer_id]
|
166
|
+
self.A_buffer[k][i][buffer_id] = 0
|
167
167
|
return
|
168
168
|
|
169
169
|
assert lora_adapter is not None
|
170
|
+
lora_rank = lora_adapter.config.hf_config["r"]
|
170
171
|
for layer_id in range(self.num_layer):
|
171
172
|
layer_weights = lora_adapter.layers[layer_id].weights
|
172
173
|
temp_A_buffer: Dict[str, torch.Tensor] = {}
|
@@ -208,17 +209,22 @@ class LoRAMemoryPool:
|
|
208
209
|
)
|
209
210
|
|
210
211
|
for name, weights in temp_A_buffer.items():
|
211
|
-
|
212
|
+
c = get_stacked_multiply(name)
|
213
|
+
self.A_buffer[name][layer_id][buffer_id][: lora_rank * c, :].copy_(
|
214
|
+
weights
|
215
|
+
)
|
212
216
|
|
213
217
|
for name, weights in temp_B_buffer.items():
|
214
218
|
c = get_stacked_multiply(name)
|
215
219
|
if c > 1:
|
216
220
|
for stacked_id in range(c):
|
217
|
-
self.B_buffer[name][layer_id][stacked_id][buffer_id]
|
218
|
-
|
219
|
-
)
|
221
|
+
self.B_buffer[name][layer_id][stacked_id][buffer_id][
|
222
|
+
:, :lora_rank
|
223
|
+
].copy_(weights[stacked_id])
|
220
224
|
else:
|
221
|
-
self.B_buffer[name][layer_id][0][buffer_id].copy_(
|
225
|
+
self.B_buffer[name][layer_id][0][buffer_id][:, :lora_rank].copy_(
|
226
|
+
weights
|
227
|
+
)
|
222
228
|
|
223
229
|
def get_tensor(
|
224
230
|
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
@@ -22,17 +22,18 @@ def _gate_up_lora_b_kernel(
|
|
22
22
|
w_stride_2,
|
23
23
|
output_stride_0,
|
24
24
|
output_stride_1,
|
25
|
-
# Information on sequence lengths and weight id
|
25
|
+
# Information on sequence lengths,ranks and weight id
|
26
26
|
seg_lens,
|
27
27
|
seg_indptr,
|
28
28
|
weight_indices,
|
29
|
+
lora_ranks,
|
29
30
|
# Meta parameters
|
30
31
|
BLOCK_S: tl.constexpr,
|
31
32
|
BLOCK_N: tl.constexpr,
|
32
33
|
BLOCK_K: tl.constexpr,
|
33
34
|
# For fused output scaling and adding
|
34
35
|
fuse_scaling_add,
|
35
|
-
|
36
|
+
scalings,
|
36
37
|
):
|
37
38
|
# This kernel packs 2 sgemms (gate/up) into a single kernel.
|
38
39
|
|
@@ -51,6 +52,11 @@ def _gate_up_lora_b_kernel(
|
|
51
52
|
w_index = tl.load(weight_indices + batch_id)
|
52
53
|
seg_start = tl.load(seg_indptr + batch_id)
|
53
54
|
n_start = gate_up_id * output_dim # offset on output dim
|
55
|
+
rank = tl.load(lora_ranks + w_index)
|
56
|
+
scaling = tl.load(scalings + w_index)
|
57
|
+
|
58
|
+
# Adjust K (rank) according to the specific LoRA adapter
|
59
|
+
K = tl.minimum(K, rank)
|
54
60
|
|
55
61
|
# The tile in output matrix will have (pid_s, pid_n) as id
|
56
62
|
num_pid_n = tl.cdiv(output_dim, BLOCK_N)
|
@@ -109,7 +115,6 @@ def gate_up_lora_b_fwd(
|
|
109
115
|
batch_info: LoRABatchInfo,
|
110
116
|
output_dim: int,
|
111
117
|
base_output: torch.Tensor = None,
|
112
|
-
scaling: float = 1.0,
|
113
118
|
) -> torch.Tensor:
|
114
119
|
|
115
120
|
# x: (s, 2 * r)
|
@@ -160,11 +165,12 @@ def gate_up_lora_b_fwd(
|
|
160
165
|
batch_info.seg_lens,
|
161
166
|
batch_info.seg_indptr,
|
162
167
|
batch_info.weight_indices,
|
168
|
+
batch_info.lora_ranks,
|
163
169
|
BLOCK_S,
|
164
170
|
BLOCK_OUT,
|
165
171
|
BLOCK_R,
|
166
172
|
fuse_scaling_add,
|
167
|
-
|
173
|
+
batch_info.scalings,
|
168
174
|
)
|
169
175
|
|
170
176
|
return output
|
@@ -26,6 +26,7 @@ def _qkv_lora_b_kernel(
|
|
26
26
|
seg_lens,
|
27
27
|
seg_indptr,
|
28
28
|
weight_indices,
|
29
|
+
lora_ranks,
|
29
30
|
# Offsets of q/k/v slice on output dimension
|
30
31
|
n_offs,
|
31
32
|
# Meta parameters
|
@@ -34,7 +35,7 @@ def _qkv_lora_b_kernel(
|
|
34
35
|
BLOCK_K: tl.constexpr,
|
35
36
|
# For fused output scaling and adding
|
36
37
|
fuse_scaling_add,
|
37
|
-
|
38
|
+
scalings,
|
38
39
|
):
|
39
40
|
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
|
40
41
|
|
@@ -54,6 +55,10 @@ def _qkv_lora_b_kernel(
|
|
54
55
|
seg_start = tl.load(seg_indptr + batch_id)
|
55
56
|
n_start = tl.load(n_offs + qkv_id)
|
56
57
|
n_size = tl.load(n_offs + qkv_id + 1) - n_start
|
58
|
+
rank = tl.load(lora_ranks + w_index)
|
59
|
+
scaling = tl.load(scalings + w_index)
|
60
|
+
# Adjust K (rank) according to the specific LoRA adapter
|
61
|
+
K = tl.minimum(K, rank)
|
57
62
|
|
58
63
|
# The tile in output matrix will have (pid_s, pid_n) as id
|
59
64
|
num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
|
@@ -112,7 +117,6 @@ def qkv_lora_b_fwd(
|
|
112
117
|
output_offset: torch.Tensor,
|
113
118
|
max_qkv_out_dim: int,
|
114
119
|
base_output: torch.Tensor = None,
|
115
|
-
scaling: float = 1.0,
|
116
120
|
) -> torch.Tensor:
|
117
121
|
|
118
122
|
# x: (s, 3 * r)
|
@@ -171,12 +175,13 @@ def qkv_lora_b_fwd(
|
|
171
175
|
batch_info.seg_lens,
|
172
176
|
batch_info.seg_indptr,
|
173
177
|
batch_info.weight_indices,
|
178
|
+
batch_info.lora_ranks,
|
174
179
|
output_offset,
|
175
180
|
BLOCK_S,
|
176
181
|
BLOCK_OUT,
|
177
182
|
BLOCK_R,
|
178
183
|
fuse_scaling_add,
|
179
|
-
|
184
|
+
batch_info.scalings,
|
180
185
|
)
|
181
186
|
|
182
187
|
return output
|
@@ -12,8 +12,9 @@ def _sgemm_lora_a_kernel(
|
|
12
12
|
weights,
|
13
13
|
output,
|
14
14
|
# Matrix dimensions
|
15
|
-
N, # r
|
15
|
+
N, # stack_num * r
|
16
16
|
K, # input_dim
|
17
|
+
stack_num,
|
17
18
|
# Strides
|
18
19
|
x_stride_0,
|
19
20
|
x_stride_1,
|
@@ -22,10 +23,11 @@ def _sgemm_lora_a_kernel(
|
|
22
23
|
w_stride_2,
|
23
24
|
output_stride_0,
|
24
25
|
output_stride_1,
|
25
|
-
# Information on sequence lengths and weight id
|
26
|
+
# Information on sequence lengths,ranks and weight id
|
26
27
|
seg_lens,
|
27
28
|
seg_indptr,
|
28
29
|
weight_indices,
|
30
|
+
lora_ranks,
|
29
31
|
# Meta parameters
|
30
32
|
BLOCK_S: tl.constexpr,
|
31
33
|
BLOCK_N: tl.constexpr,
|
@@ -43,6 +45,9 @@ def _sgemm_lora_a_kernel(
|
|
43
45
|
seg_len = tl.load(seg_lens + batch_id)
|
44
46
|
w_index = tl.load(weight_indices + batch_id)
|
45
47
|
seg_start = tl.load(seg_indptr + batch_id)
|
48
|
+
rank = tl.load(lora_ranks + w_index)
|
49
|
+
# Adjust N (stack_num * max_rank) according to the specific LoRA adapter
|
50
|
+
N = tl.minimum(N, rank * stack_num)
|
46
51
|
|
47
52
|
# The tile in output matrix will have (pid_s, pid_n) as id
|
48
53
|
num_pid_n = tl.cdiv(N, BLOCK_N)
|
@@ -91,11 +96,15 @@ def _sgemm_lora_a_kernel(
|
|
91
96
|
|
92
97
|
|
93
98
|
def sgemm_lora_a_fwd(
|
94
|
-
x: torch.Tensor,
|
99
|
+
x: torch.Tensor,
|
100
|
+
weights: torch.Tensor,
|
101
|
+
batch_info: LoRABatchInfo,
|
102
|
+
stack_num: int = 1,
|
95
103
|
) -> torch.Tensor:
|
96
104
|
# x: (s, input_dim)
|
97
|
-
# weights: (num_lora, r, input_dim)
|
98
|
-
# output: (s, r)
|
105
|
+
# weights: (num_lora, stack_num * r, input_dim)
|
106
|
+
# output: (s, stack_num * r)
|
107
|
+
# stack_num: run_qkv_lora: 3, run_gate_up_lora: 2
|
99
108
|
# when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
|
100
109
|
# input_dim is much larger than r
|
101
110
|
|
@@ -126,6 +135,7 @@ def sgemm_lora_a_fwd(
|
|
126
135
|
output,
|
127
136
|
R,
|
128
137
|
K,
|
138
|
+
stack_num,
|
129
139
|
x.stride(0),
|
130
140
|
x.stride(1),
|
131
141
|
weights.stride(0),
|
@@ -136,6 +146,7 @@ def sgemm_lora_a_fwd(
|
|
136
146
|
batch_info.seg_lens,
|
137
147
|
batch_info.seg_indptr,
|
138
148
|
batch_info.weight_indices,
|
149
|
+
batch_info.lora_ranks,
|
139
150
|
BLOCK_S,
|
140
151
|
BLOCK_R,
|
141
152
|
BLOCK_K,
|
@@ -26,13 +26,14 @@ def _sgemm_lora_b_kernel(
|
|
26
26
|
seg_lens,
|
27
27
|
seg_indptr,
|
28
28
|
weight_indices,
|
29
|
+
lora_ranks,
|
29
30
|
# Meta parameters
|
30
31
|
BLOCK_S: tl.constexpr,
|
31
32
|
BLOCK_N: tl.constexpr,
|
32
33
|
BLOCK_K: tl.constexpr,
|
33
34
|
# For fused output scaling and adding
|
34
35
|
fuse_scaling_add,
|
35
|
-
|
36
|
+
scalings,
|
36
37
|
):
|
37
38
|
# x: (s, K), s is the sum of sequence lengths
|
38
39
|
# weights: (num_lora, N, K)
|
@@ -45,6 +46,10 @@ def _sgemm_lora_b_kernel(
|
|
45
46
|
seg_len = tl.load(seg_lens + batch_id)
|
46
47
|
w_index = tl.load(weight_indices + batch_id)
|
47
48
|
seg_start = tl.load(seg_indptr + batch_id)
|
49
|
+
rank = tl.load(lora_ranks + w_index)
|
50
|
+
scaling = tl.load(scalings + w_index)
|
51
|
+
# Adjust K (rank) according to the specific LoRA adapter
|
52
|
+
K = tl.minimum(K, rank)
|
48
53
|
|
49
54
|
# The tile in output matrix will have (pid_s, pid_n) as id
|
50
55
|
num_pid_n = tl.cdiv(N, BLOCK_N)
|
@@ -100,12 +105,11 @@ def sgemm_lora_b_fwd(
|
|
100
105
|
weights: torch.Tensor,
|
101
106
|
batch_info: LoRABatchInfo,
|
102
107
|
base_output: torch.Tensor = None,
|
103
|
-
scaling: float = 1.0,
|
104
108
|
) -> torch.Tensor:
|
105
|
-
# x: (s,
|
106
|
-
# weights: (num_lora, output_dim,
|
109
|
+
# x: (s, max_r)
|
110
|
+
# weights: (num_lora, output_dim, max_r)
|
107
111
|
# output: (s, output_dim)
|
108
|
-
# output_dim is much larger than
|
112
|
+
# output_dim is much larger than max_r
|
109
113
|
|
110
114
|
assert x.is_contiguous()
|
111
115
|
assert weights.is_contiguous()
|
@@ -150,10 +154,11 @@ def sgemm_lora_b_fwd(
|
|
150
154
|
batch_info.seg_lens,
|
151
155
|
batch_info.seg_indptr,
|
152
156
|
batch_info.weight_indices,
|
157
|
+
batch_info.lora_ranks,
|
153
158
|
BLOCK_S,
|
154
159
|
BLOCK_N,
|
155
160
|
BLOCK_R,
|
156
161
|
fuse_scaling_add,
|
157
|
-
|
162
|
+
batch_info.scalings,
|
158
163
|
)
|
159
164
|
return output
|
sglang/srt/lora/utils.py
CHANGED
@@ -25,6 +25,12 @@ class LoRABatchInfo:
|
|
25
25
|
# The index of lora adapter used by each sequence, in shape (bs,)
|
26
26
|
weight_indices: torch.Tensor
|
27
27
|
|
28
|
+
# ranks of each lora adapter, in shape (lora_num,)
|
29
|
+
lora_ranks: torch.Tensor
|
30
|
+
|
31
|
+
# scaling of each lora adapter, in shape (lora_num,)
|
32
|
+
scalings: torch.Tensor
|
33
|
+
|
28
34
|
|
29
35
|
class LoRAType(Enum):
|
30
36
|
LORA_A = 0
|
@@ -149,6 +149,7 @@ class HiCacheController:
|
|
149
149
|
self,
|
150
150
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
151
151
|
mem_pool_host: HostKVCache,
|
152
|
+
page_size: int,
|
152
153
|
load_cache_event: threading.Event = None,
|
153
154
|
write_policy: str = "write_through_selective",
|
154
155
|
):
|
@@ -156,6 +157,7 @@ class HiCacheController:
|
|
156
157
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
157
158
|
self.mem_pool_host = mem_pool_host
|
158
159
|
self.write_policy = write_policy
|
160
|
+
self.page_size = page_size
|
159
161
|
|
160
162
|
self.load_cache_event = load_cache_event
|
161
163
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
@@ -184,7 +186,12 @@ class HiCacheController:
|
|
184
186
|
self.load_stream = torch.cuda.Stream()
|
185
187
|
|
186
188
|
self.write_thread = threading.Thread(
|
187
|
-
target=
|
189
|
+
target=(
|
190
|
+
self.write_thread_func_buffer
|
191
|
+
if self.page_size == 1
|
192
|
+
else self.write_thread_func_direct
|
193
|
+
),
|
194
|
+
daemon=True,
|
188
195
|
)
|
189
196
|
self.load_thread = threading.Thread(
|
190
197
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
@@ -205,7 +212,12 @@ class HiCacheController:
|
|
205
212
|
self.ack_load_queue.queue.clear()
|
206
213
|
|
207
214
|
self.write_thread = threading.Thread(
|
208
|
-
target=
|
215
|
+
target=(
|
216
|
+
self.write_thread_func_buffer
|
217
|
+
if self.page_size == 1
|
218
|
+
else self.write_thread_func_direct
|
219
|
+
),
|
220
|
+
daemon=True,
|
209
221
|
)
|
210
222
|
self.load_thread = threading.Thread(
|
211
223
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
@@ -260,10 +272,12 @@ class HiCacheController:
|
|
260
272
|
while not self.stop_event.is_set():
|
261
273
|
try:
|
262
274
|
operation = self.write_queue.get(block=True, timeout=1)
|
263
|
-
|
264
|
-
operation.
|
275
|
+
self.mem_pool_host.write_page_all_layers(
|
276
|
+
operation.host_indices,
|
277
|
+
operation.device_indices,
|
278
|
+
self.mem_pool_device,
|
265
279
|
)
|
266
|
-
self.
|
280
|
+
self.write_stream.synchronize()
|
267
281
|
self.mem_pool_host.complete_io(operation.host_indices)
|
268
282
|
for node_id in operation.node_ids:
|
269
283
|
if node_id != 0:
|
@@ -320,12 +334,21 @@ class HiCacheController:
|
|
320
334
|
|
321
335
|
self.layer_done_counter.reset()
|
322
336
|
for i in range(self.mem_pool_host.layer_num):
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
337
|
+
if self.page_size == 1:
|
338
|
+
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
339
|
+
batch_operation.host_indices, i
|
340
|
+
)
|
341
|
+
self.mem_pool_device.transfer_per_layer(
|
342
|
+
batch_operation.device_indices, flat_data, i
|
343
|
+
)
|
344
|
+
else:
|
345
|
+
self.mem_pool_host.load_page_per_layer(
|
346
|
+
batch_operation.host_indices,
|
347
|
+
batch_operation.device_indices,
|
348
|
+
self.mem_pool_device,
|
349
|
+
i,
|
350
|
+
)
|
351
|
+
self.load_stream.synchronize()
|
329
352
|
self.layer_done_counter.increment()
|
330
353
|
|
331
354
|
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -20,7 +20,7 @@ import copy
|
|
20
20
|
import uuid
|
21
21
|
from dataclasses import dataclass, field
|
22
22
|
from enum import Enum
|
23
|
-
from typing import Any, Dict, List, Optional, Union
|
23
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
24
24
|
|
25
25
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
26
26
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -650,7 +650,7 @@ class ProfileReqInput:
|
|
650
650
|
# If it is set, profiling is automatically stopped after this step, and
|
651
651
|
# the caller doesn't need to run stop_profile.
|
652
652
|
num_steps: Optional[int] = None
|
653
|
-
activities: Optional[List[
|
653
|
+
activities: Optional[List[Literal["CPU", "GPU", "MEM", "CUDA_PROFILER"]]] = None
|
654
654
|
|
655
655
|
|
656
656
|
class ProfileReqType(Enum):
|
@@ -675,6 +675,8 @@ class ProfileReq:
|
|
675
675
|
output_dir: Optional[str] = None
|
676
676
|
num_steps: Optional[int] = None
|
677
677
|
activities: Optional[List[str]] = None
|
678
|
+
with_stack: Optional[bool] = None
|
679
|
+
record_shapes: Optional[bool] = None
|
678
680
|
|
679
681
|
|
680
682
|
@dataclass
|