sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ try:
7
7
  except ImportError:
8
8
  use_deepep = False
9
9
 
10
+ from enum import IntEnum, auto
10
11
  from typing import Optional, Tuple
11
12
 
12
13
  import torch
@@ -19,70 +20,95 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
19
20
  )
20
21
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
21
22
 
22
- _buffer_normal = None
23
- _buffer_low_latency = None
24
23
 
24
+ class DeepEPDispatchMode(IntEnum):
25
+ NORMAL = auto()
26
+ LOW_LATENCY = auto()
25
27
 
26
- def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
27
- """
28
- Copy from DeepEP example usage in model inference prefilling.
29
- https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
30
- """
31
28
 
32
- global _buffer_normal
29
+ class DeepEPBuffer:
33
30
 
34
- num_nvl_bytes, num_rdma_bytes = 0, 0
35
- for config in (
36
- Buffer.get_dispatch_config(group.size()),
37
- Buffer.get_combine_config(group.size()),
38
- ):
39
- num_nvl_bytes = max(
40
- config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
41
- )
42
- num_rdma_bytes = max(
43
- config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
44
- )
31
+ _buffer = None
32
+ _dispatch_mode: Optional[DeepEPDispatchMode] = None
33
+ _hidden_size: Optional[int] = None
34
+ _num_max_dispatch_tokens_per_rank: Optional[int] = None
35
+ _num_experts: Optional[int] = None
45
36
 
46
- if (
47
- _buffer_normal is None
48
- or _buffer_normal.group != group
49
- or _buffer_normal.num_nvl_bytes < num_nvl_bytes
50
- or _buffer_normal.num_rdma_bytes < num_rdma_bytes
51
- ):
52
- _buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes)
53
- return _buffer_normal
54
-
55
-
56
- def _get_buffer_low_latency(
57
- group: dist.ProcessGroup,
58
- num_max_dispatch_tokens_per_rank: int,
59
- hidden: int,
60
- num_experts: int,
61
- ):
62
- """
63
- Copy from DeepEP example usage in model inference decoding.
64
- https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
65
- """
66
-
67
- global _buffer_low_latency
68
- num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
69
- num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts
70
- )
71
-
72
- if (
73
- _buffer_low_latency is None
74
- or _buffer_low_latency.group != group
75
- or not _buffer_low_latency.low_latency_mode
76
- or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes
37
+ @classmethod
38
+ def get_deepep_buffer(
39
+ cls,
40
+ group: dist.ProcessGroup,
41
+ hidden_size: int,
42
+ param_bytes: int,
43
+ deepep_mode: DeepEPMode,
44
+ num_max_dispatch_tokens_per_rank: int = None,
45
+ num_experts: int = None,
77
46
  ):
78
- assert num_experts % group.size() == 0
79
- _buffer_low_latency = Buffer(
47
+ if cls._buffer is not None:
48
+ return cls._buffer
49
+
50
+ cls._hidden_size = hidden_size
51
+ cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
52
+ cls._num_experts = num_experts
53
+
54
+ num_nvl_bytes, num_rdma_bytes = 0, 0
55
+ if deepep_mode.enable_normal():
56
+ hidden_bytes = hidden_size * param_bytes
57
+ for config in (
58
+ Buffer.get_dispatch_config(group.size()),
59
+ Buffer.get_combine_config(group.size()),
60
+ ):
61
+ num_nvl_bytes = max(
62
+ config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
63
+ num_nvl_bytes,
64
+ )
65
+ num_rdma_bytes = max(
66
+ config.get_rdma_buffer_size_hint(hidden_bytes, group.size()),
67
+ num_rdma_bytes,
68
+ )
69
+ if deepep_mode.enable_low_latency():
70
+ assert num_max_dispatch_tokens_per_rank is not None
71
+ assert num_experts is not None and num_experts % group.size() == 0
72
+ num_rdma_bytes = max(
73
+ Buffer.get_low_latency_rdma_size_hint(
74
+ num_max_dispatch_tokens_per_rank,
75
+ hidden_size,
76
+ group.size(),
77
+ num_experts,
78
+ ),
79
+ num_rdma_bytes,
80
+ )
81
+
82
+ cls._buffer = Buffer(
80
83
  group,
81
- num_rdma_bytes=num_rdma_bytes,
82
- low_latency_mode=True,
83
- num_qps_per_rank=num_experts // group.size(),
84
+ num_nvl_bytes,
85
+ num_rdma_bytes,
86
+ low_latency_mode=deepep_mode.enable_low_latency(),
87
+ num_qps_per_rank=(
88
+ num_experts // group.size() if deepep_mode.enable_low_latency() else 1
89
+ ),
84
90
  )
85
- return _buffer_low_latency
91
+ return cls._buffer
92
+
93
+ @classmethod
94
+ def clean_buffer(cls):
95
+ if not cls._buffer.low_latency_mode:
96
+ return
97
+ cls._buffer.clean_low_latency_buffer(
98
+ cls._num_max_dispatch_tokens_per_rank,
99
+ cls._hidden_size,
100
+ cls._num_experts,
101
+ )
102
+
103
+ @classmethod
104
+ def set_dispatch_mode_as_normal(cls):
105
+ cls._dispatch_mode = DeepEPDispatchMode.NORMAL
106
+
107
+ @classmethod
108
+ def set_dispatch_mode_as_low_latency(cls):
109
+ if cls._dispatch_mode == DeepEPDispatchMode.NORMAL:
110
+ cls.clean_buffer()
111
+ cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
86
112
 
87
113
 
88
114
  class _DeepEPDispatcherImplBase:
@@ -95,6 +121,7 @@ class _DeepEPDispatcherImplBase:
95
121
  num_local_experts: int,
96
122
  hidden_size: int,
97
123
  params_dtype: torch.dtype,
124
+ deepep_mode: DeepEPMode,
98
125
  ):
99
126
  if not use_deepep:
100
127
  raise ImportError(
@@ -109,7 +136,10 @@ class _DeepEPDispatcherImplBase:
109
136
  self.num_local_experts = num_local_experts
110
137
  self.hidden_size = hidden_size
111
138
  self.params_dtype = params_dtype
139
+ self.deepep_mode = deepep_mode
140
+
112
141
  self.params_bytes = 2
142
+ self.num_max_dispatch_tokens_per_rank = 128
113
143
 
114
144
  self.handle = None
115
145
 
@@ -118,8 +148,6 @@ class _DeepEPDispatcherImplBase:
118
148
  hidden_states: torch.Tensor,
119
149
  topk_idx: torch.Tensor,
120
150
  topk_weights: torch.Tensor,
121
- num_experts: int,
122
- num_max_dispatch_tokens_per_rank: int,
123
151
  ):
124
152
  raise NotImplementedError
125
153
 
@@ -137,14 +165,14 @@ class _DeepEPDispatcherImplBase:
137
165
  def combine_b(self, *args, **kwargs):
138
166
  raise NotImplementedError
139
167
 
168
+ def _get_buffer(self):
169
+ raise NotImplementedError
170
+
140
171
 
141
172
  class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
142
173
  def __init__(self, async_finish: bool, **kwargs):
143
174
  super().__init__(**kwargs)
144
175
 
145
- self.buffer_normal = _get_buffer_normal(
146
- self.group, self.hidden_size * self.params_bytes
147
- )
148
176
  self.async_finish = async_finish
149
177
  self.src2dst = None
150
178
 
@@ -153,24 +181,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
153
181
  hidden_states: torch.Tensor,
154
182
  topk_idx: torch.Tensor,
155
183
  topk_weights: torch.Tensor,
156
- num_experts: int,
157
- num_max_dispatch_tokens_per_rank: int,
158
184
  ):
