sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +12 -1
- sglang/srt/conversation.py +35 -1
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/layers/communicator.py +3 -1
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
- sglang/srt/layers/moe/ep_moe/layer.py +140 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +8 -1
- sglang/srt/managers/mm_utils.py +4 -2
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/scheduler.py +17 -5
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +113 -63
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/models/deepseek_v2.py +16 -2
- sglang/srt/models/mllama4.py +360 -79
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +62 -60
- sglang/srt/server_args.py +15 -0
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +37 -17
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ from sglang.srt.distributed import (
|
|
12
12
|
)
|
13
13
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
14
14
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
15
|
+
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
15
16
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
16
17
|
ep_gather,
|
17
18
|
ep_scatter,
|
@@ -20,6 +21,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
20
21
|
moe_ep_deepgemm_preprocess,
|
21
22
|
post_reorder_triton_kernel,
|
22
23
|
pre_reorder_triton_kernel,
|
24
|
+
pre_reorder_triton_kernel_for_cutlass_moe,
|
25
|
+
run_cutlass_moe_ep_preproess,
|
23
26
|
run_moe_ep_preproess,
|
24
27
|
silu_and_mul_masked_post_quant_fwd,
|
25
28
|
silu_and_mul_triton_kernel,
|
@@ -41,6 +44,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
41
44
|
sglang_per_token_quant_fp8,
|
42
45
|
)
|
43
46
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
47
|
+
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
44
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
45
49
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
46
50
|
from sglang.srt.utils import (
|
@@ -191,7 +195,7 @@ class EPMoE(torch.nn.Module):
|
|
191
195
|
num_fused_shared_experts == 0
|
192
196
|
), "num_fused_shared_experts is not supported in EP"
|
193
197
|
self.num_fused_shared_experts = num_fused_shared_experts
|
194
|
-
self.num_experts_per_partition
|
198
|
+
self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
|
195
199
|
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
196
200
|
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
197
201
|
|
@@ -215,6 +219,18 @@ class EPMoE(torch.nn.Module):
|
|
215
219
|
self.use_block_quant = False
|
216
220
|
self.block_shape = None
|
217
221
|
self.activation_scheme = None
|
222
|
+
self.use_w4afp8 = False
|
223
|
+
elif isinstance(quant_config, W4AFp8Config):
|
224
|
+
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
|
225
|
+
quant_config
|
226
|
+
)
|
227
|
+
self.use_w4afp8 = True
|
228
|
+
self.use_fp8_w8a8 = False
|
229
|
+
self.use_block_quant = False
|
230
|
+
self.fp8_dtype = torch.float8_e4m3fn
|
231
|
+
self.w13_weight_scale = None
|
232
|
+
self.w2_weight_scale = None
|
233
|
+
self.activation_scheme = quant_config.moe_activation_scheme
|
218
234
|
else:
|
219
235
|
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
220
236
|
quant_config
|
@@ -228,6 +244,7 @@ class EPMoE(torch.nn.Module):
|
|
228
244
|
)
|
229
245
|
self.fp8_dtype = torch.float8_e4m3fn
|
230
246
|
self.activation_scheme = quant_config.activation_scheme
|
247
|
+
self.use_w4afp8 = False
|
231
248
|
|
232
249
|
self.quant_method.create_weights(
|
233
250
|
layer=self,
|
@@ -253,6 +270,49 @@ class EPMoE(torch.nn.Module):
|
|
253
270
|
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
254
271
|
)
|
255
272
|
|
273
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
|
274
|
+
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
|
275
|
+
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
|
276
|
+
"""
|
277
|
+
Calculates how many experts should be assigned to each rank for EP and
|
278
|
+
creates a mapping from global to local expert index. Experts are
|
279
|
+
distributed evenly across ranks. Any remaining are assigned to the
|
280
|
+
last rank.
|
281
|
+
|
282
|
+
Returns:
|
283
|
+
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
|
284
|
+
- local_num_experts (int): The number of experts assigned
|
285
|
+
to the current rank.
|
286
|
+
- expert_map (Optional[torch.Tensor]): A tensor of shape
|
287
|
+
(global_num_experts,) mapping from global to local index.
|
288
|
+
Contains global_num_experts for experts not assigned to the current rank.
|
289
|
+
Returns None if ep_size is 1.
|
290
|
+
"""
|
291
|
+
ep_size = self.tp_size
|
292
|
+
ep_rank = self.tp_rank
|
293
|
+
global_num_experts = self.num_experts
|
294
|
+
|
295
|
+
assert ep_size > 0
|
296
|
+
if ep_size == 1:
|
297
|
+
return (global_num_experts, None)
|
298
|
+
|
299
|
+
local_num_experts = global_num_experts // ep_size
|
300
|
+
|
301
|
+
expert_map = torch.full(
|
302
|
+
(global_num_experts,), self.num_experts, dtype=torch.int32
|
303
|
+
)
|
304
|
+
if ep_rank < (ep_size - 1):
|
305
|
+
expert_map[
|
306
|
+
ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
|
307
|
+
] = torch.arange(0, local_num_experts, dtype=torch.int32)
|
308
|
+
else:
|
309
|
+
local_num_experts = global_num_experts - ep_rank * local_num_experts
|
310
|
+
|
311
|
+
expert_map[-local_num_experts:] = torch.arange(
|
312
|
+
0, local_num_experts, dtype=torch.int32
|
313
|
+
)
|
314
|
+
return (local_num_experts, expert_map)
|
315
|
+
|
256
316
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
257
317
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
258
318
|
return self.forward_deepgemm(hidden_states, router_logits)
|
@@ -440,6 +500,51 @@ class EPMoE(torch.nn.Module):
|
|
440
500
|
),
|
441
501
|
)
|
442
502
|
|
503
|
+
if self.use_w4afp8:
|
504
|
+
local_topk_ids = topk_ids
|
505
|
+
if self.expert_map is not None:
|
506
|
+
"Translate info from expert_map to topk_ids"
|
507
|
+
local_topk_ids = torch.where(
|
508
|
+
self.expert_map[topk_ids] != self.num_experts,
|
509
|
+
self.expert_map[topk_ids],
|
510
|
+
self.num_experts,
|
511
|
+
)
|
512
|
+
|
513
|
+
output = cutlass_w4a8_moe(
|
514
|
+
self.start_expert_id,
|
515
|
+
self.end_expert_id,
|
516
|
+
self.num_experts,
|
517
|
+
hidden_states,
|
518
|
+
self.w13_weight,
|
519
|
+
self.w2_weight,
|
520
|
+
self.w13_weight_scale_inv,
|
521
|
+
self.w2_weight_scale_inv,
|
522
|
+
topk_weights,
|
523
|
+
topk_ids,
|
524
|
+
local_topk_ids,
|
525
|
+
self.quant_method.a_strides1,
|
526
|
+
self.quant_method.b_strides1,
|
527
|
+
self.quant_method.c_strides1,
|
528
|
+
self.quant_method.a_strides2,
|
529
|
+
self.quant_method.b_strides2,
|
530
|
+
self.quant_method.c_strides2,
|
531
|
+
self.quant_method.s_strides13,
|
532
|
+
self.quant_method.s_strides2,
|
533
|
+
self.quant_method.expert_offsets,
|
534
|
+
self.quant_method.problem_sizes1,
|
535
|
+
self.quant_method.problem_sizes2,
|
536
|
+
self.w13_input_scale,
|
537
|
+
self.w2_input_scale,
|
538
|
+
)
|
539
|
+
return output
|
540
|
+
|
541
|
+
if self.grouped_gemm_runner is None:
|
542
|
+
self.grouped_gemm_runner = GroupedGemmRunner(
|
543
|
+
hidden_states.device,
|
544
|
+
use_flashinfer=False, # TODO: use flashinfer
|
545
|
+
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
546
|
+
)
|
547
|
+
|
443
548
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
444
549
|
topk_ids, self.num_experts
|
445
550
|
)
|
@@ -449,7 +554,7 @@ class EPMoE(torch.nn.Module):
|
|
449
554
|
device=hidden_states.device,
|
450
555
|
dtype=(
|
451
556
|
self.fp8_dtype
|
452
|
-
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
557
|
+
if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant)
|
453
558
|
else hidden_states.dtype
|
454
559
|
),
|
455
560
|
)
|
@@ -656,6 +761,23 @@ class EPMoE(torch.nn.Module):
|
|
656
761
|
]
|
657
762
|
]
|
658
763
|
|
764
|
+
@classmethod
|
765
|
+
def make_expert_input_scale_params_mapping(
|
766
|
+
cls,
|
767
|
+
num_experts: int,
|
768
|
+
) -> List[Tuple[str, str, int, str]]:
|
769
|
+
# (param_name, weight_name, expert_id, shard_id)
|
770
|
+
return [
|
771
|
+
(
|
772
|
+
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
|
773
|
+
f"experts.{expert_id}.{shard_id}.",
|
774
|
+
expert_id,
|
775
|
+
shard_id,
|
776
|
+
)
|
777
|
+
for expert_id in range(num_experts)
|
778
|
+
for shard_id in ["w1", "w2", "w3"]
|
779
|
+
]
|
780
|
+
|
659
781
|
def weight_loader(
|
660
782
|
self,
|
661
783
|
param: torch.nn.Parameter,
|
@@ -727,6 +849,15 @@ class EPMoE(torch.nn.Module):
|
|
727
849
|
|
728
850
|
# Input scales can be loaded directly and should be equal.
|
729
851
|
if "input_scale" in weight_name:
|
852
|
+
if self.use_w4afp8:
|
853
|
+
if shard_id == "w1":
|
854
|
+
param_data[expert_id][0] = loaded_weight
|
855
|
+
elif shard_id == "w3":
|
856
|
+
param_data[expert_id][1] = loaded_weight
|
857
|
+
else:
|
858
|
+
param_data[expert_id] = loaded_weight
|
859
|
+
return
|
860
|
+
|
730
861
|
if (
|
731
862
|
(shard_id == "w1" or shard_id == "w3")
|
732
863
|
and param_data[expert_id] != 1
|
@@ -752,6 +883,13 @@ class EPMoE(torch.nn.Module):
|
|
752
883
|
] = loaded_weight
|
753
884
|
else: # w2
|
754
885
|
param_data[expert_id] = loaded_weight
|
886
|
+
elif self.use_w4afp8:
|
887
|
+
if shard_id == "w1":
|
888
|
+
param_data[expert_id][: self.intermediate_size, :] = loaded_weight
|
889
|
+
elif shard_id == "w3":
|
890
|
+
param_data[expert_id][self.intermediate_size :, :] = loaded_weight
|
891
|
+
else:
|
892
|
+
param_data[expert_id] = loaded_weight
|
755
893
|
# If we are in merged column case (gate_up_proj)
|
756
894
|
else:
|
757
895
|
if shard_id in ("w1", "w3"):
|
@@ -1737,6 +1737,7 @@ def fused_moe(
|
|
1737
1737
|
renormalize: bool,
|
1738
1738
|
inplace: bool = False,
|
1739
1739
|
activation: str = "silu",
|
1740
|
+
apply_router_weight_on_input: bool = False,
|
1740
1741
|
use_grouped_topk: bool = False,
|
1741
1742
|
num_expert_group: Optional[int] = None,
|
1742
1743
|
num_fused_shared_experts: int = 0,
|
@@ -1822,6 +1823,7 @@ def fused_moe(
|
|
1822
1823
|
topk_ids,
|
1823
1824
|
inplace=inplace,
|
1824
1825
|
activation=activation,
|
1826
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
1825
1827
|
use_fp8_w8a8=use_fp8_w8a8,
|
1826
1828
|
use_int8_w8a8=use_int8_w8a8,
|
1827
1829
|
use_int8_w8a16=use_int8_w8a16,
|
@@ -1,5 +1,6 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
2
2
|
|
3
|
+
import importlib
|
3
4
|
from abc import abstractmethod
|
4
5
|
from enum import Enum
|
5
6
|
from typing import Callable, List, Optional, Tuple
|
@@ -19,6 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
19
20
|
QuantizationConfig,
|
20
21
|
QuantizeMethodBase,
|
21
22
|
)
|
23
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
22
24
|
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
23
25
|
from sglang.srt.utils import (
|
24
26
|
cpu_has_amx_support,
|
@@ -29,8 +31,15 @@ from sglang.srt.utils import (
|
|
29
31
|
use_intel_amx_backend,
|
30
32
|
)
|
31
33
|
|
34
|
+
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
35
|
+
|
32
36
|
if torch.cuda.is_available():
|
33
37
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
38
|
+
|
39
|
+
if has_triton_kernels:
|
40
|
+
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
41
|
+
triton_kernel_moe_forward,
|
42
|
+
)
|
34
43
|
else:
|
35
44
|
fused_experts = None # type: ignore
|
36
45
|
|
@@ -87,6 +96,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|
87
96
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
88
97
|
"""MoE method without quantization."""
|
89
98
|
|
99
|
+
def __init__(self, use_triton_kernels: bool = False):
|
100
|
+
super().__init__()
|
101
|
+
self.use_triton_kernels = use_triton_kernels
|
102
|
+
|
90
103
|
def create_weights(
|
91
104
|
self,
|
92
105
|
layer: torch.nn.Module,
|
@@ -97,20 +110,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
97
110
|
**extra_weight_attrs,
|
98
111
|
):
|
99
112
|
# Fused gate_up_proj (column parallel)
|
113
|
+
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
114
|
+
if self.use_triton_kernels:
|
115
|
+
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
|
100
116
|
w13_weight = torch.nn.Parameter(
|
101
|
-
torch.empty(
|
102
|
-
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
103
|
-
),
|
117
|
+
torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
|
104
118
|
requires_grad=False,
|
105
119
|
)
|
106
120
|
layer.register_parameter("w13_weight", w13_weight)
|
107
121
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
108
122
|
|
109
123
|
# down_proj (row parallel)
|
124
|
+
w2_weight_n, w2_weight_k = (
|
125
|
+
hidden_size,
|
126
|
+
intermediate_size,
|
127
|
+
)
|
128
|
+
if self.use_triton_kernels:
|
129
|
+
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
|
110
130
|
w2_weight = torch.nn.Parameter(
|
111
|
-
torch.empty(
|
112
|
-
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
113
|
-
),
|
131
|
+
torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
|
114
132
|
requires_grad=False,
|
115
133
|
)
|
116
134
|
layer.register_parameter("w2_weight", w2_weight)
|
@@ -192,59 +210,72 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
192
210
|
no_combine: bool = False,
|
193
211
|
routed_scaling_factor: Optional[float] = None,
|
194
212
|
) -> torch.Tensor:
|
195
|
-
topk_weights, topk_ids = select_experts(
|
196
|
-
hidden_states=x,
|
197
|
-
router_logits=router_logits,
|
198
|
-
use_grouped_topk=use_grouped_topk,
|
199
|
-
top_k=top_k,
|
200
|
-
renormalize=renormalize,
|
201
|
-
topk_group=topk_group,
|
202
|
-
num_expert_group=num_expert_group,
|
203
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
204
|
-
custom_routing_function=custom_routing_function,
|
205
|
-
correction_bias=correction_bias,
|
206
|
-
routed_scaling_factor=routed_scaling_factor,
|
207
|
-
)
|
208
213
|
|
209
|
-
if
|
210
|
-
|
211
|
-
if apply_router_weight_on_input:
|
212
|
-
assert (
|
213
|
-
topk_weights.dim() == 2
|
214
|
-
), "`topk_weights` should be in shape (num_tokens, topk)"
|
215
|
-
_, topk = topk_weights.shape
|
216
|
-
assert (
|
217
|
-
topk == 1
|
218
|
-
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
219
|
-
x = x * topk_weights.to(x.dtype)
|
220
|
-
topk_weights = torch.ones_like(
|
221
|
-
topk_weights, dtype=torch.float32
|
222
|
-
) # topk_weights must be FP32 (float32)
|
223
|
-
|
224
|
-
return fused_moe(
|
225
|
-
x,
|
226
|
-
layer.w13_weight,
|
227
|
-
layer.w2_weight,
|
228
|
-
topk_weights,
|
229
|
-
topk_ids,
|
230
|
-
activation=(
|
231
|
-
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
232
|
-
),
|
233
|
-
)
|
234
|
-
else:
|
235
|
-
return fused_experts(
|
214
|
+
if self.use_triton_kernels:
|
215
|
+
return triton_kernel_moe_forward(
|
236
216
|
hidden_states=x,
|
237
217
|
w1=layer.w13_weight,
|
238
218
|
w2=layer.w2_weight,
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
219
|
+
gating_output=router_logits,
|
220
|
+
topk=top_k,
|
221
|
+
renormalize=renormalize,
|
222
|
+
)
|
223
|
+
else:
|
224
|
+
topk_weights, topk_ids = select_experts(
|
225
|
+
hidden_states=x,
|
226
|
+
router_logits=router_logits,
|
227
|
+
use_grouped_topk=use_grouped_topk,
|
228
|
+
top_k=top_k,
|
229
|
+
renormalize=renormalize,
|
230
|
+
topk_group=topk_group,
|
231
|
+
num_expert_group=num_expert_group,
|
232
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
233
|
+
custom_routing_function=custom_routing_function,
|
234
|
+
correction_bias=correction_bias,
|
245
235
|
routed_scaling_factor=routed_scaling_factor,
|
246
236
|
)
|
247
237
|
|
238
|
+
if _use_aiter:
|
239
|
+
assert not no_combine, "unsupported"
|
240
|
+
if apply_router_weight_on_input:
|
241
|
+
assert (
|
242
|
+
topk_weights.dim() == 2
|
243
|
+
), "`topk_weights` should be in shape (num_tokens, topk)"
|
244
|
+
_, topk = topk_weights.shape
|
245
|
+
assert (
|
246
|
+
topk == 1
|
247
|
+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
248
|
+
x = x * topk_weights.to(x.dtype)
|
249
|
+
topk_weights = torch.ones_like(
|
250
|
+
topk_weights, dtype=torch.float32
|
251
|
+
) # topk_weights must be FP32 (float32)
|
252
|
+
|
253
|
+
return fused_moe(
|
254
|
+
x,
|
255
|
+
layer.w13_weight,
|
256
|
+
layer.w2_weight,
|
257
|
+
topk_weights,
|
258
|
+
topk_ids,
|
259
|
+
activation=(
|
260
|
+
ActivationType.Silu
|
261
|
+
if activation == "silu"
|
262
|
+
else ActivationType.Gelu
|
263
|
+
),
|
264
|
+
)
|
265
|
+
else:
|
266
|
+
return fused_experts(
|
267
|
+
hidden_states=x,
|
268
|
+
w1=layer.w13_weight,
|
269
|
+
w2=layer.w2_weight,
|
270
|
+
topk_weights=topk_weights,
|
271
|
+
topk_ids=topk_ids,
|
272
|
+
inplace=inplace and not no_combine,
|
273
|
+
activation=activation,
|
274
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
275
|
+
no_combine=no_combine,
|
276
|
+
routed_scaling_factor=routed_scaling_factor,
|
277
|
+
)
|
278
|
+
|
248
279
|
def forward_cpu(
|
249
280
|
self,
|
250
281
|
layer: torch.nn.Module,
|
@@ -286,9 +317,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
286
317
|
x,
|
287
318
|
layer.w13_weight,
|
288
319
|
layer.w2_weight,
|
289
|
-
topk_weights
|
290
|
-
torch.float
|
291
|
-
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
|
320
|
+
topk_weights,
|
292
321
|
topk_ids,
|
293
322
|
False, # inplace # See [Note] inplace should be False in fused_experts.
|
294
323
|
False, # use_int8_w8a8
|
@@ -475,9 +504,13 @@ class FusedMoE(torch.nn.Module):
|
|
475
504
|
self.inplace = inplace
|
476
505
|
self.no_combine = no_combine
|
477
506
|
|
507
|
+
self.use_triton_kernels = (
|
508
|
+
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
509
|
+
)
|
510
|
+
|
478
511
|
if quant_config is None:
|
479
|
-
self.quant_method: Optional[QuantizeMethodBase] = (
|
480
|
-
|
512
|
+
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
513
|
+
self.use_triton_kernels
|
481
514
|
)
|
482
515
|
else:
|
483
516
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
@@ -597,6 +630,8 @@ class FusedMoE(torch.nn.Module):
|
|
597
630
|
)
|
598
631
|
else:
|
599
632
|
if not self.use_presharded_weights:
|
633
|
+
if self.use_triton_kernels:
|
634
|
+
loaded_weight = loaded_weight.transpose(-2, -1)
|
600
635
|
loaded_weight = loaded_weight.narrow(
|
601
636
|
shard_dim, shard_size * tp_rank, shard_size
|
602
637
|
)
|
@@ -612,6 +647,27 @@ class FusedMoE(torch.nn.Module):
|
|
612
647
|
loaded_weight: torch.tensor,
|
613
648
|
tp_rank: int,
|
614
649
|
):
|
650
|
+
"""Load w2 weights for down projection.
|
651
|
+
|
652
|
+
Args:
|
653
|
+
expert_data: The expert data tensor to load into
|
654
|
+
shard_dim: The dimension to shard along
|
655
|
+
shard_id: The shard ID (must be "w2")
|
656
|
+
loaded_weight: The weight tensor to load from
|
657
|
+
tp_rank: The tensor parallel rank
|
658
|
+
"""
|
659
|
+
if not isinstance(expert_data, torch.Tensor) or not isinstance(
|
660
|
+
loaded_weight, torch.Tensor
|
661
|
+
):
|
662
|
+
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
|
663
|
+
|
664
|
+
if expert_data.dim() != 2 or loaded_weight.dim() != 2:
|
665
|
+
raise ValueError(
|
666
|
+
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
|
667
|
+
)
|
668
|
+
|
669
|
+
if shard_id != "w2":
|
670
|
+
raise ValueError(f"shard_id must be 'w2', got {shard_id}")
|
615
671
|
|
616
672
|
# Index the loaded weight for tp sharding.
|
617
673
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
@@ -630,6 +686,12 @@ class FusedMoE(torch.nn.Module):
|
|
630
686
|
)
|
631
687
|
else:
|
632
688
|
if not self.use_presharded_weights:
|
689
|
+
if self.use_triton_kernels:
|
690
|
+
loaded_weight = loaded_weight.transpose(-2, -1)
|
691
|
+
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
|
692
|
+
raise ValueError(
|
693
|
+
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
|
694
|
+
)
|
633
695
|
loaded_weight = loaded_weight.narrow(
|
634
696
|
shard_dim, shard_size * tp_rank, shard_size
|
635
697
|
)
|
@@ -716,6 +778,8 @@ class FusedMoE(torch.nn.Module):
|
|
716
778
|
# should be whatever dimension intermediate_size is
|
717
779
|
is_transposed = getattr(param, "is_transposed", False)
|
718
780
|
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
781
|
+
if self.use_triton_kernels:
|
782
|
+
is_transposed = True
|
719
783
|
if is_transposed:
|
720
784
|
shard_dim = int(not shard_dim)
|
721
785
|
|
@@ -754,8 +818,21 @@ class FusedMoE(torch.nn.Module):
|
|
754
818
|
tp_rank=tp_rank,
|
755
819
|
)
|
756
820
|
return
|
821
|
+
|
757
822
|
if "ModelOpt" in self.quant_method.__class__.__name__:
|
758
|
-
|
823
|
+
# Determine per-tensor weight scale patterns based on variant
|
824
|
+
is_fp4_variant = (
|
825
|
+
"ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
|
826
|
+
)
|
827
|
+
|
828
|
+
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
|
829
|
+
per_tensor_conditions = (
|
830
|
+
"weight_scale_2" in weight_name
|
831
|
+
if is_fp4_variant
|
832
|
+
else "weight_scale" in weight_name
|
833
|
+
) or "input_scale" in weight_name
|
834
|
+
|
835
|
+
if per_tensor_conditions:
|
759
836
|
self._load_per_tensor_weight_scale(
|
760
837
|
shard_id=shard_id,
|
761
838
|
param=param,
|