sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +36 -2
  3. sglang/srt/conversation.py +56 -3
  4. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  5. sglang/srt/disaggregation/ascend/conn.py +44 -0
  6. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +50 -18
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  9. sglang/srt/disaggregation/utils.py +25 -3
  10. sglang/srt/entrypoints/engine.py +1 -1
  11. sglang/srt/entrypoints/http_server.py +1 -0
  12. sglang/srt/entrypoints/http_server_engine.py +1 -1
  13. sglang/srt/entrypoints/openai/protocol.py +11 -0
  14. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  15. sglang/srt/function_call/function_call_parser.py +2 -0
  16. sglang/srt/function_call/kimik2_detector.py +220 -0
  17. sglang/srt/hf_transformers_utils.py +18 -0
  18. sglang/srt/jinja_template_utils.py +8 -0
  19. sglang/srt/layers/communicator.py +20 -5
  20. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  21. sglang/srt/layers/layernorm.py +2 -2
  22. sglang/srt/layers/linear.py +12 -2
  23. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  24. sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
  25. sglang/srt/layers/moe/ep_moe/layer.py +141 -2
  26. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  29. sglang/srt/layers/moe/topk.py +8 -2
  30. sglang/srt/layers/parameter.py +19 -3
  31. sglang/srt/layers/quantization/__init__.py +2 -0
  32. sglang/srt/layers/quantization/fp8.py +28 -7
  33. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  35. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  36. sglang/srt/layers/quantization/w4afp8.py +264 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  38. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  39. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  40. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  41. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  42. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  43. sglang/srt/managers/cache_controller.py +41 -195
  44. sglang/srt/managers/io_struct.py +35 -3
  45. sglang/srt/managers/mm_utils.py +59 -96
  46. sglang/srt/managers/schedule_batch.py +17 -6
  47. sglang/srt/managers/scheduler.py +38 -6
  48. sglang/srt/managers/tokenizer_manager.py +16 -0
  49. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  50. sglang/srt/mem_cache/memory_pool.py +176 -101
  51. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  52. sglang/srt/mem_cache/radix_cache.py +8 -4
  53. sglang/srt/model_executor/forward_batch_info.py +13 -1
  54. sglang/srt/model_loader/loader.py +23 -12
  55. sglang/srt/models/deepseek_janus_pro.py +1 -1
  56. sglang/srt/models/deepseek_v2.py +78 -19
  57. sglang/srt/models/deepseek_vl2.py +1 -1
  58. sglang/srt/models/gemma3_mm.py +1 -1
  59. sglang/srt/models/gemma3n_mm.py +6 -3
  60. sglang/srt/models/internvl.py +8 -2
  61. sglang/srt/models/kimi_vl.py +8 -2
  62. sglang/srt/models/llama.py +2 -0
  63. sglang/srt/models/llava.py +3 -1
  64. sglang/srt/models/llavavid.py +1 -1
  65. sglang/srt/models/minicpmo.py +1 -2
  66. sglang/srt/models/minicpmv.py +1 -1
  67. sglang/srt/models/mixtral_quant.py +4 -0
  68. sglang/srt/models/mllama4.py +372 -82
  69. sglang/srt/models/phi4mm.py +8 -2
  70. sglang/srt/models/phimoe.py +553 -0
  71. sglang/srt/models/qwen2.py +2 -0
  72. sglang/srt/models/qwen2_5_vl.py +10 -7
  73. sglang/srt/models/qwen2_vl.py +12 -1
  74. sglang/srt/models/vila.py +8 -2
  75. sglang/srt/multimodal/mm_utils.py +2 -2
  76. sglang/srt/multimodal/processors/base_processor.py +197 -137
  77. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  78. sglang/srt/multimodal/processors/gemma3.py +4 -2
  79. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  80. sglang/srt/multimodal/processors/internvl.py +1 -1
  81. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  82. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  83. sglang/srt/multimodal/processors/minicpm.py +4 -3
  84. sglang/srt/multimodal/processors/mllama4.py +63 -61
  85. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  86. sglang/srt/multimodal/processors/pixtral.py +1 -1
  87. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  88. sglang/srt/multimodal/processors/vila.py +1 -1
  89. sglang/srt/server_args.py +26 -4
  90. sglang/srt/two_batch_overlap.py +3 -0
  91. sglang/srt/utils.py +191 -48
  92. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  93. sglang/utils.py +5 -5
  94. sglang/version.py +1 -1
  95. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
  96. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
  97. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
