sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +36 -2
  3. sglang/srt/conversation.py +56 -3
  4. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  5. sglang/srt/disaggregation/ascend/conn.py +44 -0
  6. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +50 -18
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  9. sglang/srt/disaggregation/utils.py +25 -3
  10. sglang/srt/entrypoints/engine.py +1 -1
  11. sglang/srt/entrypoints/http_server.py +1 -0
  12. sglang/srt/entrypoints/http_server_engine.py +1 -1
  13. sglang/srt/entrypoints/openai/protocol.py +11 -0
  14. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  15. sglang/srt/function_call/function_call_parser.py +2 -0
  16. sglang/srt/function_call/kimik2_detector.py +220 -0
  17. sglang/srt/hf_transformers_utils.py +18 -0
  18. sglang/srt/jinja_template_utils.py +8 -0
  19. sglang/srt/layers/communicator.py +20 -5
  20. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  21. sglang/srt/layers/layernorm.py +2 -2
  22. sglang/srt/layers/linear.py +12 -2
  23. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  24. sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
  25. sglang/srt/layers/moe/ep_moe/layer.py +141 -2
  26. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  29. sglang/srt/layers/moe/topk.py +8 -2
  30. sglang/srt/layers/parameter.py +19 -3
  31. sglang/srt/layers/quantization/__init__.py +2 -0
  32. sglang/srt/layers/quantization/fp8.py +28 -7
  33. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  35. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  36. sglang/srt/layers/quantization/w4afp8.py +264 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  38. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  39. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  40. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  41. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  42. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  43. sglang/srt/managers/cache_controller.py +41 -195
  44. sglang/srt/managers/io_struct.py +35 -3
  45. sglang/srt/managers/mm_utils.py +59 -96
  46. sglang/srt/managers/schedule_batch.py +17 -6
  47. sglang/srt/managers/scheduler.py +38 -6
  48. sglang/srt/managers/tokenizer_manager.py +16 -0
  49. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  50. sglang/srt/mem_cache/memory_pool.py +176 -101
  51. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  52. sglang/srt/mem_cache/radix_cache.py +8 -4
  53. sglang/srt/model_executor/forward_batch_info.py +13 -1
  54. sglang/srt/model_loader/loader.py +23 -12
  55. sglang/srt/models/deepseek_janus_pro.py +1 -1
  56. sglang/srt/models/deepseek_v2.py +78 -19
  57. sglang/srt/models/deepseek_vl2.py +1 -1
  58. sglang/srt/models/gemma3_mm.py +1 -1
  59. sglang/srt/models/gemma3n_mm.py +6 -3
  60. sglang/srt/models/internvl.py +8 -2
  61. sglang/srt/models/kimi_vl.py +8 -2
  62. sglang/srt/models/llama.py +2 -0
  63. sglang/srt/models/llava.py +3 -1
  64. sglang/srt/models/llavavid.py +1 -1
  65. sglang/srt/models/minicpmo.py +1 -2
  66. sglang/srt/models/minicpmv.py +1 -1
  67. sglang/srt/models/mixtral_quant.py +4 -0
  68. sglang/srt/models/mllama4.py +372 -82
  69. sglang/srt/models/phi4mm.py +8 -2
  70. sglang/srt/models/phimoe.py +553 -0
  71. sglang/srt/models/qwen2.py +2 -0
  72. sglang/srt/models/qwen2_5_vl.py +10 -7
  73. sglang/srt/models/qwen2_vl.py +12 -1
  74. sglang/srt/models/vila.py +8 -2
  75. sglang/srt/multimodal/mm_utils.py +2 -2
  76. sglang/srt/multimodal/processors/base_processor.py +197 -137
  77. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  78. sglang/srt/multimodal/processors/gemma3.py +4 -2
  79. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  80. sglang/srt/multimodal/processors/internvl.py +1 -1
  81. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  82. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  83. sglang/srt/multimodal/processors/minicpm.py +4 -3
  84. sglang/srt/multimodal/processors/mllama4.py +63 -61
  85. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  86. sglang/srt/multimodal/processors/pixtral.py +1 -1
  87. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  88. sglang/srt/multimodal/processors/vila.py +1 -1
  89. sglang/srt/server_args.py +26 -4
  90. sglang/srt/two_batch_overlap.py +3 -0
  91. sglang/srt/utils.py +191 -48
  92. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  93. sglang/utils.py +5 -5
  94. sglang/version.py +1 -1
  95. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
  96. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
  97. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/layers/vocab_parallel_embedding.py
