sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,14 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
2
2
 
3
- import importlib.util
3
+ import datetime
4
+ import glob
4
5
  import logging
6
+ import os
7
+ import sys
5
8
  from enum import Enum
6
- from functools import lru_cache
7
9
  from typing import List, Optional, Tuple
8
10
 
9
11
  import torch
10
- from packaging import version as pkg_version
11
12
 
12
13
  from sglang.srt.distributed import (
13
14
  get_moe_expert_parallel_rank,
@@ -22,6 +23,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
22
23
  )
23
24
  from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
24
25
  from sglang.srt.layers.moe.topk import StandardTopKOutput
26
+ from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
25
27
  from sglang.srt.layers.quantization.base_config import (
26
28
  QuantizationConfig,
27
29
  QuantizeMethodBase,
@@ -29,22 +31,59 @@ from sglang.srt.layers.quantization.base_config import (
29
31
  from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
30
32
  from sglang.srt.managers.schedule_batch import global_server_args_dict
31
33
  from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
32
- from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip
34
+ from sglang.srt.utils import (
35
+ cpu_has_amx_support,
36
+ get_bool_env_var,
37
+ is_cpu,
38
+ is_flashinfer_available,
39
+ is_hip,
40
+ next_power_of_2,
41
+ round_up,
42
+ )
43
+
44
+ if is_flashinfer_available():
45
+ from flashinfer import (
46
+ RoutingMethodType,
47
+ fp4_quantize,
48
+ reorder_rows_for_gated_act_gemm,
49
+ shuffle_matrix_a,
50
+ shuffle_matrix_sf_a,
51
+ )
33
52
 
34
53
  _is_hip = is_hip()
35
54
  _is_cpu_amx_available = cpu_has_amx_support()
36
55
  _is_cpu = is_cpu()
37
56
 
57
+
58
+ # Try to import FP4 TRTLLM function if flashinfer is available
59
+ trtllm_fp4_block_scale_moe = None
60
+ if should_use_flashinfer_trtllm_moe():
61
+ try:
62
+ from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
63
+ except ImportError:
64
+ trtllm_fp4_block_scale_moe = None
65
+
38
66
  logger = logging.getLogger(__name__)
39
67
 
40
68
 
41
- @lru_cache(maxsize=1)
42
- def should_use_flashinfer_trtllm_moe():
43
- return global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
44
- not importlib.util.find_spec("flashinfer")
45
- or pkg_version.parse(__import__("flashinfer").__version__)
46
- >= pkg_version.parse("0.2.9rc1")
47
- )
69
+ def _is_fp4_quantization_enabled():
70
+ """Check if ModelOpt FP4 quantization is enabled."""
71
+ try:
72
+ # Use the same simple check that works for class selection
73
+ quantization = global_server_args_dict.get("quantization")
74
+ return quantization == "modelopt_fp4"
75
+ except:
76
+ return False
77
+
78
+
79
+ def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
80
+ # Guess tokens per expert assuming perfect expert distribution first.
81
+ num_tokens_per_expert = (num_tokens * top_k) // num_experts
82
+ # And pad the number to the next power of 2.
83
+ tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
84
+ # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
85
+ tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
86
+ return tile_tokens_dim
48
87
 
49
88
 
50
89
  class FusedMoeWeightScaleSupported(Enum):
@@ -96,6 +135,10 @@ class FusedMoE(torch.nn.Module):
96
135
  no_combine: bool = False,
97
136
  routed_scaling_factor: Optional[float] = None,
98
137
  enable_flashinfer_cutlass_moe: Optional[bool] = False,
138
+ activation_alpha: Optional[float] = None,
139
+ swiglu_limit: Optional[float] = None,
140
+ use_weight_loader_fused: bool = False,
141
+ with_bias=False,
99
142
  ):
100
143
  super().__init__()
101
144
 
@@ -104,12 +147,15 @@ class FusedMoE(torch.nn.Module):
104
147
 
