sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/configs/internvl.py +3 -0
  3. sglang/srt/configs/model_config.py +4 -0
  4. sglang/srt/constrained/base_grammar_backend.py +10 -2
  5. sglang/srt/constrained/xgrammar_backend.py +7 -5
  6. sglang/srt/conversation.py +16 -1
  7. sglang/srt/debug_utils/__init__.py +0 -0
  8. sglang/srt/debug_utils/dump_comparator.py +131 -0
  9. sglang/srt/debug_utils/dumper.py +108 -0
  10. sglang/srt/debug_utils/text_comparator.py +172 -0
  11. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  12. sglang/srt/disaggregation/mooncake/conn.py +16 -0
  13. sglang/srt/disaggregation/prefill.py +13 -1
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/openai/serving_chat.py +132 -79
  16. sglang/srt/function_call/ebnf_composer.py +10 -3
  17. sglang/srt/function_call/function_call_parser.py +2 -0
  18. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  19. sglang/srt/function_call/qwen3_coder_detector.py +1 -0
  20. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  21. sglang/srt/layers/attention/vision.py +56 -8
  22. sglang/srt/layers/layernorm.py +26 -1
  23. sglang/srt/layers/logits_processor.py +14 -3
  24. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  27. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  28. sglang/srt/layers/moe/topk.py +84 -22
  29. sglang/srt/layers/multimodal.py +11 -8
  30. sglang/srt/layers/quantization/fp8.py +25 -247
  31. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  32. sglang/srt/layers/quantization/modelopt_quant.py +25 -10
  33. sglang/srt/layers/quantization/unquant.py +24 -76
  34. sglang/srt/layers/quantization/w4afp8.py +68 -17
  35. sglang/srt/lora/lora_registry.py +93 -29
  36. sglang/srt/managers/cache_controller.py +9 -7
  37. sglang/srt/managers/mm_utils.py +154 -35
  38. sglang/srt/managers/multimodal_processor.py +3 -14
  39. sglang/srt/managers/schedule_batch.py +14 -8
  40. sglang/srt/managers/scheduler.py +35 -1
  41. sglang/srt/managers/tokenizer_manager.py +37 -6
  42. sglang/srt/managers/tp_worker.py +3 -0
  43. sglang/srt/mem_cache/hiradix_cache.py +5 -2
  44. sglang/srt/model_executor/model_runner.py +68 -14
  45. sglang/srt/models/deepseek_v2.py +62 -28
  46. sglang/srt/models/glm4_moe.py +1035 -0
  47. sglang/srt/models/glm4_moe_nextn.py +167 -0
  48. sglang/srt/models/interns1.py +328 -0
  49. sglang/srt/models/internvl.py +143 -47
  50. sglang/srt/models/llava.py +9 -5
  51. sglang/srt/models/minicpmo.py +4 -1
  52. sglang/srt/models/qwen2_moe.py +2 -2
  53. sglang/srt/models/qwen3_moe.py +5 -2
  54. sglang/srt/multimodal/processors/base_processor.py +20 -6
  55. sglang/srt/multimodal/processors/clip.py +2 -2
  56. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  57. sglang/srt/multimodal/processors/gemma3.py +2 -2
  58. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  59. sglang/srt/multimodal/processors/internvl.py +21 -8
  60. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  61. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  62. sglang/srt/multimodal/processors/llava.py +4 -4
  63. sglang/srt/multimodal/processors/minicpm.py +2 -3
  64. sglang/srt/multimodal/processors/mlama.py +2 -2
  65. sglang/srt/multimodal/processors/mllama4.py +18 -111
  66. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  67. sglang/srt/multimodal/processors/pixtral.py +2 -2
  68. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  69. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  70. sglang/srt/multimodal/processors/vila.py +3 -1
  71. sglang/srt/reasoning_parser.py +2 -1
  72. sglang/srt/server_args.py +57 -6
  73. sglang/srt/utils.py +96 -1
  74. sglang/srt/weight_sync/utils.py +119 -0
  75. sglang/test/runners.py +4 -0
  76. sglang/test/test_utils.py +65 -5
  77. sglang/utils.py +19 -0
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +4 -4
  80. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +83 -73
  81. sglang/srt/debug_utils.py +0 -74
  82. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -30,13 +30,13 @@ from sglang.srt.layers.quantization.base_config import (
30
30
  QuantizationConfig,
31
31
  QuantizeMethodBase,
32
32
  )
