sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/conversation.py +0 -112
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  7. sglang/srt/disaggregation/launch_lb.py +5 -20
  8. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  9. sglang/srt/disaggregation/prefill.py +1 -0
  10. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  11. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  12. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  13. sglang/srt/distributed/parallel_state.py +11 -0
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +35 -15
  16. sglang/srt/eplb/expert_distribution.py +4 -2
  17. sglang/srt/hf_transformers_utils.py +25 -10
  18. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  19. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  20. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  21. sglang/srt/layers/attention/utils.py +6 -1
  22. sglang/srt/layers/attention/vision.py +27 -10
  23. sglang/srt/layers/communicator.py +14 -4
  24. sglang/srt/layers/linear.py +7 -1
  25. sglang/srt/layers/logits_processor.py +9 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +29 -68
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
  29. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  30. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  31. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  32. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  33. sglang/srt/layers/moe/utils.py +43 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  35. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  36. sglang/srt/layers/quantization/fp8.py +57 -1
  37. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  38. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  39. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  40. sglang/srt/lora/lora_registry.py +7 -0
  41. sglang/srt/managers/cache_controller.py +43 -39
  42. sglang/srt/managers/data_parallel_controller.py +52 -2
  43. sglang/srt/managers/io_struct.py +6 -1
  44. sglang/srt/managers/schedule_batch.py +3 -2
  45. sglang/srt/managers/schedule_policy.py +3 -1
  46. sglang/srt/managers/scheduler.py +145 -6
  47. sglang/srt/managers/template_manager.py +25 -22
  48. sglang/srt/managers/tokenizer_manager.py +114 -62
  49. sglang/srt/managers/utils.py +45 -1
  50. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  51. sglang/srt/mem_cache/hicache_storage.py +13 -12
  52. sglang/srt/mem_cache/hiradix_cache.py +21 -4
  53. sglang/srt/mem_cache/memory_pool.py +15 -118
  54. sglang/srt/mem_cache/memory_pool_host.py +350 -33
  55. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  56. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
  57. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  58. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
  59. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
  60. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
  61. sglang/srt/model_executor/cuda_graph_runner.py +42 -4
  62. sglang/srt/model_executor/forward_batch_info.py +13 -3
  63. sglang/srt/model_executor/model_runner.py +13 -1
  64. sglang/srt/model_loader/weight_utils.py +2 -0
  65. sglang/srt/models/deepseek_v2.py +28 -23
  66. sglang/srt/models/glm4_moe.py +85 -22
  67. sglang/srt/models/grok.py +3 -3
  68. sglang/srt/models/llama4.py +13 -2
  69. sglang/srt/models/mixtral.py +3 -3
  70. sglang/srt/models/mllama4.py +428 -19
  71. sglang/srt/models/qwen2_moe.py +1 -4
  72. sglang/srt/models/qwen3_moe.py +7 -8
  73. sglang/srt/models/step3_vl.py +1 -4
  74. sglang/srt/multimodal/processors/base_processor.py +4 -3
  75. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  76. sglang/srt/operations_strategy.py +1 -1
  77. sglang/srt/server_args.py +115 -21
  78. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  79. sglang/srt/two_batch_overlap.py +6 -4
  80. sglang/srt/utils.py +4 -24
  81. sglang/srt/weight_sync/utils.py +1 -1
  82. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  83. sglang/test/runners.py +2 -2
  84. sglang/test/test_utils.py +3 -3
  85. sglang/version.py +1 -1
  86. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  87. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
  88. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  89. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  90. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  91. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  92. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,25 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
2
2
 
3
+ import importlib.util
3
4
  import logging
4
5
  from enum import Enum
6
+ from functools import lru_cache
5
7
  from typing import List, Optional, Tuple
6
8
 
7
9
  import torch
10
+ from packaging import version as pkg_version
8
11
 
9
12
  from sglang.srt.distributed import (
10
13
  get_moe_expert_parallel_rank,
11
14
  get_moe_expert_parallel_world_size,
12
15
  get_moe_tensor_parallel_rank,
13
16
  get_moe_tensor_parallel_world_size,
14
- get_tensor_model_parallel_rank,
15
- get_tensor_model_parallel_world_size,
17
+ get_tp_group,
16
18
  tensor_model_parallel_all_reduce,
17
19
  )
20
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
21
+ use_symmetric_memory,
22
+ )
18
23
  from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
