sglang 0.4.2.post3__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/check_env.py +1 -0
  2. sglang/global_config.py +2 -0
  3. sglang/srt/constrained/outlines_backend.py +4 -1
  4. sglang/srt/entrypoints/engine.py +2 -2
  5. sglang/srt/layers/attention/flashinfer_backend.py +265 -147
  6. sglang/srt/layers/attention/triton_backend.py +358 -72
  7. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  8. sglang/srt/layers/linear.py +12 -5
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
  18. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -5
  19. sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
  20. sglang/srt/layers/moe/topk.py +1 -1
  21. sglang/srt/layers/quantization/__init__.py +51 -5
  22. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  25. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  31. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
  32. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
  35. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  37. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
  39. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  41. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
  49. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  51. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
  53. sglang/srt/layers/quantization/fp8_kernel.py +123 -17
  54. sglang/srt/layers/quantization/fp8_utils.py +33 -4
  55. sglang/srt/lora/backend/__init__.py +25 -5
  56. sglang/srt/lora/backend/base_backend.py +31 -9
  57. sglang/srt/lora/backend/flashinfer_backend.py +41 -4
  58. sglang/srt/lora/backend/triton_backend.py +34 -4
  59. sglang/srt/lora/layers.py +293 -0
  60. sglang/srt/lora/lora.py +101 -326
  61. sglang/srt/lora/lora_manager.py +101 -269
  62. sglang/srt/lora/mem_pool.py +174 -0
  63. sglang/srt/lora/triton_ops/__init__.py +7 -1
  64. sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
  65. sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
  66. sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
  67. sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
  68. sglang/srt/lora/utils.py +141 -0
  69. sglang/srt/managers/detokenizer_manager.py +1 -0
  70. sglang/srt/managers/io_struct.py +4 -0
  71. sglang/srt/managers/schedule_batch.py +16 -3
  72. sglang/srt/managers/scheduler.py +29 -0
  73. sglang/srt/managers/tokenizer_manager.py +6 -0
  74. sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
  75. sglang/srt/model_executor/cuda_graph_runner.py +16 -1
  76. sglang/srt/model_executor/model_runner.py +12 -2
  77. sglang/srt/models/deepseek_v2.py +17 -7
  78. sglang/srt/server_args.py +20 -1
  79. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  80. sglang/srt/speculative/eagle_utils.py +64 -21
  81. sglang/srt/speculative/eagle_worker.py +29 -8
  82. sglang/srt/utils.py +7 -0
  83. sglang/version.py +1 -1
  84. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/METADATA +6 -5
  85. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/RECORD +88 -55
  86. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/LICENSE +0 -0
  87. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/WHEEL +0 -0
  88. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING, Optional
4
4
 
5
5
  import torch
6
+ import triton
6
7
 
7
8
  from sglang.srt.layers.attention import AttentionBackend