33
- from sglang.srt.layers.quantization.fp8 import Fp8EPMoEMethod
33
+ from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
34
34
  from sglang.srt.layers.quantization.fp8_kernel import (
35
35
  is_fp8_fnuz,
36
36
  sglang_per_token_group_quant_fp8,
37
37
  sglang_per_token_quant_fp8,
38
38
  )
39
- from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod
39
+ from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
40
40
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
41
41
  from sglang.srt.managers.schedule_batch import global_server_args_dict
42
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -47,23 +47,33 @@ from sglang.srt.utils import (
47
47
  get_bool_env_var,
48
48
  is_hip,
49
49
  is_npu,
50
+ next_power_of_2,
50
51
  )
51
52
 
52
53
  _is_hip = is_hip()
53
54
  _is_npu = is_npu()
54
55
  _is_fp8_fnuz = is_fp8_fnuz()
55
56
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
57
+ use_flashinfer_trtllm_moe = (
58
+ global_server_args_dict["enable_flashinfer_trtllm_moe"]
59
+ and global_server_args_dict["enable_ep_moe"]
60
+ )
56
61
 
57
62
  if not (_is_npu or _is_hip):
58
63
  from sgl_kernel import silu_and_mul
59
64
 
60
- from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
61
-
62
65
  if _use_aiter:
63
66
  from aiter import ActivationType, QuantType
64
67
  from aiter.fused_moe import fused_moe
65
68
  from aiter.ops.shuffle import shuffle_weight
66
69
 
70
+ if use_flashinfer_trtllm_moe:
71
+ try:
72
+ import flashinfer.fused_moe as fi_fused_moe
73
+ except ImportError:
74
+ fi_fused_moe = None
75
+ use_flashinfer_trtllm_moe = False
76
+
67
77
  logger = logging.getLogger(__name__)
68
78
 
69
79
 
@@ -140,7 +150,17 @@ class GroupedGemmRunner(torch.nn.Module):
140
150
  return c
141
151
 
142
152
 
