sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/bench_serving.py +56 -12
  2. sglang/launch_server.py +2 -0
  3. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
  4. sglang/srt/compilation/backend.py +1 -1
  5. sglang/srt/configs/model_config.py +5 -5
  6. sglang/srt/distributed/parallel_state.py +0 -7
  7. sglang/srt/entrypoints/engine.py +18 -15
  8. sglang/srt/entrypoints/grpc_server.py +0 -1
  9. sglang/srt/entrypoints/http_server.py +75 -94
  10. sglang/srt/environ.py +16 -2
  11. sglang/srt/eplb/expert_distribution.py +30 -0
  12. sglang/srt/function_call/function_call_parser.py +2 -0
  13. sglang/srt/function_call/minimax_m2.py +367 -0
  14. sglang/srt/layers/activation.py +6 -0
  15. sglang/srt/layers/attention/flashattention_backend.py +12 -2
  16. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  17. sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
  18. sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
  19. sglang/srt/layers/attention/utils.py +78 -0
  20. sglang/srt/layers/communicator.py +1 -0
  21. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  22. sglang/srt/layers/layernorm.py +19 -4
  23. sglang/srt/layers/logits_processor.py +5 -0
  24. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  25. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  26. sglang/srt/layers/moe/ep_moe/layer.py +79 -272
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  29. sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
  30. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  31. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  32. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  33. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  34. sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
  35. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  36. sglang/srt/layers/moe/topk.py +4 -4
  37. sglang/srt/layers/moe/utils.py +3 -4
  38. sglang/srt/layers/quantization/__init__.py +3 -5
  39. sglang/srt/layers/quantization/awq.py +0 -3
  40. sglang/srt/layers/quantization/base_config.py +7 -0
  41. sglang/srt/layers/quantization/fp8.py +68 -63
  42. sglang/srt/layers/quantization/gguf.py +566 -0
  43. sglang/srt/layers/quantization/mxfp4.py +30 -38
  44. sglang/srt/layers/quantization/unquant.py +23 -45
  45. sglang/srt/layers/quantization/w4afp8.py +38 -2
  46. sglang/srt/layers/radix_attention.py +5 -2
  47. sglang/srt/layers/rotary_embedding.py +13 -1
  48. sglang/srt/layers/sampler.py +12 -1
  49. sglang/srt/managers/io_struct.py +3 -0
  50. sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
  51. sglang/srt/managers/scheduler.py +21 -15
  52. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  53. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  54. sglang/srt/managers/tokenizer_manager.py +11 -19
  55. sglang/srt/mem_cache/hicache_storage.py +7 -1
  56. sglang/srt/mem_cache/memory_pool.py +82 -0
  57. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  58. sglang/srt/model_executor/forward_batch_info.py +44 -3
  59. sglang/srt/model_executor/model_runner.py +1 -149
  60. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  61. sglang/srt/models/deepseek_v2.py +147 -44
  62. sglang/srt/models/glm4_moe.py +322 -354
  63. sglang/srt/models/glm4_moe_nextn.py +4 -14
  64. sglang/srt/models/glm4v_moe.py +29 -196
  65. sglang/srt/models/minimax_m2.py +922 -0
  66. sglang/srt/models/nvila.py +355 -0
  67. sglang/srt/models/nvila_lite.py +184 -0
  68. sglang/srt/models/qwen2.py +22 -1
  69. sglang/srt/models/qwen3.py +34 -4
  70. sglang/srt/models/qwen3_moe.py +2 -4
  71. sglang/srt/multimodal/processors/base_processor.py +1 -0
  72. sglang/srt/multimodal/processors/glm4v.py +1 -1
  73. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  74. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  75. sglang/srt/parser/reasoning_parser.py +28 -1
  76. sglang/srt/server_args.py +365 -186
  77. sglang/srt/single_batch_overlap.py +2 -7
  78. sglang/srt/utils/common.py +87 -42
  79. sglang/srt/utils/hf_transformers_utils.py +7 -3
  80. sglang/test/test_deterministic.py +235 -12
  81. sglang/test/test_deterministic_utils.py +2 -1
  82. sglang/version.py +1 -1
  83. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
  84. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
  85. sglang/srt/models/vila.py +0 -306
  86. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  87. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  88. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -58,7 +58,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
58
58
  logger = logging.getLogger(__name__)
59
59
 
60
60
 
61
- class DeepEPNormalOutput(NamedTuple):
61
+ class DeepEPNormalDispatchOutput(NamedTuple):
62
62
  """DeepEP normal dispatch output."""
63
63
 