2
2
 
3
+ import logging
3
4
  from dataclasses import dataclass
4
5
  from typing import List, Optional, Sequence, Tuple
5
6
 
@@ -28,6 +29,8 @@ DEFAULT_VOCAB_PADDING_SIZE = 64
28
29
  _is_cpu_amx_available = cpu_has_amx_support()
29
30
  _is_cpu = is_cpu()
30
31
 
32
+ logger = logging.getLogger(__name__)
33
+
31
34
 
32
35
  class UnquantizedEmbeddingMethod(QuantizeMethodBase):
33
36
  """Unquantized method for embeddings."""
@@ -562,9 +565,12 @@ class ParallelLMHead(VocabParallelEmbedding):
562
565
  )
563
566
  self.quant_config = quant_config
564
567
 
565
- # We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight"
566
- if self.quant_config is None and _is_cpu and _is_cpu_amx_available:
567
- self.quant_method = PackWeightMethod(weight_names=["weight"])
568
+ # We only support pack LMHead if it's not quantized.
569
+ if _is_cpu and _is_cpu_amx_available:
570
+ if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16:
571
+ self.quant_method = PackWeightMethod(weight_names=["weight"])
572
+ else:
573
+ logger.warning("The weight of LmHead is not packed")
568
574
 
569
575
  if bias:
570
576
  self.bias = Parameter(
@@ -31,28 +31,44 @@ def _gate_up_lora_b_kernel(
31
31
  BLOCK_S: tl.constexpr,
32
32
  BLOCK_N: tl.constexpr,
33
33
  BLOCK_K: tl.constexpr,
34
- # For fused output scaling and adding
35
- fuse_scaling_add,
34
+ # For fused output scaling
36
35
  scalings,
37
36
  ):
38
- # This kernel packs 2 sgemms (gate/up) into a single kernel.
39
-
40
- # x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
41
- # weights: (num_lora, 2 * output_dim, K)
42
- # output: (s, 2 * output_dim)
37
+ """
38
+ This kernel packs 2 sgemms (gate/up) into a single kernel. The multiplication
39
+ results are accumulated into the output tensor.
40
+
41
+ When a sequence's rank is 0, the kernel is essentially a no-op, following
42
+ the convention in pytorch where the product of two matrices of shape (m, 0)
43
+ and (0, n) is an all-zero matrix of shape (m, n).
44
+
45
+ Args:
46
+ x (Tensor): The input tensor, which is the result of the LoRA A projection.
47
+ Shape: (s, 2 * K), where s is the sum of all sequence lengths in the
48
+ batch and K is the maximum LoRA rank.
49
+ weights (Tensor): The LoRA B weights for all adapters.
50
+ Shape: (num_lora, 2 * output_dim, K).
51
+ output (Tensor): The output tensor where the result is stored.
52
+ Shape: (s, 2 * output_dim).
53
+ """
43
54
  # output_dim >> K
44
55
 
45
56
  # Current block computes sequence with batch_id,
46
57
  # which starts from row seg_start of x with length seg_len.
47
58
  # gate_up_id decides which of gate or up (0: gate, 1: up)
48
59
  batch_id = tl.program_id(axis=2)
60
+ w_index = tl.load(weight_indices + batch_id)
61
+ rank = tl.load(lora_ranks + w_index)
62
+
63
+ # If rank is 0, this kernel is a no-op.
64
+ if rank == 0:
65
+ return
66
+
49
67
  gate_up_id = tl.program_id(axis=1)
50
68
  pid = tl.program_id(axis=0)
51
69
  seg_len = tl.load(seg_lens + batch_id)
52
- w_index = tl.load(weight_indices + batch_id)
53
70
  seg_start = tl.load(seg_indptr + batch_id)
54
71
  n_start = gate_up_id * output_dim # offset on output dim
55
- rank = tl.load(lora_ranks + w_index)
56
72
  scaling = tl.load(scalings + w_index)
57
73
 
58
74
  # Adjust K (rank) according to the specific LoRA adapter
@@ -82,14 +98,13 @@ def _gate_up_lora_b_kernel(
82
98
  for k in range(0, tl.cdiv(K, BLOCK_K)):
83
99
  x_tile = tl.load(
84
100
  x_ptrs,
85
- mask=(s_offset[:, None] < seg_len)
86
- and (k_offset[None, :] < K - k * BLOCK_K),
101
+ mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
87
102
  other=0.0,
88
103
  )
89
104
  w_tile = tl.load(
90
105
  w_ptrs,
91
106
  mask=(k_offset[:, None] < K - k * BLOCK_K)
92
- and (n_offset[None, :] < output_dim),
107
+ & (n_offset[None, :] < output_dim),
93
108
  other=0.0,
94
109
  )
95
110
  partial_sum += tl.dot(x_tile, w_tile)
@@ -103,9 +118,8 @@ def _gate_up_lora_b_kernel(
103
118
  output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
104
119
  s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
105
120
  )
106
- output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim)
107
- if fuse_scaling_add:
108
- partial_sum += tl.load(output_ptr, mask=output_mask)
121
+ output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < output_dim)
122
+ partial_sum += tl.load(output_ptr, mask=output_mask)
109
123
  tl.store(output_ptr, partial_sum, mask=output_mask)