143
- class EPMoE(torch.nn.Module):
153
+ def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
154
+ # Guess tokens per expert assuming perfect expert distribution first.
155
+ num_tokens_per_expert = (num_tokens * top_k) // num_experts
156
+ # And pad the number to the next power of 2.
157
+ tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
158
+ # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
159
+ tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
160
+ return tile_tokens_dim
161
+
162
+
163
+ class EPMoE(FusedMoE):
144
164
  """
145
165
  MoE Expert Parallel Impl
146
166
 
@@ -162,51 +182,60 @@ class EPMoE(torch.nn.Module):
162
182
  routed_scaling_factor: Optional[float] = None,
163
183
  use_per_token_if_dynamic: bool = True,
164
184
  ):
165
- super().__init__()
185
+ super().__init__(
186
+ num_experts=num_experts,
187
+ hidden_size=hidden_size,
188
+ intermediate_size=intermediate_size,
189
+ top_k=top_k,
190
+ layer_id=layer_id,
191
+ params_dtype=params_dtype,
192
+ quant_config=quant_config,
193
+ tp_size=tp_size,
194
+ prefix=prefix,
195
+ activation=activation,
196
+ routed_scaling_factor=routed_scaling_factor,
197
+ enable_ep_moe=True,
198
+ skip_quant=True,
199
+ )
166
200
 
167
201
  if params_dtype is None:
168
202
  params_dtype = torch.get_default_dtype()
169
203
 
170
- self.tp_size = (
171
- tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
172
- )
173
- self.tp_rank = get_tensor_model_parallel_rank()
174
-
175
204
  self.layer_id = layer_id
176
- self.num_experts = num_experts
177
- assert self.num_experts % self.tp_size == 0
178
- self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
179
- self.start_expert_id = self.tp_rank * self.num_experts_per_partition
180
- self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
205
+ self.num_local_experts, self.expert_map = self.determine_expert_map()
206
+ self.start_expert_id = self.ep_rank * self.num_local_experts
207
+ self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
181
208
 
182
- self.top_k = top_k
183
209
  self.intermediate_size = intermediate_size
184
- self.activation = activation
185
- self.routed_scaling_factor = routed_scaling_factor
186
210
  self.use_per_token_if_dynamic = use_per_token_if_dynamic
187
211
 
212
+ # TODO(ch-wan): move quant preparation to FusedMoE
188
213
  if quant_config is None:
189
- self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
214
+ self.quant_method: Optional[QuantizeMethodBase] = (
215
+ UnquantizedFusedMoEMethod()
216
+ )
190
217
  self.use_fp8_w8a8 = False
191
218
  self.use_block_quant = False
192
219
  self.block_shape = None
193
220
  self.activation_scheme = None
194
- self.use_w4afp8 = False
221
+ self.w13_input_scale = None
222
+ self.w2_input_scale = None
223
+ self.w13_weight_scale = None
224
+ self.w2_weight_scale = None
195
225
  elif isinstance(quant_config, W4AFp8Config):
196
226
  self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
197
227
  quant_config
198
228
  )
199
- self.use_w4afp8 = True
200
229
  self.use_fp8_w8a8 = False
201
230
  self.use_block_quant = False
202
231
  self.fp8_dtype = torch.float8_e4m3fn
232
+ self.w13_input_scale = None
233
+ self.w2_input_scale = None
203
234
  self.w13_weight_scale = None
204
235
  self.w2_weight_scale = None
205
236
  self.activation_scheme = quant_config.moe_activation_scheme
206
- else:
207
- self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
208
- quant_config
209
- )
237
+ elif isinstance(quant_config, Fp8Config):
238
+ self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
210
239
  self.use_fp8_w8a8 = True
211
240
  self.use_block_quant = getattr(self.quant_method, "block_quant", False)
212
241
  self.block_shape = (
@@ -216,11 +245,13 @@ class EPMoE(torch.nn.Module):
216
245
  )
217
246
  self.fp8_dtype = torch.float8_e4m3fn
218
247
  self.activation_scheme = quant_config.activation_scheme
219
- self.use_w4afp8 = False
248
+ else:
249
+ raise ValueError(f"Unsupported quant_config: {quant_config}")
220
250
 
251
+ self.quant_config = quant_config
221
252
  self.quant_method.create_weights(
222
253
  layer=self,
223
- num_experts_per_partition=self.num_experts_per_partition,
254
+ num_experts=self.num_local_experts,
224
255
  hidden_size=hidden_size,
225
256
  intermediate_size=self.intermediate_size,
226
257
  params_dtype=params_dtype,
@@ -229,19 +260,6 @@ class EPMoE(torch.nn.Module):
229
260
 
230
261
  self.grouped_gemm_runner = None
231
262
 
232
- self.w13_weight_fp8 = (
233
- self.w13_weight,
234
- (
235
- self.w13_weight_scale_inv
236
- if self.use_block_quant
237
- else self.w13_weight_scale
238
- ),
239
- )
240
- self.w2_weight_fp8 = (
241
- self.w2_weight,
242
- self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
243
- )
244
-
245
263
  # Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
246
264
  # Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
247
265
  def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
@@ -260,8 +278,8 @@ class EPMoE(torch.nn.Module):
260
278
  Contains global_num_experts for experts not assigned to the current rank.
261
279
  Returns None if ep_size is 1.
262
280
  """
263
- ep_size = self.tp_size
264
- ep_rank = self.tp_rank
281
+ ep_size = self.ep_size
282
+ ep_rank = self.ep_rank
265
283
  global_num_experts = self.num_experts
266
284
 
267
285
  assert ep_size > 0
@@ -271,7 +289,7 @@ class EPMoE(torch.nn.Module):
271
289
  local_num_experts = global_num_experts // ep_size
272
290
 
273
291
  expert_map = torch.full(
274
- (global_num_experts,), self.num_experts, dtype=torch.int32
292
+ (global_num_experts,), global_num_experts, dtype=torch.int32
275
293
  )
276
294
  if ep_rank < (ep_size - 1):
277
295
  expert_map[
@@ -296,6 +314,20 @@ class EPMoE(torch.nn.Module):
296
314
  hidden_states: torch.Tensor,
297
315
  topk_output: TopKOutput,
298
316
  ):
317
+
318
+ self.w13_weight_fp8 = (
319
+ self.w13_weight,
320
+ (
321
+ self.w13_weight_scale_inv
322
+ if self.use_block_quant
323
+ else self.w13_weight_scale
324
+ ),
325
+ )
326
+ self.w2_weight_fp8 = (
327
+ self.w2_weight,
328
+ self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
329
+ )
330
+
299
331
  assert self.quant_method is not None
300
332
  assert self.activation == "silu"
301
333
  hidden_states_shape = hidden_states.shape
@@ -435,7 +467,10 @@ class EPMoE(torch.nn.Module):
435
467
  return output
436
468
 
437
469
  def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
438
- assert self.quant_method is not None
470
+ return self.quant_method.apply(self, hidden_states, topk_output)
471
+
472
+ def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
473
+
439
474
  topk_weights, topk_ids, _ = topk_output
440
475
 
441
476
  hidden_states_shape = hidden_states.shape
@@ -448,53 +483,11 @@ class EPMoE(torch.nn.Module):
448
483
  use_per_token_if_dynamic=self.use_per_token_if_dynamic,
449
484
  )
450
485
 
451
- if self.use_w4afp8:
452
- local_topk_ids = topk_ids
453
- if self.expert_map is not None:
454
- "Translate info from expert_map to topk_ids"
455
- local_topk_ids = torch.where(
456
- self.expert_map[topk_ids] != self.num_experts,
457
- self.expert_map[topk_ids],
458
- self.num_experts,
459
- )
460
-
461
- output = cutlass_w4a8_moe(
462
- self.start_expert_id,
463
- self.end_expert_id,
464
- self.num_experts,
465
- hidden_states,
466
- self.w13_weight,
467
- self.w2_weight,
468
- self.w13_weight_scale_inv,
469
- self.w2_weight_scale_inv,
470
- topk_weights,
471
- topk_ids,
472
- local_topk_ids,
473
- self.quant_method.a_strides1,
474
- self.quant_method.b_strides1,
475
- self.quant_method.c_strides1,
476
- self.quant_method.a_strides2,
477
- self.quant_method.b_strides2,
478
- self.quant_method.c_strides2,
479
- self.quant_method.s_strides13,
480
- self.quant_method.s_strides2,
481
- self.quant_method.expert_offsets,
482
- self.quant_method.problem_sizes1,
483
- self.quant_method.problem_sizes2,
484
- self.w13_input_scale,
485
- self.w2_input_scale,
486
- )
487
- return output
488
-
489
- if self.grouped_gemm_runner is None:
490
- self.grouped_gemm_runner = GroupedGemmRunner(
491
- hidden_states.device,
492
- use_flashinfer=False, # TODO: use flashinfer
493
- use_per_token_if_dynamic=self.use_per_token_if_dynamic,
494
- )
486
+ num_experts = self.num_experts
495
487
 
496
488
  reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
497
- topk_ids, self.num_experts
489
+ topk_ids,
490
+ num_experts,
498
491
  )
499
492
 
500
493
  gateup_input = torch.empty(
@@ -502,7 +495,7 @@ class EPMoE(torch.nn.Module):
502
495
  device=hidden_states.device,
503
496
  dtype=(
504
497
  self.fp8_dtype
505
- if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant)
498
+ if self.use_fp8_w8a8 and not self.use_block_quant
506
499
  else hidden_states.dtype
507
500
  ),
508
501
  )
@@ -513,7 +506,7 @@ class EPMoE(torch.nn.Module):
513
506
  else:
514
507
  max_value = (
515
508
  torch.max(hidden_states)
516
- .repeat(self.num_experts_per_partition)
509
+ .repeat(self.num_local_experts)
517
510
  .to(torch.float32)
518
511
  )
519
512
  self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
@@ -554,7 +547,7 @@ class EPMoE(torch.nn.Module):
554
547
  seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
555
548
  weight_indices_cur_rank = torch.arange(
556
549
  0,
557
- self.num_experts_per_partition,
550
+ self.num_local_experts,
558
551
  device=hidden_states_device,
559
552
  dtype=torch.int64,
560
553
  )
@@ -564,17 +557,13 @@ class EPMoE(torch.nn.Module):
564
557
  b=self.w13_weight,
565
558
  c=None,
566
559
  c_dtype=hidden_states_dtype,
567
- batch_size=self.num_experts_per_partition,
560
+ batch_size=self.num_local_experts,
568
561
  weight_column_major=True,
569
562
  seg_indptr=seg_indptr_cur_rank,
570
563
  weight_indices=weight_indices_cur_rank,
571
564
  use_fp8_w8a8=self.use_fp8_w8a8,
572
565
  scale_a=self.w13_input_scale,
573
- scale_b=(
574
- self.w13_weight_scale_inv
575
- if self.use_block_quant
576
- else self.w13_weight_scale
577
- ),
566
+ scale_b=self.w13_weight_scale,
578
567
  block_shape=self.block_shape,
579
568
  )
580
569
  del gateup_input
@@ -631,7 +620,7 @@ class EPMoE(torch.nn.Module):
631
620
  down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
632
621
  else:
633
622
  self.w2_input_scale = torch.ones(
634
- self.num_experts_per_partition,
623
+ self.num_local_experts,
635
624
  dtype=torch.float32,
636
625
  device=hidden_states_device,
637
626
  )
@@ -647,17 +636,13 @@ class EPMoE(torch.nn.Module):
647
636
  a=down_input,
648
637
  b=self.w2_weight,
649
638
  c=down_output,
650
- batch_size=self.num_experts_per_partition,
639
+ batch_size=self.num_local_experts,
651
640
  weight_column_major=True,
652
641
  seg_indptr=seg_indptr_cur_rank,
653
642
  weight_indices=weight_indices_cur_rank,
654
643
  use_fp8_w8a8=self.use_fp8_w8a8,
655
644
  scale_a=self.w2_input_scale,
656
- scale_b=(
657
- self.w2_weight_scale_inv
658
- if self.use_block_quant
659
- else self.w2_weight_scale
660
- ),
645
+ scale_b=self.w2_weight_scale,
661
646
  block_shape=self.block_shape,
662
647
  )
663
648
  del down_input
@@ -760,95 +745,14 @@ class EPMoE(torch.nn.Module):
760
745
  return
761
746
  expert_id = expert_id - self.start_expert_id
762
747
 
763
- if shard_id not in ("w1", "w2", "w3"):
764
- raise ValueError(
765
- f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
766
- )
767
-
768
- # Special case for fp8 scales.
769
- if "scale" in weight_name:
770
- self._load_fp8_scale(
771
- param.data,
772
- loaded_weight,
773
- weight_name,
774
- shard_id,
775
- expert_id,
776
- )
777
- return
778
-
779
- if shard_id == "w2":
780
- param.data[expert_id] = loaded_weight
781
- elif shard_id == "w1":
782
- param.data[expert_id][: self.intermediate_size, :] = loaded_weight
783
- elif shard_id == "w3":
784
- param.data[expert_id][self.intermediate_size :, :] = loaded_weight
785
- else:
786
- raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}")
787
-
788
- def _load_fp8_scale(
789
- self,
790
- param: torch.nn.Parameter,
791
- loaded_weight: torch.Tensor,
792
- weight_name: str,
793
- shard_id: str,
794
- expert_id: int,
795
- ) -> None:
796
- param_data = param.data
797
-
798
- # Input scales can be loaded directly and should be equal.
799
- if "input_scale" in weight_name:
800
- if self.use_w4afp8:
801
- if shard_id == "w1":
802
- param_data[expert_id][0] = loaded_weight
803
- elif shard_id == "w3":
804
- param_data[expert_id][1] = loaded_weight
805
- else:
806
- param_data[expert_id] = loaded_weight
807
- return
808
-
809
- if (
810
- (shard_id == "w1" or shard_id == "w3")
811
- and param_data[expert_id] != 1
812
- and (param_data[expert_id] - loaded_weight).abs() > 1e-5
813
- ):
814
- raise ValueError(
815
- "input_scales of w1 and w3 of a layer "
816
- f"must be equal. But got {param_data[expert_id]} "
817
- f"vs. {loaded_weight}"
818
- )
819
- param_data[expert_id] = loaded_weight
820
- # Weight scales
821
- elif "weight_scale" in weight_name:
822
- if self.use_block_quant:
823
- block_n, block_k = self.block_shape[0], self.block_shape[1]
824
- if shard_id == "w1":
825
- param_data[expert_id][
826
- : (self.intermediate_size + block_n - 1) // block_n, :
827
- ] = loaded_weight
828
- elif shard_id == "w3":
829
- param_data[expert_id][
830
- (self.intermediate_size + block_n - 1) // block_n :, :
831
- ] = loaded_weight
832
- else: # w2
833
- param_data[expert_id] = loaded_weight
834
- elif self.use_w4afp8:
835
- if shard_id == "w1":
836
- param_data[expert_id][: self.intermediate_size, :] = loaded_weight
837
- elif shard_id == "w3":
838
- param_data[expert_id][self.intermediate_size :, :] = loaded_weight
839
- else:
840
- param_data[expert_id] = loaded_weight
841
- # If we are in merged column case (gate_up_proj)
842
- else:
843
- if shard_id in ("w1", "w3"):
844
- # We have to keep the weight scales of w1 and w3 because
845
- # we need to re-quantize w1/w3 weights after weight loading.
846
- idx = 0 if shard_id == "w1" else 1
847
- param_data[expert_id][idx] = loaded_weight
848
-
849
- # If we are in the row parallel case (down_proj)
850
- else:
851
- param_data[expert_id] = loaded_weight
748
+ self._weight_loader_impl(
749
+ param=param,
750
+ loaded_weight=loaded_weight,
751
+ weight_name=weight_name,
752
+ shard_id=shard_id,
753
+ expert_id=expert_id,
754
+ )
755
+ return
852
756
 
853
757
 
854
758
  class DeepEPMoE(EPMoE):
@@ -898,13 +802,13 @@ class DeepEPMoE(EPMoE):
898
802
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
899
803
  ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
900
804
  if _use_aiter:
901
- # expert_mask is of size (self.num_experts_per_partition + 1),
805
+ # expert_mask is of size (self.num_local_experts + 1),
902
806
  # the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
903
807
  # for instance, if we have 4 experts on this rank, we would have a expert_mask like:
904
808
  # self.expert_mask = [1, 1, 1, 1, 0]
905
809
  # idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
906
810
  self.expert_mask = torch.zeros(
907
- (self.num_experts_per_partition + 1),
811
+ (self.num_local_experts + 1),
908
812
  device=torch.cuda.current_device(),
909
813
  dtype=torch.int,
910
814
  )
@@ -977,13 +881,13 @@ class DeepEPMoE(EPMoE):
977
881
  if self.activation_scheme == "dynamic" and not self.use_block_quant:
978
882
  max_value = (
979
883
  torch.max(hidden_states)
980
- .repeat(self.num_experts_per_partition)
884
+ .repeat(self.num_local_experts)
981
885
  .to(torch.float32)
982
886
  )
983
887
  self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
984
888
  weight_indices_cur_rank = torch.arange(
985
889
  0,
986
- self.num_experts_per_partition,
890
+ self.num_local_experts,
987
891
  device=hidden_states.device,
988
892
  dtype=torch.int64,
989
893
  )
@@ -995,7 +899,7 @@ class DeepEPMoE(EPMoE):
995
899
  b=self.w13_weight,
996
900
  c=None,
997
901
  c_dtype=hidden_states.dtype,
998
- batch_size=self.num_experts_per_partition,
902
+ batch_size=self.num_local_experts,
999
903
  weight_column_major=True,
1000
904
  seg_indptr=seg_indptr,
1001
905
  weight_indices=weight_indices_cur_rank,
@@ -1029,7 +933,7 @@ class DeepEPMoE(EPMoE):
1029
933
  )
1030
934
  if self.w2_input_scale is None and not self.use_block_quant:
1031
935
  self.w2_input_scale = torch.ones(
1032
- self.num_experts_per_partition,
936
+ self.num_local_experts,
1033
937
  dtype=torch.float32,
1034
938
  device=hidden_states_device,
1035
939
  )
@@ -1042,7 +946,7 @@ class DeepEPMoE(EPMoE):
1042
946
  reorder_topk_ids,
1043
947
  self.w2_input_scale,
1044
948
  0,
1045
- self.num_experts_per_partition - 1,
949
+ self.num_local_experts - 1,
1046
950
  BLOCK_SIZE=512,
1047
951
  )
1048
952
  else:
@@ -1062,7 +966,7 @@ class DeepEPMoE(EPMoE):
1062
966
  a=down_input,
1063
967
  b=self.w2_weight,
1064
968
  c=down_output,
1065
- batch_size=self.num_experts_per_partition,
969
+ batch_size=self.num_local_experts,
1066
970
  weight_column_major=True,
1067
971
  seg_indptr=seg_indptr,
1068
972
  weight_indices=weight_indices_cur_rank,
@@ -1087,9 +991,9 @@ class DeepEPMoE(EPMoE):
1087
991
  return hidden_states
1088
992
  # in original deepep, idx == -1 meaning invalid and will not be processed.
1089
993
  # aiter does not accept -1, we use a expert mask to make these idx invalid
1090
- # (idx == num_experts_per_partition) meaning not used in aiter fused_moe
994
+ # (idx == num_local_experts) meaning not used in aiter fused_moe
1091
995
  topk_idx_copy = topk_idx.to(torch.int32)
1092
- topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition
996
+ topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts
1093
997
 
1094
998
  return fused_moe(
1095
999
  hidden_states,
@@ -1315,12 +1219,74 @@ class DeepEPMoE(EPMoE):
1315
1219
  return down_output
1316
1220
 
1317
1221
 
1222
+ class FlashInferEPMoE(EPMoE):
1223
+ def __init__(self, *args, **kwargs):
1224
+ renormalize = kwargs.pop("renormalize", True)
1225
+ num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
1226
+ use_grouped_topk = kwargs.pop("use_grouped_topk", False)
1227
+ num_expert_group = kwargs.pop("num_expert_group", None)
1228
+ topk_group = kwargs.pop("topk_group", None)
1229
+ correction_bias = kwargs.pop("correction_bias", None)
1230
+ super().__init__(*args, **kwargs)
1231
+ self.renormalize = renormalize
1232
+ self.num_fused_shared_experts = num_fused_shared_experts
1233
+ self.use_grouped_topk = use_grouped_topk
1234
+ if self.use_grouped_topk:
1235
+ assert num_expert_group is not None and topk_group is not None
1236
+ self.num_expert_group = num_expert_group
1237
+ self.topk_group = topk_group
1238
+ self.correction_bias = correction_bias
1239
+ self.use_flashinfer_trtllm_moe = use_flashinfer_trtllm_moe
1240
+
1241
+ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
1242
+ assert use_flashinfer_trtllm_moe
1243
+ assert (
1244
+ self.activation == "silu"
1245
+ ), "Only silu is supported for flashinfer blockscale fp8 moe"
1246
+ assert (
1247
+ self.renormalize
1248
+ ), "Renormalize is required for flashinfer blockscale fp8 moe"
1249
+ assert (
1250
+ self.num_fused_shared_experts == 0
1251
+ ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
1252
+ a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
1253
+ # NOTE: scales of hidden states have to be transposed!
1254
+ a_sf_t = a_sf.t().contiguous()
1255
+ assert fi_fused_moe is not None
1256
+ return fi_fused_moe.trtllm_fp8_block_scale_moe(
1257
+ routing_logits=router_logits.to(torch.float32),
1258
+ routing_bias=self.correction_bias.to(hidden_states.dtype),
1259
+ hidden_states=a_q,
1260
+ hidden_states_scale=a_sf_t,
1261
+ gemm1_weights=self.w13_weight,
1262
+ gemm1_weights_scale=self.w13_weight_scale_inv,
1263
+ gemm2_weights=self.w2_weight,
1264
+ gemm2_weights_scale=self.w2_weight_scale_inv,
1265
+ num_experts=self.num_experts,
1266
+ top_k=self.top_k,
1267
+ n_group=self.num_expert_group,
1268
+ topk_group=self.topk_group,
1269
+ intermediate_size=self.w2_weight.shape[2],
1270
+ local_expert_offset=self.start_expert_id,
1271
+ local_num_experts=self.num_experts_per_partition,
1272
+ routed_scaling_factor=self.routed_scaling_factor,
1273
+ tile_tokens_dim=_get_tile_tokens_dim(
1274
+ hidden_states.shape[0], self.top_k, self.num_experts
1275
+ ),
1276
+ routing_method_type=2, # DeepSeek-styled routing method
1277
+ use_shuffled_weight=False,
1278
+ )
1279
+
1280
+
1318
1281
  def get_moe_impl_class():
1319
1282
  if global_server_args_dict["enable_deepep_moe"]:
1320
1283
  return DeepEPMoE
1321
- if global_server_args_dict["enable_flashinfer_moe"]:
1284
+ if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
1322
1285
  # Must come before EPMoE because FusedMoE also supports enable_ep_moe
1323
1286
  return FusedMoE
1287
+ if use_flashinfer_trtllm_moe:
1288
+ # Must come before EPMoE because FusedMoE also supports enable_ep_moe
1289
+ return FlashInferEPMoE
1324
1290
  if global_server_args_dict["enable_ep_moe"]:
1325
1291
  return EPMoE
1326
1292
  return FusedMoE