sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,654 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from __future__ import annotations
5
+
6
+ import importlib.util
7
+ import logging
8
+ from typing import TYPE_CHECKING, List, Optional
9
+
10
+ import torch
11
+ import triton.language as tl
12
+ from torch.nn.parameter import Parameter
13
+
14
+ from sglang.srt.layers.quantization.base_config import (
15
+ FusedMoEMethodBase,
16
+ QuantizationConfig,
17
+ QuantizeMethodBase,
18
+ )
19
+ from sglang.srt.layers.quantization.utils import is_layer_skipped
20
+ from sglang.srt.layers.utils import is_sm100_supported
21
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
22
+ from sglang.srt.utils import (
23
+ direct_register_custom_op,
24
+ get_bool_env_var,
25
+ is_cuda,
26
+ is_flashinfer_available,
27
+ is_hip,
28
+ is_triton_kernels_available,
29
+ log_info_on_rank0,
30
+ next_power_of_2,
31
+ round_up,
32
+ set_weight_attrs,
33
+ )
34
+
35
+ _is_sm100_supported = is_cuda() and is_sm100_supported()
36
+ has_triton_kernels = is_triton_kernels_available()
37
+
38
+
39
+ if is_flashinfer_available():
40
+ from flashinfer import (
41
+ mxfp8_quantize,
42
+ shuffle_matrix_a,
43
+ shuffle_matrix_sf_a,
44
+ trtllm_fp4_block_scale_moe,
45
+ )
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ if TYPE_CHECKING:
50
+ from sglang.srt.layers.moe.topk import TopKOutput
51
+
52
+ OCP_MX_BLOCK_SIZE = 32
53
+
54
+
55
+ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
56
+ """weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
57
+ import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
58
+ from triton_kernels.numerics import InFlexData
59
+ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
60
+ from triton_kernels.tensor_details import layout
61
+
62
+ value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
63
+ mx_axis=1
64
+ )
65
+ scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
66
+ mx_axis=1, num_warps=num_warps
67
+ )
68
+ if _is_sm100_supported:
69
+ constraints = {
70
+ "is_persistent": True,
71
+ "epilogue_subtile": 1,
72
+ }
73
+ opt_flags.update_opt_flags_constraints(constraints)
74
+ # transpose the tensor so that the quantization axis is on dim1
75
+ quant_tensor = quant_tensor.transpose(-2, -1)
76
+ scale = scale.transpose(-2, -1)
77
+ quant_tensor = convert_layout(
78
+ wrap_torch_tensor(quant_tensor, dtype=FP4), value_layout, **value_layout_opts
79
+ )
80
+ scale = convert_layout(wrap_torch_tensor(scale), scale_layout, **scale_layout_opts)
81
+ return quant_tensor, InFlexData(), scale
82
+
83
+
84
+ def _dequant_mxfp4(
85
+ x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
86
+ ) -> torch.Tensor:
87
+ try:
88
+ from quark.torch.kernel import mx
89
+ except ImportError as err:
90
+ raise ImportError(
91
+ "The package `amd-quark` is required to use "
92
+ "MX-FP4 models. Please install it with `pip install "
93
+ "amd-quark`."
94
+ ) from err
95
+
96
+ return mx.dq_mxfp4(x, scale, float_dtype)
97
+
98
+
99
+ def _dequant_mxfp4_fake(
100
+ x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
101
+ ) -> torch.Tensor:
102
+ return torch.empty(
103
+ (*x.shape[:-1], x.shape[-1] * 2), dtype=float_dtype, device=x.device
104
+ )
105
+
106
+
107
+ def _quant_dequant_mxfp4(
108
+ x: torch.Tensor, scale_calculation_mode: str = "even"
109
+ ) -> torch.Tensor:
110
+ try:
111
+ from quark.torch.kernel import mx
112
+ except ImportError as err:
113
+ raise ImportError(
114
+ "The package `amd-quark` is required to use "
115
+ "MX-FP4 models. Please install it with `pip install "
116
+ "amd-quark`."
117
+ ) from err
118
+
119
+ return mx.qdq_mxfp4(x, scale_calculation_mode)
120
+
121
+
122
+ def _quant_dequant_mxfp4_fake(
123
+ x: torch.Tensor, scale_calculation_mode: str = "even"
124
+ ) -> torch.Tensor:
125
+ return torch.empty_like(x)
126
+
127
+
128
+ try:
129
+ direct_register_custom_op(
130
+ op_name="dequant_mxfp4",
131
+ op_func=_dequant_mxfp4,
132
+ mutates_args=[],
133
+ fake_impl=_dequant_mxfp4_fake,
134
+ )
135
+ dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
136
+ except AttributeError as error:
137
+ raise error
138
+
139
+ try:
140
+ direct_register_custom_op(
141
+ op_name="quant_dequant_mxfp4",
142
+ op_func=_quant_dequant_mxfp4,
143
+ mutates_args=[],
144
+ fake_impl=_quant_dequant_mxfp4_fake,
145
+ )
146
+ quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
147
+ except AttributeError as error:
148
+ raise error
149
+
150
+
151
+ class Mxfp4Config(QuantizationConfig):
152
+
153
+ def __init__(self, ignored_layers: Optional[list[str]] = None):
154
+ super().__init__()
155
+ self.ignored_layers = ignored_layers
156
+
157
+ @classmethod
158
+ def from_config(cls, config):
159
+ return cls()
160
+
161
+ @classmethod
162
+ def get_min_capability(cls) -> int:
163
+ return 80
164
+
165
+ @classmethod
166
+ def get_name(cls) -> str:
167
+ return "mxfp4"
168
+
169
+ @classmethod
170
+ def get_supported_act_dtypes(cls) -> list[torch.dtype]:
171
+ return [torch.bfloat16, torch.float16]
172
+
173
+ @classmethod
174
+ def get_config_filenames(cls) -> list[str]:
175
+ return []
176
+
177
+ def get_quant_method(
178
+ self, layer: torch.nn.Module, prefix: str
179
+ ) -> Optional["QuantizeMethodBase"]:
180
+
181
+ from sglang.srt.layers.linear import LinearBase
182
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
183
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
184
+
185
+ if isinstance(layer, LinearBase):
186
+ if self.ignored_layers and is_layer_skipped(
187
+ prefix=prefix,
188
+ ignored_layers=self.ignored_layers,
189
+ fused_mapping=self.packed_modules_mapping,
190
+ ):
191
+ return UnquantizedLinearMethod()
192
+ elif isinstance(layer, FusedMoE):
193
+ return Mxfp4MoEMethod(prefix)
194
+ else:
195
+ raise NotImplementedError("Mxfp4 attention layer is not implemented")
196
+ return None
197
+
198
+ def get_scaled_act_names(self) -> List[str]:
199
+ return []
200
+
201
+
202
+ class Mxfp4MoEMethod(FusedMoEMethodBase):
203
+
204
+ def __init__(
205
+ self,
206
+ prefix: str,
207
+ ):
208
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
209
+
210
+ super().__init__()
211
+
212
+ self.topk_indices_dtype = None
213
+ self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
214
+ self.with_bias = False
215
+ self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
216
+
217
+ self.triton_kernel_moe_forward = None
218
+ self.triton_kernel_moe_with_bias_forward = None
219
+ if torch.cuda.is_available() and has_triton_kernels:
220
+ from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
221
+ triton_kernel_moe_forward as _tk_forward,
222
+ )
223
+ from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
224
+ triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
225
+ )
226
+
227
+ self.triton_kernel_moe_forward = _tk_forward
228
+ self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
229
+
230
+ def create_weights(
231
+ self,
232
+ layer: torch.nn.Module,
233
+ num_experts: int,
234
+ hidden_size: int,
235
+ intermediate_size: int,
236
+ params_dtype: torch.dtype,
237
+ with_bias: bool = False,
238
+ **extra_weight_attrs,
239
+ ):
240
+ self.num_experts = num_experts
241
+ weight_dtype = torch.uint8
242
+ scale_dtype = torch.uint8
243
+ self.with_bias = with_bias
244
+ mxfp4_block = 32
245
+
246
+ # pad the intermediate size to be a multiple of 2 * mxfp4_block
247
+ # for to hold non-uniform sharded tensor as well as swizzling
248
+ intermediate_size_per_partition_after_pad = intermediate_size
249
+ if _is_sm100_supported:
250
+ if self.use_flashinfer:
251
+ intermediate_size_per_partition_after_pad = round_up(
252
+ intermediate_size, 256
253
+ )
254
+ hidden_size = round_up(hidden_size, 256)
255
+ else:
256
+ intermediate_size_per_partition_after_pad = round_up(
257
+ intermediate_size, 64
258
+ )
259
+
260
+ self.intermediate_size = intermediate_size_per_partition_after_pad
261
+
262
+ self.hidden_size = hidden_size
263
+ # Fused gate_up_proj (column parallel)
264
+ w13_weight = torch.nn.Parameter(
265
+ torch.zeros(
266
+ layer.num_local_experts,
267
+ 2 * intermediate_size_per_partition_after_pad,
268
+ hidden_size // 2,
269
+ dtype=weight_dtype,
270
+ ),
271
+ requires_grad=False,
272
+ )
273
+ layer.register_parameter("w13_weight", w13_weight)
274
+ set_weight_attrs(w13_weight, extra_weight_attrs)
275
+
276
+ w13_weight_scale = torch.nn.Parameter(
277
+ torch.zeros(
278
+ layer.num_local_experts,
279
+ 2 * intermediate_size_per_partition_after_pad,
280
+ hidden_size // mxfp4_block,
281
+ dtype=scale_dtype,
282
+ ),
283
+ requires_grad=False,
284
+ )
285
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
286
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
287
+
288
+ w13_weight_bias = torch.nn.Parameter(
289
+ torch.zeros(
290
+ layer.num_local_experts,
291
+ 2 * intermediate_size_per_partition_after_pad,
292
+ dtype=torch.bfloat16,
293
+ ),
294
+ requires_grad=False,
295
+ )
296
+ layer.register_parameter("w13_weight_bias", w13_weight_bias)
297
+ set_weight_attrs(w13_weight_bias, extra_weight_attrs)
298
+
299
+ # down_proj (row parallel)
300
+ w2_weight = torch.nn.Parameter(
301
+ torch.zeros(
302
+ layer.num_local_experts,
303
+ hidden_size,
304
+ intermediate_size_per_partition_after_pad // 2,
305
+ dtype=weight_dtype,
306
+ ),
307
+ requires_grad=False,
308
+ )
309
+ layer.register_parameter("w2_weight", w2_weight)
310
+ set_weight_attrs(w2_weight, extra_weight_attrs)
311
+
312
+ w2_weight_scale = torch.nn.Parameter(
313
+ torch.zeros(
314
+ layer.num_local_experts,
315
+ hidden_size,
316
+ intermediate_size_per_partition_after_pad // mxfp4_block,
317
+ dtype=scale_dtype,
318
+ ),
319
+ requires_grad=False,
320
+ )
321
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
322
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
323
+
324
+ w2_weight_bias = torch.nn.Parameter(
325
+ torch.zeros(layer.num_local_experts, hidden_size, dtype=torch.bfloat16),
326
+ requires_grad=False,
327
+ )
328
+ layer.register_parameter("w2_weight_bias", w2_weight_bias)
329
+ set_weight_attrs(w2_weight_bias, extra_weight_attrs)
330
+
331
+ def process_weights_after_loading(self, layer):
332
+ if self.use_flashinfer:
333
+ log_info_on_rank0(
334
+ logger,
335
+ "Shuffling MoE weights for FlashInfer MXFP4 moe kernel, it might take a while...",
336
+ )
337
+ layer.gemm1_alpha = Parameter(
338
+ torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
339
+ requires_grad=False,
340
+ )
341
+ layer.gemm1_beta = Parameter(
342
+ torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
343
+ requires_grad=False,
344
+ )
345
+ layer.gemm1_clamp_limit = Parameter(
346
+ torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
347
+ requires_grad=False,
348
+ )
349
+ sf_block_size = 32 # mxfp4 block size
350
+
351
+ assert (
352
+ layer.w13_weight.dim() == 3
353
+ and layer.w13_weight.shape[0] == self.num_experts
354
+ and layer.w13_weight.shape[1] == self.intermediate_size * 2
355
+ and layer.w13_weight.shape[2] == self.hidden_size // 2
356
+ )
357
+ assert (
358
+ layer.w13_weight_scale.dim() == 3
359
+ and layer.w13_weight_scale.shape[0] == self.num_experts
360
+ and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
361
+ and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
362
+ )
363
+ assert (
364
+ layer.w2_weight.dim() == 3
365
+ and layer.w2_weight.shape[0] == self.num_experts
366
+ and layer.w2_weight.shape[1] == self.hidden_size
367
+ and layer.w2_weight.shape[2] == self.intermediate_size // 2
368
+ )
369
+ assert (
370
+ layer.w2_weight_scale.dim() == 3
371
+ and layer.w2_weight_scale.shape[1] == self.hidden_size
372
+ and layer.w2_weight_scale.shape[2]
373
+ == self.intermediate_size // sf_block_size
374
+ )
375
+ assert (
376
+ layer.w13_weight_bias.dim() == 2
377
+ and layer.w13_weight_bias.shape[0] == self.num_experts
378
+ and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2
379
+ )
380
+ assert (
381
+ layer.w2_weight_bias.dim() == 2
382
+ and layer.w2_weight_bias.shape[0] == self.num_experts
383
+ and layer.w2_weight_bias.shape[1] == self.hidden_size
384
+ )
385
+
386
+ w13_weight_scale = layer.w13_weight_scale.data
387
+ w2_weight_scale = layer.w2_weight_scale.data
388
+ w13_weight = layer.w13_weight.data
389
+ w2_weight = layer.w2_weight.data
390
+ w13_bias = layer.w13_weight_bias.data.to(torch.float32)
391
+ w2_bias = layer.w2_weight_bias.data.to(torch.float32)
392
+
393
+ # Swap w1 and w3 as the definition of
394
+ # swiglu is different in the trtllm-gen
395
+ def swap_every_two_rows(x, axis=-1):
396
+ shape = x.shape
397
+ if axis < 0:
398
+ axis = len(shape) + axis
399
+
400
+ # Create a new shape with pairs swapped along specified axis
401
+ new_shape = list(shape)
402
+ new_shape[axis] = shape[axis] // 2
403
+ new_shape.insert(axis + 1, 2)
404
+
405
+ # Reshape to expose pairs, swap them, and reshape back
406
+ x = x.reshape(*new_shape)
407
+ x = x.flip(axis + 1)
408
+ new_shape = list(shape)
409
+ return x.reshape(*new_shape)
410
+
411
+ w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
412
+ w13_weight = swap_every_two_rows(w13_weight, -2)
413
+ w13_bias = swap_every_two_rows(w13_bias, -1)
414
+
415
+ # Shuffle weights and scaling factors for transposed mma output
416
+ gemm1_weights_mxfp4_shuffled = []
417
+ gemm1_scales_mxfp4_shuffled = []
418
+ gemm2_weights_mxfp4_shuffled = []
419
+ gemm2_scales_mxfp4_shuffled = []
420
+ gemm1_bias_shuffled = []
421
+ gemm2_bias_shuffled = []
422
+ epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
423
+ for i in range(self.num_experts):
424
+ gemm1_weights_mxfp4_shuffled.append(
425
+ shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)
426
+ )
427
+ gemm1_scales_mxfp4_shuffled.append(
428
+ shuffle_matrix_sf_a(
429
+ w13_weight_scale[i].view(torch.uint8), epilogue_tile_m
430
+ )
431
+ )
432
+ gemm1_bias_shuffled.append(
433
+ shuffle_matrix_a(
434
+ w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m
435
+ )
436
+ )
437
+
438
+ gemm2_weights_mxfp4_shuffled.append(
439
+ shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
440
+ )
441
+ gemm2_scales_mxfp4_shuffled.append(
442
+ shuffle_matrix_sf_a(
443
+ w2_weight_scale[i].view(torch.uint8), epilogue_tile_m
444
+ )
445
+ )
446
+ gemm2_bias_shuffled.append(
447
+ shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m)
448
+ )
449
+
450
+ w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
451
+ w13_weight_scale = (
452
+ torch.stack(gemm1_scales_mxfp4_shuffled)
453
+ .reshape(
454
+ self.num_experts,
455
+ 2 * self.intermediate_size,
456
+ self.hidden_size // sf_block_size,
457
+ )
458
+ .view(torch.float8_e4m3fn)
459
+ )
460
+
461
+ w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
462
+ w2_weight_scale = (
463
+ torch.stack(gemm2_scales_mxfp4_shuffled)
464
+ .reshape(
465
+ self.num_experts,
466
+ self.hidden_size,
467
+ self.intermediate_size // sf_block_size,
468
+ )
469
+ .view(torch.float8_e4m3fn)
470
+ )
471
+
472
+ layer.w13_weight = Parameter(w13_weight, requires_grad=False)
473
+ layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
474
+ layer.w2_weight = Parameter(w2_weight, requires_grad=False)
475
+ layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
476
+ layer.w13_weight_bias = Parameter(
477
+ torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
478
+ requires_grad=False,
479
+ )
480
+ layer.w2_weight_bias = Parameter(
481
+ torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
482
+ requires_grad=False,
483
+ )
484
+ return
485
+
486
+ if self.use_triton_kernels:
487
+
488
+ from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
489
+
490
+ w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
491
+ w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
492
+
493
+ layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
494
+ layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
495
+
496
+ num_warps = 8
497
+
498
+ w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
499
+ layer.w13_weight, layer.w13_weight_scale, num_warps
500
+ )
501
+ w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
502
+ layer.w2_weight, layer.w2_weight_scale, num_warps
503
+ )
504
+
505
+ self.w13_precision_config = PrecisionConfig(
506
+ weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
507
+ )
508
+ self.w2_precision_config = PrecisionConfig(
509
+ weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
510
+ )
511
+
512
+ self.w13_weight_triton_tensor = w13_weight
513
+ self.w2_weight_triton_tensor = w2_weight
514
+ del layer.w13_weight
515
+ del layer.w2_weight
516
+ else:
517
+ from triton_kernels.numerics_details.mxfp import upcast_from_mxfp
518
+
519
+ w13_weight = upcast_from_mxfp(
520
+ layer.w13_weight, layer.w13_weight_scale, dtype=torch.bfloat16, axis=-1
521
+ )
522
+ w2_weight = upcast_from_mxfp(
523
+ layer.w2_weight, layer.w2_weight_scale, dtype=torch.bfloat16, axis=-1
524
+ )
525
+ del layer.w13_weight
526
+ del layer.w2_weight
527
+ del layer.w13_weight_scale
528
+ del layer.w2_weight_scale
529
+ layer.w13_weight = Parameter(w13_weight.data, requires_grad=False)
530
+ layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
531
+ torch.cuda.empty_cache()
532
+
533
+ def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
534
+ # Number of tokens in the input tensor.
535
+ num_tokens = x.shape[0]
536
+ # Factor to account for the imbalance of the experts.
537
+ # factor equals to the
538
+ # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
539
+ # - 1.0 means perfect expert distribution.
540
+ # - > 1.0 means some experts have more
541
+ # tokens than the perfect distribution.
542
+ # - < 1.0 does not make sense.
543
+ imbalance_factor = 1.3
544
+ # Calculate the number of tokens per expert
545
+ # assuming perfect distribution.
546
+ num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
547
+ # Apply the imbalance factor.
548
+ num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
549
+ # And pad the number to the next power of 2.
550
+ tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
551
+ # Cap to 8-64 tokens per CTA tile
552
+ # as it's the range supported by the kernel.
553
+ tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
554
+
555
+ return tile_tokens_dim
556
+
557
+ def apply(
558
+ self,
559
+ layer: torch.nn.Module,
560
+ x: torch.Tensor,
561
+ topk_output: TopKOutput,
562
+ *,
563
+ activation: str = "silu",
564
+ apply_router_weight_on_input: bool = False,
565
+ inplace: bool = True,
566
+ no_combine: bool = False,
567
+ routed_scaling_factor: Optional[float] = None,
568
+ activation_alpha: Optional[float] = None,
569
+ swiglu_limit: Optional[float] = None,
570
+ ) -> torch.Tensor:
571
+ if self.use_flashinfer:
572
+ # Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
573
+ x_quant, x_scale = mxfp8_quantize(
574
+ x, False, alignment=self.hidden_size
575
+ ) # to mxfp8
576
+ x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
577
+ assert x_quant.shape[-1] == self.hidden_size
578
+
579
+ top_k, router_logits = topk_output
580
+
581
+ trtllm_gen_output = trtllm_fp4_block_scale_moe(
582
+ router_logits.to(torch.bfloat16),
583
+ None, # routing_bias
584
+ x_quant,
585
+ x_scale,
586
+ layer.w13_weight, # uint8 (e2m1 x 2)
587
+ layer.w13_weight_scale, # uint8 (e4m3 x 2)
588
+ layer.w13_weight_bias, # fp32 per expert per channel
589
+ layer.gemm1_alpha, # fp32 per expert
590
+ layer.gemm1_beta, # fp32 per expert
591
+ layer.gemm1_clamp_limit, # fp32 per expert
592
+ layer.w2_weight, # uint8 (e2m1 x 2)
593
+ layer.w2_weight_scale, # ue8m0
594
+ layer.w2_weight_bias, # fp32 per expert per channel
595
+ None, # output1_scale_scalar
596
+ None, # output1_scale_gate_scalar
597
+ None, # output2_scale_scalar
598
+ layer.num_experts,
599
+ top_k,
600
+ None, # n_group
601
+ None, # topk_group
602
+ self.intermediate_size, # padded to multiple of 256
603
+ layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
604
+ layer.num_local_experts, # local num experts
605
+ None,
606
+ self._get_tile_tokens_dim(x, top_k),
607
+ 1, # routing_method_type, renormalize
608
+ True, # do finalize
609
+ )[0]
610
+ return trtllm_gen_output
611
+
612
+ if self.use_triton_kernels:
613
+ assert (
614
+ layer.moe_ep_size == 1
615
+ ), "Expert parallel is not supported when using triton kernels"
616
+ if self.with_bias:
617
+ return self.triton_kernel_moe_with_bias_forward(
618
+ hidden_states=x,
619
+ w1=self.w13_weight_triton_tensor,
620
+ w1_pcg=self.w13_precision_config,
621
+ w2=self.w2_weight_triton_tensor,
622
+ w2_pcg=self.w2_precision_config,
623
+ b1=layer.w13_weight_bias,
624
+ b2=layer.w2_weight_bias,
625
+ topk_output=topk_output,
626
+ activation=activation,
627
+ activation_alpha=activation_alpha,
628
+ swiglu_limit=swiglu_limit,
629
+ )
630
+ else:
631
+ return self.triton_kernel_moe_forward(
632
+ hidden_states=x,
633
+ w1=layer.w13_weight,
634
+ w2=layer.w2_weight,
635
+ topk_output=topk_output,
636
+ )
637
+ else:
638
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
639
+
640
+ return fused_experts(
641
+ hidden_states=x,
642
+ w1=layer.w13_weight,
643
+ w2=layer.w2_weight,
644
+ topk_output=topk_output,
645
+ b1=layer.w13_weight_bias,
646
+ b2=layer.w2_weight_bias,
647
+ inplace=inplace,
648
+ activation=activation,
649
+ apply_router_weight_on_input=apply_router_weight_on_input,
650
+ no_combine=no_combine,
651
+ routed_scaling_factor=routed_scaling_factor,
652
+ activation_alpha=activation_alpha,
653
+ swiglu_limit=swiglu_limit,
654
+ )