sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/model_config.py +1 -0
  4. sglang/srt/constrained/base_grammar_backend.py +5 -1
  5. sglang/srt/custom_op.py +5 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  7. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  8. sglang/srt/entrypoints/engine.py +0 -5
  9. sglang/srt/layers/attention/flashattention_backend.py +394 -76
  10. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  11. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  12. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  13. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  14. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  15. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  20. sglang/srt/layers/moe/topk.py +49 -3
  21. sglang/srt/layers/quantization/__init__.py +4 -1
  22. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  23. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  24. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  25. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  26. sglang/srt/layers/quantization/utils.py +1 -1
  27. sglang/srt/layers/rotary_embedding.py +0 -12
  28. sglang/srt/managers/cache_controller.py +34 -11
  29. sglang/srt/managers/mm_utils.py +202 -156
  30. sglang/srt/managers/multimodal_processor.py +0 -2
  31. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  32. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  33. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  34. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  35. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  36. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  37. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  38. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  40. sglang/srt/managers/schedule_batch.py +185 -128
  41. sglang/srt/managers/scheduler.py +4 -4
  42. sglang/srt/managers/tokenizer_manager.py +1 -1
  43. sglang/srt/managers/utils.py +1 -6
  44. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  45. sglang/srt/mem_cache/memory_pool.py +72 -6
  46. sglang/srt/mem_cache/paged_allocator.py +39 -0
  47. sglang/srt/metrics/collector.py +23 -53
  48. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  49. sglang/srt/model_executor/forward_batch_info.py +10 -10
  50. sglang/srt/model_executor/model_runner.py +59 -57
  51. sglang/srt/model_loader/loader.py +8 -0
  52. sglang/srt/models/clip.py +12 -7
  53. sglang/srt/models/deepseek_janus_pro.py +10 -15
  54. sglang/srt/models/deepseek_v2.py +212 -121
  55. sglang/srt/models/deepseek_vl2.py +105 -104
  56. sglang/srt/models/gemma3_mm.py +14 -80
  57. sglang/srt/models/llama.py +4 -1
  58. sglang/srt/models/llava.py +31 -19
  59. sglang/srt/models/llavavid.py +16 -7
  60. sglang/srt/models/minicpmo.py +63 -147
  61. sglang/srt/models/minicpmv.py +17 -27
  62. sglang/srt/models/mllama.py +29 -14
  63. sglang/srt/models/qwen2.py +9 -6
  64. sglang/srt/models/qwen2_5_vl.py +21 -31
  65. sglang/srt/models/qwen2_vl.py +20 -21
  66. sglang/srt/openai_api/adapter.py +18 -6
  67. sglang/srt/platforms/interface.py +371 -0
  68. sglang/srt/server_args.py +99 -14
  69. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  70. sglang/srt/speculative/eagle_utils.py +140 -28
  71. sglang/srt/speculative/eagle_worker.py +93 -24
  72. sglang/srt/utils.py +104 -51
  73. sglang/test/test_custom_ops.py +55 -0
  74. sglang/test/test_utils.py +13 -26
  75. sglang/utils.py +2 -2
  76. sglang/version.py +1 -1
  77. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
  78. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
  79. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  80. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  81. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ from sglang.srt.utils import DeepEPMode
2
+
1
3
  try:
2
4
  from deep_ep import Buffer
3
5
 
@@ -21,7 +23,7 @@ _buffer_normal = None
21
23
  _buffer_low_latency = None
22
24
 
23
25
 
24
- def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
26
+ def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
25
27
  """
26
28
  Copy from DeepEP example usage in model inference prefilling.
27
29
  https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
@@ -51,7 +53,7 @@ def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
51
53
  return _buffer_normal
52
54
 
53
55
 
54
- def get_buffer_low_latency(
56
+ def _get_buffer_low_latency(
55
57
  group: dist.ProcessGroup,
56
58
  num_max_dispatch_tokens_per_rank: int,
57
59
  hidden: int,
@@ -76,151 +78,103 @@ def get_buffer_low_latency(
76
78
  assert num_experts % group.size() == 0
77
79
  _buffer_low_latency = Buffer(
78
80
  group,
79
- 0,
80
- num_rdma_bytes,
81
+ num_rdma_bytes=num_rdma_bytes,
81
82
  low_latency_mode=True,
82
83
  num_qps_per_rank=num_experts // group.size(),
83
84
  )
84
85
  return _buffer_low_latency
85
86
 
86
87
 
87
- class DeepEPDispatcher:
88
- """
89
- Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
90
- https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
91
- """
92
-
88
+ class _DeepEPDispatcherImplBase:
93
89
  def __init__(
94
90
  self,
95
91
  group: torch.distributed.ProcessGroup,
96
92
  router_topk: int,
97
- permute_fusion: bool = False,
98
- capacity_factor: float = None,
99
- num_experts: int = None,
100
- num_local_experts: int = None,
101
- hidden_size: int = None,
102
- params_dtype: torch.dtype = None,
103
- async_finish: bool = False,
93
+ permute_fusion: bool,
94
+ num_experts: int,
95
+ num_local_experts: int,
96
+ hidden_size: int,
97
+ params_dtype: torch.dtype,
104
98
  ):
99
+ if not use_deepep:
100
+ raise ImportError(
101
+ "DeepEP is not installed. Please install DeepEP package from "
102
+ "https://github.com/deepseek-ai/deepep."
103
+ )
104
+
105
105
  self.group = group
106
106
  self.router_topk = router_topk
107
- self.capacity_factor = capacity_factor
108
107
  self.permute_fusion = permute_fusion
109
108
  self.num_experts = num_experts
110
109
  self.num_local_experts = num_local_experts
111
110
  self.hidden_size = hidden_size
112
- self.recv_expert_count = None
113
111
  self.params_dtype = params_dtype
114
112
  self.params_bytes = 2
115
- # Metadata
116
- self.token_indices = None
117
- self.token_probs = None
118
- # Handle used for combine operation
113
+
119
114
  self.handle = None
120
- self.async_finish = async_finish
121
115
 
122
- # `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
123
- # https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
124
- self.num_max_dispatch_tokens_per_rank = 128
116
+ def dispatch_a(
117
+ self,
118
+ hidden_states: torch.Tensor,
119
+ topk_idx: torch.Tensor,
120
+ topk_weights: torch.Tensor,
121
+ num_experts: int,
122
+ num_max_dispatch_tokens_per_rank: int,
123
+ ):
124
+ raise NotImplementedError
125
125
 
126
- if not use_deepep:
127
- raise ImportError(
128
- "DeepEP is not installed. Please install DeepEP package from "
129
- "https://github.com/deepseek-ai/deepep."
130
- )
131
- self.buffer_normal = get_buffer_normal(
132
- self.group, self.hidden_size * self.params_bytes
133
- )
134
- self.buffer_low_latency = None
135
- # Todo: enable low latency dispatch
136
- """
137
- self.buffer_low_latency = get_buffer_low_latency(
138
- self.group,
139
- self.num_max_dispatch_tokens_per_rank,
140
- self.hidden_size * self.params_bytes,
141
- self.num_experts,
142
- )
143
- """
126
+ def dispatch_b(self, *args, **kwargs):
127
+ raise NotImplementedError
144
128
 
145
- def deepep_permute(
129
+ def combine_a(
146
130
  self,
147
- hidden_states,
148
- fp8_dtype=None,
149
- use_fp8_w8a8=False,
150
- use_block_quant=False,
131
+ hidden_states: torch.Tensor,
132
+ topk_idx: torch.Tensor,
133
+ topk_weights: torch.Tensor,
151
134
  ):
152
- reorder_topk_ids, src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
153
- self.topk_idx, self.num_experts
154
- )
155
- num_total_tokens = reorder_topk_ids.numel()
156
- gateup_input = torch.empty(
157
- (int(num_total_tokens), hidden_states.shape[1]),
158
- device=hidden_states.device,
159
- dtype=(
160
- fp8_dtype
161
- if (use_fp8_w8a8 and not use_block_quant)
162
- else hidden_states.dtype
163
- ),
164
- )
165
- # PreReorder
166
- deepep_permute_triton_kernel[(hidden_states.shape[0],)](
167
- hidden_states,
168
- gateup_input,
169
- src2dst,
170
- self.topk_idx,
171
- None,
172
- self.router_topk,
173
- hidden_states.shape[1],
174
- BLOCK_SIZE=512,
135
+ raise NotImplementedError
136
+
137
+ def combine_b(self, *args, **kwargs):
138
+ raise NotImplementedError
139
+
140
+
141
+ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
142
+ def __init__(self, async_finish: bool, **kwargs):
143
+ super().__init__(**kwargs)
144
+
145
+ self.buffer_normal = _get_buffer_normal(
146
+ self.group, self.hidden_size * self.params_bytes
175
147
  )
176
- self.src2dst = src2dst
177
- return reorder_topk_ids, seg_indptr, gateup_input
148
+ self.async_finish = async_finish
149
+ self.src2dst = None
178
150
 
179
- def dispatch(
151
+ def dispatch_a(
180
152
  self,
181
153
  hidden_states: torch.Tensor,
182
154
  topk_idx: torch.Tensor,
183
155
  topk_weights: torch.Tensor,
184
156
  num_experts: int,
185
- forward_mode: ForwardMode,
186
- num_max_dispatch_tokens_per_rank: int = 128,
187
- ) -> Tuple[torch.Tensor, torch.Tensor]:
157
+ num_max_dispatch_tokens_per_rank: int,
158
+ ):
188
159
  topk_idx = topk_idx.to(torch.int64)
189
- # Todo: enable low latency dispatch
190
- if True: # not forward_mode.is_decode():
191
- (
192
- hidden_states,
193
- topk_idx,
194
- topk_weights,
195
- num_recv_tokens_per_expert_list,
196
- handle,
197
- event,
198
- ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
199
- self.tokens_per_expert = torch.tensor(
200
- num_recv_tokens_per_expert_list,
201
- device=hidden_states.device,
202
- dtype=torch.int64,
203
- )
204
- else:
205
- hidden_states, recv_expert_count, handle, event, hook = (
206
- self.dispatch_low_latency(
207
- hidden_states,
208
- topk_idx,
209
- num_max_dispatch_tokens_per_rank,
210
- num_experts,
211
- )
212
- )
213
- self.recv_expert_count = recv_expert_count
214
-
215
- if self.async_finish:
216
- event.current_stream_wait()
160
+ previous_event = Buffer.capture() if self.async_finish else None
161
+ return hidden_states, topk_idx, topk_weights, num_experts, previous_event
217
162
 
218
- self.handle = handle
219
- self.topk_idx = topk_idx
220
- self.topk_weights = topk_weights
163
+ def dispatch_b(
164
+ self, hidden_states, topk_idx, topk_weights, num_experts, previous_event
165
+ ):
166
+ (
167
+ hidden_states,
168
+ topk_idx,
169
+ topk_weights,
170
+ event,
171
+ ) = self._dispatch_core(
172
+ hidden_states, topk_idx, topk_weights, num_experts, previous_event
173
+ )
174
+ event.current_stream_wait() if self.async_finish else ()
221
175
  if hidden_states.shape[0] > 0:
222
- reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
223
- hidden_states, fp8_dtype=hidden_states.dtype
176
+ reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
177
+ hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
224
178
  )
225
179
  else:
226
180
  reorder_topk_ids = torch.empty(
@@ -229,17 +183,27 @@ class DeepEPDispatcher:
229
183
  seg_indptr = torch.zeros(
230
184
  (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
231
185
  )
232
- return hidden_states, reorder_topk_ids, seg_indptr
233
186
 
234
- def dispatch_normal(
187
+ masked_m = expected_m = None
188
+
189
+ return (
190
+ hidden_states,
191
+ topk_idx,
192
+ topk_weights,
193
+ reorder_topk_ids,
194
+ seg_indptr,
195
+ masked_m,
196
+ expected_m,
197
+ )
198
+
199
+ def _dispatch_core(
235
200
  self,
236
201
  x: torch.Tensor,
237
202
  topk_idx: torch.Tensor,
238
203
  topk_weights: torch.Tensor,
239
204
  num_experts: int,
205
+ previous_event,
240
206
  ):
241
- previous_event = Buffer.capture() if self.async_finish else None
242
-
243
207
  (
244
208
  num_tokens_per_rank,
245
209
  num_tokens_per_rdma_rank,
@@ -254,12 +218,15 @@ class DeepEPDispatcher:
254
218
  allocate_on_comm_stream=previous_event is not None,
255
219
  )
256
220
 
221
+ # FIXME: `handle` should be transmitted with tokens from dispatch to combine.
222
+ # However, doing this would incur an unknown synchronization error, but keeping
223
+ # `handle` as a member variable works.
257
224
  (
258
225
  recv_x,
259
226
  recv_topk_idx,
260
227
  recv_topk_weights,
261
- num_recv_tokens_per_expert_list,
262
- handle,
228
+ _, # num_recv_tokens_per_expert_list
229
+ self.handle,
263
230
  event,
264
231
  ) = self.buffer_normal.dispatch(
265
232
  x,
@@ -278,29 +245,191 @@ class DeepEPDispatcher:
278
245
  recv_x,
279
246
  recv_topk_idx,
280
247
  recv_topk_weights,
281
- num_recv_tokens_per_expert_list,
282
- handle,
283
248
  event,
284
249
  )
285
250
 
286
- def dispatch_low_latency(
251
+ def _deepep_permute(
252
+ self,
253
+ hidden_states: torch.Tensor,
254
+ topk_idx: torch.Tensor,
255
+ fp8_dtype: Optional[torch.dtype] = None,
256
+ use_fp8_w8a8: bool = False,
257
+ use_block_quant: bool = False,
258
+ ):
259
+ """
260
+ Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
261
+ https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
262
+ """
263
+
264
+ reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
265
+ topk_idx, self.num_experts
266
+ )
267
+ num_total_tokens = reorder_topk_ids.numel()
268
+ gateup_input = torch.empty(
269
+ (int(num_total_tokens), hidden_states.shape[1]),
270
+ device=hidden_states.device,
271
+ dtype=(
272
+ fp8_dtype
273
+ if (use_fp8_w8a8 and not use_block_quant)
274
+ else hidden_states.dtype
275
+ ),
276
+ )
277
+ # PreReorder
278
+ deepep_permute_triton_kernel[(hidden_states.shape[0],)](
279
+ hidden_states,
280
+ gateup_input,
281
+ self.src2dst,
282
+ topk_idx,
283
+ None,
284
+ self.router_topk,
285
+ hidden_states.shape[1],
286
+ BLOCK_SIZE=512,
287
+ )
288
+ return reorder_topk_ids, seg_indptr, gateup_input
289
+
290
+ def combine_a(
291
+ self,
292
+ hidden_states: torch.Tensor,
293
+ topk_idx: torch.Tensor,
294
+ topk_weights: torch.Tensor,
295
+ ):
296
+ if hidden_states.shape[0] > 0:
297
+ num_tokens = self.src2dst.shape[0] // self.router_topk
298
+ output = torch.empty(
299
+ (num_tokens, hidden_states.shape[1]),
300
+ device=hidden_states.device,
301
+ dtype=hidden_states.dtype,
302
+ )
303
+ deepep_post_reorder_triton_kernel[(num_tokens,)](
304
+ hidden_states,
305
+ output,
306
+ self.src2dst,
307
+ topk_idx,
308
+ topk_weights,
309
+ self.router_topk,
310
+ hidden_states.shape[1],
311
+ BLOCK_SIZE=512,
312
+ )
313
+ else:
314
+ output = torch.zeros(
315
+ (0, hidden_states.shape[1]),
316
+ device=hidden_states.device,
317
+ dtype=hidden_states.dtype,
318
+ )
319
+ previous_event = Buffer.capture() if self.async_finish else None
320
+ return output, previous_event
321
+
322
+ def combine_b(self, output, previous_event):
323
+ hidden_states, event = self._combine_core(output, previous_event)
324
+ event.current_stream_wait() if self.async_finish else ()
325
+ self.handle = None
326
+ self.src2dst = None
327
+ return hidden_states
328
+
329
+ def _combine_core(self, x: torch.Tensor, previous_event):
330
+ combined_x, _, event = self.buffer_normal.combine(
331
+ x,
332
+ self.handle,
333
+ async_finish=self.async_finish,
334
+ previous_event=previous_event,
335
+ allocate_on_comm_stream=previous_event is not None,
336
+ )
337
+ return combined_x, event
338
+
339
+
340
+ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
341
+ def __init__(self, return_recv_hook: bool, **kwargs):
342
+ super().__init__(**kwargs)
343
+
344
+ """
345
+ num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
346
+ https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
347
+ """
348
+ # TODO(ch-wan): allow users to set this value
349
+ self.num_max_dispatch_tokens_per_rank = 128
350
+ self.buffer_low_latency = _get_buffer_low_latency(
351
+ self.group,
352
+ self.num_max_dispatch_tokens_per_rank,
353
+ self.hidden_size,
354
+ self.num_experts,
355
+ )
356
+ self.return_recv_hook = return_recv_hook
357
+
358
+ def dispatch_a(
359
+ self,
360
+ hidden_states: torch.Tensor,
361
+ topk_idx: torch.Tensor,
362
+ topk_weights: torch.Tensor,
363
+ num_experts: int,
364
+ num_max_dispatch_tokens_per_rank: int,
365
+ ):
366
+ topk_idx = topk_idx.to(torch.int64)
367
+ expected_m = (
368
+ hidden_states.shape[0]
369
+ * self.buffer_low_latency.group_size
370
+ * topk_idx.shape[1]
371
+ + num_experts
372
+ ) // num_experts
373
+ hidden_states, masked_m, event, hook = self._dispatch_core(
374
+ hidden_states,
375
+ topk_idx,
376
+ num_max_dispatch_tokens_per_rank,
377
+ num_experts,
378
+ use_fp8=True,
379
+ )
380
+ return (
381
+ hidden_states,
382
+ topk_idx,
383
+ topk_weights,
384
+ masked_m,
385
+ expected_m,
386
+ event,
387
+ hook,
388
+ )
389
+
390
+ def dispatch_b(
391
+ self,
392
+ hidden_states,
393
+ topk_idx,
394
+ topk_weights,
395
+ masked_m,
396
+ expected_m,
397
+ event,
398
+ hook,
399
+ ):
400
+ hook() if self.return_recv_hook else event.current_stream_wait()
401
+
402
+ reorder_topk_ids = seg_indptr = None
403
+
404
+ return (
405
+ hidden_states,
406
+ topk_idx,
407
+ topk_weights,
408
+ reorder_topk_ids,
409
+ seg_indptr,
410
+ masked_m,
411
+ expected_m,
412
+ )
413
+
414
+ def _dispatch_core(
287
415
  self,
288
416
  hidden_states: torch.Tensor,
289
417
  topk_idx: torch.Tensor,
290
418
  num_max_dispatch_tokens_per_rank: int,
291
419
  num_experts: int,
420
+ use_fp8: bool = False,
292
421
  ):
293
422
  """
