sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,9 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import logging
4
- from typing import Any, Callable, Dict, List, Optional, Union
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
5
7
 
6
8
  import torch
7
9
  import torch.nn.functional as F
@@ -28,17 +30,14 @@ except ImportError:
28
30
 
29
31
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
30
32
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
31
- from sglang.srt.layers.linear import (
32
- LinearBase,
33
- LinearMethodBase,
34
- UnquantizedLinearMethod,
35
- )
36
33
  from sglang.srt.layers.parameter import (
37
34
  BlockQuantScaleParameter,
38
35
  ModelWeightParameter,
39
36
  PerTensorScaleParameter,
40
37
  )
41
38
  from sglang.srt.layers.quantization.base_config import (
39
+ FusedMoEMethodBase,
40
+ LinearMethodBase,
42
41
  QuantizationConfig,
43
42
  QuantizeMethodBase,
44
43
  )
@@ -56,6 +55,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
56
55
  normalize_e4m3fn_to_e4m3fnuz,
57
56
  )
58
57
  from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
58
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
59
59
  from sglang.srt.layers.quantization.utils import (
60
60
  all_close_1d,
61
61
  convert_to_channelwise,
@@ -77,6 +77,10 @@ from sglang.srt.utils import (
77
77
  use_intel_amx_backend,
78
78
  )
79
79
 
80
+ if TYPE_CHECKING:
81
+ from sglang.srt.layers.moe.topk import TopKOutput
82
+ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
83
+
80
84
  _is_hip = is_hip()
81
85
  _is_cuda = is_cuda()
82
86
  _is_npu = is_npu()
@@ -91,10 +95,9 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
91
95
  if _is_hip and (_use_aiter or _use_hip_int4):
92
96
  from aiter import ActivationType, QuantType
93
97
  from aiter.fused_moe import fused_moe
94
- from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
95
98
  from aiter.ops.shuffle import shuffle_weight
96
99
 
97
- if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
100
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
98
101
  from vllm._custom_ops import scaled_fp8_quant
99
102
 
100
103
 
@@ -152,7 +155,7 @@ class Fp8Config(QuantizationConfig):
152
155
  return []
153
156
 
154
157
  @classmethod
155
- def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
158
+ def from_config(cls, config: Dict[str, Any]) -> Fp8Config:
156
159
  quant_method = cls.get_from_keys(config, ["quant_method"])
157
160
  is_checkpoint_fp8_serialized = "fp8" in quant_method
158
161
  activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
@@ -167,7 +170,8 @@ class Fp8Config(QuantizationConfig):
167
170
 
168
171
  def get_quant_method(
169
172
  self, layer: torch.nn.Module, prefix: str
170
- ) -> Optional["QuantizeMethodBase"]:
173
+ ) -> Optional[QuantizeMethodBase]:
174
+ from sglang.srt.layers.linear import LinearBase
171
175
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
172
176
 
173
177
  if isinstance(layer, LinearBase):
@@ -200,7 +204,7 @@ class Fp8LinearMethod(LinearMethodBase):
200
204
  quant_config: The quantization config.
201
205
  """
202
206
 
203
- def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]):
207
+ def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]):
204
208
  self.quant_config = quant_config
205
209
  self.cutlass_fp8_supported = cutlass_fp8_supported()
206
210
 
@@ -486,7 +490,7 @@ class Fp8LinearMethod(LinearMethodBase):
486
490
  )
487
491
 
488
492
 
489
- class Fp8MoEMethod:
493
+ class Fp8MoEMethod(FusedMoEMethodBase):
490
494
  """MoE method for FP8.
491
495
  Supports loading FP8 checkpoints with static weight scale and
492
496
  dynamic/static activation scale.
@@ -499,25 +503,7 @@ class Fp8MoEMethod:
499
503
  quant_config: The quantization config.
