sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,15 @@
1
+ import logging
2
+ from dataclasses import dataclass
3
+
1
4
  from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
2
- from sglang.srt.utils import DeepEPMode
5
+ from sglang.srt.managers.expert_distribution import (
6
+ get_global_expert_distribution_recorder,
7
+ )
8
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
9
+ from sglang.srt.utils import DeepEPMode, load_json_config
3
10
 
4
11
  try:
5
- from deep_ep import Buffer
12
+ from deep_ep import Buffer, Config
6
13
 
7
14
  from sglang.srt.layers.quantization.fp8_kernel import (
8
15
  sglang_per_token_group_quant_fp8,
@@ -12,7 +19,7 @@ try:
12
19
  except ImportError:
13
20
  use_deepep = False
14
21
 
15
- from enum import IntEnum, auto
22
+ from enum import Enum, IntEnum, auto
16
23
  from typing import Optional, Tuple, Union
17
24
 
18
25
  import torch
@@ -25,6 +32,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
25
32
  )
26
33
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
27
34
 
35
+ logger = logging.getLogger(__name__)
36
+
28
37
 
29
38
  class DeepEPDispatchMode(IntEnum):
30
39
  NORMAL = auto()
@@ -32,7 +41,6 @@ class DeepEPDispatchMode(IntEnum):
32
41
 
33
42
 
34
43
  class DeepEPBuffer:
35
-
36
44
  _buffer = None
37
45
  _dispatch_mode: Optional[DeepEPDispatchMode] = None
38
46
  _hidden_size: Optional[int] = None
@@ -60,8 +68,10 @@ class DeepEPBuffer:
60
68
  if deepep_mode.enable_normal():
61
69
  hidden_bytes = hidden_size * param_bytes
62
70
  for config in (
63
- Buffer.get_dispatch_config(group.size()),
64
- Buffer.get_combine_config(group.size()),
71
+ DeepEPConfig.get_instance().normal_dispatch_config
72
+ or Buffer.get_dispatch_config(group.size()),
73
+ DeepEPConfig.get_instance().normal_combine_config
74
+ or Buffer.get_combine_config(group.size()),
65
75
  ):
66
76
  num_nvl_bytes = max(
67
77
  config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
@@ -88,7 +98,12 @@ class DeepEPBuffer:
88
98
  num_nvl_bytes,
89
99
  num_rdma_bytes,
90
100
  low_latency_mode=deepep_mode.enable_low_latency(),
91
- num_qps_per_rank=(max(num_experts // group.size(), Buffer.num_sms // 2)),
101
+ num_qps_per_rank=(
102
+ max(
103
+ num_experts // group.size(),
104
+ DeepEPConfig.get_instance().num_sms // 2,
105
+ )
106
+ ),
92
107
  )
93
108
  return cls._buffer
94
109
 
@@ -113,6 +128,35 @@ class DeepEPBuffer:
113
128
  cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
114
129
 
115
130
 
131
+ class DeepEPConfig:
132
+ _instance = None
133
+
134
+ def __init__(self):
135
+ config_str = global_server_args_dict["deepep_config"]
136
+ if config_str:
137
+ config_parsed = load_json_config(config_str)
138
+ if torch.distributed.get_rank() == 0:
139
+ logger.info(f"Use DeepEP Config: {config_parsed}")
140
+ config_dispatch = config_parsed["normal_dispatch"]
141
+ config_combine = config_parsed["normal_combine"]
142
+
143
+ self.normal_dispatch_config = Config(**config_dispatch)
144
+ self.normal_combine_config = Config(**config_combine)
145
+
146
+ assert config_dispatch["num_sms"] == config_combine["num_sms"]
147
+ self.num_sms = config_dispatch["num_sms"]
148
+ else:
149
+ self.normal_dispatch_config = None
150
+ self.normal_combine_config = None
151
+ self.num_sms = Buffer.num_sms
152
+
153
+ @classmethod
154
+ def get_instance(cls):
155
+ if cls._instance is None:
156
+ cls._instance = DeepEPConfig()
157
+ return cls._instance
158
+
159
+
116
160
  class _DeepEPDispatcherImplBase:
117
161
  def __init__(
118
162
  self,
@@ -295,6 +339,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
295
339
  async_finish=self.async_finish,
296
340
  allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
297
341
  expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
342
+ config=DeepEPConfig.get_instance().normal_dispatch_config,
343
+ )
344
+
345
+ get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
346
+ num_recv_tokens_per_expert_list,
347
+ num_tokens_per_rank=num_tokens_per_rank,
348
+ num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
349
+ num_tokens_per_expert=num_tokens_per_expert,
298
350
  )
299
351
 
300
352
  return (
@@ -394,6 +446,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
394
446
  async_finish=self.async_finish,
395
447
  previous_event=previous_event,
396
448
  allocate_on_comm_stream=previous_event is not None,
449
+ config=DeepEPConfig.get_instance().normal_combine_config,
397
450
  )
398
451
  return combined_x, event
399
452
 
@@ -459,6 +512,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
459
512
  ):