2
2
 
3
+ import importlib
3
4
  from abc import abstractmethod
4
5
  from enum import Enum
5
6
  from typing import Callable, List, Optional, Tuple
@@ -19,6 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
19
20
  QuantizationConfig,
20
21
  QuantizeMethodBase,
21
22
  )
23
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
22
24
  from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
23
25
  from sglang.srt.utils import (
24
26
  cpu_has_amx_support,
@@ -29,8 +31,15 @@ from sglang.srt.utils import (
29
31
  use_intel_amx_backend,
30
32
  )
31
33
 
34
+ has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
35
+
32
36
  if torch.cuda.is_available():
33
37
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
38
+
39
+ if has_triton_kernels:
40
+ from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
41
+ triton_kernel_moe_forward,
42
+ )
34
43
  else:
35
44
  fused_experts = None # type: ignore
36
45
 
@@ -87,6 +96,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
87
96
  class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
88
97
  """MoE method without quantization."""
89
98
 
99
+ def __init__(self, use_triton_kernels: bool = False):
100
+ super().__init__()
101
+ self.use_triton_kernels = use_triton_kernels
102
+
90
103
  def create_weights(
91
104
  self,
92
105
  layer: torch.nn.Module,
@@ -97,20 +110,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
97
110
  **extra_weight_attrs,
98
111
  ):
99
112
  # Fused gate_up_proj (column parallel)
113
+ w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
114
+ if self.use_triton_kernels:
115
+ w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
100
116
  w13_weight = torch.nn.Parameter(
101
- torch.empty(
102
- num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
103
- ),
117
+ torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
104
118
  requires_grad=False,
105
119
  )
106
120
  layer.register_parameter("w13_weight", w13_weight)
107
121
  set_weight_attrs(w13_weight, extra_weight_attrs)
108
122
 
109
123
  # down_proj (row parallel)
124
+ w2_weight_n, w2_weight_k = (
125
+ hidden_size,
126
+ intermediate_size,
127
+ )
128
+ if self.use_triton_kernels:
129
+ w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
110
130
  w2_weight = torch.nn.Parameter(
111
- torch.empty(
112
- num_experts, hidden_size, intermediate_size, dtype=params_dtype
113
- ),
131
+ torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
114
132
  requires_grad=False,
115
133
  )
116
134
  layer.register_parameter("w2_weight", w2_weight)
@@ -192,59 +210,72 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
192
210
  no_combine: bool = False,
193
211
  routed_scaling_factor: Optional[float] = None,
194
212
  ) -> torch.Tensor:
195
- topk_weights, topk_ids = select_experts(
196
- hidden_states=x,
197
- router_logits=router_logits,
198
- use_grouped_topk=use_grouped_topk,
199
- top_k=top_k,
200
- renormalize=renormalize,
201
- topk_group=topk_group,
202
- num_expert_group=num_expert_group,
203
- num_fused_shared_experts=num_fused_shared_experts,
204
- custom_routing_function=custom_routing_function,
205
- correction_bias=correction_bias,
206
- routed_scaling_factor=routed_scaling_factor,
207
- )
208
213
 
