sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ try:
|
|
7
7
|
except ImportError:
|
8
8
|
use_deepep = False
|
9
9
|
|
10
|
+
from enum import IntEnum, auto
|
10
11
|
from typing import Optional, Tuple
|
11
12
|
|
12
13
|
import torch
|
@@ -19,70 +20,95 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
19
20
|
)
|
20
21
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
21
22
|
|
22
|
-
_buffer_normal = None
|
23
|
-
_buffer_low_latency = None
|
24
23
|
|
24
|
+
class DeepEPDispatchMode(IntEnum):
|
25
|
+
NORMAL = auto()
|
26
|
+
LOW_LATENCY = auto()
|
25
27
|
|
26
|
-
def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
|
27
|
-
"""
|
28
|
-
Copy from DeepEP example usage in model inference prefilling.
|
29
|
-
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
|
30
|
-
"""
|
31
28
|
|
32
|
-
|
29
|
+
class DeepEPBuffer:
|
33
30
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
num_nvl_bytes = max(
|
40
|
-
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
|
41
|
-
)
|
42
|
-
num_rdma_bytes = max(
|
43
|
-
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
|
44
|
-
)
|
31
|
+
_buffer = None
|
32
|
+
_dispatch_mode: Optional[DeepEPDispatchMode] = None
|
33
|
+
_hidden_size: Optional[int] = None
|
34
|
+
_num_max_dispatch_tokens_per_rank: Optional[int] = None
|
35
|
+
_num_experts: Optional[int] = None
|
45
36
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
def _get_buffer_low_latency(
|
57
|
-
group: dist.ProcessGroup,
|
58
|
-
num_max_dispatch_tokens_per_rank: int,
|
59
|
-
hidden: int,
|
60
|
-
num_experts: int,
|
61
|
-
):
|
62
|
-
"""
|
63
|
-
Copy from DeepEP example usage in model inference decoding.
|
64
|
-
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
65
|
-
"""
|
66
|
-
|
67
|
-
global _buffer_low_latency
|
68
|
-
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
|
69
|
-
num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts
|
70
|
-
)
|
71
|
-
|
72
|
-
if (
|
73
|
-
_buffer_low_latency is None
|
74
|
-
or _buffer_low_latency.group != group
|
75
|
-
or not _buffer_low_latency.low_latency_mode
|
76
|
-
or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes
|
37
|
+
@classmethod
|
38
|
+
def get_deepep_buffer(
|
39
|
+
cls,
|
40
|
+
group: dist.ProcessGroup,
|
41
|
+
hidden_size: int,
|
42
|
+
param_bytes: int,
|
43
|
+
deepep_mode: DeepEPMode,
|
44
|
+
num_max_dispatch_tokens_per_rank: int = None,
|
45
|
+
num_experts: int = None,
|
77
46
|
):
|
78
|
-
|
79
|
-
|
47
|
+
if cls._buffer is not None:
|
48
|
+
return cls._buffer
|
49
|
+
|
50
|
+
cls._hidden_size = hidden_size
|
51
|
+
cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
52
|
+
cls._num_experts = num_experts
|
53
|
+
|
54
|
+
num_nvl_bytes, num_rdma_bytes = 0, 0
|
55
|
+
if deepep_mode.enable_normal():
|
56
|
+
hidden_bytes = hidden_size * param_bytes
|
57
|
+
for config in (
|
58
|
+
Buffer.get_dispatch_config(group.size()),
|
59
|
+
Buffer.get_combine_config(group.size()),
|
60
|
+
):
|
61
|
+
num_nvl_bytes = max(
|
62
|
+
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
|
63
|
+
num_nvl_bytes,
|
64
|
+
)
|
65
|
+
num_rdma_bytes = max(
|
66
|
+
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()),
|
67
|
+
num_rdma_bytes,
|
68
|
+
)
|
69
|
+
if deepep_mode.enable_low_latency():
|
70
|
+
assert num_max_dispatch_tokens_per_rank is not None
|
71
|
+
assert num_experts is not None and num_experts % group.size() == 0
|
72
|
+
num_rdma_bytes = max(
|
73
|
+
Buffer.get_low_latency_rdma_size_hint(
|
74
|
+
num_max_dispatch_tokens_per_rank,
|
75
|
+
hidden_size,
|
76
|
+
group.size(),
|
77
|
+
num_experts,
|
78
|
+
),
|
79
|
+
num_rdma_bytes,
|
80
|
+
)
|
81
|
+
|
82
|
+
cls._buffer = Buffer(
|
80
83
|
group,
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
+
num_nvl_bytes,
|
85
|
+
num_rdma_bytes,
|
86
|
+
low_latency_mode=deepep_mode.enable_low_latency(),
|
87
|
+
num_qps_per_rank=(
|
88
|
+
num_experts // group.size() if deepep_mode.enable_low_latency() else 1
|
89
|
+
),
|
84
90
|
)
|
85
|
-
|
91
|
+
return cls._buffer
|
92
|
+
|
93
|
+
@classmethod
|
94
|
+
def clean_buffer(cls):
|
95
|
+
if not cls._buffer.low_latency_mode:
|
96
|
+
return
|
97
|
+
cls._buffer.clean_low_latency_buffer(
|
98
|
+
cls._num_max_dispatch_tokens_per_rank,
|
99
|
+
cls._hidden_size,
|
100
|
+
cls._num_experts,
|
101
|
+
)
|
102
|
+
|
103
|
+
@classmethod
|
104
|
+
def set_dispatch_mode_as_normal(cls):
|
105
|
+
cls._dispatch_mode = DeepEPDispatchMode.NORMAL
|
106
|
+
|
107
|
+
@classmethod
|
108
|
+
def set_dispatch_mode_as_low_latency(cls):
|
109
|
+
if cls._dispatch_mode == DeepEPDispatchMode.NORMAL:
|
110
|
+
cls.clean_buffer()
|
111
|
+
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
86
112
|
|
87
113
|
|
88
114
|
class _DeepEPDispatcherImplBase:
|
@@ -95,6 +121,7 @@ class _DeepEPDispatcherImplBase:
|
|
95
121
|
num_local_experts: int,
|
96
122
|
hidden_size: int,
|
97
123
|
params_dtype: torch.dtype,
|
124
|
+
deepep_mode: DeepEPMode,
|
98
125
|
):
|
99
126
|
if not use_deepep:
|
100
127
|
raise ImportError(
|
@@ -109,7 +136,10 @@ class _DeepEPDispatcherImplBase:
|
|
109
136
|
self.num_local_experts = num_local_experts
|
110
137
|
self.hidden_size = hidden_size
|
111
138
|
self.params_dtype = params_dtype
|
139
|
+
self.deepep_mode = deepep_mode
|
140
|
+
|
112
141
|
self.params_bytes = 2
|
142
|
+
self.num_max_dispatch_tokens_per_rank = 128
|
113
143
|
|
114
144
|
self.handle = None
|
115
145
|
|
@@ -118,8 +148,6 @@ class _DeepEPDispatcherImplBase:
|
|
118
148
|
hidden_states: torch.Tensor,
|
119
149
|
topk_idx: torch.Tensor,
|
120
150
|
topk_weights: torch.Tensor,
|
121
|
-
num_experts: int,
|
122
|
-
num_max_dispatch_tokens_per_rank: int,
|
123
151
|
):
|
124
152
|
raise NotImplementedError
|
125
153
|
|
@@ -137,14 +165,14 @@ class _DeepEPDispatcherImplBase:
|
|
137
165
|
def combine_b(self, *args, **kwargs):
|
138
166
|
raise NotImplementedError
|
139
167
|
|
168
|
+
def _get_buffer(self):
|
169
|
+
raise NotImplementedError
|
170
|
+
|
140
171
|
|
141
172
|
class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
142
173
|
def __init__(self, async_finish: bool, **kwargs):
|
143
174
|
super().__init__(**kwargs)
|
144
175
|
|
145
|
-
self.buffer_normal = _get_buffer_normal(
|
146
|
-
self.group, self.hidden_size * self.params_bytes
|
147
|
-
)
|
148
176
|
self.async_finish = async_finish
|
149
177
|
self.src2dst = None
|
150
178
|
|
@@ -153,24 +181,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
153
181
|
hidden_states: torch.Tensor,
|
154
182
|
topk_idx: torch.Tensor,
|
155
183
|
topk_weights: torch.Tensor,
|
156
|
-
num_experts: int,
|
157
|
-
num_max_dispatch_tokens_per_rank: int,
|
158
184
|
):
|
159
185
|
topk_idx = topk_idx.to(torch.int64)
|
160
186
|
previous_event = Buffer.capture() if self.async_finish else None
|
161
|
-
return hidden_states, topk_idx, topk_weights,
|
187
|
+
return hidden_states, topk_idx, topk_weights, previous_event
|
162
188
|
|
163
|
-
def dispatch_b(
|
164
|
-
self, hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
165
|
-
):
|
189
|
+
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
166
190
|
(
|
167
191
|
hidden_states,
|
168
192
|
topk_idx,
|
169
193
|
topk_weights,
|
170
194
|
event,
|
171
|
-
) = self._dispatch_core(
|
172
|
-
hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
173
|
-
)
|
195
|
+
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
|
174
196
|
event.current_stream_wait() if self.async_finish else ()
|
175
197
|
if hidden_states.shape[0] > 0:
|
176
198
|
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
@@ -181,7 +203,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
181
203
|
(0,), device=hidden_states.device, dtype=torch.int64
|
182
204
|
)
|
183
205
|
seg_indptr = torch.zeros(
|
184
|
-
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
206
|
+
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
185
207
|
)
|
186
208
|
|
187
209
|
masked_m = expected_m = None
|
@@ -201,18 +223,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
201
223
|
x: torch.Tensor,
|
202
224
|
topk_idx: torch.Tensor,
|
203
225
|
topk_weights: torch.Tensor,
|
204
|
-
num_experts: int,
|
205
226
|
previous_event,
|
206
227
|
):
|
228
|
+
buffer = self._get_buffer()
|
207
229
|
(
|
208
230
|
num_tokens_per_rank,
|
209
231
|
num_tokens_per_rdma_rank,
|
210
232
|
num_tokens_per_expert,
|
211
233
|
is_token_in_rank,
|
212
234
|
previous_event,
|
213
|
-
) =
|
235
|
+
) = buffer.get_dispatch_layout(
|
214
236
|
topk_idx,
|
215
|
-
num_experts,
|
237
|
+
self.num_experts,
|
216
238
|
previous_event=previous_event,
|
217
239
|
async_finish=self.async_finish,
|
218
240
|
allocate_on_comm_stream=previous_event is not None,
|
@@ -221,6 +243,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
221
243
|
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
|
222
244
|
# However, doing this would incur an unknown synchronization error, but keeping
|
223
245
|
# `handle` as a member variable works.
|
246
|
+
|
224
247
|
(
|
225
248
|
recv_x,
|
226
249
|
recv_topk_idx,
|
@@ -228,7 +251,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
228
251
|
_, # num_recv_tokens_per_expert_list
|
229
252
|
self.handle,
|
230
253
|
event,
|
231
|
-
) =
|
254
|
+
) = buffer.dispatch(
|
232
255
|
x,
|
233
256
|
topk_idx=topk_idx,
|
234
257
|
topk_weights=topk_weights,
|
@@ -327,7 +350,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
327
350
|
return hidden_states
|
328
351
|
|
329
352
|
def _combine_core(self, x: torch.Tensor, previous_event):
|
330
|
-
|
353
|
+
buffer = self._get_buffer()
|
354
|
+
combined_x, _, event = buffer.combine(
|
331
355
|
x,
|
332
356
|
self.handle,
|
333
357
|
async_finish=self.async_finish,
|
@@ -336,6 +360,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
336
360
|
)
|
337
361
|
return combined_x, event
|
338
362
|
|
363
|
+
def _get_buffer(self):
|
364
|
+
DeepEPBuffer.set_dispatch_mode_as_normal()
|
365
|
+
return DeepEPBuffer.get_deepep_buffer(
|
366
|
+
self.group,
|
367
|
+
self.hidden_size,
|
368
|
+
self.params_bytes,
|
369
|
+
self.deepep_mode,
|
370
|
+
self.num_max_dispatch_tokens_per_rank,
|
371
|
+
self.num_experts,
|
372
|
+
)
|
373
|
+
|
339
374
|
|
340
375
|
class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
341
376
|
def __init__(self, return_recv_hook: bool, **kwargs):
|
@@ -345,14 +380,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
345
380
|
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
|
346
381
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
347
382
|
"""
|
348
|
-
# TODO(ch-wan): allow users to set this value
|
349
|
-
self.num_max_dispatch_tokens_per_rank = 128
|
350
|
-
self.buffer_low_latency = _get_buffer_low_latency(
|
351
|
-
self.group,
|
352
|
-
self.num_max_dispatch_tokens_per_rank,
|
353
|
-
self.hidden_size,
|
354
|
-
self.num_experts,
|
355
|
-
)
|
356
383
|
self.return_recv_hook = return_recv_hook
|
357
384
|
|
358
385
|
def dispatch_a(
|
@@ -360,21 +387,16 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
360
387
|
hidden_states: torch.Tensor,
|
361
388
|
topk_idx: torch.Tensor,
|
362
389
|
topk_weights: torch.Tensor,
|
363
|
-
num_experts: int,
|
364
|
-
num_max_dispatch_tokens_per_rank: int,
|
365
390
|
):
|
391
|
+
buffer = self._get_buffer()
|
366
392
|
topk_idx = topk_idx.to(torch.int64)
|
367
393
|
expected_m = (
|
368
|
-
hidden_states.shape[0]
|
369
|
-
|
370
|
-
|
371
|
-
+ num_experts
|
372
|
-
) // num_experts
|
394
|
+
hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
|
395
|
+
+ self.num_experts
|
396
|
+
) // self.num_experts
|
373
397
|
hidden_states, masked_m, event, hook = self._dispatch_core(
|
374
398
|
hidden_states,
|
375
399
|
topk_idx,
|
376
|
-
num_max_dispatch_tokens_per_rank,
|
377
|
-
num_experts,
|
378
400
|
use_fp8=True,
|
379
401
|
)
|
380
402
|
return (
|
@@ -415,8 +437,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
415
437
|
self,
|
416
438
|
hidden_states: torch.Tensor,
|
417
439
|
topk_idx: torch.Tensor,
|
418
|
-
num_max_dispatch_tokens_per_rank: int,
|
419
|
-
num_experts: int,
|
420
440
|
use_fp8: bool = False,
|
421
441
|
):
|
422
442
|
"""
|
@@ -451,13 +471,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
451
471
|
|
452
472
|
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
453
473
|
"""
|
454
|
-
|
474
|
+
buffer = self._get_buffer()
|
455
475
|
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
456
|
-
|
476
|
+
buffer.low_latency_dispatch(
|
457
477
|
hidden_states,
|
458
478
|
topk_idx,
|
459
|
-
num_max_dispatch_tokens_per_rank,
|
460
|
-
num_experts,
|
479
|
+
self.num_max_dispatch_tokens_per_rank,
|
480
|
+
self.num_experts,
|
461
481
|
use_fp8=use_fp8,
|
462
482
|
async_finish=not self.return_recv_hook,
|
463
483
|
return_recv_hook=self.return_recv_hook,
|
@@ -488,19 +508,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
488
508
|
topk_idx: torch.Tensor,
|
489
509
|
topk_weights: torch.Tensor,
|
490
510
|
):
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
)
|
511
|
+
buffer = self._get_buffer()
|
512
|
+
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
513
|
+
hidden_states,
|
514
|
+
topk_idx,
|
515
|
+
topk_weights,
|
516
|
+
self.handle,
|
517
|
+
async_finish=not self.return_recv_hook,
|
518
|
+
return_recv_hook=self.return_recv_hook,
|
500
519
|
)
|
501
520
|
self.handle = None
|
502
521
|
return combined_hidden_states, event, hook
|
503
522
|
|
523
|
+
def _get_buffer(self):
|
524
|
+
DeepEPBuffer.set_dispatch_mode_as_low_latency()
|
525
|
+
return DeepEPBuffer.get_deepep_buffer(
|
526
|
+
self.group,
|
527
|
+
self.hidden_size,
|
528
|
+
self.params_bytes,
|
529
|
+
self.deepep_mode,
|
530
|
+
self.num_max_dispatch_tokens_per_rank,
|
531
|
+
self.num_experts,
|
532
|
+
)
|
533
|
+
|
504
534
|
|
505
535
|
class DeepEPDispatcher:
|
506
536
|
def __init__(
|
@@ -526,18 +556,19 @@ class DeepEPDispatcher:
|
|
526
556
|
num_local_experts=num_local_experts,
|
527
557
|
hidden_size=hidden_size,
|
528
558
|
params_dtype=params_dtype,
|
559
|
+
deepep_mode=deepep_mode,
|
529
560
|
)
|
530
561
|
|
531
|
-
if self.deepep_mode.enable_normal():
|
532
|
-
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
|
533
|
-
async_finish=async_finish,
|
534
|
-
**common_kwargs,
|
535
|
-
)
|
536
562
|
if self.deepep_mode.enable_low_latency():
|
537
563
|
self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency(
|
538
564
|
return_recv_hook=return_recv_hook,
|
539
565
|
**common_kwargs,
|
540
566
|
)
|
567
|
+
if self.deepep_mode.enable_normal():
|
568
|
+
self._normal_dispatcher = _DeepEPDispatcherImplNormal(
|
569
|
+
async_finish=async_finish,
|
570
|
+
**common_kwargs,
|
571
|
+
)
|
541
572
|
|
542
573
|
def dispatch(self, *args, **kwargs) -> Tuple:
|
543
574
|
self.dispatch_a(*args, **kwargs)
|
@@ -548,16 +579,12 @@ class DeepEPDispatcher:
|
|
548
579
|
hidden_states: torch.Tensor,
|
549
580
|
topk_idx: torch.Tensor,
|
550
581
|
topk_weights: torch.Tensor,
|
551
|
-
num_experts: int,
|
552
|
-
num_max_dispatch_tokens_per_rank: int = 128,
|
553
582
|
forward_mode: ForwardMode = None,
|
554
583
|
):
|
555
584
|
inner_state = self._get_impl(forward_mode).dispatch_a(
|
556
585
|
hidden_states=hidden_states,
|
557
586
|
topk_idx=topk_idx,
|
558
587
|
topk_weights=topk_weights,
|
559
|
-
num_experts=num_experts,
|
560
|
-
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
561
588
|
)
|
562
589
|
self._dispatch_intermediate_state = forward_mode, inner_state
|
563
590
|
|
@@ -589,7 +616,7 @@ class DeepEPDispatcher:
|
|
589
616
|
del self._combine_intermediate_state
|
590
617
|
return self._get_impl(forward_mode).combine_b(*inner_state)
|
591
618
|
|
592
|
-
def _get_impl(self, forward_mode: ForwardMode) ->
|
619
|
+
def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
|
593
620
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
594
621
|
if resolved_deepep_mode == DeepEPMode.normal:
|
595
622
|
return self._normal_dispatcher
|
@@ -26,6 +26,7 @@ def fused_moe_forward_native(
|
|
26
26
|
apply_router_weight_on_input: bool = False,
|
27
27
|
inplace: bool = True,
|
28
28
|
no_combine: bool = False,
|
29
|
+
routed_scaling_factor: Optional[float] = None,
|
29
30
|
) -> torch.Tensor:
|
30
31
|
|
31
32
|
if apply_router_weight_on_input:
|
@@ -41,6 +42,7 @@ def fused_moe_forward_native(
|
|
41
42
|
num_expert_group=num_expert_group,
|
42
43
|
custom_routing_function=custom_routing_function,
|
43
44
|
correction_bias=correction_bias,
|
45
|
+
routed_scaling_factor=routed_scaling_factor,
|
44
46
|
torch_native=True,
|
45
47
|
)
|
46
48
|
|
@@ -71,6 +73,7 @@ def moe_forward_native(
|
|
71
73
|
custom_routing_function: Optional[Callable] = None,
|
72
74
|
correction_bias: Optional[torch.Tensor] = None,
|
73
75
|
activation: str = "silu",
|
76
|
+
routed_scaling_factor: Optional[float] = None,
|
74
77
|
) -> torch.Tensor:
|
75
78
|
|
76
79
|
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
@@ -86,6 +89,7 @@ def moe_forward_native(
|
|
86
89
|
custom_routing_function=custom_routing_function,
|
87
90
|
correction_bias=correction_bias,
|
88
91
|
torch_native=True,
|
92
|
+
routed_scaling_factor=routed_scaling_factor,
|
89
93
|
)
|
90
94
|
|
91
95
|
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 16,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 32,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 16,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 4
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 32,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 32,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 64,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 64,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 64,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 64,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 64,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 64,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 32,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 64,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 64,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|