64
64
  hidden_states: torch.Tensor
@@ -72,7 +72,7 @@ class DeepEPNormalOutput(NamedTuple):
72
72
  return DispatchOutputFormat.DEEPEP_NORMAL
73
73
 
74
74
 
75
- class DeepEPLLOutput(NamedTuple):
75
+ class DeepEPLLDispatchOutput(NamedTuple):
76
76
  """DeepEP low latency dispatch output."""
77
77
 
78
78
  hidden_states: torch.Tensor
@@ -87,14 +87,17 @@ class DeepEPLLOutput(NamedTuple):
87
87
  return DispatchOutputFormat.DEEPEP_LL
88
88
 
89
89
 
90
- assert isinstance(DeepEPNormalOutput, DispatchOutput)
91
- assert isinstance(DeepEPLLOutput, DispatchOutput)
90
+ assert isinstance(DeepEPNormalDispatchOutput, DispatchOutput)
91
+ assert isinstance(DeepEPLLDispatchOutput, DispatchOutput)
92
92
 
93
93
 
94
94
  class DeepEPNormalCombineInput(NamedTuple):
95
95
  """DeepEP normal combine input."""
96
96
 
97
- pass
97
+ hidden_states: torch.Tensor
98
+ topk_ids: torch.Tensor
99
+ topk_weights: torch.Tensor
100
+ overlap_args: Optional[CombineOverlapArgs] = None
98
101
 
99
102
  @property
100
103
  def format(self) -> CombineInputFormat:
@@ -104,7 +107,10 @@ class DeepEPNormalCombineInput(NamedTuple):
104
107
  class DeepEPLLCombineInput(NamedTuple):
105
108
  """DeepEP low latency combine input."""
106
109
 
107
- pass
110
+ hidden_states: torch.Tensor
111
+ topk_ids: torch.Tensor
112
+ topk_weights: torch.Tensor
113
+ overlap_args: Optional[CombineOverlapArgs] = None
108
114
 
109
115
  @property
110
116
  def format(self) -> CombineInputFormat:
@@ -383,7 +389,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
383
389
  else:
384
390
  hidden_states_scale = None
385
391
 
