sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +12 -1
- sglang/srt/conversation.py +35 -1
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/layers/communicator.py +3 -1
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
- sglang/srt/layers/moe/ep_moe/layer.py +140 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +8 -1
- sglang/srt/managers/mm_utils.py +4 -2
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/scheduler.py +17 -5
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +113 -63
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/models/deepseek_v2.py +16 -2
- sglang/srt/models/mllama4.py +360 -79
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +62 -60
- sglang/srt/server_args.py +15 -0
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +37 -17
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,264 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch.nn import Module
|
6
|
+
from torch.nn.parameter import Parameter
|
7
|
+
|
8
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
9
|
+
from sglang.srt.layers.quantization.base_config import (
|
10
|
+
QuantizationConfig,
|
11
|
+
QuantizeMethodBase,
|
12
|
+
)
|
13
|
+
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
14
|
+
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
15
|
+
from sglang.srt.utils import set_weight_attrs
|
16
|
+
|
17
|
+
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
class W4AFp8Config(QuantizationConfig):
|
23
|
+
"""Config class for MIXED_PRECISION W4AFp8."""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
is_checkpoint_fp8_serialized: bool = True,
|
28
|
+
is_checkpoint_w4afp8_serialized: bool = True,
|
29
|
+
linear_activation_scheme: str = "dynamic",
|
30
|
+
moe_activation_scheme: str = "static",
|
31
|
+
ignored_layers: Optional[List[str]] = None,
|
32
|
+
weight_block_size: Optional[List[int]] = None,
|
33
|
+
group_size: int = 128,
|
34
|
+
) -> None:
|
35
|
+
super().__init__()
|
36
|
+
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
37
|
+
self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized
|
38
|
+
if is_checkpoint_w4afp8_serialized:
|
39
|
+
logger.warning("Detected w4afp8 checkpoint. Please note that")
|
40
|
+
if moe_activation_scheme not in ACTIVATION_SCHEMES:
|
41
|
+
raise ValueError(f"Unsupported activation scheme {moe_activation_scheme}")
|
42
|
+
self.linear_activation_scheme = linear_activation_scheme
|
43
|
+
self.moe_activation_scheme = moe_activation_scheme
|
44
|
+
self.ignored_layers = ignored_layers or []
|
45
|
+
self.weight_block_size = [128, 128]
|
46
|
+
self.group_size = group_size
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def get_name(cls) -> str:
|
50
|
+
return "w4afp8"
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
54
|
+
return [torch.bfloat16, torch.float8_e4m3fn]
|
55
|
+
|
56
|
+
@classmethod
|
57
|
+
def get_min_capability(cls) -> int:
|
58
|
+
return 90
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def get_config_filenames(cls) -> List[str]:
|
62
|
+
return []
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config":
|
66
|
+
quant_method = cls.get_from_keys(config, ["quant_method"])
|
67
|
+
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
68
|
+
is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
|
69
|
+
linear_activation_scheme = "dynamic"
|
70
|
+
moe_activation_scheme = "static"
|
71
|
+
weight_block_size = [128, 128]
|
72
|
+
return cls(
|
73
|
+
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
74
|
+
is_checkpoint_w4afp8_serialized=is_checkpoint_w4afp8_serialized,
|
75
|
+
linear_activation_scheme=linear_activation_scheme,
|
76
|
+
moe_activation_scheme=moe_activation_scheme,
|
77
|
+
weight_block_size=weight_block_size,
|
78
|
+
)
|
79
|
+
|
80
|
+
def get_quant_method(
|
81
|
+
self, layer: torch.nn.Module, prefix: str
|
82
|
+
) -> Optional["QuantizeMethodBase"]:
|
83
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
84
|
+
|
85
|
+
if isinstance(layer, LinearBase):
|
86
|
+
if is_layer_skipped(prefix, self.ignored_layers):
|
87
|
+
return UnquantizedLinearMethod()
|
88
|
+
return Fp8LinearMethod(self)
|
89
|
+
elif isinstance(layer, FusedMoE):
|
90
|
+
return W4AFp8MoEMethod(self)
|
91
|
+
return None
|
92
|
+
|
93
|
+
def get_scaled_act_names(self) -> List[str]:
|
94
|
+
return []
|
95
|
+
|
96
|
+
|
97
|
+
class W4AFp8MoEMethod:
|
98
|
+
|
99
|
+
def __init__(self, quant_config: W4AFp8Config):
|
100
|
+
self.quant_config = quant_config
|
101
|
+
|
102
|
+
def create_weights(
|
103
|
+
self,
|
104
|
+
layer: Module,
|
105
|
+
num_experts_per_partition: int,
|
106
|
+
hidden_size: int,
|
107
|
+
intermediate_size: int,
|
108
|
+
params_dtype: torch.dtype,
|
109
|
+
**extra_weight_attrs,
|
110
|
+
):
|
111
|
+
assert "weight_loader" in extra_weight_attrs
|
112
|
+
|
113
|
+
# Fused gate_up_proj (column parallel)
|
114
|
+
w13_weight = torch.nn.Parameter(
|
115
|
+
torch.empty(
|
116
|
+
num_experts_per_partition,
|
117
|
+
intermediate_size * 2,
|
118
|
+
hidden_size // 2,
|
119
|
+
dtype=torch.int8,
|
120
|
+
),
|
121
|
+
requires_grad=False,
|
122
|
+
)
|
123
|
+
layer.register_parameter("w13_weight", w13_weight)
|
124
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
125
|
+
|
126
|
+
# down_proj (row parallel)
|
127
|
+
w2_weight = torch.nn.Parameter(
|
128
|
+
torch.empty(
|
129
|
+
num_experts_per_partition,
|
130
|
+
hidden_size,
|
131
|
+
intermediate_size // 2,
|
132
|
+
dtype=torch.int8,
|
133
|
+
),
|
134
|
+
requires_grad=False,
|
135
|
+
)
|
136
|
+
layer.register_parameter("w2_weight", w2_weight)
|
137
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
138
|
+
|
139
|
+
w13_weight_scale = torch.nn.Parameter(
|
140
|
+
torch.zeros(
|
141
|
+
num_experts_per_partition,
|
142
|
+
2 * intermediate_size,
|
143
|
+
hidden_size // self.quant_config.group_size,
|
144
|
+
dtype=torch.float32,
|
145
|
+
),
|
146
|
+
requires_grad=False,
|
147
|
+
)
|
148
|
+
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
149
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
150
|
+
|
151
|
+
w2_weight_scale = torch.nn.Parameter(
|
152
|
+
torch.zeros(
|
153
|
+
num_experts_per_partition,
|
154
|
+
hidden_size,
|
155
|
+
intermediate_size // self.quant_config.group_size,
|
156
|
+
dtype=torch.float32,
|
157
|
+
),
|
158
|
+
requires_grad=False,
|
159
|
+
)
|
160
|
+
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
161
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
162
|
+
|
163
|
+
# Input scales
|
164
|
+
w13_input_scale = torch.nn.Parameter(
|
165
|
+
torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16),
|
166
|
+
requires_grad=False,
|
167
|
+
)
|
168
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
169
|
+
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
170
|
+
|
171
|
+
w2_input_scale = torch.nn.Parameter(
|
172
|
+
torch.ones(num_experts_per_partition, dtype=torch.bfloat16),
|
173
|
+
requires_grad=False,
|
174
|
+
)
|
175
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
176
|
+
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
177
|
+
|
178
|
+
# Pre-populate the strides
|
179
|
+
device = layer.w13_weight.device
|
180
|
+
|
181
|
+
self.a_strides1 = torch.full(
|
182
|
+
(num_experts_per_partition, 3),
|
183
|
+
hidden_size,
|
184
|
+
device=device,
|
185
|
+
dtype=torch.int64,
|
186
|
+
)
|
187
|
+
self.c_strides1 = torch.full(
|
188
|
+
(num_experts_per_partition, 3),
|
189
|
+
2 * intermediate_size,
|
190
|
+
device=device,
|
191
|
+
dtype=torch.int64,
|
192
|
+
)
|
193
|
+
self.a_strides2 = torch.full(
|
194
|
+
(num_experts_per_partition, 3),
|
195
|
+
intermediate_size,
|
196
|
+
device=device,
|
197
|
+
dtype=torch.int64,
|
198
|
+
)
|
199
|
+
self.c_strides2 = torch.full(
|
200
|
+
(num_experts_per_partition, 3),
|
201
|
+
hidden_size,
|
202
|
+
device=device,
|
203
|
+
dtype=torch.int64,
|
204
|
+
)
|
205
|
+
self.b_strides1 = self.a_strides1
|
206
|
+
self.s_strides13 = self.c_strides1
|
207
|
+
self.b_strides2 = self.a_strides2
|
208
|
+
self.s_strides2 = self.c_strides2
|
209
|
+
|
210
|
+
self.expert_offsets = torch.empty(
|
211
|
+
(num_experts_per_partition + 1), dtype=torch.int32, device=device
|
212
|
+
)
|
213
|
+
self.problem_sizes1 = torch.empty(
|
214
|
+
(num_experts_per_partition, 3), dtype=torch.int32, device=device
|
215
|
+
)
|
216
|
+
self.problem_sizes2 = torch.empty(
|
217
|
+
(num_experts_per_partition, 3), dtype=torch.int32, device=device
|
218
|
+
)
|
219
|
+
|
220
|
+
return
|
221
|
+
|
222
|
+
def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
|
223
|
+
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
|
224
|
+
s_shape = scales.shape
|
225
|
+
# Reshape to separate groups of 4
|
226
|
+
scales_interleaved = scales.reshape(
|
227
|
+
s_shape[0], s_shape[1], (s_shape[2] // 4), 4
|
228
|
+
)
|
229
|
+
# Permute dimensions to interleave
|
230
|
+
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
|
231
|
+
# Reshape back to original dimensions but with interleaved values
|
232
|
+
scales_interleaved = scales_interleaved.reshape(
|
233
|
+
s_shape[0], s_shape[2] // 4, s_shape[1] * 4
|
234
|
+
)
|
235
|
+
return scales_interleaved.contiguous()
|
236
|
+
|
237
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
238
|
+
dtype = torch.bfloat16
|
239
|
+
device = layer.w2_weight.device
|
240
|
+
|
241
|
+
# Interleave w13_weight_scale (gate_up_proj)
|
242
|
+
w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
|
243
|
+
w13_weight_scale = self._interleave_scales(w13_weight_scale)
|
244
|
+
layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
|
245
|
+
|
246
|
+
# Interleave w2_weight_scale (down_proj)
|
247
|
+
w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
|
248
|
+
w2_weight_scale = self._interleave_scales(w2_weight_scale)
|
249
|
+
layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
|
250
|
+
|
251
|
+
# Process input scales
|
252
|
+
w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item()
|
253
|
+
new_w13_input_scale = torch.tensor(
|
254
|
+
[w13_input_scale_max],
|
255
|
+
dtype=dtype,
|
256
|
+
device=device,
|
257
|
+
)
|
258
|
+
layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False)
|
259
|
+
|
260
|
+
w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item()
|
261
|
+
new_w2_input_scale = torch.tensor(
|
262
|
+
[w2_input_scale_max], dtype=dtype, device=device
|
263
|
+
)
|
264
|
+
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
|
@@ -1,5 +1,6 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/layers/vocab_parallel_embedding.py
|
2
2
|
|
3
|
+
import logging
|
3
4
|
from dataclasses import dataclass
|
4
5
|
from typing import List, Optional, Sequence, Tuple
|
5
6
|
|
@@ -28,6 +29,8 @@ DEFAULT_VOCAB_PADDING_SIZE = 64
|
|
28
29
|
_is_cpu_amx_available = cpu_has_amx_support()
|
29
30
|
_is_cpu = is_cpu()
|
30
31
|
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
31
34
|
|
32
35
|
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
33
36
|
"""Unquantized method for embeddings."""
|
@@ -562,9 +565,12 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
562
565
|
)
|
563
566
|
self.quant_config = quant_config
|
564
567
|
|
565
|
-
# We only support pack LMHead if it's not quantized.
|
566
|
-
if
|
567
|
-
self
|
568
|
+
# We only support pack LMHead if it's not quantized.
|
569
|
+
if _is_cpu and _is_cpu_amx_available:
|
570
|
+
if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16:
|
571
|
+
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
572
|
+
else:
|
573
|
+
logger.warning("The weight of LmHead is not packed")
|
568
574
|
|
569
575
|
if bias:
|
570
576
|
self.bias = Parameter(
|
@@ -31,28 +31,44 @@ def _gate_up_lora_b_kernel(
|
|
31
31
|
BLOCK_S: tl.constexpr,
|
32
32
|
BLOCK_N: tl.constexpr,
|
33
33
|
BLOCK_K: tl.constexpr,
|
34
|
-
# For fused output scaling
|
35
|
-
fuse_scaling_add,
|
34
|
+
# For fused output scaling
|
36
35
|
scalings,
|
37
36
|
):
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
37
|
+
"""
|
38
|
+
This kernel packs 2 sgemms (gate/up) into a single kernel. The multiplication
|
39
|
+
results are accumulated into the output tensor.
|
40
|
+
|
41
|
+
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
42
|
+
the convention in pytorch where the product of two matrices of shape (m, 0)
|
43
|
+
and (0, n) is an all-zero matrix of shape (m, n).
|
44
|
+
|
45
|
+
Args:
|
46
|
+
x (Tensor): The input tensor, which is the result of the LoRA A projection.
|
47
|
+
Shape: (s, 2 * K), where s is the sum of all sequence lengths in the
|
48
|
+
batch and K is the maximum LoRA rank.
|
49
|
+
weights (Tensor): The LoRA B weights for all adapters.
|
50
|
+
Shape: (num_lora, 2 * output_dim, K).
|
51
|
+
output (Tensor): The output tensor where the result is stored.
|
52
|
+
Shape: (s, 2 * output_dim).
|
53
|
+
"""
|
43
54
|
# output_dim >> K
|
44
55
|
|
45
56
|
# Current block computes sequence with batch_id,
|
46
57
|
# which starts from row seg_start of x with length seg_len.
|
47
58
|
# gate_up_id decides which of gate or up (0: gate, 1: up)
|
48
59
|
batch_id = tl.program_id(axis=2)
|
60
|
+
w_index = tl.load(weight_indices + batch_id)
|
61
|
+
rank = tl.load(lora_ranks + w_index)
|
62
|
+
|
63
|
+
# If rank is 0, this kernel is a no-op.
|
64
|
+
if rank == 0:
|
65
|
+
return
|
66
|
+
|
49
67
|
gate_up_id = tl.program_id(axis=1)
|
50
68
|
pid = tl.program_id(axis=0)
|
51
69
|
seg_len = tl.load(seg_lens + batch_id)
|
52
|
-
w_index = tl.load(weight_indices + batch_id)
|
53
70
|
seg_start = tl.load(seg_indptr + batch_id)
|
54
71
|
n_start = gate_up_id * output_dim # offset on output dim
|
55
|
-
rank = tl.load(lora_ranks + w_index)
|
56
72
|
scaling = tl.load(scalings + w_index)
|
57
73
|
|
58
74
|
# Adjust K (rank) according to the specific LoRA adapter
|
@@ -82,14 +98,13 @@ def _gate_up_lora_b_kernel(
|
|
82
98
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
83
99
|
x_tile = tl.load(
|
84
100
|
x_ptrs,
|
85
|
-
mask=(s_offset[:, None] < seg_len)
|
86
|
-
and (k_offset[None, :] < K - k * BLOCK_K),
|
101
|
+
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
87
102
|
other=0.0,
|
88
103
|
)
|
89
104
|
w_tile = tl.load(
|
90
105
|
w_ptrs,
|
91
106
|
mask=(k_offset[:, None] < K - k * BLOCK_K)
|
92
|
-
|
107
|
+
& (n_offset[None, :] < output_dim),
|
93
108
|
other=0.0,
|
94
109
|
)
|
95
110
|
partial_sum += tl.dot(x_tile, w_tile)
|
@@ -103,9 +118,8 @@ def _gate_up_lora_b_kernel(
|
|
103
118
|
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
104
119
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
105
120
|
)
|
106
|
-
output_mask = (s_offset[:, None] < seg_len)
|
107
|
-
|
108
|
-
partial_sum += tl.load(output_ptr, mask=output_mask)
|
121
|
+
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < output_dim)
|
122
|
+
partial_sum += tl.load(output_ptr, mask=output_mask)
|
109
123
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
110
124
|
|
111
125
|
|
@@ -143,11 +157,9 @@ def gate_up_lora_b_fwd(
|
|
143
157
|
)
|
144
158
|
|
145
159
|
if base_output is None:
|
146
|
-
output = torch.
|
147
|
-
fuse_scaling_add = False
|
160
|
+
output = torch.zeros((s, 2 * output_dim), device=x.device, dtype=x.dtype)
|
148
161
|
else:
|
149
162
|
output = base_output
|
150
|
-
fuse_scaling_add = True
|
151
163
|
|
152
164
|
_gate_up_lora_b_kernel[grid_b](
|
153
165
|
x,
|
@@ -169,7 +181,6 @@ def gate_up_lora_b_fwd(
|
|
169
181
|
BLOCK_S,
|
170
182
|
BLOCK_OUT,
|
171
183
|
BLOCK_R,
|
172
|
-
fuse_scaling_add,
|
173
184
|
batch_info.scalings,
|
174
185
|
)
|
175
186
|
|
@@ -33,29 +33,45 @@ def _qkv_lora_b_kernel(
|
|
33
33
|
BLOCK_S: tl.constexpr,
|
34
34
|
BLOCK_N: tl.constexpr,
|
35
35
|
BLOCK_K: tl.constexpr,
|
36
|
-
# For fused output scaling
|
37
|
-
fuse_scaling_add,
|
36
|
+
# For fused output scaling
|
38
37
|
scalings,
|
39
38
|
):
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
39
|
+
"""
|
40
|
+
This kernel packs 3 sgemms (q/k/v) into a single kernel. The multiplication
|
41
|
+
results are accumulated into the output tensor.
|
42
|
+
|
43
|
+
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
44
|
+
the convention in pytorch where the product of two matrices of shape (m, 0)
|
45
|
+
and (0, n) is an all-zero matrix of shape (m, n).
|
46
|
+
|
47
|
+
Args:
|
48
|
+
x (Tensor): The input tensor, which is the result of the LoRA A projection.
|
49
|
+
Shape: (s, 3 * K), where s is the sum of all sequence lengths in the
|
50
|
+
batch and K is the maximum LoRA rank. The second dimension is partitioned
|
51
|
+
for Q, K, and V.
|
52
|
+
weights (Tensor): The LoRA B weights for all adapters.
|
53
|
+
Shape: (num_lora, N_Q + 2 * N_KV, K).
|
54
|
+
output (Tensor): The output tensor where the result is stored.
|
55
|
+
Shape: (s, N_Q + 2 * N_KV).
|
56
|
+
"""
|
46
57
|
|
47
58
|
# Current block computes sequence with batch_id,
|
48
59
|
# which starts from row seg_start of x with length seg_len.
|
49
60
|
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
|
50
61
|
batch_id = tl.program_id(axis=2)
|
62
|
+
w_index = tl.load(weight_indices + batch_id)
|
63
|
+
rank = tl.load(lora_ranks + w_index)
|
64
|
+
|
65
|
+
# If rank is 0, this kernel is a no-op.
|
66
|
+
if rank == 0:
|
67
|
+
return
|
68
|
+
|
51
69
|
qkv_id = tl.program_id(axis=1)
|
52
70
|
pid = tl.program_id(axis=0)
|
53
71
|
seg_len = tl.load(seg_lens + batch_id)
|
54
|
-
w_index = tl.load(weight_indices + batch_id)
|
55
72
|
seg_start = tl.load(seg_indptr + batch_id)
|
56
73
|
n_start = tl.load(n_offs + qkv_id)
|
57
74
|
n_size = tl.load(n_offs + qkv_id + 1) - n_start
|
58
|
-
rank = tl.load(lora_ranks + w_index)
|
59
75
|
scaling = tl.load(scalings + w_index)
|
60
76
|
# Adjust K (rank) according to the specific LoRA adapter
|
61
77
|
K = tl.minimum(K, rank)
|
@@ -84,13 +100,12 @@ def _qkv_lora_b_kernel(
|
|
84
100
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
85
101
|
x_tile = tl.load(
|
86
102
|
x_ptrs,
|
87
|
-
mask=(s_offset[:, None] < seg_len)
|
88
|
-
and (k_offset[None, :] < K - k * BLOCK_K),
|
103
|
+
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
89
104
|
other=0.0,
|
90
105
|
)
|
91
106
|
w_tile = tl.load(
|
92
107
|
w_ptrs,
|
93
|
-
mask=(k_offset[:, None] < K - k * BLOCK_K)
|
108
|
+
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < n_size),
|
94
109
|
other=0.0,
|
95
110
|
)
|
96
111
|
partial_sum += tl.dot(x_tile, w_tile)
|
@@ -105,8 +120,7 @@ def _qkv_lora_b_kernel(
|
|
105
120
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
106
121
|
)
|
107
122
|
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
|
108
|
-
|
109
|
-
partial_sum += tl.load(output_ptr, mask=output_mask)
|
123
|
+
partial_sum += tl.load(output_ptr, mask=output_mask)
|
110
124
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
111
125
|
|
112
126
|
|
@@ -153,11 +167,9 @@ def qkv_lora_b_fwd(
|
|
153
167
|
)
|
154
168
|
|
155
169
|
if base_output is None:
|
156
|
-
output = torch.
|
157
|
-
fuse_scaling_add = False
|
170
|
+
output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype)
|
158
171
|
else:
|
159
172
|
output = base_output
|
160
|
-
fuse_scaling_add = True
|
161
173
|
|
162
174
|
_qkv_lora_b_kernel[grid_b](
|
163
175
|
x,
|
@@ -180,7 +192,6 @@ def qkv_lora_b_fwd(
|
|
180
192
|
BLOCK_S,
|
181
193
|
BLOCK_OUT,
|
182
194
|
BLOCK_R,
|
183
|
-
fuse_scaling_add,
|
184
195
|
batch_info.scalings,
|
185
196
|
)
|
186
197
|
|
@@ -33,19 +33,36 @@ def _sgemm_lora_a_kernel(
|
|
33
33
|
BLOCK_N: tl.constexpr,
|
34
34
|
BLOCK_K: tl.constexpr,
|
35
35
|
):
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
36
|
+
"""
|
37
|
+
Computes a segmented batched matrix multiplication for the LoRA A matrix.
|
38
|
+
|
39
|
+
The kernel ensures that output[seg_start:seg_start + seg_len, :rank * stack_num]
|
40
|
+
stores the product of the input `x` and the LoRA weights for the corresponding
|
41
|
+
sequence. This implies that when rank is 0, the kernel is essentially a no-op,
|
42
|
+
as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
|
43
|
+
|
44
|
+
Args:
|
45
|
+
x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
|
46
|
+
is the sum of all sequence lengths in the batch.
|
47
|
+
weights (torch.Tensor): The LoRA 'A' weights for all available adapters,
|
48
|
+
with shape `(num_lora, N, K)`.
|
49
|
+
output (torch.Tensor): The output tensor of shape `(s, N)`.
|
50
|
+
"""
|
40
51
|
|
41
52
|
# Current block computes sequence with batch_id,
|
42
53
|
# which starts from row seg_start of x with length seg_len
|
43
54
|
batch_id = tl.program_id(axis=1)
|
44
|
-
pid = tl.program_id(axis=0)
|
45
|
-
seg_len = tl.load(seg_lens + batch_id)
|
46
55
|
w_index = tl.load(weight_indices + batch_id)
|
47
|
-
seg_start = tl.load(seg_indptr + batch_id)
|
48
56
|
rank = tl.load(lora_ranks + w_index)
|
57
|
+
|
58
|
+
# If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
|
59
|
+
if rank == 0:
|
60
|
+
return
|
61
|
+
|
62
|
+
pid = tl.program_id(axis=0)
|
63
|
+
seg_start = tl.load(seg_indptr + batch_id)
|
64
|
+
seg_len = tl.load(seg_lens + batch_id)
|
65
|
+
|
49
66
|
# Adjust N (stack_num * max_rank) according to the specific LoRA adapter
|
50
67
|
N = tl.minimum(N, rank * stack_num)
|
51
68
|
|
@@ -72,13 +89,12 @@ def _sgemm_lora_a_kernel(
|
|
72
89
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
73
90
|
x_tile = tl.load(
|
74
91
|
x_ptrs,
|
75
|
-
mask=(s_offset[:, None] < seg_len)
|
76
|
-
and (k_offset[None, :] < K - k * BLOCK_K),
|
92
|
+
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
77
93
|
other=0.0,
|
78
94
|
)
|
79
95
|
w_tile = tl.load(
|
80
96
|
w_ptrs,
|
81
|
-
mask=(k_offset[:, None] < K - k * BLOCK_K)
|
97
|
+
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < N),
|
82
98
|
other=0.0,
|
83
99
|
)
|
84
100
|
partial_sum += tl.dot(x_tile, w_tile)
|
@@ -91,7 +107,7 @@ def _sgemm_lora_a_kernel(
|
|
91
107
|
output_ptr = (output + seg_start * output_stride_0) + (
|
92
108
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
93
109
|
)
|
94
|
-
output_mask = (s_offset[:, None] < seg_len)
|
110
|
+
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < N)
|
95
111
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
96
112
|
|
97
113
|
|
@@ -31,22 +31,39 @@ def _sgemm_lora_b_kernel(
|
|
31
31
|
BLOCK_S: tl.constexpr,
|
32
32
|
BLOCK_N: tl.constexpr,
|
33
33
|
BLOCK_K: tl.constexpr,
|
34
|
-
# For fused output scaling
|
35
|
-
fuse_scaling_add,
|
34
|
+
# For fused output scaling
|
36
35
|
scalings,
|
37
36
|
):
|
38
|
-
|
39
|
-
|
40
|
-
|
37
|
+
"""
|
38
|
+
Computes a segmented batched matrix multiplication for the LoRA B matrix
|
39
|
+
and adds the result to the output in-place.
|
40
|
+
|
41
|
+
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
42
|
+
the convention in pytorch where the product of two matrices of shape (m, 0)
|
43
|
+
and (0, n) is an all-zero matrix of shape (m, n).
|
44
|
+
|
45
|
+
Args:
|
46
|
+
x (torch.Tensor): The intermediate tensor from the LoRA 'A' multiplication,
|
47
|
+
of shape `(s, K)`, where `s` is the total number of tokens.
|
48
|
+
weights (torch.Tensor): The LoRA 'B' weights for all available adapters,
|
49
|
+
with shape `(num_lora, N, K)`.
|
50
|
+
output (torch.Tensor): The output tensor of shape `(s, N)`. This can be
|
51
|
+
the base model's output for a fused add operation.
|
52
|
+
"""
|
41
53
|
|
42
54
|
# Current block computes sequence with batch_id,
|
43
55
|
# which starts from row seg_start of x with length seg_len
|
44
56
|
batch_id = tl.program_id(axis=1)
|
57
|
+
w_index = tl.load(weight_indices + batch_id)
|
58
|
+
rank = tl.load(lora_ranks + w_index)
|
59
|
+
|
60
|
+
# If rank is 0, this kernel is a no-op.
|
61
|
+
if rank == 0:
|
62
|
+
return
|
63
|
+
|
45
64
|
pid = tl.program_id(axis=0)
|
46
65
|
seg_len = tl.load(seg_lens + batch_id)
|
47
|
-
w_index = tl.load(weight_indices + batch_id)
|
48
66
|
seg_start = tl.load(seg_indptr + batch_id)
|
49
|
-
rank = tl.load(lora_ranks + w_index)
|
50
67
|
scaling = tl.load(scalings + w_index)
|
51
68
|
# Adjust K (rank) according to the specific LoRA adapter
|
52
69
|
K = tl.minimum(K, rank)
|
@@ -74,8 +91,7 @@ def _sgemm_lora_b_kernel(
|
|
74
91
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
75
92
|
x_tile = tl.load(
|
76
93
|
x_ptrs,
|
77
|
-
mask=(s_offset[:, None] < seg_len)
|
78
|
-
and (k_offset[None, :] < K - k * BLOCK_K),
|
94
|
+
mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
|
79
95
|
other=0.0,
|
80
96
|
)
|
81
97
|
w_tile = tl.load(
|
@@ -95,8 +111,7 @@ def _sgemm_lora_b_kernel(
|
|
95
111
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
96
112
|
)
|
97
113
|
output_mask = s_offset[:, None] < seg_len
|
98
|
-
|
99
|
-
partial_sum += tl.load(output_ptr, mask=output_mask)
|
114
|
+
partial_sum += tl.load(output_ptr, mask=output_mask)
|
100
115
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
101
116
|
|
102
117
|
|
@@ -132,11 +147,9 @@ def sgemm_lora_b_fwd(
|
|
132
147
|
)
|
133
148
|
|
134
149
|
if base_output is None:
|
135
|
-
output = torch.
|
136
|
-
fuse_scaling_add = False
|
150
|
+
output = torch.zeros((S, N), device=x.device, dtype=x.dtype)
|
137
151
|
else:
|
138
152
|
output = base_output
|
139
|
-
fuse_scaling_add = True
|
140
153
|
|
141
154
|
_sgemm_lora_b_kernel[grid](
|
142
155
|
x,
|
@@ -158,7 +171,6 @@ def sgemm_lora_b_fwd(
|
|
158
171
|
BLOCK_S,
|
159
172
|
BLOCK_N,
|
160
173
|
BLOCK_R,
|
161
|
-
fuse_scaling_add,
|
162
174
|
batch_info.scalings,
|
163
175
|
)
|
164
176
|
return output
|