sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/lang/chat_template.py +24 -0
  4. sglang/srt/configs/model_config.py +40 -4
  5. sglang/srt/constrained/base_grammar_backend.py +26 -5
  6. sglang/srt/constrained/llguidance_backend.py +1 -0
  7. sglang/srt/constrained/outlines_backend.py +1 -0
  8. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  9. sglang/srt/constrained/xgrammar_backend.py +1 -0
  10. sglang/srt/conversation.py +29 -4
  11. sglang/srt/disaggregation/base/__init__.py +8 -0
  12. sglang/srt/disaggregation/base/conn.py +113 -0
  13. sglang/srt/disaggregation/decode.py +18 -5
  14. sglang/srt/disaggregation/mini_lb.py +53 -122
  15. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  16. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  17. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  18. sglang/srt/disaggregation/prefill.py +43 -19
  19. sglang/srt/disaggregation/utils.py +31 -0
  20. sglang/srt/entrypoints/EngineBase.py +53 -0
  21. sglang/srt/entrypoints/engine.py +36 -8
  22. sglang/srt/entrypoints/http_server.py +37 -8
  23. sglang/srt/entrypoints/http_server_engine.py +142 -0
  24. sglang/srt/entrypoints/verl_engine.py +37 -10
  25. sglang/srt/hf_transformers_utils.py +4 -0
  26. sglang/srt/layers/attention/flashattention_backend.py +609 -202
  27. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  28. sglang/srt/layers/attention/vision.py +1 -1
  29. sglang/srt/layers/dp_attention.py +2 -4
  30. sglang/srt/layers/elementwise.py +15 -2
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  33. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  49. sglang/srt/layers/moe/router.py +7 -1
  50. sglang/srt/layers/moe/topk.py +37 -16
  51. sglang/srt/layers/quantization/__init__.py +13 -5
  52. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  53. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  54. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  55. sglang/srt/layers/quantization/fp8.py +28 -14
  56. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  57. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  58. sglang/srt/layers/quantization/kv_cache.py +43 -52
  59. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  62. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  63. sglang/srt/layers/radix_attention.py +14 -0
  64. sglang/srt/layers/rotary_embedding.py +75 -1
  65. sglang/srt/managers/io_struct.py +254 -97
  66. sglang/srt/managers/mm_utils.py +3 -2
  67. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  68. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  69. sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
  70. sglang/srt/managers/schedule_batch.py +62 -21
  71. sglang/srt/managers/scheduler.py +71 -14
  72. sglang/srt/managers/tokenizer_manager.py +17 -3
  73. sglang/srt/managers/tp_worker.py +1 -0
  74. sglang/srt/mem_cache/memory_pool.py +14 -1
  75. sglang/srt/metrics/collector.py +9 -0
  76. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  77. sglang/srt/model_executor/forward_batch_info.py +234 -15
  78. sglang/srt/model_executor/model_runner.py +49 -9
  79. sglang/srt/model_loader/loader.py +31 -4
  80. sglang/srt/model_loader/weight_utils.py +4 -2
  81. sglang/srt/models/baichuan.py +2 -0
  82. sglang/srt/models/chatglm.py +1 -0
  83. sglang/srt/models/commandr.py +1 -0
  84. sglang/srt/models/dbrx.py +1 -0
  85. sglang/srt/models/deepseek.py +1 -0
  86. sglang/srt/models/deepseek_v2.py +248 -61
  87. sglang/srt/models/exaone.py +1 -0
  88. sglang/srt/models/gemma.py +1 -0
  89. sglang/srt/models/gemma2.py +1 -0
  90. sglang/srt/models/gemma3_causal.py +1 -0
  91. sglang/srt/models/gpt2.py +1 -0
  92. sglang/srt/models/gpt_bigcode.py +1 -0
  93. sglang/srt/models/granite.py +1 -0
  94. sglang/srt/models/grok.py +1 -0
  95. sglang/srt/models/internlm2.py +1 -0
  96. sglang/srt/models/llama.py +13 -4
  97. sglang/srt/models/llama4.py +487 -0
  98. sglang/srt/models/minicpm.py +1 -0
  99. sglang/srt/models/minicpm3.py +2 -0
  100. sglang/srt/models/mixtral.py +1 -0
  101. sglang/srt/models/mixtral_quant.py +1 -0
  102. sglang/srt/models/mllama.py +51 -8
  103. sglang/srt/models/mllama4.py +227 -0
  104. sglang/srt/models/olmo.py +1 -0
  105. sglang/srt/models/olmo2.py +1 -0
  106. sglang/srt/models/olmoe.py +1 -0
  107. sglang/srt/models/phi3_small.py +1 -0
  108. sglang/srt/models/qwen.py +1 -0
  109. sglang/srt/models/qwen2.py +1 -0
  110. sglang/srt/models/qwen2_5_vl.py +35 -70
  111. sglang/srt/models/qwen2_moe.py +1 -0
  112. sglang/srt/models/qwen2_vl.py +27 -25
  113. sglang/srt/models/stablelm.py +1 -0
  114. sglang/srt/models/xverse.py +1 -0
  115. sglang/srt/models/xverse_moe.py +1 -0
  116. sglang/srt/openai_api/adapter.py +4 -1
  117. sglang/srt/patch_torch.py +11 -0
  118. sglang/srt/server_args.py +34 -0
  119. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  120. sglang/srt/speculative/eagle_utils.py +1 -11
  121. sglang/srt/speculative/eagle_worker.py +6 -2
  122. sglang/srt/utils.py +120 -9
  123. sglang/test/attention/test_flashattn_backend.py +259 -221
  124. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  125. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  126. sglang/test/test_block_fp8.py +57 -0
  127. sglang/test/test_utils.py +19 -8
  128. sglang/version.py +1 -1
  129. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  130. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
  131. sglang/srt/disaggregation/conn.py +0 -81
  132. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  133. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  134. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -82,6 +82,8 @@ class FlashInferAttnBackend(AttentionBackend):
