sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +12 -1
  3. sglang/srt/conversation.py +35 -1
  4. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  5. sglang/srt/entrypoints/http_server_engine.py +1 -1
  6. sglang/srt/layers/communicator.py +3 -1
  7. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  8. sglang/srt/layers/layernorm.py +2 -2
  9. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  10. sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
  11. sglang/srt/layers/moe/ep_moe/layer.py +140 -2
  12. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  13. sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
  14. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  15. sglang/srt/layers/quantization/__init__.py +2 -0
  16. sglang/srt/layers/quantization/fp8.py +28 -7
  17. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  18. sglang/srt/layers/quantization/w4afp8.py +264 -0
  19. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  20. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  21. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  22. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  23. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  24. sglang/srt/managers/cache_controller.py +41 -195
  25. sglang/srt/managers/io_struct.py +8 -1
  26. sglang/srt/managers/mm_utils.py +4 -2
  27. sglang/srt/managers/schedule_batch.py +1 -1
  28. sglang/srt/managers/scheduler.py +17 -5
  29. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  30. sglang/srt/mem_cache/memory_pool.py +113 -63
  31. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  32. sglang/srt/mem_cache/radix_cache.py +8 -4
  33. sglang/srt/models/deepseek_v2.py +16 -2
  34. sglang/srt/models/mllama4.py +360 -79
  35. sglang/srt/multimodal/mm_utils.py +2 -2
  36. sglang/srt/multimodal/processors/mllama4.py +62 -60
  37. sglang/srt/server_args.py +15 -0
  38. sglang/srt/two_batch_overlap.py +3 -0
  39. sglang/srt/utils.py +37 -17
  40. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  41. sglang/utils.py +5 -5
  42. sglang/version.py +1 -1
  43. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
  44. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
  45. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  46. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  47. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ from sglang.srt.distributed import (
12
12
  )
13
13
  from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
14
14
  from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
15
+ from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
15
16
  from sglang.srt.layers.moe.ep_moe.kernels import (
16
17
  ep_gather,
17
18
  ep_scatter,
@@ -20,6 +21,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
20
21
  moe_ep_deepgemm_preprocess,
21
22
  post_reorder_triton_kernel,
22
23
  pre_reorder_triton_kernel,
24
+ pre_reorder_triton_kernel_for_cutlass_moe,
25
+ run_cutlass_moe_ep_preproess,
23
26
  run_moe_ep_preproess,
24
27
  silu_and_mul_masked_post_quant_fwd,
25
28
  silu_and_mul_triton_kernel,
@@ -41,6 +44,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
41
44
  sglang_per_token_quant_fp8,
42
45
  )
43
46
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
47
+ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
44
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
45
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
46
50
  from sglang.srt.utils import (
@@ -191,7 +195,7 @@ class EPMoE(torch.nn.Module):
191
195
  num_fused_shared_experts == 0
192
196
  ), "num_fused_shared_experts is not supported in EP"
193
197
  self.num_fused_shared_experts = num_fused_shared_experts
194
- self.num_experts_per_partition = self.num_experts // self.tp_size
198
+ self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
195
199
  self.start_expert_id = self.tp_rank * self.num_experts_per_partition
196
200
  self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
197
201
 
@@ -215,6 +219,18 @@ class EPMoE(torch.nn.Module):
215
219
  self.use_block_quant = False
216
220
  self.block_shape = None
217
221
  self.activation_scheme = None
222
+ self.use_w4afp8 = False
223
+ elif isinstance(quant_config, W4AFp8Config):
224
+ self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
225
+ quant_config
226
+ )
227
+ self.use_w4afp8 = True
228
+ self.use_fp8_w8a8 = False
229
+ self.use_block_quant = False
230
+ self.fp8_dtype = torch.float8_e4m3fn
231
+ self.w13_weight_scale = None
232
+ self.w2_weight_scale = None
233
+ self.activation_scheme = quant_config.moe_activation_scheme
218
234
  else:
219
235
  self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
220
236
  quant_config
@@ -228,6 +244,7 @@ class EPMoE(torch.nn.Module):
228
244
  )
229
245
  self.fp8_dtype = torch.float8_e4m3fn
230
246
  self.activation_scheme = quant_config.activation_scheme
247
+ self.use_w4afp8 = False
231
248
 
232
249
  self.quant_method.create_weights(
233
250
  layer=self,
@@ -253,6 +270,49 @@ class EPMoE(torch.nn.Module):
253
270
  self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
254
271
  )
255
272
 
