sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,7 @@ import concurrent.futures
21
21
  import logging
22
22
  import os
23
23
  from enum import IntEnum, auto
24
- from typing import Any, Dict, Iterable, Optional, Tuple, Union
24
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
25
25
 
26
26
  import torch
27
27
  import torch.nn.functional as F
@@ -57,6 +57,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
57
57
  is_mla_preprocess_enabled,
58
58
  )
59
59
  from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
60
+ from sglang.srt.layers.attention.utils import concat_and_cast_mha_k_triton
60
61
  from sglang.srt.layers.communicator import (
61
62
  LayerCommunicator,
62
63
  LayerScatterModes,
@@ -130,13 +131,11 @@ from sglang.srt.utils import (
130
131
  get_int_env_var,
131
132
  is_cpu,
132
133
  is_cuda,
133
- is_flashinfer_available,
134
134
  is_gfx95_supported,
135
135
  is_hip,
136
136
  is_non_idle_and_non_empty,
137
137
  is_npu,
138
138
  is_nvidia_cublas_cu12_version_ge_12_9,
139
- is_sm100_supported,
140
139
  log_info_on_rank0,
141
140
  make_layers,
142
141
  use_intel_amx_backend,
@@ -196,8 +195,6 @@ elif _is_npu:
196
195
  else:
197
196
  pass
198
197
 
199
- _is_flashinfer_available = is_flashinfer_available()
200
- _is_sm100_supported = is_cuda() and is_sm100_supported()
201
198
  _is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
202
199
 
203
200
  logger = logging.getLogger(__name__)
@@ -227,6 +224,17 @@ def add_forward_absorb_core_attention_backend(backend_name):
227
224
  logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
228
225
 
229
226
 
227
+ def is_nsa_indexer_wk_and_weights_proj_fused(config, quant_config):
228
+ """
229
+ NSA Indexer wk and weights_proj can be fused in FP4 model because they are both in BF16
230
+ """
231
+ return (
232
+ is_deepseek_nsa(config)
233
+ and quant_config is not None
234
+ and quant_config.get_name() == "modelopt_fp4"
235
+ )
236
+
237
+
230
238
  class AttnForwardMethod(IntEnum):
231
239
  # Use multi-head attention
232
240
  MHA = auto()
@@ -241,6 +249,10 @@ class AttnForwardMethod(IntEnum):
241
249
  # This method can avoid OOM when prefix lengths are long.
242
250
  MHA_CHUNKED_KV = auto()
243
251
 
252
+ # Use multi-head attention, execute the MHA for prefix and extended kv in one shot
253
+ # when the sequence lengths are below the threshold.
254
+ MHA_ONE_SHOT = auto()
255
+
244
256
  # Use MLA but with fused RoPE
245
257
  MLA_FUSED_ROPE = auto()
246
258
 
@@ -278,6 +290,7 @@ def handle_attention_ascend(attn, forward_batch):
278
290
  forward_batch.forward_mode.is_extend()
279
291
  and not forward_batch.forward_mode.is_target_verify()
280
292
  and not forward_batch.forward_mode.is_draft_extend()
293
+ and not forward_batch.forward_mode.is_draft_extend_v2()
281
294
  ):
282
295
  if hasattr(attn, "indexer"):
283
296
  return AttnForwardMethod.NPU_MLA_SPARSE
@@ -306,6 +319,14 @@ def _is_extend_without_speculative(forward_batch):
306
319
  )
307
320
 
308
321
 
322
+ def _support_mha_one_shot(attn: DeepseekV2AttentionMLA, forward_batch, backend_name):
323
+ attn_supported = backend_name in ["fa3", "flashinfer", "flashmla"]
324
+ sum_seq_lens = (
325
+ sum(forward_batch.seq_lens_cpu) if forward_batch.seq_lens_cpu is not None else 0
326
+ )
327
+ return attn_supported and sum_seq_lens <= forward_batch.get_max_chunk_capacity()
328
+
329
+
309
330
  def _handle_attention_backend(
310
331
  attn: DeepseekV2AttentionMLA, forward_batch, backend_name
311
332
  ):
@@ -325,6 +346,8 @@ def _handle_attention_backend(
325
346
  or sum_extend_prefix_lens == 0
326
347
  )
327
348
  ):