105
148
  self.layer_id = layer_id
106
149
  self.top_k = top_k
107
- self.hidden_size = hidden_size
108
150
  self.num_experts = num_experts
109
151
  self.num_fused_shared_experts = num_fused_shared_experts
110
152
  self.expert_map_cpu = None
111
153
  self.expert_map_gpu = None
112
154
 
155
+ # For activation
156
+ self.activation_alpha = activation_alpha
157
+ self.swiglu_limit = swiglu_limit
158
+
113
159
  if enable_flashinfer_cutlass_moe and quant_config is None:
114
160
  logger.warning("Disable flashinfer MoE when quantization config is None.")
115
161
  enable_flashinfer_cutlass_moe = False
@@ -124,15 +170,18 @@ class FusedMoE(torch.nn.Module):
124
170
  if self.moe_ep_size > 1:
125
171
  # TODO(ch-wan): support shared experts fusion
126
172
  # Create a tensor of size num_experts filled with -1
127
- self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
173
+ self.expert_map_cpu = torch.full(
174
+ (self.num_experts,), -1, dtype=torch.int32, device="cpu"
175
+ )
176
+ self.expert_map_cpu = torch.full(
177
+ (self.num_experts,), -1, dtype=torch.int32, device="cpu"
178
+ )
128
179
  # Create a expert map for the local experts
129
180
  self.expert_map_cpu[
130
181
  self.moe_ep_rank
131
182
  * self.num_local_experts : (self.moe_ep_rank + 1)
132
183
  * self.num_local_experts
133
184
  ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
134
- if not self.enable_flashinfer_cutlass_moe:
135
- self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
136
185
 
137
186
  self.routed_scaling_factor = routed_scaling_factor
138
187
  assert intermediate_size % self.moe_tp_size == 0
@@ -154,13 +203,19 @@ class FusedMoE(torch.nn.Module):
154
203
  )
155
204
  else:
156
205
  self.quant_method = quant_config.get_quant_method(self, prefix)
157
- if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
158
- self.quant_method.enable_flashinfer_cutlass_moe = (
159
- self.enable_flashinfer_cutlass_moe
160
- )
161
206
  assert self.quant_method is not None
162
207
 
163
208
  self.quant_config = quant_config
209
+ self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
210
+ "enable_flashinfer_mxfp4_moe", False
211
+ )
212
+ if (
213
+ self.quant_config is not None
214
+ and self.quant_config.get_name() == "mxfp4"
215
+ and self.use_enable_flashinfer_mxfp4_moe
216
+ ):
217
+ hidden_size = round_up(hidden_size, 256)
218
+ self.hidden_size = hidden_size
164
219
  self.quant_method.create_weights(
165
220
  layer=self,
166
221
  num_experts=self.num_local_experts,
@@ -169,7 +224,12 @@ class FusedMoE(torch.nn.Module):
169
224
  intermediate_size=self.intermediate_size_per_partition,
170
225
  intermediate_size_per_partition=self.intermediate_size_per_partition,
171
226
  params_dtype=params_dtype,
172
- weight_loader=self.weight_loader,
227
+ weight_loader=(
228
+ self.weight_loader
229
+ if not use_weight_loader_fused
230
+ else self.weight_loader_fused
231
+ ),
232
+ with_bias=with_bias,
173
233
  )
174
234
 
175
235
  def _load_per_tensor_weight_scale(
@@ -197,6 +257,7 @@ class FusedMoE(torch.nn.Module):
197
257
  shard_id: str,
198
258
  loaded_weight: torch.Tensor,
199
259
  tp_rank: int,
260
+ is_bias: bool = False,
200
261
  ):
201
262
  # Load grouped weight scales for group quantization
202
263
  # or model weights
@@ -207,14 +268,16 @@ class FusedMoE(torch.nn.Module):
207
268
  loaded_weight=loaded_weight,
208
269
  expert_data=expert_data,
209
270
  tp_rank=tp_rank,
271
+ is_bias=is_bias,
210
272
  )