209
- if _use_aiter:
210
- assert not no_combine, "unsupported"
211
- if apply_router_weight_on_input:
212
- assert (
213
- topk_weights.dim() == 2
214
- ), "`topk_weights` should be in shape (num_tokens, topk)"
215
- _, topk = topk_weights.shape
216
- assert (
217
- topk == 1
218
- ), "Only support topk=1 when `apply_router_weight_on_input` is True"
219
- x = x * topk_weights.to(x.dtype)
220
- topk_weights = torch.ones_like(
221
- topk_weights, dtype=torch.float32
222
- ) # topk_weights must be FP32 (float32)
223
-
224
- return fused_moe(
225
- x,
226
- layer.w13_weight,
227
- layer.w2_weight,
228
- topk_weights,
229
- topk_ids,
230
- activation=(
231
- ActivationType.Silu if activation == "silu" else ActivationType.Gelu
232
- ),
233
- )
234
- else:
235
- return fused_experts(
214
+ if self.use_triton_kernels:
215
+ return triton_kernel_moe_forward(
236
216
  hidden_states=x,
237
217
  w1=layer.w13_weight,
238
218
  w2=layer.w2_weight,
239
- topk_weights=topk_weights,
240
- topk_ids=topk_ids,
241
- inplace=inplace and not no_combine,
242
- activation=activation,
243
- apply_router_weight_on_input=apply_router_weight_on_input,
244
- no_combine=no_combine,
219
+ gating_output=router_logits,
220
+ topk=top_k,
221
+ renormalize=renormalize,
222
+ )
223
+ else:
224
+ topk_weights, topk_ids = select_experts(
225
+ hidden_states=x,
226
+ router_logits=router_logits,
227
+ use_grouped_topk=use_grouped_topk,
228
+ top_k=top_k,
229
+ renormalize=renormalize,
230
+ topk_group=topk_group,
231
+ num_expert_group=num_expert_group,
232
+ num_fused_shared_experts=num_fused_shared_experts,
233
+ custom_routing_function=custom_routing_function,
234
+ correction_bias=correction_bias,
245
235
  routed_scaling_factor=routed_scaling_factor,
246
236
  )
247
237
 
238
+ if _use_aiter:
239
+ assert not no_combine, "unsupported"
240
+ if apply_router_weight_on_input:
241
+ assert (
242
+ topk_weights.dim() == 2
243
+ ), "`topk_weights` should be in shape (num_tokens, topk)"
244
+ _, topk = topk_weights.shape
245
+ assert (
246
+ topk == 1
247
+ ), "Only support topk=1 when `apply_router_weight_on_input` is True"
248
+ x = x * topk_weights.to(x.dtype)
249
+ topk_weights = torch.ones_like(
250
+ topk_weights, dtype=torch.float32
251
+ ) # topk_weights must be FP32 (float32)
252
+
253
+ return fused_moe(
254
+ x,
255
+ layer.w13_weight,
256
+ layer.w2_weight,
257
+ topk_weights,
258
+ topk_ids,
259
+ activation=(
260
+ ActivationType.Silu
261
+ if activation == "silu"
262
+ else ActivationType.Gelu
263
+ ),
264
+ )
265
+ else:
266
+ return fused_experts(
267
+ hidden_states=x,
268
+ w1=layer.w13_weight,
269
+ w2=layer.w2_weight,
270
+ topk_weights=topk_weights,
271
+ topk_ids=topk_ids,
272
+ inplace=inplace and not no_combine,
273
+ activation=activation,
274
+ apply_router_weight_on_input=apply_router_weight_on_input,
275
+ no_combine=no_combine,
276
+ routed_scaling_factor=routed_scaling_factor,
277
+ )
278
+
248
279
  def forward_cpu(
249
280
  self,
250
281
  layer: torch.nn.Module,
@@ -286,9 +317,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
286
317
  x,
287
318
  layer.w13_weight,
288
319
  layer.w2_weight,
289
- topk_weights.to(
290
- torch.float
291
- ), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
320
+ topk_weights,
292
321
  topk_ids,
293
322
  False, # inplace # See [Note] inplace should be False in fused_experts.
294
323
  False, # use_int8_w8a8
@@ -475,9 +504,13 @@ class FusedMoE(torch.nn.Module):
475
504
  self.inplace = inplace
476
505
  self.no_combine = no_combine
477
506
 
507
+ self.use_triton_kernels = (
508
+ not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
509
+ )
510
+
478
511
  if quant_config is None:
479
- self.quant_method: Optional[QuantizeMethodBase] = (
480
- UnquantizedFusedMoEMethod()
512
+ self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
513
+ self.use_triton_kernels
481
514
  )
482
515
  else:
483
516
  self.quant_method = quant_config.get_quant_method(self, prefix)
@@ -485,6 +518,7 @@ class FusedMoE(torch.nn.Module):
485
518
  self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
486
519
  assert self.quant_method is not None
487
520
 
521
+ self.quant_config = quant_config
488
522
  self.quant_method.create_weights(
489
523
  layer=self,
490
524
  num_experts=self.local_num_experts,
@@ -597,6 +631,8 @@ class FusedMoE(torch.nn.Module):
597
631
  )
598
632
  else:
599
633
  if not self.use_presharded_weights:
634
+ if self.use_triton_kernels:
635
+ loaded_weight = loaded_weight.transpose(-2, -1)
600
636
  loaded_weight = loaded_weight.narrow(
601
637
  shard_dim, shard_size * tp_rank, shard_size
602
638
  )
@@ -612,6 +648,31 @@ class FusedMoE(torch.nn.Module):
612
648
  loaded_weight: torch.tensor,
613
649
  tp_rank: int,
614
650
  ):
651
+ """Load w2 weights for down projection.
652
+
653
+ Args:
654
+ expert_data: The expert data tensor to load into
655
+ shard_dim: The dimension to shard along
656
+ shard_id: The shard ID (must be "w2")
657
+ loaded_weight: The weight tensor to load from
658
+ tp_rank: The tensor parallel rank
659
+ """
660
+ if not isinstance(expert_data, torch.Tensor) or not isinstance(
661
+ loaded_weight, torch.Tensor
662
+ ):
663
+ raise ValueError("expert_data and loaded_weight must be torch.Tensor")
664
+
665
+ if (
666
+ self.quant_config is not None
667
+ and "modelopt" in self.quant_config.get_name()
668
+ and (expert_data.dim() != 2 or loaded_weight.dim() != 2)
669
+ ):
670
+ raise ValueError(
671
+ f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
672
+ )
673
+
674
+ if shard_id != "w2":
675
+ raise ValueError(f"shard_id must be 'w2', got {shard_id}")
615
676
 
616
677
  # Index the loaded weight for tp sharding.
617
678
  # down_proj: "RowParallel" so tp sharding on input_dim
@@ -630,6 +691,12 @@ class FusedMoE(torch.nn.Module):
630
691
  )