19
24
  from sglang.srt.layers.moe.topk import StandardTopKOutput
20
25
  from sglang.srt.layers.quantization.base_config import (
@@ -33,6 +38,15 @@ _is_cpu = is_cpu()
33
38
  logger = logging.getLogger(__name__)
34
39
 
35
40
 
41
+ @lru_cache(maxsize=1)
42
+ def should_use_flashinfer_trtllm_moe():
43
+ return global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
44
+ not importlib.util.find_spec("flashinfer")
45
+ or pkg_version.parse(__import__("flashinfer").__version__)
46
+ >= pkg_version.parse("0.2.9rc1")
47
+ )
48
+
49
+
36
50
  class FusedMoeWeightScaleSupported(Enum):
37
51
  TENSOR = "tensor"
38
52
  CHANNEL = "channel"
@@ -82,7 +96,6 @@ class FusedMoE(torch.nn.Module):
82
96
  no_combine: bool = False,
83
97
  routed_scaling_factor: Optional[float] = None,
84
98
  enable_flashinfer_cutlass_moe: Optional[bool] = False,
85
- enable_ep_moe: Optional[bool] = False,
86
99
  ):
87
100
  super().__init__()
88
101
 
@@ -100,7 +113,6 @@ class FusedMoE(torch.nn.Module):
100
113
  if enable_flashinfer_cutlass_moe and quant_config is None:
101
114
  logger.warning("Disable flashinfer MoE when quantization config is None.")
102
115
  enable_flashinfer_cutlass_moe = False
103
- enable_ep_moe = False
104
116
 
105
117
  self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
106
118
  self.moe_ep_size = get_moe_expert_parallel_world_size()
@@ -109,7 +121,7 @@ class FusedMoE(torch.nn.Module):
109
121
  self.moe_tp_rank = get_moe_tensor_parallel_rank()
110
122
  assert num_experts % self.moe_ep_size == 0
111
123
  self.num_local_experts = num_experts // self.moe_ep_size
112
- if enable_ep_moe:
124
+ if self.moe_ep_size > 1:
113
125
  # TODO(ch-wan): support shared experts fusion
114
126
  # Create a tensor of size num_experts filled with -1
115
127
  self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
@@ -119,7 +131,8 @@ class FusedMoE(torch.nn.Module):
119
131
  * self.num_local_experts : (self.moe_ep_rank + 1)
120
132
  * self.num_local_experts
121
133
  ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
122
- self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
134
+ if not self.enable_flashinfer_cutlass_moe:
135
+ self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
123
136
 
124
137
  self.routed_scaling_factor = routed_scaling_factor
125
138
  assert intermediate_size % self.moe_tp_size == 0
@@ -454,7 +467,7 @@ class FusedMoE(torch.nn.Module):
454
467
  )
455
468
 
456
469
  # Flashinfer assumes w31 format for w13_weight. Same for the scales.
457
- if getattr(self, "use_flashinfer_trtllm_moe", False):
470
+ if should_use_flashinfer_trtllm_moe():
458
471
  shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
459
472
 
460
473
  WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
@@ -617,24 +630,27 @@ class FusedMoE(torch.nn.Module):
617
630
  )
618
631
 
619
632
  # Matrix multiply.
620
- final_hidden_states = self.quant_method.apply(
621
- layer=self,
622
- x=hidden_states,
623
- topk_output=topk_output,
624
- activation=self.activation,
625
- apply_router_weight_on_input=self.apply_router_weight_on_input,
626
- routed_scaling_factor=self.routed_scaling_factor,
627
- **(
628
- dict(
629
- tp_rank=self.moe_tp_rank,
630
- tp_size=self.moe_tp_size,
631
- ep_rank=self.moe_ep_rank,
632
- ep_size=self.moe_ep_size,
633
- )
634
- if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
635
- else {}
636
- ),
637
- )
633
+ with use_symmetric_memory(get_tp_group()) as sm:
634
+ final_hidden_states = self.quant_method.apply(
635
+ layer=self,
636
+ x=hidden_states,
637
+ topk_output=topk_output,
638
+ activation=self.activation,
639
+ apply_router_weight_on_input=self.apply_router_weight_on_input,
640
+ routed_scaling_factor=self.routed_scaling_factor,
641
+ **(
642
+ dict(
643
+ tp_rank=self.moe_tp_rank,
644
+ tp_size=self.moe_tp_size,
645
+ ep_rank=self.moe_ep_rank,
646
+ ep_size=self.moe_ep_size,
647
+ )
648
+ if self.quant_method.__class__.__name__
649
+ == "ModelOptNvFp4FusedMoEMethod"
650
+ else {}
651
+ ),
652
+ )
653
+ sm.tag(final_hidden_states)
638
654
 
