sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/function_call_parser.py +96 -69
  5. sglang/srt/layers/activation.py +10 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
  7. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  8. sglang/srt/layers/attention/triton_backend.py +124 -12
  9. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
  12. sglang/srt/layers/layernorm.py +1 -5
  13. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -13
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  24. sglang/srt/layers/moe/topk.py +4 -0
  25. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  46. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/fp8_kernel.py +173 -2
  48. sglang/srt/layers/rotary_embedding.py +1 -3
  49. sglang/srt/layers/sampler.py +4 -4
  50. sglang/srt/lora/backend/__init__.py +8 -0
  51. sglang/srt/lora/backend/base_backend.py +95 -0
  52. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  53. sglang/srt/lora/backend/triton_backend.py +61 -0
  54. sglang/srt/lora/lora.py +127 -112
  55. sglang/srt/lora/lora_manager.py +50 -18
  56. sglang/srt/lora/triton_ops/__init__.py +5 -0
  57. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  59. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  60. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  61. sglang/srt/model_executor/forward_batch_info.py +58 -59
  62. sglang/srt/model_executor/model_runner.py +2 -2
  63. sglang/srt/models/llama.py +8 -3
  64. sglang/srt/models/qwen2_vl.py +1 -1
  65. sglang/srt/server_args.py +13 -2
  66. sglang/srt/speculative/build_eagle_tree.py +486 -104
  67. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  68. sglang/srt/speculative/eagle_utils.py +420 -401
  69. sglang/srt/speculative/eagle_worker.py +177 -45
  70. sglang/srt/utils.py +7 -0
  71. sglang/test/runners.py +2 -0
  72. sglang/version.py +1 -1
  73. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +15 -6
  74. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +77 -38
  75. sglang/srt/layers/custom_op_util.py +0 -25
  76. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
  77. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
  78. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,164 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 16,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 4,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0
