sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +12 -1
  3. sglang/srt/conversation.py +35 -1
  4. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  5. sglang/srt/entrypoints/http_server_engine.py +1 -1
  6. sglang/srt/layers/communicator.py +3 -1
  7. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  8. sglang/srt/layers/layernorm.py +2 -2
  9. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  10. sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
  11. sglang/srt/layers/moe/ep_moe/layer.py +140 -2
  12. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  13. sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
  14. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  15. sglang/srt/layers/quantization/__init__.py +2 -0
  16. sglang/srt/layers/quantization/fp8.py +28 -7
  17. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  18. sglang/srt/layers/quantization/w4afp8.py +264 -0
  19. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  20. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  21. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  22. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  23. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  24. sglang/srt/managers/cache_controller.py +41 -195
  25. sglang/srt/managers/io_struct.py +8 -1
  26. sglang/srt/managers/mm_utils.py +4 -2
  27. sglang/srt/managers/schedule_batch.py +1 -1
  28. sglang/srt/managers/scheduler.py +17 -5
  29. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  30. sglang/srt/mem_cache/memory_pool.py +113 -63
  31. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  32. sglang/srt/mem_cache/radix_cache.py +8 -4
  33. sglang/srt/models/deepseek_v2.py +16 -2
  34. sglang/srt/models/mllama4.py +360 -79
  35. sglang/srt/multimodal/mm_utils.py +2 -2
  36. sglang/srt/multimodal/processors/mllama4.py +62 -60
  37. sglang/srt/server_args.py +15 -0
  38. sglang/srt/two_batch_overlap.py +3 -0
  39. sglang/srt/utils.py +37 -17
  40. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  41. sglang/utils.py +5 -5
  42. sglang/version.py +1 -1
  43. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
  44. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
  45. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  46. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  47. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,264 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import torch