349
+ if _support_mha_one_shot(attn, forward_batch, backend_name):
350
+ return AttnForwardMethod.MHA_ONE_SHOT
328
351
  return AttnForwardMethod.MHA_CHUNKED_KV
329
352
  else:
330
353
  return _dispatch_mla_subtype(attn, forward_batch)
@@ -335,7 +358,11 @@ def handle_attention_flashinfer(attn, forward_batch):
335
358
 
336
359
 
337
360
  def handle_attention_fa3(attn, forward_batch):
338
- return _handle_attention_backend(attn, forward_batch, "fa3")
361
+ # when deterministic inference is enabled, use MLA
362
+ if get_global_server_args().enable_deterministic_inference:
363
+ return _dispatch_mla_subtype(attn, forward_batch)
364
+ else:
365
+ return _handle_attention_backend(attn, forward_batch, "fa3")
339
366
 
340
367
 
341
368
  def handle_attention_flashmla(attn, forward_batch):
@@ -379,6 +406,10 @@ def handle_attention_nsa(attn, forward_batch):
379
406
 
380
407
 
381
408
  def handle_attention_triton(attn, forward_batch):
409
+ # when deterministic inference is enabled, use MLA
410
+ if get_global_server_args().enable_deterministic_inference:
411
+ return _dispatch_mla_subtype(attn, forward_batch)
412
+
382
413
  if (
383
414
  _is_extend_without_speculative(forward_batch)
384
415
  and sum(forward_batch.extend_prefix_lens_cpu) == 0
@@ -496,6 +527,9 @@ class MoEGate(nn.Module):
496
527
  True, # is_vnni
497
528
  )
498
529
 
530
+ if get_global_server_args().enable_deterministic_inference:
531
+ return F.linear(hidden_states, self.weight, None)
532
+
499
533
  # NOTE: For some unknown reason, router_gemm seems degrade accept length.
500
534
  if (
501
535
  _is_cuda
@@ -982,16 +1016,14 @@ class DeepseekV2MoE(nn.Module):
982
1016
  )
983
1017
 
984
1018
  def op_experts(self, state):
985
- state.hidden_states_experts_output = self.experts.run_moe_core(
1019
+ state.combine_input = self.experts.run_moe_core(
986
1020
  dispatch_output=state.dispatch_output,
987
1021
  )
988
1022
 
989
1023
  def op_combine_a(self, state):
990
1024
  if self.ep_size > 1:
991
1025
  self.experts.dispatcher.combine_a(
992
- hidden_states=state.pop("hidden_states_experts_output"),
993
- topk_ids=state.dispatch_output.topk_ids,
994
- topk_weights=state.dispatch_output.topk_weights,
1026
+ combine_input=state.pop("combine_input"),
995
1027
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
996
1028
  )
997
1029
  state.pop("dispatch_output")
@@ -1043,6 +1075,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1043
1075
  layer_id: int = None,
1044
1076
  prefix: str = "",
1045
1077
  alt_stream: Optional[torch.cuda.Stream] = None,
1078
+ skip_rope: bool = False,
1046
1079
  ) -> None:
1047
1080
  super().__init__()
1048
1081
  self.layer_id = layer_id
@@ -1062,6 +1095,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1062
1095
  self.scaling = self.qk_head_dim**-0.5
1063
1096
  self.rope_theta = rope_theta
1064
1097
  self.max_position_embeddings = max_position_embeddings
1098
+ self.kv_cache_dtype = get_global_server_args().kv_cache_dtype
1065
1099
 
1066
1100
  # NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
1067
1101
  if rope_scaling:
@@ -1122,6 +1156,9 @@ class DeepseekV2AttentionMLA(nn.Module):
1122
1156
  quant_config=quant_config,
1123
1157
  layer_id=layer_id,
1124
1158
  alt_stream=alt_stream,
1159
+ fuse_wk_and_weights_proj=is_nsa_indexer_wk_and_weights_proj_fused(
1160
+ config, quant_config
1161
+ ),
1125
1162
  )
1126
1163
 
1127
1164
  self.kv_b_proj = ColumnParallelLinear(
@@ -1146,23 +1183,26 @@ class DeepseekV2AttentionMLA(nn.Module):
1146
1183
  )
1147
1184
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
1148
1185
 