211
- elif shard_id in ("w1", "w3"):
273
+ elif shard_id in ("w1", "w3", "w13"):
212
274
  self._load_w13(
213
275
  shard_id=shard_id,
214
276
  shard_dim=shard_dim,
215
277
  loaded_weight=loaded_weight,
216
278
  expert_data=expert_data,
217
279
  tp_rank=tp_rank,
280
+ is_bias=is_bias,
218
281
  )
219
282
 
220
283
  def _load_per_channel_weight_scale(
@@ -244,17 +307,30 @@ class FusedMoE(torch.nn.Module):
244
307
  shard_id: str,
245
308
  loaded_weight: torch.Tensor,
246
309
  tp_rank: int,
310
+ is_bias: bool = False,
247
311
  ):
248
312
 
249
313
  # Index the loaded weight for tp sharding.
250
314
  # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
251
- shard_size = expert_data.shape[shard_dim] // 2
315
+ assert shard_id in {"w1", "w3", "w13"}
316
+
317
+ if is_bias:
318
+ # if this weight is a bias, the last dimension must be the sharded dimension
319
+ shard_dim = -1
320
+
321
+ if shard_id in {"w1", "w3"}:
322
+ # non-fused version
323
+ shard_size = expert_data.shape[shard_dim] // 2
324
+ elif shard_id in {"w13"}:
325
+ # fused version
326
+ shard_size = expert_data.shape[shard_dim]
327
+ else:
328
+ raise NotImplementedError
252
329
 
253
330
  # Narrow parameter and load.
254
331
  # w1, gate_proj: Load into first logical weight of w13.
255
332
  # w3, up_proj: Load into second logical weight of w13.
256
333
  # trtllm cutlass kernel assumes differently
257
- assert shard_id in ("w1", "w3")
258
334
  switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
259
335
  if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
260
336
  start = shard_size
@@ -273,7 +349,8 @@ class FusedMoE(torch.nn.Module):
273
349
  )
274
350
  else:
275
351
  if not self.use_presharded_weights:
276
- if self.use_triton_kernels:
352
+ if not is_bias and self.use_triton_kernels:
353
+ # do not transpose for bias
277
354
  loaded_weight = loaded_weight.transpose(-2, -1)
278
355
  loaded_weight = loaded_weight.narrow(
279
356
  shard_dim, shard_size * tp_rank, shard_size
@@ -289,6 +366,7 @@ class FusedMoE(torch.nn.Module):
289
366
  shard_id: str,
290
367
  loaded_weight: torch.Tensor,
291
368
  tp_rank: int,
369
+ is_bias: bool = False,
292
370
  ):