5
+ from torch.nn import Module
6
+ from torch.nn.parameter import Parameter
7
+
8
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
9
+ from sglang.srt.layers.quantization.base_config import (
10
+ QuantizationConfig,
11
+ QuantizeMethodBase,
12
+ )
13
+ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
14
+ from sglang.srt.layers.quantization.utils import is_layer_skipped
15
+ from sglang.srt.utils import set_weight_attrs
16
+
17
+ ACTIVATION_SCHEMES = ["static", "dynamic"]
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class W4AFp8Config(QuantizationConfig):
23
+ """Config class for MIXED_PRECISION W4AFp8."""
24
+
25
+ def __init__(
26
+ self,
27
+ is_checkpoint_fp8_serialized: bool = True,
28
+ is_checkpoint_w4afp8_serialized: bool = True,
29
+ linear_activation_scheme: str = "dynamic",
30
+ moe_activation_scheme: str = "static",
31
+ ignored_layers: Optional[List[str]] = None,
32
+ weight_block_size: Optional[List[int]] = None,
33
+ group_size: int = 128,
34
+ ) -> None:
35
+ super().__init__()
36
+ self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
37
+ self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized
38
+ if is_checkpoint_w4afp8_serialized:
39
+ logger.warning("Detected w4afp8 checkpoint. Please note that")
40
+ if moe_activation_scheme not in ACTIVATION_SCHEMES:
41
+ raise ValueError(f"Unsupported activation scheme {moe_activation_scheme}")
42
+ self.linear_activation_scheme = linear_activation_scheme
43
+ self.moe_activation_scheme = moe_activation_scheme
44
+ self.ignored_layers = ignored_layers or []
45
+ self.weight_block_size = [128, 128]
46
+ self.group_size = group_size
47
+
48
+ @classmethod
49
+ def get_name(cls) -> str:
50
+ return "w4afp8"
51
+
52
+ @classmethod
53
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
54
+ return [torch.bfloat16, torch.float8_e4m3fn]
55
+
56
+ @classmethod
57
+ def get_min_capability(cls) -> int:
58
+ return 90
59
+
60
+ @classmethod
61
+ def get_config_filenames(cls) -> List[str]:
62
+ return []
63
+
64
+ @classmethod
65
+ def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config":
66
+ quant_method = cls.get_from_keys(config, ["quant_method"])
67
+ is_checkpoint_fp8_serialized = "fp8" in quant_method
68
+ is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
69
+ linear_activation_scheme = "dynamic"
70
+ moe_activation_scheme = "static"
71
+ weight_block_size = [128, 128]
72
+ return cls(
73
+ is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
74
+ is_checkpoint_w4afp8_serialized=is_checkpoint_w4afp8_serialized,
75
+ linear_activation_scheme=linear_activation_scheme,
76
+ moe_activation_scheme=moe_activation_scheme,
77
+ weight_block_size=weight_block_size,
78
+ )
79
+
80
+ def get_quant_method(
81
+ self, layer: torch.nn.Module, prefix: str
82
+ ) -> Optional["QuantizeMethodBase"]:
83
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
84
+
85
+ if isinstance(layer, LinearBase):
86
+ if is_layer_skipped(prefix, self.ignored_layers):
87
+ return UnquantizedLinearMethod()
88
+ return Fp8LinearMethod(self)
89
+ elif isinstance(layer, FusedMoE):
90
+ return W4AFp8MoEMethod(self)
91
+ return None
92
+
93
+ def get_scaled_act_names(self) -> List[str]:
94
+ return []
95
+
96
+
97
+ class W4AFp8MoEMethod:
98
+
99
+ def __init__(self, quant_config: W4AFp8Config):
100
+ self.quant_config = quant_config
101
+
102
+ def create_weights(
103
+ self,
104
+ layer: Module,
105
+ num_experts_per_partition: int,
106
+ hidden_size: int,
107
+ intermediate_size: int,
108
+ params_dtype: torch.dtype,
109
+ **extra_weight_attrs,
110
+ ):
111
+ assert "weight_loader" in extra_weight_attrs
112
+
113
+ # Fused gate_up_proj (column parallel)
114
+ w13_weight = torch.nn.Parameter(
115
+ torch.empty(
116
+ num_experts_per_partition,
117
+ intermediate_size * 2,
118
+ hidden_size // 2,
119
+ dtype=torch.int8,
120
+ ),
121
+ requires_grad=False,
122
+ )
123
+ layer.register_parameter("w13_weight", w13_weight)
124
+ set_weight_attrs(w13_weight, extra_weight_attrs)
125
+
126
+ # down_proj (row parallel)
127
+ w2_weight = torch.nn.Parameter(
128
+ torch.empty(
129
+ num_experts_per_partition,
130
+ hidden_size,
131
+ intermediate_size // 2,
132
+ dtype=torch.int8,
133
+ ),
134
+ requires_grad=False,
135
+ )
136
+ layer.register_parameter("w2_weight", w2_weight)
137
+ set_weight_attrs(w2_weight, extra_weight_attrs)
138
+
139
+ w13_weight_scale = torch.nn.Parameter(
140
+ torch.zeros(
141
+ num_experts_per_partition,
142
+ 2 * intermediate_size,
143
+ hidden_size // self.quant_config.group_size,
144
+ dtype=torch.float32,
145
+ ),
146
+ requires_grad=False,
147
+ )
148
+ layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
149
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
150
+
151
+ w2_weight_scale = torch.nn.Parameter(
152
+ torch.zeros(
153
+ num_experts_per_partition,
154
+ hidden_size,
155
+ intermediate_size // self.quant_config.group_size,
156
+ dtype=torch.float32,
157
+ ),
158
+ requires_grad=False,
159
+ )
160
+ layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
161
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
162
+
163
+ # Input scales
164
+ w13_input_scale = torch.nn.Parameter(
165
+ torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16),
166
+ requires_grad=False,
167
+ )
168
+ layer.register_parameter("w13_input_scale", w13_input_scale)
169
+ set_weight_attrs(w13_input_scale, extra_weight_attrs)
170
+
171
+ w2_input_scale = torch.nn.Parameter(
172
+ torch.ones(num_experts_per_partition, dtype=torch.bfloat16),
173
+ requires_grad=False,
174
+ )
175
+ layer.register_parameter("w2_input_scale", w2_input_scale)
176
+ set_weight_attrs(w2_input_scale, extra_weight_attrs)
177
+
178
+ # Pre-populate the strides
179
+ device = layer.w13_weight.device
180
+
181
+ self.a_strides1 = torch.full(
182
+ (num_experts_per_partition, 3),
183
+ hidden_size,
184
+ device=device,
185
+ dtype=torch.int64,
186
+ )
187
+ self.c_strides1 = torch.full(
188
+ (num_experts_per_partition, 3),
189
+ 2 * intermediate_size,
190
+ device=device,
191
+ dtype=torch.int64,
192
+ )
193
+ self.a_strides2 = torch.full(
194
+ (num_experts_per_partition, 3),
195
+ intermediate_size,
196
+ device=device,
197
+ dtype=torch.int64,
198
+ )
199
+ self.c_strides2 = torch.full(
200
+ (num_experts_per_partition, 3),
201
+ hidden_size,
202
+ device=device,
203
+ dtype=torch.int64,
204
+ )
205
+ self.b_strides1 = self.a_strides1
206
+ self.s_strides13 = self.c_strides1
207
+ self.b_strides2 = self.a_strides2
208
+ self.s_strides2 = self.c_strides2
209
+
210
+ self.expert_offsets = torch.empty(
211
+ (num_experts_per_partition + 1), dtype=torch.int32, device=device
212
+ )
213
+ self.problem_sizes1 = torch.empty(
214
+ (num_experts_per_partition, 3), dtype=torch.int32, device=device
215
+ )
216
+ self.problem_sizes2 = torch.empty(
217
+ (num_experts_per_partition, 3), dtype=torch.int32, device=device
218
+ )
219
+
220
+ return
221
+
222
+ def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
223
+ """Interleave scales in groups of 4 similar to TRT-LLM implementation."""
224
+ s_shape = scales.shape
225
+ # Reshape to separate groups of 4
226
+ scales_interleaved = scales.reshape(
227
+ s_shape[0], s_shape[1], (s_shape[2] // 4), 4
228
+ )
229
+ # Permute dimensions to interleave
230
+ scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
231
+ # Reshape back to original dimensions but with interleaved values
232
+ scales_interleaved = scales_interleaved.reshape(
233
+ s_shape[0], s_shape[2] // 4, s_shape[1] * 4
234
+ )
235
+ return scales_interleaved.contiguous()
236
+
237
+ def process_weights_after_loading(self, layer: Module) -> None:
238
+ dtype = torch.bfloat16
239
+ device = layer.w2_weight.device
240
+
241
+ # Interleave w13_weight_scale (gate_up_proj)
242
+ w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
243
+ w13_weight_scale = self._interleave_scales(w13_weight_scale)
244
+ layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
245
+
246
+ # Interleave w2_weight_scale (down_proj)
247
+ w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
248
+ w2_weight_scale = self._interleave_scales(w2_weight_scale)
249
+ layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
250
+
251
+ # Process input scales
252
+ w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item()
253
+ new_w13_input_scale = torch.tensor(
254
+ [w13_input_scale_max],
255
+ dtype=dtype,
256
+ device=device,
257
+ )
258
+ layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False)
259
+
260
+ w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item()
261
+ new_w2_input_scale = torch.tensor(
262
+ [w2_input_scale_max], dtype=dtype, device=device
263
+ )
264
+ layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
@@ -1,5 +1,6 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/layers/vocab_parallel_embedding.py
2
2
 
