sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import torch
1
2
  import triton
2
3
  import triton.language as tl
3
4
 
@@ -101,3 +102,80 @@ def create_flashmla_kv_indices_triton(
101
102
  data // PAGED_SIZE,
102
103
  mask=mask_out,
103
104
  )
105
+
106
+
107
+ @triton.jit
108
+ def concat_and_cast_mha_k_kernel(
109
+ k_ptr,
110
+ k_nope_ptr,
111
+ k_rope_ptr,
112
+ head_cnt: tl.constexpr,
113
+ k_stride0: tl.constexpr,
114
+ k_stride1: tl.constexpr,
115
+ nope_stride0: tl.constexpr,
116
+ nope_stride1: tl.constexpr,
117
+ rope_stride0: tl.constexpr,
118
+ nope_dim: tl.constexpr,
119
+ rope_dim: tl.constexpr,
120
+ ):
121
+ pid_loc = tl.program_id(0)
122
+ head_range = tl.arange(0, head_cnt)
123
+
124
+ k_head_ptr = k_ptr + pid_loc * k_stride0 + head_range[:, None] * k_stride1
125
+
126
+ nope_offs = tl.arange(0, nope_dim)
127
+
128
+ src_nope_ptr = (
129
+ k_nope_ptr
130
+ + pid_loc * nope_stride0
131
+ + head_range[:, None] * nope_stride1
132
+ + nope_offs[None, :]
133
+ )
134
+ dst_nope_ptr = k_head_ptr + nope_offs[None, :]
135
+
136
+ src_nope = tl.load(src_nope_ptr)
137
+ tl.store(dst_nope_ptr, src_nope)
138
+
139
+ rope_offs = tl.arange(0, rope_dim)
140
+ src_rope_ptr = k_rope_ptr + pid_loc * rope_stride0 + rope_offs[None, :]
141
+ dst_rope_ptr = k_head_ptr + nope_dim + rope_offs[None, :]
142
+ src_rope = tl.load(src_rope_ptr)
143
+ tl.store(dst_rope_ptr, src_rope)
144
+
145
+
146
+ def concat_and_cast_mha_k_triton(
147
+ k: torch.Tensor,
148
+ k_nope: torch.Tensor,
149
+ k_rope: torch.Tensor,
150
+ ):
151
+ # The source data type will be implicitly converted to the target data type.
152
+ assert (
153
+ len(k.shape) == 3 and len(k_nope.shape) == 3 and len(k_rope.shape) == 3
154
+ ), f"shape should be 3d, but got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
155
+ assert (
156
+ k.shape[0] == k_nope.shape[0] and k.shape[0] == k_rope.shape[0]
157
+ ), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
158
+ assert (
159
+ k.shape[1] == k_nope.shape[1] and 1 == k_rope.shape[1]
160
+ ), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
161
+ assert (
162
+ k.shape[-1] == k_nope.shape[-1] + k_rope.shape[-1]
163
+ ), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
164
+
165
+ nope_dim = k_nope.shape[-1]
166
+ rope_dim = k_rope.shape[-1]
167
+ grid = (k.shape[0],)
168
+
169
+ concat_and_cast_mha_k_kernel[grid](
170
+ k,
171
+ k_nope,
172
+ k_rope,
173
+ k.shape[1],
174
+ k.stride(0),
175
+ k.stride(1),
176
+ k_nope.stride(0),
177
+ k_nope.stride(1),
178
+ k_rope.stride(0),
179
+ nope_dim,
180
+ rope_dim,
181
+ )
@@ -15,7 +15,7 @@
15
15
  from dataclasses import dataclass
16
16
  from enum import Enum, auto
17
17
  from functools import partial
18
- from typing import Dict, Optional
18
+ from typing import Dict, List, Optional
19
19
 
20
20
  import torch
21
21
 
@@ -216,6 +216,28 @@ class LayerCommunicator:
216
216
  get_global_server_args().speculative_algorithm
217
217
  )
218
218
 