82
82
  self.max_context_len = model_runner.model_config.context_len
83
83
  self.skip_prefill = skip_prefill
84
84
  self.is_multimodal = model_runner.model_config.is_multimodal
85
+ self.kv_cache_dtype = model_runner.kv_cache_dtype
86
+ self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
85
87
 
86
88
  assert not (
87
89
  model_runner.sliding_window_size is not None
@@ -391,6 +393,8 @@ class FlashInferAttnBackend(AttentionBackend):
391
393
  forward_batch: ForwardBatch,
392
394
  save_kv_cache=True,
393
395
  ):
396
+ k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
397
+ v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
394
398
  prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
395
399
  self._get_wrapper_idx(layer)
396
400
  ]
@@ -407,7 +411,7 @@ class FlashInferAttnBackend(AttentionBackend):
407
411
  assert v is not None
408
412
  if save_kv_cache:
409
413
  forward_batch.token_to_kv_pool.set_kv_buffer(
410
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
414
+ layer, cache_loc, k, v, k_scale, v_scale
411
415
  )
412
416
 
413
417
  o = prefill_wrapper_paged.forward(
@@ -417,8 +421,8 @@ class FlashInferAttnBackend(AttentionBackend):
417
421
  sm_scale=layer.scaling,
418
422
  window_left=layer.sliding_window_size,
419
423
  logits_soft_cap=logits_soft_cap,
420
- k_scale=layer.k_scale,
421
- v_scale=layer.v_scale,
424
+ k_scale=k_scale,
425
+ v_scale=v_scale,
422
426
  )
423
427
  else:
424
428
  o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
@@ -445,7 +449,7 @@ class FlashInferAttnBackend(AttentionBackend):
445
449
 
446
450
  if save_kv_cache:
447
451
  forward_batch.token_to_kv_pool.set_kv_buffer(
448
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
452
+ layer, cache_loc, k, v, k_scale, v_scale
449
453
  )
450
454
 
451
455
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -459,6 +463,8 @@ class FlashInferAttnBackend(AttentionBackend):
459
463
  forward_batch: ForwardBatch,
460
464
  save_kv_cache=True,
461
465
  ):
466
+ k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
467
+ v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
462
468
  decode_wrapper = self.forward_metadata.decode_wrappers[
463
469
  self._get_wrapper_idx(layer)
464
470
  ]
@@ -472,7 +478,7 @@ class FlashInferAttnBackend(AttentionBackend):
472
478
  assert v is not None
473
479
  if save_kv_cache:
474
480
  forward_batch.token_to_kv_pool.set_kv_buffer(
475
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
481
+ layer, cache_loc, k, v, k_scale, v_scale
476
482
  )
477
483
 
