sglang 0.4.9.post5__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +3 -0
- sglang/srt/entrypoints/http_server.py +13 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/layers/moe/ep_moe/layer.py +152 -37
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +6 -2
- sglang/srt/layers/quantization/modelopt_quant.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +4 -0
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/scheduler.py +29 -0
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/tokenizer_manager.py +43 -9
- sglang/srt/managers/tp_worker.py +5 -0
- sglang/srt/model_executor/model_runner.py +15 -13
- sglang/srt/models/deepseek_v2.py +13 -56
- sglang/srt/models/qwen3_moe.py +12 -69
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/server_args.py +8 -0
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/test/test_utils.py +53 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +2 -1
- {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +32 -25
- {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -261,6 +261,9 @@ class ModelConfig:
|
|
261
261
|
self.num_key_value_heads = self.num_attention_heads
|
262
262
|
self.hidden_size = self.hf_text_config.hidden_size
|
263
263
|
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
264
|
+
self.num_nextn_predict_layers = getattr(
|
265
|
+
self.hf_text_config, "num_nextn_predict_layers", None
|
266
|
+
)
|
264
267
|
self.vocab_size = self.hf_text_config.vocab_size
|
265
268
|
|
266
269
|
# Verify quantization
|
@@ -38,7 +38,7 @@ import orjson
|
|
38
38
|
import requests
|
39
39
|
import uvicorn
|
40
40
|
import uvloop
|
41
|
-
from fastapi import Depends, FastAPI, Request, UploadFile
|
41
|
+
from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile
|
42
42
|
from fastapi.exceptions import RequestValidationError
|
43
43
|
from fastapi.middleware.cors import CORSMiddleware
|
44
44
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
@@ -174,6 +174,18 @@ app.add_middleware(
|
|
174
174
|
)
|
175
175
|
|
176
176
|
|
177
|
+
@app.exception_handler(HTTPException)
|
178
|
+
async def validation_exception_handler(request: Request, exc: HTTPException):
|
179
|
+
"""Enrich HTTP exception with status code and other details"""
|
180
|
+
error = ErrorResponse(
|
181
|
+
object="error",
|
182
|
+
message=exc.detail,
|
183
|
+
type=str(exc.status_code),
|
184
|
+
code=exc.status_code,
|
185
|
+
)
|
186
|
+
return ORJSONResponse(content=error.model_dump(), status_code=exc.status_code)
|
187
|
+
|
188
|
+
|
177
189
|
# Custom exception handlers to change validation error status codes
|
178
190
|
@app.exception_handler(RequestValidationError)
|
179
191
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
@@ -317,7 +317,9 @@ class ToolCall(BaseModel):
|
|
317
317
|
|
318
318
|
class ChatCompletionMessageGenericParam(BaseModel):
|
319
319
|
role: Literal["system", "assistant", "tool"]
|
320
|
-
content: Union[str, List[ChatCompletionMessageContentTextPart], None]
|
320
|
+
content: Union[str, List[ChatCompletionMessageContentTextPart], None] = Field(
|
321
|
+
default=None
|
322
|
+
)
|
321
323
|
tool_call_id: Optional[str] = None
|
322
324
|
name: Optional[str] = None
|
323
325
|
reasoning_content: Optional[str] = None
|
@@ -4,7 +4,7 @@ import uuid
|
|
4
4
|
from abc import ABC, abstractmethod
|
5
5
|
from typing import Any, Optional, Union
|
6
6
|
|
7
|
-
from fastapi import Request
|
7
|
+
from fastapi import HTTPException, Request
|
8
8
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
9
9
|
|
10
10
|
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
|
@@ -45,7 +45,10 @@ class OpenAIServingBase(ABC):
|
|
45
45
|
return await self._handle_non_streaming_request(
|
46
46
|
adapted_request, processed_request, raw_request
|
47
47
|
)
|
48
|
-
|
48
|
+
except HTTPException as e:
|
49
|
+
return self.create_error_response(
|
50
|
+
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
|
51
|
+
)
|
49
52
|
except Exception as e:
|
50
53
|
logger.exception(f"Error in request: {e}")
|
51
54
|
return self.create_error_response(
|
@@ -1,5 +1,7 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import logging
|
2
|
-
from typing import List, Optional, Tuple
|
4
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
3
5
|
|
4
6
|
import torch
|
5
7
|
|
@@ -50,6 +52,13 @@ from sglang.srt.utils import (
|
|
50
52
|
next_power_of_2,
|
51
53
|
)
|
52
54
|
|
55
|
+
if TYPE_CHECKING:
|
56
|
+
from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
|
57
|
+
DeepEPLLOutput,
|
58
|
+
DeepEPNormalOutput,
|
59
|
+
DispatchOutput,
|
60
|
+
)
|
61
|
+
|
53
62
|
_is_hip = is_hip()
|
54
63
|
_is_npu = is_npu()
|
55
64
|
_is_fp8_fnuz = is_fp8_fnuz()
|
@@ -791,11 +800,24 @@ class DeepEPMoE(EPMoE):
|
|
791
800
|
routed_scaling_factor=routed_scaling_factor,
|
792
801
|
)
|
793
802
|
self.deepep_mode = deepep_mode
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
803
|
+
|
804
|
+
# TODO: move to the beginning of the file
|
805
|
+
from sglang.srt.distributed.parallel_state import get_tp_group
|
806
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
807
|
+
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
808
|
+
|
809
|
+
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
810
|
+
group=get_tp_group().device_group,
|
811
|
+
router_topk=self.top_k,
|
812
|
+
permute_fusion=True,
|
813
|
+
num_experts=self.num_experts,
|
814
|
+
num_local_experts=self.num_local_experts,
|
815
|
+
hidden_size=hidden_size,
|
816
|
+
params_dtype=params_dtype,
|
817
|
+
deepep_mode=deepep_mode,
|
818
|
+
async_finish=True, # TODO
|
819
|
+
return_recv_hook=True,
|
820
|
+
)
|
799
821
|
|
800
822
|
if self.deepep_mode.enable_low_latency():
|
801
823
|
assert (
|
@@ -837,37 +859,128 @@ class DeepEPMoE(EPMoE):
|
|
837
859
|
hidden_states: torch.Tensor,
|
838
860
|
topk_idx: torch.Tensor,
|
839
861
|
topk_weights: torch.Tensor,
|
840
|
-
reorder_topk_ids: torch.Tensor,
|
841
|
-
seg_indptr: torch.Tensor,
|
842
|
-
masked_m: torch.Tensor,
|
843
|
-
expected_m: int,
|
844
|
-
num_recv_tokens_per_expert: List[int],
|
845
862
|
forward_batch: ForwardBatch,
|
846
863
|
):
|
864
|
+
dispatch_output = self.dispatch(
|
865
|
+
hidden_states, topk_idx, topk_weights, forward_batch
|
866
|
+
)
|
867
|
+
hidden_states = self.moe_impl(dispatch_output)
|
868
|
+
hidden_states = self.combine(
|
869
|
+
hidden_states,
|
870
|
+
dispatch_output.topk_idx,
|
871
|
+
dispatch_output.topk_weights,
|
872
|
+
forward_batch,
|
873
|
+
)
|
874
|
+
return hidden_states
|
875
|
+
|
876
|
+
def dispatch(
|
877
|
+
self,
|
878
|
+
hidden_states: torch.Tensor,
|
879
|
+
topk_idx: torch.Tensor,
|
880
|
+
topk_weights: torch.Tensor,
|
881
|
+
forward_batch: ForwardBatch,
|
882
|
+
):
|
883
|
+
return self.deepep_dispatcher.dispatch(
|
884
|
+
hidden_states=hidden_states,
|
885
|
+
topk_idx=topk_idx,
|
886
|
+
topk_weights=topk_weights,
|
887
|
+
forward_batch=forward_batch,
|
888
|
+
)
|
889
|
+
|
890
|
+
def moe_impl(self, dispatch_output: DispatchOutput):
|
847
891
|
if _use_aiter:
|
848
892
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
849
|
-
return self.forward_aiter(
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
if resolved_deepep_mode == DeepEPMode.normal:
|
854
|
-
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
855
|
-
return self.forward_deepgemm_contiguous(
|
856
|
-
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
857
|
-
)
|
893
|
+
return self.forward_aiter(dispatch_output)
|
894
|
+
if dispatch_output.format.is_deepep_normal():
|
895
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
896
|
+
return self.forward_deepgemm_contiguous(dispatch_output)
|
858
897
|
else:
|
859
|
-
return self.forward_normal(
|
860
|
-
elif
|
861
|
-
return self.forward_deepgemm_masked(
|
898
|
+
return self.forward_normal(dispatch_output)
|
899
|
+
elif dispatch_output.format.is_deepep_ll():
|
900
|
+
return self.forward_deepgemm_masked(dispatch_output)
|
862
901
|
else:
|
863
902
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
864
903
|
|
865
|
-
def
|
904
|
+
def combine(
|
866
905
|
self,
|
867
906
|
hidden_states: torch.Tensor,
|
868
|
-
|
869
|
-
|
907
|
+
topk_idx: torch.Tensor,
|
908
|
+
topk_weights: torch.Tensor,
|
909
|
+
forward_batch: ForwardBatch,
|
870
910
|
):
|
911
|
+
return self.deepep_dispatcher.combine(
|
912
|
+
hidden_states=hidden_states,
|
913
|
+
topk_idx=topk_idx,
|
914
|
+
topk_weights=topk_weights,
|
915
|
+
forward_batch=forward_batch,
|
916
|
+
)
|
917
|
+
|
918
|
+
def _prepare_for_normal(
|
919
|
+
self,
|
920
|
+
hidden_states: torch.Tensor,
|
921
|
+
topk_idx: torch.Tensor,
|
922
|
+
):
|
923
|
+
from sglang.srt.layers.moe.ep_moe.kernels import (
|
924
|
+
deepep_permute_triton_kernel,
|
925
|
+
deepep_run_moe_deep_preprocess,
|
926
|
+
)
|
927
|
+
|
928
|
+
if hidden_states.shape[0] == 0:
|
929
|
+
reorder_topk_ids = torch.empty(
|
930
|
+
(0,), device=hidden_states.device, dtype=torch.int64
|
931
|
+
)
|
932
|
+
seg_indptr = torch.zeros(
|
933
|
+
(self.num_experts + 1,),
|
934
|
+
device=hidden_states.device,
|
935
|
+
dtype=torch.int64,
|
936
|
+
)
|
937
|
+
return reorder_topk_ids, seg_indptr, hidden_states
|
938
|
+
else:
|
939
|
+
if _use_aiter:
|
940
|
+
# skip permutation here as aiter fused_moe has fused inside
|
941
|
+
reorder_topk_ids = torch.empty(
|
942
|
+
(0,), device=hidden_states.device, dtype=torch.int64
|
943
|
+
)
|
944
|
+
seg_indptr = torch.zeros(
|
945
|
+
(self.num_experts + 1,),
|
946
|
+
device=hidden_states.device,
|
947
|
+
dtype=torch.int64,
|
948
|
+
)
|
949
|
+
return reorder_topk_ids, seg_indptr, hidden_states
|
950
|
+
|
951
|
+
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
952
|
+
topk_idx, self.num_experts
|
953
|
+
)
|
954
|
+
num_total_tokens = reorder_topk_ids.numel()
|
955
|
+
gateup_input = torch.empty(
|
956
|
+
(int(num_total_tokens), hidden_states.shape[1]),
|
957
|
+
device=hidden_states.device,
|
958
|
+
dtype=hidden_states.dtype,
|
959
|
+
)
|
960
|
+
# PreReorder
|
961
|
+
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
962
|
+
hidden_states,
|
963
|
+
gateup_input,
|
964
|
+
self.src2dst,
|
965
|
+
topk_idx,
|
966
|
+
None,
|
967
|
+
self.router_topk,
|
968
|
+
hidden_states.shape[1],
|
969
|
+
BLOCK_SIZE=512,
|
970
|
+
)
|
971
|
+
return reorder_topk_ids, seg_indptr, gateup_input
|
972
|
+
|
973
|
+
def forward_normal(
|
974
|
+
self,
|
975
|
+
dispatch_output: DeepEPNormalOutput,
|
976
|
+
):
|
977
|
+
hidden_states, topk_idx = (
|
978
|
+
dispatch_output.hidden_states,
|
979
|
+
dispatch_output.topk_idx,
|
980
|
+
)
|
981
|
+
reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
|
982
|
+
hidden_states, topk_idx
|
983
|
+
)
|
871
984
|
hidden_states_dtype = hidden_states.dtype
|
872
985
|
hidden_states_device = hidden_states.device
|
873
986
|
|
@@ -983,10 +1096,13 @@ class DeepEPMoE(EPMoE):
|
|
983
1096
|
|
984
1097
|
def forward_aiter(
|
985
1098
|
self,
|
986
|
-
|
987
|
-
topk_idx: torch.Tensor,
|
988
|
-
topk_weights: torch.Tensor,
|
1099
|
+
dispatch_output: DeepEPNormalOutput,
|
989
1100
|
):
|
1101
|
+
hidden_states, topk_idx, topk_weights = (
|
1102
|
+
dispatch_output.hidden_states,
|
1103
|
+
dispatch_output.topk_idx,
|
1104
|
+
dispatch_output.topk_weights,
|
1105
|
+
)
|
990
1106
|
if hidden_states.shape[0] == 0:
|
991
1107
|
return hidden_states
|
992
1108
|
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
@@ -1014,11 +1130,11 @@ class DeepEPMoE(EPMoE):
|
|
1014
1130
|
|
1015
1131
|
def forward_deepgemm_contiguous(
|
1016
1132
|
self,
|
1017
|
-
|
1018
|
-
topk_idx,
|
1019
|
-
topk_weights,
|
1020
|
-
num_recv_tokens_per_expert: List[int],
|
1133
|
+
dispatch_output: DeepEPNormalOutput,
|
1021
1134
|
):
|
1135
|
+
hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
|
1136
|
+
dispatch_output
|
1137
|
+
)
|
1022
1138
|
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
1023
1139
|
assert self.quant_method is not None
|
1024
1140
|
assert self.activation == "silu"
|
@@ -1138,10 +1254,9 @@ class DeepEPMoE(EPMoE):
|
|
1138
1254
|
|
1139
1255
|
def forward_deepgemm_masked(
|
1140
1256
|
self,
|
1141
|
-
|
1142
|
-
masked_m: torch.Tensor,
|
1143
|
-
expected_m: int,
|
1257
|
+
dispatch_output: DeepEPLLOutput,
|
1144
1258
|
):
|
1259
|
+
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
|
1145
1260
|
assert self.quant_method is not None
|
1146
1261
|
assert self.activation == "silu"
|
1147
1262
|
|
@@ -1268,7 +1383,7 @@ class FlashInferEPMoE(EPMoE):
|
|
1268
1383
|
topk_group=self.topk_group,
|
1269
1384
|
intermediate_size=self.w2_weight.shape[2],
|
1270
1385
|
local_expert_offset=self.start_expert_id,
|
1271
|
-
local_num_experts=self.
|
1386
|
+
local_num_experts=self.num_local_experts,
|
1272
1387
|
routed_scaling_factor=self.routed_scaling_factor,
|
1273
1388
|
tile_tokens_dim=_get_tile_tokens_dim(
|
1274
1389
|
hidden_states.shape[0], self.top_k, self.num_experts
|
@@ -1,7 +1,27 @@
|
|
1
|
+
# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
1
5
|
import logging
|
2
6
|
from dataclasses import dataclass
|
7
|
+
from typing import (
|
8
|
+
TYPE_CHECKING,
|
9
|
+
List,
|
10
|
+
NamedTuple,
|
11
|
+
Optional,
|
12
|
+
Protocol,
|
13
|
+
Tuple,
|
14
|
+
Union,
|
15
|
+
runtime_checkable,
|
16
|
+
)
|
3
17
|
|
4
18
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
19
|
+
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
20
|
+
BaseDispatcher,
|
21
|
+
BaseDispatcherConfig,
|
22
|
+
DispatchOutput,
|
23
|
+
DispatchOutputFormat,
|
24
|
+
)
|
5
25
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
6
26
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
7
27
|
from sglang.srt.utils import (
|
@@ -24,7 +44,6 @@ except ImportError:
|
|
24
44
|
use_deepep = False
|
25
45
|
|
26
46
|
from enum import Enum, IntEnum, auto
|
27
|
-
from typing import Optional, Tuple, Union
|
28
47
|
|
29
48
|
import torch
|
30
49
|
import torch.distributed as dist
|
@@ -41,6 +60,37 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
|
41
60
|
logger = logging.getLogger(__name__)
|
42
61
|
|
43
62
|
|
63
|
+
class DeepEPNormalOutput(NamedTuple):
|
64
|
+
"""DeepEP normal dispatch output."""
|
65
|
+
|
66
|
+
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
|
67
|
+
topk_idx: torch.Tensor
|
68
|
+
topk_weights: torch.Tensor
|
69
|
+
num_recv_tokens_per_expert: List[int]
|
70
|
+
|
71
|
+
@property
|
72
|
+
def format(self) -> DispatchOutputFormat:
|
73
|
+
return DispatchOutputFormat.deepep_normal
|
74
|
+
|
75
|
+
|
76
|
+
class DeepEPLLOutput(NamedTuple):
|
77
|
+
"""DeepEP low latency dispatch output."""
|
78
|
+
|
79
|
+
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
80
|
+
topk_idx: torch.Tensor
|
81
|
+
topk_weights: torch.Tensor
|
82
|
+
masked_m: torch.Tensor
|
83
|
+
expected_m: int
|
84
|
+
|
85
|
+
@property
|
86
|
+
def format(self) -> DispatchOutputFormat:
|
87
|
+
return DispatchOutputFormat.deepep_ll
|
88
|
+
|
89
|
+
|
90
|
+
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
91
|
+
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
92
|
+
|
93
|
+
|
44
94
|
class DeepEPDispatchMode(IntEnum):
|
45
95
|
NORMAL = auto()
|
46
96
|
LOW_LATENCY = auto()
|
@@ -107,6 +157,20 @@ class DeepEPBuffer:
|
|
107
157
|
else:
|
108
158
|
raise NotImplementedError
|
109
159
|
|
160
|
+
total_num_sms = torch.cuda.get_device_properties(
|
161
|
+
device="cuda"
|
162
|
+
).multi_processor_count
|
163
|
+
if (
|
164
|
+
(deepep_mode != DeepEPMode.low_latency)
|
165
|
+
and not global_server_args_dict["enable_two_batch_overlap"]
|
166
|
+
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
167
|
+
):
|
168
|
+
logger.warning(
|
169
|
+
f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
|
170
|
+
f"This may result in highly suboptimal performance. "
|
171
|
+
f"Consider using --deepep-config to change the behavior."
|
172
|
+
)
|
173
|
+
|
110
174
|
cls._buffer = Buffer(
|
111
175
|
group,
|
112
176
|
num_nvl_bytes,
|
@@ -139,7 +203,7 @@ class DeepEPBuffer:
|
|
139
203
|
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
140
204
|
|
141
205
|
|
142
|
-
class DeepEPConfig:
|
206
|
+
class DeepEPConfig(BaseDispatcherConfig):
|
143
207
|
_instance = None
|
144
208
|
|
145
209
|
def __init__(self):
|
@@ -255,63 +319,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
255
319
|
return hidden_states, topk_idx, topk_weights, previous_event
|
256
320
|
|
257
321
|
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
return (
|
270
|
-
hidden_states,
|
271
|
-
topk_idx,
|
272
|
-
topk_weights,
|
273
|
-
None,
|
274
|
-
num_recv_tokens_per_expert_list,
|
275
|
-
None,
|
276
|
-
None,
|
277
|
-
None,
|
278
|
-
)
|
279
|
-
else:
|
280
|
-
(
|
281
|
-
hidden_states,
|
282
|
-
topk_idx,
|
283
|
-
topk_weights,
|
284
|
-
num_recv_tokens_per_expert_list,
|
285
|
-
event,
|
286
|
-
) = self._dispatch_core(
|
287
|
-
hidden_states, topk_idx, topk_weights, previous_event
|
288
|
-
)
|
289
|
-
event.current_stream_wait() if self.async_finish else ()
|
290
|
-
if hidden_states.shape[0] > 0:
|
291
|
-
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
292
|
-
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
293
|
-
)
|
294
|
-
else:
|
295
|
-
reorder_topk_ids = torch.empty(
|
296
|
-
(0,), device=hidden_states.device, dtype=torch.int64
|
297
|
-
)
|
298
|
-
seg_indptr = torch.zeros(
|
299
|
-
(self.num_experts + 1,),
|
300
|
-
device=hidden_states.device,
|
301
|
-
dtype=torch.int64,
|
302
|
-
)
|
303
|
-
|
304
|
-
masked_m = expected_m = None
|
305
|
-
return (
|
306
|
-
hidden_states,
|
307
|
-
topk_idx,
|
308
|
-
topk_weights,
|
309
|
-
reorder_topk_ids,
|
310
|
-
None,
|
311
|
-
seg_indptr,
|
312
|
-
masked_m,
|
313
|
-
expected_m,
|
314
|
-
)
|
322
|
+
(
|
323
|
+
hidden_states,
|
324
|
+
topk_idx,
|
325
|
+
topk_weights,
|
326
|
+
num_recv_tokens_per_expert,
|
327
|
+
event,
|
328
|
+
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
|
329
|
+
event.current_stream_wait() if self.async_finish else ()
|
330
|
+
return DeepEPNormalOutput(
|
331
|
+
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
332
|
+
)
|
315
333
|
|
316
334
|
def _dispatch_core(
|
317
335
|
self,
|
@@ -343,7 +361,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
343
361
|
recv_x,
|
344
362
|
recv_topk_idx,
|
345
363
|
recv_topk_weights,
|
346
|
-
|
364
|
+
num_recv_tokens_per_expert,
|
347
365
|
self.handle,
|
348
366
|
event,
|
349
367
|
) = buffer.dispatch(
|
@@ -362,7 +380,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
362
380
|
)
|
363
381
|
|
364
382
|
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
365
|
-
|
383
|
+
num_recv_tokens_per_expert,
|
366
384
|
num_tokens_per_rank=num_tokens_per_rank,
|
367
385
|
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
368
386
|
num_tokens_per_expert=num_tokens_per_expert,
|
@@ -372,58 +390,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
372
390
|
recv_x,
|
373
391
|
recv_topk_idx,
|
374
392
|
recv_topk_weights,
|
375
|
-
|
393
|
+
num_recv_tokens_per_expert,
|
376
394
|
event,
|
377
395
|
)
|
378
396
|
|
379
|
-
def _deepep_permute(
|
380
|
-
self,
|
381
|
-
hidden_states: torch.Tensor,
|
382
|
-
topk_idx: torch.Tensor,
|
383
|
-
fp8_dtype: Optional[torch.dtype] = None,
|
384
|
-
use_fp8_w8a8: bool = False,
|
385
|
-
use_block_quant: bool = False,
|
386
|
-
):
|
387
|
-
"""
|
388
|
-
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
389
|
-
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
390
|
-
"""
|
391
|
-
if _use_aiter:
|
392
|
-
# skip permutation here as aiter fused_moe has fused inside
|
393
|
-
reorder_topk_ids = torch.empty(
|
394
|
-
(0,), device=hidden_states.device, dtype=torch.int64
|
395
|
-
)
|
396
|
-
seg_indptr = torch.zeros(
|
397
|
-
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
398
|
-
)
|
399
|
-
return reorder_topk_ids, seg_indptr, hidden_states
|
400
|
-
|
401
|
-
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
402
|
-
topk_idx, self.num_experts
|
403
|
-
)
|
404
|
-
num_total_tokens = reorder_topk_ids.numel()
|
405
|
-
gateup_input = torch.empty(
|
406
|
-
(int(num_total_tokens), hidden_states.shape[1]),
|
407
|
-
device=hidden_states.device,
|
408
|
-
dtype=(
|
409
|
-
fp8_dtype
|
410
|
-
if (use_fp8_w8a8 and not use_block_quant)
|
411
|
-
else hidden_states.dtype
|
412
|
-
),
|
413
|
-
)
|
414
|
-
# PreReorder
|
415
|
-
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
416
|
-
hidden_states,
|
417
|
-
gateup_input,
|
418
|
-
self.src2dst,
|
419
|
-
topk_idx,
|
420
|
-
None,
|
421
|
-
self.router_topk,
|
422
|
-
hidden_states.shape[1],
|
423
|
-
BLOCK_SIZE=512,
|
424
|
-
)
|
425
|
-
return reorder_topk_ids, seg_indptr, gateup_input
|
426
|
-
|
427
397
|
def combine_a(
|
428
398
|
self,
|
429
399
|
hidden_states: torch.Tensor,
|
@@ -544,15 +514,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
544
514
|
masked_m
|
545
515
|
)
|
546
516
|
|
547
|
-
|
548
|
-
|
549
|
-
return (
|
517
|
+
return DeepEPLLOutput(
|
550
518
|
hidden_states,
|
551
519
|
topk_idx,
|
552
520
|
topk_weights,
|
553
|
-
reorder_topk_ids,
|
554
|
-
None,
|
555
|
-
seg_indptr,
|
556
521
|
masked_m,
|
557
522
|
expected_m,
|
558
523
|
)
|
@@ -636,7 +601,7 @@ class _Stage(Enum):
|
|
636
601
|
AFTER_COMBINE_A = auto()
|
637
602
|
|
638
603
|
|
639
|
-
class DeepEPDispatcher:
|
604
|
+
class DeepEPDispatcher(BaseDispatcher):
|
640
605
|
def __init__(
|
641
606
|
self,
|
642
607
|
group: torch.distributed.ProcessGroup,
|
@@ -676,7 +641,7 @@ class DeepEPDispatcher:
|
|
676
641
|
|
677
642
|
self._stage = _Stage.INITIAL
|
678
643
|
|
679
|
-
def dispatch(self, *args, **kwargs) ->
|
644
|
+
def dispatch(self, *args, **kwargs) -> DispatchOutput:
|
680
645
|
self.dispatch_a(*args, **kwargs)
|
681
646
|
ret = self.dispatch_b()
|
682
647
|
return ret
|