386
- return DeepEPNormalOutput(
392
+ return DeepEPNormalDispatchOutput(
387
393
  hidden_states,
388
394
  hidden_states_scale,
389
395
  topk_ids,
@@ -562,7 +568,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
562
568
  else:
563
569
  hidden_states_scale = None
564
570
 
565
- deepep_output = DeepEPLLOutput(
571
+ deepep_output = DeepEPLLDispatchOutput(
566
572
  hidden_states,
567
573
  hidden_states_scale,
568
574
  topk_ids,
@@ -756,18 +762,16 @@ class DeepEPDispatcher(BaseDispatcher):
756
762
  del self._dispatch_intermediate_state
757
763
  return self._get_impl().dispatch_b(*inner_state)
758
764
 
759
- def combine(self, *args, **kwargs) -> Tuple:
760
- self.combine_a(*args, **kwargs)
765
+ def combine(self, combine_input: CombineInput) -> Tuple:
766
+ self.combine_a(combine_input)
761
767
  ret = self.combine_b()
762
768
  return ret
763
769
 
764
770
  def combine_a(
765
771
  self,
766
- hidden_states: torch.Tensor,
767
- topk_ids: torch.Tensor,
768
- topk_weights: torch.Tensor,
769
- overlap_args: Optional["CombineOverlapArgs"] = None,
772
+ combine_input: CombineInput,
770
773
  ):
774
+ hidden_states, topk_ids, topk_weights, overlap_args = combine_input
771
775
  self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
772
776
  inner_state = self._get_impl().combine_a(
773
777
  hidden_states=hidden_states,
@@ -88,7 +88,7 @@ class StandardDispatcher(BaseDispatcher):
88
88
  topk_output = topk_output._replace(
89
89
  topk_ids=self.local_expert_mapping[topk_output.topk_ids]
90
90
  )
91
- elif TopKOutputChecker.format_is_triton_kernel(topk_output):
91
+ elif TopKOutputChecker.format_is_triton_kernels(topk_output):
92
92
  raise NotImplementedError()
93
93
 
94
94
  return StandardDispatchOutput(
@@ -111,10 +111,10 @@ class TopKOutputChecker:
111
111
  return topk_output.format.is_standard()
112
112
 
113
113
  @staticmethod
114
- def format_is_triton_kernel(
114
+ def format_is_triton_kernels(
115
115
  topk_output: TopKOutput,
116
116
  ) -> TypeGuard[TritonKernelTopKOutput]:
117
- return topk_output.format.is_triton_kernel()
117
+ return topk_output.format.is_triton_kernels()
118
118
 
119
119
  @staticmethod
120
120
  def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
@@ -129,7 +129,7 @@ class TopKOutputFormat(Enum):
129
129
  def is_standard(self) -> bool:
130
130
  return self == TopKOutputFormat.STANDARD
131
131
 
132
- def is_triton_kernel(self) -> bool:
132
+ def is_triton_kernels(self) -> bool:
133
133
  return self == TopKOutputFormat.TRITON_KERNEL
134
134
 
135
135
  def is_bypassed(self) -> bool:
@@ -254,7 +254,7 @@ class TopK(CustomOp):
254
254
  ) -> TopKOutput:
255
255
  if self.topk_config.output_format is not None:
256
256
  output_format = self.topk_config.output_format
257
- elif get_moe_runner_backend().is_triton_kernel():
257
+ elif get_moe_runner_backend().is_triton_kernels():
258
258
  output_format = TopKOutputFormat.TRITON_KERNEL
259
259
  elif (
260
260
  should_use_flashinfer_trtllm_moe()
@@ -51,7 +51,7 @@ class MoeRunnerBackend(Enum):
51
51
  AUTO = "auto"
52
52
  DEEP_GEMM = "deep_gemm"
53
53
  TRITON = "triton"
54
- TRITON_KERNEL = "triton_kernel"
54
+ TRITON_KERNELS = "triton_kernel"
55
55
  FLASHINFER_TRTLLM = "flashinfer_trtllm"
56
56
  FLASHINFER_CUTLASS = "flashinfer_cutlass"
57
57
  FLASHINFER_MXFP4 = "flashinfer_mxfp4"
@@ -67,8 +67,8 @@ class MoeRunnerBackend(Enum):
67
67
  def is_triton(self):
68
68
  return self == MoeRunnerBackend.TRITON
69
69
 
70
- def is_triton_kernel(self):
71
- return self == MoeRunnerBackend.TRITON_KERNEL
70
+ def is_triton_kernels(self):
71
+ return self == MoeRunnerBackend.TRITON_KERNELS
72
72
 
73
73
  def is_flashinfer_trtllm(self):
74
74
  return self == MoeRunnerBackend.FLASHINFER_TRTLLM
@@ -152,7 +152,6 @@ def initialize_moe_config(server_args: ServerArgs):
152
152
  def get_moe_a2a_backend() -> MoeA2ABackend:
153
153
  global MOE_A2A_BACKEND
154
154
  if MOE_A2A_BACKEND is None:
155
- logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
156
155
  MOE_A2A_BACKEND = MoeA2ABackend.NONE
157
156
  return MOE_A2A_BACKEND
158
157
 
@@ -12,7 +12,6 @@ try:
12
12
  from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
13
13
  from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
14
14
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
15
- from vllm.model_executor.layers.quantization.gguf import GGUFConfig
16
15
  from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
17
16
  GPTQMarlin24Config,
18
17
  )
@@ -32,9 +31,7 @@ except ImportError as e:
32
31
 
33
32
  AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
34
33
  ExpertsInt8Config
35
- ) = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = (
36
- DummyConfig
37
- )
34
+ ) = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
38
35
 
39
36
 
40
37
  from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
@@ -45,6 +42,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
45
42
  )
46
43
  from sglang.srt.layers.quantization.fp8 import Fp8Config
47
44
  from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config
45
+ from sglang.srt.layers.quantization.gguf import GGUFConfig
48
46
  from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
49
47
  from sglang.srt.layers.quantization.modelopt_quant import (
50
48
  ModelOptFp4Config,
@@ -75,6 +73,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
75
73
  "w8a8_fp8": W8A8Fp8Config,
76
74
  "awq": AWQConfig,
77
75
  "awq_marlin": AWQMarlinConfig,
76
+ "gguf": GGUFConfig,
78
77
  "gptq": GPTQConfig,
79
78
  "gptq_marlin": GPTQMarlinConfig,
80
79
  "moe_wna16": MoeWNA16Config,
@@ -108,7 +107,6 @@ VLLM_QUANTIZATION_METHODS = {
108
107
  "deepspeedfp": DeepSpeedFPConfig,
109
108
  "tpu_int8": Int8TpuConfig,
110
109
  "marlin": MarlinConfig,
111
- "gguf": GGUFConfig,
112
110
  "gptq_marlin_24": GPTQMarlin24Config,
113
111
  "bitsandbytes": BitsAndBytesConfig,
114
112
  "qqq": QQQConfig,
@@ -840,12 +840,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
840
840
  self.moe_runner_config.activation == "silu"
841
841
  ), "Only SiLU activation is supported."
842
842
 
843
- # The input must currently be float16
844
843
  x = dispatch_output.hidden_states
845
844
  topk_output = dispatch_output.topk_output
846
-
847
845
  orig_dtype = x.dtype
848
- x = x.half()
849
846
 
850
847
  topk_weights, topk_ids, router_logits = topk_output
851
848
 
@@ -179,6 +179,13 @@ class QuantizationConfig(ABC):
179
179
  elif "NVFP4" in quant_algo or "FP4" in quant_algo:
180
180
  return "modelopt_fp4"
181
181
 
182
+ # The hf_quant_config may be a parsed quant config, so we need to check the
183
+ # quant_method.
184
+ if hf_quant_config.get("quant_method", "") == "modelopt_fp8":
185
+ return "modelopt_fp8"
186
+ elif hf_quant_config.get("quant_method", "") == "modelopt_fp4":
187
+ return "modelopt_fp4"
188
+
182
189
  return None
183
190
 
184
191
  @staticmethod
@@ -33,6 +33,7 @@ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
33
33
  from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
34
34
  from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
35
35
  from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
36
+ from sglang.srt.layers.moe.utils import get_moe_runner_backend
36
37
  from sglang.srt.layers.parameter import (
37
38
  BlockQuantScaleParameter,
38
39
  ModelWeightParameter,
@@ -525,12 +526,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
525
526
  self.quant_config = quant_config
526
527
  self.block_quant = self.quant_config.weight_block_size is not None
527
528
  self.cutlass_fp8_supported = cutlass_fp8_supported()
528
- self.use_cutlass_fused_experts_fp8 = (
529
- get_bool_env_var("SGLANG_CUTLASS_MOE")
530
- and self.cutlass_fp8_supported
531
- and self.block_quant
532
- and (is_sm100_supported() or is_sm90_supported())
533
- )
534
529
 
535
530
  def create_weights(
536
531
  self,
@@ -638,58 +633,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
638
633
  layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
639
634
  layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
640
635
  assert self.quant_config.activation_scheme == "dynamic"
641
- if self.use_cutlass_fused_experts_fp8:
642
- self.ab_strides1 = torch.full(
643
- (num_experts,),
644
- hidden_size,
645
- device=w13_weight.device,
646
- dtype=torch.int64,
647
- )
648
- self.c_strides1 = torch.full(
649
- (num_experts,),
650
- 2 * intermediate_size_per_partition,
651
- device=w13_weight.device,
652
- dtype=torch.int64,
653
- )
654
- self.ab_strides2 = torch.full(
655
- (num_experts,),
656
- intermediate_size_per_partition,
657
- device=w2_weight.device,
658
- dtype=torch.int64,
659
- )
660
- self.c_strides2 = torch.full(
661
- (num_experts,),
662
- hidden_size,
663
- device=w2_weight.device,
664
- dtype=torch.int64,
665
- )
666
- self.workspace = torch.empty(
667
- 90000, device=w13_weight.device, dtype=torch.uint8
668
- )
669
- self.a_ptr = torch.empty(
670
- num_experts, device=w13_weight.device, dtype=torch.int64
671
- )
672
- self.b_ptr = torch.empty(
673
- num_experts, device=w13_weight.device, dtype=torch.int64
674
- )
675
- self.out_ptr = torch.empty(
676
- num_experts, device=w13_weight.device, dtype=torch.int64
677
- )
678
- self.a_scales_ptr = torch.empty(
679
- num_experts, device=w13_weight.device, dtype=torch.int64
680
- )
681
- self.b_scales_ptr = torch.empty(
682
- num_experts, device=w13_weight.device, dtype=torch.int64
683
- )
684
- self.expert_offsets = torch.empty(
685
- num_experts + 1, device=w13_weight.device, dtype=torch.int32
686
- )
687
- self.problem_sizes1 = torch.empty(
688
- num_experts, 3, device=w13_weight.device, dtype=torch.int32
689
- )
690
- self.problem_sizes2 = torch.empty(
691
- num_experts, 3, device=w13_weight.device, dtype=torch.int32
692
- )
636
+ if self._should_use_cutlass_fused_experts():
637
+ self._ensure_cutlass_buffers_initialized(layer)
693
638
 
694
639
  else:
695
640
  # Allocate 2 scales for w1 and w3 respectively.
@@ -1039,13 +984,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1039
984
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1040
985
 
1041
986
  x = dispatch_output.hidden_states
1042
- topk_output = dispatch_output.topk_output
1043
987
  moe_runner_config = self.moe_runner_config
1044
988
 
1045
989
  if use_intel_amx_backend(layer):
1046
990
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
1047
991
 
1048
- topk_weights, topk_ids, _ = topk_output
992
+ topk_weights, topk_ids, _ = dispatch_output.topk_output
1049
993
  x, topk_weights = apply_topk_weights_cpu(
1050
994
  moe_runner_config.apply_router_weight_on_input, topk_weights, x
1051
995
  )
@@ -1072,17 +1016,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1072
1016
  ret = self.maybe_apply_hip_fused_experts(
1073
1017
  layer,
1074
1018
  x,
1075
- topk_output,
1019
+ dispatch_output.topk_output,
1076
1020
  moe_runner_config.activation,
1077
1021
  moe_runner_config.no_combine,
1078
1022
  )
1079
1023
  if ret is not None:
1080
1024
  return StandardCombineInput(hidden_states=ret)
1081
1025
 
1082
- if self.use_cutlass_fused_experts_fp8:
1026
+ if self._should_use_cutlass_fused_experts():
1083
1027
  from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
1084
1028
 
1085
- topk_weights, topk_ids, _ = topk_output
1029
+ topk_weights, topk_ids, _ = dispatch_output.topk_output
1086
1030
  output = cutlass_fused_experts_fp8(
1087
1031
  x,
1088
1032
  layer.w13_weight.transpose(1, 2),
@@ -1171,6 +1115,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1171
1115
 
1172
1116
  return self.runner.run(dispatch_output, quant_info)
1173
1117
 
1118
+ def _should_use_cutlass_fused_experts(self) -> bool:
1119
+ """Decide whether to use Cutlass FP8 fused-experts path based on moe runner backend,
1120
+ with env var override via `SGLANG_CUTLASS_MOE`.
1121
+ """
1122
+ backend = get_moe_runner_backend()
1123
+ env_force = get_bool_env_var("SGLANG_CUTLASS_MOE")
1124
+ # TODO: remove env var in the future, it should be handled by moe runner backend
1125
+ if env_force:
1126
+ return True
1127
+ return (
1128
+ backend.is_flashinfer_cutlass()
1129
+ and self.cutlass_fp8_supported
1130
+ and self.block_quant
1131
+ and (is_sm100_supported() or is_sm90_supported())
1132
+ )
1133
+
1134
+ def _ensure_cutlass_buffers_initialized(self, layer: Module) -> None:
1135
+ if getattr(self, "_cutlass_buffers_ready", False):
1136
+ return
1137
+
1138
+ device = layer.w13_weight.device
1139
+ num_experts = layer.w13_weight.shape[0]
1140
+ hidden_size = layer.w2_weight.shape[1]
1141
+ intermediate_size_per_partition = layer.intermediate_size_per_partition
1142
+
1143
+ self.ab_strides1 = torch.full(
1144
+ (num_experts,), hidden_size, device=device, dtype=torch.int64
1145
+ )
1146
+ self.c_strides1 = torch.full(
1147
+ (num_experts,),
1148
+ 2 * intermediate_size_per_partition,
1149
+ device=device,
1150
+ dtype=torch.int64,
1151
+ )
1152
+ self.ab_strides2 = torch.full(
1153
+ (num_experts,),
1154
+ intermediate_size_per_partition,
1155
+ device=device,
1156
+ dtype=torch.int64,
1157
+ )
1158
+ self.c_strides2 = torch.full(
1159
+ (num_experts,), hidden_size, device=device, dtype=torch.int64
1160
+ )
1161
+ self.workspace = torch.empty(90000, device=device, dtype=torch.uint8)
1162
+ self.a_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1163
+ self.b_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1164
+ self.out_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1165
+ self.a_scales_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1166
+ self.b_scales_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1167
+ self.expert_offsets = torch.empty(
1168
+ num_experts + 1, device=device, dtype=torch.int32
1169
+ )
1170
+ self.problem_sizes1 = torch.empty(
1171
+ num_experts, 3, device=device, dtype=torch.int32
1172
+ )
1173
+ self.problem_sizes2 = torch.empty(
1174
+ num_experts, 3, device=device, dtype=torch.int32
1175
+ )
1176
+
1177
+ self._cutlass_buffers_ready = True
1178
+
1174
1179
  def apply_with_router_logits(
1175
1180
  self,
1176
1181
  layer: torch.nn.Module,