631
692
  else:
632
693
  if not self.use_presharded_weights:
694
+ if self.use_triton_kernels:
695
+ loaded_weight = loaded_weight.transpose(-2, -1)
696
+ if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
697
+ raise ValueError(
698
+ f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
699
+ )
633
700
  loaded_weight = loaded_weight.narrow(
634
701
  shard_dim, shard_size * tp_rank, shard_size
635
702
  )
@@ -716,6 +783,8 @@ class FusedMoE(torch.nn.Module):
716
783
  # should be whatever dimension intermediate_size is
717
784
  is_transposed = getattr(param, "is_transposed", False)
718
785
  shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
786
+ if self.use_triton_kernels:
787
+ is_transposed = True
719
788
  if is_transposed:
720
789
  shard_dim = int(not shard_dim)
721
790
 
@@ -754,8 +823,21 @@ class FusedMoE(torch.nn.Module):
754
823
  tp_rank=tp_rank,
755
824
  )
756
825
  return
826
+
757
827
  if "ModelOpt" in self.quant_method.__class__.__name__:
758
- if "weight_scale_2" in weight_name or "input_scale" in weight_name:
828
+ # Determine per-tensor weight scale patterns based on variant
829
+ is_fp4_variant = (
830
+ "ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
831
+ )
832
+
833
+ # FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
834
+ per_tensor_conditions = (
835
+ "weight_scale_2" in weight_name
836
+ if is_fp4_variant
837
+ else "weight_scale" in weight_name
838
+ ) or "input_scale" in weight_name
839
+
840
+ if per_tensor_conditions:
759
841
  self._load_per_tensor_weight_scale(
760
842
  shard_id=shard_id,
761
843
  param=param,
@@ -773,7 +855,7 @@ class FusedMoE(torch.nn.Module):
773
855
  return
774
856
 
775
857
  # Case weight scales and zero_points
776
- if "scale" in weight_name or "zero" in weight_name:
858
+ if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name:
777
859
  # load the weight scales and zp based on the quantization scheme
778
860
  # supported weight scales/zp can be found in
779
861
  # FusedMoeWeightScaleSupported
@@ -0,0 +1,176 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from sgl_kernel import gelu_and_mul, silu_and_mul
6
+ from triton_kernels.matmul_ogs import matmul_ogs
7
+ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
8
+
9
+ from sglang.srt.utils import direct_register_custom_op
10
+
11
+
12
+ def triton_kernel_moe_forward(
13
+ hidden_states: torch.Tensor,
14
+ w1: torch.Tensor,
15
+ w2: torch.Tensor,
16
+ gating_output: torch.Tensor,
17
+ topk: int,
18
+ renormalize: bool,
19
+ inplace: bool = False,
20
+ activation: str = "silu",
21
+ apply_router_weight_on_input: bool = False,
22
+ use_fp8_w8a8: bool = False,
23
+ per_channel_quant: bool = False,
24
+ global_num_experts: int = -1,
25
+ expert_map: Optional[torch.Tensor] = None,
26
+ w1_scale: Optional[torch.Tensor] = None,
27
+ w2_scale: Optional[torch.Tensor] = None,
28
+ a1_scale: Optional[torch.Tensor] = None,
29
+ a2_scale: Optional[torch.Tensor] = None,
30
+ block_shape: Optional[list[int]] = None,
31
+ ) -> torch.Tensor:
32
+
33
+ if not renormalize:
34
+ gating_output = torch.softmax(gating_output, dim=-1)
35
+ routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
36
+
37
+ return triton_kernel_fused_experts(
38
+ hidden_states,
39
+ w1,
40
+ w2,
41
+ routing_data,
42
+ gather_idx,
43
+ scatter_idx,
44
+ inplace=inplace,
45
+ activation=activation,
46
+ apply_router_weight_on_input=apply_router_weight_on_input,
47
+ use_fp8_w8a8=use_fp8_w8a8,
48
+ per_channel_quant=per_channel_quant,
49
+ global_num_experts=global_num_experts,
50
+ expert_map=expert_map,
51
+ w1_scale=w1_scale,
52
+ w2_scale=w2_scale,
53
+ a1_scale=a1_scale,
54
+ a2_scale=a2_scale,
55
+ block_shape=block_shape,
56
+ )
57
+
58
+
59
+ # This is a triton implementation of the fused_experts function
60
+ def triton_kernel_fused_experts(
61
+ hidden_states: torch.Tensor,
62
+ w1: torch.Tensor,
63
+ w2: torch.Tensor,
64
+ routing_data: RoutingData,
65
+ gather_indx: GatherIndx,
66
+ scatter_indx: ScatterIndx,
67
+ inplace: bool = False,
68
+ activation: str = "silu",
69
+ apply_router_weight_on_input: bool = False,
70
+ use_fp8_w8a8: bool = False,
71
+ per_channel_quant: bool = False,
72
+ global_num_experts: int = -1,
73
+ expert_map: Optional[torch.Tensor] = None,
74
+ w1_scale: Optional[torch.Tensor] = None,
75
+ w2_scale: Optional[torch.Tensor] = None,
76
+ a1_scale: Optional[torch.Tensor] = None,
77
+ a2_scale: Optional[torch.Tensor] = None,
78
+ block_shape: Optional[list[int]] = None,
79
+ ) -> torch.Tensor:
80
+
81
+ assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
82
+ assert per_channel_quant == False, "per_channel_quant is not supported"
83
+ assert expert_map == None, "expert_map is not supported"
84
+ assert w1_scale == None, "w1_scale is not supported"
85
+ assert w2_scale == None, "w2_scale is not supported"
86
+ assert a1_scale == None, "a1_scale is not supported"
87
+ assert a2_scale == None, "a2_scale is not supported"
88
+ assert block_shape == None, "block_shape is not supported"
89
+
90
+ # type check
91
+ assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
92
+ assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
93
+ assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
94
+
95
+ # Shape check
96
+ assert hidden_states.ndim == 2, "hidden_states must be 2D"
97
+ assert (
98
+ hidden_states.shape[-1] == w1.shape[-2]
99
+ ), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
100
+ assert (
101
+ w2.shape[-1] == w1.shape[1]
102
+ ), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
103
+
104
+ # feature check
105
+ assert inplace == False, "Inplace is not supported in new triton MoE kernel"
106
+
107
+ M, K = hidden_states.shape
108
+ E, _, N = w1.shape
109
+ n_expts_act = routing_data.n_expts_act
110
+ dtype = hidden_states.dtype
111
+
112
+ if global_num_experts == -1:
113
+ global_num_experts = E
114
+
115
+ # consistent with default implementation
116
+ intermediate_cache2 = torch.empty(
117
+ (M * n_expts_act, N // 2), device="cuda", dtype=dtype
118
+ )
119
+
120
+ intermediate_cache1 = matmul_ogs(
121
+ hidden_states,
122
+ w1,
123
+ None,
124
+ routing_data,
125
+ gather_indx=gather_indx,
126
+ gammas=routing_data.gate_scal if apply_router_weight_on_input else None,
127
+ )
128
+
129
+ if activation == "silu":
130
+ silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
131
+ elif activation == "gelu":
132
+ gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
133
+ else:
134
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}")
135
+
136
+ intermediate_cache3 = matmul_ogs(
137
+ intermediate_cache2,
138
+ w2,
139
+ None,
140
+ routing_data,
141
+ scatter_indx=scatter_indx,
142
+ gammas=None if apply_router_weight_on_input else routing_data.gate_scal,
143
+ )
144
+
145
+ return intermediate_cache3
146
+
147
+
148
+ def triton_kernel_moe_forward_fake(
149
+ hidden_states: torch.Tensor,
150
+ w1: torch.Tensor,
151
+ w2: torch.Tensor,
152
+ gating_output: torch.Tensor,
153
+ topk: int,
154
+ renormalize: bool,
155
+ inplace: bool = False,
156
+ activation: str = "silu",
157
+ apply_router_weight_on_input: bool = False,
158
+ use_fp8_w8a8: bool = False,
159
+ per_channel_quant: bool = False,
160
+ global_num_experts: int = -1,
161
+ expert_map: Optional[torch.Tensor] = None,
162
+ w1_scale: Optional[torch.Tensor] = None,
163
+ w2_scale: Optional[torch.Tensor] = None,
164
+ a1_scale: Optional[torch.Tensor] = None,
165
+ a2_scale: Optional[torch.Tensor] = None,
166
+ block_shape: Optional[list[int]] = None,
167
+ ) -> torch.Tensor:
168
+ return torch.empty_like(hidden_states)
169
+
170
+
171
+ direct_register_custom_op(
172
+ op_name="forward_cuda_triton",
173
+ op_func=triton_kernel_moe_forward,
174
+ mutates_args=[],
175
+ fake_impl=triton_kernel_moe_forward_fake,
176
+ )
@@ -83,13 +83,18 @@ def fused_topk_cpu(
83
83
  gating_output: torch.Tensor,
84
84
  topk: int,
85
85
  renormalize: bool,
86
+ num_token_non_padded: Optional[torch.Tensor] = None,
87
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
86
88
  ):
87
- return torch.ops.sgl_kernel.topk_softmax_cpu(
89
+ topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
88
90
  hidden_states=hidden_states,
89
91
  gating_output=gating_output,
90
92
  topk=topk,
91
93
  renormalize=renormalize,
92
94
  )
95
+ topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
96
+ _mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
97
+ return topk_weights, topk_ids
93
98
 
94
99
 
95
100
  def fused_topk(
@@ -303,7 +308,7 @@ def biased_grouped_topk_gpu(
303
308
  renormalize: bool,
304
309
  num_expert_group: int = 0,
305
310
  topk_group: int = 0,
306
- compiled: bool = True,
311
+ compiled: bool = not _is_npu,
307
312
  num_fused_shared_experts: int = 0,
308
313
  routed_scaling_factor: Optional[float] = None,
309
314
  num_token_non_padded: Optional[torch.Tensor] = None,
@@ -411,6 +416,7 @@ if _is_cpu and _is_cpu_amx_available:
411
416
  biased_grouped_topk = biased_grouped_topk_cpu
412
417
  grouped_topk = grouped_topk_cpu
413
418
  fused_topk_native = fused_topk_cpu
419
+ fused_topk = fused_topk_cpu
414
420
  else:
415
421
  biased_grouped_topk = biased_grouped_topk_gpu
416
422
  grouped_topk = grouped_topk_gpu
@@ -187,10 +187,26 @@ class _ColumnvLLMParameter(BasevLLMParameter):
187
187
  param_data = self.data
188
188
  shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
189
189
  param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
190
- if not use_presharded_weights:
191
- loaded_weight = loaded_weight.narrow(
192
- self.output_dim, shard_id * shard_size, shard_size
190
+
191
+ if _is_cpu:
192
+ from sglang.srt.model_loader.weight_utils import (
193
+ narrow_padded_param_and_loaded_weight,
194
+ )
195
+
196
+ param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
197
+ param_data,
198
+ loaded_weight,
199
+ 0, # param_data_start
200
+ shard_id * shard_size,
201
+ self.output_dim,
202
+ shard_size,
203
+ not use_presharded_weights,
193
204
  )
205
+ else:
206
+ if not use_presharded_weights:
207
+ loaded_weight = loaded_weight.narrow(
208
+ self.output_dim, shard_id * shard_size, shard_size
209
+ )
194
210
 
195
211
  assert (
196
212
  param_data.shape == loaded_weight.shape
@@ -68,6 +68,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
68
68
  )
69
69
  from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
70
70
  from sglang.srt.layers.quantization.qoq import QoQConfig
71
+ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
71
72
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
72
73
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
73
74
 
@@ -82,6 +83,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
82
83
  "moe_wna16": MoeWNA16Config,
83
84
  "compressed-tensors": CompressedTensorsConfig,
84
85
  "qoq": QoQConfig,
86
+ "w4afp8": W4AFp8Config,
85
87
  }
86
88
 
87
89
  # VLLM-dependent quantization methods
@@ -1,7 +1,7 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2
2
 
3
3
  import logging
4
- from typing import Any, Callable, Dict, List, Optional
4
+ from typing import Any, Callable, Dict, List, Optional, Union
5
5
 
6
6
  import torch
7
7
  import torch.nn.functional as F
@@ -88,7 +88,7 @@ _is_fp8_fnuz = is_fp8_fnuz()
88
88
  _use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
89
89
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
90
90
 
91
- if _is_hip:
91
+ if _is_hip and (_use_aiter or _use_hip_int4):
92
92
  from aiter import ActivationType, QuantType
93
93
  from aiter.fused_moe import fused_moe
94
94
  from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
@@ -200,7 +200,7 @@ class Fp8LinearMethod(LinearMethodBase):
200
200
  quant_config: The quantization config.
201
201
  """
202
202
 
203
- def __init__(self, quant_config: Fp8Config):
203
+ def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]):
204
204
  self.quant_config = quant_config
205
205
  self.cutlass_fp8_supported = cutlass_fp8_supported()
206
206
 
@@ -286,7 +286,10 @@ class Fp8LinearMethod(LinearMethodBase):
286
286
  if self.quant_config.is_checkpoint_fp8_serialized:
287
287
  # WEIGHT SCALE
288
288
  if self.block_quant:
289
- assert self.quant_config.activation_scheme == "dynamic"
289
+ if hasattr(self.quant_config, "activation_scheme"):
290
+ assert self.quant_config.activation_scheme == "dynamic"
291
+ elif hasattr(self.quant_config, "linear_activation_scheme"):
292
+ assert self.quant_config.linear_activation_scheme == "dynamic"
290
293
  scale = BlockQuantScaleParameter(
291
294
  data=torch.empty(
292
295
  (output_size_per_partition + block_n - 1) // block_n,
@@ -308,7 +311,13 @@ class Fp8LinearMethod(LinearMethodBase):
308
311
  layer.register_parameter("weight_scale", scale)
309
312
 
310
313
  # INPUT ACTIVATION SCALE
311
- if self.quant_config.activation_scheme == "static":
314
+ if (
315
+ hasattr(self.quant_config, "activation_scheme")
316
+ and self.quant_config.activation_scheme == "static"
317
+ ) or (
318
+ hasattr(self.quant_config, "linear_activation_scheme")
319
+ and self.quant_config.linear_activation_scheme == "static"
320
+ ):
312
321
  scale = PerTensorScaleParameter(
313
322
  data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
314
323
  weight_loader=weight_loader,
@@ -371,7 +380,13 @@ class Fp8LinearMethod(LinearMethodBase):
371
380
  layer.weight_scale = torch.nn.Parameter(
372
381
  layer.weight_scale.data, requires_grad=False
373
382
  )
374
- if self.quant_config.activation_scheme == "static":
383
+ if (
384
+ hasattr(self.quant_config, "activation_scheme")
385
+ and self.quant_config.activation_scheme == "static"
386
+ ) or (
387
+ hasattr(self.quant_config, "linear_activation_scheme")
388
+ and self.quant_config.linear_activation_scheme == "static"
389
+ ):
375
390
  layer.input_scale = torch.nn.Parameter(
376
391
  layer.input_scale.data, requires_grad=False
377
392
  )
@@ -405,7 +420,13 @@ class Fp8LinearMethod(LinearMethodBase):
405
420
  # Update layer with new values.
406
421
  layer.weight = Parameter(weight.t(), requires_grad=False)
407
422
  layer.weight_scale = Parameter(weight_scale, requires_grad=False)
408
- if self.quant_config.activation_scheme == "static":
423
+ if (
424
+ hasattr(self.quant_config, "activation_scheme")
425
+ and self.quant_config.activation_scheme == "static"
426
+ ) or (
427
+ hasattr(self.quant_config, "linear_activation_scheme")
428
+ and self.quant_config.linear_activation_scheme == "static"
429
+ ):
409
430
  layer.input_scale = Parameter(
410
431
  layer.input_scale.max(), requires_grad=False
411
432
  )
@@ -160,8 +160,8 @@ def _per_token_group_quant_fp8_colmajor(
160
160
  """
161
161
  # Map the program id to the row of X and Y it should compute.
162
162
  g_id = tl.program_id(0)
163
- y_ptr += g_id * group_size
164
- y_q_ptr += g_id * group_size
163
+ y_ptr += g_id.to(tl.int64) * group_size
164
+ y_q_ptr += g_id.to(tl.int64) * group_size
165
165
 
166
166
  # Convert g_id the flattened block coordinate to 2D so we can index
167
167
  # into the output y_scales matrix