sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +72 -10
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +6 -16
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +582 -125
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +79 -6
- sglang/srt/layers/quantization/__init__.py +137 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +44 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -127
- sglang/srt/managers/scheduler.py +29 -23
- sglang/srt/managers/tokenizer_manager.py +1 -2
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +16 -13
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +64 -59
- sglang/srt/model_loader/loader.py +19 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +568 -0
- sglang/srt/models/deepseek_janus_pro.py +12 -17
- sglang/srt/models/deepseek_v2.py +339 -123
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +20 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +106 -93
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +120 -25
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +94 -25
- sglang/srt/utils.py +137 -51
- sglang/test/runners.py +27 -2
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +14 -27
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
|
|
1
|
+
from sglang.srt.utils import DeepEPMode
|
2
|
+
|
1
3
|
try:
|
2
4
|
from deep_ep import Buffer
|
3
5
|
|
@@ -21,7 +23,7 @@ _buffer_normal = None
|
|
21
23
|
_buffer_low_latency = None
|
22
24
|
|
23
25
|
|
24
|
-
def
|
26
|
+
def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
|
25
27
|
"""
|
26
28
|
Copy from DeepEP example usage in model inference prefilling.
|
27
29
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
|
@@ -51,7 +53,7 @@ def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
|
|
51
53
|
return _buffer_normal
|
52
54
|
|
53
55
|
|
54
|
-
def
|
56
|
+
def _get_buffer_low_latency(
|
55
57
|
group: dist.ProcessGroup,
|
56
58
|
num_max_dispatch_tokens_per_rank: int,
|
57
59
|
hidden: int,
|
@@ -76,151 +78,103 @@ def get_buffer_low_latency(
|
|
76
78
|
assert num_experts % group.size() == 0
|
77
79
|
_buffer_low_latency = Buffer(
|
78
80
|
group,
|
79
|
-
|
80
|
-
num_rdma_bytes,
|
81
|
+
num_rdma_bytes=num_rdma_bytes,
|
81
82
|
low_latency_mode=True,
|
82
83
|
num_qps_per_rank=num_experts // group.size(),
|
83
84
|
)
|
84
85
|
return _buffer_low_latency
|
85
86
|
|
86
87
|
|
87
|
-
class
|
88
|
-
"""
|
89
|
-
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
90
|
-
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
91
|
-
"""
|
92
|
-
|
88
|
+
class _DeepEPDispatcherImplBase:
|
93
89
|
def __init__(
|
94
90
|
self,
|
95
91
|
group: torch.distributed.ProcessGroup,
|
96
92
|
router_topk: int,
|
97
|
-
permute_fusion: bool
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
params_dtype: torch.dtype = None,
|
103
|
-
async_finish: bool = False,
|
93
|
+
permute_fusion: bool,
|
94
|
+
num_experts: int,
|
95
|
+
num_local_experts: int,
|
96
|
+
hidden_size: int,
|
97
|
+
params_dtype: torch.dtype,
|
104
98
|
):
|
99
|
+
if not use_deepep:
|
100
|
+
raise ImportError(
|
101
|
+
"DeepEP is not installed. Please install DeepEP package from "
|
102
|
+
"https://github.com/deepseek-ai/deepep."
|
103
|
+
)
|
104
|
+
|
105
105
|
self.group = group
|
106
106
|
self.router_topk = router_topk
|
107
|
-
self.capacity_factor = capacity_factor
|
108
107
|
self.permute_fusion = permute_fusion
|
109
108
|
self.num_experts = num_experts
|
110
109
|
self.num_local_experts = num_local_experts
|
111
110
|
self.hidden_size = hidden_size
|
112
|
-
self.recv_expert_count = None
|
113
111
|
self.params_dtype = params_dtype
|
114
112
|
self.params_bytes = 2
|
115
|
-
|
116
|
-
self.token_indices = None
|
117
|
-
self.token_probs = None
|
118
|
-
# Handle used for combine operation
|
113
|
+
|
119
114
|
self.handle = None
|
120
|
-
self.async_finish = async_finish
|
121
115
|
|
122
|
-
|
123
|
-
|
124
|
-
|
116
|
+
def dispatch_a(
|
117
|
+
self,
|
118
|
+
hidden_states: torch.Tensor,
|
119
|
+
topk_idx: torch.Tensor,
|
120
|
+
topk_weights: torch.Tensor,
|
121
|
+
num_experts: int,
|
122
|
+
num_max_dispatch_tokens_per_rank: int,
|
123
|
+
):
|
124
|
+
raise NotImplementedError
|
125
125
|
|
126
|
-
|
127
|
-
|
128
|
-
"DeepEP is not installed. Please install DeepEP package from "
|
129
|
-
"https://github.com/deepseek-ai/deepep."
|
130
|
-
)
|
131
|
-
self.buffer_normal = get_buffer_normal(
|
132
|
-
self.group, self.hidden_size * self.params_bytes
|
133
|
-
)
|
134
|
-
self.buffer_low_latency = None
|
135
|
-
# Todo: enable low latency dispatch
|
136
|
-
"""
|
137
|
-
self.buffer_low_latency = get_buffer_low_latency(
|
138
|
-
self.group,
|
139
|
-
self.num_max_dispatch_tokens_per_rank,
|
140
|
-
self.hidden_size * self.params_bytes,
|
141
|
-
self.num_experts,
|
142
|
-
)
|
143
|
-
"""
|
126
|
+
def dispatch_b(self, *args, **kwargs):
|
127
|
+
raise NotImplementedError
|
144
128
|
|
145
|
-
def
|
129
|
+
def combine_a(
|
146
130
|
self,
|
147
|
-
hidden_states,
|
148
|
-
|
149
|
-
|
150
|
-
use_block_quant=False,
|
131
|
+
hidden_states: torch.Tensor,
|
132
|
+
topk_idx: torch.Tensor,
|
133
|
+
topk_weights: torch.Tensor,
|
151
134
|
):
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
)
|
165
|
-
# PreReorder
|
166
|
-
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
167
|
-
hidden_states,
|
168
|
-
gateup_input,
|
169
|
-
src2dst,
|
170
|
-
self.topk_idx,
|
171
|
-
None,
|
172
|
-
self.router_topk,
|
173
|
-
hidden_states.shape[1],
|
174
|
-
BLOCK_SIZE=512,
|
135
|
+
raise NotImplementedError
|
136
|
+
|
137
|
+
def combine_b(self, *args, **kwargs):
|
138
|
+
raise NotImplementedError
|
139
|
+
|
140
|
+
|
141
|
+
class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
142
|
+
def __init__(self, async_finish: bool, **kwargs):
|
143
|
+
super().__init__(**kwargs)
|
144
|
+
|
145
|
+
self.buffer_normal = _get_buffer_normal(
|
146
|
+
self.group, self.hidden_size * self.params_bytes
|
175
147
|
)
|
176
|
-
self.
|
177
|
-
|
148
|
+
self.async_finish = async_finish
|
149
|
+
self.src2dst = None
|
178
150
|
|
179
|
-
def
|
151
|
+
def dispatch_a(
|
180
152
|
self,
|
181
153
|
hidden_states: torch.Tensor,
|
182
154
|
topk_idx: torch.Tensor,
|
183
155
|
topk_weights: torch.Tensor,
|
184
156
|
num_experts: int,
|
185
|
-
|
186
|
-
|
187
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
157
|
+
num_max_dispatch_tokens_per_rank: int,
|
158
|
+
):
|
188
159
|
topk_idx = topk_idx.to(torch.int64)
|
189
|
-
|
190
|
-
|
191
|
-
(
|
192
|
-
hidden_states,
|
193
|
-
topk_idx,
|
194
|
-
topk_weights,
|
195
|
-
num_recv_tokens_per_expert_list,
|
196
|
-
handle,
|
197
|
-
event,
|
198
|
-
) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
|
199
|
-
self.tokens_per_expert = torch.tensor(
|
200
|
-
num_recv_tokens_per_expert_list,
|
201
|
-
device=hidden_states.device,
|
202
|
-
dtype=torch.int64,
|
203
|
-
)
|
204
|
-
else:
|
205
|
-
hidden_states, recv_expert_count, handle, event, hook = (
|
206
|
-
self.dispatch_low_latency(
|
207
|
-
hidden_states,
|
208
|
-
topk_idx,
|
209
|
-
num_max_dispatch_tokens_per_rank,
|
210
|
-
num_experts,
|
211
|
-
)
|
212
|
-
)
|
213
|
-
self.recv_expert_count = recv_expert_count
|
214
|
-
|
215
|
-
if self.async_finish:
|
216
|
-
event.current_stream_wait()
|
160
|
+
previous_event = Buffer.capture() if self.async_finish else None
|
161
|
+
return hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
217
162
|
|
218
|
-
|
219
|
-
self
|
220
|
-
|
163
|
+
def dispatch_b(
|
164
|
+
self, hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
165
|
+
):
|
166
|
+
(
|
167
|
+
hidden_states,
|
168
|
+
topk_idx,
|
169
|
+
topk_weights,
|
170
|
+
event,
|
171
|
+
) = self._dispatch_core(
|
172
|
+
hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
173
|
+
)
|
174
|
+
event.current_stream_wait() if self.async_finish else ()
|
221
175
|
if hidden_states.shape[0] > 0:
|
222
|
-
reorder_topk_ids, seg_indptr, hidden_states = self.
|
223
|
-
hidden_states, fp8_dtype=hidden_states.dtype
|
176
|
+
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
177
|
+
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
224
178
|
)
|
225
179
|
else:
|
226
180
|
reorder_topk_ids = torch.empty(
|
@@ -229,17 +183,27 @@ class DeepEPDispatcher:
|
|
229
183
|
seg_indptr = torch.zeros(
|
230
184
|
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
231
185
|
)
|
232
|
-
return hidden_states, reorder_topk_ids, seg_indptr
|
233
186
|
|
234
|
-
|
187
|
+
masked_m = expected_m = None
|
188
|
+
|
189
|
+
return (
|
190
|
+
hidden_states,
|
191
|
+
topk_idx,
|
192
|
+
topk_weights,
|
193
|
+
reorder_topk_ids,
|
194
|
+
seg_indptr,
|
195
|
+
masked_m,
|
196
|
+
expected_m,
|
197
|
+
)
|
198
|
+
|
199
|
+
def _dispatch_core(
|
235
200
|
self,
|
236
201
|
x: torch.Tensor,
|
237
202
|
topk_idx: torch.Tensor,
|
238
203
|
topk_weights: torch.Tensor,
|
239
204
|
num_experts: int,
|
205
|
+
previous_event,
|
240
206
|
):
|
241
|
-
previous_event = Buffer.capture() if self.async_finish else None
|
242
|
-
|
243
207
|
(
|
244
208
|
num_tokens_per_rank,
|
245
209
|
num_tokens_per_rdma_rank,
|
@@ -254,12 +218,15 @@ class DeepEPDispatcher:
|
|
254
218
|
allocate_on_comm_stream=previous_event is not None,
|
255
219
|
)
|
256
220
|
|
221
|
+
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
|
222
|
+
# However, doing this would incur an unknown synchronization error, but keeping
|
223
|
+
# `handle` as a member variable works.
|
257
224
|
(
|
258
225
|
recv_x,
|
259
226
|
recv_topk_idx,
|
260
227
|
recv_topk_weights,
|
261
|
-
num_recv_tokens_per_expert_list
|
262
|
-
handle,
|
228
|
+
_, # num_recv_tokens_per_expert_list
|
229
|
+
self.handle,
|
263
230
|
event,
|
264
231
|
) = self.buffer_normal.dispatch(
|
265
232
|
x,
|
@@ -278,29 +245,191 @@ class DeepEPDispatcher:
|
|
278
245
|
recv_x,
|
279
246
|
recv_topk_idx,
|
280
247
|
recv_topk_weights,
|
281
|
-
num_recv_tokens_per_expert_list,
|
282
|
-
handle,
|
283
248
|
event,
|
284
249
|
)
|
285
250
|
|
286
|
-
def
|
251
|
+
def _deepep_permute(
|
252
|
+
self,
|
253
|
+
hidden_states: torch.Tensor,
|
254
|
+
topk_idx: torch.Tensor,
|
255
|
+
fp8_dtype: Optional[torch.dtype] = None,
|
256
|
+
use_fp8_w8a8: bool = False,
|
257
|
+
use_block_quant: bool = False,
|
258
|
+
):
|
259
|
+
"""
|
260
|
+
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
261
|
+
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
262
|
+
"""
|
263
|
+
|
264
|
+
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
265
|
+
topk_idx, self.num_experts
|
266
|
+
)
|
267
|
+
num_total_tokens = reorder_topk_ids.numel()
|
268
|
+
gateup_input = torch.empty(
|
269
|
+
(int(num_total_tokens), hidden_states.shape[1]),
|
270
|
+
device=hidden_states.device,
|
271
|
+
dtype=(
|
272
|
+
fp8_dtype
|
273
|
+
if (use_fp8_w8a8 and not use_block_quant)
|
274
|
+
else hidden_states.dtype
|
275
|
+
),
|
276
|
+
)
|
277
|
+
# PreReorder
|
278
|
+
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
279
|
+
hidden_states,
|
280
|
+
gateup_input,
|
281
|
+
self.src2dst,
|
282
|
+
topk_idx,
|
283
|
+
None,
|
284
|
+
self.router_topk,
|
285
|
+
hidden_states.shape[1],
|
286
|
+
BLOCK_SIZE=512,
|
287
|
+
)
|
288
|
+
return reorder_topk_ids, seg_indptr, gateup_input
|
289
|
+
|
290
|
+
def combine_a(
|
291
|
+
self,
|
292
|
+
hidden_states: torch.Tensor,
|
293
|
+
topk_idx: torch.Tensor,
|
294
|
+
topk_weights: torch.Tensor,
|
295
|
+
):
|
296
|
+
if hidden_states.shape[0] > 0:
|
297
|
+
num_tokens = self.src2dst.shape[0] // self.router_topk
|
298
|
+
output = torch.empty(
|
299
|
+
(num_tokens, hidden_states.shape[1]),
|
300
|
+
device=hidden_states.device,
|
301
|
+
dtype=hidden_states.dtype,
|
302
|
+
)
|
303
|
+
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
304
|
+
hidden_states,
|
305
|
+
output,
|
306
|
+
self.src2dst,
|
307
|
+
topk_idx,
|
308
|
+
topk_weights,
|
309
|
+
self.router_topk,
|
310
|
+
hidden_states.shape[1],
|
311
|
+
BLOCK_SIZE=512,
|
312
|
+
)
|
313
|
+
else:
|
314
|
+
output = torch.zeros(
|
315
|
+
(0, hidden_states.shape[1]),
|
316
|
+
device=hidden_states.device,
|
317
|
+
dtype=hidden_states.dtype,
|
318
|
+
)
|
319
|
+
previous_event = Buffer.capture() if self.async_finish else None
|
320
|
+
return output, previous_event
|
321
|
+
|
322
|
+
def combine_b(self, output, previous_event):
|
323
|
+
hidden_states, event = self._combine_core(output, previous_event)
|
324
|
+
event.current_stream_wait() if self.async_finish else ()
|
325
|
+
self.handle = None
|
326
|
+
self.src2dst = None
|
327
|
+
return hidden_states
|
328
|
+
|
329
|
+
def _combine_core(self, x: torch.Tensor, previous_event):
|
330
|
+
combined_x, _, event = self.buffer_normal.combine(
|
331
|
+
x,
|
332
|
+
self.handle,
|
333
|
+
async_finish=self.async_finish,
|
334
|
+
previous_event=previous_event,
|
335
|
+
allocate_on_comm_stream=previous_event is not None,
|
336
|
+
)
|
337
|
+
return combined_x, event
|
338
|
+
|
339
|
+
|
340
|
+
class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
341
|
+
def __init__(self, return_recv_hook: bool, **kwargs):
|
342
|
+
super().__init__(**kwargs)
|
343
|
+
|
344
|
+
"""
|
345
|
+
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
|
346
|
+
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
347
|
+
"""
|
348
|
+
# TODO(ch-wan): allow users to set this value
|
349
|
+
self.num_max_dispatch_tokens_per_rank = 128
|
350
|
+
self.buffer_low_latency = _get_buffer_low_latency(
|
351
|
+
self.group,
|
352
|
+
self.num_max_dispatch_tokens_per_rank,
|
353
|
+
self.hidden_size,
|
354
|
+
self.num_experts,
|
355
|
+
)
|
356
|
+
self.return_recv_hook = return_recv_hook
|
357
|
+
|
358
|
+
def dispatch_a(
|
359
|
+
self,
|
360
|
+
hidden_states: torch.Tensor,
|
361
|
+
topk_idx: torch.Tensor,
|
362
|
+
topk_weights: torch.Tensor,
|
363
|
+
num_experts: int,
|
364
|
+
num_max_dispatch_tokens_per_rank: int,
|
365
|
+
):
|
366
|
+
topk_idx = topk_idx.to(torch.int64)
|
367
|
+
expected_m = (
|
368
|
+
hidden_states.shape[0]
|
369
|
+
* self.buffer_low_latency.group_size
|
370
|
+
* topk_idx.shape[1]
|
371
|
+
+ num_experts
|
372
|
+
) // num_experts
|
373
|
+
hidden_states, masked_m, event, hook = self._dispatch_core(
|
374
|
+
hidden_states,
|
375
|
+
topk_idx,
|
376
|
+
num_max_dispatch_tokens_per_rank,
|
377
|
+
num_experts,
|
378
|
+
use_fp8=True,
|
379
|
+
)
|
380
|
+
return (
|
381
|
+
hidden_states,
|
382
|
+
topk_idx,
|
383
|
+
topk_weights,
|
384
|
+
masked_m,
|
385
|
+
expected_m,
|
386
|
+
event,
|
387
|
+
hook,
|
388
|
+
)
|
389
|
+
|
390
|
+
def dispatch_b(
|
391
|
+
self,
|
392
|
+
hidden_states,
|
393
|
+
topk_idx,
|
394
|
+
topk_weights,
|
395
|
+
masked_m,
|
396
|
+
expected_m,
|
397
|
+
event,
|
398
|
+
hook,
|
399
|
+
):
|
400
|
+
hook() if self.return_recv_hook else event.current_stream_wait()
|
401
|
+
|
402
|
+
reorder_topk_ids = seg_indptr = None
|
403
|
+
|
404
|
+
return (
|
405
|
+
hidden_states,
|
406
|
+
topk_idx,
|
407
|
+
topk_weights,
|
408
|
+
reorder_topk_ids,
|
409
|
+
seg_indptr,
|
410
|
+
masked_m,
|
411
|
+
expected_m,
|
412
|
+
)
|
413
|
+
|
414
|
+
def _dispatch_core(
|
287
415
|
self,
|
288
416
|
hidden_states: torch.Tensor,
|
289
417
|
topk_idx: torch.Tensor,
|
290
418
|
num_max_dispatch_tokens_per_rank: int,
|
291
419
|
num_experts: int,
|
420
|
+
use_fp8: bool = False,
|
292
421
|
):
|
293
422
|
"""
|
294
|
-
# For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'
|
295
|
-
# Please
|
423
|
+
# For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'.
|
424
|
+
# Please make sure to change DeepEP code in internode_ll.cu dispatch / combine as below first and then reinstall.
|
296
425
|
# More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
|
297
|
-
|
426
|
+
|
298
427
|
diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
|
299
|
-
index
|
428
|
+
index 76ae2e2..8ecd08f 100644
|
300
429
|
--- a/csrc/kernels/internode_ll.cu
|
301
430
|
+++ b/csrc/kernels/internode_ll.cu
|
302
|
-
@@ -
|
303
|
-
int num_topk, int num_experts, int rank, int num_ranks,
|
431
|
+
@@ -310,8 +310,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
432
|
+
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
304
433
|
void* workspace, cudaStream_t stream, int phases) {
|
305
434
|
constexpr int kNumMaxTopK = 9;
|
306
435
|
- constexpr int kNumWarpsPerGroup = 10;
|
@@ -308,16 +437,9 @@ class DeepEPDispatcher:
|
|
308
437
|
+ constexpr int kNumWarpsPerGroup = 8;
|
309
438
|
+ constexpr int kNumWarpGroups = 4;
|
310
439
|
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
|
311
|
-
|
440
|
+
|
312
441
|
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
313
|
-
|
314
|
-
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
315
|
-
- EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
316
|
-
+ // EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
317
|
-
+
|
318
|
-
// Workspace checks
|
319
|
-
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
|
320
|
-
@@ -505,8 +505,8 @@ void combine(void* combined_x,
|
442
|
+
@@ -501,8 +501,8 @@ void combine(void* combined_x,
|
321
443
|
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
322
444
|
int num_topk, int num_experts, int rank, int num_ranks,
|
323
445
|
void* workspace, cudaStream_t stream, int phases) {
|
@@ -326,91 +448,152 @@ class DeepEPDispatcher:
|
|
326
448
|
+ constexpr int kNumWarpsPerGroup = 8;
|
327
449
|
+ constexpr int kNumWarpGroups = 4;
|
328
450
|
constexpr int kNumMaxTopk = 9;
|
329
|
-
|
451
|
+
|
330
452
|
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
331
453
|
"""
|
332
454
|
|
333
|
-
|
455
|
+
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
334
456
|
self.buffer_low_latency.low_latency_dispatch(
|
335
457
|
hidden_states,
|
336
458
|
topk_idx,
|
337
459
|
num_max_dispatch_tokens_per_rank,
|
338
460
|
num_experts,
|
339
|
-
|
340
|
-
|
461
|
+
use_fp8=use_fp8,
|
462
|
+
async_finish=not self.return_recv_hook,
|
463
|
+
return_recv_hook=self.return_recv_hook,
|
341
464
|
)
|
342
465
|
)
|
343
|
-
|
344
|
-
return recv_hidden_states, recv_expert_count, handle, event, hook
|
345
|
-
|
346
|
-
def combine(
|
347
|
-
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
348
|
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
349
|
-
# Todo: enable low latency combine
|
350
|
-
if True: # not forward_mode.is_decode():
|
351
|
-
if hidden_states.shape[0] > 0:
|
352
|
-
num_tokens = self.src2dst.shape[0] // self.router_topk
|
353
|
-
output = torch.empty(
|
354
|
-
(num_tokens, hidden_states.shape[1]),
|
355
|
-
device=hidden_states.device,
|
356
|
-
dtype=hidden_states.dtype,
|
357
|
-
)
|
358
|
-
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
359
|
-
hidden_states,
|
360
|
-
output,
|
361
|
-
self.src2dst,
|
362
|
-
self.topk_idx,
|
363
|
-
self.topk_weights,
|
364
|
-
self.router_topk,
|
365
|
-
hidden_states.shape[1],
|
366
|
-
BLOCK_SIZE=512,
|
367
|
-
)
|
368
|
-
else:
|
369
|
-
output = torch.zeros(
|
370
|
-
(0, hidden_states.shape[1]),
|
371
|
-
device=hidden_states.device,
|
372
|
-
dtype=hidden_states.dtype,
|
373
|
-
)
|
374
|
-
hidden_states, event = self.combine_normal(output, self.handle)
|
375
|
-
else:
|
376
|
-
hidden_states, event, hook = self.combine_low_latency(
|
377
|
-
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
378
|
-
)
|
466
|
+
return packed_recv_hidden, packed_recv_count, event, hook
|
379
467
|
|
380
|
-
|
381
|
-
|
468
|
+
def combine_a(
|
469
|
+
self,
|
470
|
+
hidden_states: torch.Tensor,
|
471
|
+
topk_idx: torch.Tensor,
|
472
|
+
topk_weights: torch.Tensor,
|
473
|
+
):
|
474
|
+
hidden_states, event, hook = self._combine_core(
|
475
|
+
hidden_states,
|
476
|
+
topk_idx,
|
477
|
+
topk_weights,
|
478
|
+
)
|
479
|
+
return hidden_states, event, hook
|
382
480
|
|
383
|
-
|
481
|
+
def combine_b(self, hidden_states, event, hook):
|
482
|
+
hook() if self.return_recv_hook else event.current_stream_wait()
|
384
483
|
return hidden_states
|
385
484
|
|
386
|
-
def
|
387
|
-
previous_event = Buffer.capture() if self.async_finish else None
|
388
|
-
|
389
|
-
combined_x, _, event = self.buffer_normal.combine(
|
390
|
-
x,
|
391
|
-
handle,
|
392
|
-
async_finish=self.async_finish,
|
393
|
-
previous_event=previous_event,
|
394
|
-
allocate_on_comm_stream=previous_event is not None,
|
395
|
-
)
|
396
|
-
return combined_x, event
|
397
|
-
|
398
|
-
def combine_low_latency(
|
485
|
+
def _combine_core(
|
399
486
|
self,
|
400
487
|
hidden_states: torch.Tensor,
|
401
488
|
topk_idx: torch.Tensor,
|
402
489
|
topk_weights: torch.Tensor,
|
403
|
-
handle: Tuple,
|
404
490
|
):
|
405
|
-
combined_hidden_states,
|
491
|
+
combined_hidden_states, event, hook = (
|
406
492
|
self.buffer_low_latency.low_latency_combine(
|
407
493
|
hidden_states,
|
408
494
|
topk_idx,
|
409
495
|
topk_weights,
|
410
|
-
handle,
|
411
|
-
async_finish=self.
|
412
|
-
return_recv_hook=
|
496
|
+
self.handle,
|
497
|
+
async_finish=not self.return_recv_hook,
|
498
|
+
return_recv_hook=self.return_recv_hook,
|
413
499
|
)
|
414
500
|
)
|
415
|
-
|
416
|
-
return combined_hidden_states,
|
501
|
+
self.handle = None
|
502
|
+
return combined_hidden_states, event, hook
|
503
|
+
|
504
|
+
|
505
|
+
class DeepEPDispatcher:
|
506
|
+
def __init__(
|
507
|
+
self,
|
508
|
+
group: torch.distributed.ProcessGroup,
|
509
|
+
router_topk: int,
|
510
|
+
permute_fusion: bool = False,
|
511
|
+
num_experts: int = None,
|
512
|
+
num_local_experts: int = None,
|
513
|
+
hidden_size: int = None,
|
514
|
+
params_dtype: torch.dtype = None,
|
515
|
+
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
516
|
+
async_finish: bool = False,
|
517
|
+
return_recv_hook: bool = False,
|
518
|
+
):
|
519
|
+
self.deepep_mode = deepep_mode
|
520
|
+
|
521
|
+
common_kwargs = dict(
|
522
|
+
group=group,
|
523
|
+
router_topk=router_topk,
|
524
|
+
permute_fusion=permute_fusion,
|
525
|
+
num_experts=num_experts,
|
526
|
+
num_local_experts=num_local_experts,
|
527
|
+
hidden_size=hidden_size,
|
528
|
+
params_dtype=params_dtype,
|
529
|
+
)
|
530
|
+
|
531
|
+
if self.deepep_mode.enable_normal():
|
532
|
+
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
|
533
|
+
async_finish=async_finish,
|
534
|
+
**common_kwargs,
|
535
|
+
)
|
536
|
+
if self.deepep_mode.enable_low_latency():
|
537
|
+
self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency(
|
538
|
+
return_recv_hook=return_recv_hook,
|
539
|
+
**common_kwargs,
|
540
|
+
)
|
541
|
+
|
542
|
+
def dispatch(self, *args, **kwargs) -> Tuple:
|
543
|
+
self.dispatch_a(*args, **kwargs)
|
544
|
+
return self.dispatch_b()
|
545
|
+
|
546
|
+
def dispatch_a(
|
547
|
+
self,
|
548
|
+
hidden_states: torch.Tensor,
|
549
|
+
topk_idx: torch.Tensor,
|
550
|
+
topk_weights: torch.Tensor,
|
551
|
+
num_experts: int,
|
552
|
+
num_max_dispatch_tokens_per_rank: int = 128,
|
553
|
+
forward_mode: ForwardMode = None,
|
554
|
+
):
|
555
|
+
inner_state = self._get_impl(forward_mode).dispatch_a(
|
556
|
+
hidden_states=hidden_states,
|
557
|
+
topk_idx=topk_idx,
|
558
|
+
topk_weights=topk_weights,
|
559
|
+
num_experts=num_experts,
|
560
|
+
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
561
|
+
)
|
562
|
+
self._dispatch_intermediate_state = forward_mode, inner_state
|
563
|
+
|
564
|
+
def dispatch_b(self):
|
565
|
+
forward_mode, inner_state = self._dispatch_intermediate_state
|
566
|
+
del self._dispatch_intermediate_state
|
567
|
+
return self._get_impl(forward_mode).dispatch_b(*inner_state)
|
568
|
+
|
569
|
+
def combine(self, *args, **kwargs) -> Tuple:
|
570
|
+
self.combine_a(*args, **kwargs)
|
571
|
+
return self.combine_b()
|
572
|
+
|
573
|
+
def combine_a(
|
574
|
+
self,
|
575
|
+
hidden_states: torch.Tensor,
|
576
|
+
topk_idx: torch.Tensor,
|
577
|
+
topk_weights: torch.Tensor,
|
578
|
+
forward_mode: ForwardMode,
|
579
|
+
):
|
580
|
+
inner_state = self._get_impl(forward_mode).combine_a(
|
581
|
+
hidden_states=hidden_states,
|
582
|
+
topk_idx=topk_idx,
|
583
|
+
topk_weights=topk_weights,
|
584
|
+
)
|
585
|
+
self._combine_intermediate_state = forward_mode, inner_state
|
586
|
+
|
587
|
+
def combine_b(self):
|
588
|
+
forward_mode, inner_state = self._combine_intermediate_state
|
589
|
+
del self._combine_intermediate_state
|
590
|
+
return self._get_impl(forward_mode).combine_b(*inner_state)
|
591
|
+
|
592
|
+
def _get_impl(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase":
|
593
|
+
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
594
|
+
if resolved_deepep_mode == DeepEPMode.normal:
|
595
|
+
return self._normal_dispatcher
|
596
|
+
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
597
|
+
return self._low_latency_dispatcher
|
598
|
+
else:
|
599
|
+
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|