3
+ import logging
3
4
  from dataclasses import dataclass
4
5
  from typing import List, Optional, Sequence, Tuple
5
6
 
@@ -28,6 +29,8 @@ DEFAULT_VOCAB_PADDING_SIZE = 64
28
29
  _is_cpu_amx_available = cpu_has_amx_support()
29
30
  _is_cpu = is_cpu()
30
31
 
32
+ logger = logging.getLogger(__name__)
33
+
31
34
 
32
35
  class UnquantizedEmbeddingMethod(QuantizeMethodBase):
33
36
  """Unquantized method for embeddings."""
@@ -562,9 +565,12 @@ class ParallelLMHead(VocabParallelEmbedding):
562
565
  )
563
566
  self.quant_config = quant_config
564
567
 
565
- # We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight"
566
- if self.quant_config is None and _is_cpu and _is_cpu_amx_available:
567
- self.quant_method = PackWeightMethod(weight_names=["weight"])
568
+ # We only support pack LMHead if it's not quantized.
569
+ if _is_cpu and _is_cpu_amx_available:
570
+ if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16:
571
+ self.quant_method = PackWeightMethod(weight_names=["weight"])
572
+ else:
573
+ logger.warning("The weight of LmHead is not packed")
568
574
 
569
575
  if bias:
