sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +8 -0
  3. sglang/srt/configs/model_config.py +6 -0
  4. sglang/srt/configs/step3_vl.py +172 -0
  5. sglang/srt/conversation.py +23 -0
  6. sglang/srt/disaggregation/decode.py +2 -8
  7. sglang/srt/disaggregation/prefill.py +2 -6
  8. sglang/srt/distributed/parallel_state.py +86 -1
  9. sglang/srt/entrypoints/engine.py +14 -18
  10. sglang/srt/entrypoints/http_server.py +23 -3
  11. sglang/srt/entrypoints/openai/protocol.py +3 -1
  12. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  13. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  14. sglang/srt/eplb/expert_distribution.py +5 -0
  15. sglang/srt/eplb/expert_location.py +17 -6
  16. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  17. sglang/srt/eplb/expert_location_updater.py +2 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/step3_detector.py +436 -0
  20. sglang/srt/hf_transformers_utils.py +2 -0
  21. sglang/srt/jinja_template_utils.py +4 -1
  22. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  23. sglang/srt/layers/moe/ep_moe/layer.py +98 -603
  24. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
  29. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  30. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  31. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  32. sglang/srt/layers/moe/topk.py +6 -2
  33. sglang/srt/layers/quantization/fp8.py +0 -18
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -0
  35. sglang/srt/layers/quantization/unquant.py +0 -8
  36. sglang/srt/layers/quantization/w4afp8.py +1 -0
  37. sglang/srt/managers/cache_controller.py +143 -45
  38. sglang/srt/managers/data_parallel_controller.py +6 -0
  39. sglang/srt/managers/io_struct.py +12 -2
  40. sglang/srt/managers/scheduler.py +116 -669
  41. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  42. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  43. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  44. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  45. sglang/srt/managers/template_manager.py +62 -19
  46. sglang/srt/managers/tokenizer_manager.py +166 -83
  47. sglang/srt/managers/tp_worker.py +9 -0
  48. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  49. sglang/srt/mem_cache/hicache_storage.py +45 -11
  50. sglang/srt/mem_cache/hiradix_cache.py +15 -4
  51. sglang/srt/mem_cache/memory_pool_host.py +73 -1
  52. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  53. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  54. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
  55. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  56. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  57. sglang/srt/model_executor/model_runner.py +20 -13
  58. sglang/srt/models/arcee.py +532 -0
  59. sglang/srt/models/deepseek_v2.py +15 -56
  60. sglang/srt/models/glm4_moe.py +3 -1
  61. sglang/srt/models/granitemoe.py +3 -0
  62. sglang/srt/models/grok.py +3 -0
  63. sglang/srt/models/hunyuan.py +1 -0
  64. sglang/srt/models/llama4.py +3 -0
  65. sglang/srt/models/mixtral.py +3 -0
  66. sglang/srt/models/olmoe.py +3 -0
  67. sglang/srt/models/phimoe.py +1 -0
  68. sglang/srt/models/qwen3_moe.py +12 -69
  69. sglang/srt/models/step3_vl.py +994 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/poll_based_barrier.py +31 -0
  73. sglang/srt/reasoning_parser.py +2 -1
  74. sglang/srt/server_args.py +18 -13
  75. sglang/srt/speculative/eagle_worker.py +2 -0
  76. sglang/srt/two_batch_overlap.py +8 -3
  77. sglang/test/test_utils.py +53 -0
  78. sglang/utils.py +0 -11
  79. sglang/version.py +1 -1
  80. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
  81. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
  82. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,27 @@
1
+ # TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
2
+
3
+ from __future__ import annotations
4
+
1
5
  import logging
2
6
  from dataclasses import dataclass
7
+ from typing import (
8
+ TYPE_CHECKING,
9
+ List,
10
+ NamedTuple,
11
+ Optional,
12
+ Protocol,
13
+ Tuple,
14
+ Union,
15
+ runtime_checkable,
16
+ )
3
17
 
