sglang 0.4.0__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/srt/constrained/outlines_backend.py +5 -0
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +20 -5
- sglang/srt/layers/attention/torch_native_backend.py +22 -8
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/managers/schedule_batch.py +1 -0
- sglang/srt/managers/scheduler.py +69 -65
- sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +15 -1
- sglang/srt/model_executor/model_runner.py +11 -4
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/deepseek_v2.py +87 -7
- sglang/srt/models/grok.py +0 -5
- sglang/srt/models/llama.py +0 -5
- sglang/srt/models/mixtral.py +12 -9
- sglang/srt/models/phi3_small.py +0 -5
- sglang/srt/models/qwen2_moe.py +0 -5
- sglang/srt/models/torch_native_llama.py +0 -5
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +3 -3
- sglang/srt/server_args.py +43 -4
- sglang/srt/utils.py +50 -0
- sglang/version.py +1 -1
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/RECORD +43 -38
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -42,6 +42,7 @@ class OutlinesGrammar(BaseGrammarObject):
|
|
42
42
|
self.guide = guide
|
43
43
|
self.jump_forward_map = jump_forward_map
|
44
44
|
self.state = 0
|
45
|
+
self.finished = False
|
45
46
|
|
46
47
|
def accept_token(self, token: int):
|
47
48
|
self.state = self.guide.get_next_state(self.state, token)
|
@@ -84,6 +85,10 @@ class OutlinesGrammar(BaseGrammarObject):
|
|
84
85
|
) -> torch.Tensor:
|
85
86
|
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
86
87
|
|
88
|
+
@staticmethod
|
89
|
+
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
90
|
+
return vocab_mask
|
91
|
+
|
87
92
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
88
93
|
tokens = torch.tensor(
|
89
94
|
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
@@ -45,6 +45,7 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
45
45
|
self.matcher = matcher
|
46
46
|
self.vocab_size = vocab_size
|
47
47
|
self.ctx = ctx
|
48
|
+
self.finished = False
|
48
49
|
|
49
50
|
def accept_token(self, token: int):
|
50
51
|
assert self.matcher.accept_token(token)
|
@@ -85,12 +86,11 @@ class XGrammarGrammar(BaseGrammarObject):
|
|
85
86
|
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
86
87
|
|
87
88
|
@staticmethod
|
88
|
-
def
|
89
|
-
|
90
|
-
# vocab_mask must then be on the same device as logits
|
91
|
-
# when applying the token bitmask, so we check and move if needed
|
92
|
-
vocab_mask = vocab_mask.to(logits.device)
|
89
|
+
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
90
|
+
return vocab_mask.to(device, non_blocking=True)
|
93
91
|
|
92
|
+
@staticmethod
|
93
|
+
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
94
94
|
apply_token_bitmask_inplace(logits, vocab_mask)
|
95
95
|
|
96
96
|
def copy(self):
|
@@ -52,12 +52,13 @@ class AttentionBackend(ABC):
|
|
52
52
|
v: torch.Tensor,
|
53
53
|
layer: RadixAttention,
|
54
54
|
forward_batch: ForwardBatch,
|
55
|
+
save_kv_cache: bool = True,
|
55
56
|
):
|
56
57
|
"""Run forward on an attention layer."""
|
57
58
|
if forward_batch.forward_mode.is_decode():
|
58
|
-
return self.forward_decode(q, k, v, layer, forward_batch)
|
59
|
+
return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
|
59
60
|
else:
|
60
|
-
return self.forward_extend(q, k, v, layer, forward_batch)
|
61
|
+
return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache)
|
61
62
|
|
62
63
|
def forward_decode(
|
63
64
|
self,
|
@@ -66,6 +67,7 @@ class AttentionBackend(ABC):
|
|
66
67
|
v: torch.Tensor,
|
67
68
|
layer: RadixAttention,
|
68
69
|
forward_batch: ForwardBatch,
|
70
|
+
save_kv_cache: bool = True,
|
69
71
|
):
|
70
72
|
"""Run a forward for decode."""
|
71
73
|
raise NotImplementedError()
|
@@ -77,6 +79,7 @@ class AttentionBackend(ABC):
|
|
77
79
|
v: torch.Tensor,
|
78
80
|
layer: RadixAttention,
|
79
81
|
forward_batch: ForwardBatch,
|
82
|
+
save_kv_cache: bool = True,
|
80
83
|
):
|
81
84
|
"""Run a forward for extend."""
|
82
85
|
raise NotImplementedError()
|
@@ -165,7 +165,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
165
165
|
return 1
|
166
166
|
|
167
167
|
def forward_extend(
|
168
|
-
self,
|
168
|
+
self,
|
169
|
+
q,
|
170
|
+
k,
|
171
|
+
v,
|
172
|
+
layer: RadixAttention,
|
173
|
+
forward_batch: ForwardBatch,
|
174
|
+
save_kv_cache=True,
|
169
175
|
):
|
170
176
|
# TODO: reuse the buffer across layers
|
171
177
|
if layer.qk_head_dim != layer.v_head_dim:
|
@@ -181,9 +187,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
181
187
|
.expand(k.shape[0], -1, -1),
|
182
188
|
)
|
183
189
|
|
184
|
-
|
185
|
-
|
186
|
-
|
190
|
+
if save_kv_cache:
|
191
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
192
|
+
layer, forward_batch.out_cache_loc, k, v, k_label
|
193
|
+
)
|
187
194
|
|
188
195
|
(
|
189
196
|
start_loc,
|
@@ -212,7 +219,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
212
219
|
return o
|
213
220
|
|
214
221
|
def forward_decode(
|
215
|
-
self,
|
222
|
+
self,
|
223
|
+
q,
|
224
|
+
k,
|
225
|
+
v,
|
226
|
+
layer: RadixAttention,
|
227
|
+
forward_batch: ForwardBatch,
|
228
|
+
save_kv_cache=True,
|
216
229
|
):
|
217
230
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
218
231
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
@@ -242,9 +255,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
242
255
|
.expand(k.shape[0], -1, -1),
|
243
256
|
)
|
244
257
|
|
245
|
-
|
246
|
-
|
247
|
-
|
258
|
+
if save_kv_cache:
|
259
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
260
|
+
layer, forward_batch.out_cache_loc, k, v, k_label
|
261
|
+
)
|
248
262
|
|
249
263
|
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
|
250
264
|
# and set a minimum value for sparse_decode
|
@@ -221,7 +221,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
221
221
|
return 0
|
222
222
|
|
223
223
|
def forward_extend(
|
224
|
-
self,
|
224
|
+
self,
|
225
|
+
q,
|
226
|
+
k,
|
227
|
+
v,
|
228
|
+
layer: RadixAttention,
|
229
|
+
forward_batch: ForwardBatch,
|
230
|
+
save_kv_cache=True,
|
225
231
|
):
|
226
232
|
prefill_wrapper_paged = self.prefill_wrappers_paged[
|
227
233
|
self._get_wrapper_idx(layer)
|
@@ -237,7 +243,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
237
243
|
if not use_ragged:
|
238
244
|
if k is not None:
|
239
245
|
assert v is not None
|
240
|
-
|
246
|
+
if save_kv_cache:
|
247
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
241
248
|
|
242
249
|
o = prefill_wrapper_paged.forward(
|
243
250
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
@@ -270,12 +277,19 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
270
277
|
|
271
278
|
o, _ = merge_state(o1, s1, o2, s2)
|
272
279
|
|
273
|
-
|
280
|
+
if save_kv_cache:
|
281
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
274
282
|
|
275
283
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
276
284
|
|
277
285
|
def forward_decode(
|
278
|
-
self,
|
286
|
+
self,
|
287
|
+
q,
|
288
|
+
k,
|
289
|
+
v,
|
290
|
+
layer: RadixAttention,
|
291
|
+
forward_batch: ForwardBatch,
|
292
|
+
save_kv_cache=True,
|
279
293
|
):
|
280
294
|
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
|
281
295
|
cache_loc = (
|
@@ -286,7 +300,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
286
300
|
|
287
301
|
if k is not None:
|
288
302
|
assert v is not None
|
289
|
-
|
303
|
+
if save_kv_cache:
|
304
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
290
305
|
|
291
306
|
o = decode_wrapper.forward(
|
292
307
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
@@ -216,16 +216,23 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
216
216
|
return output
|
217
217
|
|
218
218
|
def forward_extend(
|
219
|
-
self,
|
219
|
+
self,
|
220
|
+
q,
|
221
|
+
k,
|
222
|
+
v,
|
223
|
+
layer: RadixAttention,
|
224
|
+
forward_batch: ForwardBatch,
|
225
|
+
save_kv_cache=True,
|
220
226
|
):
|
221
227
|
if layer.qk_head_dim != layer.v_head_dim:
|
222
228
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
223
229
|
else:
|
224
230
|
o = torch.empty_like(q)
|
225
231
|
|
226
|
-
|
227
|
-
|
228
|
-
|
232
|
+
if save_kv_cache:
|
233
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
234
|
+
layer, forward_batch.out_cache_loc, k, v
|
235
|
+
)
|
229
236
|
|
230
237
|
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
231
238
|
|
@@ -249,7 +256,13 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
249
256
|
return o
|
250
257
|
|
251
258
|
def forward_decode(
|
252
|
-
self,
|
259
|
+
self,
|
260
|
+
q,
|
261
|
+
k,
|
262
|
+
v,
|
263
|
+
layer: RadixAttention,
|
264
|
+
forward_batch: ForwardBatch,
|
265
|
+
save_kv_cache=True,
|
253
266
|
):
|
254
267
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
255
268
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
@@ -260,9 +273,10 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
260
273
|
else:
|
261
274
|
o = torch.empty_like(q)
|
262
275
|
|
263
|
-
|
264
|
-
|
265
|
-
|
276
|
+
if save_kv_cache:
|
277
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
278
|
+
layer, forward_batch.out_cache_loc, k, v
|
279
|
+
)
|
266
280
|
|
267
281
|
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
268
282
|
|
@@ -114,7 +114,13 @@ class TritonAttnBackend(AttentionBackend):
|
|
114
114
|
return 1
|
115
115
|
|
116
116
|
def forward_extend(
|
117
|
-
self,
|
117
|
+
self,
|
118
|
+
q,
|
119
|
+
k,
|
120
|
+
v,
|
121
|
+
layer: RadixAttention,
|
122
|
+
forward_batch: ForwardBatch,
|
123
|
+
save_kv_cache=True,
|
118
124
|
):
|
119
125
|
# TODO: reuse the buffer across layers
|
120
126
|
if layer.qk_head_dim != layer.v_head_dim:
|
@@ -122,9 +128,10 @@ class TritonAttnBackend(AttentionBackend):
|
|
122
128
|
else:
|
123
129
|
o = torch.empty_like(q)
|
124
130
|
|
125
|
-
|
126
|
-
|
127
|
-
|
131
|
+
if save_kv_cache:
|
132
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
133
|
+
layer, forward_batch.out_cache_loc, k, v
|
134
|
+
)
|
128
135
|
|
129
136
|
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
130
137
|
self.extend_attention_fwd(
|
@@ -146,7 +153,13 @@ class TritonAttnBackend(AttentionBackend):
|
|
146
153
|
return o
|
147
154
|
|
148
155
|
def forward_decode(
|
149
|
-
self,
|
156
|
+
self,
|
157
|
+
q,
|
158
|
+
k,
|
159
|
+
v,
|
160
|
+
layer: RadixAttention,
|
161
|
+
forward_batch: ForwardBatch,
|
162
|
+
save_kv_cache=True,
|
150
163
|
):
|
151
164
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
152
165
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
@@ -160,9 +173,10 @@ class TritonAttnBackend(AttentionBackend):
|
|
160
173
|
|
161
174
|
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
162
175
|
|
163
|
-
|
164
|
-
|
165
|
-
|
176
|
+
if save_kv_cache:
|
177
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
178
|
+
layer, forward_batch.out_cache_loc, k, v
|
179
|
+
)
|
166
180
|
|
167
181
|
self.decode_attention_fwd(
|
168
182
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
File without changes
|
@@ -0,0 +1,349 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import triton
|
6
|
+
import triton.language as tl
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
@triton.jit
|
12
|
+
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
13
|
+
expert = tl.program_id(0)
|
14
|
+
low = 0
|
15
|
+
high = num_toks - 1
|
16
|
+
target_location = -1
|
17
|
+
while low <= high:
|
18
|
+
mid = (low + high) // 2
|
19
|
+
|
20
|
+
if tl.load(reorder_topk_ids + mid) > expert:
|
21
|
+
high = mid - 1
|
22
|
+
else:
|
23
|
+
low = mid + 1
|
24
|
+
target_location = mid
|
25
|
+
tl.store(seg_indptr + expert + 1, target_location + 1)
|
26
|
+
|
27
|
+
|
28
|
+
@triton.jit
|
29
|
+
def compute_src2dst_triton_kernel(
|
30
|
+
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
|
31
|
+
):
|
32
|
+
pid = tl.program_id(axis=0)
|
33
|
+
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
34
|
+
mask = dst_id < num_toks
|
35
|
+
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
36
|
+
tl.store(src2dst + src_id, dst_id, mask=mask)
|
37
|
+
|
38
|
+
|
39
|
+
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
40
|
+
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
41
|
+
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
42
|
+
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
43
|
+
|
44
|
+
compute_seg_indptr_triton_kernel[(num_experts,)](
|
45
|
+
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
46
|
+
)
|
47
|
+
|
48
|
+
BLOCK_SIZE = 512
|
49
|
+
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
|
50
|
+
compute_src2dst_triton_kernel[grid](
|
51
|
+
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
|
52
|
+
)
|
53
|
+
return reorder_topk_ids, src2dst, seg_indptr
|
54
|
+
|
55
|
+
|
56
|
+
@triton.jit
|
57
|
+
def pre_reorder_triton_kernel(
|
58
|
+
input_ptr,
|
59
|
+
gateup_input_ptr,
|
60
|
+
src2dst_ptr,
|
61
|
+
topk_ids_ptr,
|
62
|
+
a1_scales_ptr,
|
63
|
+
start_expert_id,
|
64
|
+
end_expert_id,
|
65
|
+
topk,
|
66
|
+
hidden_size,
|
67
|
+
BLOCK_SIZE: tl.constexpr,
|
68
|
+
):
|
69
|
+
OutDtype = gateup_input_ptr.dtype.element_ty
|
70
|
+
|
71
|
+
src_idx = tl.program_id(0)
|
72
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
73
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
74
|
+
|
75
|
+
src_ptr = input_ptr + src_idx * hidden_size
|
76
|
+
for idx in range(topk):
|
77
|
+
expert_id = tl.load(topk_ids_ptr + idx)
|
78
|
+
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
79
|
+
if a1_scales_ptr is not None:
|
80
|
+
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
|
81
|
+
else:
|
82
|
+
scale = 1.0
|
83
|
+
|
84
|
+
dst_idx = tl.load(src2dst_ptr + idx)
|
85
|
+
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
86
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
87
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
88
|
+
mask = offset < hidden_size
|
89
|
+
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
90
|
+
out_data = (in_data * scale).to(OutDtype)
|
91
|
+
tl.store(dst_ptr + offset, out_data, mask=mask)
|
92
|
+
|
93
|
+
|
94
|
+
@triton.jit
|
95
|
+
def silu_and_mul_triton_kernel(
|
96
|
+
gateup_output,
|
97
|
+
down_input,
|
98
|
+
hidden_size,
|
99
|
+
reorder_topk_ids,
|
100
|
+
scales,
|
101
|
+
start_expert_id,
|
102
|
+
end_expert_id,
|
103
|
+
BLOCK_SIZE: tl.constexpr,
|
104
|
+
):
|
105
|
+
InDtype = gateup_output.dtype.element_ty
|
106
|
+
OutDtype = down_input.dtype.element_ty
|
107
|
+
|
108
|
+
half_hidden_size = hidden_size // 2
|
109
|
+
|
110
|
+
pid = tl.program_id(0)
|
111
|
+
expert_id = tl.load(reorder_topk_ids + pid)
|
112
|
+
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
113
|
+
gateup_output_ptr = gateup_output + pid * hidden_size
|
114
|
+
gate_output_ptr = gateup_output_ptr
|
115
|
+
up_output_ptr = gateup_output_ptr + half_hidden_size
|
116
|
+
down_input_ptr = down_input + pid * half_hidden_size
|
117
|
+
|
118
|
+
if scales is not None:
|
119
|
+
scale = tl.load(scales + expert_id - start_expert_id)
|
120
|
+
scale = (1 / scale).to(InDtype)
|
121
|
+
else:
|
122
|
+
scale = 1
|
123
|
+
|
124
|
+
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
|
125
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
126
|
+
mask = offset < half_hidden_size
|
127
|
+
|
128
|
+
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
|
129
|
+
up_output = tl.load(up_output_ptr + offset, mask=mask)
|
130
|
+
|
131
|
+
# silu & mul & quantize
|
132
|
+
gate_output = gate_output * tl.sigmoid(gate_output)
|
133
|
+
gate_output = gate_output.to(InDtype)
|
134
|
+
|
135
|
+
silu_mul_output = gate_output * up_output * scale
|
136
|
+
silu_mul_output = silu_mul_output.to(OutDtype)
|
137
|
+
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
138
|
+
|
139
|
+
|
140
|
+
@triton.jit
|
141
|
+
def post_reorder_triton_kernel(
|
142
|
+
down_output_ptr,
|
143
|
+
output_ptr,
|
144
|
+
src2dst_ptr,
|
145
|
+
topk_ids_ptr,
|
146
|
+
topk_weights_ptr,
|
147
|
+
start_expert_id,
|
148
|
+
end_expert_id,
|
149
|
+
topk,
|
150
|
+
hidden_size,
|
151
|
+
BLOCK_SIZE: tl.constexpr,
|
152
|
+
):
|
153
|
+
InDtype = down_output_ptr.dtype.element_ty
|
154
|
+
|
155
|
+
src_idx = tl.program_id(0)
|
156
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
157
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
158
|
+
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
159
|
+
|
160
|
+
computed = False
|
161
|
+
store_ptr = output_ptr + src_idx * hidden_size
|
162
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
163
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
164
|
+
mask = offset < hidden_size
|
165
|
+
|
166
|
+
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
167
|
+
for idx in range(topk):
|
168
|
+
expert_id = tl.load(topk_ids_ptr + idx)
|
169
|
+
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
170
|
+
computed = True
|
171
|
+
dst_idx = tl.load(src2dst_ptr + idx)
|
172
|
+
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
173
|
+
load_ptr = down_output_ptr + dst_idx * hidden_size
|
174
|
+
in_data = tl.load(load_ptr + offset, mask=mask)
|
175
|
+
sum_vec += in_data * weigh_scale
|
176
|
+
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
177
|
+
|
178
|
+
if computed == False:
|
179
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
180
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
181
|
+
mask = offset < hidden_size
|
182
|
+
tl.store(
|
183
|
+
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
|
184
|
+
)
|
185
|
+
|
186
|
+
|
187
|
+
@triton.jit
|
188
|
+
def compute_m_range(
|
189
|
+
pid,
|
190
|
+
batch_size,
|
191
|
+
seg_indptr,
|
192
|
+
weight_indices,
|
193
|
+
m_num_tiles_indptr,
|
194
|
+
BLOCK_SIZE_M: tl.constexpr,
|
195
|
+
):
|
196
|
+
idx = 0
|
197
|
+
for bs in range(batch_size):
|
198
|
+
tiles = tl.load(m_num_tiles_indptr + bs)
|
199
|
+
if pid >= tiles:
|
200
|
+
idx = bs
|
201
|
+
|
202
|
+
idx_start = tl.load(m_num_tiles_indptr + idx)
|
203
|
+
|
204
|
+
m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
|
205
|
+
m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
|
206
|
+
expert_id = tl.load(weight_indices + idx)
|
207
|
+
return m_range_start, m_range_end, expert_id
|
208
|
+
|
209
|
+
|
210
|
+
@triton.jit
|
211
|
+
def grouped_gemm_triton_kernel(
|
212
|
+
a,
|
213
|
+
b,
|
214
|
+
c,
|
215
|
+
batch_size,
|
216
|
+
N,
|
217
|
+
K,
|
218
|
+
seg_indptr,
|
219
|
+
weight_indices,
|
220
|
+
m_num_tiles_indptr,
|
221
|
+
use_fp8_w8a8,
|
222
|
+
scale_a,
|
223
|
+
scale_b,
|
224
|
+
a_stride_0: tl.constexpr,
|
225
|
+
b_stride_0: tl.constexpr,
|
226
|
+
b_stride_1: tl.constexpr,
|
227
|
+
BLOCK_SIZE_M: tl.constexpr,
|
228
|
+
BLOCK_SIZE_N: tl.constexpr,
|
229
|
+
BLOCK_SIZE_K: tl.constexpr,
|
230
|
+
):
|
231
|
+
c_dtype = c.dtype.element_ty
|
232
|
+
|
233
|
+
pid_m = tl.program_id(0)
|
234
|
+
pid_n = tl.program_id(1)
|
235
|
+
total_m_block = tl.load(m_num_tiles_indptr + batch_size)
|
236
|
+
if pid_m >= total_m_block:
|
237
|
+
return
|
238
|
+
|
239
|
+
m_range_start, m_range_end, expert_id = compute_m_range(
|
240
|
+
pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
|
241
|
+
)
|
242
|
+
if m_range_end - m_range_start == 0:
|
243
|
+
return
|
244
|
+
|
245
|
+
n_range_start = pid_n * BLOCK_SIZE_N
|
246
|
+
n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
|
247
|
+
|
248
|
+
offs_am = tl.arange(0, BLOCK_SIZE_M)
|
249
|
+
offs_bn = tl.arange(0, BLOCK_SIZE_N)
|
250
|
+
|
251
|
+
offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
|
252
|
+
offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
|
253
|
+
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
254
|
+
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
255
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
256
|
+
|
257
|
+
a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
|
258
|
+
b_ptr = b + (
|
259
|
+
(expert_id * b_stride_0)
|
260
|
+
+ (n_range_start + offs_bn[:, None]) * b_stride_1
|
261
|
+
+ offs_k[None, :]
|
262
|
+
)
|
263
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
264
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
265
|
+
a_tile = tl.load(
|
266
|
+
a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
267
|
+
)
|
268
|
+
b_tile = tl.load(
|
269
|
+
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
270
|
+
)
|
271
|
+
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
|
272
|
+
a_ptr += BLOCK_SIZE_K
|
273
|
+
b_ptr += BLOCK_SIZE_K
|
274
|
+
|
275
|
+
if use_fp8_w8a8:
|
276
|
+
scale_a_value = tl.load(scale_a + expert_id)
|
277
|
+
scale_b_value = tl.load(scale_b + expert_id)
|
278
|
+
accumulator *= scale_a_value * scale_b_value
|
279
|
+
c_tile = accumulator.to(c_dtype)
|
280
|
+
|
281
|
+
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
|
282
|
+
offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
|
283
|
+
c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
|
284
|
+
c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
|
285
|
+
tl.store(c_ptr, c_tile, mask=c_mask)
|
286
|
+
|
287
|
+
|
288
|
+
@triton.jit
|
289
|
+
def compute_m_num_tiles_indptr(
|
290
|
+
m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
|
291
|
+
):
|
292
|
+
for bs in range(batch_size):
|
293
|
+
m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
|
294
|
+
cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
|
295
|
+
pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
|
296
|
+
tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
|
297
|
+
|
298
|
+
|
299
|
+
def grouped_gemm_triton(
|
300
|
+
a: torch.Tensor,
|
301
|
+
b: torch.Tensor,
|
302
|
+
c: torch.Tensor,
|
303
|
+
batch_size: int,
|
304
|
+
weight_column_major: bool,
|
305
|
+
seg_indptr: Optional[torch.Tensor] = None,
|
306
|
+
weight_indices: Optional[torch.Tensor] = None,
|
307
|
+
use_fp8_w8a8: bool = False,
|
308
|
+
scale_a: torch.Tensor = None,
|
309
|
+
scale_b: torch.Tensor = None,
|
310
|
+
):
|
311
|
+
assert weight_column_major == True # TODO: more
|
312
|
+
if use_fp8_w8a8:
|
313
|
+
assert scale_a is not None and scale_b is not None
|
314
|
+
|
315
|
+
config = {
|
316
|
+
"BLOCK_SIZE_M": 128,
|
317
|
+
"BLOCK_SIZE_N": 128,
|
318
|
+
"BLOCK_SIZE_K": 128,
|
319
|
+
}
|
320
|
+
|
321
|
+
m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
|
322
|
+
compute_m_num_tiles_indptr[(1,)](
|
323
|
+
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
|
324
|
+
)
|
325
|
+
|
326
|
+
grid = lambda META: (
|
327
|
+
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
|
328
|
+
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
|
329
|
+
)
|
330
|
+
|
331
|
+
grouped_gemm_triton_kernel[grid](
|
332
|
+
a,
|
333
|
+
b,
|
334
|
+
c,
|
335
|
+
batch_size,
|
336
|
+
b.size(1),
|
337
|
+
b.size(2),
|
338
|
+
seg_indptr,
|
339
|
+
weight_indices,
|
340
|
+
m_num_tiles_indptr,
|
341
|
+
use_fp8_w8a8,
|
342
|
+
scale_a,
|
343
|
+
scale_b,
|
344
|
+
a.stride(0),
|
345
|
+
b.stride(0),
|
346
|
+
b.stride(1),
|
347
|
+
**config,
|
348
|
+
)
|
349
|
+
return c
|