460
513
  hook() if self.return_recv_hook else event.current_stream_wait()
461
514
 
515
+ get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
516
+ masked_m
517
+ )
518
+
462
519
  reorder_topk_ids = seg_indptr = None
463
520
 
464
521
  return (
@@ -571,6 +628,14 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
571
628
  )
572
629
 
573
630
 
631
+ @dataclass
632
+ class _Stage(Enum):
633
+ INITIAL = auto()
634
+ AFTER_DISPATCH_A = auto()
635
+ AFTER_DISPATCH_B = auto()
636
+ AFTER_COMBINE_A = auto()
637
+
638
+
574
639
  class DeepEPDispatcher:
575
640
  def __init__(
576
641
  self,
@@ -609,6 +674,8 @@ class DeepEPDispatcher:
609
674
  **common_kwargs,
610
675
  )
611
676
 
677
+ self._stage = _Stage.INITIAL
678
+
612
679
  def dispatch(self, *args, **kwargs) -> Tuple:
613
680
  self.dispatch_a(*args, **kwargs)
614
681
  ret = self.dispatch_b()
@@ -621,6 +688,7 @@ class DeepEPDispatcher:
621
688
  topk_weights: torch.Tensor,
622
689
  forward_mode: ForwardMode = None,
623
690
  ):
691
+ self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
624
692
  inner_state = self._get_impl(forward_mode).dispatch_a(
625
693
  hidden_states=hidden_states,
626
694
  topk_idx=topk_idx,
@@ -629,6 +697,7 @@ class DeepEPDispatcher:
629
697
  self._dispatch_intermediate_state = forward_mode, inner_state
630
698
 
631
699
  def dispatch_b(self):
700
+ self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
632
701
  forward_mode, inner_state = self._dispatch_intermediate_state
633
702
  del self._dispatch_intermediate_state
634
703
  return self._get_impl(forward_mode).dispatch_b(*inner_state)
@@ -645,6 +714,7 @@ class DeepEPDispatcher:
645
714
  topk_weights: torch.Tensor,
646
715
  forward_mode: ForwardMode,
647
716
  ):