8
9
  from sglang.srt.layers.attention.flashinfer_backend import (
@@ -18,7 +19,12 @@ if TYPE_CHECKING:
18
19
 
19
20
 
20
21
  class TritonAttnBackend(AttentionBackend):
21
- def __init__(self, model_runner: ModelRunner):
22
+ def __init__(
23
+ self,
24
+ model_runner: ModelRunner,
25
+ skip_prefill: bool = False,
26
+ kv_indptr_buf: Optional[torch.Tensor] = None,
27
+ ):
22
28
  # Lazy import to avoid the initialization of cuda context
23
29
  from sglang.srt.layers.attention.triton_ops.decode_attention import (
24
30
  decode_attention_fwd,
@@ -32,14 +38,29 @@ class TritonAttnBackend(AttentionBackend):
32
38
  self.decode_attention_fwd = decode_attention_fwd
33
39
  self.extend_attention_fwd = extend_attention_fwd
34
40
 
41
+ self.skip_prefill = skip_prefill
42
+
35
43
  max_bs = model_runner.req_to_token_pool.size
36
- self.kv_indptr = torch.zeros(
37
- (max_bs + 1,), dtype=torch.int32, device=model_runner.device
38
- )
44
+
45
+ if kv_indptr_buf is None:
46
+ self.kv_indptr = torch.zeros(
47
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
48
+ )
49
+ else:
50
+ self.kv_indptr = kv_indptr_buf
51
+
39
52
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
40
- self.qo_indptr = torch.zeros(
41
- (max_bs + 1,), dtype=torch.int32, device=model_runner.device
42
- )
53
+
54
+ if not self.skip_prefill:
55
+ self.qo_indptr = torch.zeros(
56
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
57
+ )
58
+
59
+ self.mask_indptr = torch.zeros(
60
+ (max_bs + 1,), dtype=torch.int64, device=model_runner.device
61
+ )
62
+
63
+ self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
43
64
 
44
65
  self.num_head = (
45
66
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
@@ -50,7 +71,7 @@ class TritonAttnBackend(AttentionBackend):
50
71
 
51
72
  self.forward_metadata = None
52
73
 
53
- self.cuda_graph_max_seq_len = model_runner.model_config.context_len
74
+ self.max_context_len = model_runner.model_config.context_len
54
75
 
55
76
  self.device = model_runner.device
56
77
 
@@ -59,11 +80,31 @@ class TritonAttnBackend(AttentionBackend):
59
80
 
60
81
  bs = forward_batch.batch_size
61
82
  kv_indptr = self.kv_indptr
62
-
63
- if forward_batch.forward_mode.is_decode():
64
- attn_logits = torch.empty(
83
+ spec_info = forward_batch.spec_info
84
+
85
+ if forward_batch.forward_mode.is_decode_or_idle():
86
+ if spec_info is None:
87
+ kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
88
+ kv_indptr = kv_indptr[: bs + 1]
89
+ kv_indices = torch.zeros(
90
+ forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
91
+ )
92
+ create_flashinfer_kv_indices_triton[(bs,)](
93
+ self.req_to_token,
94
+ forward_batch.req_pool_indices,
95
+ forward_batch.seq_lens,
96
+ kv_indptr,
97
+ None,
98
+ kv_indices,
99
+ self.req_to_token.stride(0),
100
+ )
101
+ else:
102
+ kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
103
+ bs = kv_indptr.shape[0] - 1
104
+
105
+ attn_logits = torch.zeros(
65
106
  (
66
- forward_batch.batch_size,
107
+ bs,
67
108
  self.num_head,
68
109
  self.num_kv_splits,
69
110
  self.v_head_dim + 1,
@@ -72,12 +113,24 @@ class TritonAttnBackend(AttentionBackend):
72
113
  device=self.device,
73
114
  )
74
115
 
116
+ qo_indptr = None
117
+ custom_mask = None
118
+ mask_indptr = None
75
119
  max_extend_len = None
76
-
120
+ elif forward_batch.forward_mode.is_target_verify():
121
+ bs = len(forward_batch.req_pool_indices)
122
+ qo_indptr = torch.arange(
123
+ 0,
124
+ (1 + bs) * self.num_draft_tokens,
125
+ step=self.num_draft_tokens,
126
+ dtype=torch.int32,
127
+ device=self.device,
128
+ )
129
+ # Different with flashinfer kv_indptr and kv_indices construction
77
130
  kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
78
131
  kv_indptr = kv_indptr[: bs + 1]
79
- kv_indices = torch.empty(
80
- forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
132
+ kv_indices = torch.zeros(
133
+ kv_indptr[-1], dtype=torch.int32, device=self.device
81
134
  )
82
135
  create_flashinfer_kv_indices_triton[(bs,)](
83
136
  self.req_to_token,
@@ -89,15 +142,32 @@ class TritonAttnBackend(AttentionBackend):
89
142
  self.req_to_token.stride(0),
90
143
  )
91
144
 
92
- qo_indptr = None
93
- custom_mask = None
94
- mask_offsets = None
145
+ custom_mask = spec_info.custom_mask
146
+ seq_mask_len = self.num_draft_tokens * (
147
+ forward_batch.seq_lens + self.num_draft_tokens
148
+ )
149
+ mask_indptr = self.mask_indptr
150
+ mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
151
+ mask_indptr = mask_indptr[: bs + 1]
152
+ max_extend_len = self.num_draft_tokens
153
+ attn_logits = None
154
+ elif forward_batch.forward_mode.is_draft_extend():
155
+ kv_indices, kv_indptr, qo_indptr, custom_mask = (
156
+ spec_info.generate_attn_arg_prefill(
157
+ forward_batch.req_pool_indices,
158
+ forward_batch.seq_lens,
159
+ self.req_to_token,
160
+ )
161
+ )
162
+ mask_indptr = None
163
+ max_extend_len = torch.max(spec_info.accept_length).item()
164
+ attn_logits = None
95
165
  else:
96
166
  kv_indptr[1 : bs + 1] = torch.cumsum(
97
167
  forward_batch.extend_prefix_lens, dim=0
98
168
  )
99
169
  kv_indptr = kv_indptr[: bs + 1]
100
- kv_indices = torch.empty(
170
+ kv_indices = torch.zeros(
101
171
  forward_batch.extend_prefix_lens.sum().item(),
102
172
  dtype=torch.int32,
103
173
  device=self.device,
@@ -116,8 +186,7 @@ class TritonAttnBackend(AttentionBackend):
116
186
  qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
117
187
  qo_indptr = qo_indptr[: bs + 1]
118
188
  custom_mask = None
119
- mask_offsets = None
120
-
189
+ mask_indptr = None
121
190
  attn_logits = None
122
191
  max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
123
192
 
@@ -128,25 +197,32 @@ class TritonAttnBackend(AttentionBackend):
128
197
  kv_indices,
129
198
  qo_indptr,
130
199
  custom_mask,
131
- mask_offsets,
200
+ mask_indptr,
132
201
  )
133
202
 
134
- def init_cuda_graph_state(self, max_bs: int):
135
- self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
136
-
137
- self.cuda_graph_start_loc = torch.zeros(
138
- (max_bs,), dtype=torch.int32, device=self.device
139
- )
140
- self.cuda_graph_attn_logits = torch.empty(
203
+ def init_cuda_graph_state(
204
+ self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
205
+ ):
206
+ self.cuda_graph_attn_logits = torch.zeros(
141
207
  (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
142
208
  dtype=torch.float32,
143
209
  device=self.device,
144
210
  )
145
- self.cuda_graph_kv_indices = torch.zeros(
146
- (max_bs * self.cuda_graph_max_seq_len),
147
- dtype=torch.int32,
148
- device=self.device,
149
- )
211
+ if kv_indices_buf is None:
212
+ self.cuda_graph_kv_indices = torch.zeros(
213
+ (max_bs * self.max_context_len),
214
+ dtype=torch.int32,
215
+ device=self.device,
216
+ )
217
+ else:
218
+ self.cuda_graph_kv_indices = kv_indices_buf
219
+
220
+ if not self.skip_prefill:
221
+ self.cuda_graph_custom_mask = torch.zeros(
222
+ (max_bs * self.max_context_len),
223
+ dtype=torch.uint8,
224
+ device=self.device,
225
+ )
150
226
 
151
227
  def init_forward_metadata_capture_cuda_graph(
152
228
  self,
@@ -159,31 +235,71 @@ class TritonAttnBackend(AttentionBackend):
159
235
  spec_info: Optional[SpecInfo],
160
236
  ):
161
237
  assert encoder_lens is None, "Not supported"
162
- assert forward_mode.is_decode(), "Not supported"
163
- assert spec_info is None, "Not supported"
164
238
 
165
- kv_indptr = self.kv_indptr
166
- kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
167
- kv_indptr = kv_indptr[: bs + 1]
168
- kv_indices = self.cuda_graph_kv_indices
169
- create_flashinfer_kv_indices_triton[(bs,)](
170
- self.req_to_token,
171
- req_pool_indices,
172
- seq_lens,
173
- kv_indptr,
174
- None,
175
- kv_indices,
176
- self.req_to_token.stride(0),
177
- )
239
+ if forward_mode.is_decode_or_idle():
240
+ if spec_info is None:
241
+ kv_indptr = self.kv_indptr
242
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
243
+ kv_indptr = kv_indptr[: bs + 1]
244
+ kv_indices = self.cuda_graph_kv_indices
245
+ create_flashinfer_kv_indices_triton[(bs,)](
246
+ self.req_to_token,
247
+ req_pool_indices,
248
+ seq_lens,
249
+ kv_indptr,
250
+ None,
251
+ kv_indices,
252
+ self.req_to_token.stride(0),
253
+ )
254
+ else:
255
+ kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
256
+
257
+ attn_logits = self.cuda_graph_attn_logits
258
+ max_extend_len = None
259
+ qo_indptr = None
260
+ custom_mask = None
261
+ mask_indptr = None
262
+ elif forward_mode.is_target_verify():
263
+ qo_indptr = self.qo_indptr[: bs + 1]
264
+ qo_indptr[: bs + 1] = torch.arange(
265
+ 0,
266
+ (1 + bs) * self.num_draft_tokens,
267
+ step=self.num_draft_tokens,
268
+ dtype=torch.int32,
269
+ device=self.device,
270
+ )
271
+ kv_indptr = self.kv_indptr[: bs + 1]
272
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
273
+ kv_indices = self.cuda_graph_kv_indices
274
+ create_flashinfer_kv_indices_triton[(bs,)](
275
+ self.req_to_token,
276
+ req_pool_indices,
277
+ seq_lens,
278
+ kv_indptr,
279
+ None,
280
+ kv_indices,
281
+ self.req_to_token.stride(0),
282
+ )
283
+
284
+ custom_mask = self.cuda_graph_custom_mask
285
+ seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
286
+ mask_indptr = self.mask_indptr[: bs + 1]
287
+ mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
288
+ max_extend_len = self.num_draft_tokens
289
+ attn_logits = None
290
+ else:
291
+ raise ValueError(
292
+ f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
293
+ )
178
294
 
179
295
  self.forward_metadata = (
180
- self.cuda_graph_attn_logits,
181
- None,
296
+ attn_logits,
297
+ max_extend_len,
182
298
  kv_indptr,
183
299
  kv_indices,
184
- None,
185
- None,
186
- None,
300
+ qo_indptr,
301
+ custom_mask,
302
+ mask_indptr,
187
303
  )
188
304
 
189
305
  def init_forward_metadata_replay_cuda_graph(
@@ -197,22 +313,57 @@ class TritonAttnBackend(AttentionBackend):
197
313
  spec_info: Optional[SpecInfo],
198
314
  ):
199
315
  # NOTE: encoder_lens expected to be zeros or None
200
- self.cuda_graph_start_loc.zero_()
201
- self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
202
-
203
- kv_indptr = self.kv_indptr
204
- kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
205
- kv_indptr = kv_indptr[: bs + 1]
206
- kv_indices = self.cuda_graph_kv_indices
207
- create_flashinfer_kv_indices_triton[(bs,)](
208
- self.req_to_token,
209
- req_pool_indices[:bs],
210
- seq_lens[:bs],
211
- kv_indptr,
212
- None,
213
- kv_indices,
214
- self.req_to_token.stride(0),
215
- )
316
+ if forward_mode.is_decode_or_idle():
317
+ # Update kv_indptr, kv_indices
318
+ kv_indptr = self.kv_indptr
319
+ kv_indices = self.cuda_graph_kv_indices
320
+ if spec_info is None:
321
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
322
+ kv_indptr = kv_indptr[: bs + 1]
323
+ create_flashinfer_kv_indices_triton[(bs,)](
324
+ self.req_to_token,
325
+ req_pool_indices[:bs],
326
+ seq_lens[:bs],
327
+ kv_indptr,
328
+ None,
329
+ kv_indices,
330
+ self.req_to_token.stride(0),
331
+ )
332
+ else:
333
+ kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
334
+ kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
335
+ elif forward_mode.is_target_verify():
336
+ # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
337
+ bs = len(req_pool_indices)
338
+ qo_indptr = self.qo_indptr[: bs + 1]
339
+ qo_indptr[: bs + 1] = torch.arange(
340
+ 0,
341
+ (1 + bs) * self.num_draft_tokens,
342
+ step=self.num_draft_tokens,
343
+ dtype=torch.int32,
344
+ device=self.device,
345
+ )
346
+ kv_indptr = self.kv_indptr[: bs + 1]
347
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
348
+ kv_indices = self.cuda_graph_kv_indices
349
+ create_flashinfer_kv_indices_triton[(bs,)](
350
+ self.req_to_token,
351
+ req_pool_indices,
352
+ seq_lens,
353
+ kv_indptr,
354
+ None,
355
+ kv_indices,
356
+ self.req_to_token.stride(0),
357
+ )
358
+ custom_mask = self.cuda_graph_custom_mask
359
+ custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
360
+ seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
361
+ mask_indptr = self.mask_indptr[: bs + 1]
362
+ mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
363
+ else:
364
+ raise ValueError(
365
+ f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
366
+ )
216
367
 
217
368
  def get_cuda_graph_seq_len_fill_value(self):
218
369
  return 1
@@ -244,8 +395,9 @@ class TritonAttnBackend(AttentionBackend):
244
395
  kv_indices,
245
396
  qo_indptr,
246
397
  custom_mask,
247
- mask_offsets,
398
+ mask_indptr,
248
399
  ) = self.forward_metadata
400
+
249
401
  self.extend_attention_fwd(
250
402
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
251
403
  k.contiguous(),
@@ -257,7 +409,7 @@ class TritonAttnBackend(AttentionBackend):
257
409
  kv_indptr,
258
410
  kv_indices,
259
411
  custom_mask,
260
- mask_offsets,
412
+ mask_indptr,
261
413
  max_extend_len,
262
414
  layer.scaling,
263
415
  layer.logit_cap,
@@ -303,3 +455,137 @@ class TritonAttnBackend(AttentionBackend):
303
455
  layer.logit_cap,
304
456
  )
305
457
  return o
458
+
459
+
460
+ class TritonMultiStepDraftBackend:
461
+ """
462
+ Wrap multiple triton attention backends as one for multiple consecutive
463
+ draft decoding steps.
464
+ """
465
+
466
+ def __init__(
467
+ self,
468
+ model_runner: ModelRunner,
469
+ topk: int,
470
+ speculative_num_steps: int,
471
+ ):
472
+ from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
473
+
474
+ self.topk = topk
475
+ self.speculative_num_steps = speculative_num_steps
476
+ self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
477
+ max_bs = model_runner.req_to_token_pool.size
478
+ self.kv_indptr = torch.zeros(
479
+ (
480
+ self.speculative_num_steps,
481
+ max_bs + 1,
482
+ ),
483
+ dtype=torch.int32,
484
+ device=model_runner.device,
485
+ )
486
+ self.attn_backends = []
487
+ for i in range(self.speculative_num_steps):
488
+ self.attn_backends.append(
489
+ TritonAttnBackend(
490
+ model_runner,
491
+ skip_prefill=True,
492
+ kv_indptr_buf=self.kv_indptr[i],
493
+ )
494
+ )
495
+ self.max_context_len = self.attn_backends[0].max_context_len
496
+ self.device = model_runner.device
497
+ # Cached variables for generate_draft_decode_kv_indices
498
+ self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
499
+
500
+ def common_template(
501
+ self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
502
+ ):
503
+ num_seqs = forward_batch.batch_size
504
+ bs = self.topk * num_seqs
505
+ seq_lens_sum = forward_batch.seq_lens_sum
506
+
507
+ self.generate_draft_decode_kv_indices[
508
+ (self.speculative_num_steps, num_seqs, self.topk)
509
+ ](
510
+ forward_batch.req_pool_indices,
511
+ forward_batch.req_to_token_pool.req_to_token,
512
+ forward_batch.seq_lens,
513
+ kv_indices_buffer,
514
+ self.kv_indptr,
515
+ forward_batch.positions,
516
+ num_seqs,
517
+ self.topk,
518
+ self.pool_len,
519
+ kv_indices_buffer.shape[1],
520
+ self.kv_indptr.shape[1],
521
+ triton.next_power_of_2(num_seqs),
522
+ triton.next_power_of_2(self.speculative_num_steps),
523
+ triton.next_power_of_2(bs),
524
+ )
525
+
526
+ for i in range(self.speculative_num_steps):
527
+ forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
528
+ forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
529
+ : seq_lens_sum * self.topk + bs * (i + 1)
530
+ ]
531
+ call_fn(i, forward_batch)
532
+
533
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
534
+ kv_indices = torch.zeros(
535
+ (
536
+ self.speculative_num_steps,
537
+ forward_batch.batch_size * self.topk * self.max_context_len,
538
+ ),
539
+ dtype=torch.int32,
540
+ device=self.device,
541
+ )
542
+
543
+ def call_fn(i, forward_batch):
544
+ forward_batch.spec_info.kv_indptr = (
545
+ forward_batch.spec_info.kv_indptr.clone()
546
+ )
547
+ forward_batch.spec_info.kv_indices = (
548
+ forward_batch.spec_info.kv_indices.clone()
549
+ )
550
+ self.attn_backends[i].init_forward_metadata(forward_batch)
551
+
552
+ self.common_template(forward_batch, kv_indices, call_fn)
553
+
554
+ def init_cuda_graph_state(self, max_bs: int):
555
+ self.cuda_graph_kv_indices = torch.zeros(
556
+ (self.speculative_num_steps, max_bs * self.max_context_len),
557
+ dtype=torch.int32,
558
+ device=self.device,
559
+ )
560
+ for i in range(self.speculative_num_steps):
561
+ self.attn_backends[i].init_cuda_graph_state(
562
+ max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
563
+ )
564
+
565
+ def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
566
+ def call_fn(i, forward_batch):
567
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
568
+ forward_batch.batch_size,
569
+ forward_batch.batch_size * self.topk,
570
+ forward_batch.req_pool_indices,
571
+ forward_batch.seq_lens,
572
+ encoder_lens=None,
573
+ forward_mode=ForwardMode.DECODE,
574
+ spec_info=forward_batch.spec_info,
575
+ )
576
+
577
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
578
+
579
+ def init_forward_metadata_replay_cuda_graph(self, forward_batch):
580
+ def call_fn(i, forward_batch):
581
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
582
+ forward_batch.batch_size,
583
+ forward_batch.req_pool_indices,
584
+ forward_batch.seq_lens,
585
+ seq_lens_sum=-1,
586
+ encoder_lens=None,
587
+ forward_mode=ForwardMode.DECODE,
588
+ spec_info=forward_batch.spec_info,
589
+ )
590
+
591
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
@@ -50,7 +50,7 @@ def _fwd_kernel(
50
50
  kv_indptr,
51
51
  kv_indices,
52
52
  mask_ptr,
53
- mask_offsets,
53
+ mask_indptr,
54
54
  sm_scale,
55
55
  kv_group_num,
56
56
  stride_qbs,
@@ -87,7 +87,7 @@ def _fwd_kernel(
87
87
  cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend
88
88
 
89
89
  if USE_CUSTOM_MASK:
90
- cur_seq_mask_start_idx = tl.load(mask_offsets + cur_seq)
90
+ cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
91
91
 
92
92
  offs_d = tl.arange(0, BLOCK_DMODEL)
93
93
  offs_dv = tl.arange(0, BLOCK_DV)
@@ -288,7 +288,7 @@ def extend_attention_fwd(
288
288
  kv_indptr,
289
289
  kv_indices,
290
290
  custom_mask,
291
- mask_offsets,
291
+ mask_indptr,
292
292
  max_len_extend,
293
293
  sm_scale=None,
294
294
  logit_cap=0.0,
@@ -364,7 +364,7 @@ def extend_attention_fwd(
364
364
  kv_indptr,
365
365
  kv_indices,
366
366
  custom_mask,
367
- mask_offsets,
367
+ mask_indptr,
368
368
  sm_scale,
369
369
  kv_group_num,
370
370
  q_extend.stride(0),
@@ -421,11 +421,18 @@ class ColumnParallelLinear(LinearBase):
421
421
  if len(loaded_weight.shape) == 0:
422
422
  assert loaded_weight.numel() == 1
423
423
  loaded_weight = loaded_weight.reshape(1)
424
- param.load_column_parallel_weight(
425
- loaded_weight,
426
- tp_rank=self.tp_rank,
427
- use_presharded_weights=self.use_presharded_weights,
428
- )
424
+
425
+ from sglang.srt.layers.parameter import _ColumnvLLMParameter
426
+
427
+ if isinstance(param, _ColumnvLLMParameter):
428
+ # FIXME: why would we need this special case?
429
+ param.load_column_parallel_weight(
430
+ loaded_weight,
431
+ tp_rank=self.tp_rank,
432
+ use_presharded_weights=self.use_presharded_weights,
433
+ )
434
+ else:
435
+ param.load_column_parallel_weight(loaded_weight)
429
436
 
430
437
  def forward(self, input_):
431
438
  bias = self.bias if not self.skip_bias_add else None
@@ -72,10 +72,10 @@
72
72
  "waves_per_eu": 0
73
73
  },
74
74
  "64": {
75
- "BLOCK_SIZE_M": 256,
75
+ "BLOCK_SIZE_M": 32,
76
76
  "BLOCK_SIZE_N": 128,
77
77
  "BLOCK_SIZE_K": 128,
78
- "GROUP_SIZE_M": 1,
78
+ "GROUP_SIZE_M": 4,
79
79
  "num_warps": 4,
80
80
  "num_stages": 2,
81
81
  "waves_per_eu": 0