1149
- self.rotary_emb = get_rope_wrapper(
1150
- qk_rope_head_dim,
1151
- rotary_dim=qk_rope_head_dim,
1152
- max_position=max_position_embeddings,
1153
- base=rope_theta,
1154
- rope_scaling=rope_scaling,
1155
- is_neox_style=False,
1156
- device=get_global_server_args().device,
1157
- )
1186
+ if not skip_rope:
1187
+ self.rotary_emb = get_rope_wrapper(
1188
+ qk_rope_head_dim,
1189
+ rotary_dim=qk_rope_head_dim,
1190
+ max_position=max_position_embeddings,
1191
+ base=rope_theta,
1192
+ rope_scaling=rope_scaling,
1193
+ is_neox_style=False,
1194
+ device=get_global_server_args().device,
1195
+ )
1158
1196
 
1159
- if rope_scaling:
1160
- mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
1161
- scaling_factor = rope_scaling["factor"]
1162
- mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
1163
- self.scaling = self.scaling * mscale * mscale
1197
+ if rope_scaling:
1198
+ mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
1199
+ scaling_factor = rope_scaling["factor"]
1200
+ mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
1201
+ self.scaling = self.scaling * mscale * mscale
1202
+ else:
1203
+ self.rotary_emb.forward = self.rotary_emb.forward_native
1164
1204
  else:
1165
- self.rotary_emb.forward = self.rotary_emb.forward_native
1205
+ self.rotary_emb = None
1166
1206
 
1167
1207
  self.attn_mqa = RadixAttention(
1168
1208
  self.num_local_heads,
@@ -1238,7 +1278,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1238
1278
  and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
1239
1279
  and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
1240
1280
  and _is_cuda
1241
- and _device_sm >= 90
1281
+ and 90 <= _device_sm < 120
1242
1282
  )
1243
1283
 
