sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -130,28 +130,30 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
130
130
 
131
131
  @triton.jit
132
132
  def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
133
- expert = tl.program_id(0)
133
+ expert_id_minus_1 = tl.program_id(0) - 1
134
134
  low = 0
135
135
  high = num_toks - 1
136
136
  target_location = -1
137
137
  while low <= high:
138
138
  mid = (low + high) // 2
139
139
 
140
- if tl.load(reorder_topk_ids + mid) > expert:
140
+ if tl.load(reorder_topk_ids + mid) > expert_id_minus_1:
141
141
  high = mid - 1
142
142
  else:
143
143
  low = mid + 1
144
144
  target_location = mid
145
- tl.store(seg_indptr + expert + 1, target_location + 1)
145
+ tl.store(seg_indptr + expert_id_minus_1 + 1, target_location + 1)
146
146
 
147
147
 
148
- def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
148
+ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_local_experts: int):
149
149
  reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
150
150
 
151
- seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
151
+ seg_indptr = torch.zeros(
152
+ num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
153
+ )
152
154
  src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
153
155
 
154
- compute_seg_indptr_triton_kernel[(num_experts,)](
156
+ compute_seg_indptr_triton_kernel[(num_local_experts,)](
155
157
  reorder_topk_ids, seg_indptr, topk_ids.numel()
156
158
  )
157
159
 
@@ -164,25 +166,6 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
164
166
  return reorder_topk_ids, src2dst, seg_indptr
165
167
 
166
168
 
167
- def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
168
- reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
169
-
170
- seg_indptr = torch.zeros(
171
- local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
172
- )
173
- src2dst = torch.empty(
174
- local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
175
- )
176
-
177
- BLOCK_SIZE = 512
178
- grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
179
- compute_src2dst_triton_kernel[grid](
180
- reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
181
- )
182
-
183
- return reorder_topk_ids, src2dst, seg_indptr
184
-
185
-
186
169
  @triton.jit
187
170
  def pre_reorder_triton_kernel_for_cutlass_moe(
188
171
  input_ptr,
@@ -190,52 +173,13 @@ def pre_reorder_triton_kernel_for_cutlass_moe(
190
173
  src2dst_ptr,
191
174
  topk_ids_ptr,
192
175
  a1_scales_ptr,
193
- num_experts,
176
+ num_local_experts,
194
177
  topk,
195
178
  hidden_size,
196
179
  BLOCK_SIZE: tl.constexpr,
197
180
  ):
198
181
  OutDtype = gateup_input_ptr.dtype.element_ty
199
182
 
200
- src_idx = tl.program_id(0)
201
- src2dst_ptr = src2dst_ptr + src_idx * topk
202
- topk_ids_ptr = topk_ids_ptr + src_idx * topk
203
-
204
- src_ptr = input_ptr + src_idx * hidden_size
205
- for idx in range(topk):
206
- expert_id = tl.load(topk_ids_ptr + idx)
207
- if expert_id != num_experts:
208
- if a1_scales_ptr is not None:
209
- scale = 1.0 / tl.load(a1_scales_ptr)
210
- else:
211
- scale = 1.0
212
-
213
- dst_idx = tl.load(src2dst_ptr + idx)
214
- dst_ptr = gateup_input_ptr + dst_idx * hidden_size
215
- for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
216
- offset = start_offset + tl.arange(0, BLOCK_SIZE)
217
- mask = offset < hidden_size
218
- in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
219
- out_data = (in_data * scale).to(OutDtype)
220
- tl.store(dst_ptr + offset, out_data, mask=mask)
221
-
222
-
223
- @triton.jit
224
- def pre_reorder_triton_kernel(
225
- input_ptr,
226
- gateup_input_ptr,
227
- src2dst_ptr,
228
- topk_ids_ptr,
229
- a1_scales_ptr,
230
- start_expert_id,
231
- end_expert_id,
232
- topk,
233
- hidden_size,
234
- BLOCK_SIZE: tl.constexpr,
235
- use_per_token_if_dynamic: tl.constexpr,
236
- ):
237
- OutDtype = gateup_input_ptr.dtype.element_ty
238
-
239
183
  src_idx_int32 = tl.program_id(0)
240
184
  src_idx = src_idx_int32.to(tl.int64)
241
185
  src2dst_ptr = src2dst_ptr + src_idx * topk
@@ -244,15 +188,11 @@ def pre_reorder_triton_kernel(
244
188
 
245
189
  vec = tl.arange(0, BLOCK_SIZE)
246
190
 
247
- if a1_scales_ptr is not None and use_per_token_if_dynamic:
248
- scale = 1.0 / tl.load(a1_scales_ptr + src_idx)
249
-
250
191
  for idx in range(topk):
251
192
  expert_id = tl.load(topk_ids_ptr + idx)
252
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
193
+ if expert_id != num_local_experts:
253
194
  if a1_scales_ptr is not None:
254
- if not use_per_token_if_dynamic:
255
- scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
195
+ scale = 1.0 / tl.load(a1_scales_ptr)
256
196
  else:
257
197
  scale = 1.0
258
198
 
@@ -267,52 +207,6 @@ def pre_reorder_triton_kernel(
267
207
  tl.store(dst_ptr + offset, out_data, mask=mask)
268
208
 
269
209
 
270
- @triton.jit
271
- def silu_and_mul_triton_kernel(
272
- gateup_output,
273
- down_input,
274
- hidden_size,
275
- reorder_topk_ids,
276
- scales,
277
- start_expert_id,
278
- end_expert_id,
279
- BLOCK_SIZE: tl.constexpr,
280
- ):
281
- InDtype = gateup_output.dtype.element_ty
282
- OutDtype = down_input.dtype.element_ty
283
-
284
- half_hidden_size = hidden_size // 2
285
-
286
- pid = tl.program_id(0)
287
- expert_id = tl.load(reorder_topk_ids + pid)
288
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
289
- gateup_output_ptr = gateup_output + pid * hidden_size
290
- gate_output_ptr = gateup_output_ptr
291
- up_output_ptr = gateup_output_ptr + half_hidden_size
292
- down_input_ptr = down_input + pid * half_hidden_size
293
-
294
- if scales is not None:
295
- scale = tl.load(scales + expert_id - start_expert_id)
296
- scale = (1 / scale).to(InDtype)
297
- else:
298
- scale = 1
299
-
300
- for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
301
- offset = start_offset + tl.arange(0, BLOCK_SIZE)
302
- mask = offset < half_hidden_size
303
-
304
- gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
305
- up_output = tl.load(up_output_ptr + offset, mask=mask)
306
-
307
- # silu & mul & quantize
308
- gate_output = gate_output * tl.sigmoid(gate_output)
309
- gate_output = gate_output.to(InDtype)
310
-
311
- silu_mul_output = gate_output * up_output * scale
312
- silu_mul_output = silu_mul_output.to(OutDtype)
313
- tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
314
-
315
-
316
210
  # copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
317
211
  @triton.jit
318
212
  def _silu_and_mul_post_quant_kernel(
@@ -461,84 +355,15 @@ def silu_and_mul_masked_post_quant_fwd(
461
355
 
462
356
 
463
357
  @triton.jit
464
- def tanh(x):
465
- return 2 * tl.sigmoid(2 * x) - 1
466
-
467
-
468
- @triton.jit
469
- def gelu_and_mul_triton_kernel(
470
- gateup_output,
471
- down_input,
472
- hidden_size,
473
- reorder_topk_ids,
474
- scales,
475
- start_expert_id,
476
- end_expert_id,
477
- BLOCK_SIZE: tl.constexpr,
478
- ):
479
- InDtype = gateup_output.dtype.element_ty
480
- OutDtype = down_input.dtype.element_ty
481
-
482
- half_hidden_size = hidden_size // 2
483
-
484
- pid = tl.program_id(0)
485
- expert_id = tl.load(reorder_topk_ids + pid)
486
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
487
- gateup_output_ptr = gateup_output + pid * hidden_size
488
- gate_output_ptr = gateup_output_ptr
489
- up_output_ptr = gateup_output_ptr + half_hidden_size
490
- down_input_ptr = down_input + pid * half_hidden_size
491
-
492
- if scales is not None:
493
- scale = tl.load(scales + expert_id - start_expert_id)
494
- scale = (1 / scale).to(InDtype)
495
- else:
496
- scale = 1
497
-
498
- for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
499
- offset = start_offset + tl.arange(0, BLOCK_SIZE)
500
- mask = offset < half_hidden_size
501
-
502
- gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
503
- up_output = tl.load(up_output_ptr + offset, mask=mask)
504
-
505
- # gelu & mul & quantize
506
- # https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
507
- # sqrt(2/pi)
508
- kAlpha = 0.7978845608028654
509
- gate_output = (
510
- 0.5
511
- * gate_output
512
- * (
513
- 1
514
- + tanh(
515
- kAlpha
516
- * (
517
- gate_output
518
- + 0.044715 * gate_output * gate_output * gate_output
519
- )
520
- )
521
- )
522
- )
523
- gate_output = gate_output.to(InDtype)
524
-
525
- gelu_mul_output = gate_output * up_output * scale
526
- gelu_mul_output = gelu_mul_output.to(OutDtype)
527
- tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
528
-
529
-
530
- @triton.jit
531
- def post_reorder_triton_kernel(
358
+ def post_reorder_triton_kernel_for_cutlass_moe(
532
359
  down_output_ptr,
533
360
  output_ptr,
534
361
  src2dst_ptr,
535
362
  topk_ids_ptr,
536
363
  topk_weights_ptr,
537
- start_expert_id,
538
- end_expert_id,
539
364
  topk,
365
+ num_local_experts,
540
366
  hidden_size,
541
- dst_start,
542
367
  BLOCK_SIZE: tl.constexpr,
543
368
  ):
544
369
  InDtype = down_output_ptr.dtype.element_ty
@@ -549,7 +374,6 @@ def post_reorder_triton_kernel(
549
374
  topk_ids_ptr = topk_ids_ptr + src_idx * topk
550
375
  topk_weights_ptr = topk_weights_ptr + src_idx * topk
551
376
 
552
- computed = False
553
377
  store_ptr = output_ptr + src_idx * hidden_size
554
378
 
555
379
  vec = tl.arange(0, BLOCK_SIZE)
@@ -561,37 +385,25 @@ def post_reorder_triton_kernel(
561
385
  sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
562
386
  for idx in range(topk):
563
387
  expert_id = tl.load(topk_ids_ptr + idx)
564
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
565
- computed = True
388
+ if expert_id != num_local_experts:
566
389
  dst_idx_int32 = tl.load(src2dst_ptr + idx)
567
390
  dst_idx = dst_idx_int32.to(tl.int64)
568
- dst_idx = dst_idx - dst_start
569
391
  weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
570
392
  load_ptr = down_output_ptr + dst_idx * hidden_size
571
393
  in_data = tl.load(load_ptr + offset, mask=mask)
572
394
  sum_vec += in_data * weigh_scale
573
395
  tl.store(store_ptr + offset, sum_vec, mask=mask)
574
396
 
575
- if computed == False:
576
- for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
577
- offset = start_offset + vec
578
- mask = offset < hidden_size
579
- tl.store(
580
- store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
581
- )
582
-
583
397
 
584
398
  @triton.jit
585
- def post_reorder_triton_kernel_for_cutlass_moe(
399
+ def post_reorder_triton_kernel(
586
400
  down_output_ptr,
587
401
  output_ptr,
588
402
  src2dst_ptr,
589
403
  topk_ids_ptr,
590
404
  topk_weights_ptr,
591
- num_experts,
592
405
  topk,
593
406
  hidden_size,
594
- dst_start,
595
407
  BLOCK_SIZE: tl.constexpr,
596
408
  ):
597
409
  InDtype = down_output_ptr.dtype.element_ty
@@ -613,10 +425,9 @@ def post_reorder_triton_kernel_for_cutlass_moe(
613
425
  sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
614
426
  for idx in range(topk):
615
427
  expert_id = tl.load(topk_ids_ptr + idx)
616
- if expert_id != num_experts:
428
+ if expert_id > 0:
617
429
  dst_idx_int32 = tl.load(src2dst_ptr + idx)
618
430
  dst_idx = dst_idx_int32.to(tl.int64)
619
- dst_idx = dst_idx - dst_start
620
431
  weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
621
432
  load_ptr = down_output_ptr + dst_idx * hidden_size
622
433
  in_data = tl.load(load_ptr + offset, mask=mask)
@@ -624,232 +435,6 @@ def post_reorder_triton_kernel_for_cutlass_moe(
624
435
  tl.store(store_ptr + offset, sum_vec, mask=mask)
625
436
 
626
437
 
627
- @triton.jit
628
- def compute_m_range(
629
- pid,
630
- batch_size,
631
- seg_indptr,
632
- weight_indices,
633
- m_num_tiles_indptr,
634
- BLOCK_SIZE_M: tl.constexpr,
635
- ):
636
- idx = 0
637
- for bs in range(batch_size):
638
- tiles = tl.load(m_num_tiles_indptr + bs)
639
- if pid >= tiles:
640
- idx = bs
641
-
642
- idx_start = tl.load(m_num_tiles_indptr + idx)
643
-
644
- m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
645
- m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
646
- expert_id = tl.load(weight_indices + idx)
647
- return m_range_start, m_range_end, expert_id
648
-
649
-
650
- @triton.jit
651
- def grouped_gemm_triton_kernel(
652
- a,
653
- b,
654
- c,
655
- batch_size,
656
- N,
657
- K,
658
- seg_indptr,
659
- weight_indices,
660
- m_num_tiles_indptr,
661
- scale_a,
662
- scale_b,
663
- use_fp8_w8a8: tl.constexpr,
664
- group_n: tl.constexpr,
665
- group_k: tl.constexpr,
666
- a_stride_0: tl.constexpr,
667
- b_stride_0: tl.constexpr,
668
- b_stride_1: tl.constexpr,
669
- as_stride_0: tl.constexpr,
670
- as_stride_1: tl.constexpr,
671
- bs_stride_0: tl.constexpr,
672
- bs_stride_2: tl.constexpr,
673
- bs_stride_1: tl.constexpr,
674
- use_per_token_if_dynamic: tl.constexpr,
675
- BLOCK_SIZE_M: tl.constexpr,
676
- BLOCK_SIZE_N: tl.constexpr,
677
- BLOCK_SIZE_K: tl.constexpr,
678
- ):
679
- c_dtype = c.dtype.element_ty
680
-
681
- pid_m = tl.program_id(0)
682
- pid_n = tl.program_id(1)
683
- total_m_block = tl.load(m_num_tiles_indptr + batch_size)
684
- if pid_m >= total_m_block:
685
- return
686
-
687
- m_range_start, m_range_end, expert_id = compute_m_range(
688
- pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
689
- )
690
- if m_range_end - m_range_start == 0:
691
- return
692
-
693
- n_range_start = pid_n * BLOCK_SIZE_N
694
- n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
695
-
696
- offs_am = tl.arange(0, BLOCK_SIZE_M)
697
- offs_bn = tl.arange(0, BLOCK_SIZE_N)
698
-
699
- offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
700
- offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
701
- offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
702
- offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
703
- offs_k = tl.arange(0, BLOCK_SIZE_K)
704
-
705
- a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
706
- b_ptr = b + (
707
- (expert_id * b_stride_0)
708
- + (n_range_start + offs_bn[:, None]) * b_stride_1
709
- + offs_k[None, :]
710
- )
711
-
712
- if group_k > 0 and group_n > 0:
713
- a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
714
- offs_bsn = (n_range_start + offs_bn) // group_n
715
- b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
716
-
717
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
718
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
719
- a_tile = tl.load(
720
- a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
721
- )
722
- b_tile = tl.load(
723
- b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
724
- )
725
-
726
- if group_k > 0 and group_n > 0:
727
- k_start = k * BLOCK_SIZE_K
728
- offs_ks = k_start // group_k
729
- a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1)
730
- b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2)
731
- accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
732
- else:
733
- accumulator = tl.dot(a_tile, b_tile.T, accumulator)
734
- a_ptr += BLOCK_SIZE_K
735
- b_ptr += BLOCK_SIZE_K
736
-
737
- if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
738
- if use_per_token_if_dynamic:
739
- scale_a_value = tl.load(scale_a + (m_range_start + offs_am[:, None]))
740
- else:
741
- scale_a_value = tl.load(scale_a + expert_id)
742
- scale_b_value = tl.load(scale_b + expert_id)
743
- accumulator *= scale_a_value * scale_b_value
744
-
745
- c_tile = accumulator.to(c_dtype)
746
-
747
- offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
748
- offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
749
- c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
750
- c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
751
- tl.store(c_ptr, c_tile, mask=c_mask)
752
-
753
-
754
- @triton.jit
755
- def compute_m_num_tiles_indptr(
756
- m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
757
- ):
758
- for bs in range(batch_size):
759
- m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
760
- cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
761
- pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
762
- tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
763
-
764
-
765
- def grouped_gemm_triton(
766
- a: torch.Tensor,
767
- b: torch.Tensor,
768
- c: torch.Tensor,
769
- batch_size: int,
770
- weight_column_major: bool,
771
- seg_indptr: Optional[torch.Tensor] = None,
772
- weight_indices: Optional[torch.Tensor] = None,
773
- use_fp8_w8a8: bool = False,
774
- scale_a: torch.Tensor = None,
775
- scale_b: torch.Tensor = None,
776
- block_shape: Optional[List[int]] = None,
777
- c_dtype=None,
778
- use_per_token_if_dynamic: bool = True,
779
- ):
780
- assert weight_column_major == True # TODO: more
781
- if use_fp8_w8a8 and block_shape is None:
782
- assert scale_a is not None and scale_b is not None
783
-
784
- if block_shape is not None:
785
- a_original = a
786
-
787
- assert len(block_shape) == 2
788
- block_n, block_k = block_shape[0], block_shape[1]
789
- a, scale_a = per_token_group_quant_fp8(a, block_k)
790
-
791
- assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
792
- assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
793
- assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
794
-
795
- dispose_tensor(a_original)
796
-
797
- # TODO: adjust config or tune kernel
798
- # Reduce block size to prevent L40 shared memory overflow.
799
- config = {
800
- "BLOCK_SIZE_M": 64,
801
- "BLOCK_SIZE_N": 32,
802
- "BLOCK_SIZE_K": 128,
803
- }
804
-
805
- m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
806
- compute_m_num_tiles_indptr[(1,)](
807
- m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
808
- )
809
-
810
- if c is None:
811
- assert c_dtype is not None
812
- c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype)
813
-
814
- grid = lambda META: (
815
- triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
816
- triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
817
- )
818
-
819
- if use_fp8_w8a8 and block_shape is None and use_per_token_if_dynamic:
820
- assert (
821
- scale_a.shape[0] == a.shape[0]
822
- ), f"scale_a.shape: {scale_a.shape}, a.shape: {a.shape}"
823
-
824
- grouped_gemm_triton_kernel[grid](
825
- a,
826
- b,
827
- c,
828
- batch_size,
829
- b.size(1),
830
- b.size(2),
831
- seg_indptr,
832
- weight_indices,
833
- m_num_tiles_indptr,
834
- scale_a,
835
- scale_b,
836
- use_fp8_w8a8,
837
- 0 if block_shape is None else block_shape[0],
838
- 0 if block_shape is None else block_shape[1],
839
- a.stride(0),
840
- b.stride(0),
841
- b.stride(1),
842
- scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0,
843
- scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0,
844
- scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
845
- scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
846
- scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
847
- use_per_token_if_dynamic,
848
- **config,
849
- )
850
- return c
851
-
852
-
853
438
  @triton.jit
854
439
  def _fwd_kernel_ep_scatter_1(
855
440
  num_recv_tokens_per_expert,
@@ -1234,7 +819,7 @@ def deepgemm_compute_src2dst_triton_kernel(
1234
819
  mask = dst_id < num_toks
1235
820
  src_id = tl.load(reorder_ids + dst_id, mask=mask)
1236
821
  expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
1237
- expert_dst_start = tl.load(seg_indptr + expert_id)
822
+ expert_dst_start = tl.load(seg_indptr + expert_id, mask=(expert_id >= 0))
1238
823
  expert_dst_offset = dst_id - expert_dst_start
1239
824
  dst_id = expert_id * m_max + expert_dst_offset
1240
825
  tl.store(src2dst + src_id, dst_id, mask=mask)
@@ -1248,10 +833,7 @@ def fill_gateup_input_triton_kernel(
1248
833
  gateup_input_scale_ptr,
1249
834
  src2dst_ptr,
1250
835
  topk_ids_ptr,
1251
- start_expert_id,
1252
- end_expert_id,
1253
836
  topk,
1254
- m_max,
1255
837
  hidden_size,
1256
838
  scale_size,
1257
839
  BLOCK_SIZE: tl.constexpr,
@@ -1267,10 +849,9 @@ def fill_gateup_input_triton_kernel(
1267
849
  vec = tl.arange(0, BLOCK_SIZE)
1268
850
  for idx in range(topk):
1269
851
  expert_id = tl.load(topk_ids_ptr + idx)
1270
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
852
+ if expert_id >= 0:
1271
853
  dst_idx_int32 = tl.load(src2dst_ptr + idx)
1272
854
  dst_idx = dst_idx_int32.to(tl.int64)
1273
- dst_idx = dst_idx - start_expert_id * m_max
1274
855
  dst_ptr = gateup_input_ptr + dst_idx * hidden_size
1275
856
  for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
1276
857
  offset = start_offset + vec
@@ -1287,31 +868,31 @@ def fill_gateup_input_triton_kernel(
1287
868
 
1288
869
  def moe_ep_deepgemm_preprocess(
1289
870
  topk_ids: torch.Tensor,
1290
- num_experts: int,
871
+ num_local_experts: int,
1291
872
  hidden_states: torch.Tensor,
1292
873
  top_k: int,
1293
- start_expert_id,
1294
- end_expert_id,
1295
874
  block_shape,
1296
875
  output_dtype: torch.dtype = torch.float8_e4m3fn,
1297
876
  ):
1298
877
  reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
1299
- seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
878
+ seg_indptr = torch.zeros(
879
+ num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
880
+ )
1300
881
  src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
1301
- masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
882
+ masked_m = torch.empty(num_local_experts, device=topk_ids.device, dtype=torch.int32)
1302
883
 
1303
- compute_seg_indptr_triton_kernel[(num_experts,)](
884
+ compute_seg_indptr_triton_kernel[(num_local_experts + 1,)](
1304
885
  reorder_topk_ids, seg_indptr, topk_ids.numel()
1305
886
  )
1306
887
 
1307
888
  grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
1308
- compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
889
+ compute_masked_m_triton_kernel[(num_local_experts,)](seg_indptr, masked_m)
1309
890
 
1310
891
  # For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
1311
- m_max = (hidden_states.size(0) + 255) // 256 * 256
1312
- expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
892
+ m_max = (hidden_states.size(0) // 256 + 1) * 256
893
+ expected_m = (topk_ids.numel() - 1) // num_local_experts + 1
1313
894
  gateup_input = torch.empty(
1314
- (int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
895
+ (num_local_experts, m_max, hidden_states.size(1)),
1315
896
  device=hidden_states.device,
1316
897
  dtype=output_dtype,
1317
898
  )
@@ -1330,6 +911,8 @@ def moe_ep_deepgemm_preprocess(
1330
911
  block_shape = [128, 128]
1331
912
  assert len(block_shape) == 2
1332
913
  block_n, block_k = block_shape[0], block_shape[1]
914
+
915
+ # TODO: fuse this with the preprocess
1333
916
  hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
1334
917
 
1335
918
  gateup_input_scale = torch.empty(
@@ -1345,18 +928,14 @@ def moe_ep_deepgemm_preprocess(
1345
928
  gateup_input_scale,
1346
929
  src2dst,
1347
930
  topk_ids,
1348
- start_expert_id,
1349
- end_expert_id,
1350
931
  top_k,
1351
- m_max,
1352
932
  hidden_states.size(1),
1353
933
  scale.size(1),
1354
934
  BLOCK_SIZE=1024,
1355
935
  )
1356
936
 
1357
937
  return (
1358
- m_max,
1359
- masked_m[start_expert_id : (end_expert_id + 1)],
938
+ masked_m,
1360
939
  expected_m,
1361
940
  src2dst,
1362
941
  gateup_input,