219
+ def prepare_attn_and_capture_last_layer_outputs(
220
+ self,
221
+ hidden_states: torch.Tensor,
222
+ residual: torch.Tensor,
223
+ forward_batch: ForwardBatch,
224
+ captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
225
+ ):
226
+ hidden_states, residual = self.prepare_attn(
227
+ hidden_states, residual, forward_batch
228
+ )
229
+ if captured_last_layer_outputs is not None:
230
+ gathered_last_layer_output = self._communicate_simple_fn(
231
+ hidden_states=residual,
232
+ forward_batch=forward_batch,
233
+ context=self._context,
234
+ )
235
+ if gathered_last_layer_output is residual:
236
+ # Clone to avoid modifying the original residual by Custom RMSNorm inplace operation
237
+ gathered_last_layer_output = residual.clone()
238
+ captured_last_layer_outputs.append(gathered_last_layer_output)
239
+ return hidden_states, residual
240
+
219
241
  def prepare_attn(
220
242
  self,
221
243
  hidden_states: torch.Tensor,
@@ -337,6 +359,7 @@ class LayerCommunicator:
337
359
  static_conditions_met = (
338
360
  (not self.is_last_layer)
339
361
  and (self._context.tp_size > 1)
362
+ and not is_dp_attention_enabled()
340
363
  and get_global_server_args().enable_flashinfer_allreduce_fusion
341
364
  and _is_flashinfer_available
342
365
  )
@@ -26,7 +26,7 @@ _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "fal
26
26
 
27
27
  # Force redirect deep_gemm cache_dir
28
28
  os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
29
- "SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
29
+ "SGLANG_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
30
30
  )
31
31
 
32
32
  # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
@@ -20,7 +20,12 @@ import torch
20
20
  import torch.nn as nn
21
21
  from packaging.version import Version
22
22
 
23
+ from sglang.srt.batch_invariant_ops import (
24
+ is_batch_invariant_mode_enabled,
25
+ rms_norm_batch_invariant,
26
+ )
23
27
  from sglang.srt.custom_op import CustomOp
28
+ from sglang.srt.server_args import get_global_server_args
24
29
  from sglang.srt.utils import (
25
30
  cpu_has_amx_support,
26
31
  get_bool_env_var,
@@ -73,9 +78,16 @@ class RMSNorm(CustomOp):
73
78
  hidden_size: int,
74
79
  eps: float = 1e-6,
75
80
  var_hidden_size: Optional[int] = None,
81
+ cast_x_before_out_mul: bool = False,
82
+ fp32_residual: bool = False,
83
+ weight_dtype: Optional = None,
84
+ override_orig_dtype: Optional = None,
76
85
  ) -> None:
77
86
  super().__init__()
78
- self.weight = nn.Parameter(torch.ones(hidden_size))
87
+ self.cast_x_before_out_mul = cast_x_before_out_mul
88
+ self.fp32_residual = fp32_residual
89
+ self.override_orig_dtype = override_orig_dtype
90
+ self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype))
79
91
  self.variance_epsilon = eps
80
92
  self.hidden_size = hidden_size
81
93
  self.variance_size_override = (
@@ -83,8 +95,6 @@ class RMSNorm(CustomOp):
83
95
  )
84
96
  if _use_aiter:
85
97
  self._forward_method = self.forward_aiter
86
- if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
87
- self._forward_method = self.forward_native
88
98
 