294
- # For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'
295
- # Please please make sure to change DeepEP code in internode_ll.cu dispatch / combine first and then reinstall!
423
+ # For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'.
424
+ # Please make sure to change DeepEP code in internode_ll.cu dispatch / combine as below first and then reinstall.
296
425
  # More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
297
- +
426
+
298
427
  diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
299
- index f60e933..cddaabf 100644
428
+ index 76ae2e2..8ecd08f 100644
300
429
  --- a/csrc/kernels/internode_ll.cu
301
430
  +++ b/csrc/kernels/internode_ll.cu
302
- @@ -307,14 +307,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
303
- int num_topk, int num_experts, int rank, int num_ranks,
431
+ @@ -310,8 +310,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
432
+ int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
304
433
  void* workspace, cudaStream_t stream, int phases) {
305
434
  constexpr int kNumMaxTopK = 9;
306
435
  - constexpr int kNumWarpsPerGroup = 10;
@@ -308,16 +437,9 @@ class DeepEPDispatcher:
308
437
  + constexpr int kNumWarpsPerGroup = 8;
309
438
  + constexpr int kNumWarpGroups = 4;
310
439
  EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
311
- +
440
+
312
441
  const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
313
- const auto num_sms = cell_div(num_experts, kNumWarpGroups);
314
- EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
315
- - EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
316
- + // EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
317
- +
318
- // Workspace checks
319
- auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
320
- @@ -505,8 +505,8 @@ void combine(void* combined_x,
442
+ @@ -501,8 +501,8 @@ void combine(void* combined_x,
321
443
  int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
322
444
  int num_topk, int num_experts, int rank, int num_ranks,
323
445
  void* workspace, cudaStream_t stream, int phases) {
@@ -326,91 +448,152 @@ class DeepEPDispatcher:
326
448
  + constexpr int kNumWarpsPerGroup = 8;
327
449
  + constexpr int kNumWarpGroups = 4;
328
450
  constexpr int kNumMaxTopk = 9;
329
- +
451
+
330
452
  const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
331
453
  """
332
454
 
333
- recv_hidden_states, recv_expert_count, handle, event, hook = (
455
+ packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
334
456
  self.buffer_low_latency.low_latency_dispatch(
335
457
  hidden_states,
336
458
  topk_idx,
337
459
  num_max_dispatch_tokens_per_rank,
338
460
  num_experts,
339
- async_finish=self.async_finish,
340
- return_recv_hook=False, # True for double-batch overlapping, need call hook()
461
+ use_fp8=use_fp8,
462
+ async_finish=not self.return_recv_hook,
463
+ return_recv_hook=self.return_recv_hook,
341
464
  )
342
465
  )
343
- # hook()
344
- return recv_hidden_states, recv_expert_count, handle, event, hook
345
-
346
- def combine(
347
- self, hidden_states: torch.Tensor, forward_mode: ForwardMode
348
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
349
- # Todo: enable low latency combine
350
- if True: # not forward_mode.is_decode():
351
- if hidden_states.shape[0] > 0:
352
- num_tokens = self.src2dst.shape[0] // self.router_topk
353
- output = torch.empty(
354
- (num_tokens, hidden_states.shape[1]),
355
- device=hidden_states.device,
356
- dtype=hidden_states.dtype,
357
- )
358
- deepep_post_reorder_triton_kernel[(num_tokens,)](
359
- hidden_states,
360
- output,
361
- self.src2dst,
362
- self.topk_idx,
363
- self.topk_weights,
364
- self.router_topk,
365
- hidden_states.shape[1],
366
- BLOCK_SIZE=512,
367
- )
368
- else:
369
- output = torch.zeros(
370
- (0, hidden_states.shape[1]),
371
- device=hidden_states.device,
372
- dtype=hidden_states.dtype,
373
- )
374
- hidden_states, event = self.combine_normal(output, self.handle)
375
- else:
376
- hidden_states, event, hook = self.combine_low_latency(
377
- hidden_states, self.topk_idx, self.topk_weights, self.handle
378
- )
466
+ return packed_recv_hidden, packed_recv_count, event, hook
379
467
 
380
- if self.async_finish:
381
- event.current_stream_wait()
468
+ def combine_a(
469
+ self,
470
+ hidden_states: torch.Tensor,
471
+ topk_idx: torch.Tensor,
472
+ topk_weights: torch.Tensor,
473
+ ):
474
+ hidden_states, event, hook = self._combine_core(
475
+ hidden_states,
476
+ topk_idx,
477
+ topk_weights,
478
+ )
479
+ return hidden_states, event, hook
382
480
 
383
- self.handle = None
481
+ def combine_b(self, hidden_states, event, hook):
482
+ hook() if self.return_recv_hook else event.current_stream_wait()
384
483
  return hidden_states
385
484
 
386
- def combine_normal(self, x: torch.Tensor, handle: Tuple):
387
- previous_event = Buffer.capture() if self.async_finish else None
388
-
389
- combined_x, _, event = self.buffer_normal.combine(
390
- x,
391
- handle,
392
- async_finish=self.async_finish,
393
- previous_event=previous_event,
394
- allocate_on_comm_stream=previous_event is not None,
395
- )
396
- return combined_x, event
397
-
398
- def combine_low_latency(
485
+ def _combine_core(
399
486
  self,
400
487
  hidden_states: torch.Tensor,
401
488
  topk_idx: torch.Tensor,
402
489
  topk_weights: torch.Tensor,
403
- handle: Tuple,
404
490
  ):
405
- combined_hidden_states, event_overlap, hook = (
491
+ combined_hidden_states, event, hook = (
406
492
  self.buffer_low_latency.low_latency_combine(
407
493
  hidden_states,
408
494
  topk_idx,
409
495
  topk_weights,
410
- handle,
411
- async_finish=self.async_finish,
412
- return_recv_hook=False, # True for double-batch overlapping, need call hook()
496
+ self.handle,
497
+ async_finish=not self.return_recv_hook,
498
+ return_recv_hook=self.return_recv_hook,
413
499
  )
414
500
  )
415
- # hook()
416
- return combined_hidden_states, event_overlap, hook
501
+ self.handle = None
502
+ return combined_hidden_states, event, hook
503
+
504
+
505
+ class DeepEPDispatcher:
506
+ def __init__(
507
+ self,
508
+ group: torch.distributed.ProcessGroup,
509
+ router_topk: int,
510
+ permute_fusion: bool = False,
511
+ num_experts: int = None,
512
+ num_local_experts: int = None,
513
+ hidden_size: int = None,
514
+ params_dtype: torch.dtype = None,
515
+ deepep_mode: DeepEPMode = DeepEPMode.auto,
516
+ async_finish: bool = False,
517
+ return_recv_hook: bool = False,
518
+ ):
519
+ self.deepep_mode = deepep_mode
520
+
521
+ common_kwargs = dict(
522
+ group=group,
523
+ router_topk=router_topk,
524
+ permute_fusion=permute_fusion,
525
+ num_experts=num_experts,
526
+ num_local_experts=num_local_experts,
527
+ hidden_size=hidden_size,
528
+ params_dtype=params_dtype,
529
+ )
530
+
531
+ if self.deepep_mode.enable_normal():
532
+ self._normal_dispatcher = _DeepEPDispatcherImplNormal(
533
+ async_finish=async_finish,
534
+ **common_kwargs,
535
+ )
536
+ if self.deepep_mode.enable_low_latency():
537
+ self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency(
538
+ return_recv_hook=return_recv_hook,
539
+ **common_kwargs,
540
+ )
541
+
542
+ def dispatch(self, *args, **kwargs) -> Tuple:
543
+ self.dispatch_a(*args, **kwargs)
544
+ return self.dispatch_b()
545
+
546
+ def dispatch_a(
547
+ self,
548
+ hidden_states: torch.Tensor,
549
+ topk_idx: torch.Tensor,
550
+ topk_weights: torch.Tensor,
551
+ num_experts: int,
552
+ num_max_dispatch_tokens_per_rank: int = 128,
553
+ forward_mode: ForwardMode = None,
554
+ ):
555
+ inner_state = self._get_impl(forward_mode).dispatch_a(
556
+ hidden_states=hidden_states,
557
+ topk_idx=topk_idx,
558
+ topk_weights=topk_weights,
559
+ num_experts=num_experts,
560
+ num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
561
+ )
562
+ self._dispatch_intermediate_state = forward_mode, inner_state
563
+
564
+ def dispatch_b(self):
565
+ forward_mode, inner_state = self._dispatch_intermediate_state
566
+ del self._dispatch_intermediate_state
567
+ return self._get_impl(forward_mode).dispatch_b(*inner_state)
568
+
569
+ def combine(self, *args, **kwargs) -> Tuple:
570
+ self.combine_a(*args, **kwargs)
571
+ return self.combine_b()
572
+
573
+ def combine_a(
574
+ self,
575
+ hidden_states: torch.Tensor,
576
+ topk_idx: torch.Tensor,
577
+ topk_weights: torch.Tensor,
578
+ forward_mode: ForwardMode,
579
+ ):
580
+ inner_state = self._get_impl(forward_mode).combine_a(
581
+ hidden_states=hidden_states,
582
+ topk_idx=topk_idx,
583
+ topk_weights=topk_weights,
584
+ )
585
+ self._combine_intermediate_state = forward_mode, inner_state
586
+
587
+ def combine_b(self):
588
+ forward_mode, inner_state = self._combine_intermediate_state
589
+ del self._combine_intermediate_state
590
+ return self._get_impl(forward_mode).combine_b(*inner_state)
591
+
592
+ def _get_impl(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase":
593
+ resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
594
+ if resolved_deepep_mode == DeepEPMode.normal:
595
+ return self._normal_dispatcher
596
+ elif resolved_deepep_mode == DeepEPMode.low_latency:
597
+ return self._low_latency_dispatcher
598
+ else:
599
+ raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")