sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,557 +0,0 @@
1
- # SPDX-License-Identifier: Apache-2.0
2
-
3
- from __future__ import annotations
4
-
5
- import fnmatch
6
- import logging
7
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast
8
-
9
- import aiter
10
- import torch
11
- import torch.nn.functional as F
12
- from aiter import ActivationType, QuantType, dtypes
13
- from aiter.fused_moe import fused_moe
14
- from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
15
- from aiter.ops.gemm_op_a4w4 import gemm_a4w4
16
- from aiter.ops.quant import get_torch_quant
17
- from aiter.ops.shuffle import shuffle_weight
18
- from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
19
- from aiter.ops.triton.quant import dynamic_mxfp4_quant
20
- from aiter.utility.fp4_utils import e8m0_shuffle
21
- from torch.nn import Module
22
-
23
- from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
24
- from sglang.srt.layers.parameter import ModelWeightParameter
25
- from sglang.srt.layers.quantization.base_config import (
26
- FusedMoEMethodBase,
27
- LinearMethodBase,
28
- QuantizationConfig,
29
- QuantizeMethodBase,
30
- )
31
- from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
32
- from sglang.srt.layers.quantization.quark.schemes import QuarkScheme, QuarkW4A4MXFP4
33
- from sglang.srt.layers.quantization.quark.utils import deep_compare, should_ignore_layer
34
- from sglang.srt.layers.radix_attention import RadixAttention
35
- from sglang.srt.utils import (
36
- get_bool_env_var,
37
- get_device_capability,
38
- log_info_on_rank0,
39
- mxfp_supported,
40
- set_weight_attrs,
41
- )
42
-
43
- if TYPE_CHECKING:
44
- from sglang.srt.layers.moe.topk import TopKOutput
45
-
46
- logger = logging.getLogger(__name__)
47
-
48
- use_dynamic_mxfp4_linear = get_bool_env_var("SGLANG_USE_DYNAMIC_MXFP4_linear")
49
-
50
- OCP_MX_BLOCK_SIZE = 32
51
-
52
-
53
- class Mxfp4Config(QuantizationConfig):
54
-
55
- def __init__(self, ignored_layers: Optional[list[str]] = None):
56
- super().__init__()
57
- self.ignored_layers = ignored_layers
58
-
59
- @classmethod
60
- def from_config(cls, config):
61
- return cls()
62
-
63
- @classmethod
64
- def get_min_capability(cls) -> int:
65
- return 80
66
-
67
- @classmethod
68
- def get_name(cls) -> QuantizationMethods:
69
- return "mxfp4"
70
-
71
- @classmethod
72
- def get_supported_act_dtypes(cls) -> list[torch.dtype]:
73
- return [torch.bfloat16]
74
-
75
- @classmethod
76
- def get_config_filenames(cls) -> list[str]:
77
- return []
78
-
79
- def get_quant_method(
80
- self, layer: torch.nn.Module, prefix: str
81
- ) -> Optional["QuantizeMethodBase"]:
82
- from vllm.attention.layer import Attention # Avoid circular import
83
-
84
- if isinstance(layer, LinearBase):
85
- if self.ignored_layers and is_layer_skipped(
86
- prefix=prefix,
87
- ignored_layers=self.ignored_layers,
88
- fused_mapping=self.packed_modules_mapping,
89
- ):
90
- return UnquantizedLinearMethod()
91
- raise NotImplementedError("Mxfp4 linear layer is not implemented")
92
- elif isinstance(layer, FusedMoE):
93
- return Mxfp4MoEMethod(layer.moe_config)
94
- elif isinstance(layer, Attention):
95
- raise NotImplementedError("Mxfp4 attention layer is not implemented")
96
- return None
97
-
98
-
99
- class MxFp4LinearMethod(LinearMethodBase):
100
-
101
- def __init__(self, quantization_config: MxFp4Config):
102
- self.quantization_config = quantization_config
103
-
104
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
105
- return
106
- # if self.quantization_config.is_checkpoint_fp4_serialized:
107
- # layer.scheme.process_weights_after_loading(layer)
108
- # else:
109
- # #w, w_scales = dynamic_mxfp4_quant(layer.weight.data)
110
- # ##log_info_on_rank0(logger, f"w.shape: {w.shape}")
111
-
112
- # #wshuffle = w#shuffle_weight(w, layout=(16, 16))
113
- # #w_scales_shuffle = w_scales#e8m0_shuffle(w_scales).view(dtypes.fp8_e8m0)
114
-
115
- # quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
116
-
117
- # w, w_scales_shuffle = quant_func(layer.weight.data, shuffle=True)
118
-
119
- # wshuffle = shuffle_weight(w, layout=(16, 16))
120
-
121
- # layer.weight = torch.nn.Parameter(wshuffle,
122
- # requires_grad=False)
123
- # layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
124
- # requires_grad=False)
125
-
126
- def create_weights(
127
- self,
128
- layer: torch.nn.Module,
129
- input_size_per_partition: int,
130
- output_partition_sizes: list[int],
131
- input_size: int,
132
- output_size: int,
133
- params_dtype: torch.dtype,
134
- **extra_weight_attrs,
135
- ):
136
- """
137
- Use the CompressedTensorsScheme associated with each layer to create
138
- the necessary parameters for the layer. See LinearMethodBase for param
139
- details
140
- """
141
- weight_loader = extra_weight_attrs.get("weight_loader")
142
-
143
- if self.quantization_config.is_checkpoint_fp4_serialized:
144
- layer.scheme.create_weights(
145
- layer=layer,
146
- input_size=input_size,
147
- input_size_per_partition=input_size_per_partition,
148
- output_partition_sizes=output_partition_sizes,
149
- output_size=output_size,
150
- params_dtype=params_dtype,
151
- weight_loader=weight_loader,
152
- )
153
- else:
154
- output_size_per_partition = sum(output_partition_sizes)
155
- layer.logical_widths = output_partition_sizes
156
- layer.input_size_per_partition = input_size_per_partition
157
- layer.output_size_per_partition = output_size_per_partition
158
- layer.orig_dtype = params_dtype
159
-
160
- weight_dtype = params_dtype
161
-
162
- weight = ModelWeightParameter(
163
- data=torch.empty(
164
- output_size_per_partition,
165
- input_size_per_partition,
166
- dtype=weight_dtype,
167
- ),
168
- input_dim=1,
169
- output_dim=0,
170
- weight_loader=weight_loader,
171
- )
172
-
173
- layer.register_parameter("weight", weight)
174
- layer.register_parameter("weight_scale", None)
175
-
176
- def apply(
177
- self,
178
- layer: torch.nn.Module,
179
- x: torch.Tensor,
180
- bias: Optional[torch.Tensor] = None,
181
- ):
182
- """
183
- Use the output of create_weights and the CompressedTensorsScheme
184
- associated with the layer to apply the forward pass with the
185
- layer input. See LinearMethodBase for param details
186
-
187
- """
188
- if self.quantization_config.is_checkpoint_fp4_serialized:
189
- scheme = layer.scheme
190
- if scheme is None:
191
- raise ValueError("A scheme must be defined for each layer")
192
- return scheme.apply_weights(layer, x, bias=bias)
193
- else:
194
- out_dtype = x.dtype
195
-
196
- # ck or asm implement
197
- # M = x.shape[0]
198
- # N = layer.weight.shape[0]
199
-
200
- # quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
201
-
202
- # x, x_scales_shuffle = quant_func(x, shuffle=True)
203
-
204
- # y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=out_dtype)
205
-
206
- # out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
207
-
208
- # return out[:M]
209
-
210
- # triton implement
211
- x_q, x_s = dynamic_mxfp4_quant(x)
212
- y = torch.empty(
213
- x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
214
- )
215
-
216
- out = gemm_afp4wfp4(
217
- x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y
218
- )
219
-
220
- return out
221
-
222
-
223
- class MxFp4MoEMethod:
224
- def __new__(cls, *args, **kwargs):
225
- if not hasattr(cls, "_initialized"):
226
- original_init = cls.__init__
227
- new_cls = type(
228
- cls.__name__,
229
- (FusedMoEMethodBase,),
230
- {
231
- "__init__": original_init,
232
- **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
233
- },
234
- )
235
- obj = super(new_cls, new_cls).__new__(new_cls)
236
- obj.__init__(*args, **kwargs)
237
- return obj
238
- return super().__new__(cls)
239
-
240
- @staticmethod
241
- def get_moe_method(
242
- quant_config: "MxFp4Config", # type: ignore # noqa E501 # noqa F821
243
- module: torch.nn.Module,
244
- layer_name: str,
245
- ) -> "MxFp4MoEMethod":
246
-
247
- if quant_config.is_checkpoint_fp4_serialized:
248
- layer_quant_config = quant_config._find_matched_config(layer_name, module)
249
-
250
- if layer_quant_config.get("output_tensors") or layer_quant_config.get(
251
- "bias"
252
- ):
253
- raise NotImplementedError(
254
- "Currently, Quark models with "
255
- "output_tensors and bias "
256
- "quantized are not supported"
257
- )
258
- weight_config = layer_quant_config.get("weight")
259
- input_config = layer_quant_config.get("input_tensors")
260
-
261
- if quant_config._is_mx_fp4(weight_config, input_config):
262
- return W4A4MXFp4MoEStaticMethod(weight_config, input_config)
263
- else:
264
- raise RuntimeError("Unsupported FusedMoe scheme")
265
- else:
266
- return W4A4MXFp4MoEDynamicMethod(quant_config)
267
-
268
-
269
- class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
270
- def __init__(self, quant_config):
271
- self.quant_config = quant_config
272
-
273
- def create_weights(
274
- self,
275
- layer: torch.nn.Module,
276
- num_experts: int,
277
- hidden_size: int,
278
- intermediate_size_per_partition: int,
279
- params_dtype: torch.dtype,
280
- **extra_weight_attrs,
281
- ):
282
-
283
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
284
-
285
- w13_weight = torch.nn.Parameter(
286
- torch.empty(
287
- num_experts,
288
- 2 * intermediate_size_per_partition,
289
- hidden_size,
290
- dtype=params_dtype,
291
- ),
292
- requires_grad=False,
293
- )
294
- w2_weight = torch.nn.Parameter(
295
- torch.empty(
296
- num_experts,
297
- hidden_size,
298
- intermediate_size_per_partition,
299
- dtype=params_dtype,
300
- ),
301
- requires_grad=False,
302
- )
303
-
304
- layer.register_parameter("w13_weight", w13_weight)
305
- set_weight_attrs(w13_weight, extra_weight_attrs)
306
-
307
- layer.register_parameter("w2_weight", w2_weight)
308
- set_weight_attrs(w2_weight, extra_weight_attrs)
309
-
310
- # Allocate 2 scales for w1 and w3 respectively.
311
- # They will be combined to a single scale after weight loading.
312
- w13_weight_scale = torch.nn.Parameter(
313
- torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
314
- )
315
- w2_weight_scale = torch.nn.Parameter(
316
- torch.ones(num_experts, dtype=torch.float32), requires_grad=False
317
- )
318
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
319
- layer.register_parameter("w2_weight_scale", w2_weight_scale)
320
-
321
- # Add the quantization method used (per tensor/grouped/channel)
322
- # to ensure the weight scales are loaded in properly
323
- extra_weight_attrs.update(
324
- {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
325
- )
326
-
327
- layer.w13_input_scale = None
328
- layer.w2_input_scale = None
329
-
330
- def mxfp4_quantize(self, w):
331
- w_shape = w.shape
332
- w_need_reshape = True if w.dim() != 2 else False
333
-
334
- if w_need_reshape:
335
- w_last_dim_size = w_shape[-1]
336
- w = w.view(-1, w_last_dim_size)
337
-
338
- # log_info_on_rank0(logger, f"[Pre-quant] w.shape: {w.shape}")
339
- w, mx_scales = dynamic_mxfp4_quant(w)
340
- # log_info_on_rank0(logger, f"[Post-quant] w.shape: {w.shape} mx_scales.shape: {mx_scales.shape}")
341
-
342
- if w_need_reshape:
343
- w_new_shape = w_shape[:-1] + (w.shape[-1],)
344
- w = w.view(w_new_shape)
345
-
346
- # log_info_on_rank0(logger, f"[re-shape] w.shape: {w.shape} mx_scales.shape: {mx_scales.shape}")
347
-
348
- mx_scales = e8m0_shuffle(mx_scales)
349
-
350
- return w, mx_scales
351
-
352
- def process_weights_after_loading(self, layer: Module) -> None:
353
- w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
354
- w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
355
-
356
- layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
357
- layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False)
358
-
359
- layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
360
- layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
361
-
362
- def apply(
363
- self,
364
- layer: torch.nn.Module,
365
- x: torch.Tensor,
366
- topk_output: TopKOutput,
367
- *,
368
- activation: str = "silu",
369
- apply_router_weight_on_input: bool = False,
370
- inplace: bool = True,
371
- no_combine: bool = False,
372
- routed_scaling_factor: Optional[float] = None,
373
- ) -> torch.Tensor:
374
- topk_weights, topk_ids, _ = topk_output
375
-
376
- return fused_moe(
377
- x,
378
- layer.w13_weight,
379
- layer.w2_weight,
380
- topk_weights,
381
- topk_ids,
382
- quant_type=QuantType.per_1x32,
383
- w1_scale=layer.w13_weight_scale,
384
- w2_scale=layer.w2_weight_scale,
385
- activation=(
386
- ActivationType.Silu if activation == "silu" else ActivationType.Gelu
387
- ),
388
- doweight_stage1=False,
389
- )
390
-
391
-
392
- class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
393
-
394
- def __init__(self, weight_config: dict[str, Any], input_config: dict[str, Any]):
395
- self.weight_quant = weight_config
396
- self.input_quant = input_config
397
-
398
- weight_qscheme = self.weight_quant.get("qscheme")
399
- input_qscheme = self.input_quant.get("qscheme")
400
- if not (weight_qscheme == "per_group" and input_qscheme == "per_group"):
401
- raise ValueError(
402
- "For MX(FP4) Fused MoE layers, only per-group scales "
403
- "for weights and activations are supported. Found "
404
- f"{weight_qscheme=}, {input_qscheme=}"
405
- ) # noqa E501
406
-
407
- self.static_input_scales = not self.input_quant.get("is_dynamic")
408
-
409
- def create_weights(
410
- self,
411
- layer: torch.nn.Module,
412
- num_experts: int,
413
- hidden_size: int,
414
- intermediate_size_per_partition: int,
415
- params_dtype: torch.dtype,
416
- **extra_weight_attrs,
417
- ):
418
-
419
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
420
-
421
- # Add the quantization method used (per tensor/grouped/channel)
422
- # to ensure the weight scales are loaded in properly
423
- extra_weight_attrs.update(
424
- {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
425
- )
426
-
427
- params_dtype = torch.uint8
428
-
429
- # WEIGHTS
430
- w13_weight = torch.nn.Parameter(
431
- torch.empty(
432
- num_experts,
433
- 2 * intermediate_size_per_partition,
434
- hidden_size // 2,
435
- dtype=params_dtype,
436
- ),
437
- requires_grad=False,
438
- )
439
- layer.register_parameter("w13_weight", w13_weight)
440
-
441
- set_weight_attrs(w13_weight, extra_weight_attrs)
442
-
443
- w2_weight = torch.nn.Parameter(
444
- torch.empty(
445
- num_experts,
446
- hidden_size,
447
- intermediate_size_per_partition // 2,
448
- dtype=params_dtype,
449
- ),
450
- requires_grad=False,
451
- )
452
- layer.register_parameter("w2_weight", w2_weight)
453
-
454
- set_weight_attrs(w2_weight, extra_weight_attrs)
455
-
456
- # WEIGHT_SCALES
457
- w13_weight_scale = torch.nn.Parameter(
458
- torch.ones(
459
- num_experts,
460
- 2 * intermediate_size_per_partition,
461
- hidden_size // OCP_MX_BLOCK_SIZE,
462
- dtype=params_dtype,
463
- ),
464
- requires_grad=False,
465
- )
466
- w2_weight_scale = torch.nn.Parameter(
467
- torch.ones(
468
- num_experts,
469
- hidden_size,
470
- intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
471
- dtype=params_dtype,
472
- ),
473
- requires_grad=False,
474
- )
475
- set_weight_attrs(w2_weight_scale, extra_weight_attrs)
476
- set_weight_attrs(w13_weight_scale, extra_weight_attrs)
477
-
478
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
479
- layer.register_parameter("w2_weight_scale", w2_weight_scale)
480
-
481
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
482
- float_dtype = torch.get_default_dtype()
483
-
484
- # Pre-shuffle weight scales
485
- s0, s1, _ = layer.w13_weight_scale.shape
486
- w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
487
- w13_weight_scale = e8m0_shuffle(w13_weight_scale)
488
- layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
489
-
490
- s0, s1, _ = layer.w2_weight_scale.shape
491
- w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
492
- w2_weight_scale = e8m0_shuffle(w2_weight_scale)
493
- layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
494
-
495
- def apply(
496
- self,
497
- layer: torch.nn.Module,
498
- x: torch.Tensor,
499
- topk_output: TopKOutput,
500
- *,
501
- activation: str = "silu",
502
- apply_router_weight_on_input: bool = False,
503
- inplace: bool = True,
504
- no_combine: bool = False,
505
- routed_scaling_factor: Optional[float] = None,
506
- ) -> torch.Tensor:
507
- topk_weights, topk_ids, _ = topk_output
508
-
509
- return fused_moe(
510
- x,
511
- layer.w13_weight,
512
- layer.w2_weight,
513
- topk_weights,
514
- topk_ids,
515
- quant_type=QuantType.per_1x32,
516
- w1_scale=layer.w13_weight_scale,
517
- w2_scale=layer.w2_weight_scale,
518
- activation=(
519
- ActivationType.Silu if activation == "silu" else ActivationType.Gelu
520
- ),
521
- doweight_stage1=False,
522
- )
523
-
524
-
525
- class MxFp4KVCacheMethod(BaseKVCacheMethod):
526
- """
527
- Supports loading kv-cache scaling factors from quark checkpoints.
528
- """
529
-
530
- def __init__(self, quant_config: MxFp4Config):
531
- self.validate_kv_cache_config(quant_config.kv_cache_config)
532
- super().__init__(quant_config)
533
-
534
- @staticmethod
535
- def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
536
- """
537
- Validator for the kv cache configuration. Useful for controlling the
538
- kv cache quantization schemes, that are being supported in vLLM
539
- :param kv_cache_config: the quark kv cache scheme
540
- """
541
- if kv_cache_config is None:
542
- return
543
-
544
- dtype = kv_cache_config.get("dtype")
545
- if dtype != "fp8_e4m3":
546
- raise NotImplementedError(
547
- "Currently supported kv cache quantization is "
548
- f"dtype=fp8_e4m3, however received {dtype}"
549
- )
550
-
551
- qscheme = kv_cache_config.get("qscheme")
552
- if qscheme != "per_tensor":
553
- raise NotImplementedError(
554
- "Only support per-tensor scaling factor "
555
- "for quark KV cache. "
556
- f"Expected qscheme: per_tensor, found qscheme: {qscheme}"
557
- )