639
655
  if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
640
656
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
@@ -686,3 +702,44 @@ class FusedMoE(torch.nn.Module):
686
702
  for expert_id in range(num_experts)
687
703
  for shard_id in ["w1", "w2", "w3"]
688
704
  ]
705
+
706
+
707
+ class FlashInferFusedMoE(FusedMoE):
708
+ def __init__(self, *args, **kwargs):
709
+ renormalize = kwargs.pop("renormalize", True)
710
+ num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
711
+ use_grouped_topk = kwargs.pop("use_grouped_topk", False)
712
+ num_expert_group = kwargs.pop("num_expert_group", None)
713
+ topk_group = kwargs.pop("topk_group", None)
714
+ correction_bias = kwargs.pop("correction_bias", None)
715
+ super().__init__(*args, **kwargs)
716
+ self.renormalize = renormalize
717
+ self.num_fused_shared_experts = num_fused_shared_experts
718
+ self.use_grouped_topk = use_grouped_topk
719
+ if self.use_grouped_topk:
720
+ assert num_expert_group is not None and topk_group is not None
721
+ self.num_expert_group = num_expert_group
722
+ self.topk_group = topk_group
723
+ self.correction_bias = correction_bias
724
+
725
+ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
726
+ assert self.quant_method is not None
727
+ assert (
728
+ self.renormalize
729
+ ), "Renormalize is required for flashinfer blockscale fp8 moe"
730
+ assert (
731
+ self.num_fused_shared_experts == 0
732
+ ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
733
+ # Matrix multiply.
734
+ final_hidden_states = self.quant_method.apply_with_router_logits(
735
+ layer=self,
736
+ x=hidden_states,
737
+ router_logits=router_logits,
738
+ activation=self.activation,
739
+ routed_scaling_factor=self.routed_scaling_factor,
740
+ )
741
+
742
+ if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
743
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
744
+
745
+ return final_hidden_states
@@ -146,34 +146,3 @@ def triton_kernel_fused_experts(
146
146
  )
147
147
 
148
148
  return intermediate_cache3
149
-
150
-
151
- def triton_kernel_moe_forward_fake(
152
- hidden_states: torch.Tensor,
153
- w1: torch.Tensor,
154
- w2: torch.Tensor,
155
- gating_output: torch.Tensor,
156
- topk: int,
157
- renormalize: bool,
158
- inplace: bool = False,
159
- activation: str = "silu",
160
- apply_router_weight_on_input: bool = False,
161
- use_fp8_w8a8: bool = False,
162
- per_channel_quant: bool = False,
163
- global_num_experts: int = -1,
164
- expert_map: Optional[torch.Tensor] = None,
165
- w1_scale: Optional[torch.Tensor] = None,
166
- w2_scale: Optional[torch.Tensor] = None,
167
- a1_scale: Optional[torch.Tensor] = None,
168
- a2_scale: Optional[torch.Tensor] = None,
169
- block_shape: Optional[list[int]] = None,
170
- ) -> torch.Tensor:
171
- return torch.empty_like(hidden_states)
172
-
173
-
174
- direct_register_custom_op(
175
- op_name="forward_cuda_triton",
176
- op_func=triton_kernel_moe_forward,
177
- mutates_args=[],
178
- fake_impl=triton_kernel_moe_forward_fake,
179
- )
@@ -0,0 +1,23 @@
1
+ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
2
+ BaseDispatcher,
3
+ BaseDispatcherConfig,
4
+ DispatchOutput,
5
+ DispatchOutputFormat,
6
+ )
7
+ from sglang.srt.layers.moe.token_dispatcher.deepep import (
8
+ DeepEPConfig,
9
+ DeepEPDispatcher,
10
+ DeepEPLLOutput,
11
+ DeepEPNormalOutput,
12
+ )
13
+
14
+ __all__ = [
15
+ "BaseDispatcher",
16
+ "BaseDispatcherConfig",
17
+ "DispatchOutput",
18
+ "DispatchOutputFormat",
19
+ "DeepEPConfig",
20
+ "DeepEPDispatcher",
21
+ "DeepEPNormalOutput",
22
+ "DeepEPLLOutput",
23
+ ]
@@ -2,11 +2,22 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from enum import Enum, auto
5
- from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
5
+ from typing import Protocol, runtime_checkable
6
6
 