4
18
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
19
+ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
20
+ BaseDispatcher,
21
+ BaseDispatcherConfig,
22
+ DispatchOutput,
23
+ DispatchOutputFormat,
24
+ )
5
25
  from sglang.srt.layers.quantization import deep_gemm_wrapper
6
26
  from sglang.srt.managers.schedule_batch import global_server_args_dict
7
27
  from sglang.srt.utils import (
@@ -24,7 +44,6 @@ except ImportError:
24
44
  use_deepep = False
25
45
 
26
46
  from enum import Enum, IntEnum, auto
27
- from typing import Optional, Tuple, Union
28
47
 
29
48
  import torch
30
49
  import torch.distributed as dist
@@ -41,6 +60,37 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
41
60
  logger = logging.getLogger(__name__)
42
61
 
43
62
 
63
+ class DeepEPNormalOutput(NamedTuple):
64
+ """DeepEP normal dispatch output."""
65
+
66
+ hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
67
+ topk_idx: torch.Tensor
68
+ topk_weights: torch.Tensor
69
+ num_recv_tokens_per_expert: List[int]
70
+
71
+ @property
72
+ def format(self) -> DispatchOutputFormat:
73
+ return DispatchOutputFormat.deepep_normal
74
+
75
+
76
+ class DeepEPLLOutput(NamedTuple):
77
+ """DeepEP low latency dispatch output."""
78
+
79
+ hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
80
+ topk_idx: torch.Tensor
81
+ topk_weights: torch.Tensor
82
+ masked_m: torch.Tensor
83
+ expected_m: int
84
+
85
+ @property
86
+ def format(self) -> DispatchOutputFormat:
87
+ return DispatchOutputFormat.deepep_ll
88
+
89
+
90
+ assert isinstance(DeepEPNormalOutput, DispatchOutput)
91
+ assert isinstance(DeepEPLLOutput, DispatchOutput)
92
+
93
+
44
94
  class DeepEPDispatchMode(IntEnum):
45
95
  NORMAL = auto()
46
96
  LOW_LATENCY = auto()
@@ -107,6 +157,20 @@ class DeepEPBuffer:
107
157
  else:
108
158
  raise NotImplementedError
109
159
 
160
+ total_num_sms = torch.cuda.get_device_properties(
161
+ device="cuda"
162
+ ).multi_processor_count
163
+ if (
164
+ (deepep_mode != DeepEPMode.low_latency)
165
+ and not global_server_args_dict["enable_two_batch_overlap"]
166
+ and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
167
+ ):
168
+ logger.warning(
169
+ f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
170
+ f"This may result in highly suboptimal performance. "
171
+ f"Consider using --deepep-config to change the behavior."
172
+ )
173
+
110
174
  cls._buffer = Buffer(
111
175
  group,
112
176
  num_nvl_bytes,
@@ -139,7 +203,7 @@ class DeepEPBuffer:
139
203
  cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
140
204
 
141
205
 
142
- class DeepEPConfig:
206
+ class DeepEPConfig(BaseDispatcherConfig):
143
207
  _instance = None
144
208
 
145
209
  def __init__(self):
@@ -255,63 +319,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
255
319
  return hidden_states, topk_idx, topk_weights, previous_event
256
320
 
257
321
  def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
258
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
259
- (
260
- hidden_states,
261
- topk_idx,
262
- topk_weights,
263
- num_recv_tokens_per_expert_list,
264
- event,
265
- ) = self._dispatch_core(
266
- hidden_states, topk_idx, topk_weights, previous_event
267
- )
268
- event.current_stream_wait() if self.async_finish else ()
269
- return (
270
- hidden_states,
271
- topk_idx,
272
- topk_weights,
273
- None,
274
- num_recv_tokens_per_expert_list,
275
- None,
276
- None,
277
- None,
278
- )
279
- else:
280
- (
281
- hidden_states,
282
- topk_idx,
283
- topk_weights,
284
- num_recv_tokens_per_expert_list,
285
- event,
286
- ) = self._dispatch_core(
287
- hidden_states, topk_idx, topk_weights, previous_event
288
- )
289
- event.current_stream_wait() if self.async_finish else ()
290
- if hidden_states.shape[0] > 0:
291
- reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
292
- hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
293
- )
294
- else:
295
- reorder_topk_ids = torch.empty(
296
- (0,), device=hidden_states.device, dtype=torch.int64
297
- )
298
- seg_indptr = torch.zeros(
299
- (self.num_experts + 1,),
300
- device=hidden_states.device,
301
- dtype=torch.int64,
302
- )
303
-
304
- masked_m = expected_m = None
305
- return (
306
- hidden_states,
307
- topk_idx,
308
- topk_weights,
309
- reorder_topk_ids,
310
- None,
311
- seg_indptr,
312
- masked_m,
313
- expected_m,
314
- )
322
+ (
323
+ hidden_states,
324
+ topk_idx,
325
+ topk_weights,
326
+ num_recv_tokens_per_expert,
327
+ event,
328
+ ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
329
+ event.current_stream_wait() if self.async_finish else ()
330
+ return DeepEPNormalOutput(
331
+ hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
332
+ )
315
333
 
316
334
  def _dispatch_core(
317
335
  self,
@@ -343,7 +361,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
343
361
  recv_x,
344
362
  recv_topk_idx,
345
363
  recv_topk_weights,
346
- num_recv_tokens_per_expert_list,
364
+ num_recv_tokens_per_expert,
347
365
  self.handle,
348
366
  event,
349
367
  ) = buffer.dispatch(
@@ -362,7 +380,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
362
380
  )
363
381
 
364
382
  get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
365
- num_recv_tokens_per_expert_list,
383
+ num_recv_tokens_per_expert,
366
384
  num_tokens_per_rank=num_tokens_per_rank,
367
385
  num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
368
386
  num_tokens_per_expert=num_tokens_per_expert,
@@ -372,58 +390,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
372
390
  recv_x,
373
391
  recv_topk_idx,
374
392
  recv_topk_weights,
375
- num_recv_tokens_per_expert_list,
393
+ num_recv_tokens_per_expert,
376
394
  event,
377
395
  )
378
396
 
379
- def _deepep_permute(
380
- self,
381
- hidden_states: torch.Tensor,
382
- topk_idx: torch.Tensor,
383
- fp8_dtype: Optional[torch.dtype] = None,
384
- use_fp8_w8a8: bool = False,
385
- use_block_quant: bool = False,
386
- ):
387
- """
388
- Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
389
- https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
390
- """
391
- if _use_aiter:
392
- # skip permutation here as aiter fused_moe has fused inside
393
- reorder_topk_ids = torch.empty(
394
- (0,), device=hidden_states.device, dtype=torch.int64
395
- )
396
- seg_indptr = torch.zeros(
397
- (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
398
- )
399
- return reorder_topk_ids, seg_indptr, hidden_states
400
-
401
- reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
402
- topk_idx, self.num_experts
403
- )
404
- num_total_tokens = reorder_topk_ids.numel()
405
- gateup_input = torch.empty(
406
- (int(num_total_tokens), hidden_states.shape[1]),
407
- device=hidden_states.device,
408
- dtype=(
409
- fp8_dtype
410
- if (use_fp8_w8a8 and not use_block_quant)
411
- else hidden_states.dtype
412
- ),
413
- )
414
- # PreReorder
415
- deepep_permute_triton_kernel[(hidden_states.shape[0],)](
416
- hidden_states,
417
- gateup_input,
418
- self.src2dst,
419
- topk_idx,
420
- None,
421
- self.router_topk,
422
- hidden_states.shape[1],
423
- BLOCK_SIZE=512,
424
- )
425
- return reorder_topk_ids, seg_indptr, gateup_input
426
-
427
397
  def combine_a(
428
398
  self,
429
399
  hidden_states: torch.Tensor,
@@ -544,15 +514,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
544
514
  masked_m
545
515
  )
546
516
 
547
- reorder_topk_ids = seg_indptr = None
548
-
549
- return (
517
+ return DeepEPLLOutput(
550
518
  hidden_states,
551
519
  topk_idx,
552
520
  topk_weights,
553
- reorder_topk_ids,
554
- None,
555
- seg_indptr,
556
521
  masked_m,
557
522
  expected_m,
558
523
  )
@@ -636,7 +601,7 @@ class _Stage(Enum):
636
601
  AFTER_COMBINE_A = auto()
637
602
 
638
603
 
639
- class DeepEPDispatcher:
604
+ class DeepEPDispatcher(BaseDispatcher):
640
605
  def __init__(
641
606
  self,
642
607
  group: torch.distributed.ProcessGroup,
@@ -676,7 +641,7 @@ class DeepEPDispatcher:
676
641
 
677
642
  self._stage = _Stage.INITIAL
678
643
 
679
- def dispatch(self, *args, **kwargs) -> Tuple:
644
+ def dispatch(self, *args, **kwargs) -> DispatchOutput:
680
645
  self.dispatch_a(*args, **kwargs)
681
646
  ret = self.dispatch_b()
682
647
  return ret
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 256,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 32,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 32,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 8,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 8,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 256,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 2
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 256,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -413,18 +413,37 @@ def fused_moe_kernel(
413
413
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
414
414
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
415
415
  return
416
- offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
416
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
417
417
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
418
418
  offs_token = offs_token.to(tl.int64)
419
419
  token_mask = offs_token < num_valid_tokens
420
420
 
421
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
421
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
422
+
423
+ if off_experts == -1:
424
+ # -----------------------------------------------------------
425
+ # Write back zeros to the output when the expert is not
426
+ # in the current expert parallel rank.
427
+ write_zeros_to_output(
428
+ c_ptr,
429
+ stride_cm,
430
+ stride_cn,
431
+ pid_n,
432
+ N,
433
+ offs_token,
434
+ token_mask,
435
+ BLOCK_SIZE_M,
436
+ BLOCK_SIZE_N,
437
+ compute_type,
438
+ )
439
+ return
440
+
441
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
422
442
  offs_k = tl.arange(0, BLOCK_SIZE_K)
423
443
  a_ptrs = a_ptr + (
424
444
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
425
445
  )
426
446
 
427
- off_experts = tl.load(expert_ids_ptr + pid_m)
428
447
  b_ptrs = (
429
448
  b_ptr
430
449
  + off_experts * stride_be
@@ -497,7 +516,6 @@ def fused_moe_kernel(
497
516
 
498
517
  accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
499
518
  else:
500
- # fix out of shared memory issue
501
519
  if use_fp8_w8a8:
502
520
  accumulator = tl.dot(a, b, acc=accumulator)
503
521
  else:
@@ -568,7 +586,7 @@ def moe_align_block_size(
568
586
  - The padding ensures that the total number of tokens is now divisible
569
587
  by block_size for proper block matrix operations.
570
588
  """
571
- max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
589
+ max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
572
590
  sorted_ids = torch.empty(
573
591
  (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
574
592
  )
@@ -578,13 +596,9 @@ def moe_align_block_size(
578
596
  )
579
597
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
580
598
 
599
+ # In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
581
600
  cumsum_buffer = torch.empty(
582
- (num_experts + 1,), dtype=torch.int32, device=topk_ids.device
583
- )
584
- token_cnts_buffer = torch.empty(
585
- (num_experts + 1) * num_experts,
586
- dtype=torch.int32,
587
- device=topk_ids.device,
601
+ (num_experts + 2,), dtype=torch.int32, device=topk_ids.device
588
602
  )
589
603
 
590
604
  # Threshold based on benchmark results
@@ -594,12 +608,11 @@ def moe_align_block_size(
594
608
 
595
609
  sgl_moe_align_block_size(
596
610
  topk_ids,
597
- num_experts,
611
+ num_experts + 1,
598
612
  block_size,
599
613
  sorted_ids,
600
614
  expert_ids,
601
615
  num_tokens_post_pad,
602
- token_cnts_buffer,
603
616
  cumsum_buffer,
604
617
  fuse_sorted_ids_padding,
605
618
  )