110
124
 
111
125
 
@@ -143,11 +157,9 @@ def gate_up_lora_b_fwd(
143
157
  )
144
158
 
145
159
  if base_output is None:
146
- output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype)
147
- fuse_scaling_add = False
160
+ output = torch.zeros((s, 2 * output_dim), device=x.device, dtype=x.dtype)
148
161
  else:
149
162
  output = base_output
150
- fuse_scaling_add = True
151
163
 
152
164
  _gate_up_lora_b_kernel[grid_b](
153
165
  x,
@@ -169,7 +181,6 @@ def gate_up_lora_b_fwd(
169
181
  BLOCK_S,
170
182
  BLOCK_OUT,
171
183
  BLOCK_R,
172
- fuse_scaling_add,
173
184
  batch_info.scalings,
174
185
  )
175
186
 
@@ -33,29 +33,45 @@ def _qkv_lora_b_kernel(
33
33
  BLOCK_S: tl.constexpr,
34
34
  BLOCK_N: tl.constexpr,
35
35
  BLOCK_K: tl.constexpr,
36
- # For fused output scaling and adding
37
- fuse_scaling_add,
36
+ # For fused output scaling
38
37
  scalings,
39
38
  ):
40
- # This kernel packs 3 sgemms (q/k/v) into a single kernel.
41
-
42
- # x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
43
- # weights: (num_lora, N_Q + 2 * N_KV, K)
44
- # output: (s, N_Q + 2 * N_KV)
45
- # N_Q >> K, N_KV >> K
39
+ """
40
+ This kernel packs 3 sgemms (q/k/v) into a single kernel. The multiplication
41
+ results are accumulated into the output tensor.
42
+
43
+ When a sequence's rank is 0, the kernel is essentially a no-op, following
44
+ the convention in pytorch where the product of two matrices of shape (m, 0)
45
+ and (0, n) is an all-zero matrix of shape (m, n).
46
+
47
+ Args:
48
+ x (Tensor): The input tensor, which is the result of the LoRA A projection.
49
+ Shape: (s, 3 * K), where s is the sum of all sequence lengths in the
50
+ batch and K is the maximum LoRA rank. The second dimension is partitioned
51
+ for Q, K, and V.
52
+ weights (Tensor): The LoRA B weights for all adapters.
53
+ Shape: (num_lora, N_Q + 2 * N_KV, K).
54
+ output (Tensor): The output tensor where the result is stored.
55
+ Shape: (s, N_Q + 2 * N_KV).
56
+ """
46
57
 
47
58
  # Current block computes sequence with batch_id,
48
59
  # which starts from row seg_start of x with length seg_len.
49
60
  # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
50
61
  batch_id = tl.program_id(axis=2)
62
+ w_index = tl.load(weight_indices + batch_id)
63
+ rank = tl.load(lora_ranks + w_index)
64
+
65
+ # If rank is 0, this kernel is a no-op.
66
+ if rank == 0:
67
+ return
68
+
51
69
  qkv_id = tl.program_id(axis=1)
52
70
  pid = tl.program_id(axis=0)
53
71
  seg_len = tl.load(seg_lens + batch_id)
54
- w_index = tl.load(weight_indices + batch_id)
55
72
  seg_start = tl.load(seg_indptr + batch_id)
56
73
  n_start = tl.load(n_offs + qkv_id)
57
74
  n_size = tl.load(n_offs + qkv_id + 1) - n_start
58
- rank = tl.load(lora_ranks + w_index)
59
75
  scaling = tl.load(scalings + w_index)
60
76
  # Adjust K (rank) according to the specific LoRA adapter
61
77
  K = tl.minimum(K, rank)
@@ -84,13 +100,12 @@ def _qkv_lora_b_kernel(
84
100
  for k in range(0, tl.cdiv(K, BLOCK_K)):
85
101
  x_tile = tl.load(
86
102
  x_ptrs,
87
- mask=(s_offset[:, None] < seg_len)
88
- and (k_offset[None, :] < K - k * BLOCK_K),
103
+ mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
89
104
  other=0.0,
90
105
  )
91
106
  w_tile = tl.load(
92
107
  w_ptrs,
93
- mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size),
108
+ mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < n_size),
94
109
  other=0.0,
95
110
  )