7
7
  import torch
8
8
 
9
9
 
10
+ class MoEA2ABackend(Enum):
11
+ none = "none"
12
+ deepep = "deepep"
13
+
14
+ def is_none(self):
15
+ return self == MoEA2ABackend.none
16
+
17
+ def is_deepep(self):
18
+ return self == MoEA2ABackend.deepep
19
+
20
+
10
21
  class DispatchOutputFormat(Enum):
11
22
  standard = auto()
12
23
  deepep_normal = auto()
@@ -1,5 +1,3 @@
1
- # TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
2
-
3
1
  from __future__ import annotations
4
2
 
5
3
  import logging
@@ -22,15 +20,10 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
22
20
  DispatchOutput,
23
21
  DispatchOutputFormat,
24
22
  )
23
+ from sglang.srt.layers.moe.utils import DeepEPMode
25
24
  from sglang.srt.layers.quantization import deep_gemm_wrapper
26
25
  from sglang.srt.managers.schedule_batch import global_server_args_dict
27
- from sglang.srt.utils import (
28
- DeepEPMode,
29
- get_bool_env_var,
30
- get_int_env_var,
31
- is_hip,
32
- load_json_config,
33
- )
26
+ from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
34
27
 
35
28
  try:
36
29
  from deep_ep import Buffer, Config
@@ -150,9 +143,9 @@ class DeepEPBuffer:
150
143
  num_rdma_bytes,
151
144
  )
152
145
 
153
- if deepep_mode == DeepEPMode.normal:
146
+ if deepep_mode == DeepEPMode.NORMAL:
154
147
  num_qps_per_rank = DeepEPConfig.get_instance().num_sms // 2
155
- elif deepep_mode in [DeepEPMode.low_latency, DeepEPMode.auto]:
148
+ elif deepep_mode in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]:
156
149
  num_qps_per_rank = num_experts // group.size()
157
150
  else:
158
151
  raise NotImplementedError
@@ -161,7 +154,7 @@ class DeepEPBuffer:
161
154
  device="cuda"
162
155
  ).multi_processor_count
163
156
  if (
164
- (deepep_mode != DeepEPMode.low_latency)
157
+ (deepep_mode != DeepEPMode.LOW_LATENCY)
165
158
  and not global_server_args_dict["enable_two_batch_overlap"]
166
159
  and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
167
160
  ):
@@ -611,7 +604,7 @@ class DeepEPDispatcher(BaseDispatcher):
611
604
  num_local_experts: int = None,
612
605
  hidden_size: int = None,
613
606
  params_dtype: torch.dtype = None,
614
- deepep_mode: DeepEPMode = DeepEPMode.auto,
607
+ deepep_mode: DeepEPMode = DeepEPMode.AUTO,
615
608
  async_finish: bool = False,
616
609
  return_recv_hook: bool = False,
617
610
  ):
@@ -697,9 +690,9 @@ class DeepEPDispatcher(BaseDispatcher):
697
690
  resolved_deepep_mode = self.deepep_mode.resolve(
698
691
  forward_batch.is_extend_in_batch
699
692
  )
700
- if resolved_deepep_mode == DeepEPMode.normal:
693
+ if resolved_deepep_mode == DeepEPMode.NORMAL:
701
694
  return self._normal_dispatcher
702
- elif resolved_deepep_mode == DeepEPMode.low_latency:
695
+ elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
703
696
  return self._low_latency_dispatcher
704
697
  else:
705
698
  raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
