sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/function_call_parser.py +96 -69
  5. sglang/srt/layers/activation.py +10 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
  7. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  8. sglang/srt/layers/attention/triton_backend.py +124 -12
  9. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
  12. sglang/srt/layers/layernorm.py +1 -5
  13. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -13
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  24. sglang/srt/layers/moe/topk.py +4 -0
  25. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  46. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/fp8_kernel.py +173 -2
  48. sglang/srt/layers/rotary_embedding.py +1 -3
  49. sglang/srt/layers/sampler.py +4 -4
  50. sglang/srt/lora/backend/__init__.py +8 -0
  51. sglang/srt/lora/backend/base_backend.py +95 -0
  52. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  53. sglang/srt/lora/backend/triton_backend.py +61 -0
  54. sglang/srt/lora/lora.py +127 -112
  55. sglang/srt/lora/lora_manager.py +50 -18
  56. sglang/srt/lora/triton_ops/__init__.py +5 -0
  57. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  59. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  60. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  61. sglang/srt/model_executor/forward_batch_info.py +58 -59
  62. sglang/srt/model_executor/model_runner.py +2 -2
  63. sglang/srt/models/llama.py +8 -3
  64. sglang/srt/models/qwen2_vl.py +1 -1
  65. sglang/srt/server_args.py +13 -2
  66. sglang/srt/speculative/build_eagle_tree.py +486 -104
  67. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  68. sglang/srt/speculative/eagle_utils.py +420 -401
  69. sglang/srt/speculative/eagle_worker.py +177 -45
  70. sglang/srt/utils.py +7 -0
  71. sglang/test/runners.py +2 -0
  72. sglang/version.py +1 -1
  73. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +15 -6
  74. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +77 -38
  75. sglang/srt/layers/custom_op_util.py +0 -25
  76. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
  77. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
  78. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,182 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from sglang.srt.lora.lora import LoraBatchInfo
6
+
7
+
8
+ @triton.jit
9
+ def _qkv_lora_b_kernel(
10
+ # Pointers to matrices
11
+ x,
12
+ weights,
13
+ output,
14
+ # Parameters of size
15
+ K, # K = R
16
+ max_qkv_out_dim, # max(output_q_dim, output_kv_dim)
17
+ # Strides
18
+ x_stride_0,
19
+ x_stride_1,
20
+ w_stride_0,
21
+ w_stride_1,
22
+ w_stride_2,
23
+ output_stride_0,
24
+ output_stride_1,
25
+ # Information on sequence lengths and weight id
26
+ seg_lens,
27
+ seg_indptr,
28
+ weight_indices,
29
+ # Offsets of q/k/v slice on output dimension
30
+ n_offs,
31
+ # Meta parameters
32
+ BLOCK_S: tl.constexpr,
33
+ BLOCK_N: tl.constexpr,
34
+ BLOCK_K: tl.constexpr,
35
+ # For fused output scaling and adding
36
+ fuse_scaling_add,
37
+ scaling,
38
+ ):
39
+ # This kernel packs 3 sgemms (q/k/v) into a single kernel.
40
+
41
+ # x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
42
+ # weights: (num_lora, N_Q + 2 * N_KV, K)
43
+ # output: (s, N_Q + 2 * N_KV)
44
+ # N_Q >> K, N_KV >> K
45
+
46
+ # Current block computes sequence with batch_id,
47
+ # which starts from row seg_start of x with length seg_len.
48
+ # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
49
+ batch_id = tl.program_id(axis=2)
50
+ qkv_id = tl.program_id(axis=1)
51
+ pid = tl.program_id(axis=0)
52
+ seg_len = tl.load(seg_lens + batch_id)
53
+ w_index = tl.load(weight_indices + batch_id)
54
+ seg_start = tl.load(seg_indptr + batch_id)
55
+ n_start = tl.load(n_offs + qkv_id)
56
+ n_size = tl.load(n_offs + qkv_id + 1) - n_start
57
+
58
+ # The tile in output matrix will have (pid_s, pid_n) as id
59
+ num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
60
+ pid_s = pid // num_pid_n
61
+ pid_n = pid % num_pid_n
62
+
63
+ # Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
64
+ # The pointers will be advanced as we move in the K direction
65
+ # and accumulate
66
+ s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
67
+ n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
68
+ k_offset = tl.arange(0, BLOCK_K)
69
+
70
+ x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + (
71
+ s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
72
+ )
73
+ w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
74
+ k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
75
+ )
76
+
77
+ # Iteate to compute the block in output matrix
78
+ partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
79
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
80
+ x_tile = tl.load(
81
+ x_ptrs,
82
+ mask=(s_offset[:, None] < seg_len)
83
+ and (k_offset[None, :] < K - k * BLOCK_K),
84
+ other=0.0,
85
+ )
86
+ w_tile = tl.load(
87
+ w_ptrs,
88
+ mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size),
89
+ other=0.0,
90
+ )
91
+ partial_sum += tl.dot(x_tile, w_tile)
92
+
93
+ x_ptrs += BLOCK_K * x_stride_1
94
+ w_ptrs += BLOCK_K * w_stride_2
95
+
96
+ # Store result to output matrix
97
+ partial_sum *= scaling
98
+ partial_sum = partial_sum.to(x.dtype.element_ty)
99
+ output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
100
+ s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
101
+ )
102
+ output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
103
+ if fuse_scaling_add:
104
+ partial_sum += tl.load(output_ptr, mask=output_mask)
105
+ tl.store(output_ptr, partial_sum, mask=output_mask)
106
+
107
+
108
+ def qkv_lora_b_fwd(
109
+ x: torch.Tensor,
110
+ qkv_lora_b: torch.Tensor,
111
+ batch_info: LoraBatchInfo,
112
+ output_offset: torch.Tensor,
113
+ max_qkv_out_dim: int,
114
+ base_output: torch.Tensor = None,
115
+ scaling: float = 1.0,
116
+ ) -> torch.Tensor:
117
+
118
+ # x: (s, 3 * r)
119
+ # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
120
+ # output_offset = [0, output_dim_q, output_dim_q + output_dim_kv,
121
+ # output_dim_q + 2 * output_dim_kv]
122
+ # max_qkv_out_dim = max(output_dim_q, output_dim_kv)
123
+ # output: (s, output_dim_q + 2 * output_dim_kv)
124
+
125
+ # Compute lora_output with shape (s, output_dim) as follows:
126
+ # lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], )
127
+ # lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
128
+ # = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0])
129
+ # lora_output[:, output_dim_q + output_dim_kv: ]
130
+ # = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1])
131
+
132
+ # Get dims
133
+ s = x.shape[0]
134
+ input_dim = x.shape[1]
135
+ r = qkv_lora_b.shape[-1]
136
+ output_dim = qkv_lora_b.shape[-2]
137
+ assert input_dim == 3 * r
138
+ assert output_offset.shape[0] == 4
139
+
140
+ BLOCK_S = 16
141
+ BLOCK_R = 16
142
+ BLOCK_OUT = 64
143
+
144
+ grid_b = (
145
+ triton.cdiv(batch_info.max_len, BLOCK_S)
146
+ * triton.cdiv(max_qkv_out_dim, BLOCK_OUT),
147
+ 3, # this dimension decides current block computes on q, k or v
148
+ batch_info.bs,
149
+ )
150
+
151
+ if base_output is None:
152
+ output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype)
153
+ fuse_scaling_add = False
154
+ else:
155
+ output = base_output
156
+ fuse_scaling_add = True
157
+
158
+ _qkv_lora_b_kernel[grid_b](
159
+ x,
160
+ qkv_lora_b,
161
+ output,
162
+ r,
163
+ max_qkv_out_dim,
164
+ x.stride(0),
165
+ x.stride(1),
166
+ qkv_lora_b.stride(0),
167
+ qkv_lora_b.stride(1),
168
+ qkv_lora_b.stride(2),
169
+ output.stride(0),
170
+ output.stride(1),
171
+ batch_info.seg_lens,
172
+ batch_info.seg_indptr,
173
+ batch_info.weight_indices,
174
+ output_offset,
175
+ BLOCK_S,
176
+ BLOCK_OUT,
177
+ BLOCK_R,
178
+ fuse_scaling_add,
179
+ scaling,
180
+ )
181
+
182
+ return output
@@ -0,0 +1,143 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from sglang.srt.lora.lora import LoraBatchInfo
6
+
7
+
8
+ @triton.jit
9
+ def _sgemm_lora_a_kernel(
10
+ # Pointers to matrices
11
+ x,
12
+ weights,
13
+ output,
14
+ # Matrix dimensions
15
+ N, # r
16
+ K, # input_dim
17
+ # Strides
18
+ x_stride_0,
19
+ x_stride_1,
20
+ w_stride_0,
21
+ w_stride_1,
22
+ w_stride_2,
23
+ output_stride_0,
24
+ output_stride_1,
25
+ # Information on sequence lengths and weight id
26
+ seg_lens,
27
+ seg_indptr,
28
+ weight_indices,
29
+ # Meta parameters
30
+ BLOCK_S: tl.constexpr,
31
+ BLOCK_N: tl.constexpr,
32
+ BLOCK_K: tl.constexpr,
33
+ ):
34
+
35
+ # x: (s, K), s is the sum of sequence lengths
36
+ # weights: (num_lora, N, K)
37
+ # output: (s, N)
38
+
39
+ # Current block computes sequence with batch_id,
40
+ # which starts from row seg_start of x with length seg_len
41
+ batch_id = tl.program_id(axis=1)
42
+ pid = tl.program_id(axis=0)
43
+ seg_len = tl.load(seg_lens + batch_id)
44
+ w_index = tl.load(weight_indices + batch_id)
45
+ seg_start = tl.load(seg_indptr + batch_id)
46
+
47
+ # The tile in output matrix will have (pid_s, pid_n) as id
48
+ num_pid_n = tl.cdiv(N, BLOCK_N)
49
+ pid_s = pid // num_pid_n
50
+ pid_n = pid % num_pid_n
51
+
52
+ # Create pointers for the first block of x and weights[batch_id]
53
+ # The pointers will be advanced as we move in the K direction
54
+ # and accumulate
55
+ s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
56
+ n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
57
+ k_offset = tl.arange(0, BLOCK_K)
58
+ x_ptrs = (x + seg_start * x_stride_0) + (
59
+ s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
60
+ )
61
+ w_ptrs = (weights + w_index * w_stride_0) + (
62
+ k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
63
+ )
64
+
65
+ # Iteate to compute the block in output matrix
66
+ partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
67
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
68
+ x_tile = tl.load(
69
+ x_ptrs,
70
+ mask=(s_offset[:, None] < seg_len)
71
+ and (k_offset[None, :] < K - k * BLOCK_K),
72
+ other=0.0,
73
+ )
74
+ w_tile = tl.load(
75
+ w_ptrs,
76
+ mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N),
77
+ other=0.0,
78
+ )
79
+ partial_sum += tl.dot(x_tile, w_tile)
80
+
81
+ x_ptrs += BLOCK_K * x_stride_1
82
+ w_ptrs += BLOCK_K * w_stride_2
83
+
84
+ # Store result to output matrix
85
+ partial_sum = partial_sum.to(x.dtype.element_ty)
86
+ output_ptr = (output + seg_start * output_stride_0) + (
87
+ s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
88
+ )
89
+ output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N)
90
+ tl.store(output_ptr, partial_sum, mask=output_mask)
91
+
92
+
93
+ def sgemm_lora_a_fwd(
94
+ x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo
95
+ ) -> torch.Tensor:
96
+ # x: (s, input_dim)
97
+ # weights: (num_lora, r, input_dim)
98
+ # output: (s, r)
99
+ # when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
100
+ # input_dim is much larger than r
101
+
102
+ assert x.is_contiguous()
103
+ assert weights.is_contiguous()
104
+ assert len(x.shape) == 2
105
+ assert len(weights.shape) == 3
106
+
107
+ S = x.shape[0]
108
+ R = weights.shape[-2]
109
+ K = weights.shape[-1]
110
+ assert x.shape[-1] == K
111
+
112
+ # Block shapes
113
+ BLOCK_S = 16
114
+ BLOCK_K = 256
115
+ BLOCK_R = 16
116
+
117
+ grid = (
118
+ triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(R, BLOCK_R),
119
+ batch_info.bs,
120
+ )
121
+
122
+ output = torch.empty((S, R), device=x.device, dtype=x.dtype)
123
+ _sgemm_lora_a_kernel[grid](
124
+ x,
125
+ weights,
126
+ output,
127
+ R,
128
+ K,
129
+ x.stride(0),
130
+ x.stride(1),
131
+ weights.stride(0),
132
+ weights.stride(1),
133
+ weights.stride(2),
134
+ output.stride(0),
135
+ output.stride(1),
136
+ batch_info.seg_lens,
137
+ batch_info.seg_indptr,
138
+ batch_info.weight_indices,
139
+ BLOCK_S,
140
+ BLOCK_R,
141
+ BLOCK_K,
142
+ )
143
+ return output
@@ -0,0 +1,159 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from sglang.srt.lora.lora import LoraBatchInfo
6
+
7
+
8
+ @triton.jit
9
+ def _sgemm_lora_b_kernel(
10
+ # Pointers to matrices
11
+ x,
12
+ weights,
13
+ output,
14
+ # Matrix dimensions
15
+ N, # output_dim
16
+ K, # r
17
+ # Strides
18
+ x_stride_0,
19
+ x_stride_1,
20
+ w_stride_0,
21
+ w_stride_1,
22
+ w_stride_2,
23
+ output_stride_0,
24
+ output_stride_1,
25
+ # Information on sequence lengths and weight id
26
+ seg_lens,
27
+ seg_indptr,
28
+ weight_indices,
29
+ # Meta parameters
30
+ BLOCK_S: tl.constexpr,
31
+ BLOCK_N: tl.constexpr,
32
+ BLOCK_K: tl.constexpr,
33
+ # For fused output scaling and adding
34
+ fuse_scaling_add,
35
+ scaling,
36
+ ):
37
+ # x: (s, K), s is the sum of sequence lengths
38
+ # weights: (num_lora, N, K)
39
+ # output: (s, N)
40
+
41
+ # Current block computes sequence with batch_id,
42
+ # which starts from row seg_start of x with length seg_len
43
+ batch_id = tl.program_id(axis=1)
44
+ pid = tl.program_id(axis=0)
45
+ seg_len = tl.load(seg_lens + batch_id)
46
+ w_index = tl.load(weight_indices + batch_id)
47
+ seg_start = tl.load(seg_indptr + batch_id)
48
+
49
+ # The tile in output matrix will have (pid_s, pid_n) as id
50
+ num_pid_n = tl.cdiv(N, BLOCK_N)
51
+ pid_s = pid // num_pid_n
52
+ pid_n = pid % num_pid_n
53
+
54
+ # Create pointers for the first block of x and weights[batch_id]
55
+ # The pointers will be advanced as we move in the K direction
56
+ # and accumulate
57
+ s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
58
+ n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
59
+ k_offset = tl.arange(0, BLOCK_K)
60
+ x_ptrs = (x + seg_start * x_stride_0) + (
61
+ s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
62
+ )
63
+ w_ptrs = (weights + w_index * w_stride_0) + (
64
+ k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
65
+ )
66
+
67
+ # Iteate to compute the block in output matrix
68
+ partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
69
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
70
+ x_tile = tl.load(
71
+ x_ptrs,
72
+ mask=(s_offset[:, None] < seg_len)
73
+ and (k_offset[None, :] < K - k * BLOCK_K),
74
+ other=0.0,
75
+ )
76
+ w_tile = tl.load(
77
+ w_ptrs,
78
+ mask=(k_offset[:, None] < K - k * BLOCK_K),
79
+ other=0.0,
80
+ )
81
+ partial_sum += tl.dot(x_tile, w_tile)
82
+
83
+ x_ptrs += BLOCK_K * x_stride_1
84
+ w_ptrs += BLOCK_K * w_stride_2
85
+
86
+ # Store result to output matrix
87
+ partial_sum *= scaling
88
+ partial_sum = partial_sum.to(x.dtype.element_ty)
89
+ output_ptr = (output + seg_start * output_stride_0) + (
90
+ s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
91
+ )
92
+ output_mask = s_offset[:, None] < seg_len
93
+ if fuse_scaling_add:
94
+ partial_sum += tl.load(output_ptr, mask=output_mask)
95
+ tl.store(output_ptr, partial_sum, mask=output_mask)
96
+
97
+
98
+ def sgemm_lora_b_fwd(
99
+ x: torch.Tensor,
100
+ weights: torch.Tensor,
101
+ batch_info: LoraBatchInfo,
102
+ base_output: torch.Tensor = None,
103
+ scaling: float = 1.0,
104
+ ) -> torch.Tensor:
105
+ # x: (s, r)
106
+ # weights: (num_lora, output_dim, r)
107
+ # output: (s, output_dim)
108
+ # output_dim is much larger than r
109
+
110
+ assert x.is_contiguous()
111
+ assert weights.is_contiguous()
112
+ assert len(x.shape) == 2
113
+ assert len(weights.shape) == 3
114
+
115
+ S = x.shape[0]
116
+ N = weights.shape[-2]
117
+ R = weights.shape[-1]
118
+ assert x.shape[-1] == R
119
+
120
+ # Block shapes
121
+ BLOCK_S = 16
122
+ BLOCK_R = 16
123
+ BLOCK_N = 256
124
+
125
+ grid = (
126
+ triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(N, BLOCK_N),
127
+ batch_info.bs,
128
+ )
129
+
130
+ if base_output is None:
131
+ output = torch.empty((S, N), device=x.device, dtype=x.dtype)
132
+ fuse_scaling_add = False
133
+ else:
134
+ output = base_output
135
+ fuse_scaling_add = True
136
+
137
+ _sgemm_lora_b_kernel[grid](
138
+ x,
139
+ weights,
140
+ output,
141
+ N,
142
+ R,
143
+ x.stride(0),
144
+ x.stride(1),
145
+ weights.stride(0),
146
+ weights.stride(1),
147
+ weights.stride(2),
148
+ output.stride(0),
149
+ output.stride(1),
150
+ batch_info.seg_lens,
151
+ batch_info.seg_indptr,
152
+ batch_info.weight_indices,
153
+ BLOCK_S,
154
+ BLOCK_N,
155
+ BLOCK_R,
156
+ fuse_scaling_add,
157
+ scaling,
158
+ )
159
+ return output