89
99
  def forward_cuda(
90
100
  self,
@@ -93,6 +103,17 @@ class RMSNorm(CustomOp):
93
103
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
94
104
  if self.variance_size_override is not None:
95
105
  return self.forward_native(x, residual)
106
+ if is_batch_invariant_mode_enabled():
107
+ if (
108
+ residual is not None
109
+ or get_global_server_args().rl_on_policy_target == "fsdp"
110
+ ):
111
+ return self.forward_native(x, residual)
112
+ return rms_norm_batch_invariant(
113
+ x,
114
+ self.weight.data,
115
+ self.variance_epsilon,
116
+ )
96
117
  if residual is not None:
97
118
  fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
98
119
  return x, residual
@@ -165,11 +186,14 @@ class RMSNorm(CustomOp):
165
186
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
166
187
  if not x.is_contiguous():
167
188
  x = x.contiguous()
168
- orig_dtype = x.dtype
189
+ orig_dtype = self.override_orig_dtype or x.dtype
169
190
  x = x.to(torch.float32)
170
191
  if residual is not None:
171
192
  x = x + residual.to(torch.float32)
172
- residual = x.to(orig_dtype)
193
+ if self.fp32_residual:
194
+ residual = x.clone()
195
+ else:
196
+ residual = x.to(orig_dtype)
173
197
 
174
198
  hidden_size = x.shape[-1]
175
199
  if hidden_size != self.hidden_size:
@@ -191,7 +215,12 @@ class RMSNorm(CustomOp):
191
215
 
192
216
  variance = x_var.pow(2).mean(dim=-1, keepdim=True)
193
217
  x = x * torch.rsqrt(variance + self.variance_epsilon)
194
- x = (x * self.weight).to(orig_dtype)
218
+
219
+ if self.cast_x_before_out_mul:
220
+ x = self.weight * x.to(orig_dtype)
221
+ else:
222
+ x = (x * self.weight).to(orig_dtype)
223
+
195
224
  if residual is None:
196
225
  return x
197
226
  else:
@@ -38,7 +38,6 @@ from sglang.srt.layers.dp_attention import (
38
38
  get_dp_device,
39
39
  get_dp_dtype,
40
40
  get_dp_hidden_size,
41
- get_local_attention_dp_size,
42
41
  )
43
42
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
44
43
  from sglang.srt.model_executor.forward_batch_info import (
@@ -47,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import (
47
46
  ForwardMode,
48
47
  )
49
48
  from sglang.srt.server_args import get_global_server_args
50
- from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
49
+ from sglang.srt.utils import is_npu, use_intel_amx_backend
51
50
 
52
51
  logger = logging.getLogger(__name__)
53
52
 
@@ -135,10 +134,7 @@ class LogitsMetadata:
135
134
  @classmethod
136
135
  def from_forward_batch(cls, forward_batch: ForwardBatch):
137
136
  if (
138
- (
139
- forward_batch.forward_mode.is_extend()
140
- or forward_batch.forward_mode.is_split_prefill()
141
- )
137
+ forward_batch.forward_mode.is_extend()
142
138
  and forward_batch.return_logprob
143
139
  and not forward_batch.forward_mode.is_target_verify()
144
140
  ):
@@ -252,10 +248,6 @@ class LogitsProcessor(nn.Module):
252
248
  ):
253
249
  self.final_logit_softcapping = None
254
250
 
255
- self.debug_tensor_dump_output_folder = (
256
- get_global_server_args().debug_tensor_dump_output_folder
257
- )
258
-
259
251
  def compute_logprobs_for_multi_item_scoring(
260
252
  self,
261
253
  input_ids,
@@ -389,8 +381,8 @@ class LogitsProcessor(nn.Module):
389
381
  input_logprob_indices = None
390
382
  elif (
391
383
  logits_metadata.forward_mode.is_extend()
392
- or logits_metadata.forward_mode.is_split_prefill()
393
- ) and not logits_metadata.extend_return_logprob:
384
+ and not logits_metadata.extend_return_logprob
385
+ ):
394
386
  # Prefill without input logprobs.
395
387
  if logits_metadata.padded_static_len < 0:
396
388
  last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
@@ -463,14 +455,6 @@ class LogitsProcessor(nn.Module):
463
455
  logits[sample_indices] if sample_indices is not None else logits
464
456
  )
465
457
 
466
- if self.debug_tensor_dump_output_folder:
467
- assert (
468
- not self.do_tensor_parallel_all_gather
469
- or get_local_attention_dp_size() == 1
470
- ), "dp attention + sharded lm_head doesn't support full logits"
471
- full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
472
- dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
473
-
474
458
  hidden_states_to_store: Optional[torch.Tensor] = None
475
459
  if logits_metadata.capture_hidden_mode.need_capture():
