sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,14 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
2
2
 
3
- import importlib.util
3
+ import datetime
4
+ import glob
4
5
  import logging
6
+ import os
7
+ import sys
5
8
  from enum import Enum
6
- from functools import lru_cache
7
9
  from typing import List, Optional, Tuple
8
10
 
9
11
  import torch
10
- from packaging import version as pkg_version
11
12
 
12
13
  from sglang.srt.distributed import (
13
14
  get_moe_expert_parallel_rank,
@@ -22,6 +23,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
22
23
  )
23
24
  from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
24
25
  from sglang.srt.layers.moe.topk import StandardTopKOutput
26
+ from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
25
27
  from sglang.srt.layers.quantization.base_config import (
26
28
  QuantizationConfig,
27
29
  QuantizeMethodBase,
@@ -29,22 +31,59 @@ from sglang.srt.layers.quantization.base_config import (
29
31
  from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
30
32
  from sglang.srt.managers.schedule_batch import global_server_args_dict
31
33
  from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
32
- from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip
34
+ from sglang.srt.utils import (
35
+ cpu_has_amx_support,
36
+ get_bool_env_var,
37
+ is_cpu,
38
+ is_flashinfer_available,
39
+ is_hip,
40
+ next_power_of_2,
41
+ round_up,
42
+ )
43
+
44
+ if is_flashinfer_available():
45
+ from flashinfer import (
46
+ RoutingMethodType,
47
+ fp4_quantize,
48
+ reorder_rows_for_gated_act_gemm,
49
+ shuffle_matrix_a,
50
+ shuffle_matrix_sf_a,
51
+ )
33
52
 
34
53
  _is_hip = is_hip()
35
54
  _is_cpu_amx_available = cpu_has_amx_support()
36
55
  _is_cpu = is_cpu()
37
56
 
57
+
58
+ # Try to import FP4 TRTLLM function if flashinfer is available
59
+ trtllm_fp4_block_scale_moe = None
60
+ if should_use_flashinfer_trtllm_moe():
61
+ try:
62
+ from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
63
+ except ImportError:
64
+ trtllm_fp4_block_scale_moe = None
65
+
38
66
  logger = logging.getLogger(__name__)
39
67
 
40
68
 
41
- @lru_cache(maxsize=1)
42
- def should_use_flashinfer_trtllm_moe():
43
- return global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
44
- not importlib.util.find_spec("flashinfer")
45
- or pkg_version.parse(__import__("flashinfer").__version__)
46
- >= pkg_version.parse("0.2.9rc1")
47
- )
69
+ def _is_fp4_quantization_enabled():
70
+ """Check if ModelOpt FP4 quantization is enabled."""
71
+ try:
72
+ # Use the same simple check that works for class selection
73
+ quantization = global_server_args_dict.get("quantization")
74
+ return quantization == "modelopt_fp4"
75
+ except:
76
+ return False
77
+
78
+
79
+ def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
80
+ # Guess tokens per expert assuming perfect expert distribution first.
81
+ num_tokens_per_expert = (num_tokens * top_k) // num_experts
82
+ # And pad the number to the next power of 2.
83
+ tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
84
+ # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
85
+ tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
86
+ return tile_tokens_dim
48
87
 
49
88
 
50
89
  class FusedMoeWeightScaleSupported(Enum):
@@ -96,6 +135,10 @@ class FusedMoE(torch.nn.Module):
96
135
  no_combine: bool = False,
97
136
  routed_scaling_factor: Optional[float] = None,
98
137
  enable_flashinfer_cutlass_moe: Optional[bool] = False,
138
+ activation_alpha: Optional[float] = None,
139
+ swiglu_limit: Optional[float] = None,
140
+ use_weight_loader_fused: bool = False,
141
+ with_bias=False,
99
142
  ):
100
143
  super().__init__()
101
144
 
@@ -110,6 +153,10 @@ class FusedMoE(torch.nn.Module):
110
153
  self.expert_map_cpu = None
111
154
  self.expert_map_gpu = None
112
155
 
156
+ # For activation
157
+ self.activation_alpha = activation_alpha
158
+ self.swiglu_limit = swiglu_limit
159
+
113
160
  if enable_flashinfer_cutlass_moe and quant_config is None:
114
161
  logger.warning("Disable flashinfer MoE when quantization config is None.")
115
162
  enable_flashinfer_cutlass_moe = False
@@ -124,15 +171,18 @@ class FusedMoE(torch.nn.Module):
124
171
  if self.moe_ep_size > 1:
125
172
  # TODO(ch-wan): support shared experts fusion
126
173
  # Create a tensor of size num_experts filled with -1
127
- self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
174
+ self.expert_map_cpu = torch.full(
175
+ (self.num_experts,), -1, dtype=torch.int32, device="cpu"
176
+ )
177
+ self.expert_map_cpu = torch.full(
178
+ (self.num_experts,), -1, dtype=torch.int32, device="cpu"
179
+ )
128
180
  # Create a expert map for the local experts
129
181
  self.expert_map_cpu[
130
182
  self.moe_ep_rank
131
183
  * self.num_local_experts : (self.moe_ep_rank + 1)
132
184
  * self.num_local_experts
133
185
  ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
134
- if not self.enable_flashinfer_cutlass_moe:
135
- self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
136
186
 
137
187
  self.routed_scaling_factor = routed_scaling_factor
138
188
  assert intermediate_size % self.moe_tp_size == 0
@@ -154,13 +204,19 @@ class FusedMoE(torch.nn.Module):
154
204
  )
155
205
  else:
156
206
  self.quant_method = quant_config.get_quant_method(self, prefix)
157
- if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
158
- self.quant_method.enable_flashinfer_cutlass_moe = (
159
- self.enable_flashinfer_cutlass_moe
160
- )
161
207
  assert self.quant_method is not None
162
208
 
163
209
  self.quant_config = quant_config
210
+ self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
211
+ "enable_flashinfer_mxfp4_moe", False
212
+ )
213
+ # TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
214
+ if (
215
+ self.quant_config is not None
216
+ and self.quant_config.get_name() == "mxfp4"
217
+ and self.use_enable_flashinfer_mxfp4_moe
218
+ ):
219
+ hidden_size = round_up(hidden_size, 256)
164
220
  self.quant_method.create_weights(
165
221
  layer=self,
166
222
  num_experts=self.num_local_experts,
@@ -169,7 +225,12 @@ class FusedMoE(torch.nn.Module):
169
225
  intermediate_size=self.intermediate_size_per_partition,
170
226
  intermediate_size_per_partition=self.intermediate_size_per_partition,
171
227
  params_dtype=params_dtype,
172
- weight_loader=self.weight_loader,
228
+ weight_loader=(
229
+ self.weight_loader
230
+ if not use_weight_loader_fused
231
+ else self.weight_loader_fused
232
+ ),
233
+ with_bias=with_bias,
173
234
  )
174
235
 
175
236
  def _load_per_tensor_weight_scale(
@@ -197,6 +258,7 @@ class FusedMoE(torch.nn.Module):
197
258
  shard_id: str,
198
259
  loaded_weight: torch.Tensor,
199
260
  tp_rank: int,
261
+ is_bias: bool = False,
200
262
  ):
201
263
  # Load grouped weight scales for group quantization
202
264
  # or model weights
@@ -207,14 +269,16 @@ class FusedMoE(torch.nn.Module):
207
269
  loaded_weight=loaded_weight,
208
270
  expert_data=expert_data,
209
271
  tp_rank=tp_rank,
272
+ is_bias=is_bias,
210
273
  )
211
- elif shard_id in ("w1", "w3"):
274
+ elif shard_id in ("w1", "w3", "w13"):
212
275
  self._load_w13(
213
276
  shard_id=shard_id,
214
277
  shard_dim=shard_dim,
215
278
  loaded_weight=loaded_weight,
216
279
  expert_data=expert_data,
217
280
  tp_rank=tp_rank,
281
+ is_bias=is_bias,
218
282
  )
219
283
 
220
284
  def _load_per_channel_weight_scale(
@@ -244,17 +308,30 @@ class FusedMoE(torch.nn.Module):
244
308
  shard_id: str,
245
309
  loaded_weight: torch.Tensor,
246
310
  tp_rank: int,
311
+ is_bias: bool = False,
247
312
  ):
248
313
 
249
314
  # Index the loaded weight for tp sharding.
250
315
  # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
251
- shard_size = expert_data.shape[shard_dim] // 2
316
+ assert shard_id in {"w1", "w3", "w13"}
317
+
318
+ if is_bias:
319
+ # if this weight is a bias, the last dimension must be the sharded dimension
320
+ shard_dim = -1
321
+
322
+ if shard_id in {"w1", "w3"}:
323
+ # non-fused version
324
+ shard_size = expert_data.shape[shard_dim] // 2
325
+ elif shard_id in {"w13"}:
326
+ # fused version
327
+ shard_size = expert_data.shape[shard_dim]
328
+ else:
329
+ raise NotImplementedError
252
330
 