1244
1284
  self.qkv_proj_with_rope_is_int8 = (
@@ -1359,6 +1399,10 @@ class DeepseekV2AttentionMLA(nn.Module):
1359
1399
  inner_state = self.forward_normal_chunked_kv_prepare(
1360
1400
  positions, hidden_states, forward_batch, zero_allocator
1361
1401
  )
1402
+ elif attn_forward_method == AttnForwardMethod.MHA_ONE_SHOT:
1403
+ inner_state = self.forward_normal_one_shot_prepare(
1404
+ positions, hidden_states, forward_batch, zero_allocator
1405
+ )
1362
1406
  elif attn_forward_method == AttnForwardMethod.MLA:
1363
1407
  if not self.is_mla_preprocess_enabled:
1364
1408
  inner_state = self.forward_absorb_prepare(
@@ -1410,6 +1454,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1410
1454
  return self.forward_normal_core(*inner_state)
1411
1455
  elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
1412
1456
  return self.forward_normal_chunked_kv_core(*inner_state)
1457
+ elif attn_forward_method == AttnForwardMethod.MHA_ONE_SHOT:
1458
+ return self.forward_normal_one_shot_core(*inner_state)
1413
1459
  elif attn_forward_method == AttnForwardMethod.MLA:
1414
1460
  return self.forward_absorb_core(*inner_state)
1415
1461
  elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
@@ -1444,41 +1490,25 @@ class DeepseekV2AttentionMLA(nn.Module):
1444
1490
  kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1445
1491
  latent_cache = latent_cache.unsqueeze(1)
1446
1492
  kv_a = self.kv_a_layernorm(kv_a)
1447
- kv = self.kv_b_proj(kv_a)[0]
1448
- kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
1449
- k_nope = kv[..., : self.qk_nope_head_dim]
1450
- v = kv[..., self.qk_nope_head_dim :]
1451
1493
  k_pe = latent_cache[:, :, self.kv_lora_rank :]
1452
- q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1494
+ if self.rotary_emb is not None:
1495
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1453
1496
  q[..., self.qk_nope_head_dim :] = q_pe
1454
- k = torch.empty_like(q)
1455
1497
 
1456
- # Temporary for DeepSeek V3/R1 only, but can generalize if needed
1498
+ self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch)
1457
1499
  if (
1458
- _is_cuda
1459
- and (self.num_local_heads == 128)
1460
- and (self.qk_nope_head_dim == 128)
1461
- and (self.qk_rope_head_dim == 64)
1500
+ forward_batch.mha_one_shot
1501
+ and sum(forward_batch.extend_prefix_lens_cpu) != 0
1462
1502
  ):
1463
- concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
1464
- else:
1465
- k[..., : self.qk_nope_head_dim] = k_nope
1466
- k[..., self.qk_nope_head_dim :] = k_pe
1467
-
1468
- if not _is_npu:
1469
- latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1470
- latent_cache[:, :, self.kv_lora_rank :] = k_pe
1471
-
1472
- # Save latent cache
1473
- forward_batch.token_to_kv_pool.set_kv_buffer(
1474
- self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1475
- )
1476
- else:
1477
- # To reduce a time-costing split operation
1478
- forward_batch.token_to_kv_pool.set_kv_buffer(
1479
- self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
1503
+ kv_a, k_pe = self._get_mla_kv_buffer(
1504
+ forward_batch.fetch_mha_one_shot_kv_indices(), q.dtype, forward_batch
1480
1505
  )
1506
+ kv = self.kv_b_proj(kv_a)[0]
1507
+ kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
1508
+ k_nope = kv[..., : self.qk_nope_head_dim]
1509
+ v = kv[..., self.qk_nope_head_dim :]
1481
1510
 
1511
+ k = self._concat_and_cast_mha_k(k_nope, k_pe, forward_batch)
1482
1512
  return q, k, v, forward_batch
1483
1513
 
1484
1514
  def forward_normal_core(self, q, k, v, forward_batch):
@@ -1621,8 +1651,10 @@ class DeepseekV2AttentionMLA(nn.Module):
1621
1651
 
1622
1652
  q_nope_out = q_nope_out.transpose(0, 1)
1623
1653
 
1624
- if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
1625
- not _use_aiter or not _is_gfx95_supported or self.use_nsa
1654
+ if (
1655
+ self.rotary_emb is not None
1656
+ and (not self._fuse_rope_for_trtllm_mla(forward_batch))
1657
+ and (not _use_aiter or not _is_gfx95_supported or self.use_nsa)
1626
1658
  ):
1627
1659
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1628
1660
 
@@ -2288,20 +2320,11 @@ class DeepseekV2AttentionMLA(nn.Module):
2288
2320
  for i in range(forward_batch.num_prefix_chunks):
2289
2321
  forward_batch.set_prefix_chunk_idx(i)
2290
2322
 
2323
+ kv_indices = forward_batch.prefix_chunk_kv_indices[i]
2291
2324
  # Fetch latent cache from memory pool with precomputed chunked kv indices
2292
- latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
2293
- self.attn_mha.layer_id
2294
- )
2295
- latent_cache = (
2296
- latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
2297
- .contiguous()
2298
- .to(q.dtype)
2299
- )
2300
-
2301
- kv_a_normed, k_pe = latent_cache.split(
2302
- [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
2325
+ kv_a_normed, k_pe = self._get_mla_kv_buffer(
2326
+ kv_indices, q.dtype, forward_batch
2303
2327
  )
2304
- kv_a_normed = kv_a_normed.squeeze(1).contiguous()
2305
2328
  kv = self.kv_b_proj(kv_a_normed)[0]
2306
2329
  kv = kv.view(
2307
2330
  -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
@@ -2376,6 +2399,107 @@ class DeepseekV2AttentionMLA(nn.Module):
2376
2399
  output, _ = self.o_proj(attn_output)
2377
2400
  return output
2378
2401
 
2402
+ def forward_normal_one_shot_prepare(
2403
+ self,
2404
+ positions: torch.Tensor,
2405
+ hidden_states: torch.Tensor,
2406
+ forward_batch: ForwardBatch,
2407
+ zero_allocator: BumpAllocator,
2408
+ ):
2409
+ forward_batch.mha_one_shot = True
2410
+ return self.forward_normal_prepare(
2411
+ positions, hidden_states, forward_batch, zero_allocator
2412
+ )
2413
+
2414
+ def forward_normal_one_shot_core(self, q, k, v, forward_batch):
2415
+ has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
2416
+ # Only initialize the info once
2417
+ if has_extend_prefix and forward_batch.num_prefix_chunks is None:
2418
+ forward_batch.num_prefix_chunks = 0
2419
+ if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
2420
+ forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
2421
+ forward_batch.mha_return_lse = False
2422
+ # Do mha for extended part without prefix
2423
+ forward_batch.set_attn_attend_prefix_cache(False)
2424
+ return self.forward_normal_core(q, k, v, forward_batch)
2425
+
2426
+ def _set_mla_kv_buffer(
2427
+ self,
2428
+ latent_cache: torch.Tensor,
2429
+ kv_a: torch.Tensor,
2430
+ k_pe: torch.Tensor,
2431
+ forward_batch: ForwardBatch,
2432
+ ):
2433
+ if _is_cuda:
2434
+ # Save latent cache
2435
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
2436
+ self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
2437
+ )
2438
+ elif _is_npu:
2439
+ # To reduce a time-costing split operation
2440
+ forward_batch.token_to_kv_pool.set_kv_buffer(
2441
+ self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
2442
+ )
2443
+ else:
2444
+ latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
2445
+ latent_cache[:, :, self.kv_lora_rank :] = k_pe
2446
+
2447
+ # Save latent cache
2448
+ forward_batch.token_to_kv_pool.set_kv_buffer(
2449
+ self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
2450
+ )
2451
+
2452
+ def _get_mla_kv_buffer(
2453
+ self,
2454
+ kv_indices: torch.Tensor,
2455
+ dst_dtype: torch.dtype,
2456
+ forward_batch: ForwardBatch,
2457
+ ):
2458
+ if _is_cuda:
2459
+ kv_a, k_pe = forward_batch.token_to_kv_pool.get_mla_kv_buffer(
2460
+ self.attn_mha, kv_indices, dst_dtype
2461
+ )
2462
+ kv_a = kv_a.squeeze(1)
2463
+ else:
2464
+ latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
2465
+ self.attn_mha.layer_id
2466
+ )
2467
+ latent_cache = latent_cache_buf[kv_indices].contiguous().to(dst_dtype)
2468
+
2469
+ kv_a, k_pe = latent_cache.split(
2470
+ [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
2471
+ )
2472
+ kv_a = kv_a.squeeze(1).contiguous()
2473
+ return kv_a, k_pe
2474
+
2475
+ def _concat_and_cast_mha_k(self, k_nope, k_pe, forward_batch):
2476
+ # Temporary for DeepSeek V3/R1 only, but can generalize if needed
2477
+ k_shape = (k_nope.shape[0], self.num_local_heads, self.qk_head_dim)
2478
+ if (
2479
+ _is_cuda
2480
+ and (self.num_local_heads == 128)
2481
+ and (self.qk_nope_head_dim == 128)
2482
+ and (self.qk_rope_head_dim == 64)
2483
+ ):
2484
+ k = k_nope.new_empty(*k_shape)
2485
+ concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
2486
+ elif _is_cuda:
2487
+ # fa3 mha support fp8 inputs
2488
+ if (
2489
+ self.current_attention_backend == "fa3"
2490
+ and self.kv_cache_dtype != "auto"
2491
+ ):
2492
+ attn_dtype = forward_batch.token_to_kv_pool.dtype
2493
+ else:
2494
+ attn_dtype = k_nope.dtype
2495
+ k = k_nope.new_empty(*k_shape, dtype=attn_dtype)
2496
+ concat_and_cast_mha_k_triton(k, k_nope, k_pe)
2497
+ else:
2498
+ k = k_nope.new_empty(*k_shape)
2499
+ k[..., : self.qk_nope_head_dim] = k_nope
2500
+ k[..., self.qk_nope_head_dim :] = k_pe
2501
+ return k
2502
+
2379
2503
  @staticmethod
2380
2504
  def _get_q_b_proj_quant_config(quant_config):
2381
2505
  if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
@@ -2725,6 +2849,7 @@ class DeepseekV2Model(nn.Module):
2725
2849
  self.embed_tokens.embedding_dim,
2726
2850
  )
2727
2851
  )
2852
+ self.layers_to_capture = []
2728
2853
 
2729
2854
  def get_input_embeddings(self) -> torch.Tensor:
2730
2855
  return self.embed_tokens
@@ -2781,9 +2906,11 @@ class DeepseekV2Model(nn.Module):
2781
2906
  normal_end_layer = self.first_k_dense_replace
2782
2907
  elif self.first_k_dense_replace < normal_start_layer:
2783
2908
  normal_end_layer = normal_start_layer = 0
2784
-
2909
+ aux_hidden_states = []
2785
2910
  for i in range(normal_start_layer, normal_end_layer):
2786
2911
  with get_global_expert_distribution_recorder().with_current_layer(i):
2912
+ if i in self.layers_to_capture:
2913
+ aux_hidden_states.append(hidden_states + residual)
2787
2914
  layer = self.layers[i]
2788
2915
  hidden_states, residual = layer(
2789
2916
  positions,
@@ -2821,7 +2948,9 @@ class DeepseekV2Model(nn.Module):
2821
2948
  hidden_states = self.norm(hidden_states)
2822
2949
  else:
2823
2950
  hidden_states, _ = self.norm(hidden_states, residual)
2824
- return hidden_states
2951
+ if len(aux_hidden_states) == 0:
2952
+ return hidden_states
2953
+ return hidden_states, aux_hidden_states
2825
2954
 
2826
2955
 
2827
2956
  class DeepseekV2ForCausalLM(nn.Module):
@@ -2875,6 +3004,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2875
3004
  if isinstance(layer.mlp, DeepseekV2MoE)
2876
3005
  }
2877
3006
  )