476
460
  if logits_metadata.capture_hidden_mode.is_full():
@@ -593,6 +577,11 @@ class LogitsProcessor(nn.Module):
593
577
  None, # bias
594
578
  True, # is_vnni
595
579
  )
580
+ elif get_global_server_args().rl_on_policy_target == "fsdp":
581
+ # Due to tie-weight, we may not be able to change lm_head's weight dtype
582
+ logits = torch.matmul(
583
+ hidden_states.bfloat16(), lm_head.weight.T.bfloat16()
584
+ )
596
585
  else:
597
586
  logits = torch.matmul(
598
587
  hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
@@ -11,12 +11,14 @@ from sgl_kernel import (
11
11
  )
12
12
 
13
13
  from sglang.srt.layers.moe.ep_moe.kernels import (
14
+ deepep_ll_get_cutlass_w4a8_moe_mm_data,
14
15
  deepep_permute_triton_kernel,
15
16
  deepep_post_reorder_triton_kernel,
16
17
  deepep_run_moe_deep_preprocess,
17
18
  post_reorder_triton_kernel_for_cutlass_moe,
18
19
  pre_reorder_triton_kernel_for_cutlass_moe,
19
20
  run_moe_ep_preproess,
21
+ silu_and_mul_masked_post_per_tensor_quant_fwd,
20
22
  )
21
23
 
22
24
 
@@ -396,3 +398,139 @@ def cutlass_w4a8_moe_deepep_normal(
396
398
  )
397
399
 
398
400
  return output
401
+
402
+
403
+ def cutlass_w4a8_moe_deepep_ll(
404
+ a: torch.Tensor,
405
+ w1_q: torch.Tensor,
406
+ w2_q: torch.Tensor,
407
+ w1_scale: torch.Tensor,
408
+ w2_scale: torch.Tensor,
409
+ topk_ids_: torch.Tensor,
410
+ masked_m: torch.Tensor,
411
+ a_strides1: torch.Tensor,
412
+ b_strides1: torch.Tensor,
413
+ c_strides1: torch.Tensor,
414
+ a_strides2: torch.Tensor,
415
+ b_strides2: torch.Tensor,
416
+ c_strides2: torch.Tensor,
417
+ s_strides13: torch.Tensor,
418
+ s_strides2: torch.Tensor,
419
+ expert_offsets: torch.Tensor,
420
+ problem_sizes1: torch.Tensor,
421
+ problem_sizes2: torch.Tensor,
422
+ a1_scale: Optional[torch.Tensor] = None,
423
+ a2_scale: Optional[torch.Tensor] = None,
424
+ ) -> torch.Tensor:
425
+ """
426
+ This function computes a w4a8-quantized Mixture of Experts (MoE) layer
427
+ using two sets of quantized weights, w1_q and w2_q, and top-k gating
428
+ mechanism. The matrix multiplications are implemented with CUTLASS
429
+ grouped gemm.
430
+
431
+ Parameters:
432
+ - a (torch.Tensor): The input tensor to the MoE layer.
433
+ Shape: [num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, K]
434
+ - w1_q (torch.Tensor): The first set of int4-quantized expert weights.
435
+ Shape: [num_experts, N * 2, K // 2]
436
+ (the weights are passed transposed and int4-packed)
437
+ - w2_q (torch.Tensor): The second set of int4-quantized expert weights.
438
+ Shape: [num_experts, K, N // 2]
439
+ (the weights are passed transposed and int4-packed)
440
+ - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
441
+ Shape: [num_experts, K // 512, N * 8]
442
+ - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
443
+ Shape: [num_experts, N // 512, K * 4]
444
+ - topk_weights (torch.Tensor): The weights of each token->expert mapping.
445
+ - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
446
+ - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
447
+ - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
448
+ - a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
449
+ - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
450
+ - c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
451
+ - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
452
+ - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
453
+ - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
454
+ Shape: scalar or [1, K]
455
+ - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
456
+ quantize the intermediate result between the gemms.
457
+ Shape: scalar or [1, N]
458
+ - apply_router_weight_on_input (bool): When true, the topk weights are
459
+ applied directly on the inputs. This is only applicable when topk is 1.
460
+
461
+ Returns:
462
+ - torch.Tensor: The fp8 output tensor after applying the MoE layer.
463
+ """
464
+ assert w1_q.dtype == torch.int8
465
+ assert w2_q.dtype == torch.int8
466
+ assert a.shape[2] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
467
+ assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
468
+ assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
469
+ assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
470
+ assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
471
+
472
+ assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
473
+ assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
474
+ assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
475
+ assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
476
+ num_experts = w1_q.size(0)
477
+ m = a.size(1)
478
+ k = w1_q.size(2) * 2 # w1_q is transposed and packed
479
+ n = w2_q.size(2) * 2 # w2_q is transposed and packed
480
+ topk = topk_ids_.size(1)
481
+
482
+ device = a.device
483
+
484
+ problem_sizes1, problem_sizes2 = deepep_ll_get_cutlass_w4a8_moe_mm_data(
485
+ masked_m,
486
+ problem_sizes1,
487
+ problem_sizes2,
488
+ num_experts,
489
+ n,
490
+ k,
491
+ )
492
+
493
+ gateup_input = torch.empty(a.shape, dtype=torch.float8_e4m3fn, device=device)
494
+ sgl_per_tensor_quant_fp8(a, gateup_input, a1_scale.float(), True)
495
+ c1 = torch.empty((num_experts, m, n * 2), device=device, dtype=torch.bfloat16)
496
+ c2 = torch.empty((num_experts, m, k), device=device, dtype=torch.bfloat16)
497
+
498
+ cutlass_w4a8_moe_mm(
499
+ c1,
500
+ gateup_input,
501
+ w1_q,
502
+ a1_scale.float(),
503
+ w1_scale,
504
+ expert_offsets[:-1],
505
+ problem_sizes1,
506
+ a_strides1,
507
+ b_strides1,
508
+ c_strides1,
509
+ s_strides13,
510
+ 128,
511
+ topk,
512
+ )
513
+
514
+ intermediate_q = torch.empty(
515
+ (num_experts, m, n), device=a.device, dtype=torch.float8_e4m3fn
516
+ )
517
+ silu_and_mul_masked_post_per_tensor_quant_fwd(
518
+ c1, intermediate_q, masked_m, a2_scale
519
+ )
520
+ cutlass_w4a8_moe_mm(
521
+ c2,
522
+ intermediate_q,
523
+ w2_q,
524
+ a2_scale.float(),
525
+ w2_scale,
526
+ expert_offsets[:-1],
527
+ problem_sizes2,
528
+ a_strides2,
529
+ b_strides2,
530
+ c_strides2,
531
+ s_strides2,
532
+ 128,
533
+ topk,
534
+ )
535
+
536
+ return c2
@@ -1014,3 +1014,197 @@ def zero_experts_compute_triton(
1014
1014
  )
1015
1015
 
1016
1016
  return output
1017
+
1018
+
1019
+ @triton.jit
1020
+ def compute_problem_sizes_w4a8_kernel(
1021
+ masked_m_ptr,
1022
+ problem_sizes1_ptr,
1023
+ problem_sizes2_ptr,
1024
+ n,
1025
+ k,
1026
+ num_experts,
1027
+ BLOCK_SIZE: tl.constexpr,
1028
+ ):
1029
+ pid = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1030
+ mask = pid < num_experts
1031
+ final_occurrences = tl.load(masked_m_ptr + pid, mask=mask, other=0)
1032
+
1033
+ ps1_idx_0 = pid * 3
1034
+ ps1_idx_1 = ps1_idx_0 + 1
1035
+ ps1_idx_2 = ps1_idx_0 + 2
1036
+
1037
+ ps2_idx_0 = pid * 3
1038
+ ps2_idx_1 = ps2_idx_0 + 1
1039
+ ps2_idx_2 = ps2_idx_0 + 2
1040
+
1041
+ ps1_mask_0 = ps1_idx_0 < num_experts * 3
1042
+ ps1_mask_1 = ps1_idx_1 < num_experts * 3
1043
+ ps1_mask_2 = ps1_idx_2 < num_experts * 3
1044
+ ps2_mask_0 = ps2_idx_0 < num_experts * 3
1045
+ ps2_mask_1 = ps2_idx_1 < num_experts * 3
1046
+ ps2_mask_2 = ps2_idx_2 < num_experts * 3
1047
+
1048
+ tl.store(problem_sizes1_ptr + ps1_idx_0, 2 * n, mask=ps1_mask_0)
1049
+ tl.store(problem_sizes1_ptr + ps1_idx_1, final_occurrences, mask=ps1_mask_1)
1050
+ tl.store(problem_sizes1_ptr + ps1_idx_2, k, mask=ps1_mask_2)
1051
+
1052
+ tl.store(problem_sizes2_ptr + ps2_idx_0, k, mask=ps2_mask_0)
1053
+ tl.store(problem_sizes2_ptr + ps2_idx_1, final_occurrences, mask=ps2_mask_1)
1054
+ tl.store(problem_sizes2_ptr + ps2_idx_2, n, mask=ps2_mask_2)
1055
+
1056
+
1057
+ def compute_problem_sizes_w4a8(
1058
+ masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
1059
+ ):
1060
+ BLOCK_SIZE = 256
1061
+ grid = lambda meta: (triton.cdiv(num_experts, meta["BLOCK_SIZE"]),)
1062
+ compute_problem_sizes_w4a8_kernel[grid](
1063
+ masked_m,
1064
+ problem_sizes1,
1065
+ problem_sizes2,
1066
+ n,
1067
+ k,
1068
+ num_experts,
1069
+ BLOCK_SIZE=BLOCK_SIZE,
1070
+ )
1071
+ return problem_sizes1, problem_sizes2
1072
+
1073
+
1074
+ def deepep_ll_get_cutlass_w4a8_moe_mm_data(
1075
+ masked_m,
1076
+ problem_sizes1,
1077
+ problem_sizes2,
1078
+ num_experts,
1079
+ n,
1080
+ k,
1081
+ ):
1082
+ problem_sizes1, problem_sizes2 = compute_problem_sizes_w4a8(
1083
+ masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
1084
+ )
1085
+ return (
1086
+ problem_sizes1.to(torch.int32),
1087
+ problem_sizes2.to(torch.int32),
1088
+ )
1089
+
1090
+
1091
+ @triton.jit
1092
+ def _silu_and_mul_post_per_tensor_quant_kernel(
1093
+ input_ptr,
1094
+ stride_input_expert,
1095
+ stride_input_token,
1096
+ stride_input_dim,
1097
+ output_ptr,
1098
+ stride_output_expert,
1099
+ stride_output_token,
1100
+ stride_output_dim,
1101
+ scale_ptr,
1102
+ masked_m_ptr,
1103
+ inner_dim,
1104
+ fp8_max,
1105
+ fp8_min,
1106
+ BLOCK_N: tl.constexpr,
1107
+ NUM_STAGE: tl.constexpr,
1108
+ ):
1109
+ """
1110
+ Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization.
1111
+
1112
+ Shape:
1113
+ input: [E, T_padded, 2*D] -> gate: [:,:,D], up: [:,:,D]
1114
+ output: [E, T_padded, D], dtype=float8_e4m3fn
1115
+ """
1116
+ expert_id = tl.program_id(2)
1117
+ block_id_token = tl.program_id(1)
1118
+ block_id_dim = tl.program_id(0)
1119
+
1120
+ num_token_blocks = tl.num_programs(1)
1121
+
1122
+ token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
1123
+
1124
+ scale = 1.0 / tl.load(scale_ptr).to(tl.float32)
1125
+
1126
+ stride_input_expert = tl.cast(stride_input_expert, tl.int32)
1127
+ stride_output_expert = tl.cast(stride_output_expert, tl.int32)
1128
+ stride_input_token = tl.cast(stride_input_token, tl.int32)
1129
+ stride_output_token = tl.cast(stride_output_token, tl.int32)
1130
+
1131
+ offset_d = block_id_dim * BLOCK_N + tl.arange(0, BLOCK_N)
1132
+ mask_d = offset_d < inner_dim
1133
+
1134
+ # base pointers for current expert and dim block
1135
+ input_base_offs = input_ptr + expert_id * stride_input_expert + offset_d
1136
+ output_base_offs = output_ptr + expert_id * stride_output_expert + offset_d
1137
+
1138
+ for token_idx in tl.range(
1139
+ block_id_token, token_num_cur_expert, num_token_blocks, num_stages=NUM_STAGE
1140
+ ):
1141
+ gate_ptr = input_base_offs + token_idx * stride_input_token
1142
+ up_ptr = gate_ptr + inner_dim
1143
+ gate = tl.load(gate_ptr, mask=mask_d, other=0.0).to(tl.float32)
1144
+ up = tl.load(up_ptr, mask=mask_d, other=0.0).to(tl.float32)
1145
+
1146
+ # SiLU: x * sigmoid(x)
1147
+ gate = gate / (1 + tl.exp(-gate))
1148
+ gate = gate.to(input_ptr.dtype.element_ty)
1149
+ gate_up = up * gate
1150
+
1151
+ scaled = gate_up * scale
1152
+ output_q = tl.clamp(scaled, fp8_min, fp8_max).to(output_ptr.dtype.element_ty)
1153
+ out_ptr = output_base_offs + token_idx * stride_output_token
1154
+ tl.store(out_ptr, output_q, mask=mask_d)
1155
+
1156
+
1157
+ def silu_and_mul_masked_post_per_tensor_quant_fwd(
1158
+ input: torch.Tensor,
1159
+ output: torch.Tensor,
1160
+ masked_m: torch.Tensor,
1161
+ scale: torch.Tensor,
1162
+ ) -> torch.Tensor:
1163
+ """
1164
+ Fused SiLU + Mul + Per-Tensor Quantization to FP8.
1165
+
1166
+ Args:
1167
+ input: [expert_num, token_num_padded, 2 * inner_dim]
1168
+ output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn
1169
+ masked_m: [expert_num], actual token count for each expert
1170
+ scale: [1] or [expert_num], quantization scale (per-tensor or per-expert)
1171
+
1172
+ Returns:
1173
+ output tensor
1174
+ """
1175
+ assert input.is_contiguous()
1176
+ assert output.is_contiguous()
1177
+ assert output.dtype == torch.float8_e4m3fn
1178
+ assert input.ndim == 3
1179
+ assert input.shape[0] == masked_m.shape[0]
1180
+ assert input.shape[-1] % 2 == 0
1181
+ assert scale.numel() == 1 or scale.shape[0] == input.shape[0]
1182
+
1183
+ expert_num = input.shape[0]
1184
+ # 3584
1185
+ inner_dim = input.shape[-1] // 2
1186
+
1187
+ BLOCK_N = 256
1188
+ BLOCK_M = 64 if expert_num < 4 else 32
1189
+ NUM_STAGES = 3
1190
+ hidden_dim_split_block_num = triton.cdiv(inner_dim, BLOCK_N)
1191
+
1192
+ grid = (hidden_dim_split_block_num, BLOCK_M, expert_num)
1193
+ finfo = torch.finfo(torch.float8_e4m3fn)
1194
+ fp8_max = finfo.max
1195
+ fp8_min = -fp8_max
1196
+
1197
+ _silu_and_mul_post_per_tensor_quant_kernel[grid](
1198
+ input,
1199
+ *input.stride(),
1200
+ output,
1201
+ *output.stride(),
1202
+ scale,
1203
+ masked_m,
1204
+ inner_dim,
1205
+ fp8_max,
1206
+ fp8_min,
1207
+ BLOCK_N=BLOCK_N,
1208
+ NUM_STAGE=NUM_STAGES,
1209
+ )
1210
+ return output