500
504
  """
501
505
 
502
- def __new__(cls, *args, **kwargs):
503
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
504
-
505
- if not hasattr(cls, "_initialized"):
506
- original_init = cls.__init__
507
- new_cls = type(
508
- cls.__name__,
509
- (FusedMoEMethodBase,),
510
- {
511
- "__init__": original_init,
512
- **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
513
- },
514
- )
515
- obj = super(new_cls, new_cls).__new__(new_cls)
516
- obj.__init__(*args, **kwargs)
517
- return obj
518
- return super().__new__(cls)
519
-
520
- def __init__(self, quant_config):
506
+ def __init__(self, quant_config: Fp8Config):
521
507
  self.quant_config = quant_config
522
508
  self.block_quant = self.quant_config.weight_block_size is not None
523
509
  self.cutlass_fp8_supported = cutlass_fp8_supported()
@@ -985,15 +971,8 @@ class Fp8MoEMethod:
985
971
  self,
986
972
  layer: torch.nn.Module,
987
973
  x: torch.Tensor,
988
- router_logits: torch.Tensor,
989
- top_k: int,
990
- renormalize: bool,
991
- use_grouped_topk: bool,
992
- topk_group: Optional[int] = None,
993
- num_expert_group: Optional[int] = None,
994
- num_fused_shared_experts: int = 0,
995
- custom_routing_function: Optional[Callable] = None,
996
- correction_bias: Optional[torch.Tensor] = None,
974
+ topk_output: TopKOutput,
975
+ *,
997
976
  activation: str = "silu",
998
977
  apply_router_weight_on_input: bool = False,
999
978
  inplace: bool = True,
@@ -1001,24 +980,15 @@ class Fp8MoEMethod:
1001
980
  routed_scaling_factor: Optional[float] = None,
1002
981
  ) -> torch.Tensor:
1003
982
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
1004
- from sglang.srt.layers.moe.topk import select_experts
1005
-
1006
- # Expert selection
1007
- topk_weights, topk_ids = select_experts(
1008
- hidden_states=x,
1009
- router_logits=router_logits,
1010
- use_grouped_topk=use_grouped_topk,
1011
- top_k=top_k,
1012
- renormalize=renormalize,
1013
- topk_group=topk_group,
1014
- num_expert_group=num_expert_group,
1015
- num_fused_shared_experts=num_fused_shared_experts,
1016
- custom_routing_function=custom_routing_function,
1017
- correction_bias=correction_bias,
1018
- routed_scaling_factor=routed_scaling_factor,
1019
- )
1020
983
 
1021
984
  if use_intel_amx_backend(layer):
985
+ from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
986
+
987
+ topk_weights, topk_ids, _ = topk_output
988
+ x, topk_weights = apply_topk_weights_cpu(
989
+ apply_router_weight_on_input, topk_weights, x
990
+ )
991
+
1022
992
  return torch.ops.sgl_kernel.fused_experts_cpu(
1023
993
  x,
1024
994
  layer.w13_weight,
@@ -1040,8 +1010,7 @@ class Fp8MoEMethod:
1040
1010
  ret = self.maybe_apply_hip_fused_experts(
1041
1011
  layer,
1042
1012
  x,
1043
- topk_weights,
1044
- topk_ids,
1013
+ topk_output,
1045
1014
  activation,
1046
1015
  no_combine,
1047
1016
  )
@@ -1056,6 +1025,7 @@ class Fp8MoEMethod:
1056
1025
  ):
1057
1026
  from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
1058
1027
 
1028
+ topk_weights, topk_ids, _ = topk_output
1059
1029
  return cutlass_fused_experts_fp8(
1060
1030
  x,
1061
1031
  layer.w13_weight.transpose(1, 2),
@@ -1084,8 +1054,7 @@ class Fp8MoEMethod:
1084
1054
  x,
1085
1055
  layer.w13_weight,
1086
1056
  layer.w2_weight,
1087
- topk_weights=topk_weights,
1088
- topk_ids=topk_ids,
1057
+ topk_output=topk_output,
1089
1058
  inplace=inplace and not no_combine,
1090
1059
  activation=activation,
1091
1060
  apply_router_weight_on_input=apply_router_weight_on_input,
@@ -1109,11 +1078,11 @@ class Fp8MoEMethod:
1109
1078
  self,
1110
1079
  layer: torch.nn.Module,
1111
1080
  x: torch.Tensor,
1112
- topk_weights: torch.Tensor,
1113
- topk_ids: torch.Tensor,
1081
+ topk_output: TopKOutput,
1114
1082
  activation: str = "silu",
1115
1083
  no_combine: bool = False,
1116
1084
  ) -> Optional[torch.Tensor]:
1085
+ topk_weights, topk_ids, _ = topk_output
1117
1086
  if _use_hip_int4:
1118
1087
  # TODO: add triton kernel and add check _use_aiter
1119
1088
  assert not no_combine, f"{no_combine=} is not supported."
@@ -1169,6 +1138,248 @@ class Fp8MoEMethod:
1169
1138
  return None
1170
1139
 
1171
1140
 
1141
+ class Fp8EPMoEMethod(Fp8MoEMethod):
1142
+ """MoE method for FP8.
1143
+ Supports loading FP8 checkpoints with static weight scale and
1144
+ dynamic/static activation scale.
1145
+
1146
+ Args:
1147
+ quant_config: The quantization config.
1148
+ """
1149
+
1150
+ def __init__(self, quant_config: Fp8Config):
1151
+ self.quant_config = quant_config
1152
+ self.block_quant = self.quant_config.weight_block_size is not None
1153
+
1154
+ def create_weights(
1155
+ self,
1156
+ layer: Module,
1157
+ num_experts_per_partition: int,
1158
+ hidden_size: int,
1159
+ intermediate_size: int,
1160
+ params_dtype: torch.dtype,
1161
+ **extra_weight_attrs,
1162
+ ):
1163
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
1164
+
1165
+ if self.quant_config.is_checkpoint_fp8_serialized:
1166
+ params_dtype = torch.float8_e4m3fn
1167
+
1168
+ tp_size = get_tensor_model_parallel_world_size()
1169
+ if self.block_quant:
1170
+ block_n, block_k = (
1171
+ self.quant_config.weight_block_size[0],
1172
+ self.quant_config.weight_block_size[1],
1173
+ )
1174
+ # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
1175
+ # Required by column parallel or enabling merged weights
1176
+ if intermediate_size % block_n != 0:
1177
+ raise ValueError(
1178
+ f"The output_size of gate's and up's weight = "
1179
+ f"{intermediate_size} is not divisible by "
1180
+ f"weight quantization block_n = {block_n}."
1181
+ )
1182
+ if tp_size > 1:
1183
+ # Required by row parallel
1184
+ if intermediate_size % block_k != 0:
1185
+ raise ValueError(
1186
+ f"The input_size of down's weight = "
1187
+ f"{intermediate_size} is not divisible by "
1188
+ f"weight quantization block_k = {block_k}."
1189
+ )
1190
+
1191
+ # WEIGHTS
1192
+ w13_weight = torch.nn.Parameter(
1193
+ torch.empty(
1194
+ num_experts_per_partition,
1195
+ 2 * intermediate_size,
1196
+ hidden_size,
1197
+ dtype=params_dtype,
1198
+ ),
1199
+ requires_grad=False,
1200
+ )
1201
+ layer.register_parameter("w13_weight", w13_weight)
1202
+ set_weight_attrs(w13_weight, extra_weight_attrs)
1203
+
1204
+ w2_weight = torch.nn.Parameter(
1205
+ torch.empty(
1206
+ num_experts_per_partition,
1207
+ hidden_size,
1208
+ intermediate_size,
1209
+ dtype=params_dtype,
1210
+ ),
1211
+ requires_grad=False,
1212
+ )
1213
+ layer.register_parameter("w2_weight", w2_weight)
1214
+ set_weight_attrs(w2_weight, extra_weight_attrs)
1215
+
1216
+ # WEIGHT_SCALES
1217
+ if self.block_quant:
1218
+ w13_weight_scale = torch.nn.Parameter(
1219
+ torch.ones(
1220
+ num_experts_per_partition,
1221
+ 2 * ((intermediate_size + block_n - 1) // block_n),
1222
+ (hidden_size + block_k - 1) // block_k,
1223
+ dtype=torch.float32,
1224
+ ),
1225
+ requires_grad=False,
1226
+ )
1227
+ w2_weight_scale = torch.nn.Parameter(
1228
+ torch.ones(
1229
+ num_experts_per_partition,
1230
+ (hidden_size + block_n - 1) // block_n,
1231
+ (intermediate_size + block_k - 1) // block_k,
1232
+ dtype=torch.float32,
1233
+ ),
1234
+ requires_grad=False,
1235
+ )
1236
+ layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
1237
+ layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
1238
+ assert self.quant_config.activation_scheme == "dynamic"
1239
+ else:
1240
+ # WEIGHT_SCALES
1241
+ # Allocate 2 scales for w1 and w3 respectively.
1242
+ w13_weight_scale = torch.nn.Parameter(
1243
+ torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
1244
+ requires_grad=False,
1245
+ )
1246
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
1247
+
1248
+ w2_weight_scale = torch.nn.Parameter(
1249
+ torch.ones(num_experts_per_partition, dtype=torch.float32),
1250
+ requires_grad=False,
1251
+ )
1252
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
1253
+ # Add the quantization method used (per tensor/grouped/channel)
1254
+ # to ensure the weight scales are loaded in properly
1255
+ extra_weight_attrs.update(
1256
+ {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
1257
+ if self.block_quant
1258
+ else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
1259
+ )
1260
+ # If loading fp8 checkpoint, pass the weight loaders.
1261
+ # If loading an fp16 checkpoint, do not (we will quantize in
1262
+ # process_weights_after_loading()
1263
+ if self.quant_config.is_checkpoint_fp8_serialized:
1264
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
1265
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
1266
+
1267
+ # INPUT_SCALES
1268
+ if self.quant_config.activation_scheme == "static":
1269
+ if not self.quant_config.is_checkpoint_fp8_serialized:
1270
+ raise ValueError(
1271
+ "Found static activation scheme for checkpoint that "
1272
+ "was not serialized fp8."
1273
+ )
1274
+
1275
+ w13_input_scale = torch.nn.Parameter(
1276
+ torch.ones(num_experts_per_partition, dtype=torch.float32),
1277
+ requires_grad=False,
1278
+ )
1279
+ layer.register_parameter("w13_input_scale", w13_input_scale)
1280
+ set_weight_attrs(w13_input_scale, extra_weight_attrs)
1281
+
1282
+ w2_input_scale = torch.nn.Parameter(
1283
+ torch.ones(num_experts_per_partition, dtype=torch.float32),
1284
+ requires_grad=False,
1285
+ )
1286
+ layer.register_parameter("w2_input_scale", w2_input_scale)
1287
+ set_weight_attrs(w2_input_scale, extra_weight_attrs)
1288
+
1289
+ else:
1290
+ layer.w13_input_scale = None
1291
+ layer.w2_input_scale = None
1292
+
1293
+ def process_weights_after_loading(self, layer: Module) -> None:
1294
+
1295
+ # If checkpoint is fp16, quantize in place.
1296
+ if not self.quant_config.is_checkpoint_fp8_serialized:
1297
+ # If rocm, use float8_e4m3fnuz as dtype
1298
+ fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
1299
+ w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
1300
+ w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
1301
+
1302
+ layer.w13_weight_scale = torch.nn.Parameter(
1303
+ torch.ones(
1304
+ layer.num_experts_per_partition,
1305
+ dtype=torch.float32,
1306
+ device=w13_weight.device,
1307
+ ),
1308
+ requires_grad=False,
1309
+ )
1310
+
1311
+ for expert in range(layer.num_experts_per_partition):
1312
+ w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
1313
+ scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
1314
+ )
1315
+ w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
1316
+ scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
1317
+ )
1318
+ layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
1319
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
1320
+ return
1321
+
1322
+ # If checkpoint is fp8, we need to handle that the
1323
+ # MoE kernels require single activation scale and single weight
1324
+ # scale for w13 per expert.
1325
+ else:
1326
+ if self.quant_config.activation_scheme == "static":
1327
+ if layer.w13_input_scale is None or layer.w2_input_scale is None:
1328
+ raise ValueError(
1329
+ "QuantConfig has static quantization, but found "
1330
+ "activation scales are None."
1331
+ )
1332
+ layer.w13_weight_scale = torch.nn.Parameter(
1333
+ torch.max(layer.w13_weight_scale, dim=1).values,
1334
+ requires_grad=False,
1335
+ )
1336
+ if self.block_quant:
1337
+ # If ROCm, normalize the weights and scales to e4m3fnuz
1338
+ if _is_fp8_fnuz:
1339
+ # activation_scheme: dynamic
1340
+ w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1341
+ weight=layer.w13_weight,
1342
+ weight_scale=layer.w13_weight_scale_inv,
1343
+ input_scale=None,
1344
+ )
1345
+ w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1346
+ weight=layer.w2_weight,
1347
+ weight_scale=layer.w2_weight_scale_inv,
1348
+ input_scale=None,
1349
+ )
1350
+ # Reset the parameter
1351
+ layer.w13_weight = torch.nn.Parameter(
1352
+ w13_weight, requires_grad=False
1353
+ )
1354
+ layer.w13_weight_scale_inv = torch.nn.Parameter(
1355
+ w13_weight_scale, requires_grad=False
1356
+ )
1357
+ layer.w13_input_scale = None
1358
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
1359
+ layer.w2_weight_scale_inv = torch.nn.Parameter(
1360
+ w2_weight_scale, requires_grad=False
1361
+ )
1362
+ layer.w2_input_scale = None
1363
+ if _use_aiter:
1364
+ layer.w13_weight = torch.nn.Parameter(
1365
+ shuffle_weight(layer.w13_weight.data, (16, 16)),
1366
+ requires_grad=False,
1367
+ )
1368
+ layer.w2_weight = torch.nn.Parameter(
1369
+ shuffle_weight(layer.w2_weight.data, (16, 16)),
1370
+ requires_grad=False,
1371
+ )
1372
+ return
1373
+
1374
+ def apply(
1375
+ self,
1376
+ layer: torch.nn.Module,
1377
+ hidden_states: torch.Tensor,
1378
+ topk_output: TopKOutput,
1379
+ ) -> torch.Tensor:
1380
+ raise NotImplementedError
1381
+
1382
+
1172
1383
  class Fp8KVCacheMethod(BaseKVCacheMethod):
1173
1384
  """
1174
1385
  Supports loading kv-cache scaling factors from FP8 checkpoints.