sglang 0.4.0__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/srt/constrained/outlines_backend.py +5 -0
  3. sglang/srt/constrained/xgrammar_backend.py +5 -5
  4. sglang/srt/layers/attention/__init__.py +5 -2
  5. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  6. sglang/srt/layers/attention/flashinfer_backend.py +20 -5
  7. sglang/srt/layers/attention/torch_native_backend.py +22 -8
  8. sglang/srt/layers/attention/triton_backend.py +22 -8
  9. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  10. sglang/srt/layers/ep_moe/__init__.py +0 -0
  11. sglang/srt/layers/ep_moe/kernels.py +349 -0
  12. sglang/srt/layers/ep_moe/layer.py +661 -0
  13. sglang/srt/layers/quantization/__init__.py +2 -2
  14. sglang/srt/layers/quantization/fp8.py +559 -0
  15. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  16. sglang/srt/layers/radix_attention.py +4 -2
  17. sglang/srt/layers/sampler.py +2 -0
  18. sglang/srt/layers/torchao_utils.py +23 -45
  19. sglang/srt/managers/schedule_batch.py +1 -0
  20. sglang/srt/managers/scheduler.py +69 -65
  21. sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
  22. sglang/srt/mem_cache/memory_pool.py +5 -1
  23. sglang/srt/model_executor/cuda_graph_runner.py +15 -1
  24. sglang/srt/model_executor/model_runner.py +11 -4
  25. sglang/srt/model_parallel.py +1 -5
  26. sglang/srt/models/commandr.py +2 -2
  27. sglang/srt/models/deepseek_v2.py +87 -7
  28. sglang/srt/models/grok.py +0 -5
  29. sglang/srt/models/llama.py +0 -5
  30. sglang/srt/models/mixtral.py +12 -9
  31. sglang/srt/models/phi3_small.py +0 -5
  32. sglang/srt/models/qwen2_moe.py +0 -5
  33. sglang/srt/models/torch_native_llama.py +0 -5
  34. sglang/srt/sampling/sampling_batch_info.py +9 -8
  35. sglang/srt/server.py +3 -3
  36. sglang/srt/server_args.py +43 -4
  37. sglang/srt/utils.py +50 -0
  38. sglang/version.py +1 -1
  39. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  40. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/RECORD +43 -38
  41. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  42. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  43. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -66,7 +66,7 @@ from sglang.version import __version__
66
66
 
67
67
  __all__ += ["__version__"]
68
68
 
69
- # SGL Backends
69
+ # SGLang Backends
70
70
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
71
71
  from sglang.utils import LazyImport
72
72
 
@@ -42,6 +42,7 @@ class OutlinesGrammar(BaseGrammarObject):
42
42
  self.guide = guide
43
43
  self.jump_forward_map = jump_forward_map
44
44
  self.state = 0
45
+ self.finished = False
45
46
 
46
47
  def accept_token(self, token: int):
47
48
  self.state = self.guide.get_next_state(self.state, token)
@@ -84,6 +85,10 @@ class OutlinesGrammar(BaseGrammarObject):
84
85
  ) -> torch.Tensor:
85
86
  return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
86
87
 
88
+ @staticmethod
89
+ def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
90
+ return vocab_mask
91
+
87
92
  def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
88
93
  tokens = torch.tensor(
89
94
  self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
@@ -45,6 +45,7 @@ class XGrammarGrammar(BaseGrammarObject):
45
45
  self.matcher = matcher
46
46
  self.vocab_size = vocab_size
47
47
  self.ctx = ctx
48
+ self.finished = False
48
49
 
49
50
  def accept_token(self, token: int):
50
51
  assert self.matcher.accept_token(token)
@@ -85,12 +86,11 @@ class XGrammarGrammar(BaseGrammarObject):
85
86
  self.matcher.fill_next_token_bitmask(vocab_mask, idx)
86
87
 
87
88
  @staticmethod
88
- def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
89
- if vocab_mask.device.type != logits.device.type:
90
- # vocab_mask must then be on the same device as logits
91
- # when applying the token bitmask, so we check and move if needed
92
- vocab_mask = vocab_mask.to(logits.device)
89
+ def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
90
+ return vocab_mask.to(device, non_blocking=True)
93
91
 
92
+ @staticmethod
93
+ def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
94
94
  apply_token_bitmask_inplace(logits, vocab_mask)
95
95
 
96
96
  def copy(self):
@@ -52,12 +52,13 @@ class AttentionBackend(ABC):
52
52
  v: torch.Tensor,
53
53
  layer: RadixAttention,
54
54
  forward_batch: ForwardBatch,
55
+ save_kv_cache: bool = True,
55
56
  ):