3007
+ self.capture_aux_hidden_states = False
2878
3008
 
2879
3009
  @property
2880
3010
  def routed_experts_weights_of_layer(self):
@@ -2899,7 +3029,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2899
3029
  disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
2900
3030
  elif get_moe_expert_parallel_world_size() > 1:
2901
3031
  disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
2902
- elif self.quant_config.get_name() == "w4afp8":
3032
+ elif self.quant_config and self.quant_config.get_name() == "w4afp8":
2903
3033
  disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
2904
3034
 
2905
3035
  if disable_reason is not None:
@@ -2928,10 +3058,13 @@ class DeepseekV2ForCausalLM(nn.Module):
2928
3058
  hidden_states = self.model(
2929
3059
  input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
2930
3060
  )
3061
+ aux_hidden_states = None
3062
+ if self.capture_aux_hidden_states:
3063
+ hidden_states, aux_hidden_states = hidden_states
2931
3064
 
2932
3065
  if self.pp_group.is_last_rank:
2933
3066
  return self.logits_processor(
2934
- input_ids, hidden_states, self.lm_head, forward_batch
3067
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
2935
3068
  )
2936
3069
  else:
2937
3070
  return hidden_states
@@ -3190,8 +3323,8 @@ class DeepseekV2ForCausalLM(nn.Module):
3190
3323
  experts = layer.mlp.experts