293
371
  """Load w2 weights for down projection.
294
372
 
@@ -319,7 +397,14 @@ class FusedMoE(torch.nn.Module):
319
397
  # Index the loaded weight for tp sharding.
320
398
  # down_proj: "RowParallel" so tp sharding on input_dim
321
399
  # Narrow parameter and load.
322
- shard_size = expert_data.shape[shard_dim]
400
+ if is_bias:
401
+ # this expert_data is a bias, not weight,
402
+ # for w2_weight_bias in TP, it does not need to be sharded
403
+ shard_size = expert_data.shape[-1]
404
+ else:
405
+ # this parameter is a weight matrix
406
+ # for w2 in TP, it shards the input_features, i.e., shard_dim=2
407
+ shard_size = expert_data.shape[shard_dim]
323
408
 
324
409
  if _is_cpu:
325
410
  expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
@@ -332,13 +417,9 @@ class FusedMoE(torch.nn.Module):
332
417
  not self.use_presharded_weights,
333
418
  )
334
419
  else:
335
- if not self.use_presharded_weights:
420
+ if not is_bias and not self.use_presharded_weights:
336
421
  if self.use_triton_kernels:
337
422
  loaded_weight = loaded_weight.transpose(-2, -1)
338
- if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
339
- raise ValueError(
340
- f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
341
- )
342
423
  loaded_weight = loaded_weight.narrow(
343
424
  shard_dim, shard_size * tp_rank, shard_size
344
425
  )
@@ -386,9 +467,25 @@ class FusedMoE(torch.nn.Module):
386
467
  loaded_weight: torch.Tensor,
387
468
  weight_name: str,
388
469
  shard_id: str,
389
- expert_id: int,
470
+ expert_id: Optional[int],
390
471
  ) -> None:
391
472
 
473
+ # if expert_id is None, then
474
+ # all the experts are loaded at the same time
475
+ if (
476
+ not expert_id
477
+ and self.quant_config is not None
478
+ and self.quant_config.get_name() == "mxfp4"
479
+ ):
480
+ if "bias" in weight_name:
481
+ dim1 = loaded_weight.shape[1]
482
+ param.data[:, :dim1].copy_(loaded_weight)
483
+ else:
484
+ dim1 = loaded_weight.shape[1]
485
+ dim2 = loaded_weight.shape[2]
486
+ param.data[:, :dim1, :dim2].copy_(loaded_weight)
487
+ return
488
+
392
489
  global_expert_location_metadata = get_global_expert_location_metadata()
393
490
  if global_expert_location_metadata is None:
394
491
  self._weight_loader_impl(
@@ -427,6 +524,7 @@ class FusedMoE(torch.nn.Module):
427
524
  shard_id: str,
428
525
  expert_id: int,
429
526
  ) -> None:
527
+
430
528
  expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
431
529
  if expert_id == -1:
432
530
  return
@@ -621,16 +719,111 @@ class FusedMoE(torch.nn.Module):
621
719
  )
622
720
  return
623
721
 
722
+ def weight_loader_fused(
723
+ self,
724
+ param: torch.nn.Parameter,
725
+ loaded_weight: torch.Tensor,
726
+ weight_name: str,
727
+ shard_id: str,
728
+ ) -> None:
729
+ tp_rank = self.moe_tp_rank
730
+
731
+ if self.quant_config is not None and self.quant_config.get_name() == "mxfp4":
732
+ if "bias" in weight_name:
733
+ dim1 = loaded_weight.shape[1]
734
+ param.data[:, :dim1].copy_(loaded_weight)
735
+ elif "scale" in weight_name:
736
+ param.data.copy_(loaded_weight)
737
+ else:
738
+ dim1 = loaded_weight.shape[1]
739
+ dim2 = loaded_weight.shape[2]
740
+ param.data[:, :dim1, :dim2].copy_(loaded_weight)
741
+ return
742
+
743
+ # compressed-tensors checkpoints with packed weights are stored flipped
744
+ # TODO: check self.quant_method.quant_config.quant_format
745
+ # against known CompressionFormat enum values that have this quality
746
+ loaded_weight = (
747
+ loaded_weight.t().contiguous()
748
+ if (
749
+ self.quant_method.__class__.__name__
750
+ == "CompressedTensorsWNA16MoEMethod"
751
+ )
752
+ else loaded_weight
753
+ )
754
+
755
+ if shard_id not in ("w13", "w2"):
756
+ raise ValueError(f"shard_id must be ['w13','w2'] but " f"got {shard_id}.")
757
+
758
+ # Fetch the dim to shard the parameter/loaded weight
759
+ # based on the shard id. This will be whatever
760
+ # dimension intermediate_size is used.
761
+ SHARD_ID_TO_SHARDED_DIM = {"w13": 1, "w2": 2}
762
+ SHARD_ID_TO_SHARDED_DIM_TRANSPOSE = {"w13": 2, "w2": 1}
763
+
764
+ expert_data = param.data
765
+ is_bias = expert_data.dim() == 2
766
+
767
+ # is_transposed: if the dim to shard the weight
768
+ # should be flipped. Required by GPTQ, compressed-tensors
769
+ # should be whatever dimension intermediate_size is
770
+ is_transposed = getattr(param, "is_transposed", False)
771
+
772
+ if self.use_triton_kernels:
773
+ is_transposed = True
774
+ shard_dim = (
775
+ SHARD_ID_TO_SHARDED_DIM[shard_id]
776
+ if not is_transposed
777
+ else SHARD_ID_TO_SHARDED_DIM_TRANSPOSE[shard_id]
778
+ )
779
+
780
+ # Case model weights
781
+ if "weight" in weight_name:
782
+ self._load_model_weight_or_group_weight_scale(
783
+ shard_id=shard_id,
784
+ shard_dim=shard_dim,
785
+ loaded_weight=loaded_weight,
786
+ expert_data=expert_data,
787
+ tp_rank=tp_rank,
788
+ is_bias=is_bias,
789
+ )
790
+ return
791
+ else:
792
+ logging.warning(
793
+ f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
794
+ )
795
+
624
796
  def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
797
+ origin_hidden_states_dim = hidden_states.shape[-1]
798
+ if self.hidden_size != origin_hidden_states_dim:
799
+ hidden_states = torch.nn.functional.pad(
800
+ hidden_states,
801
+ (0, self.hidden_size - origin_hidden_states_dim),
802
+ mode="constant",
803
+ value=0.0,
804
+ )
625
805
  assert self.quant_method is not None
626
806
 
627
- if self.expert_map_gpu is not None:
807
+ if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
808
+ if self.expert_map_cpu is not None and self.expert_map_gpu is None:
809
+ # If we are in EP mode, we need to move the expert map to GPU.
810
+ self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
811
+
812
+ if self.expert_map_gpu is not None and isinstance(
813
+ topk_output, StandardTopKOutput
814
+ ):
628
815
  topk_output = topk_output._replace(
629
816
  topk_ids=self.expert_map_gpu[topk_output.topk_ids]
630
817
  )
631
818
 
632
819
  # Matrix multiply.
633
820
  with use_symmetric_memory(get_tp_group()) as sm:
821
+ kwargs = {}
822
+ if self.activation_alpha is not None:
823
+ kwargs["activation_alpha"] = self.activation_alpha
824
+ if self.swiglu_limit is not None:
825
+ kwargs["swiglu_limit"] = self.swiglu_limit
826
+
634
827
  final_hidden_states = self.quant_method.apply(
635
828
  layer=self,
636
829
  x=hidden_states,
@@ -649,13 +842,14 @@ class FusedMoE(torch.nn.Module):
649
842
  == "ModelOptNvFp4FusedMoEMethod"
650
843
  else {}
651
844
  ),
845
+ **kwargs,
652
846
  )
653
847
  sm.tag(final_hidden_states)
654
848
 
655
849
  if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
656
850
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
657
851
 
658
- return final_hidden_states
852
+ return final_hidden_states[..., :origin_hidden_states_dim].contiguous()
659
853
 
660
854
  @classmethod
661
855
  def make_expert_params_mapping(
@@ -686,6 +880,52 @@ class FusedMoE(torch.nn.Module):
686
880
  ]
687
881
  ]
688
882
 
883
+ @classmethod
884
+ def make_expert_params_mapping_fused(
885
+ cls,
886
+ ckpt_gate_up_proj_name: str,
887
+ ckpt_down_proj_name: str,
888
+ ckpt_gate_up_proj_bias_name: str,
889
+ ckpt_down_proj_bias_name: str,
890
+ ):
891
+ return [
892
+ ("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
893
+ (
894
+ "experts.w13_weight_bias",
895
+ f"experts.{ckpt_gate_up_proj_bias_name}",
896
+ "w13",
897
+ ),
898
+ ("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
899
+ ("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
900
+ ]
901
+
902
+ @classmethod
903
+ def make_expert_params_mapping_fused_mxfp4(
904
+ cls,
905
+ ckpt_gate_up_proj_name: str,
906
+ ckpt_down_proj_name: str,
907
+ ckpt_gate_up_proj_bias_name: str,
908
+ ckpt_down_proj_bias_name: str,
909
+ ckpt_gate_up_proj_scale_name: str,
910
+ ckpt_down_proj_scale_name: str,
911
+ ):
912
+ return [
913
+ ("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
914
+ (
915
+ "experts.w13_weight_bias",
916
+ f"experts.{ckpt_gate_up_proj_bias_name}",
917
+ "w13",
918
+ ),
919
+ ("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
920
+ ("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
921
+ (
922
+ "experts.w13_weight_scale",
923
+ f"experts.{ckpt_gate_up_proj_scale_name}",
924
+ "w13",
925
+ ),
926
+ ("experts.w2_weight_scale", f"experts.{ckpt_down_proj_scale_name}", "w2"),
927
+ ]
928
+
689
929
  @classmethod
690
930
  def make_expert_input_scale_params_mapping(
691
931
  cls,
@@ -721,8 +961,13 @@ class FlashInferFusedMoE(FusedMoE):
721
961
  self.num_expert_group = num_expert_group
722
962
  self.topk_group = topk_group
723
963
  self.correction_bias = correction_bias
964
+ self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
724
965
 
725
- def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
966
+ def forward(self, hidden_states: torch.Tensor, topk_output: tuple):
967
+ assert self.use_flashinfer_trtllm_moe
968
+ assert (
969
+ self.activation == "silu"
970
+ ), "Only silu is supported for flashinfer blockscale fp8 moe"
726
971
  assert self.quant_method is not None
727
972
  assert (
728
973
  self.renormalize
@@ -730,6 +975,14 @@ class FlashInferFusedMoE(FusedMoE):
730
975
  assert (
731
976
  self.num_fused_shared_experts == 0
732
977
  ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
978
+
979
+ # TRTLLM mode expects (TopK_config, router_logits) tuple
980
+ if not isinstance(topk_output, tuple) or len(topk_output) != 2:
981
+ raise ValueError(
982
+ f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
983
+ )
984
+ _, router_logits = topk_output
985
+
733
986
  # Matrix multiply.
734
987
  final_hidden_states = self.quant_method.apply_with_router_logits(
735
988
  layer=self,
@@ -739,7 +992,135 @@ class FlashInferFusedMoE(FusedMoE):
739
992
  routed_scaling_factor=self.routed_scaling_factor,
740
993
  )
741
994
 
742
- if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
995
+ if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
743
996
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
744
997
 
745
998
  return final_hidden_states
999
+
1000
+
1001
+ class FlashInferFP4MoE(FusedMoE):
1002
+ """FP4 TRTLLM MoE implementation using FlashInfer."""
1003
+
1004
+ def __init__(self, *args, **kwargs):
1005
+ # Extract DeepSeek-specific parameters
1006
+ renormalize = kwargs.pop("renormalize", True)
1007
+ num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
1008
+ use_grouped_topk = kwargs.pop("use_grouped_topk", False)
1009
+ num_expert_group = kwargs.pop("num_expert_group", None)
1010
+ topk_group = kwargs.pop("topk_group", None)
1011
+ correction_bias = kwargs.pop("correction_bias", None)
1012
+
1013
+ # Extract additional TopK parameters that were previously extracted in forward
1014
+ routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
1015
+
1016
+ super().__init__(*args, **kwargs)
1017
+
1018
+ # Store DeepSeek parameters
1019
+ self.renormalize = renormalize
1020
+ self.num_fused_shared_experts = num_fused_shared_experts
1021
+ self.use_grouped_topk = use_grouped_topk
1022
+ self.num_expert_group = num_expert_group
1023
+ self.topk_group = topk_group
1024
+ self.correction_bias = correction_bias
1025
+ self.routed_scaling_factor = routed_scaling_factor
1026
+
1027
+ # ---------------------------------------------------------------------
1028
+ # Helper: quantize hidden states to FP4 each forward pass
1029
+ # ---------------------------------------------------------------------
1030
+ def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor):
1031
+ """
1032
+ Quantize hidden states using global scale factor from quantization method.
1033
+
1034
+ Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading.
1035
+ Only block scales are computed at runtime for efficiency.
1036
+
1037
+ Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32)
1038
+ """
1039
+
1040
+ # flashinfer.fp4_quantize returns (packed_uint8, scale_fp8)
1041
+ # Only the block scales are computed at runtime
1042
+ hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
1043
+ hidden_states,
1044
+ self.w13_input_scale_quant,
1045
+ 16, # sf_vec_size
1046
+ False, # use_ue8m0
1047
+ False, # is_sf_swizzled_layout
1048
+ )
1049
+
1050
+ hs_fp4 = hs_fp4_bytes.reshape(
1051
+ hidden_states.shape[0], hidden_states.shape[1] // 2
1052
+ )
1053
+ hs_sf = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)
1054
+
1055
+ return hs_fp4, hs_sf
1056
+
1057
+ def forward(self, hidden_states: torch.Tensor, topk_output):
1058
+ """Forward pass using FP4 TRTLLM kernel.
1059
+
1060
+ Args:
1061
+ hidden_states: Input tensor
1062
+ topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
1063
+ """
1064
+
1065
+ # TRTLLM mode expects (TopK_config, router_logits) tuple
1066
+ if not isinstance(topk_output, tuple) or len(topk_output) != 2:
1067
+ raise ValueError(
1068
+ f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
1069
+ )
1070
+
1071
+ _, router_logits = topk_output
1072
+
1073
+ hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
1074
+
1075
+ router_logits = router_logits.to(torch.float32)
1076
+
1077
+ result = trtllm_fp4_block_scale_moe(
1078
+ routing_logits=router_logits,
1079
+ routing_bias=self.correction_bias.to(hidden_states.dtype),
1080
+ hidden_states=hs_fp4,
1081
+ hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
1082
+ gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
1083
+ gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view(
1084
+ torch.float8_e4m3fn
1085
+ ),
1086
+ gemm1_bias=None,
1087
+ gemm1_alpha=None,
1088
+ gemm1_beta=None,
1089
+ gemm1_clamp_limit=None,
1090
+ gemm2_weights=self.gemm2_weights_fp4_shuffled.data,
1091
+ gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view(
1092
+ torch.float8_e4m3fn
1093
+ ),
1094
+ gemm2_bias=None,
1095
+ output1_scale_scalar=self.g1_scale_c.data,
1096
+ output1_scale_gate_scalar=self.g1_alphas.data,
1097
+ output2_scale_scalar=self.g2_alphas.data,
1098
+ num_experts=self.num_experts,
1099
+ top_k=self.top_k,
1100
+ n_group=self.num_expert_group,
1101
+ topk_group=self.topk_group,
1102
+ intermediate_size=self.intermediate_size_per_partition,
1103
+ local_expert_offset=self.moe_ep_rank * self.num_local_experts,
1104
+ local_num_experts=self.num_local_experts,
1105
+ routed_scaling_factor=self.routed_scaling_factor,
1106
+ tile_tokens_dim=_get_tile_tokens_dim(
1107
+ hidden_states.shape[0], self.top_k, self.num_local_experts
1108
+ ),
1109
+ routing_method_type=RoutingMethodType.DeepSeekV3,
1110
+ do_finalize=True,
1111
+ )[0]
1112
+
1113
+ return result
1114
+
1115
+
1116
+ def get_fused_moe_impl_class():
1117
+ """Factory function to get the appropriate FusedMoE implementation class."""
1118
+ if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
1119
+ # Use FP4 variant when FP4 quantization is enabled
1120
+ return FlashInferFP4MoE
1121
+ elif should_use_flashinfer_trtllm_moe():
1122
+ # Use regular FlashInfer variant for non-FP4 FlashInfer cases
1123
+ return FlashInferFusedMoE
1124
+ else:
1125
+ # Default case
1126
+ return FusedMoE