56
57
  """Run forward on an attention layer."""
57
58
  if forward_batch.forward_mode.is_decode():
58
- return self.forward_decode(q, k, v, layer, forward_batch)
59
+ return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
59
60
  else:
60
- return self.forward_extend(q, k, v, layer, forward_batch)
61
+ return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache)
61
62
 
62
63
  def forward_decode(
63
64
  self,
@@ -66,6 +67,7 @@ class AttentionBackend(ABC):
66
67
  v: torch.Tensor,
67
68
  layer: RadixAttention,
68
69
  forward_batch: ForwardBatch,
70
+ save_kv_cache: bool = True,
69
71
  ):
70
72
  """Run a forward for decode."""
71
73
  raise NotImplementedError()
@@ -77,6 +79,7 @@ class AttentionBackend(ABC):
77
79
  v: torch.Tensor,
78
80
  layer: RadixAttention,
79
81
  forward_batch: ForwardBatch,
82
+ save_kv_cache: bool = True,
80
83
  ):
81
84
  """Run a forward for extend."""
82
85
  raise NotImplementedError()
@@ -165,7 +165,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
165
165
  return 1
166
166
 
167
167
  def forward_extend(
168
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
168
+ self,
169
+ q,
170
+ k,
171
+ v,
172
+ layer: RadixAttention,
173
+ forward_batch: ForwardBatch,
174
+ save_kv_cache=True,
169
175
  ):
170
176
  # TODO: reuse the buffer across layers
171
177
  if layer.qk_head_dim != layer.v_head_dim:
@@ -181,9 +187,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
181
187
  .expand(k.shape[0], -1, -1),
182
188
  )
183
189
 
184
- forward_batch.token_to_kv_pool.set_kv_buffer(
185
- layer, forward_batch.out_cache_loc, k, v, k_label
186
- )
190
+ if save_kv_cache:
191
+ forward_batch.token_to_kv_pool.set_kv_buffer(
192
+ layer, forward_batch.out_cache_loc, k, v, k_label
193
+ )
187
194
 
188
195
  (
189
196
  start_loc,
@@ -212,7 +219,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
212
219
  return o
213
220
 
214
221
  def forward_decode(
215
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
222
+ self,
223
+ q,
224
+ k,
225
+ v,
226
+ layer: RadixAttention,
227
+ forward_batch: ForwardBatch,
228
+ save_kv_cache=True,
216
229
  ):
217
230
  # During torch.compile, there is a bug in rotary_emb that causes the
218
231
  # output value to have a 3D tensor shape. This reshapes the output correctly.
@@ -242,9 +255,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
242
255
  .expand(k.shape[0], -1, -1),
243
256
  )
244
257
 
245
- forward_batch.token_to_kv_pool.set_kv_buffer(
246
- layer, forward_batch.out_cache_loc, k, v, k_label
247
- )
258
+ if save_kv_cache:
259
+ forward_batch.token_to_kv_pool.set_kv_buffer(
260
+ layer, forward_batch.out_cache_loc, k, v, k_label
261
+ )
248
262
 
249
263
  # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
250
264
  # and set a minimum value for sparse_decode
@@ -221,7 +221,13 @@ class FlashInferAttnBackend(AttentionBackend):
221
221
  return 0
222
222
 
223
223
  def forward_extend(
224
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
224
+ self,
225
+ q,
226
+ k,
227
+ v,
228
+ layer: RadixAttention,
229
+ forward_batch: ForwardBatch,
230
+ save_kv_cache=True,
225
231
  ):
