sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/bench_one_batch.py +3 -0
  3. sglang/srt/configs/__init__.py +8 -0
  4. sglang/srt/configs/model_config.py +4 -0
  5. sglang/srt/configs/step3_vl.py +172 -0
  6. sglang/srt/conversation.py +23 -0
  7. sglang/srt/disaggregation/decode.py +2 -8
  8. sglang/srt/disaggregation/launch_lb.py +5 -20
  9. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  10. sglang/srt/disaggregation/prefill.py +2 -6
  11. sglang/srt/distributed/parallel_state.py +86 -1
  12. sglang/srt/entrypoints/engine.py +14 -18
  13. sglang/srt/entrypoints/http_server.py +10 -2
  14. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  15. sglang/srt/eplb/expert_distribution.py +5 -0
  16. sglang/srt/eplb/expert_location.py +17 -6
  17. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  18. sglang/srt/eplb/expert_location_updater.py +2 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/step3_detector.py +436 -0
  21. sglang/srt/hf_transformers_utils.py +2 -0
  22. sglang/srt/jinja_template_utils.py +4 -1
  23. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  24. sglang/srt/layers/attention/utils.py +6 -1
  25. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +39 -674
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
  29. sglang/srt/layers/quantization/fp8.py +52 -18
  30. sglang/srt/layers/quantization/unquant.py +0 -8
  31. sglang/srt/layers/quantization/w4afp8.py +1 -0
  32. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  33. sglang/srt/managers/cache_controller.py +165 -67
  34. sglang/srt/managers/data_parallel_controller.py +2 -0
  35. sglang/srt/managers/io_struct.py +0 -2
  36. sglang/srt/managers/scheduler.py +90 -671
  37. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  38. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  39. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  40. sglang/srt/managers/template_manager.py +62 -19
  41. sglang/srt/managers/tokenizer_manager.py +123 -74
  42. sglang/srt/managers/tp_worker.py +4 -0
  43. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  44. sglang/srt/mem_cache/hicache_storage.py +60 -17
  45. sglang/srt/mem_cache/hiradix_cache.py +36 -8
  46. sglang/srt/mem_cache/memory_pool.py +15 -118
  47. sglang/srt/mem_cache/memory_pool_host.py +418 -29
  48. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  49. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  50. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  51. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  52. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  53. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
  54. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  55. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  57. sglang/srt/model_executor/model_runner.py +13 -1
  58. sglang/srt/model_loader/weight_utils.py +2 -0
  59. sglang/srt/models/arcee.py +532 -0
  60. sglang/srt/models/deepseek_v2.py +7 -6
  61. sglang/srt/models/glm4_moe.py +6 -4
  62. sglang/srt/models/granitemoe.py +3 -0
  63. sglang/srt/models/grok.py +3 -0
  64. sglang/srt/models/hunyuan.py +1 -0
  65. sglang/srt/models/llama4.py +3 -0
  66. sglang/srt/models/mixtral.py +3 -0
  67. sglang/srt/models/olmoe.py +3 -0
  68. sglang/srt/models/phimoe.py +1 -0
  69. sglang/srt/models/step3_vl.py +991 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/reasoning_parser.py +2 -1
  73. sglang/srt/server_args.py +49 -18
  74. sglang/srt/speculative/eagle_worker.py +2 -0
  75. sglang/srt/utils.py +1 -0
  76. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  77. sglang/utils.py +0 -11
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
  80. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
  81. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  82. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -413,18 +413,37 @@ def fused_moe_kernel(
413
413
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
414
414
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
415
415
  return
416
- offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
416
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
417
417
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
418
418
  offs_token = offs_token.to(tl.int64)
419
419
  token_mask = offs_token < num_valid_tokens
420
420
 
421
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
421
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
422
+
423
+ if off_experts == -1:
424
+ # -----------------------------------------------------------
425
+ # Write back zeros to the output when the expert is not
426
+ # in the current expert parallel rank.
427
+ write_zeros_to_output(
428
+ c_ptr,
429
+ stride_cm,
430
+ stride_cn,
431
+ pid_n,
432
+ N,
433
+ offs_token,
434
+ token_mask,
435
+ BLOCK_SIZE_M,
436
+ BLOCK_SIZE_N,
437
+ compute_type,
438
+ )
439
+ return
440
+
441
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
422
442
  offs_k = tl.arange(0, BLOCK_SIZE_K)
423
443
  a_ptrs = a_ptr + (
424
444
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
425
445
  )
426
446
 
427
- off_experts = tl.load(expert_ids_ptr + pid_m)
428
447
  b_ptrs = (
429
448
  b_ptr
430
449
  + off_experts * stride_be
@@ -497,7 +516,6 @@ def fused_moe_kernel(
497
516
 
498
517
  accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
499
518
  else:
500
- # fix out of shared memory issue
501
519
  if use_fp8_w8a8:
502
520
  accumulator = tl.dot(a, b, acc=accumulator)
503
521
  else:
@@ -568,7 +586,7 @@ def moe_align_block_size(
568
586
  - The padding ensures that the total number of tokens is now divisible
569
587
  by block_size for proper block matrix operations.
570
588
  """
571
- max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
589
+ max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
572
590
  sorted_ids = torch.empty(
573
591
  (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
574
592
  )
@@ -578,13 +596,9 @@ def moe_align_block_size(
578
596
  )
579
597
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
580
598
 
599
+ # In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
581
600
  cumsum_buffer = torch.empty(
582
- (num_experts + 1,), dtype=torch.int32, device=topk_ids.device
583
- )
584
- token_cnts_buffer = torch.empty(
585
- (num_experts + 1) * num_experts,
586
- dtype=torch.int32,
587
- device=topk_ids.device,
601
+ (num_experts + 2,), dtype=torch.int32, device=topk_ids.device
588
602
  )
589
603
 
590
604
  # Threshold based on benchmark results
@@ -594,12 +608,11 @@ def moe_align_block_size(
594
608
 
595
609
  sgl_moe_align_block_size(
596
610
  topk_ids,
597
- num_experts,
611
+ num_experts + 1,
598
612
  block_size,
599
613
  sorted_ids,
600
614
  expert_ids,
601
615
  num_tokens_post_pad,
602
- token_cnts_buffer,
603
616
  cumsum_buffer,
604
617
  fuse_sorted_ids_padding,
605
618
  )
@@ -1,17 +1,25 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
2
2
 
3
+ import importlib.util
3
4
  import logging
4
5
  from enum import Enum
6
+ from functools import lru_cache
5
7
  from typing import List, Optional, Tuple
6
8
 
7
9
  import torch
10
+ from packaging import version as pkg_version
8
11
 
9
12
  from sglang.srt.distributed import (
13
+ get_moe_expert_parallel_rank,
14
+ get_moe_expert_parallel_world_size,
15
+ get_moe_tensor_parallel_rank,
16
+ get_moe_tensor_parallel_world_size,
10
17
  get_tensor_model_parallel_rank,
11
18
  get_tensor_model_parallel_world_size,
12
19
  tensor_model_parallel_all_reduce,
13
20
  )
14
- from sglang.srt.layers.moe.topk import TopKOutput
21
+ from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
22
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
15
23
  from sglang.srt.layers.quantization.base_config import (
16
24
  QuantizationConfig,
17
25
  QuantizeMethodBase,
@@ -28,6 +36,15 @@ _is_cpu = is_cpu()
28
36
  logger = logging.getLogger(__name__)
29
37
 
30
38
 
39
+ @lru_cache(maxsize=1)
40
+ def should_use_flashinfer_trtllm_moe():
41
+ return global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
42
+ not importlib.util.find_spec("flashinfer")
43
+ or pkg_version.parse(__import__("flashinfer").__version__)
44
+ >= pkg_version.parse("0.2.9rc1")
45
+ )
46
+
47
+
31
48
  class FusedMoeWeightScaleSupported(Enum):
32
49
  TENSOR = "tensor"
33
50
  CHANNEL = "channel"
@@ -62,8 +79,9 @@ class FusedMoE(torch.nn.Module):
62
79
  num_experts: int,
63
80
  hidden_size: int,
64
81
  intermediate_size: int,
82
+ layer_id: int,
65
83
  top_k: Optional[int] = None,
66
- layer_id: Optional[int] = None,
84
+ num_fused_shared_experts: int = 0,
67
85
  params_dtype: Optional[torch.dtype] = None,
68
86
  reduce_results: bool = False,
69
87
  quant_config: Optional[QuantizationConfig] = None,
@@ -77,21 +95,19 @@ class FusedMoE(torch.nn.Module):
77
95
  routed_scaling_factor: Optional[float] = None,
78
96
  enable_flashinfer_cutlass_moe: Optional[bool] = False,
79
97
  enable_ep_moe: Optional[bool] = False,
80
- skip_quant: Optional[bool] = False,
81
98
  ):
82
99
  super().__init__()
83
100
 
84
101
  if params_dtype is None:
85
102
  params_dtype = torch.get_default_dtype()
86
103
 
104
+ self.layer_id = layer_id
87
105
  self.top_k = top_k
88
106
  self.hidden_size = hidden_size
89
- self.tp_size = (
90
- tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
91
- )
92
- self.tp_rank = get_tensor_model_parallel_rank()
93
107
  self.num_experts = num_experts
94
- self.expert_map = None
108
+ self.num_fused_shared_experts = num_fused_shared_experts
109
+ self.expert_map_cpu = None
110
+ self.expert_map_gpu = None
95
111
 
96
112
  if enable_flashinfer_cutlass_moe and quant_config is None:
97
113
  logger.warning("Disable flashinfer MoE when quantization config is None.")
@@ -99,28 +115,28 @@ class FusedMoE(torch.nn.Module):
99
115
  enable_ep_moe = False
100
116
 
101
117
  self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
118
+ self.moe_ep_size = get_moe_expert_parallel_world_size()
119
+ self.moe_ep_rank = get_moe_expert_parallel_rank()
120
+ self.moe_tp_size = get_moe_tensor_parallel_world_size()
121
+ self.moe_tp_rank = get_moe_tensor_parallel_rank()
122
+ assert num_experts % self.moe_ep_size == 0
123
+ self.num_local_experts = num_experts // self.moe_ep_size
102
124
  if enable_ep_moe:
103
- self.ep_size = self.tp_size
104
- self.ep_rank = self.tp_rank
105
- self.tp_size = 1
106
- self.tp_rank = 0
125
+ # TODO(ch-wan): support shared experts fusion
107
126
  # Create a tensor of size num_experts filled with -1
108
- self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
127
+ self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
109
128
  # Create a expert map for the local experts
110
- assert num_experts % self.ep_size == 0
111
- self.num_local_experts = num_experts // self.ep_size
112
- self.expert_map[
113
- self.ep_rank
114
- * self.num_local_experts : (self.ep_rank + 1)
129
+ self.expert_map_cpu[
130
+ self.moe_ep_rank
131
+ * self.num_local_experts : (self.moe_ep_rank + 1)
115
132
  * self.num_local_experts
116
133
  ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
117
- else:
118
- self.ep_size = 1
119
- self.ep_rank = 0
120
- self.num_local_experts = num_experts
134
+ if not self.enable_flashinfer_cutlass_moe:
135
+ self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
136
+
121
137
  self.routed_scaling_factor = routed_scaling_factor
122
- assert intermediate_size % self.tp_size == 0
123
- self.intermediate_size_per_partition = intermediate_size // self.tp_size
138
+ assert intermediate_size % self.moe_tp_size == 0
139
+ self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
124
140
  self.reduce_results = reduce_results
125
141
  self.activation = activation
126
142
  self.apply_router_weight_on_input = apply_router_weight_on_input
@@ -132,9 +148,6 @@ class FusedMoE(torch.nn.Module):
132
148
  not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
133
149
  )
134
150
 
135
- if skip_quant:
136
- return
137
-
138
151
  if quant_config is None:
139
152
  self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
140
153
  self.use_triton_kernels
@@ -363,9 +376,9 @@ class FusedMoE(torch.nn.Module):
363
376
  expert_data.copy_(loaded_weight)
364
377
 
365
378
  def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
366
- if self.expert_map is None:
379
+ if self.expert_map_cpu is None:
367
380
  return expert_id
368
- return self.expert_map[expert_id].item()
381
+ return self.expert_map_cpu[expert_id].item()
369
382
 
370
383
  def weight_loader(
371
384
  self,
@@ -375,10 +388,48 @@ class FusedMoE(torch.nn.Module):
375
388
  shard_id: str,
376
389
  expert_id: int,
377
390
  ) -> None:
391
+
392
+ global_expert_location_metadata = get_global_expert_location_metadata()
393
+ if global_expert_location_metadata is None:
394
+ self._weight_loader_impl(
395
+ param=param,
396
+ loaded_weight=loaded_weight,
397
+ weight_name=weight_name,
398
+ shard_id=shard_id,
399
+ expert_id=expert_id,
400
+ )
401
+ return
402
+
403
+ if expert_id >= self.num_experts - self.num_fused_shared_experts:
404
+ # This is a shared expert.
405
+ physical_expert_ids = [expert_id]
406
+ else:
407
+ physical_expert_ids = (
408
+ global_expert_location_metadata.logical_to_all_physical(
409
+ self.layer_id, expert_id
410
+ )
411
+ )
412
+
413
+ for physical_expert_id in physical_expert_ids:
414
+ self._weight_loader_physical(
415
+ param=param,
416
+ loaded_weight=loaded_weight,
417
+ weight_name=weight_name,
418
+ shard_id=shard_id,
419
+ expert_id=physical_expert_id,
420
+ )
421
+
422
+ def _weight_loader_physical(
423
+ self,
424
+ param: torch.nn.Parameter,
425
+ loaded_weight: torch.Tensor,
426
+ weight_name: str,
427
+ shard_id: str,
428
+ expert_id: int,
429
+ ) -> None:
378
430
  expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
379
431
  if expert_id == -1:
380
432
  return
381
-
382
433
  self._weight_loader_impl(
383
434
  param=param,
384
435
  loaded_weight=loaded_weight,
@@ -396,8 +447,7 @@ class FusedMoE(torch.nn.Module):
396
447
  expert_id: int,
397
448
  ) -> None:
398
449
 
399
- # TP rank is set to 0 if EP is enabled
400
- tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
450
+ tp_rank = self.moe_tp_rank
401
451
 
402
452
  # compressed-tensors checkpoints with packed weights are stored flipped
403
453
  # TODO (mgoin): check self.quant_method.quant_config.quant_format
@@ -417,7 +467,7 @@ class FusedMoE(torch.nn.Module):
417
467
  )
418
468
 
419
469
  # Flashinfer assumes w31 format for w13_weight. Same for the scales.
420
- if getattr(self, "use_flashinfer_trtllm_moe", False):
470
+ if should_use_flashinfer_trtllm_moe():
421
471
  shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
422
472
 
423
473
  WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
@@ -571,9 +621,14 @@ class FusedMoE(torch.nn.Module):
571
621
  )
572
622
  return
573
623
 
574
- def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
624
+ def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
575
625
  assert self.quant_method is not None
576
626
 
627
+ if self.expert_map_gpu is not None:
628
+ topk_output = topk_output._replace(
629
+ topk_ids=self.expert_map_gpu[topk_output.topk_ids]
630
+ )
631
+
577
632
  # Matrix multiply.
578
633
  final_hidden_states = self.quant_method.apply(
579
634
  layer=self,
@@ -584,17 +639,17 @@ class FusedMoE(torch.nn.Module):
584
639
  routed_scaling_factor=self.routed_scaling_factor,
585
640
  **(
586
641
  dict(
587
- tp_rank=self.tp_rank,
588
- tp_size=self.tp_size,
589
- ep_rank=self.ep_rank,
590
- ep_size=self.ep_size,
642
+ tp_rank=self.moe_tp_rank,
643
+ tp_size=self.moe_tp_size,
644
+ ep_rank=self.moe_ep_rank,
645
+ ep_size=self.moe_ep_size,
591
646
  )
592
647
  if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
593
648
  else {}
594
649
  ),
595
650
  )
596
651
 
597
- if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
652
+ if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
598
653
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
599
654
 
600
655
  return final_hidden_states
@@ -627,3 +682,61 @@ class FusedMoE(torch.nn.Module):
627
682
  ("w3", ckpt_up_proj_name),
628
683
  ]
629
684
  ]
685
+
686
+ @classmethod
687
+ def make_expert_input_scale_params_mapping(
688
+ cls,
689
+ num_experts: int,
690
+ ) -> List[Tuple[str, str, int, str]]:
691
+ # (param_name, weight_name, expert_id, shard_id)
692
+ return [
693
+ (
694
+ "experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
695
+ f"experts.{expert_id}.{shard_id}.",
696
+ expert_id,
697
+ shard_id,
698
+ )
699
+ for expert_id in range(num_experts)
700
+ for shard_id in ["w1", "w2", "w3"]
701
+ ]
702
+
703
+
704
+ class FlashInferFusedMoE(FusedMoE):
705
+ def __init__(self, *args, **kwargs):
706
+ renormalize = kwargs.pop("renormalize", True)
707
+ num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
708
+ use_grouped_topk = kwargs.pop("use_grouped_topk", False)
709
+ num_expert_group = kwargs.pop("num_expert_group", None)
710
+ topk_group = kwargs.pop("topk_group", None)
711
+ correction_bias = kwargs.pop("correction_bias", None)
712
+ super().__init__(*args, **kwargs)
713
+ self.renormalize = renormalize
714
+ self.num_fused_shared_experts = num_fused_shared_experts
715
+ self.use_grouped_topk = use_grouped_topk
716
+ if self.use_grouped_topk:
717
+ assert num_expert_group is not None and topk_group is not None
718
+ self.num_expert_group = num_expert_group
719
+ self.topk_group = topk_group
720
+ self.correction_bias = correction_bias
721
+
722
+ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
723
+ assert self.quant_method is not None
724
+ assert (
725
+ self.renormalize
726
+ ), "Renormalize is required for flashinfer blockscale fp8 moe"
727
+ assert (
728
+ self.num_fused_shared_experts == 0
729
+ ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
730
+ # Matrix multiply.
731
+ final_hidden_states = self.quant_method.apply_with_router_logits(
732
+ layer=self,
733
+ x=hidden_states,
734
+ router_logits=router_logits,
735
+ activation=self.activation,
736
+ routed_scaling_factor=self.routed_scaling_factor,
737
+ )
738
+
739
+ if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
740
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
741
+
742
+ return final_hidden_states
@@ -72,6 +72,7 @@ from sglang.srt.utils import (
72
72
  is_hip,
73
73
  is_npu,
74
74
  log_info_on_rank0,
75
+ next_power_of_2,
75
76
  print_warning_once,
76
77
  set_weight_attrs,
77
78
  use_intel_amx_backend,
@@ -172,7 +173,6 @@ class Fp8Config(QuantizationConfig):
172
173
  self, layer: torch.nn.Module, prefix: str
173
174
  ) -> Optional[QuantizeMethodBase]:
174
175
  from sglang.srt.layers.linear import LinearBase
175
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
176
176
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
177
177
 
178
178
  if isinstance(layer, LinearBase):
@@ -181,8 +181,6 @@ class Fp8Config(QuantizationConfig):
181
181
  return Fp8LinearMethod(self)
182
182
  elif isinstance(layer, FusedMoE):
183
183
  return Fp8MoEMethod(self)
184
- elif isinstance(layer, EPMoE):
185
- return Fp8EPMoEMethod(self)
186
184
  return None
187
185
 
188
186
  def get_scaled_act_names(self) -> List[str]:
@@ -493,6 +491,16 @@ class Fp8LinearMethod(LinearMethodBase):
493
491
  )
494
492
 
495
493
 
494
+ def get_tile_tokens_dim(num_tokens, top_k, num_experts):
495
+ # Guess tokens per expert assuming perfect expert distribution first.
496
+ num_tokens_per_expert = (num_tokens * top_k) // num_experts
497
+ # And pad the number to the next power of 2.
498
+ tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
499
+ # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
500
+ tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
501
+ return tile_tokens_dim
502
+
503
+
496
504
  class Fp8MoEMethod(FusedMoEMethodBase):
497
505
  """MoE method for FP8.
498
506
  Supports loading FP8 checkpoints with static weight scale and
@@ -984,23 +992,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
984
992
  no_combine: bool = False,
985
993
  routed_scaling_factor: Optional[float] = None,
986
994
  ) -> torch.Tensor:
987
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
988
995
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
989
996
 
990
- if isinstance(layer, EPMoE):
991
- layer.w13_weight_scale = (
992
- layer.w13_weight_scale_inv
993
- if self.block_quant
994
- else layer.w13_weight_scale
995
- )
996
- layer.w2_weight_scale = (
997
- layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
998
- )
999
- return layer.run_moe(
1000
- hidden_states=x,
1001
- topk_output=topk_output,
1002
- )
1003
-
1004
997
  if use_intel_amx_backend(layer):
1005
998
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
1006
999
 
@@ -1094,6 +1087,47 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1094
1087
  routed_scaling_factor=routed_scaling_factor,
1095
1088
  )
1096
1089
 
1090
+ def apply_with_router_logits(
1091
+ self,
1092
+ layer: torch.nn.Module,
1093
+ x: torch.Tensor,
1094
+ router_logits: torch.Tensor,
1095
+ *,
1096
+ activation: str = "silu",
1097
+ routed_scaling_factor: Optional[float] = None,
1098
+ ) -> torch.Tensor:
1099
+ assert (
1100
+ activation == "silu"
1101
+ ), "Only silu is supported for flashinfer blockscale fp8 moe"
1102
+ a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
1103
+ # NOTE: scales of hidden states have to be transposed!
1104
+ a_sf_t = a_sf.t().contiguous()
1105
+ from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
1106
+
1107
+ return trtllm_fp8_block_scale_moe(
1108
+ routing_logits=router_logits.to(torch.float32),
1109
+ routing_bias=layer.correction_bias.to(x.dtype),
1110
+ hidden_states=a_q,
1111
+ hidden_states_scale=a_sf_t,
1112
+ gemm1_weights=layer.w13_weight,
1113
+ gemm1_weights_scale=layer.w13_weight_scale_inv,
1114
+ gemm2_weights=layer.w2_weight,
1115
+ gemm2_weights_scale=layer.w2_weight_scale_inv,
1116
+ num_experts=layer.num_experts,
1117
+ top_k=layer.top_k,
1118
+ n_group=layer.num_expert_group,
1119
+ topk_group=layer.topk_group,
1120
+ intermediate_size=layer.w2_weight.shape[2],
1121
+ local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
1122
+ local_num_experts=layer.num_local_experts,
1123
+ routed_scaling_factor=routed_scaling_factor,
1124
+ tile_tokens_dim=get_tile_tokens_dim(
1125
+ x.shape[0], layer.top_k, layer.num_experts
1126
+ ),
1127
+ routing_method_type=2, # DeepSeek-styled routing method
1128
+ use_shuffled_weight=False,
1129
+ )
1130
+
1097
1131
  def maybe_apply_hip_fused_experts(
1098
1132
  self,
1099
1133
  layer: torch.nn.Module,
@@ -204,14 +204,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
204
204
  routed_scaling_factor: Optional[float] = None,
205
205
  ) -> torch.Tensor:
206
206
 
207
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
208
-
209
- if isinstance(layer, EPMoE):
210
- return layer.run_moe(
211
- hidden_states=x,
212
- topk_output=topk_output,
213
- )
214
-
215
207
  return self.forward(
216
208
  x=x,
217
209
  layer=layer,
@@ -276,6 +276,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
276
276
  layer: EPMoE,
277
277
  hidden_states: torch.Tensor,
278
278
  topk_output: TopKOutput,
279
+ **kwargs,
279
280
  ) -> torch.Tensor:
280
281
 
281
282
  # TODO(ch-wan): move it out of this class
@@ -231,7 +231,10 @@ class W8A8Int8Config(QuantizationConfig):
231
231
 
232
232
  @classmethod
233
233
  def get_config_filenames(cls) -> List[str]:
234
- return []
234
+ filenames = []
235
+ if _is_npu:
236
+ filenames.append("quant_model_description.json")
237
+ return filenames
235
238
 
236
239
  @classmethod
237
240
  def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config: