sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,17 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
2
+ from __future__ import annotations
2
3
 
3
4
  import logging
4
- from typing import Any, Callable, Dict, List, Optional
5
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
5
6
 
6
7
  import torch
7
8
  from torch.nn.parameter import Parameter
8
9
 
9
- from sglang.srt.layers.linear import (
10
- LinearBase,
11
- LinearMethodBase,
12
- UnquantizedLinearMethod,
13
- )
14
10
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
15
11
  from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
16
12
  from sglang.srt.layers.quantization.base_config import (
13
+ FusedMoEMethodBase,
14
+ LinearMethodBase,
17
15
  QuantizationConfig,
18
16
  QuantizeMethodBase,
19
17
  )
@@ -23,6 +21,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
23
21
  is_sm100_supported,
24
22
  )
25
23
  from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
24
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
26
25
  from sglang.srt.layers.quantization.utils import (
27
26
  convert_to_channelwise,
28
27
  is_layer_skipped,
@@ -32,6 +31,9 @@ from sglang.srt.layers.quantization.utils import (
32
31
  from sglang.srt.layers.radix_attention import RadixAttention
33
32
  from sglang.srt.utils import is_cuda, next_power_of_2
34
33
 
34
+ if TYPE_CHECKING:
35
+ from sglang.srt.layers.moe.topk import TopKOutput
36
+
35
37
  if is_cuda():
36
38
  from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
37
39
 
@@ -86,7 +88,7 @@ class ModelOptFp8Config(QuantizationConfig):
86
88
  return ["hf_quant_config.json"]
87
89
 
88
90
  @classmethod
89
- def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
91
+ def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
90
92
  quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
91
93
  kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
92
94
  "kv_cache_quant_algo"
@@ -109,7 +111,11 @@ class ModelOptFp8Config(QuantizationConfig):
109
111
 
110
112
  def get_quant_method(
111
113
  self, layer: torch.nn.Module, prefix: str
112
- ) -> Optional["QuantizeMethodBase"]:
114
+ ) -> Optional[QuantizeMethodBase]:
115
+
116
+ from sglang.srt.layers.linear import LinearBase
117
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
118
+
113
119
  if self.exclude_modules and any(
114
120
  module in prefix
115
121
  or (
@@ -125,9 +131,6 @@ class ModelOptFp8Config(QuantizationConfig):
125
131
  if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
126
132
  return ModelOptFp8KVCacheMethod(self)
127
133
 
128
- # Add MoE support
129
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
130
-
131
134
  if isinstance(layer, FusedMoE):
132
135
  return ModelOptFp8MoEMethod(self)
133
136
 
@@ -246,7 +249,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
246
249
  super().__init__(quant_config)
247
250
 
248
251
 
249
- class ModelOptFp8MoEMethod:
252
+ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
250
253
  """MoE method for ModelOpt FP8.
251
254
  Supports loading FP8 checkpoints with static weight scale and activation scale.
252
255
 
@@ -254,30 +257,6 @@ class ModelOptFp8MoEMethod:
254
257
  quant_config: The ModelOpt quantization config.
255
258
  """
256
259
 
257
- def __new__(cls, *args, **kwargs):
258
- """
259
- Dynamic class composition pattern.
260
-
261
- This allows us to effectively "inject" FusedMoEMethodBase as a parent class
262
- at runtime while avoiding circular import issues.
263
- """
264
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
265
-
266
- if not hasattr(cls, "_initialized"):
267
- original_init = cls.__init__
268
- new_cls = type(
269
- cls.__name__,
270
- (FusedMoEMethodBase,),
271
- {
272
- "__init__": original_init,
273
- **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
274
- },
275
- )
276
- obj = super(new_cls, new_cls).__new__(new_cls)
277
- obj.__init__(*args, **kwargs)
278
- return obj
279
- return super().__new__(cls)
280
-
281
260
  def __init__(self, quant_config: ModelOptFp8Config):
282
261
  self.quant_config = quant_config
283
262
  self.cutlass_fp8_supported = cutlass_fp8_supported()
@@ -426,15 +405,8 @@ class ModelOptFp8MoEMethod:
426
405
  self,
427
406
  layer: torch.nn.Module,
428
407
  x: torch.Tensor,
429
- router_logits: torch.Tensor,
430
- top_k: int,
431
- renormalize: bool,
432
- use_grouped_topk: bool,
433
- topk_group: Optional[int] = None,
434
- num_expert_group: Optional[int] = None,
435
- num_fused_shared_experts: Optional[int] = None,
436
- custom_routing_function: Optional[Callable] = None,
437
- correction_bias: Optional[torch.Tensor] = None,
408
+ topk_output: TopKOutput,
409
+ *,
438
410
  activation: str = "silu",
439
411
  apply_router_weight_on_input: bool = False,
440
412
  inplace: bool = True,
@@ -442,29 +414,12 @@ class ModelOptFp8MoEMethod:
442
414
  routed_scaling_factor: Optional[float] = None,
443
415
  ) -> torch.Tensor:
444
416
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
445
- from sglang.srt.layers.moe.topk import select_experts
446
-
447
- # Expert selection
448
- topk_weights, topk_ids = select_experts(
449
- hidden_states=x,
450
- router_logits=router_logits,
451
- use_grouped_topk=use_grouped_topk,
452
- top_k=top_k,
453
- renormalize=renormalize,
454
- topk_group=topk_group,
455
- num_expert_group=num_expert_group,
456
- num_fused_shared_experts=num_fused_shared_experts,
457
- custom_routing_function=custom_routing_function,
458
- correction_bias=correction_bias,
459
- routed_scaling_factor=routed_scaling_factor,
460
- )
461
417
 
462
418
  return fused_experts(
463
419
  x,
464
420
  layer.w13_weight,
465
421
  layer.w2_weight,
466
- topk_weights=topk_weights,
467
- topk_ids=topk_ids,
422
+ topk_output=topk_output,
468
423
  inplace=inplace,
469
424
  activation=activation,
470
425
  use_fp8_w8a8=True,
@@ -514,7 +469,7 @@ class ModelOptFp4Config(QuantizationConfig):
514
469
  return ["hf_quant_config.json"]
515
470
 
516
471
  @classmethod
517
- def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp4Config":
472
+ def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
518
473
  quant_config = cls.get_from_keys(config, ["quantization"])
519
474
  quant_method = quant_config["quant_algo"]
520
475
  if not quant_method in ["FP8", "NVFP4"]:
@@ -559,7 +514,8 @@ class ModelOptFp4Config(QuantizationConfig):
559
514
 
560
515
  def get_quant_method(
561
516
  self, layer: torch.nn.Module, prefix: str
562
- ) -> Optional["QuantizeMethodBase"]:
517
+ ) -> Optional[QuantizeMethodBase]:
518
+ from sglang.srt.layers.linear import LinearBase
563
519
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
564
520
 
565
521
  if isinstance(layer, LinearBase):
@@ -740,31 +696,13 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
740
696
  return out.view(*output_shape)
741
697
 
742
698
 
743
- class ModelOptNvFp4FusedMoEMethod:
699
+ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
744
700
  """
745
701
  MoE Method for FP4 Quantization with Blockscales and PerTensorScales
746
702
  Args:
747
703
  quant_config: NVFP4 Quant Config
748
704
  """
749
705
 
750
- def __new__(cls, *args, **kwargs):
751
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
752
-
753
- if not hasattr(cls, "_initialized"):
754
- original_init = cls.__init__
755
- new_cls = type(
756
- cls.__name__,
757
- (FusedMoEMethodBase,),
758
- {
759
- "__init__": original_init,
760
- **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
761
- },
762
- )
763
- obj = super(new_cls, new_cls).__new__(new_cls)
764
- obj.__init__(*args, **kwargs)
765
- return obj
766
- return super().__new__(cls)
767
-
768
706
  def __init__(self, quant_config: ModelOptFp4Config):
769
707
  self.quant_config = quant_config
770
708
  if not is_sm100_supported():
@@ -1002,15 +940,8 @@ class ModelOptNvFp4FusedMoEMethod:
1002
940
  self,
1003
941
  layer: torch.nn.Module,
1004
942
  x: torch.Tensor,
1005
- router_logits: torch.Tensor,
1006
- top_k: int,
1007
- renormalize: bool,
1008
- use_grouped_topk: bool,
1009
- topk_group: Optional[int] = None,
1010
- num_expert_group: Optional[int] = None,
1011
- num_fused_shared_experts: Optional[int] = None,
1012
- custom_routing_function: Optional[Callable] = None,
1013
- correction_bias: Optional[torch.Tensor] = None,
943
+ topk_output: TopKOutput,
944
+ *,
1014
945
  activation: str = "silu",
1015
946
  apply_router_weight_on_input: bool = False,
1016
947
  inplace: bool = True,
@@ -1023,21 +954,6 @@ class ModelOptNvFp4FusedMoEMethod:
1023
954
  ) -> torch.Tensor:
1024
955
 
1025
956
  assert activation == "silu", "Only SiLU activation is supported."
1026
- from sglang.srt.layers.moe.topk import select_experts
1027
-
1028
- topk_weights, topk_ids = select_experts(
1029
- hidden_states=x,
1030
- router_logits=router_logits,
1031
- use_grouped_topk=use_grouped_topk,
1032
- top_k=top_k,
1033
- renormalize=renormalize,
1034
- topk_group=topk_group,
1035
- num_expert_group=num_expert_group,
1036
- num_fused_shared_experts=num_fused_shared_experts,
1037
- custom_routing_function=custom_routing_function,
1038
- correction_bias=correction_bias,
1039
- routed_scaling_factor=routed_scaling_factor,
1040
- )
1041
957
 
1042
958
  if self.enable_flashinfer_moe:
1043
959
  assert (
@@ -1045,6 +961,7 @@ class ModelOptNvFp4FusedMoEMethod:
1045
961
  ), "apply_router_weight_on_input is not supported for Flashinfer"
1046
962
  # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
1047
963
  # and fp4 quantized weights loaded from the checkpoint
964
+ topk_weights, topk_ids, _ = topk_output
1048
965
  output = flashinfer_cutlass_fused_moe(
1049
966
  x,
1050
967
  topk_ids.to(torch.int),
@@ -1070,6 +987,7 @@ class ModelOptNvFp4FusedMoEMethod:
1070
987
 
1071
988
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
1072
989
 
990
+ topk_weights, topk_ids, _ = topk_output
1073
991
  return cutlass_moe_fp4(
1074
992
  a=x,
1075
993
  a1_gscale=layer.w13_input_scale_quant,
@@ -1,23 +1,59 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/moe_wna16.py
2
+ from __future__ import annotations
2
3
 
3
4
  import logging
4
- from typing import Any, Callable, Dict, List, Optional
5
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
5
6
 
7
+ import numpy as np
6
8
  import torch
7
9
 
8
10
  from sglang.srt.distributed import get_tensor_model_parallel_rank
9
11
  from sglang.srt.distributed.parallel_state import get_tp_group
10
- from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
11
12
  from sglang.srt.layers.quantization.awq import AWQConfig
12
13
  from sglang.srt.layers.quantization.base_config import (
14
+ FusedMoEMethodBase,
13
15
  QuantizationConfig,
14
16
  QuantizeMethodBase,
15
17
  )
16
18
  from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
19
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
17
20
  from sglang.srt.utils import get_device_capability, set_weight_attrs
18
21
 
19
22
  logger = logging.getLogger(__name__)
20
23
 
24
+ if TYPE_CHECKING:
25
+ from sglang.srt.layers.moe.topk import TopKOutput
26
+
27
+
28
+ def get_weight_perm(num_bits: int):
29
+ perm_list: List[int] = []
30
+ for i in range(32):
31
+ perm1: List[int] = []
32
+ col = i // 4
33
+ for block in [0, 1]:
34
+ for row in [
35
+ 2 * (i % 4),
36
+ 2 * (i % 4) + 1,
37
+ 2 * (i % 4 + 4),
38
+ 2 * (i % 4 + 4) + 1,
39
+ ]:
40
+ perm1.append(16 * row + col + 8 * block)
41
+ for j in range(4):
42
+ perm_list.extend([p + 256 * j for p in perm1])
43
+
44
+ perm = np.array(perm_list)
45
+
46
+ if num_bits == 4:
47
+ interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
48
+ elif num_bits == 8:
49
+ interleave = np.array([0, 2, 1, 3])
50
+ else:
51
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
52
+
53
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
54
+ perm = torch.from_numpy(perm)
55
+ return perm
56
+
21
57
 
22
58
  class MoeWNA16Config(QuantizationConfig):
23
59
  """Config class for MOE WNA16 (W8A16/W4A16) quantization."""
@@ -88,7 +124,7 @@ class MoeWNA16Config(QuantizationConfig):
88
124
  raise NotImplementedError
89
125
 
90
126
  @classmethod
91
- def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
127
+ def from_config(cls, config: Dict[str, Any]) -> MoeWNA16Config:
92
128
  quant_method = cls.get_from_keys(config, ["quant_method"])
93
129
  weight_bits = cls.get_from_keys(config, ["bits"])
94
130
  group_size = cls.get_from_keys(config, ["group_size"])
@@ -147,8 +183,9 @@ class MoeWNA16Config(QuantizationConfig):
147
183
 
148
184
  def get_quant_method(
149
185
  self, layer: torch.nn.Module, prefix: str
150
- ) -> Optional["QuantizeMethodBase"]:
186
+ ) -> Optional[QuantizeMethodBase]:
151
187
  # avoid circular import
188
+ from sglang.srt.layers.linear import LinearBase
152
189
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
153
190
 
154
191
  if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
@@ -179,32 +216,13 @@ def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
179
216
  return any(module_name in prefix for module_name in modules_to_not_convert)
180
217
 
181
218
 
182
- class MoeWNA16Method:
219
+ class MoeWNA16Method(FusedMoEMethodBase):
183
220
  """Linear method for MOE WNA16 (W8A16/W4A16) quantization.
184
221
 
185
222
  Args:
186
223
  quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
187
224
  """
188
225
 
189
- def __new__(cls, *args, **kwargs):
190
- # avoid circular import
191
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
192
-
193
- if not hasattr(cls, "_initialized"):
194
- original_init = cls.__init__
195
- new_cls = type(
196
- cls.__name__,
197
- (FusedMoEMethodBase,),
198
- {
199
- "__init__": original_init,
200
- **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
201
- },
202
- )
203
- obj = super(new_cls, new_cls).__new__(new_cls)
204
- obj.__init__(*args, **kwargs)
205
- return obj
206
- return super().__new__(cls)
207
-
208
226
  def __init__(self, quant_config: MoeWNA16Config):
209
227
  self.quant_config = quant_config
210
228
 
@@ -334,15 +352,8 @@ class MoeWNA16Method:
334
352
  self,
335
353
  layer: torch.nn.Module,
336
354
  x: torch.Tensor,
337
- router_logits: torch.Tensor,
338
- top_k: int,
339
- renormalize: bool,
340
- use_grouped_topk: bool = False,
341
- topk_group: Optional[int] = None,
342
- num_expert_group: Optional[int] = None,
343
- num_fused_shared_experts: int = 0,
344
- custom_routing_function: Optional[Callable] = None,
345
- correction_bias: Optional[torch.Tensor] = None,
355
+ topk_output: TopKOutput,
356
+ *,
346
357
  activation: str = "silu",
347
358
  apply_router_weight_on_input: bool = False,
348
359
  inplace: bool = True,
@@ -351,22 +362,8 @@ class MoeWNA16Method:
351
362
  ) -> torch.Tensor:
352
363
  # avoid circular import
353
364
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
354
- from sglang.srt.layers.moe.topk import select_experts
355
365
 
356
366
  assert activation == "silu", "Only SiLU activation is supported."
357
- topk_weights, topk_ids = select_experts(
358
- hidden_states=x,
359
- router_logits=router_logits,
360
- top_k=top_k,
361
- use_grouped_topk=use_grouped_topk,
362
- renormalize=renormalize,
363
- topk_group=topk_group,
364
- num_expert_group=num_expert_group,
365
- num_fused_shared_experts=num_fused_shared_experts,
366
- custom_routing_function=custom_routing_function,
367
- correction_bias=correction_bias,
368
- routed_scaling_factor=routed_scaling_factor,
369
- )
370
367
 
371
368
  weight_bits = self.quant_config.weight_bits
372
369
  has_zp = self.quant_config.has_zp
@@ -375,8 +372,7 @@ class MoeWNA16Method:
375
372
  x,
376
373
  layer.w13_qweight,
377
374
  layer.w2_qweight,
378
- topk_weights=topk_weights,
379
- topk_ids=topk_ids,
375
+ topk_output=topk_output,
380
376
  inplace=inplace,
381
377
  apply_router_weight_on_input=apply_router_weight_on_input,
382
378
  use_int4_w4a16=weight_bits == 4,
@@ -0,0 +1,252 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
2
+
3
+
4
+ import logging
5
+ from typing import Any, Callable, Dict, List, Optional
6
+
7
+ import regex as re
8
+ import torch
9
+ from torch.nn.parameter import Parameter
10
+
11
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
12
+ from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
13
+ from sglang.srt.layers.quantization.base_config import (
14
+ LinearMethodBase,
15
+ QuantizationConfig,
16
+ QuantizeMethodBase,
17
+ )
18
+ from sglang.srt.layers.quantization.petit_utils import (
19
+ apply_petit_nvfp4_linear,
20
+ prepare_nvfp4_layer_for_petit,
21
+ verify_petit_nvfp4_supported,
22
+ )
23
+ from sglang.srt.layers.quantization.utils import is_layer_skipped
24
+ from sglang.srt.utils import is_hip
25
+
26
+ _is_hip = is_hip()
27
+
28
+ # Initialize logger for the module
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ # Configuration class to support the NVFP4 quantized model generated by the ModelOpt quantization tool
33
+ class PetitNvFp4Config(QuantizationConfig):
34
+ """Config class for Petit FP4."""
35
+
36
+ def __init__(
37
+ self,
38
+ is_checkpoint_nvfp4_serialized: bool = False,
39
+ kv_cache_quant_algo: str = None,
40
+ group_size: int = None,
41
+ exclude_modules: List[str] = None,
42
+ ) -> None:
43
+ self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
44
+ if is_checkpoint_nvfp4_serialized:
45
+ logger.warning(
46
+ "Detected nvfp4 checkpoint. Please note that the "
47
+ "format is experimental and subject to change."
48
+ )
49
+ self.group_size = group_size
50
+ self.kv_cache_quant_algo = kv_cache_quant_algo
51
+ self.exclude_modules = exclude_modules
52
+
53
+ @classmethod
54
+ def get_name(cls) -> str:
55
+ return "petit_nvfp4"
56
+
57
+ @classmethod
58
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
59
+ return [torch.bfloat16, torch.half]
60
+
61
+ @classmethod
62
+ def get_min_capability(cls) -> int:
63
+ # Petit supports the gfx90a and gfx942 GPUs
64
+ return 90
65
+
66
+ @classmethod
67
+ def get_config_filenames(cls) -> List[str]:
68
+ return ["hf_quant_config.json"]
69
+
70
+ @classmethod
71
+ def from_config(cls, config: Dict[str, Any]) -> "PetitNvFp4Config":
72
+ quant_config = cls.get_from_keys(config, ["quantization"])
73
+ quant_method = quant_config["quant_algo"]
74
+ group_size = quant_config.get("group_size", None)
75
+ verify_petit_nvfp4_supported(quant_method, group_size)
76
+
77
+ is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
78
+ kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
79
+ if not kv_cache_quant_algo:
80
+ kv_cache_quant_algo = "auto"
81
+ exclude_modules = quant_config.get("exclude_modules", None)
82
+ if not (group_size and kv_cache_quant_algo and (exclude_modules is not None)):
83
+ logger.warning(
84
+ f"group_size: {group_size},"
85
+ f"kv_cache_quant_algo: {kv_cache_quant_algo},"
86
+ f"exclude_modules: {exclude_modules}"
87
+ )
88
+ raise ValueError(
89
+ "NVFP4 quantization requires group size and "
90
+ "kv_cache_quant_algo specified in "
91
+ "hf_quant_config.json"
92
+ )
93
+ return cls(
94
+ is_checkpoint_nvfp4_serialized,
95
+ kv_cache_quant_algo,
96
+ group_size,
97
+ exclude_modules,
98
+ )
99
+
100
+ @classmethod
101
+ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
102
+ can_convert = cls.is_petit_nvfp4_compatible(hf_quant_cfg)
103
+ if can_convert:
104
+ return cls.get_name()
105
+ return None
106
+
107
+ @classmethod
108
+ def is_petit_nvfp4_compatible(cls, quant_config: Dict[str, Any]) -> bool:
109
+ quant_method = quant_config.get("quant_method", "").lower()
110
+ return _is_hip and quant_method == "modelopt"
111
+
112
+ def is_layer_excluded(self, prefix: str, exclude_modules: list):
113
+ for pattern in exclude_modules:
114
+ regex_str = pattern.replace(".", r"\.").replace("*", r".*")
115
+ if re.fullmatch(regex_str, prefix):
116
+ return True
117
+ return False
118
+
119
+ def get_quant_method(
120
+ self, layer: torch.nn.Module, prefix: str
121
+ ) -> Optional["QuantizeMethodBase"]:
122
+ if isinstance(layer, LinearBase):
123
+ if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
124
+ prefix, self.exclude_modules
125
+ ):
126
+ return UnquantizedLinearMethod()
127
+ return PetitNvFp4LinearMethod(self)
128
+ return None
129
+
130
+ def get_scaled_act_names(self) -> List[str]:
131
+ return []
132
+
133
+
134
+ class PetitNvFp4LinearMethod(LinearMethodBase):
135
+ """Linear method for NVFP4.
136
+ Supports loading NVFP4 checkpoints with the following structure:
137
+
138
+ |Tensor Name | datatype | shape |
139
+ |----------------------------------------------------|
140
+ |input_scale | torch.float32 | scalar |
141
+ |weight | NVFP4(SE2M1) | [1, X, y/2] |
142
+ |weight_scale | FP8-E4M3 | [X, Y] |
143
+ |weight_scale_2 | torch.float32 | scalar |
144
+
145
+ The weights are quantized per block of 16 elements.
146
+ Args: quant_config: The ModelOpt quantization config.
147
+ """
148
+
149
+ def __init__(self, quant_config: PetitNvFp4Config):
150
+ self.quant_config = quant_config
151
+
152
+ def create_weights(
153
+ self,
154
+ layer: torch.nn.Module,
155
+ input_size_per_partition: int,
156
+ output_partition_sizes: List[int],
157
+ input_size: int,
158
+ output_size: int,
159
+ params_dtype: torch.dtype,
160
+ **extra_weight_attrs,
161
+ ):
162
+ del input_size, output_size
163
+ if not self.quant_config.is_checkpoint_nvfp4_serialized:
164
+ raise ValueError(
165
+ "NVFP4 quantization was selected, "
166
+ " dynamic quantization is not supported."
167
+ )
168
+
169
+ output_size_per_partition = sum(output_partition_sizes)
170
+ weight_loader = extra_weight_attrs.get("weight_loader")
171
+
172
+ layer.logical_widths = output_partition_sizes
173
+
174
+ layer.input_size_per_partition = input_size_per_partition
175
+ layer.output_size_per_partition = output_size_per_partition
176
+ if input_size_per_partition % 16 != 0:
177
+ raise ValueError(
178
+ "Unsupported model when in features size is " "not multiple of 16"
179
+ )
180
+
181
+ weight_dtype = (
182
+ torch.float8_e4m3fn
183
+ if self.quant_config.is_checkpoint_nvfp4_serialized
184
+ else params_dtype
185
+ )
186
+
187
+ weight = ModelWeightParameter(
188
+ data=torch.empty(
189
+ # 2 fp4 data is packed in one uint8 in the input dimension
190
+ output_size_per_partition,
191
+ input_size_per_partition // 2,
192
+ dtype=torch.uint8,
193
+ ),
194
+ input_dim=1,
195
+ output_dim=0,
196
+ weight_loader=weight_loader,
197
+ )
198
+ layer.register_parameter("weight", weight)
199
+
200
+ input_scale = PerTensorScaleParameter(
201
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
202
+ weight_loader=weight_loader,
203
+ )
204
+
205
+ layer.register_parameter("input_scale", input_scale)
206
+
207
+ weight_scale_2 = PerTensorScaleParameter(
208
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
209
+ weight_loader=weight_loader,
210
+ )
211
+ layer.register_parameter("weight_scale_2", weight_scale_2)
212
+
213
+ weight_scale = ModelWeightParameter(
214
+ data=torch.empty(
215
+ output_size_per_partition,
216
+ input_size_per_partition // self.quant_config.group_size,
217
+ dtype=weight_dtype,
218
+ ),
219
+ input_dim=1,
220
+ output_dim=0,
221
+ weight_loader=weight_loader,
222
+ )
223
+
224
+ layer.register_parameter("weight_scale", weight_scale)
225
+
226
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
227
+ input_scale_2 = layer.input_scale.max().to(torch.float32)
228
+ weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
229
+ layer.input_scale = Parameter(input_scale_2, requires_grad=False)
230
+ layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
231
+ layer.alpha = Parameter(
232
+ layer.input_scale * layer.weight_scale_2, requires_grad=False
233
+ )
234
+
235
+ prepare_nvfp4_layer_for_petit(layer)
236
+ del layer.input_scale
237
+
238
+ def apply(
239
+ self,
240
+ layer: torch.nn.Module,
241
+ x: torch.Tensor,
242
+ bias: Optional[torch.Tensor] = None,
243
+ ) -> torch.Tensor:
244
+ return apply_petit_nvfp4_linear(
245
+ input=x,
246
+ weight=layer.weight,
247
+ weight_scale=layer.weight_scale,
248
+ weight_scale_2=layer.weight_scale_2,
249
+ size_n=layer.output_size_per_partition,
250
+ size_k=layer.input_size_per_partition,
251
+ bias=bias,
252
+ )