273
+ # Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
274
+ # Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
275
+ def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
276
+ """
277
+ Calculates how many experts should be assigned to each rank for EP and
278
+ creates a mapping from global to local expert index. Experts are
279
+ distributed evenly across ranks. Any remaining are assigned to the
280
+ last rank.
281
+
282
+ Returns:
283
+ Tuple[int, Optional[torch.Tensor]]: A tuple containing:
284
+ - local_num_experts (int): The number of experts assigned
285
+ to the current rank.
286
+ - expert_map (Optional[torch.Tensor]): A tensor of shape
287
+ (global_num_experts,) mapping from global to local index.
288
+ Contains global_num_experts for experts not assigned to the current rank.
289
+ Returns None if ep_size is 1.
290
+ """
291
+ ep_size = self.tp_size
292
+ ep_rank = self.tp_rank
293
+ global_num_experts = self.num_experts
294
+
295
+ assert ep_size > 0
296
+ if ep_size == 1:
297
+ return (global_num_experts, None)
298
+
299
+ local_num_experts = global_num_experts // ep_size
300
+
301
+ expert_map = torch.full(
302
+ (global_num_experts,), self.num_experts, dtype=torch.int32
303
+ )
304
+ if ep_rank < (ep_size - 1):
305
+ expert_map[
306
+ ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
307
+ ] = torch.arange(0, local_num_experts, dtype=torch.int32)
308
+ else:
309
+ local_num_experts = global_num_experts - ep_rank * local_num_experts
310
+
311
+ expert_map[-local_num_experts:] = torch.arange(
312
+ 0, local_num_experts, dtype=torch.int32
313
+ )
314
+ return (local_num_experts, expert_map)
315
+
256
316
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
257
317
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
258
318
  return self.forward_deepgemm(hidden_states, router_logits)
@@ -440,6 +500,51 @@ class EPMoE(torch.nn.Module):
440
500
  ),
441
501
  )
442
502
 
503
+ if self.use_w4afp8:
504
+ local_topk_ids = topk_ids
505
+ if self.expert_map is not None:
506
+ "Translate info from expert_map to topk_ids"
507
+ local_topk_ids = torch.where(
508
+ self.expert_map[topk_ids] != self.num_experts,
509
+ self.expert_map[topk_ids],
510
+ self.num_experts,
511
+ )
512
+
513
+ output = cutlass_w4a8_moe(
514
+ self.start_expert_id,
515
+ self.end_expert_id,
516
+ self.num_experts,
517
+ hidden_states,
518
+ self.w13_weight,
519
+ self.w2_weight,
520
+ self.w13_weight_scale_inv,
521
+ self.w2_weight_scale_inv,
522
+ topk_weights,
523
+ topk_ids,
524
+ local_topk_ids,
525
+ self.quant_method.a_strides1,
526
+ self.quant_method.b_strides1,
527
+ self.quant_method.c_strides1,
528
+ self.quant_method.a_strides2,
529
+ self.quant_method.b_strides2,
530
+ self.quant_method.c_strides2,
531
+ self.quant_method.s_strides13,
532
+ self.quant_method.s_strides2,
533
+ self.quant_method.expert_offsets,
534
+ self.quant_method.problem_sizes1,
535
+ self.quant_method.problem_sizes2,
536
+ self.w13_input_scale,
537
+ self.w2_input_scale,
538
+ )
539
+ return output
540
+
541
+ if self.grouped_gemm_runner is None:
542
+ self.grouped_gemm_runner = GroupedGemmRunner(
543
+ hidden_states.device,
544
+ use_flashinfer=False, # TODO: use flashinfer
545
+ use_per_token_if_dynamic=self.use_per_token_if_dynamic,
546
+ )
547
+
443
548
  reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
444
549
  topk_ids, self.num_experts
445
550
  )
@@ -449,7 +554,7 @@ class EPMoE(torch.nn.Module):
449
554
  device=hidden_states.device,
450
555
  dtype=(
451
556
  self.fp8_dtype
452
- if (self.use_fp8_w8a8 and not self.use_block_quant)
557
+ if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant)
453
558
  else hidden_states.dtype
454
559
  ),
455
560
  )
@@ -656,6 +761,23 @@ class EPMoE(torch.nn.Module):
656
761
  ]
657
762
  ]
658
763
 
764
+ @classmethod
765
+ def make_expert_input_scale_params_mapping(
766
+ cls,
767
+ num_experts: int,
768
+ ) -> List[Tuple[str, str, int, str]]:
769
+ # (param_name, weight_name, expert_id, shard_id)
770
+ return [
771
+ (
772
+ "experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
773
+ f"experts.{expert_id}.{shard_id}.",
774
+ expert_id,
775
+ shard_id,
776
+ )
777
+ for expert_id in range(num_experts)
778
+ for shard_id in ["w1", "w2", "w3"]
779
+ ]
780
+
659
781
  def weight_loader(
660
782
  self,
661
783
  param: torch.nn.Parameter,
@@ -727,6 +849,15 @@ class EPMoE(torch.nn.Module):
727
849
 
728
850
  # Input scales can be loaded directly and should be equal.
729
851
  if "input_scale" in weight_name:
852
+ if self.use_w4afp8:
853
+ if shard_id == "w1":
854
+ param_data[expert_id][0] = loaded_weight
855
+ elif shard_id == "w3":
856
+ param_data[expert_id][1] = loaded_weight
857
+ else:
858
+ param_data[expert_id] = loaded_weight
859
+ return
860
+
730
861
  if (
731
862
  (shard_id == "w1" or shard_id == "w3")
732
863
  and param_data[expert_id] != 1
@@ -752,6 +883,13 @@ class EPMoE(torch.nn.Module):
752
883
  ] = loaded_weight
753
884
  else: # w2
754
885
  param_data[expert_id] = loaded_weight
886
+ elif self.use_w4afp8:
887
+ if shard_id == "w1":
888
+ param_data[expert_id][: self.intermediate_size, :] = loaded_weight
889
+ elif shard_id == "w3":
890
+ param_data[expert_id][self.intermediate_size :, :] = loaded_weight
891
+ else:
892
+ param_data[expert_id] = loaded_weight
755
893
  # If we are in merged column case (gate_up_proj)
756
894
  else:
757
895
  if shard_id in ("w1", "w3"):
@@ -1737,6 +1737,7 @@ def fused_moe(
1737
1737
  renormalize: bool,
1738
1738
  inplace: bool = False,
1739
1739
  activation: str = "silu",
1740
+ apply_router_weight_on_input: bool = False,
1740
1741
  use_grouped_topk: bool = False,
1741
1742
  num_expert_group: Optional[int] = None,
1742
1743
  num_fused_shared_experts: int = 0,
@@ -1822,6 +1823,7 @@ def fused_moe(
1822
1823
  topk_ids,
1823
1824
  inplace=inplace,
1824
1825
  activation=activation,
1826
+ apply_router_weight_on_input=apply_router_weight_on_input,
1825
1827
  use_fp8_w8a8=use_fp8_w8a8,
1826
1828
  use_int8_w8a8=use_int8_w8a8,
1827
1829
  use_int8_w8a16=use_int8_w8a16,
@@ -1,5 +1,6 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
2
2
 
3
+ import importlib
3
4
  from abc import abstractmethod
4
5
  from enum import Enum
5
6
  from typing import Callable, List, Optional, Tuple
@@ -19,6 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
19
20
  QuantizationConfig,
20
21
  QuantizeMethodBase,
21
22
  )
23
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
22
24
  from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
23
25
  from sglang.srt.utils import (
24
26
  cpu_has_amx_support,
@@ -29,8 +31,15 @@ from sglang.srt.utils import (
29
31
  use_intel_amx_backend,
30
32
  )
31
33
 
34
+ has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
35
+
32
36
  if torch.cuda.is_available():
33
37
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
38
+
39
+ if has_triton_kernels:
40
+ from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
41
+ triton_kernel_moe_forward,
42
+ )
34
43
  else:
35
44
  fused_experts = None # type: ignore
36
45
 
@@ -87,6 +96,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
87
96
  class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
88
97
  """MoE method without quantization."""
89
98
 
99
+ def __init__(self, use_triton_kernels: bool = False):
100
+ super().__init__()
101
+ self.use_triton_kernels = use_triton_kernels
102
+
90
103
  def create_weights(
91
104
  self,
92
105
  layer: torch.nn.Module,
@@ -97,20 +110,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
97
110
  **extra_weight_attrs,
98
111
  ):
99
112
  # Fused gate_up_proj (column parallel)
113
+ w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
114
+ if self.use_triton_kernels:
115
+ w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
100
116
  w13_weight = torch.nn.Parameter(
101
- torch.empty(
102
- num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
103
- ),
117
+ torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
104
118
  requires_grad=False,
105
119
  )
106
120
  layer.register_parameter("w13_weight", w13_weight)
107
121
  set_weight_attrs(w13_weight, extra_weight_attrs)
108
122
 
109
123
  # down_proj (row parallel)
124
+ w2_weight_n, w2_weight_k = (
125
+ hidden_size,
126
+ intermediate_size,
127
+ )
128
+ if self.use_triton_kernels:
129
+ w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
110
130
  w2_weight = torch.nn.Parameter(
111
- torch.empty(
112
- num_experts, hidden_size, intermediate_size, dtype=params_dtype
113
- ),
131
+ torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
114
132
  requires_grad=False,
115
133
  )
116
134
  layer.register_parameter("w2_weight", w2_weight)
@@ -192,59 +210,72 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
192
210
  no_combine: bool = False,
193
211
  routed_scaling_factor: Optional[float] = None,
194
212
  ) -> torch.Tensor:
195
- topk_weights, topk_ids = select_experts(
196
- hidden_states=x,
197
- router_logits=router_logits,
198
- use_grouped_topk=use_grouped_topk,
199
- top_k=top_k,
200
- renormalize=renormalize,
201
- topk_group=topk_group,
202
- num_expert_group=num_expert_group,
203
- num_fused_shared_experts=num_fused_shared_experts,
204
- custom_routing_function=custom_routing_function,
205
- correction_bias=correction_bias,
206
- routed_scaling_factor=routed_scaling_factor,
207
- )
208
213
 
209
- if _use_aiter:
210
- assert not no_combine, "unsupported"
211
- if apply_router_weight_on_input:
212
- assert (
213
- topk_weights.dim() == 2
214
- ), "`topk_weights` should be in shape (num_tokens, topk)"
215
- _, topk = topk_weights.shape
216
- assert (
217
- topk == 1
218
- ), "Only support topk=1 when `apply_router_weight_on_input` is True"
219
- x = x * topk_weights.to(x.dtype)
220
- topk_weights = torch.ones_like(
221
- topk_weights, dtype=torch.float32
222
- ) # topk_weights must be FP32 (float32)
223
-
224
- return fused_moe(
225
- x,
226
- layer.w13_weight,
227
- layer.w2_weight,
228
- topk_weights,
229
- topk_ids,
230
- activation=(
231
- ActivationType.Silu if activation == "silu" else ActivationType.Gelu
232
- ),
233
- )
234
- else:
235
- return fused_experts(
214
+ if self.use_triton_kernels:
215
+ return triton_kernel_moe_forward(
236
216
  hidden_states=x,
237
217
  w1=layer.w13_weight,
238
218
  w2=layer.w2_weight,
239
- topk_weights=topk_weights,
240
- topk_ids=topk_ids,
241
- inplace=inplace and not no_combine,
242
- activation=activation,
243
- apply_router_weight_on_input=apply_router_weight_on_input,
244
- no_combine=no_combine,
219
+ gating_output=router_logits,
220
+ topk=top_k,
221
+ renormalize=renormalize,
222
+ )
223
+ else:
224
+ topk_weights, topk_ids = select_experts(
225
+ hidden_states=x,
226
+ router_logits=router_logits,
227
+ use_grouped_topk=use_grouped_topk,
228
+ top_k=top_k,
229
+ renormalize=renormalize,
230
+ topk_group=topk_group,
231
+ num_expert_group=num_expert_group,
232
+ num_fused_shared_experts=num_fused_shared_experts,
233
+ custom_routing_function=custom_routing_function,
234
+ correction_bias=correction_bias,
245
235
  routed_scaling_factor=routed_scaling_factor,
246
236
  )
247
237
 
238
+ if _use_aiter:
239
+ assert not no_combine, "unsupported"
240
+ if apply_router_weight_on_input:
241
+ assert (
242
+ topk_weights.dim() == 2
243
+ ), "`topk_weights` should be in shape (num_tokens, topk)"
244
+ _, topk = topk_weights.shape
245
+ assert (
246
+ topk == 1
247
+ ), "Only support topk=1 when `apply_router_weight_on_input` is True"
248
+ x = x * topk_weights.to(x.dtype)
249
+ topk_weights = torch.ones_like(
250
+ topk_weights, dtype=torch.float32
251
+ ) # topk_weights must be FP32 (float32)
252
+
253
+ return fused_moe(
254
+ x,
255
+ layer.w13_weight,
256
+ layer.w2_weight,
257
+ topk_weights,
258
+ topk_ids,
259
+ activation=(
260
+ ActivationType.Silu
261
+ if activation == "silu"
262
+ else ActivationType.Gelu
263
+ ),
264
+ )
265
+ else:
266
+ return fused_experts(
267
+ hidden_states=x,
268
+ w1=layer.w13_weight,
269
+ w2=layer.w2_weight,
270
+ topk_weights=topk_weights,
271
+ topk_ids=topk_ids,
272
+ inplace=inplace and not no_combine,
273
+ activation=activation,
274
+ apply_router_weight_on_input=apply_router_weight_on_input,
275
+ no_combine=no_combine,
276
+ routed_scaling_factor=routed_scaling_factor,
277
+ )
278
+
248
279
  def forward_cpu(
249
280
  self,
250
281
  layer: torch.nn.Module,
@@ -286,9 +317,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
286
317
  x,
287
318
  layer.w13_weight,
288
319
  layer.w2_weight,
289
- topk_weights.to(
290
- torch.float
291
- ), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
320
+ topk_weights,
292
321
  topk_ids,
293
322
  False, # inplace # See [Note] inplace should be False in fused_experts.
294
323
  False, # use_int8_w8a8
@@ -475,9 +504,13 @@ class FusedMoE(torch.nn.Module):
475
504
  self.inplace = inplace
476
505
  self.no_combine = no_combine
477
506
 
507
+ self.use_triton_kernels = (
508
+ not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
509
+ )
510
+
478
511
  if quant_config is None:
479
- self.quant_method: Optional[QuantizeMethodBase] = (
480
- UnquantizedFusedMoEMethod()
512
+ self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
513
+ self.use_triton_kernels
481
514
  )
482
515
  else:
483
516
  self.quant_method = quant_config.get_quant_method(self, prefix)
@@ -597,6 +630,8 @@ class FusedMoE(torch.nn.Module):
597
630
  )
598
631
  else:
599
632
  if not self.use_presharded_weights:
633
+ if self.use_triton_kernels:
634
+ loaded_weight = loaded_weight.transpose(-2, -1)
600
635
  loaded_weight = loaded_weight.narrow(
601
636
  shard_dim, shard_size * tp_rank, shard_size
602
637
  )
@@ -612,6 +647,27 @@ class FusedMoE(torch.nn.Module):
612
647
  loaded_weight: torch.tensor,
613
648
  tp_rank: int,
614
649
  ):
650
+ """Load w2 weights for down projection.
651
+
652
+ Args:
653
+ expert_data: The expert data tensor to load into
654
+ shard_dim: The dimension to shard along
655
+ shard_id: The shard ID (must be "w2")
656
+ loaded_weight: The weight tensor to load from
657
+ tp_rank: The tensor parallel rank
658
+ """
659
+ if not isinstance(expert_data, torch.Tensor) or not isinstance(
660
+ loaded_weight, torch.Tensor
661
+ ):
662
+ raise ValueError("expert_data and loaded_weight must be torch.Tensor")
663
+
664
+ if expert_data.dim() != 2 or loaded_weight.dim() != 2:
665
+ raise ValueError(
666
+ f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
667
+ )
668
+
669
+ if shard_id != "w2":
670
+ raise ValueError(f"shard_id must be 'w2', got {shard_id}")
615
671
 
616
672
  # Index the loaded weight for tp sharding.
617
673
  # down_proj: "RowParallel" so tp sharding on input_dim
@@ -630,6 +686,12 @@ class FusedMoE(torch.nn.Module):
630
686
  )
