sglang 0.4.9.post5__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. sglang/srt/configs/model_config.py +3 -0
  2. sglang/srt/entrypoints/http_server.py +13 -1
  3. sglang/srt/entrypoints/openai/protocol.py +3 -1
  4. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  5. sglang/srt/layers/moe/ep_moe/layer.py +152 -37
  6. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  7. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  8. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  9. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  10. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  11. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  12. sglang/srt/layers/moe/topk.py +6 -2
  13. sglang/srt/layers/quantization/modelopt_quant.py +2 -0
  14. sglang/srt/managers/data_parallel_controller.py +4 -0
  15. sglang/srt/managers/io_struct.py +12 -0
  16. sglang/srt/managers/scheduler.py +29 -0
  17. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  18. sglang/srt/managers/tokenizer_manager.py +43 -9
  19. sglang/srt/managers/tp_worker.py +5 -0
  20. sglang/srt/model_executor/model_runner.py +15 -13
  21. sglang/srt/models/deepseek_v2.py +13 -56
  22. sglang/srt/models/qwen3_moe.py +12 -69
  23. sglang/srt/poll_based_barrier.py +31 -0
  24. sglang/srt/server_args.py +8 -0
  25. sglang/srt/two_batch_overlap.py +8 -3
  26. sglang/test/test_utils.py +53 -0
  27. sglang/version.py +1 -1
  28. {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +2 -1
  29. {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +32 -25
  30. {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
  31. {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
  32. {sglang-0.4.9.post5.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -261,6 +261,9 @@ class ModelConfig:
261
261
  self.num_key_value_heads = self.num_attention_heads
262
262
  self.hidden_size = self.hf_text_config.hidden_size
263
263
  self.num_hidden_layers = self.hf_text_config.num_hidden_layers
264
+ self.num_nextn_predict_layers = getattr(
265
+ self.hf_text_config, "num_nextn_predict_layers", None
266
+ )
264
267
  self.vocab_size = self.hf_text_config.vocab_size
265
268
 
266
269
  # Verify quantization
@@ -38,7 +38,7 @@ import orjson
38
38
  import requests
39
39
  import uvicorn
40
40
  import uvloop
41
- from fastapi import Depends, FastAPI, Request, UploadFile
41
+ from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile
42
42
  from fastapi.exceptions import RequestValidationError
43
43
  from fastapi.middleware.cors import CORSMiddleware
44
44
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
@@ -174,6 +174,18 @@ app.add_middleware(
174
174
  )
175
175
 
176
176
 
177
+ @app.exception_handler(HTTPException)
178
+ async def validation_exception_handler(request: Request, exc: HTTPException):
179
+ """Enrich HTTP exception with status code and other details"""
180
+ error = ErrorResponse(
181
+ object="error",
182
+ message=exc.detail,
183
+ type=str(exc.status_code),
184
+ code=exc.status_code,
185
+ )
186
+ return ORJSONResponse(content=error.model_dump(), status_code=exc.status_code)
187
+
188
+
177
189
  # Custom exception handlers to change validation error status codes
178
190
  @app.exception_handler(RequestValidationError)
179
191
  async def validation_exception_handler(request: Request, exc: RequestValidationError):
@@ -317,7 +317,9 @@ class ToolCall(BaseModel):
317
317
 
318
318
  class ChatCompletionMessageGenericParam(BaseModel):
319
319
  role: Literal["system", "assistant", "tool"]
320
- content: Union[str, List[ChatCompletionMessageContentTextPart], None]
320
+ content: Union[str, List[ChatCompletionMessageContentTextPart], None] = Field(
321
+ default=None
322
+ )
321
323
  tool_call_id: Optional[str] = None
322
324
  name: Optional[str] = None
323
325
  reasoning_content: Optional[str] = None
@@ -4,7 +4,7 @@ import uuid
4
4
  from abc import ABC, abstractmethod
5
5
  from typing import Any, Optional, Union
6
6
 
7
- from fastapi import Request
7
+ from fastapi import HTTPException, Request
8
8
  from fastapi.responses import ORJSONResponse, StreamingResponse
9
9
 
10
10
  from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
@@ -45,7 +45,10 @@ class OpenAIServingBase(ABC):
45
45
  return await self._handle_non_streaming_request(
46
46
  adapted_request, processed_request, raw_request
47
47
  )
48
-
48
+ except HTTPException as e:
49
+ return self.create_error_response(
50
+ message=e.detail, err_type=str(e.status_code), status_code=e.status_code
51
+ )
49
52
  except Exception as e:
50
53
  logger.exception(f"Error in request: {e}")
51
54
  return self.create_error_response(
@@ -1,5 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
- from typing import List, Optional, Tuple
4
+ from typing import TYPE_CHECKING, List, Optional, Tuple
3
5
 
4
6
  import torch
5
7
 
@@ -50,6 +52,13 @@ from sglang.srt.utils import (
50
52
  next_power_of_2,
51
53
  )
52
54
 
55
+ if TYPE_CHECKING:
56
+ from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
57
+ DeepEPLLOutput,
58
+ DeepEPNormalOutput,
59
+ DispatchOutput,
60
+ )
61
+
53
62
  _is_hip = is_hip()
54
63
  _is_npu = is_npu()
55
64
  _is_fp8_fnuz = is_fp8_fnuz()
@@ -791,11 +800,24 @@ class DeepEPMoE(EPMoE):
791
800
  routed_scaling_factor=routed_scaling_factor,
792
801
  )
793
802
  self.deepep_mode = deepep_mode
794
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
795
- assert self.use_fp8_w8a8, (
796
- "DeepGEMM requires an fp8_w8a8 model; "
797
- "alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable."
798
- )
803
+
804
+ # TODO: move to the beginning of the file
805
+ from sglang.srt.distributed.parallel_state import get_tp_group
806
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
807
+ from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
808
+
809
+ self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
810
+ group=get_tp_group().device_group,
811
+ router_topk=self.top_k,
812
+ permute_fusion=True,
813
+ num_experts=self.num_experts,
814
+ num_local_experts=self.num_local_experts,
815
+ hidden_size=hidden_size,
816
+ params_dtype=params_dtype,
817
+ deepep_mode=deepep_mode,
818
+ async_finish=True, # TODO
819
+ return_recv_hook=True,
820
+ )
799
821
 
800
822
  if self.deepep_mode.enable_low_latency():
801
823
  assert (
@@ -837,37 +859,128 @@ class DeepEPMoE(EPMoE):
837
859
  hidden_states: torch.Tensor,
838
860
  topk_idx: torch.Tensor,
839
861
  topk_weights: torch.Tensor,
840
- reorder_topk_ids: torch.Tensor,
841
- seg_indptr: torch.Tensor,
842
- masked_m: torch.Tensor,
843
- expected_m: int,
844
- num_recv_tokens_per_expert: List[int],
845
862
  forward_batch: ForwardBatch,
846
863
  ):
864
+ dispatch_output = self.dispatch(
865
+ hidden_states, topk_idx, topk_weights, forward_batch
866
+ )
867
+ hidden_states = self.moe_impl(dispatch_output)
868
+ hidden_states = self.combine(
869
+ hidden_states,
870
+ dispatch_output.topk_idx,
871
+ dispatch_output.topk_weights,
872
+ forward_batch,
873
+ )
874
+ return hidden_states
875
+
876
+ def dispatch(
877
+ self,
878
+ hidden_states: torch.Tensor,
879
+ topk_idx: torch.Tensor,
880
+ topk_weights: torch.Tensor,
881
+ forward_batch: ForwardBatch,
882
+ ):
883
+ return self.deepep_dispatcher.dispatch(
884
+ hidden_states=hidden_states,
885
+ topk_idx=topk_idx,
886
+ topk_weights=topk_weights,
887
+ forward_batch=forward_batch,
888
+ )
889
+
890
+ def moe_impl(self, dispatch_output: DispatchOutput):
847
891
  if _use_aiter:
848
892
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
849
- return self.forward_aiter(hidden_states, topk_idx, topk_weights)
850
- resolved_deepep_mode = self.deepep_mode.resolve(
851
- forward_batch.is_extend_in_batch
852
- )
853
- if resolved_deepep_mode == DeepEPMode.normal:
854
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
855
- return self.forward_deepgemm_contiguous(
856
- hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
857
- )
893
+ return self.forward_aiter(dispatch_output)
894
+ if dispatch_output.format.is_deepep_normal():
895
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
896
+ return self.forward_deepgemm_contiguous(dispatch_output)
858
897
  else:
859
- return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
860
- elif resolved_deepep_mode == DeepEPMode.low_latency:
861
- return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
898
+ return self.forward_normal(dispatch_output)
899
+ elif dispatch_output.format.is_deepep_ll():
900
+ return self.forward_deepgemm_masked(dispatch_output)
862
901
  else:
863
902
  raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
864
903
 
865
- def forward_normal(
904
+ def combine(
866
905
  self,
867
906
  hidden_states: torch.Tensor,
868
- reorder_topk_ids: torch.Tensor,
869
- seg_indptr: torch.Tensor,
907
+ topk_idx: torch.Tensor,
908
+ topk_weights: torch.Tensor,
909
+ forward_batch: ForwardBatch,
870
910
  ):
911
+ return self.deepep_dispatcher.combine(
912
+ hidden_states=hidden_states,
913
+ topk_idx=topk_idx,
914
+ topk_weights=topk_weights,
915
+ forward_batch=forward_batch,
916
+ )
917
+
918
+ def _prepare_for_normal(
919
+ self,
920
+ hidden_states: torch.Tensor,
921
+ topk_idx: torch.Tensor,
922
+ ):
923
+ from sglang.srt.layers.moe.ep_moe.kernels import (
924
+ deepep_permute_triton_kernel,
925
+ deepep_run_moe_deep_preprocess,
926
+ )
927
+
928
+ if hidden_states.shape[0] == 0:
929
+ reorder_topk_ids = torch.empty(
930
+ (0,), device=hidden_states.device, dtype=torch.int64
931
+ )
932
+ seg_indptr = torch.zeros(
933
+ (self.num_experts + 1,),
934
+ device=hidden_states.device,
935
+ dtype=torch.int64,
936
+ )
937
+ return reorder_topk_ids, seg_indptr, hidden_states
938
+ else:
939
+ if _use_aiter:
940
+ # skip permutation here as aiter fused_moe has fused inside
941
+ reorder_topk_ids = torch.empty(
942
+ (0,), device=hidden_states.device, dtype=torch.int64
943
+ )
944
+ seg_indptr = torch.zeros(
945
+ (self.num_experts + 1,),
946
+ device=hidden_states.device,
947
+ dtype=torch.int64,
948
+ )
949
+ return reorder_topk_ids, seg_indptr, hidden_states
950
+
951
+ reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
952
+ topk_idx, self.num_experts
953
+ )
954
+ num_total_tokens = reorder_topk_ids.numel()
955
+ gateup_input = torch.empty(
956
+ (int(num_total_tokens), hidden_states.shape[1]),
957
+ device=hidden_states.device,
958
+ dtype=hidden_states.dtype,
959
+ )
960
+ # PreReorder
961
+ deepep_permute_triton_kernel[(hidden_states.shape[0],)](
962
+ hidden_states,
963
+ gateup_input,
964
+ self.src2dst,
965
+ topk_idx,
966
+ None,
967
+ self.router_topk,
968
+ hidden_states.shape[1],
969
+ BLOCK_SIZE=512,
970
+ )
971
+ return reorder_topk_ids, seg_indptr, gateup_input
972
+
973
+ def forward_normal(
974
+ self,
975
+ dispatch_output: DeepEPNormalOutput,
976
+ ):
977
+ hidden_states, topk_idx = (
978
+ dispatch_output.hidden_states,
979
+ dispatch_output.topk_idx,
980
+ )
981
+ reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
982
+ hidden_states, topk_idx
983
+ )
871
984
  hidden_states_dtype = hidden_states.dtype
872
985
  hidden_states_device = hidden_states.device
873
986
 
@@ -983,10 +1096,13 @@ class DeepEPMoE(EPMoE):
983
1096
 
984
1097
  def forward_aiter(
985
1098
  self,
986
- hidden_states: torch.Tensor,
987
- topk_idx: torch.Tensor,
988
- topk_weights: torch.Tensor,
1099
+ dispatch_output: DeepEPNormalOutput,
989
1100
  ):
1101
+ hidden_states, topk_idx, topk_weights = (
1102
+ dispatch_output.hidden_states,
1103
+ dispatch_output.topk_idx,
1104
+ dispatch_output.topk_weights,
1105
+ )
990
1106
  if hidden_states.shape[0] == 0:
991
1107
  return hidden_states
992
1108
  # in original deepep, idx == -1 meaning invalid and will not be processed.
@@ -1014,11 +1130,11 @@ class DeepEPMoE(EPMoE):
1014
1130
 
1015
1131
  def forward_deepgemm_contiguous(
1016
1132
  self,
1017
- hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
1018
- topk_idx,
1019
- topk_weights,
1020
- num_recv_tokens_per_expert: List[int],
1133
+ dispatch_output: DeepEPNormalOutput,
1021
1134
  ):
1135
+ hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
1136
+ dispatch_output
1137
+ )
1022
1138
  hidden_states_fp8, hidden_states_scale = hidden_states_fp8
1023
1139
  assert self.quant_method is not None
1024
1140
  assert self.activation == "silu"
@@ -1138,10 +1254,9 @@ class DeepEPMoE(EPMoE):
1138
1254
 
1139
1255
  def forward_deepgemm_masked(
1140
1256
  self,
1141
- hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
1142
- masked_m: torch.Tensor,
1143
- expected_m: int,
1257
+ dispatch_output: DeepEPLLOutput,
1144
1258
  ):
1259
+ hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
1145
1260
  assert self.quant_method is not None
1146
1261
  assert self.activation == "silu"
1147
1262
 
@@ -1268,7 +1383,7 @@ class FlashInferEPMoE(EPMoE):
1268
1383
  topk_group=self.topk_group,
1269
1384
  intermediate_size=self.w2_weight.shape[2],
1270
1385
  local_expert_offset=self.start_expert_id,
1271
- local_num_experts=self.num_experts_per_partition,
1386
+ local_num_experts=self.num_local_experts,
1272
1387
  routed_scaling_factor=self.routed_scaling_factor,
1273
1388
  tile_tokens_dim=_get_tile_tokens_dim(
1274
1389
  hidden_states.shape[0], self.top_k, self.num_experts
@@ -1,7 +1,27 @@
1
+ # TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
2
+
3
+ from __future__ import annotations
4
+
1
5
  import logging
2
6
  from dataclasses import dataclass
7
+ from typing import (
8
+ TYPE_CHECKING,
9
+ List,
10
+ NamedTuple,
11
+ Optional,
12
+ Protocol,
13
+ Tuple,
14
+ Union,
15
+ runtime_checkable,
16
+ )
3
17
 
4
18
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
19
+ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
20
+ BaseDispatcher,
21
+ BaseDispatcherConfig,
22
+ DispatchOutput,
23
+ DispatchOutputFormat,
24
+ )
5
25
  from sglang.srt.layers.quantization import deep_gemm_wrapper
6
26
  from sglang.srt.managers.schedule_batch import global_server_args_dict
7
27
  from sglang.srt.utils import (
@@ -24,7 +44,6 @@ except ImportError:
24
44
  use_deepep = False
25
45
 
26
46
  from enum import Enum, IntEnum, auto
27
- from typing import Optional, Tuple, Union
28
47
 
29
48
  import torch
30
49
  import torch.distributed as dist
@@ -41,6 +60,37 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
41
60
  logger = logging.getLogger(__name__)
42
61
 
43
62
 
63
+ class DeepEPNormalOutput(NamedTuple):
64
+ """DeepEP normal dispatch output."""
65
+
66
+ hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
67
+ topk_idx: torch.Tensor
68
+ topk_weights: torch.Tensor
69
+ num_recv_tokens_per_expert: List[int]
70
+
71
+ @property
72
+ def format(self) -> DispatchOutputFormat:
73
+ return DispatchOutputFormat.deepep_normal
74
+
75
+
76
+ class DeepEPLLOutput(NamedTuple):
77
+ """DeepEP low latency dispatch output."""
78
+
79
+ hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
80
+ topk_idx: torch.Tensor
81
+ topk_weights: torch.Tensor
82
+ masked_m: torch.Tensor
83
+ expected_m: int
84
+
85
+ @property
86
+ def format(self) -> DispatchOutputFormat:
87
+ return DispatchOutputFormat.deepep_ll
88
+
89
+
90
+ assert isinstance(DeepEPNormalOutput, DispatchOutput)
91
+ assert isinstance(DeepEPLLOutput, DispatchOutput)
92
+
93
+
44
94
  class DeepEPDispatchMode(IntEnum):
45
95
  NORMAL = auto()
46
96
  LOW_LATENCY = auto()
@@ -107,6 +157,20 @@ class DeepEPBuffer:
107
157
  else:
108
158
  raise NotImplementedError
109
159
 
160
+ total_num_sms = torch.cuda.get_device_properties(
161
+ device="cuda"
162
+ ).multi_processor_count
163
+ if (
164
+ (deepep_mode != DeepEPMode.low_latency)
165
+ and not global_server_args_dict["enable_two_batch_overlap"]
166
+ and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
167
+ ):
168
+ logger.warning(
169
+ f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
170
+ f"This may result in highly suboptimal performance. "
171
+ f"Consider using --deepep-config to change the behavior."
172
+ )
173
+
110
174
  cls._buffer = Buffer(
111
175
  group,
112
176
  num_nvl_bytes,
@@ -139,7 +203,7 @@ class DeepEPBuffer:
139
203
  cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
140
204
 
141
205
 
142
- class DeepEPConfig:
206
+ class DeepEPConfig(BaseDispatcherConfig):
143
207
  _instance = None
144
208
 
145
209
  def __init__(self):
@@ -255,63 +319,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
255
319
  return hidden_states, topk_idx, topk_weights, previous_event
256
320
 
257
321
  def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
258
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
259
- (
260
- hidden_states,
261
- topk_idx,
262
- topk_weights,
263
- num_recv_tokens_per_expert_list,
264
- event,
265
- ) = self._dispatch_core(
266
- hidden_states, topk_idx, topk_weights, previous_event
267
- )
268
- event.current_stream_wait() if self.async_finish else ()
269
- return (
270
- hidden_states,
271
- topk_idx,
272
- topk_weights,
273
- None,
274
- num_recv_tokens_per_expert_list,
275
- None,
276
- None,
277
- None,
278
- )
279
- else:
280
- (
281
- hidden_states,
282
- topk_idx,
283
- topk_weights,
284
- num_recv_tokens_per_expert_list,
285
- event,
286
- ) = self._dispatch_core(
287
- hidden_states, topk_idx, topk_weights, previous_event
288
- )
289
- event.current_stream_wait() if self.async_finish else ()
290
- if hidden_states.shape[0] > 0:
291
- reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
292
- hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
293
- )
294
- else:
295
- reorder_topk_ids = torch.empty(
296
- (0,), device=hidden_states.device, dtype=torch.int64
297
- )
298
- seg_indptr = torch.zeros(
299
- (self.num_experts + 1,),
300
- device=hidden_states.device,
301
- dtype=torch.int64,
302
- )
303
-
304
- masked_m = expected_m = None
305
- return (
306
- hidden_states,
307
- topk_idx,
308
- topk_weights,
309
- reorder_topk_ids,
310
- None,
311
- seg_indptr,
312
- masked_m,
313
- expected_m,
314
- )
322
+ (
323
+ hidden_states,
324
+ topk_idx,
325
+ topk_weights,
326
+ num_recv_tokens_per_expert,
327
+ event,
328
+ ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
329
+ event.current_stream_wait() if self.async_finish else ()
330
+ return DeepEPNormalOutput(
331
+ hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
332
+ )
315
333
 
316
334
  def _dispatch_core(
317
335
  self,
@@ -343,7 +361,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
343
361
  recv_x,
344
362
  recv_topk_idx,
345
363
  recv_topk_weights,
346
- num_recv_tokens_per_expert_list,
364
+ num_recv_tokens_per_expert,
347
365
  self.handle,
348
366
  event,
349
367
  ) = buffer.dispatch(
@@ -362,7 +380,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
362
380
  )
363
381
 
364
382
  get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
365
- num_recv_tokens_per_expert_list,
383
+ num_recv_tokens_per_expert,
366
384
  num_tokens_per_rank=num_tokens_per_rank,
367
385
  num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
368
386
  num_tokens_per_expert=num_tokens_per_expert,
@@ -372,58 +390,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
372
390
  recv_x,
373
391
  recv_topk_idx,
374
392
  recv_topk_weights,
375
- num_recv_tokens_per_expert_list,
393
+ num_recv_tokens_per_expert,
376
394
  event,
377
395
  )
378
396
 
379
- def _deepep_permute(
380
- self,
381
- hidden_states: torch.Tensor,
382
- topk_idx: torch.Tensor,
383
- fp8_dtype: Optional[torch.dtype] = None,
384
- use_fp8_w8a8: bool = False,
385
- use_block_quant: bool = False,
386
- ):
387
- """
388
- Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
389
- https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
390
- """
391
- if _use_aiter:
392
- # skip permutation here as aiter fused_moe has fused inside
393
- reorder_topk_ids = torch.empty(
394
- (0,), device=hidden_states.device, dtype=torch.int64
395
- )
396
- seg_indptr = torch.zeros(
397
- (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
398
- )
399
- return reorder_topk_ids, seg_indptr, hidden_states
400
-
401
- reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
402
- topk_idx, self.num_experts
403
- )
404
- num_total_tokens = reorder_topk_ids.numel()
405
- gateup_input = torch.empty(
406
- (int(num_total_tokens), hidden_states.shape[1]),
407
- device=hidden_states.device,
408
- dtype=(
409
- fp8_dtype
410
- if (use_fp8_w8a8 and not use_block_quant)
411
- else hidden_states.dtype
412
- ),
413
- )
414
- # PreReorder
415
- deepep_permute_triton_kernel[(hidden_states.shape[0],)](
416
- hidden_states,
417
- gateup_input,
418
- self.src2dst,
419
- topk_idx,
420
- None,
421
- self.router_topk,
422
- hidden_states.shape[1],
423
- BLOCK_SIZE=512,
424
- )
425
- return reorder_topk_ids, seg_indptr, gateup_input
426
-
427
397
  def combine_a(
428
398
  self,
429
399
  hidden_states: torch.Tensor,
@@ -544,15 +514,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
544
514
  masked_m
545
515
  )
546
516
 
547
- reorder_topk_ids = seg_indptr = None
548
-
549
- return (
517
+ return DeepEPLLOutput(
550
518
  hidden_states,
551
519
  topk_idx,
552
520
  topk_weights,
553
- reorder_topk_ids,
554
- None,
555
- seg_indptr,
556
521
  masked_m,
557
522
  expected_m,
558
523
  )
@@ -636,7 +601,7 @@ class _Stage(Enum):
636
601
  AFTER_COMBINE_A = auto()
637
602
 
638
603
 
639
- class DeepEPDispatcher:
604
+ class DeepEPDispatcher(BaseDispatcher):
640
605
  def __init__(
641
606
  self,
642
607
  group: torch.distributed.ProcessGroup,
@@ -676,7 +641,7 @@ class DeepEPDispatcher:
676
641
 
677
642
  self._stage = _Stage.INITIAL
678
643
 
679
- def dispatch(self, *args, **kwargs) -> Tuple:
644
+ def dispatch(self, *args, **kwargs) -> DispatchOutput:
680
645
  self.dispatch_a(*args, **kwargs)
681
646
  ret = self.dispatch_b()
682
647
  return ret