sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
2
2
|
|
3
|
+
import importlib
|
3
4
|
from abc import abstractmethod
|
4
5
|
from enum import Enum
|
5
6
|
from typing import Callable, List, Optional, Tuple
|
@@ -12,23 +13,33 @@ from sglang.srt.distributed import (
|
|
12
13
|
get_tensor_model_parallel_world_size,
|
13
14
|
tensor_model_parallel_all_reduce,
|
14
15
|
)
|
16
|
+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
15
17
|
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
16
18
|
from sglang.srt.layers.moe.topk import select_experts
|
17
19
|
from sglang.srt.layers.quantization.base_config import (
|
18
20
|
QuantizationConfig,
|
19
21
|
QuantizeMethodBase,
|
20
22
|
)
|
23
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
24
|
+
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
21
25
|
from sglang.srt.utils import (
|
22
|
-
_process_weight_after_loading,
|
23
26
|
cpu_has_amx_support,
|
24
27
|
get_bool_env_var,
|
25
28
|
is_cpu,
|
26
29
|
is_hip,
|
27
30
|
set_weight_attrs,
|
31
|
+
use_intel_amx_backend,
|
28
32
|
)
|
29
33
|
|
34
|
+
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
35
|
+
|
30
36
|
if torch.cuda.is_available():
|
31
37
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
38
|
+
|
39
|
+
if has_triton_kernels:
|
40
|
+
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
41
|
+
triton_kernel_moe_forward,
|
42
|
+
)
|
32
43
|
else:
|
33
44
|
fused_experts = None # type: ignore
|
34
45
|
|
@@ -85,6 +96,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|
85
96
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
86
97
|
"""MoE method without quantization."""
|
87
98
|
|
99
|
+
def __init__(self, use_triton_kernels: bool = False):
|
100
|
+
super().__init__()
|
101
|
+
self.use_triton_kernels = use_triton_kernels
|
102
|
+
|
88
103
|
def create_weights(
|
89
104
|
self,
|
90
105
|
layer: torch.nn.Module,
|
@@ -95,20 +110,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
95
110
|
**extra_weight_attrs,
|
96
111
|
):
|
97
112
|
# Fused gate_up_proj (column parallel)
|
113
|
+
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
114
|
+
if self.use_triton_kernels:
|
115
|
+
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
|
98
116
|
w13_weight = torch.nn.Parameter(
|
99
|
-
torch.empty(
|
100
|
-
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
101
|
-
),
|
117
|
+
torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
|
102
118
|
requires_grad=False,
|
103
119
|
)
|
104
120
|
layer.register_parameter("w13_weight", w13_weight)
|
105
121
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
106
122
|
|
107
123
|
# down_proj (row parallel)
|
124
|
+
w2_weight_n, w2_weight_k = (
|
125
|
+
hidden_size,
|
126
|
+
intermediate_size,
|
127
|
+
)
|
128
|
+
if self.use_triton_kernels:
|
129
|
+
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
|
108
130
|
w2_weight = torch.nn.Parameter(
|
109
|
-
torch.empty(
|
110
|
-
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
111
|
-
),
|
131
|
+
torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
|
112
132
|
requires_grad=False,
|
113
133
|
)
|
114
134
|
layer.register_parameter("w2_weight", w2_weight)
|
@@ -129,7 +149,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
129
149
|
|
130
150
|
# Pack weight for get better performance on CPU
|
131
151
|
if _is_cpu and _is_cpu_amx_available:
|
132
|
-
|
152
|
+
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
133
153
|
|
134
154
|
return
|
135
155
|
|
@@ -190,59 +210,72 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
190
210
|
no_combine: bool = False,
|
191
211
|
routed_scaling_factor: Optional[float] = None,
|
192
212
|
) -> torch.Tensor:
|
193
|
-
topk_weights, topk_ids = select_experts(
|
194
|
-
hidden_states=x,
|
195
|
-
router_logits=router_logits,
|
196
|
-
use_grouped_topk=use_grouped_topk,
|
197
|
-
top_k=top_k,
|
198
|
-
renormalize=renormalize,
|
199
|
-
topk_group=topk_group,
|
200
|
-
num_expert_group=num_expert_group,
|
201
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
202
|
-
custom_routing_function=custom_routing_function,
|
203
|
-
correction_bias=correction_bias,
|
204
|
-
routed_scaling_factor=routed_scaling_factor,
|
205
|
-
)
|
206
213
|
|
207
|
-
if
|
208
|
-
|
209
|
-
if apply_router_weight_on_input:
|
210
|
-
assert (
|
211
|
-
topk_weights.dim() == 2
|
212
|
-
), "`topk_weights` should be in shape (num_tokens, topk)"
|
213
|
-
_, topk = topk_weights.shape
|
214
|
-
assert (
|
215
|
-
topk == 1
|
216
|
-
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
217
|
-
x = x * topk_weights.to(x.dtype)
|
218
|
-
topk_weights = torch.ones_like(
|
219
|
-
topk_weights, dtype=torch.float32
|
220
|
-
) # topk_weights must be FP32 (float32)
|
221
|
-
|
222
|
-
return fused_moe(
|
223
|
-
x,
|
224
|
-
layer.w13_weight,
|
225
|
-
layer.w2_weight,
|
226
|
-
topk_weights,
|
227
|
-
topk_ids,
|
228
|
-
activation=(
|
229
|
-
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
230
|
-
),
|
231
|
-
)
|
232
|
-
else:
|
233
|
-
return fused_experts(
|
214
|
+
if self.use_triton_kernels:
|
215
|
+
return triton_kernel_moe_forward(
|
234
216
|
hidden_states=x,
|
235
217
|
w1=layer.w13_weight,
|
236
218
|
w2=layer.w2_weight,
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
219
|
+
gating_output=router_logits,
|
220
|
+
topk=top_k,
|
221
|
+
renormalize=renormalize,
|
222
|
+
)
|
223
|
+
else:
|
224
|
+
topk_weights, topk_ids = select_experts(
|
225
|
+
hidden_states=x,
|
226
|
+
router_logits=router_logits,
|
227
|
+
use_grouped_topk=use_grouped_topk,
|
228
|
+
top_k=top_k,
|
229
|
+
renormalize=renormalize,
|
230
|
+
topk_group=topk_group,
|
231
|
+
num_expert_group=num_expert_group,
|
232
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
233
|
+
custom_routing_function=custom_routing_function,
|
234
|
+
correction_bias=correction_bias,
|
243
235
|
routed_scaling_factor=routed_scaling_factor,
|
244
236
|
)
|
245
237
|
|
238
|
+
if _use_aiter:
|
239
|
+
assert not no_combine, "unsupported"
|
240
|
+
if apply_router_weight_on_input:
|
241
|
+
assert (
|
242
|
+
topk_weights.dim() == 2
|
243
|
+
), "`topk_weights` should be in shape (num_tokens, topk)"
|
244
|
+
_, topk = topk_weights.shape
|
245
|
+
assert (
|
246
|
+
topk == 1
|
247
|
+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
248
|
+
x = x * topk_weights.to(x.dtype)
|
249
|
+
topk_weights = torch.ones_like(
|
250
|
+
topk_weights, dtype=torch.float32
|
251
|
+
) # topk_weights must be FP32 (float32)
|
252
|
+
|
253
|
+
return fused_moe(
|
254
|
+
x,
|
255
|
+
layer.w13_weight,
|
256
|
+
layer.w2_weight,
|
257
|
+
topk_weights,
|
258
|
+
topk_ids,
|
259
|
+
activation=(
|
260
|
+
ActivationType.Silu
|
261
|
+
if activation == "silu"
|
262
|
+
else ActivationType.Gelu
|
263
|
+
),
|
264
|
+
)
|
265
|
+
else:
|
266
|
+
return fused_experts(
|
267
|
+
hidden_states=x,
|
268
|
+
w1=layer.w13_weight,
|
269
|
+
w2=layer.w2_weight,
|
270
|
+
topk_weights=topk_weights,
|
271
|
+
topk_ids=topk_ids,
|
272
|
+
inplace=inplace and not no_combine,
|
273
|
+
activation=activation,
|
274
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
275
|
+
no_combine=no_combine,
|
276
|
+
routed_scaling_factor=routed_scaling_factor,
|
277
|
+
)
|
278
|
+
|
246
279
|
def forward_cpu(
|
247
280
|
self,
|
248
281
|
layer: torch.nn.Module,
|
@@ -264,10 +297,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
264
297
|
) -> torch.Tensor:
|
265
298
|
assert activation == "silu", f"activation = {activation} is not supported."
|
266
299
|
|
267
|
-
if (
|
268
|
-
getattr(layer, "use_intel_amx_backend", False)
|
269
|
-
and not apply_router_weight_on_input
|
270
|
-
):
|
300
|
+
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
271
301
|
topk_weights, topk_ids = select_experts(
|
272
302
|
hidden_states=x,
|
273
303
|
router_logits=router_logits,
|
@@ -287,11 +317,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
287
317
|
x,
|
288
318
|
layer.w13_weight,
|
289
319
|
layer.w2_weight,
|
290
|
-
topk_weights
|
291
|
-
torch.float
|
292
|
-
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
|
320
|
+
topk_weights,
|
293
321
|
topk_ids,
|
294
|
-
|
322
|
+
False, # inplace # See [Note] inplace should be False in fused_experts.
|
295
323
|
False, # use_int8_w8a8
|
296
324
|
False, # use_fp8_w8a16
|
297
325
|
None, # w1_scale
|
@@ -321,6 +349,44 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
321
349
|
routed_scaling_factor,
|
322
350
|
)
|
323
351
|
|
352
|
+
def forward_npu(
|
353
|
+
self,
|
354
|
+
layer: torch.nn.Module,
|
355
|
+
x: torch.Tensor,
|
356
|
+
use_grouped_topk: bool,
|
357
|
+
top_k: int,
|
358
|
+
router_logits: torch.Tensor,
|
359
|
+
renormalize: bool,
|
360
|
+
topk_group: Optional[int] = None,
|
361
|
+
num_expert_group: Optional[int] = None,
|
362
|
+
num_fused_shared_experts: int = 0,
|
363
|
+
custom_routing_function: Optional[Callable] = None,
|
364
|
+
correction_bias: Optional[torch.Tensor] = None,
|
365
|
+
activation: str = "silu",
|
366
|
+
apply_router_weight_on_input: bool = False,
|
367
|
+
inplace: bool = True,
|
368
|
+
no_combine: bool = False,
|
369
|
+
routed_scaling_factor: Optional[float] = None,
|
370
|
+
) -> torch.Tensor:
|
371
|
+
return moe_forward_native(
|
372
|
+
layer,
|
373
|
+
x,
|
374
|
+
use_grouped_topk,
|
375
|
+
top_k,
|
376
|
+
router_logits,
|
377
|
+
renormalize,
|
378
|
+
topk_group,
|
379
|
+
num_expert_group,
|
380
|
+
num_fused_shared_experts,
|
381
|
+
custom_routing_function,
|
382
|
+
correction_bias,
|
383
|
+
activation,
|
384
|
+
apply_router_weight_on_input,
|
385
|
+
inplace,
|
386
|
+
no_combine,
|
387
|
+
routed_scaling_factor,
|
388
|
+
)
|
389
|
+
|
324
390
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
325
391
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
326
392
|
|
@@ -438,9 +504,13 @@ class FusedMoE(torch.nn.Module):
|
|
438
504
|
self.inplace = inplace
|
439
505
|
self.no_combine = no_combine
|
440
506
|
|
507
|
+
self.use_triton_kernels = (
|
508
|
+
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
509
|
+
)
|
510
|
+
|
441
511
|
if quant_config is None:
|
442
|
-
self.quant_method: Optional[QuantizeMethodBase] = (
|
443
|
-
|
512
|
+
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
513
|
+
self.use_triton_kernels
|
444
514
|
)
|
445
515
|
else:
|
446
516
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
@@ -537,11 +607,6 @@ class FusedMoE(torch.nn.Module):
|
|
537
607
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
538
608
|
shard_size = expert_data.shape[shard_dim] // 2
|
539
609
|
|
540
|
-
if not self.use_presharded_weights:
|
541
|
-
loaded_weight = loaded_weight.narrow(
|
542
|
-
shard_dim, shard_size * tp_rank, shard_size
|
543
|
-
)
|
544
|
-
|
545
610
|
# Narrow parameter and load.
|
546
611
|
# w1, gate_proj: Load into first logical weight of w13.
|
547
612
|
# w3, up_proj: Load into second logical weight of w13.
|
@@ -552,7 +617,26 @@ class FusedMoE(torch.nn.Module):
|
|
552
617
|
start = shard_size
|
553
618
|
else:
|
554
619
|
start = 0
|
555
|
-
|
620
|
+
|
621
|
+
if _is_cpu:
|
622
|
+
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
623
|
+
expert_data,
|
624
|
+
loaded_weight,
|
625
|
+
start,
|
626
|
+
shard_size * tp_rank,
|
627
|
+
shard_dim,
|
628
|
+
shard_size,
|
629
|
+
not self.use_presharded_weights,
|
630
|
+
)
|
631
|
+
else:
|
632
|
+
if not self.use_presharded_weights:
|
633
|
+
if self.use_triton_kernels:
|
634
|
+
loaded_weight = loaded_weight.transpose(-2, -1)
|
635
|
+
loaded_weight = loaded_weight.narrow(
|
636
|
+
shard_dim, shard_size * tp_rank, shard_size
|
637
|
+
)
|
638
|
+
|
639
|
+
expert_data = expert_data.narrow(shard_dim, start, shard_size)
|
556
640
|
expert_data.copy_(loaded_weight)
|
557
641
|
|
558
642
|
def _load_w2(
|
@@ -563,16 +647,54 @@ class FusedMoE(torch.nn.Module):
|
|
563
647
|
loaded_weight: torch.tensor,
|
564
648
|
tp_rank: int,
|
565
649
|
):
|
650
|
+
"""Load w2 weights for down projection.
|
651
|
+
|
652
|
+
Args:
|
653
|
+
expert_data: The expert data tensor to load into
|
654
|
+
shard_dim: The dimension to shard along
|
655
|
+
shard_id: The shard ID (must be "w2")
|
656
|
+
loaded_weight: The weight tensor to load from
|
657
|
+
tp_rank: The tensor parallel rank
|
658
|
+
"""
|
659
|
+
if not isinstance(expert_data, torch.Tensor) or not isinstance(
|
660
|
+
loaded_weight, torch.Tensor
|
661
|
+
):
|
662
|
+
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
|
663
|
+
|
664
|
+
if expert_data.dim() != 2 or loaded_weight.dim() != 2:
|
665
|
+
raise ValueError(
|
666
|
+
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
|
667
|
+
)
|
668
|
+
|
669
|
+
if shard_id != "w2":
|
670
|
+
raise ValueError(f"shard_id must be 'w2', got {shard_id}")
|
566
671
|
|
567
672
|
# Index the loaded weight for tp sharding.
|
568
673
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
569
674
|
# Narrow parameter and load.
|
570
675
|
shard_size = expert_data.shape[shard_dim]
|
571
676
|
|
572
|
-
if
|
573
|
-
loaded_weight =
|
574
|
-
|
677
|
+
if _is_cpu:
|
678
|
+
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
679
|
+
expert_data,
|
680
|
+
loaded_weight,
|
681
|
+
0, # param_data_start
|
682
|
+
shard_size * tp_rank,
|
683
|
+
shard_dim,
|
684
|
+
shard_size,
|
685
|
+
not self.use_presharded_weights,
|
575
686
|
)
|
687
|
+
else:
|
688
|
+
if not self.use_presharded_weights:
|
689
|
+
if self.use_triton_kernels:
|
690
|
+
loaded_weight = loaded_weight.transpose(-2, -1)
|
691
|
+
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
|
692
|
+
raise ValueError(
|
693
|
+
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
|
694
|
+
)
|
695
|
+
loaded_weight = loaded_weight.narrow(
|
696
|
+
shard_dim, shard_size * tp_rank, shard_size
|
697
|
+
)
|
576
698
|
|
577
699
|
# w2, down_proj: Load into only logical weight of w2.
|
578
700
|
expert_data.copy_(loaded_weight)
|
@@ -656,6 +778,8 @@ class FusedMoE(torch.nn.Module):
|
|
656
778
|
# should be whatever dimension intermediate_size is
|
657
779
|
is_transposed = getattr(param, "is_transposed", False)
|
658
780
|
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
781
|
+
if self.use_triton_kernels:
|
782
|
+
is_transposed = True
|
659
783
|
if is_transposed:
|
660
784
|
shard_dim = int(not shard_dim)
|
661
785
|
|
@@ -694,8 +818,21 @@ class FusedMoE(torch.nn.Module):
|
|
694
818
|
tp_rank=tp_rank,
|
695
819
|
)
|
696
820
|
return
|
821
|
+
|
697
822
|
if "ModelOpt" in self.quant_method.__class__.__name__:
|
698
|
-
|
823
|
+
# Determine per-tensor weight scale patterns based on variant
|
824
|
+
is_fp4_variant = (
|
825
|
+
"ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
|
826
|
+
)
|
827
|
+
|
828
|
+
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
|
829
|
+
per_tensor_conditions = (
|
830
|
+
"weight_scale_2" in weight_name
|
831
|
+
if is_fp4_variant
|
832
|
+
else "weight_scale" in weight_name
|
833
|
+
) or "input_scale" in weight_name
|
834
|
+
|
835
|
+
if per_tensor_conditions:
|
699
836
|
self._load_per_tensor_weight_scale(
|
700
837
|
shard_id=shard_id,
|
701
838
|
param=param,
|
@@ -0,0 +1,176 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from sgl_kernel import gelu_and_mul, silu_and_mul
|
6
|
+
from triton_kernels.matmul_ogs import matmul_ogs
|
7
|
+
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
8
|
+
|
9
|
+
from sglang.srt.utils import direct_register_custom_op
|
10
|
+
|
11
|
+
|
12
|
+
def triton_kernel_moe_forward(
|
13
|
+
hidden_states: torch.Tensor,
|
14
|
+
w1: torch.Tensor,
|
15
|
+
w2: torch.Tensor,
|
16
|
+
gating_output: torch.Tensor,
|
17
|
+
topk: int,
|
18
|
+
renormalize: bool,
|
19
|
+
inplace: bool = False,
|
20
|
+
activation: str = "silu",
|
21
|
+
apply_router_weight_on_input: bool = False,
|
22
|
+
use_fp8_w8a8: bool = False,
|
23
|
+
per_channel_quant: bool = False,
|
24
|
+
global_num_experts: int = -1,
|
25
|
+
expert_map: Optional[torch.Tensor] = None,
|
26
|
+
w1_scale: Optional[torch.Tensor] = None,
|
27
|
+
w2_scale: Optional[torch.Tensor] = None,
|
28
|
+
a1_scale: Optional[torch.Tensor] = None,
|
29
|
+
a2_scale: Optional[torch.Tensor] = None,
|
30
|
+
block_shape: Optional[list[int]] = None,
|
31
|
+
) -> torch.Tensor:
|
32
|
+
|
33
|
+
if not renormalize:
|
34
|
+
gating_output = torch.softmax(gating_output, dim=-1)
|
35
|
+
routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
|
36
|
+
|
37
|
+
return triton_kernel_fused_experts(
|
38
|
+
hidden_states,
|
39
|
+
w1,
|
40
|
+
w2,
|
41
|
+
routing_data,
|
42
|
+
gather_idx,
|
43
|
+
scatter_idx,
|
44
|
+
inplace=inplace,
|
45
|
+
activation=activation,
|
46
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
47
|
+
use_fp8_w8a8=use_fp8_w8a8,
|
48
|
+
per_channel_quant=per_channel_quant,
|
49
|
+
global_num_experts=global_num_experts,
|
50
|
+
expert_map=expert_map,
|
51
|
+
w1_scale=w1_scale,
|
52
|
+
w2_scale=w2_scale,
|
53
|
+
a1_scale=a1_scale,
|
54
|
+
a2_scale=a2_scale,
|
55
|
+
block_shape=block_shape,
|
56
|
+
)
|
57
|
+
|
58
|
+
|
59
|
+
# This is a triton implementation of the fused_experts function
|
60
|
+
def triton_kernel_fused_experts(
|
61
|
+
hidden_states: torch.Tensor,
|
62
|
+
w1: torch.Tensor,
|
63
|
+
w2: torch.Tensor,
|
64
|
+
routing_data: RoutingData,
|
65
|
+
gather_indx: GatherIndx,
|
66
|
+
scatter_indx: ScatterIndx,
|
67
|
+
inplace: bool = False,
|
68
|
+
activation: str = "silu",
|
69
|
+
apply_router_weight_on_input: bool = False,
|
70
|
+
use_fp8_w8a8: bool = False,
|
71
|
+
per_channel_quant: bool = False,
|
72
|
+
global_num_experts: int = -1,
|
73
|
+
expert_map: Optional[torch.Tensor] = None,
|
74
|
+
w1_scale: Optional[torch.Tensor] = None,
|
75
|
+
w2_scale: Optional[torch.Tensor] = None,
|
76
|
+
a1_scale: Optional[torch.Tensor] = None,
|
77
|
+
a2_scale: Optional[torch.Tensor] = None,
|
78
|
+
block_shape: Optional[list[int]] = None,
|
79
|
+
) -> torch.Tensor:
|
80
|
+
|
81
|
+
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
82
|
+
assert per_channel_quant == False, "per_channel_quant is not supported"
|
83
|
+
assert expert_map == None, "expert_map is not supported"
|
84
|
+
assert w1_scale == None, "w1_scale is not supported"
|
85
|
+
assert w2_scale == None, "w2_scale is not supported"
|
86
|
+
assert a1_scale == None, "a1_scale is not supported"
|
87
|
+
assert a2_scale == None, "a2_scale is not supported"
|
88
|
+
assert block_shape == None, "block_shape is not supported"
|
89
|
+
|
90
|
+
# type check
|
91
|
+
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
92
|
+
assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
|
93
|
+
assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
|
94
|
+
|
95
|
+
# Shape check
|
96
|
+
assert hidden_states.ndim == 2, "hidden_states must be 2D"
|
97
|
+
assert (
|
98
|
+
hidden_states.shape[-1] == w1.shape[-2]
|
99
|
+
), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
|
100
|
+
assert (
|
101
|
+
w2.shape[-1] == w1.shape[1]
|
102
|
+
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
103
|
+
|
104
|
+
# feature check
|
105
|
+
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
|
106
|
+
|
107
|
+
M, K = hidden_states.shape
|
108
|
+
E, _, N = w1.shape
|
109
|
+
n_expts_act = routing_data.n_expts_act
|
110
|
+
dtype = hidden_states.dtype
|
111
|
+
|
112
|
+
if global_num_experts == -1:
|
113
|
+
global_num_experts = E
|
114
|
+
|
115
|
+
# consistent with default implementation
|
116
|
+
intermediate_cache2 = torch.empty(
|
117
|
+
(M * n_expts_act, N // 2), device="cuda", dtype=dtype
|
118
|
+
)
|
119
|
+
|
120
|
+
intermediate_cache1 = matmul_ogs(
|
121
|
+
hidden_states,
|
122
|
+
w1,
|
123
|
+
None,
|
124
|
+
routing_data,
|
125
|
+
gather_indx=gather_indx,
|
126
|
+
gammas=routing_data.gate_scal if apply_router_weight_on_input else None,
|
127
|
+
)
|
128
|
+
|
129
|
+
if activation == "silu":
|
130
|
+
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
131
|
+
elif activation == "gelu":
|
132
|
+
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
133
|
+
else:
|
134
|
+
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
135
|
+
|
136
|
+
intermediate_cache3 = matmul_ogs(
|
137
|
+
intermediate_cache2,
|
138
|
+
w2,
|
139
|
+
None,
|
140
|
+
routing_data,
|
141
|
+
scatter_indx=scatter_indx,
|
142
|
+
gammas=None if apply_router_weight_on_input else routing_data.gate_scal,
|
143
|
+
)
|
144
|
+
|
145
|
+
return intermediate_cache3
|
146
|
+
|
147
|
+
|
148
|
+
def triton_kernel_moe_forward_fake(
|
149
|
+
hidden_states: torch.Tensor,
|
150
|
+
w1: torch.Tensor,
|
151
|
+
w2: torch.Tensor,
|
152
|
+
gating_output: torch.Tensor,
|
153
|
+
topk: int,
|
154
|
+
renormalize: bool,
|
155
|
+
inplace: bool = False,
|
156
|
+
activation: str = "silu",
|
157
|
+
apply_router_weight_on_input: bool = False,
|
158
|
+
use_fp8_w8a8: bool = False,
|
159
|
+
per_channel_quant: bool = False,
|
160
|
+
global_num_experts: int = -1,
|
161
|
+
expert_map: Optional[torch.Tensor] = None,
|
162
|
+
w1_scale: Optional[torch.Tensor] = None,
|
163
|
+
w2_scale: Optional[torch.Tensor] = None,
|
164
|
+
a1_scale: Optional[torch.Tensor] = None,
|
165
|
+
a2_scale: Optional[torch.Tensor] = None,
|
166
|
+
block_shape: Optional[list[int]] = None,
|
167
|
+
) -> torch.Tensor:
|
168
|
+
return torch.empty_like(hidden_states)
|
169
|
+
|
170
|
+
|
171
|
+
direct_register_custom_op(
|
172
|
+
op_name="forward_cuda_triton",
|
173
|
+
op_func=triton_kernel_moe_forward,
|
174
|
+
mutates_args=[],
|
175
|
+
fake_impl=triton_kernel_moe_forward_fake,
|
176
|
+
)
|