@@ -0,0 +1,43 @@
1
+ from enum import Enum
2
+
3
+
4
+ class MoeA2ABackend(Enum):
5
+
6
+ STANDARD = ("standard", "none")
7
+ DEEPEP = "deepep"
8
+
9
+ @classmethod
10
+ def _missing_(cls, value):
11
+ if value is None:
12
+ return cls.STANDARD
13
+ for member in cls:
14
+ if value in member.value:
15
+ return member
16
+ raise ValueError(f"No {cls.__name__} member for value {value}")
17
+
18
+ def is_deepep(self):
19
+ return self == MoeA2ABackend.DEEPEP
20
+
21
+ def is_standard(self):
22
+ return self == MoeA2ABackend.STANDARD
23
+
24
+
25
+ class DeepEPMode(Enum):
26
+ NORMAL = "normal"
27
+ LOW_LATENCY = "low_latency"
28
+ AUTO = "auto"
29
+
30
+ def enable_normal(self):
31
+ return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
32
+
33
+ def enable_low_latency(self):
34
+ return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
35
+
36
+ def resolve(self, is_extend_in_batch: bool):
37
+ if self != DeepEPMode.AUTO:
38
+ return self
39
+
40
+ if is_extend_in_batch:
41
+ return DeepEPMode.NORMAL
42
+ else:
43
+ return DeepEPMode.LOW_LATENCY
@@ -23,6 +23,7 @@ from sglang.srt.layers.quantization.utils import (
23
23
  from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
24
24
 
25
25
  if TYPE_CHECKING:
26
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
26
27
  from sglang.srt.layers.moe.topk import TopKOutput
27
28
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
28
29
  CompressedTensorsConfig,
@@ -189,7 +190,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
189
190
  layer.w13_input_scale = None
190
191
  layer.w2_input_scale = None
191
192
 
192
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
193
+ def process_weights_after_loading(self, layer: FusedMoE) -> None:
193
194
  # Fp8 moe kernels require a single activation scale.
194
195
  # We take the max of all the scales in case they differ.
195
196
  if self.static_input_scales:
@@ -246,7 +247,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
246
247
  assert layer.w13_weight_scale is not None
247
248
  shard_size = layer.intermediate_size_per_partition
248
249
  max_w13_scales = layer.w13_weight_scale.max(dim=1).values
249
- for expert_id in range(layer.local_num_experts):
250
+ for expert_id in range(layer.num_local_experts):
250
251
  start = 0
251
252
  for shard_id in range(2):
252
253
  dq_weight = per_tensor_dequantize(
@@ -148,7 +148,7 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
148
148
  "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
149
149
  "N": n,
150
150
  "K": k,
151
- "NUM_GROUPS": 1,
151
+ "NUM_GROUPS": num_groups,
152
152
  "BLOCK_M": block_m,
153
153
  "BLOCK_N": block_n,
154
154
  "BLOCK_K": block_k,
@@ -72,6 +72,7 @@ from sglang.srt.utils import (
72
72
  is_hip,
73
73
  is_npu,
74
74
  log_info_on_rank0,
75
+ next_power_of_2,
75
76
  print_warning_once,
76
77
  set_weight_attrs,
77
78
  use_intel_amx_backend,
@@ -490,6 +491,16 @@ class Fp8LinearMethod(LinearMethodBase):
490
491
  )
491
492
 
492
493
 
494
+ def get_tile_tokens_dim(num_tokens, top_k, num_experts):
495
+ # Guess tokens per expert assuming perfect expert distribution first.
496
+ num_tokens_per_expert = (num_tokens * top_k) // num_experts
497
+ # And pad the number to the next power of 2.
498
+ tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
499
+ # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
500
+ tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
501
+ return tile_tokens_dim
502
+
503
+
493
504
  class Fp8MoEMethod(FusedMoEMethodBase):
494
505
  """MoE method for FP8.
495
506
  Supports loading FP8 checkpoints with static weight scale and
@@ -1028,7 +1039,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1028
1039
  from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
1029
1040
 
1030
1041
  topk_weights, topk_ids, _ = topk_output
1031
- return cutlass_fused_experts_fp8(
1042
+ output = cutlass_fused_experts_fp8(
1032
1043
  x,
1033
1044
  layer.w13_weight.transpose(1, 2),
1034
1045
  layer.w2_weight.transpose(1, 2),
@@ -1051,6 +1062,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1051
1062
  self.problem_sizes2,
1052
1063
  use_fp8_blockscale=True,
1053
1064
  )
1065
+ # TODO: Fuse into select_experts
1066
+ if routed_scaling_factor is not None:
1067
+ output *= routed_scaling_factor
1068
+ return output
1054
1069
  # Expert fusion with FP8 quantization
1055
1070
  return fused_experts(
1056
1071
  x,
@@ -1076,6 +1091,47 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1076
1091
  routed_scaling_factor=routed_scaling_factor,
1077
1092
  )
1078
1093
 
1094
+ def apply_with_router_logits(
1095
+ self,
1096
+ layer: torch.nn.Module,
1097
+ x: torch.Tensor,
1098
+ router_logits: torch.Tensor,
1099
+ *,
1100
+ activation: str = "silu",
1101
+ routed_scaling_factor: Optional[float] = None,
1102
+ ) -> torch.Tensor:
1103
+ assert (
1104
+ activation == "silu"
1105
+ ), "Only silu is supported for flashinfer blockscale fp8 moe"
1106
+ a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
1107
+ # NOTE: scales of hidden states have to be transposed!
1108
+ a_sf_t = a_sf.t().contiguous()
1109
+ from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
1110
+
1111
+ return trtllm_fp8_block_scale_moe(
1112
+ routing_logits=router_logits.to(torch.float32),
1113
+ routing_bias=layer.correction_bias.to(x.dtype),
1114
+ hidden_states=a_q,
1115
+ hidden_states_scale=a_sf_t,
1116
+ gemm1_weights=layer.w13_weight,
1117
+ gemm1_weights_scale=layer.w13_weight_scale_inv,
1118
+ gemm2_weights=layer.w2_weight,
1119
+ gemm2_weights_scale=layer.w2_weight_scale_inv,
1120
+ num_experts=layer.num_experts,
1121
+ top_k=layer.top_k,
1122
+ n_group=layer.num_expert_group,
1123
+ topk_group=layer.topk_group,
1124
+ intermediate_size=layer.w2_weight.shape[2],
1125
+ local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
1126
+ local_num_experts=layer.num_local_experts,
1127
+ routed_scaling_factor=routed_scaling_factor,
1128
+ tile_tokens_dim=get_tile_tokens_dim(
1129
+ x.shape[0], layer.top_k, layer.num_experts
1130
+ ),
1131
+ routing_method_type=2, # DeepSeek-styled routing method
1132
+ use_shuffled_weight=False,
1133
+ )
1134
+
1079
1135
  def maybe_apply_hip_fused_experts(
1080
1136
  self,
1081
1137
  layer: torch.nn.Module,
@@ -354,10 +354,6 @@ def sglang_per_token_group_quant_fp8(
354
354
  ), "the last dimension of `x` cannot be divisible by `group_size`"
355
355
  assert x.is_contiguous(), "`x` is not contiguous"
356
356
 
357
- if scale_ue8m0:
358
- # TODO: handle this case by fixing the (token=4, dim=256, group_size=128) UT case
359
- assert x.shape[-1] % (group_size * 4) == 0
360
-
361
357
  x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
362
358
  x_s = create_per_token_group_quant_fp8_output_scale(
363
359
  x_shape=x.shape,
@@ -231,7 +231,10 @@ class W8A8Int8Config(QuantizationConfig):
231
231
 
232
232
  @classmethod
233
233
  def get_config_filenames(cls) -> List[str]:
234
- return []
234
+ filenames = []
235
+ if _is_npu:
236
+ filenames.append("quant_model_description.json")
237
+ return filenames
235
238
 
236
239
  @classmethod
237
240
  def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config:
@@ -11,8 +11,12 @@ from sglang.srt.distributed import (
11
11
  divide,
12
12
  get_tensor_model_parallel_rank,
13
13
  get_tensor_model_parallel_world_size,
14
+ parallel_state,
14
15
  tensor_model_parallel_all_reduce,
15
16
  )
17
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
18
+ use_symmetric_memory,
19
+ )
16
20
  from sglang.srt.layers.amx_utils import PackWeightMethod
17
21
  from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
18
22
  from sglang.srt.layers.parameter import BasevLLMParameter
@@ -464,7 +468,9 @@ class VocabParallelEmbedding(torch.nn.Module):
464
468
  else:
465
469
  masked_input = input_
466
470
  # Get the embeddings.
467
- output_parallel = self.quant_method.embedding(self, masked_input.long())
471
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
472
+ output_parallel = self.quant_method.embedding(self, masked_input.long())
473
+ sm.tag(output_parallel)
468
474
  # Mask the output embedding.
469
475
  if self.tp_size > 1:
470
476
  output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
@@ -186,3 +186,10 @@ class LoRARegistry:
186
186
  self._registry[lora_ref.lora_name] = lora_ref
187
187
  self._counters[lora_ref.lora_id] = ConcurrentCounter()
188
188
  return lora_ref
189
+
190
+ @property
191
+ def num_registered_loras(self) -> int:
192
+ """
193
+ Returns the total number of LoRA adapters currently registered.
194
+ """
195
+ return len(self._registry)