253
331
  # Narrow parameter and load.
254
332
  # w1, gate_proj: Load into first logical weight of w13.
255
333
  # w3, up_proj: Load into second logical weight of w13.
256
334
  # trtllm cutlass kernel assumes differently
257
- assert shard_id in ("w1", "w3")
258
335
  switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
259
336
  if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
260
337
  start = shard_size
@@ -273,7 +350,8 @@ class FusedMoE(torch.nn.Module):
273
350
  )
274
351
  else:
275
352
  if not self.use_presharded_weights:
276
- if self.use_triton_kernels:
353
+ if not is_bias and self.use_triton_kernels:
354
+ # do not transpose for bias
277
355
  loaded_weight = loaded_weight.transpose(-2, -1)
278
356
  loaded_weight = loaded_weight.narrow(
279
357
  shard_dim, shard_size * tp_rank, shard_size
@@ -289,6 +367,7 @@ class FusedMoE(torch.nn.Module):
289
367
  shard_id: str,
290
368
  loaded_weight: torch.Tensor,
291
369
  tp_rank: int,
370
+ is_bias: bool = False,
292
371
  ):
293
372
  """Load w2 weights for down projection.
294
373
 
@@ -319,7 +398,14 @@ class FusedMoE(torch.nn.Module):
319
398
  # Index the loaded weight for tp sharding.
320
399
  # down_proj: "RowParallel" so tp sharding on input_dim
321
400
  # Narrow parameter and load.
322
- shard_size = expert_data.shape[shard_dim]
401
+ if is_bias:
402
+ # this expert_data is a bias, not weight,
403
+ # for w2_weight_bias in TP, it does not need to be sharded
404
+ shard_size = expert_data.shape[-1]
405
+ else:
406
+ # this parameter is a weight matrix
407
+ # for w2 in TP, it shards the input_features, i.e., shard_dim=2
408
+ shard_size = expert_data.shape[shard_dim]
323
409
 
324
410
  if _is_cpu:
325
411
  expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
@@ -332,13 +418,9 @@ class FusedMoE(torch.nn.Module):
332
418
  not self.use_presharded_weights,
333
419
  )
334
420
  else:
335
- if not self.use_presharded_weights:
421
+ if not is_bias and not self.use_presharded_weights:
336
422
  if self.use_triton_kernels:
337
423
  loaded_weight = loaded_weight.transpose(-2, -1)
338
- if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
339
- raise ValueError(
340
- f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
341
- )
342
424
  loaded_weight = loaded_weight.narrow(
343
425
  shard_dim, shard_size * tp_rank, shard_size
344
426
  )
@@ -386,9 +468,25 @@ class FusedMoE(torch.nn.Module):
386
468
  loaded_weight: torch.Tensor,
387
469
  weight_name: str,
388
470
  shard_id: str,
389
- expert_id: int,
471
+ expert_id: Optional[int],
390
472
  ) -> None:
391
473
 
474
+ # if expert_id is None, then
475
+ # all the experts are loaded at the same time
476
+ if (
477
+ not expert_id
478
+ and self.quant_config is not None
479
+ and self.quant_config.get_name() == "mxfp4"
480
+ ):
481
+ if "bias" in weight_name:
482
+ dim1 = loaded_weight.shape[1]
483
+ param.data[:, :dim1].copy_(loaded_weight)
484
+ else:
485
+ dim1 = loaded_weight.shape[1]
486
+ dim2 = loaded_weight.shape[2]
487
+ param.data[:, :dim1, :dim2].copy_(loaded_weight)
488
+ return
489
+
392
490
  global_expert_location_metadata = get_global_expert_location_metadata()
393
491
  if global_expert_location_metadata is None:
394
492
  self._weight_loader_impl(
@@ -427,6 +525,7 @@ class FusedMoE(torch.nn.Module):
427
525
  shard_id: str,
428
526
  expert_id: int,
429
527
  ) -> None:
528
+
430
529
  expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
431
530
  if expert_id == -1:
432
531
  return
@@ -621,16 +720,104 @@ class FusedMoE(torch.nn.Module):
621
720
  )
622
721
  return
623
722
 
723
+ def weight_loader_fused(
724
+ self,
725
+ param: torch.nn.Parameter,
726
+ loaded_weight: torch.Tensor,
727
+ weight_name: str,
728
+ shard_id: str,
729
+ ) -> None:
730
+ tp_rank = self.moe_tp_rank
731
+
732
+ if self.quant_config is not None and self.quant_config.get_name() == "mxfp4":
733
+ if "bias" in weight_name:
734
+ dim1 = loaded_weight.shape[1]
735
+ param.data[:, :dim1].copy_(loaded_weight)
736
+ elif "scale" in weight_name:
737
+ param.data.copy_(loaded_weight)
738
+ else:
739
+ dim1 = loaded_weight.shape[1]
740
+ dim2 = loaded_weight.shape[2]
741
+ param.data[:, :dim1, :dim2].copy_(loaded_weight)
742
+ return
743
+
744
+ # compressed-tensors checkpoints with packed weights are stored flipped
745
+ # TODO: check self.quant_method.quant_config.quant_format
746
+ # against known CompressionFormat enum values that have this quality
747
+ loaded_weight = (
748
+ loaded_weight.t().contiguous()
749
+ if (
750
+ self.quant_method.__class__.__name__
751
+ == "CompressedTensorsWNA16MoEMethod"
752
+ )
753
+ else loaded_weight
754
+ )
755
+
756
+ if shard_id not in ("w13", "w2"):
757
+ raise ValueError(f"shard_id must be ['w13','w2'] but " f"got {shard_id}.")
758
+
759
+ # Fetch the dim to shard the parameter/loaded weight
760
+ # based on the shard id. This will be whatever
761
+ # dimension intermediate_size is used.
762
+ SHARD_ID_TO_SHARDED_DIM = {"w13": 1, "w2": 2}
763
+ SHARD_ID_TO_SHARDED_DIM_TRANSPOSE = {"w13": 2, "w2": 1}
764
+
765
+ expert_data = param.data
766
+ is_bias = expert_data.dim() == 2
767
+
768
+ # is_transposed: if the dim to shard the weight
769
+ # should be flipped. Required by GPTQ, compressed-tensors
770
+ # should be whatever dimension intermediate_size is
771
+ is_transposed = getattr(param, "is_transposed", False)
772
+
773
+ if self.use_triton_kernels:
774
+ is_transposed = True
775
+ shard_dim = (
776
+ SHARD_ID_TO_SHARDED_DIM[shard_id]
777
+ if not is_transposed
778
+ else SHARD_ID_TO_SHARDED_DIM_TRANSPOSE[shard_id]
779
+ )
780
+
781
+ # Case model weights
782
+ if "weight" in weight_name:
783
+ self._load_model_weight_or_group_weight_scale(
784
+ shard_id=shard_id,
785
+ shard_dim=shard_dim,
786
+ loaded_weight=loaded_weight,
787
+ expert_data=expert_data,
788
+ tp_rank=tp_rank,
789
+ is_bias=is_bias,
790
+ )
791
+ return
792
+ else:
793
+ logging.warning(
794
+ f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
795
+ )
796
+
624
797
  def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
798
+ origin_hidden_states_dim = hidden_states.shape[-1]
625
799
  assert self.quant_method is not None
626
800
 
627
- if self.expert_map_gpu is not None:
801
+ if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
802
+ if self.expert_map_cpu is not None and self.expert_map_gpu is None:
803
+ # If we are in EP mode, we need to move the expert map to GPU.
804
+ self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
805
+
806
+ if self.expert_map_gpu is not None and isinstance(
807
+ topk_output, StandardTopKOutput
808
+ ):
628
809
  topk_output = topk_output._replace(
629
810
  topk_ids=self.expert_map_gpu[topk_output.topk_ids]
630
811
  )
631
812
 
632
813
  # Matrix multiply.
633
814
  with use_symmetric_memory(get_tp_group()) as sm:
815
+ kwargs = {}
816
+ if self.activation_alpha is not None:
817
+ kwargs["activation_alpha"] = self.activation_alpha
818
+ if self.swiglu_limit is not None:
819
+ kwargs["swiglu_limit"] = self.swiglu_limit
820
+
634
821
  final_hidden_states = self.quant_method.apply(
635
822
  layer=self,
636
823
  x=hidden_states,
@@ -649,9 +836,14 @@ class FusedMoE(torch.nn.Module):
649
836
  == "ModelOptNvFp4FusedMoEMethod"
650
837
  else {}
651
838
  ),
839
+ **kwargs,
652
840
  )
653
841
  sm.tag(final_hidden_states)
654
842
 
843
+ final_hidden_states = final_hidden_states[
844
+ ..., :origin_hidden_states_dim
845
+ ].contiguous()
846
+
655
847
  if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
656
848
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
657
849
 
@@ -686,6 +878,52 @@ class FusedMoE(torch.nn.Module):
686
878
  ]
687
879
  ]
688
880
 
881
+ @classmethod
882
+ def make_expert_params_mapping_fused(
883
+ cls,
884
+ ckpt_gate_up_proj_name: str,
885
+ ckpt_down_proj_name: str,
886
+ ckpt_gate_up_proj_bias_name: str,
887
+ ckpt_down_proj_bias_name: str,
888
+ ):
889
+ return [
890
+ ("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
891
+ (
892
+ "experts.w13_weight_bias",
893
+ f"experts.{ckpt_gate_up_proj_bias_name}",
894
+ "w13",
895
+ ),
896
+ ("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
897
+ ("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
898
+ ]
899
+
900
+ @classmethod
901
+ def make_expert_params_mapping_fused_mxfp4(
902
+ cls,
903
+ ckpt_gate_up_proj_name: str,
904
+ ckpt_down_proj_name: str,
905
+ ckpt_gate_up_proj_bias_name: str,
906
+ ckpt_down_proj_bias_name: str,
907
+ ckpt_gate_up_proj_scale_name: str,
908
+ ckpt_down_proj_scale_name: str,
909
+ ):
910
+ return [
911
+ ("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
912
+ (
913
+ "experts.w13_weight_bias",
914
+ f"experts.{ckpt_gate_up_proj_bias_name}",
915
+ "w13",
916
+ ),
917
+ ("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
918
+ ("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
919
+ (
920
+ "experts.w13_weight_scale",
921
+ f"experts.{ckpt_gate_up_proj_scale_name}",
922
+ "w13",
923
+ ),
924
+ ("experts.w2_weight_scale", f"experts.{ckpt_down_proj_scale_name}", "w2"),
925
+ ]
926
+
689
927
  @classmethod
690
928
  def make_expert_input_scale_params_mapping(
691
929
  cls,
@@ -721,8 +959,13 @@ class FlashInferFusedMoE(FusedMoE):
721
959
  self.num_expert_group = num_expert_group
722
960
  self.topk_group = topk_group
723
961
  self.correction_bias = correction_bias
962
+ self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
724
963
 
725
- def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
964
+ def forward(self, hidden_states: torch.Tensor, topk_output: tuple):
965
+ assert self.use_flashinfer_trtllm_moe
966
+ assert (
967
+ self.activation == "silu"
968
+ ), "Only silu is supported for flashinfer blockscale fp8 moe"
726
969
  assert self.quant_method is not None
727
970
  assert (
728
971
  self.renormalize
@@ -730,6 +973,14 @@ class FlashInferFusedMoE(FusedMoE):
730
973
  assert (
731
974
  self.num_fused_shared_experts == 0
732
975
  ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
976
+
977
+ # TRTLLM mode expects (TopK_config, router_logits) tuple
978
+ if not isinstance(topk_output, tuple) or len(topk_output) != 2:
979
+ raise ValueError(
980
+ f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
981
+ )
982
+ _, router_logits = topk_output
983
+
733
984
  # Matrix multiply.
734
985
  final_hidden_states = self.quant_method.apply_with_router_logits(
735
986
  layer=self,
@@ -739,7 +990,135 @@ class FlashInferFusedMoE(FusedMoE):
739
990
  routed_scaling_factor=self.routed_scaling_factor,
740
991
  )
741
992
 
742
- if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
993
+ if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
743
994
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
744
995
 
745
996
  return final_hidden_states
997
+
998
+
999
+ class FlashInferFP4MoE(FusedMoE):
1000
+ """FP4 TRTLLM MoE implementation using FlashInfer."""
1001
+
1002
+ def __init__(self, *args, **kwargs):
1003
+ # Extract DeepSeek-specific parameters
1004
+ renormalize = kwargs.pop("renormalize", True)
1005
+ num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
1006
+ use_grouped_topk = kwargs.pop("use_grouped_topk", False)
1007
+ num_expert_group = kwargs.pop("num_expert_group", None)
1008
+ topk_group = kwargs.pop("topk_group", None)
1009
+ correction_bias = kwargs.pop("correction_bias", None)
1010
+
1011
+ # Extract additional TopK parameters that were previously extracted in forward
1012
+ routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
1013
+
1014
+ super().__init__(*args, **kwargs)
1015
+
1016
+ # Store DeepSeek parameters
1017
+ self.renormalize = renormalize
1018
+ self.num_fused_shared_experts = num_fused_shared_experts
1019
+ self.use_grouped_topk = use_grouped_topk
1020
+ self.num_expert_group = num_expert_group
1021
+ self.topk_group = topk_group
1022
+ self.correction_bias = correction_bias
1023
+ self.routed_scaling_factor = routed_scaling_factor
1024
+
1025
+ # ---------------------------------------------------------------------
1026
+ # Helper: quantize hidden states to FP4 each forward pass
1027
+ # ---------------------------------------------------------------------
1028
+ def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor):
1029
+ """
1030
+ Quantize hidden states using global scale factor from quantization method.
1031
+
1032
+ Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading.
1033
+ Only block scales are computed at runtime for efficiency.
1034
+
1035
+ Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32)
1036
+ """
1037
+
1038
+ # flashinfer.fp4_quantize returns (packed_uint8, scale_fp8)
1039
+ # Only the block scales are computed at runtime
1040
+ hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
1041
+ hidden_states,
1042
+ self.w13_input_scale_quant,
1043
+ 16, # sf_vec_size
1044
+ False, # use_ue8m0
1045
+ False, # is_sf_swizzled_layout
1046
+ )
1047
+
1048
+ hs_fp4 = hs_fp4_bytes.reshape(
1049
+ hidden_states.shape[0], hidden_states.shape[1] // 2
1050
+ )
1051
+ hs_sf = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)
1052
+
1053
+ return hs_fp4, hs_sf
1054
+
1055
+ def forward(self, hidden_states: torch.Tensor, topk_output):
1056
+ """Forward pass using FP4 TRTLLM kernel.
1057
+
1058
+ Args:
1059
+ hidden_states: Input tensor
1060
+ topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
1061
+ """
1062
+
1063
+ # TRTLLM mode expects (TopK_config, router_logits) tuple
1064
+ if not isinstance(topk_output, tuple) or len(topk_output) != 2:
1065
+ raise ValueError(
1066
+ f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
1067
+ )
1068
+
1069
+ _, router_logits = topk_output
1070
+
1071
+ hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
1072
+
1073
+ router_logits = router_logits.to(torch.float32)
1074
+
1075
+ result = trtllm_fp4_block_scale_moe(
1076
+ routing_logits=router_logits,
1077
+ routing_bias=self.correction_bias.to(hidden_states.dtype),
1078
+ hidden_states=hs_fp4,
1079
+ hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
1080
+ gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
1081
+ gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view(
1082
+ torch.float8_e4m3fn
1083
+ ),
1084
+ gemm1_bias=None,
1085
+ gemm1_alpha=None,
1086
+ gemm1_beta=None,
1087
+ gemm1_clamp_limit=None,
1088
+ gemm2_weights=self.gemm2_weights_fp4_shuffled.data,
1089
+ gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view(
1090
+ torch.float8_e4m3fn
1091
+ ),
1092
+ gemm2_bias=None,
1093
+ output1_scale_scalar=self.g1_scale_c.data,
1094
+ output1_scale_gate_scalar=self.g1_alphas.data,
1095
+ output2_scale_scalar=self.g2_alphas.data,
1096
+ num_experts=self.num_experts,
1097
+ top_k=self.top_k,
1098
+ n_group=self.num_expert_group,
1099
+ topk_group=self.topk_group,
1100
+ intermediate_size=self.intermediate_size_per_partition,
1101
+ local_expert_offset=self.moe_ep_rank * self.num_local_experts,
1102
+ local_num_experts=self.num_local_experts,
1103
+ routed_scaling_factor=self.routed_scaling_factor,
1104
+ tile_tokens_dim=_get_tile_tokens_dim(
1105
+ hidden_states.shape[0], self.top_k, self.num_local_experts
1106
+ ),
1107
+ routing_method_type=RoutingMethodType.DeepSeekV3,
1108
+ do_finalize=True,
1109
+ )[0]
1110
+
1111
+ return result
1112
+
1113
+
1114
+ def get_fused_moe_impl_class():
1115
+ """Factory function to get the appropriate FusedMoE implementation class."""
1116
+ if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
1117
+ # Use FP4 variant when FP4 quantization is enabled
1118
+ return FlashInferFP4MoE
1119
+ elif should_use_flashinfer_trtllm_moe():
1120
+ # Use regular FlashInfer variant for non-FP4 FlashInfer cases
1121
+ return FlashInferFusedMoE
1122
+ else:
1123
+ # Default case
1124
+ return FusedMoE