10
+ },
11
+ "2": {
12
+ "BLOCK_SIZE_M": 64,
13
+ "BLOCK_SIZE_N": 16,
14
+ "BLOCK_SIZE_K": 128,
15
+ "GROUP_SIZE_M": 32,
16
+ "num_warps": 4,
17
+ "num_stages": 2,
18
+ "waves_per_eu": 0
19
+ },
20
+ "4": {
21
+ "BLOCK_SIZE_M": 64,
22
+ "BLOCK_SIZE_N": 16,
23
+ "BLOCK_SIZE_K": 128,
24
+ "GROUP_SIZE_M": 1,
25
+ "num_warps": 4,
26
+ "num_stages": 2,
27
+ "waves_per_eu": 0
28
+ },
29
+ "8": {
30
+ "BLOCK_SIZE_M": 64,
31
+ "BLOCK_SIZE_N": 16,
32
+ "BLOCK_SIZE_K": 128,
33
+ "GROUP_SIZE_M": 4,
34
+ "num_warps": 4,
35
+ "num_stages": 2,
36
+ "waves_per_eu": 0
37
+ },
38
+ "16": {
39
+ "BLOCK_SIZE_M": 64,
40
+ "BLOCK_SIZE_N": 16,
41
+ "BLOCK_SIZE_K": 128,
42
+ "GROUP_SIZE_M": 16,
43
+ "num_warps": 4,
44
+ "num_stages": 2,
45
+ "waves_per_eu": 0
46
+ },
47
+ "24": {
48
+ "BLOCK_SIZE_M": 64,
49
+ "BLOCK_SIZE_N": 16,
50
+ "BLOCK_SIZE_K": 128,
51
+ "GROUP_SIZE_M": 16,
52
+ "num_warps": 4,
53
+ "num_stages": 2,
54
+ "waves_per_eu": 0
55
+ },
56
+ "32": {
57
+ "BLOCK_SIZE_M": 64,
58
+ "BLOCK_SIZE_N": 16,
59
+ "BLOCK_SIZE_K": 128,
60
+ "GROUP_SIZE_M": 16,
61
+ "num_warps": 4,
62
+ "num_stages": 2,
63
+ "waves_per_eu": 0
64
+ },
65
+ "48": {
66
+ "BLOCK_SIZE_M": 64,
67
+ "BLOCK_SIZE_N": 16,
68
+ "BLOCK_SIZE_K": 128,
69
+ "GROUP_SIZE_M": 16,
70
+ "num_warps": 4,
71
+ "num_stages": 2,
72
+ "waves_per_eu": 0
73
+ },
74
+ "64": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 16,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 2,
81
+ "waves_per_eu": 0
82
+ },
83
+ "96": {
84
+ "BLOCK_SIZE_M": 64,
85
+ "BLOCK_SIZE_N": 16,
86
+ "BLOCK_SIZE_K": 128,
87
+ "GROUP_SIZE_M": 1,
88
+ "num_warps": 4,
89
+ "num_stages": 2,
90
+ "waves_per_eu": 0
91
+ },
92
+ "128": {
93
+ "BLOCK_SIZE_M": 64,
94
+ "BLOCK_SIZE_N": 32,
95
+ "BLOCK_SIZE_K": 128,
96
+ "GROUP_SIZE_M": 1,
97
+ "num_warps": 4,
98
+ "num_stages": 2,
99
+ "waves_per_eu": 0
100
+ },
101
+ "256": {
102
+ "BLOCK_SIZE_M": 64,
103
+ "BLOCK_SIZE_N": 32,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 1,
106
+ "num_warps": 4,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0
109
+ },
110
+ "512": {
111
+ "BLOCK_SIZE_M": 128,
112
+ "BLOCK_SIZE_N": 32,
113
+ "BLOCK_SIZE_K": 128,
114
+ "GROUP_SIZE_M": 32,
115
+ "num_warps": 4,
116
+ "num_stages": 2,
117
+ "waves_per_eu": 0
118
+ },
119
+ "1024": {
120
+ "BLOCK_SIZE_M": 64,
121
+ "BLOCK_SIZE_N": 64,
122
+ "BLOCK_SIZE_K": 128,
123
+ "GROUP_SIZE_M": 4,
124
+ "num_warps": 4,
125
+ "num_stages": 2,
126
+ "waves_per_eu": 0
127
+ },
128
+ "1536": {
129
+ "BLOCK_SIZE_M": 64,
130
+ "BLOCK_SIZE_N": 64,
131
+ "BLOCK_SIZE_K": 128,
132
+ "GROUP_SIZE_M": 1,
133
+ "num_warps": 4,
134
+ "num_stages": 2,
135
+ "waves_per_eu": 0
136
+ },
137
+ "2048": {
138
+ "BLOCK_SIZE_M": 64,
139
+ "BLOCK_SIZE_N": 128,
140
+ "BLOCK_SIZE_K": 128,
141
+ "GROUP_SIZE_M": 1,
142
+ "num_warps": 4,
143
+ "num_stages": 2,
144
+ "waves_per_eu": 0
145
+ },
146
+ "3072": {
147
+ "BLOCK_SIZE_M": 64,
148
+ "BLOCK_SIZE_N": 128,
149
+ "BLOCK_SIZE_K": 128,
150
+ "GROUP_SIZE_M": 4,
151
+ "num_warps": 4,
152
+ "num_stages": 2,
153
+ "waves_per_eu": 0
154
+ },
155
+ "4096": {
156
+ "BLOCK_SIZE_M": 64,
157
+ "BLOCK_SIZE_N": 128,
158
+ "BLOCK_SIZE_K": 128,
159
+ "GROUP_SIZE_M": 1,
160
+ "num_warps": 4,
161
+ "num_stages": 2,
162
+ "waves_per_eu": 0
163
+ }
164
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 32,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 8,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 8,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 32,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 32,
23
+ "num_warps": 8,
24
+ "num_stages": 2
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 5
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 32,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 2
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 2
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 8,
56
+ "num_stages": 5
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 64,
63
+ "num_warps": 8,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 32,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 128,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 8,
80
+ "num_stages": 5
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 64,
87
+ "num_warps": 8,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 256,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 8,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 2
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 64,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }
@@ -22,7 +22,7 @@ import torch
22
22
  import triton
