sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +23 -3
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +98 -603
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +6 -2
- sglang/srt/layers/quantization/fp8.py +0 -18
- sglang/srt/layers/quantization/modelopt_quant.py +2 -0
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/managers/cache_controller.py +143 -45
- sglang/srt/managers/data_parallel_controller.py +6 -0
- sglang/srt/managers/io_struct.py +12 -2
- sglang/srt/managers/scheduler.py +116 -669
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +166 -83
- sglang/srt/managers/tp_worker.py +9 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +45 -11
- sglang/srt/mem_cache/hiradix_cache.py +15 -4
- sglang/srt/mem_cache/memory_pool_host.py +73 -1
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/model_runner.py +20 -13
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +15 -56
- sglang/srt/models/glm4_moe.py +3 -1
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/qwen3_moe.py +12 -69
- sglang/srt/models/step3_vl.py +994 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +18 -13
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/test/test_utils.py +53 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,27 @@
|
|
1
|
+
# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
1
5
|
import logging
|
2
6
|
from dataclasses import dataclass
|
7
|
+
from typing import (
|
8
|
+
TYPE_CHECKING,
|
9
|
+
List,
|
10
|
+
NamedTuple,
|
11
|
+
Optional,
|
12
|
+
Protocol,
|
13
|
+
Tuple,
|
14
|
+
Union,
|
15
|
+
runtime_checkable,
|
16
|
+
)
|
3
17
|
|
4
18
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
19
|
+
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
20
|
+
BaseDispatcher,
|
21
|
+
BaseDispatcherConfig,
|
22
|
+
DispatchOutput,
|
23
|
+
DispatchOutputFormat,
|
24
|
+
)
|
5
25
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
6
26
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
7
27
|
from sglang.srt.utils import (
|
@@ -24,7 +44,6 @@ except ImportError:
|
|
24
44
|
use_deepep = False
|
25
45
|
|
26
46
|
from enum import Enum, IntEnum, auto
|
27
|
-
from typing import Optional, Tuple, Union
|
28
47
|
|
29
48
|
import torch
|
30
49
|
import torch.distributed as dist
|
@@ -41,6 +60,37 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
|
41
60
|
logger = logging.getLogger(__name__)
|
42
61
|
|
43
62
|
|
63
|
+
class DeepEPNormalOutput(NamedTuple):
|
64
|
+
"""DeepEP normal dispatch output."""
|
65
|
+
|
66
|
+
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
|
67
|
+
topk_idx: torch.Tensor
|
68
|
+
topk_weights: torch.Tensor
|
69
|
+
num_recv_tokens_per_expert: List[int]
|
70
|
+
|
71
|
+
@property
|
72
|
+
def format(self) -> DispatchOutputFormat:
|
73
|
+
return DispatchOutputFormat.deepep_normal
|
74
|
+
|
75
|
+
|
76
|
+
class DeepEPLLOutput(NamedTuple):
|
77
|
+
"""DeepEP low latency dispatch output."""
|
78
|
+
|
79
|
+
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
80
|
+
topk_idx: torch.Tensor
|
81
|
+
topk_weights: torch.Tensor
|
82
|
+
masked_m: torch.Tensor
|
83
|
+
expected_m: int
|
84
|
+
|
85
|
+
@property
|
86
|
+
def format(self) -> DispatchOutputFormat:
|
87
|
+
return DispatchOutputFormat.deepep_ll
|
88
|
+
|
89
|
+
|
90
|
+
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
91
|
+
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
92
|
+
|
93
|
+
|
44
94
|
class DeepEPDispatchMode(IntEnum):
|
45
95
|
NORMAL = auto()
|
46
96
|
LOW_LATENCY = auto()
|
@@ -107,6 +157,20 @@ class DeepEPBuffer:
|
|
107
157
|
else:
|
108
158
|
raise NotImplementedError
|
109
159
|
|
160
|
+
total_num_sms = torch.cuda.get_device_properties(
|
161
|
+
device="cuda"
|
162
|
+
).multi_processor_count
|
163
|
+
if (
|
164
|
+
(deepep_mode != DeepEPMode.low_latency)
|
165
|
+
and not global_server_args_dict["enable_two_batch_overlap"]
|
166
|
+
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
167
|
+
):
|
168
|
+
logger.warning(
|
169
|
+
f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
|
170
|
+
f"This may result in highly suboptimal performance. "
|
171
|
+
f"Consider using --deepep-config to change the behavior."
|
172
|
+
)
|
173
|
+
|
110
174
|
cls._buffer = Buffer(
|
111
175
|
group,
|
112
176
|
num_nvl_bytes,
|
@@ -139,7 +203,7 @@ class DeepEPBuffer:
|
|
139
203
|
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
140
204
|
|
141
205
|
|
142
|
-
class DeepEPConfig:
|
206
|
+
class DeepEPConfig(BaseDispatcherConfig):
|
143
207
|
_instance = None
|
144
208
|
|
145
209
|
def __init__(self):
|
@@ -255,63 +319,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
255
319
|
return hidden_states, topk_idx, topk_weights, previous_event
|
256
320
|
|
257
321
|
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
return (
|
270
|
-
hidden_states,
|
271
|
-
topk_idx,
|
272
|
-
topk_weights,
|
273
|
-
None,
|
274
|
-
num_recv_tokens_per_expert_list,
|
275
|
-
None,
|
276
|
-
None,
|
277
|
-
None,
|
278
|
-
)
|
279
|
-
else:
|
280
|
-
(
|
281
|
-
hidden_states,
|
282
|
-
topk_idx,
|
283
|
-
topk_weights,
|
284
|
-
num_recv_tokens_per_expert_list,
|
285
|
-
event,
|
286
|
-
) = self._dispatch_core(
|
287
|
-
hidden_states, topk_idx, topk_weights, previous_event
|
288
|
-
)
|
289
|
-
event.current_stream_wait() if self.async_finish else ()
|
290
|
-
if hidden_states.shape[0] > 0:
|
291
|
-
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
292
|
-
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
293
|
-
)
|
294
|
-
else:
|
295
|
-
reorder_topk_ids = torch.empty(
|
296
|
-
(0,), device=hidden_states.device, dtype=torch.int64
|
297
|
-
)
|
298
|
-
seg_indptr = torch.zeros(
|
299
|
-
(self.num_experts + 1,),
|
300
|
-
device=hidden_states.device,
|
301
|
-
dtype=torch.int64,
|
302
|
-
)
|
303
|
-
|
304
|
-
masked_m = expected_m = None
|
305
|
-
return (
|
306
|
-
hidden_states,
|
307
|
-
topk_idx,
|
308
|
-
topk_weights,
|
309
|
-
reorder_topk_ids,
|
310
|
-
None,
|
311
|
-
seg_indptr,
|
312
|
-
masked_m,
|
313
|
-
expected_m,
|
314
|
-
)
|
322
|
+
(
|
323
|
+
hidden_states,
|
324
|
+
topk_idx,
|
325
|
+
topk_weights,
|
326
|
+
num_recv_tokens_per_expert,
|
327
|
+
event,
|
328
|
+
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
|
329
|
+
event.current_stream_wait() if self.async_finish else ()
|
330
|
+
return DeepEPNormalOutput(
|
331
|
+
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
332
|
+
)
|
315
333
|
|
316
334
|
def _dispatch_core(
|
317
335
|
self,
|
@@ -343,7 +361,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
343
361
|
recv_x,
|
344
362
|
recv_topk_idx,
|
345
363
|
recv_topk_weights,
|
346
|
-
|
364
|
+
num_recv_tokens_per_expert,
|
347
365
|
self.handle,
|
348
366
|
event,
|
349
367
|
) = buffer.dispatch(
|
@@ -362,7 +380,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
362
380
|
)
|
363
381
|
|
364
382
|
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
365
|
-
|
383
|
+
num_recv_tokens_per_expert,
|
366
384
|
num_tokens_per_rank=num_tokens_per_rank,
|
367
385
|
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
368
386
|
num_tokens_per_expert=num_tokens_per_expert,
|
@@ -372,58 +390,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
372
390
|
recv_x,
|
373
391
|
recv_topk_idx,
|
374
392
|
recv_topk_weights,
|
375
|
-
|
393
|
+
num_recv_tokens_per_expert,
|
376
394
|
event,
|
377
395
|
)
|
378
396
|
|
379
|
-
def _deepep_permute(
|
380
|
-
self,
|
381
|
-
hidden_states: torch.Tensor,
|
382
|
-
topk_idx: torch.Tensor,
|
383
|
-
fp8_dtype: Optional[torch.dtype] = None,
|
384
|
-
use_fp8_w8a8: bool = False,
|
385
|
-
use_block_quant: bool = False,
|
386
|
-
):
|
387
|
-
"""
|
388
|
-
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
389
|
-
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
390
|
-
"""
|
391
|
-
if _use_aiter:
|
392
|
-
# skip permutation here as aiter fused_moe has fused inside
|
393
|
-
reorder_topk_ids = torch.empty(
|
394
|
-
(0,), device=hidden_states.device, dtype=torch.int64
|
395
|
-
)
|
396
|
-
seg_indptr = torch.zeros(
|
397
|
-
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
398
|
-
)
|
399
|
-
return reorder_topk_ids, seg_indptr, hidden_states
|
400
|
-
|
401
|
-
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
402
|
-
topk_idx, self.num_experts
|
403
|
-
)
|
404
|
-
num_total_tokens = reorder_topk_ids.numel()
|
405
|
-
gateup_input = torch.empty(
|
406
|
-
(int(num_total_tokens), hidden_states.shape[1]),
|
407
|
-
device=hidden_states.device,
|
408
|
-
dtype=(
|
409
|
-
fp8_dtype
|
410
|
-
if (use_fp8_w8a8 and not use_block_quant)
|
411
|
-
else hidden_states.dtype
|
412
|
-
),
|
413
|
-
)
|
414
|
-
# PreReorder
|
415
|
-
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
416
|
-
hidden_states,
|
417
|
-
gateup_input,
|
418
|
-
self.src2dst,
|
419
|
-
topk_idx,
|
420
|
-
None,
|
421
|
-
self.router_topk,
|
422
|
-
hidden_states.shape[1],
|
423
|
-
BLOCK_SIZE=512,
|
424
|
-
)
|
425
|
-
return reorder_topk_ids, seg_indptr, gateup_input
|
426
|
-
|
427
397
|
def combine_a(
|
428
398
|
self,
|
429
399
|
hidden_states: torch.Tensor,
|
@@ -544,15 +514,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
544
514
|
masked_m
|
545
515
|
)
|
546
516
|
|
547
|
-
|
548
|
-
|
549
|
-
return (
|
517
|
+
return DeepEPLLOutput(
|
550
518
|
hidden_states,
|
551
519
|
topk_idx,
|
552
520
|
topk_weights,
|
553
|
-
reorder_topk_ids,
|
554
|
-
None,
|
555
|
-
seg_indptr,
|
556
521
|
masked_m,
|
557
522
|
expected_m,
|
558
523
|
)
|
@@ -636,7 +601,7 @@ class _Stage(Enum):
|
|
636
601
|
AFTER_COMBINE_A = auto()
|
637
602
|
|
638
603
|
|
639
|
-
class DeepEPDispatcher:
|
604
|
+
class DeepEPDispatcher(BaseDispatcher):
|
640
605
|
def __init__(
|
641
606
|
self,
|
642
607
|
group: torch.distributed.ProcessGroup,
|
@@ -676,7 +641,7 @@ class DeepEPDispatcher:
|
|
676
641
|
|
677
642
|
self._stage = _Stage.INITIAL
|
678
643
|
|
679
|
-
def dispatch(self, *args, **kwargs) ->
|
644
|
+
def dispatch(self, *args, **kwargs) -> DispatchOutput:
|
680
645
|
self.dispatch_a(*args, **kwargs)
|
681
646
|
ret = self.dispatch_b()
|
682
647
|
return ret
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 256,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 256,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 32,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 32,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 32,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 32,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 32,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 8,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 8,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 256,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 256,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 256,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 2
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 256,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 256,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 256,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 2
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 16,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 256,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 8,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 8,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -413,18 +413,37 @@ def fused_moe_kernel(
|
|
413
413
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
414
414
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
415
415
|
return
|
416
|
-
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
416
|
+
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
417
417
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
418
418
|
offs_token = offs_token.to(tl.int64)
|
419
419
|
token_mask = offs_token < num_valid_tokens
|
420
420
|
|
421
|
-
|
421
|
+
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
422
|
+
|
423
|
+
if off_experts == -1:
|
424
|
+
# -----------------------------------------------------------
|
425
|
+
# Write back zeros to the output when the expert is not
|
426
|
+
# in the current expert parallel rank.
|
427
|
+
write_zeros_to_output(
|
428
|
+
c_ptr,
|
429
|
+
stride_cm,
|
430
|
+
stride_cn,
|
431
|
+
pid_n,
|
432
|
+
N,
|
433
|
+
offs_token,
|
434
|
+
token_mask,
|
435
|
+
BLOCK_SIZE_M,
|
436
|
+
BLOCK_SIZE_N,
|
437
|
+
compute_type,
|
438
|
+
)
|
439
|
+
return
|
440
|
+
|
441
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
422
442
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
423
443
|
a_ptrs = a_ptr + (
|
424
444
|
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
425
445
|
)
|
426
446
|
|
427
|
-
off_experts = tl.load(expert_ids_ptr + pid_m)
|
428
447
|
b_ptrs = (
|
429
448
|
b_ptr
|
430
449
|
+ off_experts * stride_be
|
@@ -497,7 +516,6 @@ def fused_moe_kernel(
|
|
497
516
|
|
498
517
|
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
499
518
|
else:
|
500
|
-
# fix out of shared memory issue
|
501
519
|
if use_fp8_w8a8:
|
502
520
|
accumulator = tl.dot(a, b, acc=accumulator)
|
503
521
|
else:
|
@@ -568,7 +586,7 @@ def moe_align_block_size(
|
|
568
586
|
- The padding ensures that the total number of tokens is now divisible
|
569
587
|
by block_size for proper block matrix operations.
|
570
588
|
"""
|
571
|
-
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
589
|
+
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
|
572
590
|
sorted_ids = torch.empty(
|
573
591
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
574
592
|
)
|
@@ -578,13 +596,9 @@ def moe_align_block_size(
|
|
578
596
|
)
|
579
597
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
580
598
|
|
599
|
+
# In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
|
581
600
|
cumsum_buffer = torch.empty(
|
582
|
-
(num_experts +
|
583
|
-
)
|
584
|
-
token_cnts_buffer = torch.empty(
|
585
|
-
(num_experts + 1) * num_experts,
|
586
|
-
dtype=torch.int32,
|
587
|
-
device=topk_ids.device,
|
601
|
+
(num_experts + 2,), dtype=torch.int32, device=topk_ids.device
|
588
602
|
)
|
589
603
|
|
590
604
|
# Threshold based on benchmark results
|
@@ -594,12 +608,11 @@ def moe_align_block_size(
|
|
594
608
|
|
595
609
|
sgl_moe_align_block_size(
|
596
610
|
topk_ids,
|
597
|
-
num_experts,
|
611
|
+
num_experts + 1,
|
598
612
|
block_size,
|
599
613
|
sorted_ids,
|
600
614
|
expert_ids,
|
601
615
|
num_tokens_post_pad,
|
602
|
-
token_cnts_buffer,
|
603
616
|
cumsum_buffer,
|
604
617
|
fuse_sorted_ids_padding,
|
605
618
|
)
|