3191
3324
  if isinstance(experts, DeepEPMoE):
3192
3325
  for w in [
3193
- experts.w13_weight_fp8,
3194
- experts.w2_weight_fp8,
3326
+ (experts.w13_weight, experts.w13_weight_scale_inv),
3327
+ (experts.w2_weight, experts.w2_weight_scale_inv),
3195
3328
  ]:
3196
3329
  requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
3197
3330
  else:
@@ -3239,10 +3372,26 @@ class DeepseekV2ForCausalLM(nn.Module):
3239
3372
  )
3240
3373
 
3241
3374
  experts = layer.mlp.experts
3375
+ w13_weight_fp8 = (
3376
+ experts.w13_weight,
3377
+ (
3378
+ experts.w13_weight_scale_inv
3379
+ if hasattr(experts, "w13_weight_scale_inv")
3380
+ else experts.w13_weight_scale
3381
+ ),
3382
+ )
3383
+ w2_weight_fp8 = (
3384
+ experts.w2_weight,
3385
+ (
3386
+ experts.w2_weight_scale_inv
3387
+ if hasattr(experts, "w2_weight_scale_inv")
3388
+ else experts.w2_weight_scale
3389
+ ),
3390
+ )
3242
3391
  if isinstance(experts, DeepEPMoE):
3243
3392
  for w in [
3244
- experts.w13_weight_fp8,
3245
- experts.w2_weight_fp8,
3393
+ w13_weight_fp8,
3394
+ w2_weight_fp8,
3246
3395
  ]:
3247
3396
  transform_scale_ue8m0_inplace(w[1], mn=w[0].shape[-2])
3248
3397
 
@@ -3295,6 +3444,10 @@ class DeepseekV2ForCausalLM(nn.Module):
3295
3444
  self.config.q_lora_rank is not None
3296
3445
  )