717
+ self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
648
718
  inner_state = self._get_impl(forward_mode).combine_a(
649
719
  hidden_states=hidden_states,
650
720
  topk_idx=topk_idx,
@@ -653,6 +723,7 @@ class DeepEPDispatcher:
653
723
  self._combine_intermediate_state = forward_mode, inner_state
654
724
 
655
725
  def combine_b(self):
726
+ self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
656
727
  forward_mode, inner_state = self._combine_intermediate_state
657
728
  del self._combine_intermediate_state
658
729
  return self._get_impl(forward_mode).combine_b(*inner_state)
@@ -665,3 +736,7 @@ class DeepEPDispatcher:
665
736
  return self._low_latency_dispatcher
666
737
  else:
667
738
  raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
739
+
740
+ def _update_stage(self, old_stage, new_stage):
741
+ assert self._stage == old_stage
742
+ self._stage = new_stage
@@ -186,6 +186,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
186
186
 
187
187
  if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
188
188
  assert not no_combine, "unsupported"
189
+ if apply_router_weight_on_input:
190
+ assert (
191
+ topk_weights.dim() == 2
192
+ ), "`topk_weights` should be in shape (num_tokens, topk)"
193
+ _, topk = topk_weights.shape
194
+ assert (
195
+ topk == 1
196
+ ), "Only support topk=1 when `apply_router_weight_on_input` is True"
197
+ x = x * topk_weights.to(x.dtype)
198
+ topk_weights = torch.ones_like(
199
+ topk_weights, dtype=torch.float32
200
+ ) # topk_weights must be FP32 (float32)
201
+
189
202
  return ck_moe_2stages(
190
203
  x,
191
204
  layer.w13_weight,
@@ -270,6 +283,7 @@ class FusedMoE(torch.nn.Module):
270
283
  top_k: int,
271
284
  hidden_size: int,
272
285
  intermediate_size: int,
286
+ layer_id: Optional[int] = None,
273
287
  params_dtype: Optional[torch.dtype] = None,
274
288
  reduce_results: bool = False,
275
289
  renormalize: bool = True,
@@ -18,7 +18,14 @@ from typing import Callable, Optional
18
18
  import torch
19
19
  import torch.nn.functional as F
20
20
 
21
- from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
21
+ from sglang.srt.managers.expert_distribution import (
22
+ ExpertDistributionRecorder,
23
+ get_global_expert_distribution_recorder,
24
+ )
25
+ from sglang.srt.managers.expert_location_dispatch import (
26
+ ExpertLocationDispatchInfo,
27
+ topk_ids_logical_to_physical,
28
+ )
22
29
  from sglang.srt.managers.schedule_batch import global_server_args_dict
23
30
  from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
24
31
 
@@ -32,9 +39,6 @@ if _is_cuda or _is_hip:
32
39
  from sgl_kernel import topk_softmax
33
40
 
34
41
 
35
- expert_distribution_recorder = ExpertDistributionRecorder()
36
-
37
-
38
42
  def fused_topk_native(
39
43
  hidden_states: torch.Tensor,
40
44
  gating_output: torch.Tensor,
@@ -61,6 +65,7 @@ def fused_topk(
61
65
  gating_output: torch.Tensor,
62
66
  topk: int,
63
67
  renormalize: bool,
68
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
64
69
  ):
65
70
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
66
71
 
@@ -84,7 +89,7 @@ def fused_topk(
84
89
 
85
90
  if renormalize:
86
91
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
87
-
92
+ topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
88
93
  return topk_weights, topk_ids
89
94
 
90
95
 
@@ -99,6 +104,8 @@ def grouped_topk(
99
104
  topk_group: int = 0,
100
105
  n_share_experts_fusion: int = 0,
101
106
  routed_scaling_factor: Optional[float] = None,
107
+ num_token_non_padded: Optional[torch.Tensor] = None,
108
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
102
109
  ):
103
110
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
104
111
 
@@ -138,7 +145,10 @@ def grouped_topk(
138
145
  )
139
146
  topk_weights = topk_weights / topk_weights_sum
140
147
 
141
- return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
148
+ topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
149
+ topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
150
+ _mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
151
+ return topk_weights, topk_ids
142
152
 
143
153
 
144
154
  def biased_grouped_topk_impl(
@@ -151,6 +161,8 @@ def biased_grouped_topk_impl(
151
161
  topk_group: int = 0,
152
162
  n_share_experts_fusion: int = 0,
153
163
  routed_scaling_factor: Optional[float] = None,
164
+ num_token_non_padded: Optional[torch.Tensor] = None,
165
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
154
166
  ):
155
167
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
156
168
 
@@ -197,13 +209,26 @@ def biased_grouped_topk_impl(
197
209
  )
198
210
  topk_weights = topk_weights / topk_weights_sum
199
211
 
200
- return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
212
+ topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
213
+ topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
214
+ _mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
215
+ return topk_weights, topk_ids
201
216
 
202
217
 
203
218
  def is_power_of_two(n):
204
219
  return n > 0 and math.log2(n).is_integer()
205
220
 
206
221
 
222
+ def _mask_topk_ids_padded_region(
223
+ topk_ids: torch.Tensor,
224
+ num_token_non_padded: Optional[torch.Tensor] = None,
225
+ ):
226
+ if num_token_non_padded is None:
227
+ return
228
+ indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device)
229
+ topk_ids[indices >= num_token_non_padded, :] = -1
230
+
231
+
207
232
  def biased_grouped_topk(
208
233
  hidden_states: torch.Tensor,
209
234
  gating_output: torch.Tensor,
@@ -215,6 +240,8 @@ def biased_grouped_topk(
215
240
  compiled: bool = True,
216
241
  n_share_experts_fusion: int = 0,
217
242
  routed_scaling_factor: Optional[float] = None,
243
+ num_token_non_padded: Optional[torch.Tensor] = None,
244
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
218
245
  ):
219
246
  assert (
220
247
  routed_scaling_factor is not None
@@ -226,7 +253,7 @@ def biased_grouped_topk(
226
253
  <= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
227
254
  and is_power_of_two(correction_bias.shape[0])
228
255
  ):
229
- return moe_fused_gate(
256
+ topk_weights, topk_ids = moe_fused_gate(
230
257
  gating_output,
231
258
  correction_bias,
232
259
  num_expert_group,
@@ -235,6 +262,15 @@ def biased_grouped_topk(
235
262
  n_share_experts_fusion,
236
263
  routed_scaling_factor,
237
264
  )
265
+ # TODO merge into kernel for this branch
266
+ topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
267
+ # TODO will fuse this into kernel, thus use slow manual operation now
268
+ if num_token_non_padded is None:
269
+ return topk_weights, topk_ids
270
+ torch.compile(
271
+ _mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
272
+ )(topk_ids, num_token_non_padded)
273
+ return topk_weights, topk_ids
238
274
  else:
239
275
  biased_grouped_topk_fn = (
240
276
  torch.compile(
@@ -253,6 +289,8 @@ def biased_grouped_topk(
253
289
  topk_group,
254
290
  n_share_experts_fusion=n_share_experts_fusion,
255
291
  routed_scaling_factor=routed_scaling_factor,
292
+ num_token_non_padded=num_token_non_padded,
293
+ expert_location_dispatch_info=expert_location_dispatch_info,
256
294
  )
257
295
 
258
296
 
@@ -268,6 +306,8 @@ def select_experts(
268
306
  correction_bias: Optional[torch.Tensor] = None,
269
307
  torch_native: bool = False,
270
308
  routed_scaling_factor: Optional[float] = None,
309
+ num_token_non_padded: Optional[torch.Tensor] = None,
310
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
271
311
  ):
272
312
  n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
273
313
  # DeepSeek V2/V3/R1 series models use grouped_top_k
@@ -284,6 +324,8 @@ def select_experts(
284
324
  topk_group=topk_group,
285
325
  n_share_experts_fusion=n_share_experts_fusion,
286
326
  routed_scaling_factor=routed_scaling_factor,
327
+ num_token_non_padded=num_token_non_padded,
328
+ expert_location_dispatch_info=expert_location_dispatch_info,
287
329
  )
288
330
  else:
289
331
  topk_weights, topk_ids = biased_grouped_topk(
@@ -296,8 +338,14 @@ def select_experts(
296
338
  topk_group=topk_group,
297
339
  n_share_experts_fusion=n_share_experts_fusion,
298
340
  routed_scaling_factor=routed_scaling_factor,
341
+ num_token_non_padded=num_token_non_padded,
342
+ expert_location_dispatch_info=expert_location_dispatch_info,
299
343
  )
300
344
  elif torch_native and custom_routing_function is None:
345
+ assert (
346
+ num_token_non_padded is None
347
+ ), "num_token_non_padded is not yet supported in fused_topk_native"
348
+ assert expert_location_dispatch_info is None
301
349
  topk_weights, topk_ids = fused_topk_native(
302
350
  hidden_states=hidden_states,
303
351
  gating_output=router_logits,
@@ -305,13 +353,22 @@ def select_experts(
305
353
  renormalize=renormalize,
306
354
  )
307
355
  elif custom_routing_function is None:
356
+ assert (
357
+ num_token_non_padded is None
358
+ ), "num_token_non_padded is not yet supported in fused_topk"
359
+ # Qwen3MOE uses fused_topk
308
360
  topk_weights, topk_ids = fused_topk(
309
361
  hidden_states=hidden_states,
310
362
  gating_output=router_logits,
311
363
  topk=top_k,
312
364
  renormalize=renormalize,
365
+ expert_location_dispatch_info=expert_location_dispatch_info,
313
366
  )
314
367
  else:
368
+ assert (
369
+ num_token_non_padded is None
370
+ ), "num_token_non_padded is not yet supported in custom_routing_function"
371
+ assert expert_location_dispatch_info is None
315
372
  topk_weights, topk_ids = custom_routing_function(
316
373
  hidden_states=hidden_states,
317
374
  gating_output=router_logits,
@@ -319,6 +376,6 @@ def select_experts(
319
376
  renormalize=renormalize,
320
377
  )
321
378
 
322
- expert_distribution_recorder.record_new_token(topk_ids)
379
+ get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
323
380
 
324
381
  return topk_weights, topk_ids
@@ -0,0 +1,70 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """Logits processing."""
15
+
16
+ import torch
17
+ import triton
18
+ import triton.language as tl
19
+
20
+
21
+ @triton.jit
22
+ def hash_kernel(
23
+ input_ptr,
24
+ output_ptr,
25
+ n_elements,
26
+ BLOCK_SIZE: tl.constexpr,
27
+ PRIME: tl.constexpr,
28
+ XCONST: tl.constexpr,
29
+ ):
30
+ pid = tl.program_id(axis=0)
31
+ block_start = pid * BLOCK_SIZE
32
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
33
+ mask = offsets < n_elements
34
+
35
+ data = tl.load(input_ptr + offsets, mask=mask, other=0)
36
+ mixed = data ^ (offsets + XCONST)
37
+ hash_val = mixed * PRIME
38
+ hash_val = hash_val ^ (hash_val >> 16)
39
+ hash_val = hash_val * (PRIME ^ XCONST)
40
+ hash_val = hash_val ^ (hash_val >> 13)
41
+
42
+ tl.store(output_ptr + offsets, hash_val, mask=mask)
43
+
44
+
45
+ PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
46
+ PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
47
+
48
+
49
+ def gpu_tensor_hash(tensor: torch.Tensor) -> int:
50
+ assert tensor.is_cuda
51
+ tensor = tensor.contiguous().view(torch.int32)
52
+ n = tensor.numel()
53
+ BLOCK_SIZE = 1024
54
+ grid = (triton.cdiv(n, BLOCK_SIZE),)
55
+
56
+ intermediate_hashes = torch.empty(n, dtype=torch.int32, device=tensor.device)
57
+
58
+ hash_kernel[grid](
59
+ tensor,
60
+ intermediate_hashes,
61
+ n,
62
+ BLOCK_SIZE=BLOCK_SIZE,
63
+ PRIME=PRIME_1,
64
+ XCONST=PRIME_2,
65
+ )
66
+
67
+ # TODO: threads can't be synced on triton kernel
68
+ final_hash = intermediate_hashes.sum().item()
69
+
70
+ return final_hash
@@ -25,7 +25,6 @@ try:
25
25
  from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
26
26
  from vllm.model_executor.layers.quantization.gptq_marlin import (
27
27
  GPTQMarlinLinearMethod,
28
- GPTQMarlinMoEMethod,
29
28
  )
30
29
  from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
31
30
  GPTQMarlin24Config,
@@ -58,12 +57,17 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
58
57
  CompressedTensorsConfig,
59
58
  )
60
59
  from sglang.srt.layers.quantization.fp8 import Fp8Config
61
- from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
60
+ from sglang.srt.layers.quantization.gptq import (
61
+ GPTQConfig,
62
+ GPTQMarlinConfig,
63
+ GPTQMarlinMoEMethod,
64
+ )
62
65
  from sglang.srt.layers.quantization.modelopt_quant import (
63
66
  ModelOptFp4Config,
64
67
  ModelOptFp8Config,
65
68
  )
66
69
  from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
70
+ from sglang.srt.layers.quantization.qoq import QoQConfig
67
71
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
68
72
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
69
73
 
@@ -77,6 +81,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
77
81
  "w8a8_fp8": W8A8Fp8Config,
78
82
  "moe_wna16": MoeWNA16Config,
79
83
  "compressed-tensors": CompressedTensorsConfig,
84
+ "qoq": QoQConfig,
80
85
  }
81
86
 
82
87
  # VLLM-dependent quantization methods
@@ -11,8 +11,10 @@ from tqdm.contrib.concurrent import thread_map
11
11
  from sglang.srt.server_args import ServerArgs
12
12
  from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
13
13
 
14
+ logger = logging.getLogger(__name__)
14
15
  _ENABLE_JIT_DEEPGEMM = False
15
- if is_cuda():
16
+
17
+ try:
16
18
  import deep_gemm
17
19
  from deep_gemm import get_num_sms
18
20
  from deep_gemm.jit.compiler import get_nvcc_compiler
@@ -24,14 +26,14 @@ if is_cuda():
24
26
  if sm_version == 90:
25
27
  if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
26
28
  _ENABLE_JIT_DEEPGEMM = True
29
+ except ImportError:
30
+ logger.warning("Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.")
27
31
 
28
32
 
29
33
  def get_enable_jit_deepgemm():
30
34
  return _ENABLE_JIT_DEEPGEMM
31
35
 
32
36
 
33
- logger = logging.getLogger(__name__)
34
-
35
37
  _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
36
38
  _ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
37
39
  "SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
@@ -52,6 +52,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
52
52
  apply_w8a8_block_fp8_linear,
53
53
  cutlass_fp8_supported,
54
54
  input_to_float8,
55
+ is_sm100_supported,
55
56
  normalize_e4m3fn_to_e4m3fnuz,
56
57
  )
57
58
  from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
@@ -470,6 +471,7 @@ class Fp8MoEMethod:
470
471
  def __init__(self, quant_config):
471
472
  self.quant_config = quant_config
472
473
  self.block_quant = self.quant_config.weight_block_size is not None
474
+ self.cutlass_fp8_supported = cutlass_fp8_supported()
473
475
 
474
476
  def create_weights(
475
477
  self,
@@ -568,6 +570,63 @@ class Fp8MoEMethod:
568
570
  layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
569
571
  layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
570
572
  assert self.quant_config.activation_scheme == "dynamic"
573
+ if (
574
+ get_bool_env_var("CUTLASS_MOE")
575
+ and self.cutlass_fp8_supported
576
+ and is_sm100_supported()
577
+ ):
578
+ self.ab_strides1 = torch.full(
579
+ (num_experts,),
580
+ hidden_size,
581
+ device=w13_weight.device,
582
+ dtype=torch.int64,
583
+ )
584
+ self.c_strides1 = torch.full(
585
+ (num_experts,),
586
+ 2 * intermediate_size,
587
+ device=w13_weight.device,
588
+ dtype=torch.int64,
589
+ )
590
+ self.ab_strides2 = torch.full(
591
+ (num_experts,),
592
+ intermediate_size,
593
+ device=w2_weight.device,
594
+ dtype=torch.int64,
595
+ )
596
+ self.c_strides2 = torch.full(
597
+ (num_experts,),
598
+ hidden_size,
599
+ device=w2_weight.device,
600
+ dtype=torch.int64,
601
+ )
602
+ self.workspace = torch.empty(
603
+ 90000, device=w13_weight.device, dtype=torch.uint8
604
+ )
605
+ self.a_ptr = torch.empty(
606
+ num_experts, device=w13_weight.device, dtype=torch.int64
607
+ )
608
+ self.b_ptr = torch.empty(
609
+ num_experts, device=w13_weight.device, dtype=torch.int64
610
+ )
611
+ self.out_ptr = torch.empty(
612
+ num_experts, device=w13_weight.device, dtype=torch.int64
613
+ )
614
+ self.a_scales_ptr = torch.empty(
615
+ num_experts, device=w13_weight.device, dtype=torch.int64
616
+ )
617
+ self.b_scales_ptr = torch.empty(
618
+ num_experts, device=w13_weight.device, dtype=torch.int64
619
+ )
620
+ self.expert_offsets = torch.empty(
621
+ num_experts + 1, device=w13_weight.device, dtype=torch.int32
622
+ )
623
+ self.problem_sizes1 = torch.empty(
624
+ num_experts, 3, device=w13_weight.device, dtype=torch.int32
625
+ )
626
+ self.problem_sizes2 = torch.empty(
627
+ num_experts, 3, device=w13_weight.device, dtype=torch.int32
628
+ )
629
+
571
630
  else:
572
631
  # Allocate 2 scales for w1 and w3 respectively.
573
632
  # They will be combined to a single scale after weight loading.
@@ -913,6 +972,37 @@ class Fp8MoEMethod:
913
972
  if ret is not None:
914
973
  return ret
915
974
 
975
+ if (
976
+ get_bool_env_var("CUTLASS_MOE")
977
+ and self.cutlass_fp8_supported
978
+ and self.block_quant
979
+ and is_sm100_supported()
980
+ ):
981
+ from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts
982
+
983
+ return cutlass_fused_experts(
984
+ x,
985
+ layer.w13_weight.transpose(1, 2),
986
+ layer.w2_weight.transpose(1, 2),
987
+ layer.w13_weight_scale_inv.transpose(1, 2),
988
+ layer.w2_weight_scale_inv.transpose(1, 2),
989
+ topk_weights,
990
+ topk_ids,
991
+ self.ab_strides1,
992
+ self.c_strides1,
993
+ self.ab_strides2,
994
+ self.c_strides2,
995
+ self.workspace,
996
+ self.a_ptr,
997
+ self.b_ptr,
998
+ self.out_ptr,
999
+ self.a_scales_ptr,
1000
+ self.b_scales_ptr,
1001
+ self.expert_offsets,
1002
+ self.problem_sizes1,
1003
+ self.problem_sizes2,
1004
+ use_fp8_blockscale=True,
1005
+ )
916
1006
  # Expert fusion with FP8 quantization
917
1007
  return fused_experts(
918
1008
  x,
@@ -80,6 +80,12 @@ def cutlass_fp8_supported():
80
80
  return False
81
81
 
82
82
 
83
+ def is_sm100_supported(device=None) -> bool:
84
+ return (torch.cuda.get_device_capability(device)[0] == 10) and (
85
+ torch.version.cuda >= "12.8"
86
+ )
87
+
88
+
83
89
  def normalize_e4m3fn_to_e4m3fnuz(
84
90
  weight: torch.Tensor,
85
91
  weight_scale: torch.Tensor,