96
111
  partial_sum += tl.dot(x_tile, w_tile)
@@ -105,8 +120,7 @@ def _qkv_lora_b_kernel(
105
120
  s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
106
121
  )
107
122
  output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
108
- if fuse_scaling_add:
109
- partial_sum += tl.load(output_ptr, mask=output_mask)
123
+ partial_sum += tl.load(output_ptr, mask=output_mask)
110
124
  tl.store(output_ptr, partial_sum, mask=output_mask)
111
125
 
112
126
 
@@ -153,11 +167,9 @@ def qkv_lora_b_fwd(
153
167
  )
154
168
 
155
169
  if base_output is None:
156
- output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype)
157
- fuse_scaling_add = False
170
+ output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype)
158
171
  else:
159
172
  output = base_output
160
- fuse_scaling_add = True
161
173
 
162
174
  _qkv_lora_b_kernel[grid_b](
163
175
  x,
@@ -180,7 +192,6 @@ def qkv_lora_b_fwd(
180
192
  BLOCK_S,
181
193
  BLOCK_OUT,
182
194
  BLOCK_R,
183
- fuse_scaling_add,
184
195
  batch_info.scalings,
185
196
  )
186
197
 
@@ -33,19 +33,36 @@ def _sgemm_lora_a_kernel(
33
33
  BLOCK_N: tl.constexpr,
34
34
  BLOCK_K: tl.constexpr,
35
35
  ):
36
-
37
- # x: (s, K), s is the sum of sequence lengths
38
- # weights: (num_lora, N, K)
39
- # output: (s, N)
36
+ """
37
+ Computes a segmented batched matrix multiplication for the LoRA A matrix.
38
+
39
+ The kernel ensures that output[seg_start:seg_start + seg_len, :rank * stack_num]
40
+ stores the product of the input `x` and the LoRA weights for the corresponding
41
+ sequence. This implies that when rank is 0, the kernel is essentially a no-op,
42
+ as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
43
+
44
+ Args:
45
+ x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
46
+ is the sum of all sequence lengths in the batch.
47
+ weights (torch.Tensor): The LoRA 'A' weights for all available adapters,
48
+ with shape `(num_lora, N, K)`.
49
+ output (torch.Tensor): The output tensor of shape `(s, N)`.
50
+ """
40
51
 
41
52
  # Current block computes sequence with batch_id,
42
53
  # which starts from row seg_start of x with length seg_len
43
54
  batch_id = tl.program_id(axis=1)
44
- pid = tl.program_id(axis=0)
45
- seg_len = tl.load(seg_lens + batch_id)
46
55
  w_index = tl.load(weight_indices + batch_id)
47
- seg_start = tl.load(seg_indptr + batch_id)
48
56
  rank = tl.load(lora_ranks + w_index)
57
+
58
+ # If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
59
+ if rank == 0:
60
+ return
61
+
62
+ pid = tl.program_id(axis=0)
63
+ seg_start = tl.load(seg_indptr + batch_id)
64
+ seg_len = tl.load(seg_lens + batch_id)
65
+
49
66
  # Adjust N (stack_num * max_rank) according to the specific LoRA adapter
50
67
  N = tl.minimum(N, rank * stack_num)
51
68
 
@@ -72,13 +89,12 @@ def _sgemm_lora_a_kernel(
72
89
  for k in range(0, tl.cdiv(K, BLOCK_K)):
73
90
  x_tile = tl.load(
74
91
  x_ptrs,
75
- mask=(s_offset[:, None] < seg_len)
76
- and (k_offset[None, :] < K - k * BLOCK_K),
92
+ mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
77
93
  other=0.0,
78
94
  )
79
95
  w_tile = tl.load(
80
96
  w_ptrs,
81
- mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N),
97
+ mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < N),
82
98
  other=0.0,
83
99
  )
