sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from
|
5
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
6
5
|
|
7
6
|
import torch
|
8
|
-
import triton
|
9
|
-
import triton.language as tl
|
10
7
|
|
11
|
-
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
12
8
|
from sglang.srt.layers.moe import (
|
13
9
|
get_deepep_mode,
|
14
10
|
get_moe_a2a_backend,
|
@@ -18,13 +14,10 @@ from sglang.srt.layers.moe import (
|
|
18
14
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
19
15
|
ep_gather,
|
20
16
|
ep_scatter,
|
21
|
-
moe_ep_deepgemm_preprocess,
|
22
|
-
post_reorder_triton_kernel,
|
23
17
|
silu_and_mul_masked_post_quant_fwd,
|
24
18
|
tma_align_input_scale,
|
25
19
|
)
|
26
20
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
27
|
-
from sglang.srt.layers.moe.topk import TopKOutput
|
28
21
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
29
22
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
30
23
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
@@ -36,19 +29,10 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
|
36
29
|
CUTEDSL_MOE_NVFP4_DISPATCH,
|
37
30
|
ModelOptNvFp4FusedMoEMethod,
|
38
31
|
)
|
39
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
40
32
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
33
|
from sglang.srt.offloader import get_offloader
|
42
34
|
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
43
|
-
from sglang.srt.utils import
|
44
|
-
ceil_div,
|
45
|
-
dispose_tensor,
|
46
|
-
get_bool_env_var,
|
47
|
-
get_int_env_var,
|
48
|
-
is_cuda,
|
49
|
-
is_hip,
|
50
|
-
is_npu,
|
51
|
-
)
|
35
|
+
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
52
36
|
|
53
37
|
if TYPE_CHECKING:
|
54
38
|
from sglang.srt.layers.moe.token_dispatcher import (
|
@@ -72,29 +56,13 @@ if _use_aiter:
|
|
72
56
|
logger = logging.getLogger(__name__)
|
73
57
|
|
74
58
|
|
75
|
-
|
76
|
-
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
77
|
-
@torch.compile
|
78
|
-
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
79
|
-
temp = x.to(torch.float32).view(torch.int32)
|
80
|
-
exp = torch.bitwise_right_shift(temp, 23)
|
81
|
-
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
82
|
-
is_ru = torch.logical_and(
|
83
|
-
torch.logical_and((mant > 0), (exp != 0xFE)),
|
84
|
-
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
85
|
-
)
|
86
|
-
exp = torch.where(is_ru, exp + 1, exp)
|
87
|
-
new_x = exp.to(torch.uint8).view(torch.int)
|
88
|
-
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
89
|
-
|
90
|
-
|
91
|
-
class EPMoE(FusedMoE):
|
59
|
+
class DeepEPMoE(FusedMoE):
|
92
60
|
"""
|
93
|
-
MoE Expert Parallel Impl
|
94
|
-
|
95
|
-
|
61
|
+
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
96
62
|
"""
|
97
63
|
|
64
|
+
_has_printed = False
|
65
|
+
|
98
66
|
def __init__(
|
99
67
|
self,
|
100
68
|
num_experts: int,
|
@@ -108,272 +76,29 @@ class EPMoE(FusedMoE):
|
|
108
76
|
prefix: str = "",
|
109
77
|
activation: str = "silu",
|
110
78
|
routed_scaling_factor: Optional[float] = None,
|
111
|
-
gemm1_alpha: Optional[float] = None,
|
112
|
-
gemm1_clamp_limit: Optional[float] = None,
|
113
|
-
with_bias: bool = False,
|
114
79
|
):
|
115
80
|
super().__init__(
|
116
81
|
num_experts=num_experts,
|
82
|
+
top_k=top_k,
|
117
83
|
hidden_size=hidden_size,
|
118
84
|
intermediate_size=intermediate_size,
|
119
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
120
85
|
layer_id=layer_id,
|
121
|
-
|
86
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
122
87
|
params_dtype=params_dtype,
|
123
88
|
quant_config=quant_config,
|
124
89
|
prefix=prefix,
|
125
90
|
activation=activation,
|
126
|
-
# apply_router_weight_on_input=apply_router_weight_on_input,
|
127
91
|
routed_scaling_factor=routed_scaling_factor,
|
128
|
-
gemm1_alpha=gemm1_alpha,
|
129
|
-
gemm1_clamp_limit=gemm1_clamp_limit,
|
130
|
-
with_bias=with_bias,
|
131
92
|
)
|
132
93
|
|
133
|
-
self.intermediate_size = intermediate_size
|
134
|
-
|
135
94
|
if isinstance(quant_config, Fp8Config):
|
136
95
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
137
|
-
self.block_shape = (
|
138
|
-
self.quant_method.quant_config.weight_block_size
|
139
|
-
if self.use_block_quant
|
140
|
-
else None
|
141
|
-
)
|
142
96
|
self.use_fp8_w8a8 = True
|
143
97
|
self.fp8_dtype = torch.float8_e4m3fn
|
144
|
-
self.activation_scheme = quant_config.activation_scheme
|
145
98
|
else:
|
146
99
|
self.use_fp8_w8a8 = False
|
147
100
|
self.use_block_quant = False
|
148
|
-
self.block_shape = None
|
149
|
-
self.activation_scheme = None
|
150
|
-
|
151
|
-
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
152
|
-
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
153
|
-
return self.forward_deepgemm(hidden_states, topk_output)
|
154
|
-
else:
|
155
|
-
return super().forward(hidden_states, topk_output)
|
156
|
-
|
157
|
-
def forward_deepgemm(
|
158
|
-
self,
|
159
|
-
hidden_states: torch.Tensor,
|
160
|
-
topk_output: TopKOutput,
|
161
|
-
):
|
162
|
-
|
163
|
-
self.w13_weight_fp8 = (
|
164
|
-
self.w13_weight,
|
165
|
-
(
|
166
|
-
self.w13_weight_scale_inv
|
167
|
-
if self.use_block_quant
|
168
|
-
else self.w13_weight_scale
|
169
|
-
),
|
170
|
-
)
|
171
|
-
self.w2_weight_fp8 = (
|
172
|
-
self.w2_weight,
|
173
|
-
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
174
|
-
)
|
175
|
-
|
176
|
-
assert self.quant_method is not None
|
177
|
-
assert self.moe_runner_config.activation == "silu"
|
178
|
-
|
179
|
-
hidden_states_shape = hidden_states.shape
|
180
|
-
hidden_states_dtype = hidden_states.dtype
|
181
|
-
hidden_states_device = hidden_states.device
|
182
|
-
|
183
|
-
topk_weights, topk_ids, _ = topk_output
|
184
|
-
|
185
|
-
if not self.use_block_quant:
|
186
|
-
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
187
|
-
scale_block_size = 128
|
188
|
-
w13_weight_scale_n = 2 * (
|
189
|
-
(self.intermediate_size + scale_block_size - 1) // scale_block_size
|
190
|
-
)
|
191
|
-
w13_weight_scale_k = (
|
192
|
-
hidden_states_shape[-1] + scale_block_size - 1
|
193
|
-
) // scale_block_size
|
194
|
-
w13_weight_scale = (
|
195
|
-
self.w13_weight_scale.unsqueeze(1)
|
196
|
-
.repeat_interleave(w13_weight_scale_n, dim=1)
|
197
|
-
.unsqueeze(2)
|
198
|
-
.repeat_interleave(w13_weight_scale_k, dim=2)
|
199
|
-
)
|
200
|
-
self.w13_weight_fp8 = (
|
201
|
-
self.w13_weight,
|
202
|
-
w13_weight_scale,
|
203
|
-
)
|
204
|
-
w2_weight_scale_n = (
|
205
|
-
hidden_states_shape[-1] + scale_block_size - 1
|
206
|
-
) // scale_block_size
|
207
|
-
w2_weight_scale_k = (
|
208
|
-
self.intermediate_size + scale_block_size - 1
|
209
|
-
) // scale_block_size
|
210
|
-
w2_weight_scale = (
|
211
|
-
self.w2_weight_scale.unsqueeze(1)
|
212
|
-
.repeat_interleave(w2_weight_scale_n, dim=1)
|
213
|
-
.unsqueeze(2)
|
214
|
-
.repeat_interleave(w2_weight_scale_k, dim=2)
|
215
|
-
)
|
216
|
-
self.w2_weight_fp8 = (
|
217
|
-
self.w2_weight,
|
218
|
-
w2_weight_scale,
|
219
|
-
)
|
220
|
-
|
221
|
-
# PreReorder
|
222
|
-
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
|
223
|
-
moe_ep_deepgemm_preprocess(
|
224
|
-
topk_ids,
|
225
|
-
self.num_experts,
|
226
|
-
hidden_states,
|
227
|
-
self.top_k,
|
228
|
-
self.start_expert_id,
|
229
|
-
self.end_expert_id,
|
230
|
-
self.block_shape,
|
231
|
-
)
|
232
|
-
)
|
233
|
-
|
234
|
-
dispose_tensor(hidden_states)
|
235
|
-
|
236
|
-
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
237
|
-
b, s_mn, s_k = gateup_input_scale.shape
|
238
|
-
assert (
|
239
|
-
s_mn % 4 == 0 and s_k % 4 == 0
|
240
|
-
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
241
|
-
|
242
|
-
# GroupGemm-0
|
243
|
-
gateup_input_fp8 = (
|
244
|
-
gateup_input,
|
245
|
-
(
|
246
|
-
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
247
|
-
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
248
|
-
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
249
|
-
gateup_input_scale
|
250
|
-
)
|
251
|
-
),
|
252
|
-
)
|
253
|
-
num_groups, m, k = gateup_input_fp8[0].size()
|
254
|
-
n = self.w13_weight.size(1)
|
255
|
-
gateup_output = torch.empty(
|
256
|
-
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
257
|
-
)
|
258
|
-
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
259
|
-
gateup_input_fp8,
|
260
|
-
self.w13_weight_fp8,
|
261
|
-
gateup_output,
|
262
|
-
masked_m,
|
263
|
-
expected_m,
|
264
|
-
)
|
265
|
-
del gateup_input
|
266
|
-
del gateup_input_fp8
|
267
101
|
|
268
|
-
# Act
|
269
|
-
down_input = torch.empty(
|
270
|
-
(
|
271
|
-
gateup_output.shape[0],
|
272
|
-
gateup_output.shape[1],
|
273
|
-
gateup_output.shape[2] // 2,
|
274
|
-
),
|
275
|
-
device=hidden_states_device,
|
276
|
-
dtype=self.fp8_dtype,
|
277
|
-
)
|
278
|
-
scale_block_size = 128
|
279
|
-
down_input_scale = torch.empty(
|
280
|
-
(
|
281
|
-
gateup_output.shape[0],
|
282
|
-
gateup_output.shape[1],
|
283
|
-
gateup_output.shape[2] // 2 // scale_block_size,
|
284
|
-
),
|
285
|
-
device=hidden_states_device,
|
286
|
-
dtype=torch.float32,
|
287
|
-
)
|
288
|
-
silu_and_mul_masked_post_quant_fwd(
|
289
|
-
gateup_output,
|
290
|
-
down_input,
|
291
|
-
down_input_scale,
|
292
|
-
scale_block_size,
|
293
|
-
masked_m,
|
294
|
-
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
295
|
-
)
|
296
|
-
del gateup_output
|
297
|
-
|
298
|
-
# GroupGemm-1
|
299
|
-
n = self.w2_weight.size(1)
|
300
|
-
down_input_fp8 = (
|
301
|
-
down_input,
|
302
|
-
(
|
303
|
-
down_input_scale
|
304
|
-
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
305
|
-
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
306
|
-
),
|
307
|
-
)
|
308
|
-
down_output = torch.empty(
|
309
|
-
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
310
|
-
)
|
311
|
-
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
312
|
-
down_input_fp8,
|
313
|
-
self.w2_weight_fp8,
|
314
|
-
down_output,
|
315
|
-
masked_m,
|
316
|
-
expected_m,
|
317
|
-
)
|
318
|
-
del down_input
|
319
|
-
del down_input_fp8
|
320
|
-
|
321
|
-
# PostReorder
|
322
|
-
output = torch.empty(
|
323
|
-
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
324
|
-
)
|
325
|
-
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
326
|
-
down_output,
|
327
|
-
output,
|
328
|
-
src2dst,
|
329
|
-
topk_ids,
|
330
|
-
topk_weights,
|
331
|
-
self.start_expert_id,
|
332
|
-
self.end_expert_id,
|
333
|
-
self.top_k,
|
334
|
-
hidden_states_shape[1],
|
335
|
-
m_max * self.start_expert_id,
|
336
|
-
BLOCK_SIZE=512,
|
337
|
-
)
|
338
|
-
if self.moe_runner_config.routed_scaling_factor is not None:
|
339
|
-
output *= self.moe_runner_config.routed_scaling_factor
|
340
|
-
return output
|
341
|
-
|
342
|
-
|
343
|
-
class DeepEPMoE(EPMoE):
|
344
|
-
"""
|
345
|
-
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
346
|
-
"""
|
347
|
-
|
348
|
-
_has_printed = False
|
349
|
-
|
350
|
-
def __init__(
|
351
|
-
self,
|
352
|
-
num_experts: int,
|
353
|
-
top_k: int,
|
354
|
-
hidden_size: int,
|
355
|
-
intermediate_size: int,
|
356
|
-
layer_id: int,
|
357
|
-
num_fused_shared_experts: int = 0,
|
358
|
-
params_dtype: Optional[torch.dtype] = None,
|
359
|
-
quant_config: Optional[QuantizationConfig] = None,
|
360
|
-
prefix: str = "",
|
361
|
-
activation: str = "silu",
|
362
|
-
routed_scaling_factor: Optional[float] = None,
|
363
|
-
):
|
364
|
-
super().__init__(
|
365
|
-
num_experts=num_experts,
|
366
|
-
top_k=top_k,
|
367
|
-
hidden_size=hidden_size,
|
368
|
-
intermediate_size=intermediate_size,
|
369
|
-
layer_id=layer_id,
|
370
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
371
|
-
params_dtype=params_dtype,
|
372
|
-
quant_config=quant_config,
|
373
|
-
prefix=prefix,
|
374
|
-
activation=activation,
|
375
|
-
routed_scaling_factor=routed_scaling_factor,
|
376
|
-
)
|
377
102
|
self.deepep_mode = get_deepep_mode()
|
378
103
|
|
379
104
|
# TODO: move to the beginning of the file
|
@@ -567,7 +292,6 @@ class DeepEPMoE(EPMoE):
|
|
567
292
|
N = self.w13_weight.size(1)
|
568
293
|
scale_block_size = 128
|
569
294
|
|
570
|
-
# TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
|
571
295
|
w13_weight_fp8 = (
|
572
296
|
self.w13_weight,
|
573
297
|
(
|
@@ -988,8 +712,6 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
|
|
988
712
|
return FlashInferFusedMoE
|
989
713
|
if get_moe_runner_backend().is_flashinfer_cutlass():
|
990
714
|
return FusedMoE
|
991
|
-
if get_moe_expert_parallel_world_size() > 1:
|
992
|
-
return EPMoE
|
993
715
|
return FusedMoE
|
994
716
|
|
995
717
|
|
@@ -156,8 +156,7 @@ class FusedMoE(torch.nn.Module):
|
|
156
156
|
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
157
157
|
assert num_experts % self.moe_ep_size == 0
|
158
158
|
self.num_local_experts = num_experts // self.moe_ep_size
|
159
|
-
|
160
|
-
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
159
|
+
|
161
160
|
if self.moe_ep_size > 1:
|
162
161
|
# TODO(ch-wan): support shared experts fusion
|
163
162
|
# Create a tensor of size num_experts filled with -1
|
@@ -207,15 +206,11 @@ class FusedMoE(torch.nn.Module):
|
|
207
206
|
gemm1_clamp_limit=gemm1_clamp_limit,
|
208
207
|
)
|
209
208
|
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
self.quant_method: FusedMoEMethodBase = quant_config.get_quant_method(
|
216
|
-
self, prefix
|
217
|
-
)
|
218
|
-
assert self.quant_method is not None
|
209
|
+
self.quant_method: Optional[FusedMoEMethodBase] = None
|
210
|
+
if quant_config is not None:
|
211
|
+
self.quant_method = quant_config.get_quant_method(self, prefix)
|
212
|
+
if self.quant_method is None:
|
213
|
+
self.quant_method = UnquantizedFusedMoEMethod(self.use_triton_kernels)
|
219
214
|
|
220
215
|
self.quant_method.create_weights(
|
221
216
|
layer=self,
|
@@ -0,0 +1,304 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import TYPE_CHECKING, List, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.layers.moe.moe_runner.base import (
|
9
|
+
MoeQuantInfo,
|
10
|
+
MoeRunnerConfig,
|
11
|
+
MoeRunnerCore,
|
12
|
+
RunnerInput,
|
13
|
+
RunnerOutput,
|
14
|
+
register_post_permute,
|
15
|
+
register_pre_permute,
|
16
|
+
)
|
17
|
+
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
18
|
+
from sglang.srt.utils import dispose_tensor
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
22
|
+
StandardCombineInput,
|
23
|
+
StandardDispatchOutput,
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
# TODO(kaixih@nvidia): ideally we should merge this logic into
|
28
|
+
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
29
|
+
@torch.compile
|
30
|
+
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
31
|
+
temp = x.to(torch.float32).view(torch.int32)
|
32
|
+
exp = torch.bitwise_right_shift(temp, 23)
|
33
|
+
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
34
|
+
is_ru = torch.logical_and(
|
35
|
+
torch.logical_and((mant > 0), (exp != 0xFE)),
|
36
|
+
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
37
|
+
)
|
38
|
+
exp = torch.where(is_ru, exp + 1, exp)
|
39
|
+
new_x = exp.to(torch.uint8).view(torch.int)
|
40
|
+
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
41
|
+
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class DeepGemmRunnerInput(RunnerInput):
|
45
|
+
hidden_states: torch.Tensor
|
46
|
+
hidden_states_scale: torch.Tensor
|
47
|
+
masked_m: torch.Tensor
|
48
|
+
expected_m: int
|
49
|
+
use_masked_gemm: bool
|
50
|
+
|
51
|
+
@property
|
52
|
+
def runner_backend(self) -> MoeRunnerBackend:
|
53
|
+
return MoeRunnerBackend.DEEP_GEMM
|
54
|
+
|
55
|
+
|
56
|
+
@dataclass
|
57
|
+
class DeepGemmRunnerOutput(RunnerOutput):
|
58
|
+
hidden_states: torch.Tensor
|
59
|
+
|
60
|
+
@property
|
61
|
+
def runner_backend(self) -> MoeRunnerBackend:
|
62
|
+
return MoeRunnerBackend.DEEP_GEMM
|
63
|
+
|
64
|
+
|
65
|
+
@dataclass
|
66
|
+
class DeepGemmMoeQuantInfo(MoeQuantInfo):
|
67
|
+
w13_weight: torch.Tensor
|
68
|
+
w2_weight: torch.Tensor
|
69
|
+
use_fp8: bool
|
70
|
+
w13_scale: Optional[torch.Tensor] = None
|
71
|
+
w2_scale: Optional[torch.Tensor] = None
|
72
|
+
block_shape: Optional[List[int]] = None
|
73
|
+
|
74
|
+
|
75
|
+
class DeepGemmRunnerCore(MoeRunnerCore):
|
76
|
+
def __init__(self, config: MoeRunnerConfig):
|
77
|
+
super().__init__(config)
|
78
|
+
assert self.config.activation == "silu"
|
79
|
+
|
80
|
+
def run(
|
81
|
+
self,
|
82
|
+
runner_input: DeepGemmRunnerInput,
|
83
|
+
quant_info: DeepGemmMoeQuantInfo,
|
84
|
+
running_state: dict,
|
85
|
+
) -> DeepGemmRunnerOutput:
|
86
|
+
|
87
|
+
if runner_input.use_masked_gemm:
|
88
|
+
hidden_states = self._run_masked_gemm(
|
89
|
+
runner_input,
|
90
|
+
quant_info,
|
91
|
+
running_state,
|
92
|
+
)
|
93
|
+
else:
|
94
|
+
hidden_states = self._run_contiguous_gemm(
|
95
|
+
runner_input,
|
96
|
+
quant_info,
|
97
|
+
running_state,
|
98
|
+
)
|
99
|
+
return DeepGemmRunnerOutput(hidden_states=hidden_states)
|
100
|
+
|
101
|
+
def _run_masked_gemm(
|
102
|
+
self,
|
103
|
+
runner_input: DeepGemmRunnerInput,
|
104
|
+
quant_info: DeepGemmMoeQuantInfo,
|
105
|
+
running_state: dict,
|
106
|
+
) -> torch.Tensor:
|
107
|
+
|
108
|
+
from sglang.srt.layers.moe.ep_moe.kernels import (
|
109
|
+
silu_and_mul_masked_post_quant_fwd,
|
110
|
+
)
|
111
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
112
|
+
|
113
|
+
hidden_states = runner_input.hidden_states
|
114
|
+
hidden_states_scale = runner_input.hidden_states_scale
|
115
|
+
masked_m = runner_input.masked_m
|
116
|
+
expected_m = runner_input.expected_m
|
117
|
+
|
118
|
+
w13_weight = quant_info.w13_weight
|
119
|
+
w2_weight = quant_info.w2_weight
|
120
|
+
w13_scale = quant_info.w13_scale
|
121
|
+
w2_scale = quant_info.w2_scale
|
122
|
+
|
123
|
+
hidden_states_device = running_state["hidden_states_device"]
|
124
|
+
|
125
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
126
|
+
b, s_mn, s_k = hidden_states_scale.shape
|
127
|
+
assert (
|
128
|
+
s_mn % 4 == 0 and s_k % 4 == 0
|
129
|
+
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
130
|
+
|
131
|
+
# GroupGemm-0
|
132
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
133
|
+
hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
|
134
|
+
else:
|
135
|
+
hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
136
|
+
hidden_states_scale
|
137
|
+
)
|
138
|
+
|
139
|
+
num_groups, m, k = hidden_states.shape
|
140
|
+
n = w13_weight.size(1)
|
141
|
+
gateup_output = torch.empty(
|
142
|
+
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
143
|
+
)
|
144
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
145
|
+
(hidden_states, hidden_states_scale),
|
146
|
+
(w13_weight, w13_scale),
|
147
|
+
gateup_output,
|
148
|
+
masked_m,
|
149
|
+
expected_m,
|
150
|
+
)
|
151
|
+
dispose_tensor(hidden_states)
|
152
|
+
|
153
|
+
# Act
|
154
|
+
down_input = torch.empty(
|
155
|
+
(
|
156
|
+
gateup_output.shape[0],
|
157
|
+
gateup_output.shape[1],
|
158
|
+
gateup_output.shape[2] // 2,
|
159
|
+
),
|
160
|
+
device=hidden_states_device,
|
161
|
+
dtype=torch.float8_e4m3fn,
|
162
|
+
)
|
163
|
+
scale_block_size = 128
|
164
|
+
down_input_scale = torch.empty(
|
165
|
+
(
|
166
|
+
gateup_output.shape[0],
|
167
|
+
gateup_output.shape[1],
|
168
|
+
gateup_output.shape[2] // 2 // scale_block_size,
|
169
|
+
),
|
170
|
+
device=hidden_states_device,
|
171
|
+
dtype=torch.float32,
|
172
|
+
)
|
173
|
+
silu_and_mul_masked_post_quant_fwd(
|
174
|
+
gateup_output,
|
175
|
+
down_input,
|
176
|
+
down_input_scale,
|
177
|
+
scale_block_size,
|
178
|
+
masked_m,
|
179
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
180
|
+
)
|
181
|
+
del gateup_output
|
182
|
+
|
183
|
+
# GroupGemm-1
|
184
|
+
n = w2_weight.shape[1]
|
185
|
+
|
186
|
+
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
187
|
+
down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
188
|
+
down_input_scale
|
189
|
+
)
|
190
|
+
|
191
|
+
down_output = torch.empty(
|
192
|
+
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
193
|
+
)
|
194
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
195
|
+
(down_input, down_input_scale),
|
196
|
+
(w2_weight, w2_scale),
|
197
|
+
down_output,
|
198
|
+
masked_m,
|
199
|
+
expected_m,
|
200
|
+
)
|
201
|
+
del down_input
|
202
|
+
|
203
|
+
return down_output
|
204
|
+
|
205
|
+
def _run_contiguous_gemm(
|
206
|
+
self,
|
207
|
+
runner_input: DeepGemmRunnerInput,
|
208
|
+
quant_info: DeepGemmMoeQuantInfo,
|
209
|
+
running_state: dict,
|
210
|
+
) -> torch.Tensor:
|
211
|
+
pass
|
212
|
+
|
213
|
+
@property
|
214
|
+
def runner_backend(self) -> MoeRunnerBackend:
|
215
|
+
return MoeRunnerBackend.DEEP_GEMM
|
216
|
+
|
217
|
+
|
218
|
+
@register_pre_permute("standard", "deep_gemm")
|
219
|
+
def pre_permute_standard_to_deep_gemm(
|
220
|
+
dispatch_output: StandardDispatchOutput,
|
221
|
+
quant_info: DeepGemmMoeQuantInfo,
|
222
|
+
runner_config: MoeRunnerConfig,
|
223
|
+
running_state: dict,
|
224
|
+
) -> DeepGemmRunnerInput:
|
225
|
+
from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
|
226
|
+
|
227
|
+
hidden_states, topk_output = dispatch_output
|
228
|
+
topk_weights, topk_ids, _ = topk_output
|
229
|
+
|
230
|
+
hidden_states_shape = hidden_states.shape
|
231
|
+
hidden_states_dtype = hidden_states.dtype
|
232
|
+
hidden_states_device = hidden_states.device
|
233
|
+
hidden_states_ref = hidden_states
|
234
|
+
|
235
|
+
topk_weights, topk_ids = topk_weights, topk_ids
|
236
|
+
|
237
|
+
# PreReorder
|
238
|
+
masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = (
|
239
|
+
moe_ep_deepgemm_preprocess(
|
240
|
+
topk_ids,
|
241
|
+
runner_config.num_local_experts,
|
242
|
+
hidden_states,
|
243
|
+
runner_config.top_k,
|
244
|
+
quant_info.block_shape,
|
245
|
+
)
|
246
|
+
)
|
247
|
+
|
248
|
+
dispose_tensor(hidden_states_ref)
|
249
|
+
|
250
|
+
running_state["topk_ids"] = topk_ids
|
251
|
+
running_state["topk_weights"] = topk_weights
|
252
|
+
running_state["hidden_states_shape"] = hidden_states_shape
|
253
|
+
running_state["hidden_states_dtype"] = hidden_states_dtype
|
254
|
+
running_state["hidden_states_device"] = hidden_states_device
|
255
|
+
running_state["src2dst"] = src2dst
|
256
|
+
|
257
|
+
return DeepGemmRunnerInput(
|
258
|
+
hidden_states=hidden_states,
|
259
|
+
hidden_states_scale=hidden_states_scale,
|
260
|
+
masked_m=masked_m,
|
261
|
+
expected_m=expected_m,
|
262
|
+
use_masked_gemm=True,
|
263
|
+
)
|
264
|
+
|
265
|
+
|
266
|
+
@register_post_permute("deep_gemm", "standard")
|
267
|
+
def post_permute_deep_gemm_to_standard(
|
268
|
+
runner_output: DeepGemmRunnerOutput,
|
269
|
+
quant_info: DeepGemmMoeQuantInfo,
|
270
|
+
runner_config: MoeRunnerConfig,
|
271
|
+
running_state: dict,
|
272
|
+
) -> StandardCombineInput:
|
273
|
+
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
|
274
|
+
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
275
|
+
|
276
|
+
hidden_states_shape = running_state["hidden_states_shape"]
|
277
|
+
hidden_states_dtype = running_state["hidden_states_dtype"]
|
278
|
+
hidden_states_device = running_state["hidden_states_device"]
|
279
|
+
src2dst = running_state["src2dst"]
|
280
|
+
topk_ids = running_state["topk_ids"]
|
281
|
+
topk_weights = running_state["topk_weights"]
|
282
|
+
|
283
|
+
output = torch.empty(
|
284
|
+
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
285
|
+
)
|
286
|
+
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
287
|
+
runner_output.hidden_states,
|
288
|
+
output,
|
289
|
+
src2dst,
|
290
|
+
topk_ids,
|
291
|
+
topk_weights,
|
292
|
+
runner_config.top_k,
|
293
|
+
hidden_states_shape[1],
|
294
|
+
BLOCK_SIZE=512,
|
295
|
+
)
|
296
|
+
|
297
|
+
dispose_tensor(runner_output.hidden_states)
|
298
|
+
|
299
|
+
if runner_config.routed_scaling_factor is not None:
|
300
|
+
output *= runner_config.routed_scaling_factor
|
301
|
+
|
302
|
+
return StandardCombineInput(
|
303
|
+
hidden_states=output,
|
304
|
+
)
|