23
23
  import triton.language as tl
24
24
 
25
- from sglang.srt.utils import get_device_name, is_hip
25
+ from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
26
26
 
27
27
  is_hip_ = is_hip()
28
28
  fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
@@ -220,6 +220,165 @@ def _w8a8_block_fp8_matmul(
220
220
  tl.store(c_ptrs, c, mask=c_mask)
221
221
 
222
222
 
223
+ @triton.jit
224
+ def _w8a8_block_fp8_matmul_unrolledx4(
225
+ # Pointers to inputs and output
226
+ A,
227
+ B,
228
+ C,
229
+ As,
230
+ Bs,
231
+ # Shape for matmul
232
+ M,
233
+ N,
234
+ K,
235
+ # Block size for block-wise quantization
236
+ group_n,
237
+ group_k,
238
+ # Stride for inputs and output
239
+ stride_am,
240
+ stride_ak,
241
+ stride_bk,
242
+ stride_bn,
243
+ stride_cm,
244
+ stride_cn,
245
+ stride_As_m,
246
+ stride_As_k,
247
+ stride_Bs_k,
248
+ stride_Bs_n,
249
+ # Meta-parameters
250
+ BLOCK_SIZE_M: tl.constexpr,
251
+ BLOCK_SIZE_N: tl.constexpr,
252
+ BLOCK_SIZE_K: tl.constexpr,
253
+ GROUP_SIZE_M: tl.constexpr,
254
+ ):
255
+ """Triton-accelerated function used to perform linear operations (dot
256
+ product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
257
+ tensor `C`.
258
+ """
259
+
260
+ pid = tl.program_id(axis=0)
261
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
262
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
263
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
264
+ group_id = pid // num_pid_in_group
265
+ first_pid_m = group_id * GROUP_SIZE_M
266
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
267
+ pid_m = first_pid_m + (pid % group_size_m)
268
+ pid_n = (pid % num_pid_in_group) // group_size_m
269
+
270
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
271
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
272
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
273
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
274
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
275
+
276
+ As_ptrs = As + offs_am * stride_As_m
277
+ offs_bsn = offs_bn // group_n
278
+ Bs_ptrs = Bs + offs_bsn * stride_Bs_n
279
+
280
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
281
+ # manually unroll to 4 iterations
282
+ UNROLL_FACTOR = 4
283
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * UNROLL_FACTOR)):
284
+ # 1st iteration
285
+ a = tl.load(
286
+ a_ptrs,
287
+ mask=offs_k[None, :] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K,
288
+ other=0.0,
289
+ )
290
+ b = tl.load(
291
+ b_ptrs,
292
+ mask=offs_k[:, None] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K,
293
+ other=0.0,
294
+ )
295
+
296
+ k_start = (k * UNROLL_FACTOR) * BLOCK_SIZE_K
297
+ offs_ks = k_start // group_k
298
+ a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
299
+ b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
300
+
301
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
302
+ a_ptrs += BLOCK_SIZE_K * stride_ak
303
+ b_ptrs += BLOCK_SIZE_K * stride_bk
304
+
305
+ # 2nd iteration
306
+ a = tl.load(
307
+ a_ptrs,
308
+ mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K,
309
+ other=0.0,
310
+ )
311
+ b = tl.load(
312
+ b_ptrs,
313
+ mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K,
314
+ other=0.0,
315
+ )
316
+
317
+ k_start = k_start + BLOCK_SIZE_K
318
+ offs_ks = k_start // group_k
319
+ a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
320
+ b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
321
+
322
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
323
+ a_ptrs += BLOCK_SIZE_K * stride_ak
324
+ b_ptrs += BLOCK_SIZE_K * stride_bk
325
+
326
+ # 3rd iteration
327
+ a = tl.load(
328
+ a_ptrs,
329
+ mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K,
330
+ other=0.0,
331
+ )
332
+ b = tl.load(
333
+ b_ptrs,
334
+ mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K,
335
+ other=0.0,
336
+ )
337
+
338
+ k_start = k_start + BLOCK_SIZE_K
339
+ offs_ks = k_start // group_k
340
+ a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
341
+ b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
342
+
343
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
344
+ a_ptrs += BLOCK_SIZE_K * stride_ak
345
+ b_ptrs += BLOCK_SIZE_K * stride_bk
346
+
347
+ # 4th iteration
348
+ a = tl.load(
349
+ a_ptrs,
350
+ mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K,
351
+ other=0.0,
352
+ )
353
+ b = tl.load(
354
+ b_ptrs,
355
+ mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K,
356
+ other=0.0,
357
+ )
358
+
359
+ k_start = k_start + BLOCK_SIZE_K
360
+ offs_ks = k_start // group_k
361
+ a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
362
+ b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
363
+
364
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
365
+ a_ptrs += BLOCK_SIZE_K * stride_ak
366
+ b_ptrs += BLOCK_SIZE_K * stride_bk
367
+
368
+ if C.dtype.element_ty == tl.bfloat16:
369
+ c = accumulator.to(tl.bfloat16)
370
+ elif C.dtype.element_ty == tl.float16:
371
+ c = accumulator.to(tl.float16)
372
+ else:
373
+ c = accumulator.to(tl.float32)
374
+
375
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
376
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
377
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
378
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
379
+ tl.store(c_ptrs, c, mask=c_mask)
380
+
381
+
223
382
  @functools.lru_cache
224
383
  def get_w8a8_block_fp8_configs(
225
384
  N: int, K: int, block_n: int, block_k: int
@@ -324,7 +483,19 @@ def w8a8_block_fp8_matmul(
324
483
  triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
325
484
  )
326
485
 
327
- _w8a8_block_fp8_matmul[grid](
486
+ # Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
487
+ # Empirical testing shows the sweet spot lies when it's less than the # of
488
+ # compute units available on the device.
489
+ num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
490
+ N, config["BLOCK_SIZE_N"]
491
+ )
492
+ kernel = (
493
+ _w8a8_block_fp8_matmul_unrolledx4
494
+ if (is_hip_ == True and num_workgroups <= get_device_core_count())
495
+ else _w8a8_block_fp8_matmul
496
+ )
497
+
498
+ kernel[grid](
328
499
  A,
329
500
  B,
330
501
  C,
@@ -7,9 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
7
7
  import torch
8
8
  import torch.nn as nn
9
9
  from vllm import _custom_ops as ops
10
- from vllm.model_executor.custom_op import CustomOp
11
10
 
12
- from sglang.srt.layers.custom_op_util import register_custom_op
11
+ from sglang.srt.custom_op import CustomOp
13
12
  from sglang.srt.utils import is_cuda_available
14
13
 
15
14
  _is_cuda_available = is_cuda_available()
@@ -59,7 +58,6 @@ def _apply_rotary_emb(
59
58
  return torch.stack((o1, o2), dim=-1).flatten(-2)
60
59
 
61
60
 
62
- @register_custom_op("sglang_rotary_embedding")
63
61
  class RotaryEmbedding(CustomOp):
64
62
  """Original rotary positional embedding."""
65
63
 
@@ -85,7 +85,7 @@ class Sampler(nn.Module):
85
85
  if sampling_info.need_min_p_sampling:
86
86
  probs = top_k_renorm_prob(probs, sampling_info.top_ks)
87
87
  probs = top_p_renorm_prob(probs, sampling_info.top_ps)
88
- batch_next_token_ids, success = min_p_sampling_from_probs(
88
+ batch_next_token_ids = min_p_sampling_from_probs(
89
89
  probs, uniform_samples, sampling_info.min_ps
90
90
  )
91
91
  else:
@@ -97,9 +97,9 @@ class Sampler(nn.Module):
97
97
  filter_apply_order="joint",
98
98
  )
99
99
 
100
- if self.use_nan_detectioin and not torch.all(success):
101
- logger.warning("Detected errors during sampling!")
102
- batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
100
+ if self.use_nan_detectioin and not torch.all(success):
101
+ logger.warning("Detected errors during sampling!")
102
+ batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
103
103
 
104
104
  elif global_server_args_dict["sampling_backend"] == "pytorch":
105
105
  # A slower fallback implementation with torch native operations.
@@ -0,0 +1,8 @@
1
+ from .base_backend import BaseLoraBackend
2
+ from .flashinfer_backend import FlashInferLoraBackend
3
+ from .triton_backend import TritonLoraBackend
4
+
5
+ __all__ = [
6
+ "FlashInferLoraBackend",
7
+ "TritonLoraBackend",
8
+ ]
@@ -0,0 +1,95 @@
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+
5
+ from sglang.srt.lora.lora import LoraBatchInfo
6
+
7
+
8
+ def get_fuse_output_scaling_add_from_name(name: str) -> bool:
9
+ mapping = {
10
+ "triton": True,
11
+ "flashinfer": False,
12
+ }
13
+ return mapping.get(name, False)
14
+
15
+
16
+ def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
17
+ mapping = {
18
+ "triton": True,
19
+ "flashinfer": False,
20
+ }
21
+ return mapping.get(name, False)
22
+
23
+
24
+ class BaseLoraBackend:
25
+ """Base class for different Lora backends.
26
+ Each backend has its own implementation of Lora kernels.
27
+
28
+ Args:
29
+ name: name of backend
30
+ batch_info: information of current batch for use
31
+ fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
32
+ and the operation of scaling and adding will be fused into kernel
33
+ """
34
+
35
+ def __init__(self, name: str, batch_info: LoraBatchInfo = None):
36
+ self.name = name
37
+ self.batch_info = batch_info
38
+ self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
39
+ self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)
40
+
41
+ def run_lora_a_sgemm(
42
+ self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
43
+ ) -> torch.Tensor:
44
+ """Run segment Gemm of lora a modules with current backend.
45
+ The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
46
+
47
+ Args:
48
+ x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
49
+ weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
50
+ usually input_dim is much larger than r
51
+ Returns:
52
+ result with shape (s, r)
53
+ """
54
+ pass
55
+
56
+ def run_lora_b_sgemm(
57
+ self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
58
+ ) -> torch.Tensor:
59
+ """Run segment Gemm of lora b modules with current backend.
60
+ The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
61
+
62
+ Args:
63
+ x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
64
+ weights: a set of lora weights with shape (num_lora, output_dim, r)
65
+ usually output_dim is much larger than r
66
+ Returns:
67
+ result with shape (s, output_dim)
68
+ """
69
+ pass
70
+
71
+ def run_qkv_lora(
72
+ self,
73
+ x: torch.Tensor,
74
+ qkv_lora_a: torch.Tensor,
75
+ qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
76
+ *args,
77
+ **kwargs
78
+ ) -> torch.Tensor:
79
+ """Run the lora pass for QKV Layer.
80
+
81
+ Args:
82
+ x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
83
+ qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
84
+ qkv_lora_b: lora_b module for qkv.
85
+ If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
86
+ If passed in as a tuple of two tensors containing:
87
+ a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
88
+ and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
89
+ Returns:
90
+ result with shape (s, output_dim_q + 2 * output_dim_kv)
91
+ """
92
+ pass
93
+
94
+ def set_batch_info(self, batch_info: LoraBatchInfo):
95
+ self.batch_info = batch_info