84
100
  partial_sum += tl.dot(x_tile, w_tile)
@@ -91,7 +107,7 @@ def _sgemm_lora_a_kernel(
91
107
  output_ptr = (output + seg_start * output_stride_0) + (
92
108
  s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
93
109
  )
94
- output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N)
110
+ output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < N)
95
111
  tl.store(output_ptr, partial_sum, mask=output_mask)
96
112
 
97
113
 
@@ -31,22 +31,39 @@ def _sgemm_lora_b_kernel(
31
31
  BLOCK_S: tl.constexpr,
32
32
  BLOCK_N: tl.constexpr,
33
33
  BLOCK_K: tl.constexpr,
34
- # For fused output scaling and adding
35
- fuse_scaling_add,
34
+ # For fused output scaling
36
35
  scalings,
37
36
  ):
38
- # x: (s, K), s is the sum of sequence lengths
39
- # weights: (num_lora, N, K)
40
- # output: (s, N)
37
+ """
38
+ Computes a segmented batched matrix multiplication for the LoRA B matrix
39
+ and adds the result to the output in-place.
40
+
41
+ When a sequence's rank is 0, the kernel is essentially a no-op, following
42
+ the convention in pytorch where the product of two matrices of shape (m, 0)
43
+ and (0, n) is an all-zero matrix of shape (m, n).
44
+
45
+ Args:
46
+ x (torch.Tensor): The intermediate tensor from the LoRA 'A' multiplication,
47
+ of shape `(s, K)`, where `s` is the total number of tokens.
48
+ weights (torch.Tensor): The LoRA 'B' weights for all available adapters,
49
+ with shape `(num_lora, N, K)`.
50
+ output (torch.Tensor): The output tensor of shape `(s, N)`. This can be
51
+ the base model's output for a fused add operation.
52
+ """
41
53
 
42
54
  # Current block computes sequence with batch_id,
43
55
  # which starts from row seg_start of x with length seg_len
44
56
  batch_id = tl.program_id(axis=1)
57
+ w_index = tl.load(weight_indices + batch_id)
58
+ rank = tl.load(lora_ranks + w_index)
59
+
60
+ # If rank is 0, this kernel is a no-op.
61
+ if rank == 0:
62
+ return
63
+
45
64
  pid = tl.program_id(axis=0)
46
65
  seg_len = tl.load(seg_lens + batch_id)
47
- w_index = tl.load(weight_indices + batch_id)
48
66
  seg_start = tl.load(seg_indptr + batch_id)
49
- rank = tl.load(lora_ranks + w_index)
50
67
  scaling = tl.load(scalings + w_index)
51
68
  # Adjust K (rank) according to the specific LoRA adapter
52
69
  K = tl.minimum(K, rank)
@@ -74,8 +91,7 @@ def _sgemm_lora_b_kernel(
74
91
  for k in range(0, tl.cdiv(K, BLOCK_K)):
75
92
  x_tile = tl.load(
76
93
  x_ptrs,
77
- mask=(s_offset[:, None] < seg_len)
78
- and (k_offset[None, :] < K - k * BLOCK_K),
94
+ mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
79
95
  other=0.0,
80
96
  )
81
97
  w_tile = tl.load(
@@ -95,8 +111,7 @@ def _sgemm_lora_b_kernel(
95
111
  s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
96
112
  )
97
113
  output_mask = s_offset[:, None] < seg_len
98
- if fuse_scaling_add:
99
- partial_sum += tl.load(output_ptr, mask=output_mask)
114
+ partial_sum += tl.load(output_ptr, mask=output_mask)
100
115
  tl.store(output_ptr, partial_sum, mask=output_mask)
101
116
 
102
117
 
@@ -132,11 +147,9 @@ def sgemm_lora_b_fwd(
132
147
  )
133
148
 
134
149
  if base_output is None:
135
- output = torch.empty((S, N), device=x.device, dtype=x.dtype)
136
- fuse_scaling_add = False
150
+ output = torch.zeros((S, N), device=x.device, dtype=x.dtype)
137
151
  else:
138
152
  output = base_output
139
- fuse_scaling_add = True
140
153
 
141
154
  _sgemm_lora_b_kernel[grid](
142
155
  x,
@@ -158,7 +171,6 @@ def sgemm_lora_b_fwd(
158
171
  BLOCK_S,
159
172
  BLOCK_N,
160
173
  BLOCK_R,
161
- fuse_scaling_add,
162
174
  batch_info.scalings,
163
175
  )
164
176
  return output