159
185
  topk_idx = topk_idx.to(torch.int64)
160
186
  previous_event = Buffer.capture() if self.async_finish else None
161
- return hidden_states, topk_idx, topk_weights, num_experts, previous_event
187
+ return hidden_states, topk_idx, topk_weights, previous_event
162
188
 
163
- def dispatch_b(
164
- self, hidden_states, topk_idx, topk_weights, num_experts, previous_event
165
- ):
189
+ def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
166
190
  (
167
191
  hidden_states,
168
192
  topk_idx,
169
193
  topk_weights,
170
194
  event,
171
- ) = self._dispatch_core(
172
- hidden_states, topk_idx, topk_weights, num_experts, previous_event
173
- )
195
+ ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
174
196
  event.current_stream_wait() if self.async_finish else ()
175
197
  if hidden_states.shape[0] > 0:
176
198
  reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
@@ -181,7 +203,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
181
203
  (0,), device=hidden_states.device, dtype=torch.int64
182
204
  )
183
205
  seg_indptr = torch.zeros(
184
- (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
206
+ (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
185
207
  )
186
208
 
187
209
  masked_m = expected_m = None
@@ -201,18 +223,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
201
223
  x: torch.Tensor,
202
224
  topk_idx: torch.Tensor,
203
225
  topk_weights: torch.Tensor,
204
- num_experts: int,
205
226
  previous_event,
206
227
  ):
228
+ buffer = self._get_buffer()
207
229
  (
208
230
  num_tokens_per_rank,
209
231
  num_tokens_per_rdma_rank,
210
232
  num_tokens_per_expert,
211
233
  is_token_in_rank,
212
234
  previous_event,
213
- ) = self.buffer_normal.get_dispatch_layout(
235
+ ) = buffer.get_dispatch_layout(
214
236
  topk_idx,
215
- num_experts,
237
+ self.num_experts,
216
238
  previous_event=previous_event,
217
239
  async_finish=self.async_finish,
218
240
  allocate_on_comm_stream=previous_event is not None,
@@ -221,6 +243,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
221
243
  # FIXME: `handle` should be transmitted with tokens from dispatch to combine.
222
244
  # However, doing this would incur an unknown synchronization error, but keeping
223
245
  # `handle` as a member variable works.
246
+
224
247
  (
225
248
  recv_x,
226
249
  recv_topk_idx,
@@ -228,7 +251,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
228
251
  _, # num_recv_tokens_per_expert_list
229
252
  self.handle,
230
253
  event,
231
- ) = self.buffer_normal.dispatch(
254
+ ) = buffer.dispatch(
232
255
  x,
233
256
  topk_idx=topk_idx,
234
257
  topk_weights=topk_weights,
@@ -327,7 +350,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
327
350
  return hidden_states
328
351
 
329
352
  def _combine_core(self, x: torch.Tensor, previous_event):
330
- combined_x, _, event = self.buffer_normal.combine(
353
+ buffer = self._get_buffer()
354
+ combined_x, _, event = buffer.combine(
331
355
  x,
332
356
  self.handle,
333
357
  async_finish=self.async_finish,
@@ -336,6 +360,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
336
360
  )
337
361
  return combined_x, event
338
362
 
363
+ def _get_buffer(self):
364
+ DeepEPBuffer.set_dispatch_mode_as_normal()
365
+ return DeepEPBuffer.get_deepep_buffer(
366
+ self.group,
367
+ self.hidden_size,
368
+ self.params_bytes,
369
+ self.deepep_mode,
370
+ self.num_max_dispatch_tokens_per_rank,
371
+ self.num_experts,
372
+ )
373
+
339
374
 
340
375
  class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
341
376
  def __init__(self, return_recv_hook: bool, **kwargs):
@@ -345,14 +380,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
345
380
  num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
346
381
  https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
347
382
  """
348
- # TODO(ch-wan): allow users to set this value
349
- self.num_max_dispatch_tokens_per_rank = 128
350
- self.buffer_low_latency = _get_buffer_low_latency(
351
- self.group,
352
- self.num_max_dispatch_tokens_per_rank,
353
- self.hidden_size,
354
- self.num_experts,
355
- )
356
383
  self.return_recv_hook = return_recv_hook
357
384
 
358
385
  def dispatch_a(
@@ -360,21 +387,16 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
360
387
  hidden_states: torch.Tensor,
361
388
  topk_idx: torch.Tensor,
362
389
  topk_weights: torch.Tensor,
363
- num_experts: int,
364
- num_max_dispatch_tokens_per_rank: int,
365
390
  ):
391
+ buffer = self._get_buffer()
366
392
  topk_idx = topk_idx.to(torch.int64)
367
393
  expected_m = (
368
- hidden_states.shape[0]
369
- * self.buffer_low_latency.group_size
370
- * topk_idx.shape[1]
371
- + num_experts
372
- ) // num_experts
394
+ hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
395
+ + self.num_experts
396
+ ) // self.num_experts
373
397
  hidden_states, masked_m, event, hook = self._dispatch_core(
374
398
  hidden_states,
375
399
  topk_idx,
376
- num_max_dispatch_tokens_per_rank,
377
- num_experts,
378
400
  use_fp8=True,
379
401
  )
380
402
  return (
@@ -415,8 +437,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
415
437
  self,
416
438
  hidden_states: torch.Tensor,
417
439
  topk_idx: torch.Tensor,
418
- num_max_dispatch_tokens_per_rank: int,
419
- num_experts: int,
420
440
  use_fp8: bool = False,
421
441
  ):
422
442
  """
@@ -451,13 +471,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
451
471
 
452
472
  const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
453
473
  """
454
-
474
+ buffer = self._get_buffer()
455
475
  packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
456
- self.buffer_low_latency.low_latency_dispatch(
476
+ buffer.low_latency_dispatch(
457
477
  hidden_states,
458
478
  topk_idx,
459
- num_max_dispatch_tokens_per_rank,
460
- num_experts,
479
+ self.num_max_dispatch_tokens_per_rank,
480
+ self.num_experts,
461
481
  use_fp8=use_fp8,
462
482
  async_finish=not self.return_recv_hook,
463
483
  return_recv_hook=self.return_recv_hook,
@@ -488,19 +508,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
488
508
  topk_idx: torch.Tensor,
489
509
  topk_weights: torch.Tensor,
490
510
  ):
491
- combined_hidden_states, event, hook = (
492
- self.buffer_low_latency.low_latency_combine(
493
- hidden_states,
494
- topk_idx,
495
- topk_weights,
496
- self.handle,
497
- async_finish=not self.return_recv_hook,
498
- return_recv_hook=self.return_recv_hook,
499
- )
511
+ buffer = self._get_buffer()
512
+ combined_hidden_states, event, hook = buffer.low_latency_combine(
513
+ hidden_states,
514
+ topk_idx,
515
+ topk_weights,
516
+ self.handle,
517
+ async_finish=not self.return_recv_hook,
518
+ return_recv_hook=self.return_recv_hook,
500
519
  )
501
520
  self.handle = None
502
521
  return combined_hidden_states, event, hook
503
522
 
523
+ def _get_buffer(self):
524
+ DeepEPBuffer.set_dispatch_mode_as_low_latency()
525
+ return DeepEPBuffer.get_deepep_buffer(
526
+ self.group,
527
+ self.hidden_size,
528
+ self.params_bytes,
529
+ self.deepep_mode,
530
+ self.num_max_dispatch_tokens_per_rank,
531
+ self.num_experts,
532
+ )
533
+
504
534
 
505
535
  class DeepEPDispatcher:
506
536
  def __init__(
@@ -526,18 +556,19 @@ class DeepEPDispatcher:
526
556
  num_local_experts=num_local_experts,
527
557
  hidden_size=hidden_size,
528
558
  params_dtype=params_dtype,
559
+ deepep_mode=deepep_mode,
529
560
  )
530
561
 
531
- if self.deepep_mode.enable_normal():
532
- self._normal_dispatcher = _DeepEPDispatcherImplNormal(
533
- async_finish=async_finish,
534
- **common_kwargs,
535
- )
536
562
  if self.deepep_mode.enable_low_latency():
537
563
  self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency(
538
564
  return_recv_hook=return_recv_hook,
539
565
  **common_kwargs,
540
566
  )
567
+ if self.deepep_mode.enable_normal():
568
+ self._normal_dispatcher = _DeepEPDispatcherImplNormal(
569
+ async_finish=async_finish,
570
+ **common_kwargs,
571
+ )
541
572
 
542
573
  def dispatch(self, *args, **kwargs) -> Tuple:
543
574
  self.dispatch_a(*args, **kwargs)
@@ -548,16 +579,12 @@ class DeepEPDispatcher:
548
579
  hidden_states: torch.Tensor,
549
580
  topk_idx: torch.Tensor,
550
581
  topk_weights: torch.Tensor,
551
- num_experts: int,
552
- num_max_dispatch_tokens_per_rank: int = 128,
553
582
  forward_mode: ForwardMode = None,
554
583
  ):
555
584
  inner_state = self._get_impl(forward_mode).dispatch_a(
556
585
  hidden_states=hidden_states,
557
586
  topk_idx=topk_idx,
558
587
  topk_weights=topk_weights,
559
- num_experts=num_experts,
560
- num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
561
588
  )
562
589
  self._dispatch_intermediate_state = forward_mode, inner_state
563
590
 
@@ -589,7 +616,7 @@ class DeepEPDispatcher:
589
616
  del self._combine_intermediate_state
590
617
  return self._get_impl(forward_mode).combine_b(*inner_state)
591
618
 
592
- def _get_impl(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase":
619
+ def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
593
620
  resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
594
621
  if resolved_deepep_mode == DeepEPMode.normal:
595
622
  return self._normal_dispatcher
@@ -26,6 +26,7 @@ def fused_moe_forward_native(
26
26
  apply_router_weight_on_input: bool = False,
27
27
  inplace: bool = True,
28
28
  no_combine: bool = False,
29
+ routed_scaling_factor: Optional[float] = None,
29
30
  ) -> torch.Tensor:
30
31
 
31
32
  if apply_router_weight_on_input:
@@ -41,6 +42,7 @@ def fused_moe_forward_native(
41
42
  num_expert_group=num_expert_group,
42
43
  custom_routing_function=custom_routing_function,
43
44
  correction_bias=correction_bias,
45
+ routed_scaling_factor=routed_scaling_factor,
44
46
  torch_native=True,
45
47
  )
46
48
 
@@ -71,6 +73,7 @@ def moe_forward_native(
71
73
  custom_routing_function: Optional[Callable] = None,
72
74
  correction_bias: Optional[torch.Tensor] = None,
73
75
  activation: str = "silu",
76
+ routed_scaling_factor: Optional[float] = None,
74
77
  ) -> torch.Tensor:
75
78
 
76
79
  from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
@@ -86,6 +89,7 @@ def moe_forward_native(
86
89
  custom_routing_function=custom_routing_function,
87
90
  correction_bias=correction_bias,
88
91
  torch_native=True,
92
+ routed_scaling_factor=routed_scaling_factor,
89
93
  )
90
94
 
91
95
  # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 4,
64
+ "num_stages": 4
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 32,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 64,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 64,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 64,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 64,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }