sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +40 -4
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +609 -202
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +28 -14
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +14 -0
- sglang/srt/layers/rotary_embedding.py +75 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +49 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +13 -4
- sglang/srt/models/llama4.py +487 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +227 -0
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -82,6 +82,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
82
82
|
self.max_context_len = model_runner.model_config.context_len
|
83
83
|
self.skip_prefill = skip_prefill
|
84
84
|
self.is_multimodal = model_runner.model_config.is_multimodal
|
85
|
+
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
86
|
+
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
85
87
|
|
86
88
|
assert not (
|
87
89
|
model_runner.sliding_window_size is not None
|
@@ -391,6 +393,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
391
393
|
forward_batch: ForwardBatch,
|
392
394
|
save_kv_cache=True,
|
393
395
|
):
|
396
|
+
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
397
|
+
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
394
398
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
395
399
|
self._get_wrapper_idx(layer)
|
396
400
|
]
|
@@ -407,7 +411,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
407
411
|
assert v is not None
|
408
412
|
if save_kv_cache:
|
409
413
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
410
|
-
layer, cache_loc, k, v,
|
414
|
+
layer, cache_loc, k, v, k_scale, v_scale
|
411
415
|
)
|
412
416
|
|
413
417
|
o = prefill_wrapper_paged.forward(
|
@@ -417,8 +421,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
417
421
|
sm_scale=layer.scaling,
|
418
422
|
window_left=layer.sliding_window_size,
|
419
423
|
logits_soft_cap=logits_soft_cap,
|
420
|
-
k_scale=
|
421
|
-
v_scale=
|
424
|
+
k_scale=k_scale,
|
425
|
+
v_scale=v_scale,
|
422
426
|
)
|
423
427
|
else:
|
424
428
|
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
@@ -445,7 +449,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
445
449
|
|
446
450
|
if save_kv_cache:
|
447
451
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
448
|
-
layer, cache_loc, k, v,
|
452
|
+
layer, cache_loc, k, v, k_scale, v_scale
|
449
453
|
)
|
450
454
|
|
451
455
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
@@ -459,6 +463,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
459
463
|
forward_batch: ForwardBatch,
|
460
464
|
save_kv_cache=True,
|
461
465
|
):
|
466
|
+
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
467
|
+
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
462
468
|
decode_wrapper = self.forward_metadata.decode_wrappers[
|
463
469
|
self._get_wrapper_idx(layer)
|
464
470
|
]
|
@@ -472,7 +478,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
472
478
|
assert v is not None
|
473
479
|
if save_kv_cache:
|
474
480
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
475
|
-
layer, cache_loc, k, v,
|
481
|
+
layer, cache_loc, k, v, k_scale, v_scale
|
476
482
|
)
|
477
483
|
|
478
484
|
o = decode_wrapper.forward(
|
@@ -480,8 +486,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
480
486
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
481
487
|
sm_scale=layer.scaling,
|
482
488
|
logits_soft_cap=layer.logit_cap,
|
483
|
-
k_scale=
|
484
|
-
v_scale=
|
489
|
+
k_scale=k_scale,
|
490
|
+
v_scale=v_scale,
|
485
491
|
)
|
486
492
|
|
487
493
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
@@ -192,8 +192,7 @@ def _dp_gather(
|
|
192
192
|
|
193
193
|
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
|
194
194
|
assert (
|
195
|
-
|
196
|
-
!= local_tokens.untyped_storage().data_ptr()
|
195
|
+
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
197
196
|
), "aliasing between global_tokens and local_tokens not allowed"
|
198
197
|
memcpy_triton(
|
199
198
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
@@ -243,8 +242,7 @@ def dp_scatter(
|
|
243
242
|
assert global_tokens.is_contiguous()
|
244
243
|
if local_tokens.shape[0] > 0:
|
245
244
|
assert (
|
246
|
-
local_tokens.untyped_storage().
|
247
|
-
!= global_tokens.untyped_storage().data_ptr()
|
245
|
+
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
248
246
|
), "aliasing between local_tokens and global_tokens not allowed"
|
249
247
|
memcpy_triton(
|
250
248
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
sglang/srt/layers/elementwise.py
CHANGED
@@ -4,6 +4,10 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
|
+
from sglang.srt.utils import is_hip
|
8
|
+
|
9
|
+
_is_hip = is_hip()
|
10
|
+
|
7
11
|
fused_softcap_autotune = triton.autotune(
|
8
12
|
configs=[
|
9
13
|
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
|
@@ -185,6 +189,9 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
|
|
185
189
|
assert x.shape == residual.shape and x.dtype == residual.dtype
|
186
190
|
output, mid = torch.empty_like(x), torch.empty_like(x)
|
187
191
|
bs, hidden_dim = x.shape
|
192
|
+
|
193
|
+
min_num_warps = 16 if _is_hip else 32
|
194
|
+
|
188
195
|
if autotune:
|
189
196
|
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
|
190
197
|
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
|
@@ -193,7 +200,10 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
|
|
193
200
|
config = {
|
194
201
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
195
202
|
"num_warps": max(
|
196
|
-
min(
|
203
|
+
min(
|
204
|
+
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
|
205
|
+
),
|
206
|
+
4,
|
197
207
|
),
|
198
208
|
}
|
199
209
|
|
@@ -250,10 +260,13 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
|
|
250
260
|
else:
|
251
261
|
output = torch.empty_like(x)
|
252
262
|
bs, hidden_dim = x.shape
|
263
|
+
|
264
|
+
min_num_warps = 16 if _is_hip else 32
|
265
|
+
|
253
266
|
config = {
|
254
267
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
255
268
|
"num_warps": max(
|
256
|
-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)),
|
269
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
|
257
270
|
),
|
258
271
|
}
|
259
272
|
|
sglang/srt/layers/linear.py
CHANGED
@@ -7,6 +7,7 @@ try:
|
|
7
7
|
except ImportError:
|
8
8
|
use_deepep = False
|
9
9
|
|
10
|
+
from enum import IntEnum, auto
|
10
11
|
from typing import Optional, Tuple
|
11
12
|
|
12
13
|
import torch
|
@@ -19,70 +20,95 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
19
20
|
)
|
20
21
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
21
22
|
|
22
|
-
_buffer_normal = None
|
23
|
-
_buffer_low_latency = None
|
24
23
|
|
24
|
+
class DeepEPDispatchMode(IntEnum):
|
25
|
+
NORMAL = auto()
|
26
|
+
LOW_LATENCY = auto()
|
25
27
|
|
26
|
-
def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
|
27
|
-
"""
|
28
|
-
Copy from DeepEP example usage in model inference prefilling.
|
29
|
-
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
|
30
|
-
"""
|
31
28
|
|
32
|
-
|
29
|
+
class DeepEPBuffer:
|
33
30
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
num_nvl_bytes = max(
|
40
|
-
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
|
41
|
-
)
|
42
|
-
num_rdma_bytes = max(
|
43
|
-
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
|
44
|
-
)
|
31
|
+
_buffer = None
|
32
|
+
_dispatch_mode: Optional[DeepEPDispatchMode] = None
|
33
|
+
_hidden_size: Optional[int] = None
|
34
|
+
_num_max_dispatch_tokens_per_rank: Optional[int] = None
|
35
|
+
_num_experts: Optional[int] = None
|
45
36
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
def _get_buffer_low_latency(
|
57
|
-
group: dist.ProcessGroup,
|
58
|
-
num_max_dispatch_tokens_per_rank: int,
|
59
|
-
hidden: int,
|
60
|
-
num_experts: int,
|
61
|
-
):
|
62
|
-
"""
|
63
|
-
Copy from DeepEP example usage in model inference decoding.
|
64
|
-
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
65
|
-
"""
|
66
|
-
|
67
|
-
global _buffer_low_latency
|
68
|
-
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
|
69
|
-
num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts
|
70
|
-
)
|
71
|
-
|
72
|
-
if (
|
73
|
-
_buffer_low_latency is None
|
74
|
-
or _buffer_low_latency.group != group
|
75
|
-
or not _buffer_low_latency.low_latency_mode
|
76
|
-
or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes
|
37
|
+
@classmethod
|
38
|
+
def get_deepep_buffer(
|
39
|
+
cls,
|
40
|
+
group: dist.ProcessGroup,
|
41
|
+
hidden_size: int,
|
42
|
+
param_bytes: int,
|
43
|
+
deepep_mode: DeepEPMode,
|
44
|
+
num_max_dispatch_tokens_per_rank: int = None,
|
45
|
+
num_experts: int = None,
|
77
46
|
):
|
78
|
-
|
79
|
-
|
47
|
+
if cls._buffer is not None:
|
48
|
+
return cls._buffer
|
49
|
+
|
50
|
+
cls._hidden_size = hidden_size
|
51
|
+
cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
52
|
+
cls._num_experts = num_experts
|
53
|
+
|
54
|
+
num_nvl_bytes, num_rdma_bytes = 0, 0
|
55
|
+
if deepep_mode.enable_normal():
|
56
|
+
hidden_bytes = hidden_size * param_bytes
|
57
|
+
for config in (
|
58
|
+
Buffer.get_dispatch_config(group.size()),
|
59
|
+
Buffer.get_combine_config(group.size()),
|
60
|
+
):
|
61
|
+
num_nvl_bytes = max(
|
62
|
+
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
|
63
|
+
num_nvl_bytes,
|
64
|
+
)
|
65
|
+
num_rdma_bytes = max(
|
66
|
+
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()),
|
67
|
+
num_rdma_bytes,
|
68
|
+
)
|
69
|
+
if deepep_mode.enable_low_latency():
|
70
|
+
assert num_max_dispatch_tokens_per_rank is not None
|
71
|
+
assert num_experts is not None and num_experts % group.size() == 0
|
72
|
+
num_rdma_bytes = max(
|
73
|
+
Buffer.get_low_latency_rdma_size_hint(
|
74
|
+
num_max_dispatch_tokens_per_rank,
|
75
|
+
hidden_size,
|
76
|
+
group.size(),
|
77
|
+
num_experts,
|
78
|
+
),
|
79
|
+
num_rdma_bytes,
|
80
|
+
)
|
81
|
+
|
82
|
+
cls._buffer = Buffer(
|
80
83
|
group,
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
+
num_nvl_bytes,
|
85
|
+
num_rdma_bytes,
|
86
|
+
low_latency_mode=deepep_mode.enable_low_latency(),
|
87
|
+
num_qps_per_rank=(
|
88
|
+
num_experts // group.size() if deepep_mode.enable_low_latency() else 1
|
89
|
+
),
|
84
90
|
)
|
85
|
-
|
91
|
+
return cls._buffer
|
92
|
+
|
93
|
+
@classmethod
|
94
|
+
def clean_buffer(cls):
|
95
|
+
if not cls._buffer.low_latency_mode:
|
96
|
+
return
|
97
|
+
cls._buffer.clean_low_latency_buffer(
|
98
|
+
cls._num_max_dispatch_tokens_per_rank,
|
99
|
+
cls._hidden_size,
|
100
|
+
cls._num_experts,
|
101
|
+
)
|
102
|
+
|
103
|
+
@classmethod
|
104
|
+
def set_dispatch_mode_as_normal(cls):
|
105
|
+
cls._dispatch_mode = DeepEPDispatchMode.NORMAL
|
106
|
+
|
107
|
+
@classmethod
|
108
|
+
def set_dispatch_mode_as_low_latency(cls):
|
109
|
+
if cls._dispatch_mode == DeepEPDispatchMode.NORMAL:
|
110
|
+
cls.clean_buffer()
|
111
|
+
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
86
112
|
|
87
113
|
|
88
114
|
class _DeepEPDispatcherImplBase:
|
@@ -95,6 +121,7 @@ class _DeepEPDispatcherImplBase:
|
|
95
121
|
num_local_experts: int,
|
96
122
|
hidden_size: int,
|
97
123
|
params_dtype: torch.dtype,
|
124
|
+
deepep_mode: DeepEPMode,
|
98
125
|
):
|
99
126
|
if not use_deepep:
|
100
127
|
raise ImportError(
|
@@ -109,7 +136,10 @@ class _DeepEPDispatcherImplBase:
|
|
109
136
|
self.num_local_experts = num_local_experts
|
110
137
|
self.hidden_size = hidden_size
|
111
138
|
self.params_dtype = params_dtype
|
139
|
+
self.deepep_mode = deepep_mode
|
140
|
+
|
112
141
|
self.params_bytes = 2
|
142
|
+
self.num_max_dispatch_tokens_per_rank = 128
|
113
143
|
|
114
144
|
self.handle = None
|
115
145
|
|
@@ -118,8 +148,6 @@ class _DeepEPDispatcherImplBase:
|
|
118
148
|
hidden_states: torch.Tensor,
|
119
149
|
topk_idx: torch.Tensor,
|
120
150
|
topk_weights: torch.Tensor,
|
121
|
-
num_experts: int,
|
122
|
-
num_max_dispatch_tokens_per_rank: int,
|
123
151
|
):
|
124
152
|
raise NotImplementedError
|
125
153
|
|
@@ -137,14 +165,14 @@ class _DeepEPDispatcherImplBase:
|
|
137
165
|
def combine_b(self, *args, **kwargs):
|
138
166
|
raise NotImplementedError
|
139
167
|
|
168
|
+
def _get_buffer(self):
|
169
|
+
raise NotImplementedError
|
170
|
+
|
140
171
|
|
141
172
|
class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
142
173
|
def __init__(self, async_finish: bool, **kwargs):
|
143
174
|
super().__init__(**kwargs)
|
144
175
|
|
145
|
-
self.buffer_normal = _get_buffer_normal(
|
146
|
-
self.group, self.hidden_size * self.params_bytes
|
147
|
-
)
|
148
176
|
self.async_finish = async_finish
|
149
177
|
self.src2dst = None
|
150
178
|
|
@@ -153,24 +181,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
153
181
|
hidden_states: torch.Tensor,
|
154
182
|
topk_idx: torch.Tensor,
|
155
183
|
topk_weights: torch.Tensor,
|
156
|
-
num_experts: int,
|
157
|
-
num_max_dispatch_tokens_per_rank: int,
|
158
184
|
):
|
159
185
|
topk_idx = topk_idx.to(torch.int64)
|
160
186
|
previous_event = Buffer.capture() if self.async_finish else None
|
161
|
-
return hidden_states, topk_idx, topk_weights,
|
187
|
+
return hidden_states, topk_idx, topk_weights, previous_event
|
162
188
|
|
163
|
-
def dispatch_b(
|
164
|
-
self, hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
165
|
-
):
|
189
|
+
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
166
190
|
(
|
167
191
|
hidden_states,
|
168
192
|
topk_idx,
|
169
193
|
topk_weights,
|
170
194
|
event,
|
171
|
-
) = self._dispatch_core(
|
172
|
-
hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
173
|
-
)
|
195
|
+
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
|
174
196
|
event.current_stream_wait() if self.async_finish else ()
|
175
197
|
if hidden_states.shape[0] > 0:
|
176
198
|
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
@@ -181,7 +203,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
181
203
|
(0,), device=hidden_states.device, dtype=torch.int64
|
182
204
|
)
|
183
205
|
seg_indptr = torch.zeros(
|
184
|
-
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
206
|
+
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
185
207
|
)
|
186
208
|
|
187
209
|
masked_m = expected_m = None
|
@@ -201,18 +223,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
201
223
|
x: torch.Tensor,
|
202
224
|
topk_idx: torch.Tensor,
|
203
225
|
topk_weights: torch.Tensor,
|
204
|
-
num_experts: int,
|
205
226
|
previous_event,
|
206
227
|
):
|
228
|
+
buffer = self._get_buffer()
|
207
229
|
(
|
208
230
|
num_tokens_per_rank,
|
209
231
|
num_tokens_per_rdma_rank,
|
210
232
|
num_tokens_per_expert,
|
211
233
|
is_token_in_rank,
|
212
234
|
previous_event,
|
213
|
-
) =
|
235
|
+
) = buffer.get_dispatch_layout(
|
214
236
|
topk_idx,
|
215
|
-
num_experts,
|
237
|
+
self.num_experts,
|
216
238
|
previous_event=previous_event,
|
217
239
|
async_finish=self.async_finish,
|
218
240
|
allocate_on_comm_stream=previous_event is not None,
|
@@ -221,6 +243,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
221
243
|
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
|
222
244
|
# However, doing this would incur an unknown synchronization error, but keeping
|
223
245
|
# `handle` as a member variable works.
|
246
|
+
|
224
247
|
(
|
225
248
|
recv_x,
|
226
249
|
recv_topk_idx,
|
@@ -228,7 +251,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
228
251
|
_, # num_recv_tokens_per_expert_list
|
229
252
|
self.handle,
|
230
253
|
event,
|
231
|
-
) =
|
254
|
+
) = buffer.dispatch(
|
232
255
|
x,
|
233
256
|
topk_idx=topk_idx,
|
234
257
|
topk_weights=topk_weights,
|
@@ -327,7 +350,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
327
350
|
return hidden_states
|
328
351
|
|
329
352
|
def _combine_core(self, x: torch.Tensor, previous_event):
|
330
|
-
|
353
|
+
buffer = self._get_buffer()
|
354
|
+
combined_x, _, event = buffer.combine(
|
331
355
|
x,
|
332
356
|
self.handle,
|
333
357
|
async_finish=self.async_finish,
|
@@ -336,6 +360,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
336
360
|
)
|
337
361
|
return combined_x, event
|
338
362
|
|
363
|
+
def _get_buffer(self):
|
364
|
+
DeepEPBuffer.set_dispatch_mode_as_normal()
|
365
|
+
return DeepEPBuffer.get_deepep_buffer(
|
366
|
+
self.group,
|
367
|
+
self.hidden_size,
|
368
|
+
self.params_bytes,
|
369
|
+
self.deepep_mode,
|
370
|
+
self.num_max_dispatch_tokens_per_rank,
|
371
|
+
self.num_experts,
|
372
|
+
)
|
373
|
+
|
339
374
|
|
340
375
|
class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
341
376
|
def __init__(self, return_recv_hook: bool, **kwargs):
|
@@ -345,14 +380,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
345
380
|
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
|
346
381
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
347
382
|
"""
|
348
|
-
# TODO(ch-wan): allow users to set this value
|
349
|
-
self.num_max_dispatch_tokens_per_rank = 128
|
350
|
-
self.buffer_low_latency = _get_buffer_low_latency(
|
351
|
-
self.group,
|
352
|
-
self.num_max_dispatch_tokens_per_rank,
|
353
|
-
self.hidden_size,
|
354
|
-
self.num_experts,
|
355
|
-
)
|
356
383
|
self.return_recv_hook = return_recv_hook
|
357
384
|
|
358
385
|
def dispatch_a(
|
@@ -360,21 +387,16 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
360
387
|
hidden_states: torch.Tensor,
|
361
388
|
topk_idx: torch.Tensor,
|
362
389
|
topk_weights: torch.Tensor,
|
363
|
-
num_experts: int,
|
364
|
-
num_max_dispatch_tokens_per_rank: int,
|
365
390
|
):
|
391
|
+
buffer = self._get_buffer()
|
366
392
|
topk_idx = topk_idx.to(torch.int64)
|
367
393
|
expected_m = (
|
368
|
-
hidden_states.shape[0]
|
369
|
-
|
370
|
-
|
371
|
-
+ num_experts
|
372
|
-
) // num_experts
|
394
|
+
hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
|
395
|
+
+ self.num_experts
|
396
|
+
) // self.num_experts
|
373
397
|
hidden_states, masked_m, event, hook = self._dispatch_core(
|
374
398
|
hidden_states,
|
375
399
|
topk_idx,
|
376
|
-
num_max_dispatch_tokens_per_rank,
|
377
|
-
num_experts,
|
378
400
|
use_fp8=True,
|
379
401
|
)
|
380
402
|
return (
|
@@ -415,8 +437,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
415
437
|
self,
|
416
438
|
hidden_states: torch.Tensor,
|
417
439
|
topk_idx: torch.Tensor,
|
418
|
-
num_max_dispatch_tokens_per_rank: int,
|
419
|
-
num_experts: int,
|
420
440
|
use_fp8: bool = False,
|
421
441
|
):
|
422
442
|
"""
|
@@ -451,13 +471,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
451
471
|
|
452
472
|
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
453
473
|
"""
|
454
|
-
|
474
|
+
buffer = self._get_buffer()
|
455
475
|
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
456
|
-
|
476
|
+
buffer.low_latency_dispatch(
|
457
477
|
hidden_states,
|
458
478
|
topk_idx,
|
459
|
-
num_max_dispatch_tokens_per_rank,
|
460
|
-
num_experts,
|
479
|
+
self.num_max_dispatch_tokens_per_rank,
|
480
|
+
self.num_experts,
|
461
481
|
use_fp8=use_fp8,
|
462
482
|
async_finish=not self.return_recv_hook,
|
463
483
|
return_recv_hook=self.return_recv_hook,
|
@@ -488,19 +508,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
488
508
|
topk_idx: torch.Tensor,
|
489
509
|
topk_weights: torch.Tensor,
|
490
510
|
):
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
)
|
511
|
+
buffer = self._get_buffer()
|
512
|
+
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
513
|
+
hidden_states,
|
514
|
+
topk_idx,
|
515
|
+
topk_weights,
|
516
|
+
self.handle,
|
517
|
+
async_finish=not self.return_recv_hook,
|
518
|
+
return_recv_hook=self.return_recv_hook,
|
500
519
|
)
|
501
520
|
self.handle = None
|
502
521
|
return combined_hidden_states, event, hook
|
503
522
|
|
523
|
+
def _get_buffer(self):
|
524
|
+
DeepEPBuffer.set_dispatch_mode_as_low_latency()
|
525
|
+
return DeepEPBuffer.get_deepep_buffer(
|
526
|
+
self.group,
|
527
|
+
self.hidden_size,
|
528
|
+
self.params_bytes,
|
529
|
+
self.deepep_mode,
|
530
|
+
self.num_max_dispatch_tokens_per_rank,
|
531
|
+
self.num_experts,
|
532
|
+
)
|
533
|
+
|
504
534
|
|
505
535
|
class DeepEPDispatcher:
|
506
536
|
def __init__(
|
@@ -526,18 +556,19 @@ class DeepEPDispatcher:
|
|
526
556
|
num_local_experts=num_local_experts,
|
527
557
|
hidden_size=hidden_size,
|
528
558
|
params_dtype=params_dtype,
|
559
|
+
deepep_mode=deepep_mode,
|
529
560
|
)
|
530
561
|
|
531
|
-
if self.deepep_mode.enable_normal():
|
532
|
-
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
|
533
|
-
async_finish=async_finish,
|
534
|
-
**common_kwargs,
|
535
|
-
)
|
536
562
|
if self.deepep_mode.enable_low_latency():
|
537
563
|
self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency(
|
538
564
|
return_recv_hook=return_recv_hook,
|
539
565
|
**common_kwargs,
|
540
566
|
)
|
567
|
+
if self.deepep_mode.enable_normal():
|
568
|
+
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
|
569
|
+
async_finish=async_finish,
|
570
|
+
**common_kwargs,
|
571
|
+
)
|
541
572
|
|
542
573
|
def dispatch(self, *args, **kwargs) -> Tuple:
|
543
574
|
self.dispatch_a(*args, **kwargs)
|
@@ -548,16 +579,12 @@ class DeepEPDispatcher:
|
|
548
579
|
hidden_states: torch.Tensor,
|
549
580
|
topk_idx: torch.Tensor,
|
550
581
|
topk_weights: torch.Tensor,
|
551
|
-
num_experts: int,
|
552
|
-
num_max_dispatch_tokens_per_rank: int = 128,
|
553
582
|
forward_mode: ForwardMode = None,
|
554
583
|
):
|
555
584
|
inner_state = self._get_impl(forward_mode).dispatch_a(
|
556
585
|
hidden_states=hidden_states,
|
557
586
|
topk_idx=topk_idx,
|
558
587
|
topk_weights=topk_weights,
|
559
|
-
num_experts=num_experts,
|
560
|
-
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
561
588
|
)
|
562
589
|
self._dispatch_intermediate_state = forward_mode, inner_state
|
563
590
|
|
@@ -589,7 +616,7 @@ class DeepEPDispatcher:
|
|
589
616
|
del self._combine_intermediate_state
|
590
617
|
return self._get_impl(forward_mode).combine_b(*inner_state)
|
591
618
|
|
592
|
-
def _get_impl(self, forward_mode: ForwardMode) ->
|
619
|
+
def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
|
593
620
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
594
621
|
if resolved_deepep_mode == DeepEPMode.normal:
|
595
622
|
return self._normal_dispatcher
|
@@ -23,9 +23,14 @@ def fused_moe_forward_native(
|
|
23
23
|
custom_routing_function: Optional[Callable] = None,
|
24
24
|
correction_bias: Optional[torch.Tensor] = None,
|
25
25
|
activation: str = "silu",
|
26
|
+
apply_router_weight_on_input: bool = False,
|
26
27
|
inplace: bool = True,
|
27
28
|
no_combine: bool = False,
|
28
29
|
) -> torch.Tensor:
|
30
|
+
|
31
|
+
if apply_router_weight_on_input:
|
32
|
+
raise NotImplementedError
|
33
|
+
|
29
34
|
topk_weights, topk_ids = select_experts(
|
30
35
|
hidden_states=x,
|
31
36
|
router_logits=router_logits,
|