570
576
  self.bias = Parameter(
@@ -31,28 +31,44 @@ def _gate_up_lora_b_kernel(
31
31
  BLOCK_S: tl.constexpr,
32
32
  BLOCK_N: tl.constexpr,
33
33
  BLOCK_K: tl.constexpr,
34
- # For fused output scaling and adding
35
- fuse_scaling_add,
34
+ # For fused output scaling
36
35
  scalings,
37
36
  ):
38
- # This kernel packs 2 sgemms (gate/up) into a single kernel.
39
-
40
- # x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
41
- # weights: (num_lora, 2 * output_dim, K)
42
- # output: (s, 2 * output_dim)
37
+ """
38
+ This kernel packs 2 sgemms (gate/up) into a single kernel. The multiplication
39
+ results are accumulated into the output tensor.
40
+
41
+ When a sequence's rank is 0, the kernel is essentially a no-op, following
42
+ the convention in pytorch where the product of two matrices of shape (m, 0)
43
+ and (0, n) is an all-zero matrix of shape (m, n).
44
+
45
+ Args:
46
+ x (Tensor): The input tensor, which is the result of the LoRA A projection.
47
+ Shape: (s, 2 * K), where s is the sum of all sequence lengths in the
48
+ batch and K is the maximum LoRA rank.
49
+ weights (Tensor): The LoRA B weights for all adapters.
50
+ Shape: (num_lora, 2 * output_dim, K).
51
+ output (Tensor): The output tensor where the result is stored.
52
+ Shape: (s, 2 * output_dim).
53
+ """
43
54
  # output_dim >> K
44
55
 
45
56
  # Current block computes sequence with batch_id,
46
57
  # which starts from row seg_start of x with length seg_len.
47
58
  # gate_up_id decides which of gate or up (0: gate, 1: up)
48
59
  batch_id = tl.program_id(axis=2)
60
+ w_index = tl.load(weight_indices + batch_id)
61
+ rank = tl.load(lora_ranks + w_index)
62
+
63
+ # If rank is 0, this kernel is a no-op.
64
+ if rank == 0:
65
+ return
66
+
49
67
  gate_up_id = tl.program_id(axis=1)
50
68
  pid = tl.program_id(axis=0)
51
69
  seg_len = tl.load(seg_lens + batch_id)
52
- w_index = tl.load(weight_indices + batch_id)
53
70
  seg_start = tl.load(seg_indptr + batch_id)
54
71
  n_start = gate_up_id * output_dim # offset on output dim
55
- rank = tl.load(lora_ranks + w_index)
56
72
  scaling = tl.load(scalings + w_index)
57
73
 
58
74
  # Adjust K (rank) according to the specific LoRA adapter
@@ -82,14 +98,13 @@ def _gate_up_lora_b_kernel(
82
98
  for k in range(0, tl.cdiv(K, BLOCK_K)):
83
99
  x_tile = tl.load(
84
100
  x_ptrs,
85
- mask=(s_offset[:, None] < seg_len)
86
- and (k_offset[None, :] < K - k * BLOCK_K),
101
+ mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
87
102
  other=0.0,
88
103
  )
89
104
  w_tile = tl.load(
90
105
  w_ptrs,
91
106
  mask=(k_offset[:, None] < K - k * BLOCK_K)
92
- and (n_offset[None, :] < output_dim),
107
+ & (n_offset[None, :] < output_dim),
93
108
  other=0.0,
94
109
  )
95
110
  partial_sum += tl.dot(x_tile, w_tile)
@@ -103,9 +118,8 @@ def _gate_up_lora_b_kernel(
103
118
  output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
104
119
  s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
105
120
  )
106
- output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim)
107
- if fuse_scaling_add:
108
- partial_sum += tl.load(output_ptr, mask=output_mask)
121
+ output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < output_dim)
122
+ partial_sum += tl.load(output_ptr, mask=output_mask)
109
123
  tl.store(output_ptr, partial_sum, mask=output_mask)
110
124
 
111
125
 
@@ -143,11 +157,9 @@ def gate_up_lora_b_fwd(
143
157
  )
144
158
 
145
159
  if base_output is None:
146
- output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype)
147
- fuse_scaling_add = False
160
+ output = torch.zeros((s, 2 * output_dim), device=x.device, dtype=x.dtype)
148
161
  else:
149
162
  output = base_output
150
- fuse_scaling_add = True
151
163
 
152
164
  _gate_up_lora_b_kernel[grid_b](
153
165
  x,
@@ -169,7 +181,6 @@ def gate_up_lora_b_fwd(
169
181
  BLOCK_S,
170
182
  BLOCK_OUT,
171
183
  BLOCK_R,
172
- fuse_scaling_add,
173
184
  batch_info.scalings,
174
185
  )
175
186
 
@@ -33,29 +33,45 @@ def _qkv_lora_b_kernel(
33
33
  BLOCK_S: tl.constexpr,
34
34
  BLOCK_N: tl.constexpr,
35
35
  BLOCK_K: tl.constexpr,
36
- # For fused output scaling and adding
37
- fuse_scaling_add,
36
+ # For fused output scaling
38
37
  scalings,
39
38
  ):
40
- # This kernel packs 3 sgemms (q/k/v) into a single kernel.
41
-
42
- # x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
43
- # weights: (num_lora, N_Q + 2 * N_KV, K)
44
- # output: (s, N_Q + 2 * N_KV)
45
- # N_Q >> K, N_KV >> K
39
+ """
40
+ This kernel packs 3 sgemms (q/k/v) into a single kernel. The multiplication
41
+ results are accumulated into the output tensor.
42
+
43
+ When a sequence's rank is 0, the kernel is essentially a no-op, following
44
+ the convention in pytorch where the product of two matrices of shape (m, 0)
45
+ and (0, n) is an all-zero matrix of shape (m, n).
46
+
47
+ Args:
48
+ x (Tensor): The input tensor, which is the result of the LoRA A projection.
49
+ Shape: (s, 3 * K), where s is the sum of all sequence lengths in the
50
+ batch and K is the maximum LoRA rank. The second dimension is partitioned
51
+ for Q, K, and V.
52
+ weights (Tensor): The LoRA B weights for all adapters.
53
+ Shape: (num_lora, N_Q + 2 * N_KV, K).
54
+ output (Tensor): The output tensor where the result is stored.
55
+ Shape: (s, N_Q + 2 * N_KV).
56
+ """
46
57
 
47
58
  # Current block computes sequence with batch_id,
48
59
  # which starts from row seg_start of x with length seg_len.
49
60
  # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
50
61
  batch_id = tl.program_id(axis=2)
62
+ w_index = tl.load(weight_indices + batch_id)
63
+ rank = tl.load(lora_ranks + w_index)
64
+
65
+ # If rank is 0, this kernel is a no-op.
66
+ if rank == 0:
67
+ return
68
+
51
69
  qkv_id = tl.program_id(axis=1)
52
70
  pid = tl.program_id(axis=0)
53
71
  seg_len = tl.load(seg_lens + batch_id)
54
- w_index = tl.load(weight_indices + batch_id)
55
72
  seg_start = tl.load(seg_indptr + batch_id)
56
73
  n_start = tl.load(n_offs + qkv_id)
57
74
  n_size = tl.load(n_offs + qkv_id + 1) - n_start
58
- rank = tl.load(lora_ranks + w_index)
59
75
  scaling = tl.load(scalings + w_index)
60
76
  # Adjust K (rank) according to the specific LoRA adapter
61
77
  K = tl.minimum(K, rank)
@@ -84,13 +100,12 @@ def _qkv_lora_b_kernel(
84
100
  for k in range(0, tl.cdiv(K, BLOCK_K)):
85
101
  x_tile = tl.load(
86
102
  x_ptrs,
87
- mask=(s_offset[:, None] < seg_len)
88
- and (k_offset[None, :] < K - k * BLOCK_K),
103
+ mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
89
104
  other=0.0,
90
105
  )
91
106
  w_tile = tl.load(
92
107
  w_ptrs,
93
- mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size),
108
+ mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < n_size),
94
109
  other=0.0,
95
110
  )
96
111
  partial_sum += tl.dot(x_tile, w_tile)
@@ -105,8 +120,7 @@ def _qkv_lora_b_kernel(
105
120
  s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
106
121
  )
107
122
  output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
108
- if fuse_scaling_add:
109
- partial_sum += tl.load(output_ptr, mask=output_mask)
123
+ partial_sum += tl.load(output_ptr, mask=output_mask)
110
124
  tl.store(output_ptr, partial_sum, mask=output_mask)
111
125
 
112
126
 
@@ -153,11 +167,9 @@ def qkv_lora_b_fwd(
153
167
  )
154
168
 
155
169
  if base_output is None:
156
- output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype)
157
- fuse_scaling_add = False
170
+ output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype)
158
171
  else:
159
172
  output = base_output
160
- fuse_scaling_add = True
161
173
 
162
174
  _qkv_lora_b_kernel[grid_b](
163
175
  x,
@@ -180,7 +192,6 @@ def qkv_lora_b_fwd(
180
192
  BLOCK_S,
181
193
  BLOCK_OUT,
182
194
  BLOCK_R,
183
- fuse_scaling_add,
184
195
  batch_info.scalings,
185
196
  )
186
197
 
@@ -33,19 +33,36 @@ def _sgemm_lora_a_kernel(
33
33
  BLOCK_N: tl.constexpr,
34
34
  BLOCK_K: tl.constexpr,
35
35
  ):
36
-
37
- # x: (s, K), s is the sum of sequence lengths
38
- # weights: (num_lora, N, K)
39
- # output: (s, N)
36
+ """
37
+ Computes a segmented batched matrix multiplication for the LoRA A matrix.
38
+
39
+ The kernel ensures that output[seg_start:seg_start + seg_len, :rank * stack_num]
40
+ stores the product of the input `x` and the LoRA weights for the corresponding
41
+ sequence. This implies that when rank is 0, the kernel is essentially a no-op,
42
+ as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
43
+
44
+ Args:
45
+ x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
46
+ is the sum of all sequence lengths in the batch.
47
+ weights (torch.Tensor): The LoRA 'A' weights for all available adapters,
48
+ with shape `(num_lora, N, K)`.
49
+ output (torch.Tensor): The output tensor of shape `(s, N)`.
50
+ """
40
51
 
41
52
  # Current block computes sequence with batch_id,
42
53
  # which starts from row seg_start of x with length seg_len
43
54
  batch_id = tl.program_id(axis=1)
44
- pid = tl.program_id(axis=0)
45
- seg_len = tl.load(seg_lens + batch_id)
46
55
  w_index = tl.load(weight_indices + batch_id)
47
- seg_start = tl.load(seg_indptr + batch_id)
48
56
  rank = tl.load(lora_ranks + w_index)
57
+
58
+ # If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
59
+ if rank == 0:
60
+ return
61
+
62
+ pid = tl.program_id(axis=0)
63
+ seg_start = tl.load(seg_indptr + batch_id)
64
+ seg_len = tl.load(seg_lens + batch_id)
65
+
49
66
  # Adjust N (stack_num * max_rank) according to the specific LoRA adapter
50
67
  N = tl.minimum(N, rank * stack_num)
51
68
 
@@ -72,13 +89,12 @@ def _sgemm_lora_a_kernel(
72
89
  for k in range(0, tl.cdiv(K, BLOCK_K)):
73
90
  x_tile = tl.load(
74
91
  x_ptrs,
75
- mask=(s_offset[:, None] < seg_len)
76
- and (k_offset[None, :] < K - k * BLOCK_K),
92
+ mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
77
93
  other=0.0,
78
94
  )
79
95
  w_tile = tl.load(
80
96
  w_ptrs,
81
- mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N),
97
+ mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < N),
82
98
  other=0.0,
83
99
  )
84
100
  partial_sum += tl.dot(x_tile, w_tile)
@@ -91,7 +107,7 @@ def _sgemm_lora_a_kernel(
91
107
  output_ptr = (output + seg_start * output_stride_0) + (
92
108
  s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
93
109
  )
94
- output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N)
110
+ output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < N)
95
111
  tl.store(output_ptr, partial_sum, mask=output_mask)
96
112
 
97
113
 
@@ -31,22 +31,39 @@ def _sgemm_lora_b_kernel(
31
31
  BLOCK_S: tl.constexpr,
32
32
  BLOCK_N: tl.constexpr,
33
33
  BLOCK_K: tl.constexpr,
34
- # For fused output scaling and adding
35
- fuse_scaling_add,
34
+ # For fused output scaling
36
35
  scalings,
37
36
  ):
38
- # x: (s, K), s is the sum of sequence lengths
39
- # weights: (num_lora, N, K)
40
- # output: (s, N)
37
+ """
38
+ Computes a segmented batched matrix multiplication for the LoRA B matrix
39
+ and adds the result to the output in-place.
40
+
41
+ When a sequence's rank is 0, the kernel is essentially a no-op, following
42
+ the convention in pytorch where the product of two matrices of shape (m, 0)
43
+ and (0, n) is an all-zero matrix of shape (m, n).
44
+
45
+ Args:
46
+ x (torch.Tensor): The intermediate tensor from the LoRA 'A' multiplication,
47
+ of shape `(s, K)`, where `s` is the total number of tokens.
48
+ weights (torch.Tensor): The LoRA 'B' weights for all available adapters,
49
+ with shape `(num_lora, N, K)`.
50
+ output (torch.Tensor): The output tensor of shape `(s, N)`. This can be
51
+ the base model's output for a fused add operation.
52
+ """
41
53
 
42
54
  # Current block computes sequence with batch_id,
43
55
  # which starts from row seg_start of x with length seg_len
44
56
  batch_id = tl.program_id(axis=1)
57
+ w_index = tl.load(weight_indices + batch_id)
58
+ rank = tl.load(lora_ranks + w_index)
59
+
60
+ # If rank is 0, this kernel is a no-op.
61
+ if rank == 0:
62
+ return
63
+
45
64
  pid = tl.program_id(axis=0)
46
65
  seg_len = tl.load(seg_lens + batch_id)
47
- w_index = tl.load(weight_indices + batch_id)
48
66
  seg_start = tl.load(seg_indptr + batch_id)
49
- rank = tl.load(lora_ranks + w_index)
50
67
  scaling = tl.load(scalings + w_index)
51
68
  # Adjust K (rank) according to the specific LoRA adapter
52
69
  K = tl.minimum(K, rank)
@@ -74,8 +91,7 @@ def _sgemm_lora_b_kernel(
74
91
  for k in range(0, tl.cdiv(K, BLOCK_K)):
75
92
  x_tile = tl.load(
76
93
  x_ptrs,
77
- mask=(s_offset[:, None] < seg_len)
78
- and (k_offset[None, :] < K - k * BLOCK_K),
94
+ mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
79
95
  other=0.0,
80
96
  )
81
97
  w_tile = tl.load(
@@ -95,8 +111,7 @@ def _sgemm_lora_b_kernel(
95
111
  s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
96
112
  )
97
113
  output_mask = s_offset[:, None] < seg_len
98
- if fuse_scaling_add:
99
- partial_sum += tl.load(output_ptr, mask=output_mask)
114
+ partial_sum += tl.load(output_ptr, mask=output_mask)
100
115
  tl.store(output_ptr, partial_sum, mask=output_mask)
101
116
 
102
117
 
@@ -132,11 +147,9 @@ def sgemm_lora_b_fwd(
132
147
  )
133
148
 
134
149
  if base_output is None:
135
- output = torch.empty((S, N), device=x.device, dtype=x.dtype)
136
- fuse_scaling_add = False
150
+ output = torch.zeros((S, N), device=x.device, dtype=x.dtype)
137
151
  else:
138
152
  output = base_output
139
- fuse_scaling_add = True
140
153
 
141
154
  _sgemm_lora_b_kernel[grid](
142
155
  x,
@@ -158,7 +171,6 @@ def sgemm_lora_b_fwd(
158
171
  BLOCK_S,
159
172
  BLOCK_N,
160
173
  BLOCK_R,
161
- fuse_scaling_add,
162
174
  batch_info.scalings,
163
175
  )
164
176
  return output