478
484
  o = decode_wrapper.forward(
@@ -480,8 +486,8 @@ class FlashInferAttnBackend(AttentionBackend):
480
486
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
481
487
  sm_scale=layer.scaling,
482
488
  logits_soft_cap=layer.logit_cap,
483
- k_scale=layer.k_scale,
484
- v_scale=layer.v_scale,
489
+ k_scale=k_scale,
490
+ v_scale=v_scale,
485
491
  )
486
492
 
487
493
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -94,7 +94,7 @@ class VisionAttention(nn.Module):
94
94
  input_size=embed_dim,
95
95
  output_size=embed_dim,
96
96
  quant_config=quant_config,
97
- prefix=add_prefix("out_proj", prefix),
97
+ prefix=add_prefix("proj", prefix),
98
98
  )
99
99
 
100
100
  def forward(
@@ -192,8 +192,7 @@ def _dp_gather(
192
192
 
193
193
  if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
194
194
  assert (
195
- global_tokens.untyped_storage().data_ptr()
196
- != local_tokens.untyped_storage().data_ptr()
195
+ local_tokens.untyped_storage() is not global_tokens.untyped_storage()
197
196
  ), "aliasing between global_tokens and local_tokens not allowed"
198
197
  memcpy_triton(
199
198
  global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
@@ -243,8 +242,7 @@ def dp_scatter(
243
242
  assert global_tokens.is_contiguous()
244
243
  if local_tokens.shape[0] > 0:
245
244
  assert (
246
- local_tokens.untyped_storage().data_ptr()
247
- != global_tokens.untyped_storage().data_ptr()
245
+ local_tokens.untyped_storage() is not global_tokens.untyped_storage()
248
246
  ), "aliasing between local_tokens and global_tokens not allowed"
249
247
  memcpy_triton(
250
248
  local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
@@ -4,6 +4,10 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
+ from sglang.srt.utils import is_hip
8
+
9
+ _is_hip = is_hip()
10
+
7
11
  fused_softcap_autotune = triton.autotune(
8
12
  configs=[
9
13
  triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
@@ -185,6 +189,9 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
185
189
  assert x.shape == residual.shape and x.dtype == residual.dtype
186
190
  output, mid = torch.empty_like(x), torch.empty_like(x)
187
191
  bs, hidden_dim = x.shape
192
+
193
+ min_num_warps = 16 if _is_hip else 32
194
+
188
195
  if autotune:
189
196
  fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
190
197
  output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
@@ -193,7 +200,10 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
193
200
  config = {
194
201
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
195
202
  "num_warps": max(
196
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
203
+ min(
204
+ triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
205
+ ),
206
+ 4,
197
207
  ),
198
208
  }
199
209
 
@@ -250,10 +260,13 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
250
260
  else:
251
261
  output = torch.empty_like(x)
252
262
  bs, hidden_dim = x.shape
263
+
264
+ min_num_warps = 16 if _is_hip else 32
265
+
253
266
  config = {
254
267
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
255
268
  "num_warps": max(
256
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
269
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
257
270
  ),
258
271
  }
259
272
 
@@ -47,6 +47,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
47
47
  "GPTQLinearMethod",
48
48
  "FBGEMMFp8LinearMethod",
49
49
  "ModelOptFp8LinearMethod",
50
+ "ModelOptFp4LinearMethod",
50
51
  "IPEXAWQLinearMethod",
51
52
  ]
52
53
 
@@ -7,6 +7,7 @@ try:
7
7
  except ImportError:
8
8
  use_deepep = False
9
9
 
10
+ from enum import IntEnum, auto
10
11
  from typing import Optional, Tuple
11
12
 
12
13
  import torch
@@ -19,70 +20,95 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
19
20
  )
20
21
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
21
22
 
22
- _buffer_normal = None
23
- _buffer_low_latency = None
24
23
 
24
+ class DeepEPDispatchMode(IntEnum):
25
+ NORMAL = auto()
26
+ LOW_LATENCY = auto()
25
27
 
26
- def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
27
- """
28
- Copy from DeepEP example usage in model inference prefilling.
29
- https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
30
- """
31
28
 
32
- global _buffer_normal
29
+ class DeepEPBuffer:
33
30
 
34
- num_nvl_bytes, num_rdma_bytes = 0, 0
35
- for config in (
36
- Buffer.get_dispatch_config(group.size()),
37
- Buffer.get_combine_config(group.size()),
38
- ):
39
- num_nvl_bytes = max(
40
- config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
41
- )
42
- num_rdma_bytes = max(
43
- config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
44
- )
31
+ _buffer = None
32
+ _dispatch_mode: Optional[DeepEPDispatchMode] = None
33
+ _hidden_size: Optional[int] = None
34
+ _num_max_dispatch_tokens_per_rank: Optional[int] = None
35
+ _num_experts: Optional[int] = None
45
36
 
46
- if (
47
- _buffer_normal is None
48
- or _buffer_normal.group != group
49
- or _buffer_normal.num_nvl_bytes < num_nvl_bytes
50
- or _buffer_normal.num_rdma_bytes < num_rdma_bytes
51
- ):
52
- _buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes)
53
- return _buffer_normal
54
-
55
-
56
- def _get_buffer_low_latency(
57
- group: dist.ProcessGroup,
58
- num_max_dispatch_tokens_per_rank: int,
59
- hidden: int,
60
- num_experts: int,
61
- ):
62
- """
63
- Copy from DeepEP example usage in model inference decoding.
64
- https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
65
- """
66
-
67
- global _buffer_low_latency
68
- num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
69
- num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts
70
- )
71
-
72
- if (
73
- _buffer_low_latency is None
74
- or _buffer_low_latency.group != group
75
- or not _buffer_low_latency.low_latency_mode
76
- or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes
37
+ @classmethod
38
+ def get_deepep_buffer(
39
+ cls,
40
+ group: dist.ProcessGroup,
41
+ hidden_size: int,
42
+ param_bytes: int,
43
+ deepep_mode: DeepEPMode,
44
+ num_max_dispatch_tokens_per_rank: int = None,
45
+ num_experts: int = None,
77
46
  ):
78
- assert num_experts % group.size() == 0
79
- _buffer_low_latency = Buffer(
47
+ if cls._buffer is not None:
48
+ return cls._buffer
49
+
50
+ cls._hidden_size = hidden_size
51
+ cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
52
+ cls._num_experts = num_experts
53
+
54
+ num_nvl_bytes, num_rdma_bytes = 0, 0
55
+ if deepep_mode.enable_normal():
56
+ hidden_bytes = hidden_size * param_bytes
57
+ for config in (
58
+ Buffer.get_dispatch_config(group.size()),
59
+ Buffer.get_combine_config(group.size()),
60
+ ):
61
+ num_nvl_bytes = max(
62
+ config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
63
+ num_nvl_bytes,
64
+ )
65
+ num_rdma_bytes = max(
66
+ config.get_rdma_buffer_size_hint(hidden_bytes, group.size()),
67
+ num_rdma_bytes,
68
+ )
69
+ if deepep_mode.enable_low_latency():
70
+ assert num_max_dispatch_tokens_per_rank is not None
71
+ assert num_experts is not None and num_experts % group.size() == 0
72
+ num_rdma_bytes = max(
73
+ Buffer.get_low_latency_rdma_size_hint(
74
+ num_max_dispatch_tokens_per_rank,
75
+ hidden_size,
76
+ group.size(),
77
+ num_experts,
78
+ ),
79
+ num_rdma_bytes,
80
+ )
81
+
82
+ cls._buffer = Buffer(
80
83
  group,
81
- num_rdma_bytes=num_rdma_bytes,
82
- low_latency_mode=True,
83
- num_qps_per_rank=num_experts // group.size(),
84
+ num_nvl_bytes,
85
+ num_rdma_bytes,
86
+ low_latency_mode=deepep_mode.enable_low_latency(),
87
+ num_qps_per_rank=(
88
+ num_experts // group.size() if deepep_mode.enable_low_latency() else 1
89
+ ),
84
90
  )
85
- return _buffer_low_latency
91
+ return cls._buffer
92
+
93
+ @classmethod
94
+ def clean_buffer(cls):
95
+ if not cls._buffer.low_latency_mode:
96
+ return
97
+ cls._buffer.clean_low_latency_buffer(
98
+ cls._num_max_dispatch_tokens_per_rank,
99
+ cls._hidden_size,
100
+ cls._num_experts,
101
+ )
102
+
103
+ @classmethod
104
+ def set_dispatch_mode_as_normal(cls):
105
+ cls._dispatch_mode = DeepEPDispatchMode.NORMAL
106
+
107
+ @classmethod
108
+ def set_dispatch_mode_as_low_latency(cls):
109
+ if cls._dispatch_mode == DeepEPDispatchMode.NORMAL:
110
+ cls.clean_buffer()
111
+ cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
86
112
 
87
113
 
88
114
  class _DeepEPDispatcherImplBase:
@@ -95,6 +121,7 @@ class _DeepEPDispatcherImplBase:
95
121
  num_local_experts: int,
96
122
  hidden_size: int,
97
123
  params_dtype: torch.dtype,
124
+ deepep_mode: DeepEPMode,
98
125
  ):
99
126
  if not use_deepep:
100
127
  raise ImportError(
@@ -109,7 +136,10 @@ class _DeepEPDispatcherImplBase:
109
136
  self.num_local_experts = num_local_experts
110
137
  self.hidden_size = hidden_size
111
138
  self.params_dtype = params_dtype
139
+ self.deepep_mode = deepep_mode
140
+
112
141
  self.params_bytes = 2
142
+ self.num_max_dispatch_tokens_per_rank = 128
113
143
 
114
144
  self.handle = None
115
145
 
@@ -118,8 +148,6 @@ class _DeepEPDispatcherImplBase:
118
148
  hidden_states: torch.Tensor,
119
149
  topk_idx: torch.Tensor,
120
150
  topk_weights: torch.Tensor,
121
- num_experts: int,
122
- num_max_dispatch_tokens_per_rank: int,
123
151
  ):
124
152
  raise NotImplementedError
125
153
 
@@ -137,14 +165,14 @@ class _DeepEPDispatcherImplBase:
137
165
  def combine_b(self, *args, **kwargs):
138
166
  raise NotImplementedError
139
167
 
168
+ def _get_buffer(self):
169
+ raise NotImplementedError
170
+
140
171
 
141
172
  class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
142
173
  def __init__(self, async_finish: bool, **kwargs):
143
174
  super().__init__(**kwargs)
144
175
 
145
- self.buffer_normal = _get_buffer_normal(
146
- self.group, self.hidden_size * self.params_bytes
147
- )
148
176
  self.async_finish = async_finish
149
177
  self.src2dst = None
150
178
 
@@ -153,24 +181,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
153
181
  hidden_states: torch.Tensor,
154
182
  topk_idx: torch.Tensor,
155
183
  topk_weights: torch.Tensor,
156
- num_experts: int,
157
- num_max_dispatch_tokens_per_rank: int,
158
184
  ):
159
185
  topk_idx = topk_idx.to(torch.int64)
160
186
  previous_event = Buffer.capture() if self.async_finish else None
161
- return hidden_states, topk_idx, topk_weights, num_experts, previous_event
187
+ return hidden_states, topk_idx, topk_weights, previous_event
162
188
 
163
- def dispatch_b(
164
- self, hidden_states, topk_idx, topk_weights, num_experts, previous_event
165
- ):
189
+ def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
166
190
  (
167
191
  hidden_states,
168
192
  topk_idx,
169
193
  topk_weights,
170
194
  event,
171
- ) = self._dispatch_core(
172
- hidden_states, topk_idx, topk_weights, num_experts, previous_event
173
- )
195
+ ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
174
196
  event.current_stream_wait() if self.async_finish else ()
175
197
  if hidden_states.shape[0] > 0:
176
198
  reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
@@ -181,7 +203,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
181
203
  (0,), device=hidden_states.device, dtype=torch.int64
182
204
  )
183
205
  seg_indptr = torch.zeros(
184
- (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
206
+ (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
185
207
  )
186
208
 
187
209
  masked_m = expected_m = None
@@ -201,18 +223,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
201
223
  x: torch.Tensor,
202
224
  topk_idx: torch.Tensor,
203
225
  topk_weights: torch.Tensor,
204
- num_experts: int,
205
226
  previous_event,
206
227
  ):
228
+ buffer = self._get_buffer()
207
229
  (
208
230
  num_tokens_per_rank,
209
231
  num_tokens_per_rdma_rank,
210
232
  num_tokens_per_expert,
211
233
  is_token_in_rank,
212
234
  previous_event,
213
- ) = self.buffer_normal.get_dispatch_layout(
235
+ ) = buffer.get_dispatch_layout(
214
236
  topk_idx,
215
- num_experts,
237
+ self.num_experts,
216
238
  previous_event=previous_event,
217
239
  async_finish=self.async_finish,
218
240
  allocate_on_comm_stream=previous_event is not None,
@@ -221,6 +243,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
221
243
  # FIXME: `handle` should be transmitted with tokens from dispatch to combine.
222
244
  # However, doing this would incur an unknown synchronization error, but keeping
223
245
  # `handle` as a member variable works.
246
+
224
247
  (
225
248
  recv_x,
226
249
  recv_topk_idx,
@@ -228,7 +251,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
228
251
  _, # num_recv_tokens_per_expert_list
229
252
  self.handle,
230
253
  event,
231
- ) = self.buffer_normal.dispatch(
254
+ ) = buffer.dispatch(
232
255
  x,
233
256
  topk_idx=topk_idx,
234
257
  topk_weights=topk_weights,
@@ -327,7 +350,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
327
350
  return hidden_states
328
351
 
329
352
  def _combine_core(self, x: torch.Tensor, previous_event):
330
- combined_x, _, event = self.buffer_normal.combine(
353
+ buffer = self._get_buffer()
354
+ combined_x, _, event = buffer.combine(
331
355
  x,
332
356
  self.handle,
333
357
  async_finish=self.async_finish,
@@ -336,6 +360,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
336
360
  )
337
361
  return combined_x, event
338
362
 
363
+ def _get_buffer(self):
364
+ DeepEPBuffer.set_dispatch_mode_as_normal()
365
+ return DeepEPBuffer.get_deepep_buffer(
366
+ self.group,
367
+ self.hidden_size,
368
+ self.params_bytes,
369
+ self.deepep_mode,
370
+ self.num_max_dispatch_tokens_per_rank,
371
+ self.num_experts,
372
+ )
373
+
339
374
 
340
375
  class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
341
376
  def __init__(self, return_recv_hook: bool, **kwargs):
@@ -345,14 +380,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
345
380
  num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
346
381
  https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
347
382
  """
348
- # TODO(ch-wan): allow users to set this value
349
- self.num_max_dispatch_tokens_per_rank = 128
350
- self.buffer_low_latency = _get_buffer_low_latency(
351
- self.group,
352
- self.num_max_dispatch_tokens_per_rank,
353
- self.hidden_size,
354
- self.num_experts,
355
- )
356
383
  self.return_recv_hook = return_recv_hook
357
384
 
358
385
  def dispatch_a(
@@ -360,21 +387,16 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
360
387
  hidden_states: torch.Tensor,
361
388
  topk_idx: torch.Tensor,
362
389
  topk_weights: torch.Tensor,
363
- num_experts: int,
364
- num_max_dispatch_tokens_per_rank: int,
365
390
  ):
391
+ buffer = self._get_buffer()
366
392
  topk_idx = topk_idx.to(torch.int64)
367
393
  expected_m = (
368
- hidden_states.shape[0]
369
- * self.buffer_low_latency.group_size
370
- * topk_idx.shape[1]
371
- + num_experts
372
- ) // num_experts
394
+ hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
395
+ + self.num_experts
396
+ ) // self.num_experts
373
397
  hidden_states, masked_m, event, hook = self._dispatch_core(
374
398
  hidden_states,
375
399
  topk_idx,
376
- num_max_dispatch_tokens_per_rank,
377
- num_experts,
378
400
  use_fp8=True,
379
401
  )
380
402
  return (
@@ -415,8 +437,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
415
437
  self,
416
438
  hidden_states: torch.Tensor,
417
439
  topk_idx: torch.Tensor,
418
- num_max_dispatch_tokens_per_rank: int,
419
- num_experts: int,
420
440
  use_fp8: bool = False,
421
441
  ):
422
442
  """
@@ -451,13 +471,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
451
471
 
452
472
  const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
453
473
  """
454
-
474
+ buffer = self._get_buffer()
455
475
  packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
456
- self.buffer_low_latency.low_latency_dispatch(
476
+ buffer.low_latency_dispatch(
457
477
  hidden_states,
458
478
  topk_idx,
459
- num_max_dispatch_tokens_per_rank,
460
- num_experts,
479
+ self.num_max_dispatch_tokens_per_rank,
480
+ self.num_experts,
461
481
  use_fp8=use_fp8,
462
482
  async_finish=not self.return_recv_hook,
463
483
  return_recv_hook=self.return_recv_hook,
@@ -488,19 +508,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
488
508
  topk_idx: torch.Tensor,
489
509
  topk_weights: torch.Tensor,
490
510
  ):
491
- combined_hidden_states, event, hook = (
492
- self.buffer_low_latency.low_latency_combine(
493
- hidden_states,
494
- topk_idx,
495
- topk_weights,
496
- self.handle,
497
- async_finish=not self.return_recv_hook,
498
- return_recv_hook=self.return_recv_hook,
499
- )
511
+ buffer = self._get_buffer()
512
+ combined_hidden_states, event, hook = buffer.low_latency_combine(
513
+ hidden_states,
514
+ topk_idx,
515
+ topk_weights,
516
+ self.handle,
517
+ async_finish=not self.return_recv_hook,
518
+ return_recv_hook=self.return_recv_hook,
500
519
  )
501
520
  self.handle = None
502
521
  return combined_hidden_states, event, hook
503
522
 
523
+ def _get_buffer(self):
524
+ DeepEPBuffer.set_dispatch_mode_as_low_latency()
525
+ return DeepEPBuffer.get_deepep_buffer(
526
+ self.group,
527
+ self.hidden_size,
528
+ self.params_bytes,
529
+ self.deepep_mode,
530
+ self.num_max_dispatch_tokens_per_rank,
531
+ self.num_experts,
532
+ )
533
+
504
534
 
505
535
  class DeepEPDispatcher:
506
536
  def __init__(
@@ -526,18 +556,19 @@ class DeepEPDispatcher:
526
556
  num_local_experts=num_local_experts,
527
557
  hidden_size=hidden_size,
528
558
  params_dtype=params_dtype,
559
+ deepep_mode=deepep_mode,
529
560
  )
530
561
 
531
- if self.deepep_mode.enable_normal():
532
- self._normal_dispatcher = _DeepEPDispatcherImplNormal(
533
- async_finish=async_finish,
534
- **common_kwargs,
535
- )
536
562
  if self.deepep_mode.enable_low_latency():
537
563
  self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency(
538
564
  return_recv_hook=return_recv_hook,
539
565
  **common_kwargs,
540
566
  )
567
+ if self.deepep_mode.enable_normal():
568
+ self._normal_dispatcher = _DeepEPDispatcherImplNormal(
569
+ async_finish=async_finish,
570
+ **common_kwargs,
571
+ )
541
572
 
542
573
  def dispatch(self, *args, **kwargs) -> Tuple:
543
574
  self.dispatch_a(*args, **kwargs)
@@ -548,16 +579,12 @@ class DeepEPDispatcher:
548
579
  hidden_states: torch.Tensor,
549
580
  topk_idx: torch.Tensor,
550
581
  topk_weights: torch.Tensor,
551
- num_experts: int,
552
- num_max_dispatch_tokens_per_rank: int = 128,
553
582
  forward_mode: ForwardMode = None,
554
583
  ):
555
584
  inner_state = self._get_impl(forward_mode).dispatch_a(
556
585
  hidden_states=hidden_states,
557
586
  topk_idx=topk_idx,
558
587
  topk_weights=topk_weights,
559
- num_experts=num_experts,
560
- num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
561
588
  )
562
589
  self._dispatch_intermediate_state = forward_mode, inner_state
563
590
 
@@ -589,7 +616,7 @@ class DeepEPDispatcher:
589
616
  del self._combine_intermediate_state
590
617
  return self._get_impl(forward_mode).combine_b(*inner_state)
591
618
 
592
- def _get_impl(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase":
619
+ def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
593
620
  resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
594
621
  if resolved_deepep_mode == DeepEPMode.normal:
595
622
  return self._normal_dispatcher
@@ -23,9 +23,14 @@ def fused_moe_forward_native(
23
23
  custom_routing_function: Optional[Callable] = None,
24
24
  correction_bias: Optional[torch.Tensor] = None,
25
25
  activation: str = "silu",
26
+ apply_router_weight_on_input: bool = False,
26
27
  inplace: bool = True,
27
28
  no_combine: bool = False,
28
29
  ) -> torch.Tensor:
30
+
31
+ if apply_router_weight_on_input:
32
+ raise NotImplementedError
33
+
29
34
  topk_weights, topk_ids = select_experts(
30
35
  hidden_states=x,
31
36
  router_logits=router_logits,