226
232
  prefill_wrapper_paged = self.prefill_wrappers_paged[
227
233
  self._get_wrapper_idx(layer)
@@ -237,7 +243,8 @@ class FlashInferAttnBackend(AttentionBackend):
237
243
  if not use_ragged:
238
244
  if k is not None:
239
245
  assert v is not None
240
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
246
+ if save_kv_cache:
247
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
241
248
 
242
249
  o = prefill_wrapper_paged.forward(
243
250
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -270,12 +277,19 @@ class FlashInferAttnBackend(AttentionBackend):
270
277
 
271
278
  o, _ = merge_state(o1, s1, o2, s2)
272
279
 
273
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
280
+ if save_kv_cache:
281
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
274
282
 
275
283
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
276
284
 
277
285
  def forward_decode(
278
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
286
+ self,
287
+ q,
288
+ k,
289
+ v,
290
+ layer: RadixAttention,
291
+ forward_batch: ForwardBatch,
292
+ save_kv_cache=True,
279
293
  ):
280
294
  decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
281
295
  cache_loc = (
@@ -286,7 +300,8 @@ class FlashInferAttnBackend(AttentionBackend):
286
300
 
287
301
  if k is not None:
288
302
  assert v is not None
289
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
303
+ if save_kv_cache:
304
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
290
305
 
291
306
  o = decode_wrapper.forward(
292
307
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -216,16 +216,23 @@ class TorchNativeAttnBackend(AttentionBackend):
216
216
  return output
217
217
 
218
218
  def forward_extend(
219
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
219
+ self,
220
+ q,
221
+ k,
222
+ v,
223
+ layer: RadixAttention,
224
+ forward_batch: ForwardBatch,
225
+ save_kv_cache=True,
220
226
  ):
221
227
  if layer.qk_head_dim != layer.v_head_dim:
222
228
  o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
223
229
  else:
224
230
  o = torch.empty_like(q)
225
231
 
226
- forward_batch.token_to_kv_pool.set_kv_buffer(
227
- layer, forward_batch.out_cache_loc, k, v
228
- )
232
+ if save_kv_cache:
233
+ forward_batch.token_to_kv_pool.set_kv_buffer(
234
+ layer, forward_batch.out_cache_loc, k, v
235
+ )
229
236
 
230
237
  use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
231
238
 
@@ -249,7 +256,13 @@ class TorchNativeAttnBackend(AttentionBackend):
249
256
  return o
250
257
 
251
258
  def forward_decode(
252
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
259
+ self,
260
+ q,
261
+ k,
262
+ v,
263
+ layer: RadixAttention,
264
+ forward_batch: ForwardBatch,
265
+ save_kv_cache=True,
253
266
  ):
254
267
  # During torch.compile, there is a bug in rotary_emb that causes the
255
268
  # output value to have a 3D tensor shape. This reshapes the output correctly.
@@ -260,9 +273,10 @@ class TorchNativeAttnBackend(AttentionBackend):
260
273
  else:
261
274
  o = torch.empty_like(q)
262
275
 
263
- forward_batch.token_to_kv_pool.set_kv_buffer(
264
- layer, forward_batch.out_cache_loc, k, v
265
- )
276
+ if save_kv_cache:
277
+ forward_batch.token_to_kv_pool.set_kv_buffer(
278
+ layer, forward_batch.out_cache_loc, k, v
279
+ )
266
280
 
267
281
  use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
268
282
 
@@ -114,7 +114,13 @@ class TritonAttnBackend(AttentionBackend):
114
114
  return 1
115
115
 
116
116
  def forward_extend(
117
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
117
+ self,
118
+ q,
119
+ k,
120
+ v,
121
+ layer: RadixAttention,
122
+ forward_batch: ForwardBatch,
123
+ save_kv_cache=True,
118
124
  ):
119
125
  # TODO: reuse the buffer across layers
120
126
  if layer.qk_head_dim != layer.v_head_dim:
@@ -122,9 +128,10 @@ class TritonAttnBackend(AttentionBackend):
122
128
  else:
123
129
  o = torch.empty_like(q)
124
130
 
125
- forward_batch.token_to_kv_pool.set_kv_buffer(
126
- layer, forward_batch.out_cache_loc, k, v
127
- )
131
+ if save_kv_cache:
132
+ forward_batch.token_to_kv_pool.set_kv_buffer(
133
+ layer, forward_batch.out_cache_loc, k, v
134
+ )
128
135
 
129
136
  start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
130
137
  self.extend_attention_fwd(
@@ -146,7 +153,13 @@ class TritonAttnBackend(AttentionBackend):
146
153
  return o
147
154
 
148
155
  def forward_decode(
149
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
156
+ self,
157
+ q,
158
+ k,
159
+ v,
160
+ layer: RadixAttention,
161
+ forward_batch: ForwardBatch,
162
+ save_kv_cache=True,
150
163
  ):
151
164
  # During torch.compile, there is a bug in rotary_emb that causes the
152
165
  # output value to have a 3D tensor shape. This reshapes the output correctly.
@@ -160,9 +173,10 @@ class TritonAttnBackend(AttentionBackend):
160
173
 
161
174
  start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
162
175
 
163
- forward_batch.token_to_kv_pool.set_kv_buffer(
164
- layer, forward_batch.out_cache_loc, k, v
165
- )
176
+ if save_kv_cache:
177
+ forward_batch.token_to_kv_pool.set_kv_buffer(
178
+ layer, forward_batch.out_cache_loc, k, v
179
+ )
166
180
 
167
181
  self.decode_attention_fwd(
168
182
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
@@ -284,6 +284,9 @@ def extend_attention_fwd(
284
284
  elif Lq == 288:
285
285
  BLOCK_DMODEL = 256
286
286
  BLOCK_DPE = 32
287
+ elif Lq == 192:
288
+ BLOCK_DMODEL = 128
289
+ BLOCK_DPE = 64
287
290
  else:
288
291
  BLOCK_DMODEL = triton.next_power_of_2(Lq)
289
292
  BLOCK_DPE = 0
File without changes
@@ -0,0 +1,349 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ @triton.jit
12
+ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
13
+ expert = tl.program_id(0)
14
+ low = 0
15
+ high = num_toks - 1
16
+ target_location = -1
17
+ while low <= high:
18
+ mid = (low + high) // 2
19
+
20
+ if tl.load(reorder_topk_ids + mid) > expert:
21
+ high = mid - 1
22
+ else:
23
+ low = mid + 1
24
+ target_location = mid
25
+ tl.store(seg_indptr + expert + 1, target_location + 1)
26
+
27
+
28
+ @triton.jit
29
+ def compute_src2dst_triton_kernel(
30
+ reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
31
+ ):
32
+ pid = tl.program_id(axis=0)
33
+ dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
34
+ mask = dst_id < num_toks
35
+ src_id = tl.load(reorder_ids + dst_id, mask=mask)
36
+ tl.store(src2dst + src_id, dst_id, mask=mask)
37
+
38
+
39
+ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
40
+ reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
41
+ seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
42
+ src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
43
+
44
+ compute_seg_indptr_triton_kernel[(num_experts,)](
45
+ reorder_topk_ids, seg_indptr, topk_ids.numel()
46
+ )
47
+
48
+ BLOCK_SIZE = 512
49
+ grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
50
+ compute_src2dst_triton_kernel[grid](
51
+ reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
52
+ )
53
+ return reorder_topk_ids, src2dst, seg_indptr
54
+
55
+
56
+ @triton.jit
57
+ def pre_reorder_triton_kernel(
58
+ input_ptr,
59
+ gateup_input_ptr,
60
+ src2dst_ptr,
61
+ topk_ids_ptr,
62
+ a1_scales_ptr,
63
+ start_expert_id,
64
+ end_expert_id,
65
+ topk,
66
+ hidden_size,
67
+ BLOCK_SIZE: tl.constexpr,
68
+ ):
69
+ OutDtype = gateup_input_ptr.dtype.element_ty
70
+
71
+ src_idx = tl.program_id(0)
72
+ src2dst_ptr = src2dst_ptr + src_idx * topk
73
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
74
+
75
+ src_ptr = input_ptr + src_idx * hidden_size
76
+ for idx in range(topk):
77
+ expert_id = tl.load(topk_ids_ptr + idx)
78
+ if expert_id >= start_expert_id and expert_id <= end_expert_id:
79
+ if a1_scales_ptr is not None:
80
+ scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
81
+ else:
82
+ scale = 1.0
83
+
84
+ dst_idx = tl.load(src2dst_ptr + idx)
85
+ dst_ptr = gateup_input_ptr + dst_idx * hidden_size
86
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
87
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
88
+ mask = offset < hidden_size
89
+ in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
90
+ out_data = (in_data * scale).to(OutDtype)
91
+ tl.store(dst_ptr + offset, out_data, mask=mask)
92
+
93
+
94
+ @triton.jit
95
+ def silu_and_mul_triton_kernel(
96
+ gateup_output,
97
+ down_input,
98
+ hidden_size,
99
+ reorder_topk_ids,
100
+ scales,
101
+ start_expert_id,
102
+ end_expert_id,
103
+ BLOCK_SIZE: tl.constexpr,
104
+ ):
105
+ InDtype = gateup_output.dtype.element_ty
106
+ OutDtype = down_input.dtype.element_ty
107
+
108
+ half_hidden_size = hidden_size // 2
109
+
110
+ pid = tl.program_id(0)
111
+ expert_id = tl.load(reorder_topk_ids + pid)
112
+ if expert_id >= start_expert_id and expert_id <= end_expert_id:
113
+ gateup_output_ptr = gateup_output + pid * hidden_size
114
+ gate_output_ptr = gateup_output_ptr
115
+ up_output_ptr = gateup_output_ptr + half_hidden_size
116
+ down_input_ptr = down_input + pid * half_hidden_size
117
+
118
+ if scales is not None:
119
+ scale = tl.load(scales + expert_id - start_expert_id)
120
+ scale = (1 / scale).to(InDtype)
121
+ else:
122
+ scale = 1
123
+
124
+ for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
125
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
126
+ mask = offset < half_hidden_size
127
+
128
+ gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
129
+ up_output = tl.load(up_output_ptr + offset, mask=mask)
130
+
131
+ # silu & mul & quantize
132
+ gate_output = gate_output * tl.sigmoid(gate_output)
133
+ gate_output = gate_output.to(InDtype)
134
+
135
+ silu_mul_output = gate_output * up_output * scale
136
+ silu_mul_output = silu_mul_output.to(OutDtype)
137
+ tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
138
+
139
+
140
+ @triton.jit
141
+ def post_reorder_triton_kernel(
142
+ down_output_ptr,
143
+ output_ptr,
144
+ src2dst_ptr,
145
+ topk_ids_ptr,
146
+ topk_weights_ptr,
147
+ start_expert_id,
148
+ end_expert_id,
149
+ topk,
150
+ hidden_size,
151
+ BLOCK_SIZE: tl.constexpr,
152
+ ):
153
+ InDtype = down_output_ptr.dtype.element_ty
154
+
155
+ src_idx = tl.program_id(0)
156
+ src2dst_ptr = src2dst_ptr + src_idx * topk
157
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
158
+ topk_weights_ptr = topk_weights_ptr + src_idx * topk
159
+
160
+ computed = False
161
+ store_ptr = output_ptr + src_idx * hidden_size
162
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
163
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
164
+ mask = offset < hidden_size
165
+
166
+ sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
167
+ for idx in range(topk):
168
+ expert_id = tl.load(topk_ids_ptr + idx)
169
+ if expert_id >= start_expert_id and expert_id <= end_expert_id:
170
+ computed = True
171
+ dst_idx = tl.load(src2dst_ptr + idx)
172
+ weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
173
+ load_ptr = down_output_ptr + dst_idx * hidden_size
174
+ in_data = tl.load(load_ptr + offset, mask=mask)
175
+ sum_vec += in_data * weigh_scale
176
+ tl.store(store_ptr + offset, sum_vec, mask=mask)
177
+
178
+ if computed == False:
179
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
180
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
181
+ mask = offset < hidden_size
182
+ tl.store(
183
+ store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
184
+ )
185
+
186
+
187
+ @triton.jit
188
+ def compute_m_range(
189
+ pid,
190
+ batch_size,
191
+ seg_indptr,
192
+ weight_indices,
193
+ m_num_tiles_indptr,
194
+ BLOCK_SIZE_M: tl.constexpr,
195
+ ):
196
+ idx = 0
197
+ for bs in range(batch_size):
198
+ tiles = tl.load(m_num_tiles_indptr + bs)
199
+ if pid >= tiles:
200
+ idx = bs
201
+
202
+ idx_start = tl.load(m_num_tiles_indptr + idx)
203
+
204
+ m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
205
+ m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
206
+ expert_id = tl.load(weight_indices + idx)
207
+ return m_range_start, m_range_end, expert_id
208
+
209
+
210
+ @triton.jit
211
+ def grouped_gemm_triton_kernel(
212
+ a,
213
+ b,
214
+ c,
215
+ batch_size,
216
+ N,
217
+ K,
218
+ seg_indptr,
219
+ weight_indices,
220
+ m_num_tiles_indptr,
221
+ use_fp8_w8a8,
222
+ scale_a,
223
+ scale_b,
224
+ a_stride_0: tl.constexpr,
225
+ b_stride_0: tl.constexpr,
226
+ b_stride_1: tl.constexpr,
227
+ BLOCK_SIZE_M: tl.constexpr,
228
+ BLOCK_SIZE_N: tl.constexpr,
229
+ BLOCK_SIZE_K: tl.constexpr,
230
+ ):
231
+ c_dtype = c.dtype.element_ty
232
+
233
+ pid_m = tl.program_id(0)
234
+ pid_n = tl.program_id(1)
235
+ total_m_block = tl.load(m_num_tiles_indptr + batch_size)
236
+ if pid_m >= total_m_block:
237
+ return
238
+
239
+ m_range_start, m_range_end, expert_id = compute_m_range(
240
+ pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
241
+ )
242
+ if m_range_end - m_range_start == 0:
243
+ return
244
+
245
+ n_range_start = pid_n * BLOCK_SIZE_N
246
+ n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
247
+
248
+ offs_am = tl.arange(0, BLOCK_SIZE_M)
249
+ offs_bn = tl.arange(0, BLOCK_SIZE_N)
250
+
251
+ offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
252
+ offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
253
+ offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
254
+ offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
255
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
256
+
257
+ a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
258
+ b_ptr = b + (
259
+ (expert_id * b_stride_0)
260
+ + (n_range_start + offs_bn[:, None]) * b_stride_1
261
+ + offs_k[None, :]
262
+ )
263
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
264
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
265
+ a_tile = tl.load(
266
+ a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
267
+ )
268
+ b_tile = tl.load(
269
+ b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
270
+ )
271
+ accumulator = tl.dot(a_tile, b_tile.T, accumulator)
272
+ a_ptr += BLOCK_SIZE_K
273
+ b_ptr += BLOCK_SIZE_K
274
+
275
+ if use_fp8_w8a8:
276
+ scale_a_value = tl.load(scale_a + expert_id)
277
+ scale_b_value = tl.load(scale_b + expert_id)
278
+ accumulator *= scale_a_value * scale_b_value
279
+ c_tile = accumulator.to(c_dtype)
280
+
281
+ offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
282
+ offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
283
+ c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
284
+ c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
285
+ tl.store(c_ptr, c_tile, mask=c_mask)
286
+
287
+
288
+ @triton.jit
289
+ def compute_m_num_tiles_indptr(
290
+ m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
291
+ ):
292
+ for bs in range(batch_size):
293
+ m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
294
+ cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
295
+ pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
296
+ tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
297
+
298
+
299
+ def grouped_gemm_triton(
300
+ a: torch.Tensor,
301
+ b: torch.Tensor,
302
+ c: torch.Tensor,
303
+ batch_size: int,
304
+ weight_column_major: bool,
305
+ seg_indptr: Optional[torch.Tensor] = None,
306
+ weight_indices: Optional[torch.Tensor] = None,
307
+ use_fp8_w8a8: bool = False,
308
+ scale_a: torch.Tensor = None,
309
+ scale_b: torch.Tensor = None,
310
+ ):
311
+ assert weight_column_major == True # TODO: more
312
+ if use_fp8_w8a8:
313
+ assert scale_a is not None and scale_b is not None
314
+
315
+ config = {
316
+ "BLOCK_SIZE_M": 128,
317
+ "BLOCK_SIZE_N": 128,
318
+ "BLOCK_SIZE_K": 128,
319
+ }
320
+
321
+ m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
322
+ compute_m_num_tiles_indptr[(1,)](
323
+ m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
324
+ )
325
+
326
+ grid = lambda META: (
327
+ triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
328
+ triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
329
+ )
330
+
331
+ grouped_gemm_triton_kernel[grid](
332
+ a,
333
+ b,
334
+ c,
335
+ batch_size,
336
+ b.size(1),
337
+ b.size(2),
338
+ seg_indptr,
339
+ weight_indices,
340
+ m_num_tiles_indptr,
341
+ use_fp8_w8a8,
342
+ scale_a,
343
+ scale_b,
344
+ a.stride(0),
345
+ b.stride(0),
346
+ b.stride(1),
347
+ **config,
348
+ )
349
+ return c