631
687
  else:
632
688
  if not self.use_presharded_weights:
689
+ if self.use_triton_kernels:
690
+ loaded_weight = loaded_weight.transpose(-2, -1)
691
+ if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
692
+ raise ValueError(
693
+ f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
694
+ )
633
695
  loaded_weight = loaded_weight.narrow(
634
696
  shard_dim, shard_size * tp_rank, shard_size
635
697
  )
@@ -716,6 +778,8 @@ class FusedMoE(torch.nn.Module):
716
778
  # should be whatever dimension intermediate_size is
717
779
  is_transposed = getattr(param, "is_transposed", False)
718
780
  shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
781
+ if self.use_triton_kernels:
782
+ is_transposed = True
719
783
  if is_transposed:
720
784
  shard_dim = int(not shard_dim)
721
785
 
@@ -754,8 +818,21 @@ class FusedMoE(torch.nn.Module):
754
818
  tp_rank=tp_rank,
755
819
  )
756
820
  return
821
+
757
822
  if "ModelOpt" in self.quant_method.__class__.__name__:
758
- if "weight_scale_2" in weight_name or "input_scale" in weight_name:
823
+ # Determine per-tensor weight scale patterns based on variant
824
+ is_fp4_variant = (
825
+ "ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
826
+ )
827
+
828
+ # FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
829
+ per_tensor_conditions = (
830
+ "weight_scale_2" in weight_name
831
+ if is_fp4_variant
832
+ else "weight_scale" in weight_name
833
+ ) or "input_scale" in weight_name
834
+
835
+ if per_tensor_conditions:
759
836
  self._load_per_tensor_weight_scale(
760
837
  shard_id=shard_id,
761
838
  param=param,