3297
3446
  cached_a_proj = {} if fuse_qkv_a_proj else None
3447
+ fuse_wk_and_weights_proj = is_nsa_indexer_wk_and_weights_proj_fused(
3448
+ self.config, self.quant_config
3449
+ )
3450
+ cached_wk_and_weights_proj = {} if fuse_wk_and_weights_proj else None
3298
3451
 
3299
3452
  if is_nextn:
3300
3453
  nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
@@ -3466,6 +3619,53 @@ class DeepseekV2ForCausalLM(nn.Module):
3466
3619
  )
3467
3620
  cached_a_proj.pop(q_a_proj_name)
3468
3621
  cached_a_proj.pop(kv_a_proj_name)
3622
+ elif fuse_wk_and_weights_proj and (
3623
+ "wk" in name or "weights_proj" in name
3624
+ ):
3625
+ cached_wk_and_weights_proj[name] = loaded_weight
3626
+ wk_name = (
3627
+ name
3628
+ if "wk" in name
3629
+ else name.replace("weights_proj", "wk")
3630
+ )
3631
+ weights_proj_name = (
3632
+ name
3633
+ if "weights_proj" in name
3634
+ else name.replace("wk", "weights_proj")
3635
+ )
3636
+
3637
+ # When both wk and weights_proj has been cached, load the fused weight to parameter
3638
+ if (
3639
+ wk_name in cached_wk_and_weights_proj
3640
+ and weights_proj_name in cached_wk_and_weights_proj
3641
+ ):
3642
+ wk_weight = cached_wk_and_weights_proj[wk_name]
3643
+ weights_proj_weight = cached_wk_and_weights_proj[
3644
+ weights_proj_name
3645
+ ]
3646
+ # todo dequantize wk for fp8
3647
+ assert wk_weight.dtype == weights_proj_weight.dtype
3648
+ fused_weight = torch.cat(
3649
+ [wk_weight, weights_proj_weight], dim=0
3650
+ )
3651
+ param_name = (
3652
+ name.replace("wk", "fused_wk_and_weights_proj")
3653
+ if "wk" in name
3654
+ else name.replace(
3655
+ "weights_proj",
3656
+ "fused_wk_and_weights_proj",
3657
+ )
3658
+ )
3659
+ param = params_dict[param_name]
3660
+
3661
+ weight_loader = getattr(
3662
+ param, "weight_loader", default_weight_loader
3663
+ )
3664
+ futures.append(
3665
+ executor.submit(weight_loader, param, fused_weight)
3666
+ )
3667
+ cached_wk_and_weights_proj.pop(wk_name)
3668
+ cached_wk_and_weights_proj.pop(weights_proj_name)
3469
3669
  else:
3470
3670
  if (
3471
3671
  "k_scale" in name or "v_scale" in name
@@ -3561,8 +3761,12 @@ class DeepseekV2ForCausalLM(nn.Module):
3561
3761
  del self.lm_head.weight
3562
3762
  self.model.embed_tokens.weight = embed
3563
3763
  self.lm_head.weight = head
3564
- torch.cuda.empty_cache()
3565
- torch.cuda.synchronize()
3764
+ if not _is_npu:
3765
+ torch.cuda.empty_cache()
3766
+ torch.cuda.synchronize()
3767
+ else:
3768
+ torch.npu.empty_cache()
3769
+ torch.npu.synchronize()
3566
3770
 
3567
3771
  @classmethod
3568
3772
  def get_model_config_for_expert_location(cls, config):
@@ -3572,6 +3776,20 @@ class DeepseekV2ForCausalLM(nn.Module):
3572
3776
  num_groups=config.n_group,
3573
3777
  )
3574
3778
 
3779
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
3780
+ if not self.pp_group.is_last_rank:
3781
+ return
3782
+
3783
+ if layer_ids is None:
3784
+ self.capture_aux_hidden_states = True
3785
+ num_layers = self.config.num_hidden_layers
3786
+ self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
3787
+ else:
3788
+ self.capture_aux_hidden_states = True
3789
+ # we plus 1 here because in sglang, for the ith layer, it takes the output
3790
+ # of the (i-1)th layer as aux hidden state
3791
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
3792
+
3575
3793
 
3576
3794
  AttentionBackendRegistry.register("ascend", handle_attention_ascend)
3577
3795
  AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)