sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,7 @@ import concurrent.futures
21
21
  import logging
22
22
  import os
23
23
  from enum import IntEnum, auto
24
- from typing import Any, Dict, Iterable, Optional, Tuple, Union
24
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
25
25
 
26
26
  import torch
27
27
  import torch.nn.functional as F
@@ -131,13 +131,11 @@ from sglang.srt.utils import (
131
131
  get_int_env_var,
132
132
  is_cpu,
133
133
  is_cuda,
134
- is_flashinfer_available,
135
134
  is_gfx95_supported,
136
135
  is_hip,
137
136
  is_non_idle_and_non_empty,
138
137
  is_npu,
139
138
  is_nvidia_cublas_cu12_version_ge_12_9,
140
- is_sm100_supported,
141
139
  log_info_on_rank0,
142
140
  make_layers,
143
141
  use_intel_amx_backend,
@@ -197,8 +195,6 @@ elif _is_npu:
197
195
  else:
198
196
  pass
199
197
 
200
- _is_flashinfer_available = is_flashinfer_available()
201
- _is_sm100_supported = is_cuda() and is_sm100_supported()
202
198
  _is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
203
199
 
204
200
  logger = logging.getLogger(__name__)
@@ -228,6 +224,17 @@ def add_forward_absorb_core_attention_backend(backend_name):
228
224
  logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
229
225
 
230
226
 
227
+ def is_nsa_indexer_wk_and_weights_proj_fused(config, quant_config):
228
+ """
229
+ NSA Indexer wk and weights_proj can be fused in FP4 model because they are both in BF16
230
+ """
231
+ return (
232
+ is_deepseek_nsa(config)
233
+ and quant_config is not None
234
+ and quant_config.get_name() == "modelopt_fp4"
235
+ )
236
+
237
+
231
238
  class AttnForwardMethod(IntEnum):
232
239
  # Use multi-head attention
233
240
  MHA = auto()
@@ -283,6 +290,7 @@ def handle_attention_ascend(attn, forward_batch):
283
290
  forward_batch.forward_mode.is_extend()
284
291
  and not forward_batch.forward_mode.is_target_verify()
285
292
  and not forward_batch.forward_mode.is_draft_extend()
293
+ and not forward_batch.forward_mode.is_draft_extend_v2()
286
294
  ):
287
295
  if hasattr(attn, "indexer"):
288
296
  return AttnForwardMethod.NPU_MLA_SPARSE
@@ -519,6 +527,9 @@ class MoEGate(nn.Module):
519
527
  True, # is_vnni
520
528
  )
521
529
 
530
+ if get_global_server_args().enable_deterministic_inference:
531
+ return F.linear(hidden_states, self.weight, None)
532
+
522
533
  # NOTE: For some unknown reason, router_gemm seems degrade accept length.
523
534
  if (
524
535
  _is_cuda
@@ -1064,6 +1075,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1064
1075
  layer_id: int = None,
1065
1076
  prefix: str = "",
1066
1077
  alt_stream: Optional[torch.cuda.Stream] = None,
1078
+ skip_rope: bool = False,
1067
1079
  ) -> None:
1068
1080
  super().__init__()
1069
1081
  self.layer_id = layer_id
@@ -1144,6 +1156,9 @@ class DeepseekV2AttentionMLA(nn.Module):
1144
1156
  quant_config=quant_config,
1145
1157
  layer_id=layer_id,
1146
1158
  alt_stream=alt_stream,
1159
+ fuse_wk_and_weights_proj=is_nsa_indexer_wk_and_weights_proj_fused(
1160
+ config, quant_config
1161
+ ),
1147
1162
  )
1148
1163
 
1149
1164
  self.kv_b_proj = ColumnParallelLinear(
@@ -1168,23 +1183,26 @@ class DeepseekV2AttentionMLA(nn.Module):
1168
1183
  )
1169
1184
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
1170
1185
 
1171
- self.rotary_emb = get_rope_wrapper(
1172
- qk_rope_head_dim,
1173
- rotary_dim=qk_rope_head_dim,
1174
- max_position=max_position_embeddings,
1175
- base=rope_theta,
1176
- rope_scaling=rope_scaling,
1177
- is_neox_style=False,
1178
- device=get_global_server_args().device,
1179
- )
1186
+ if not skip_rope:
1187
+ self.rotary_emb = get_rope_wrapper(
1188
+ qk_rope_head_dim,
1189
+ rotary_dim=qk_rope_head_dim,
1190
+ max_position=max_position_embeddings,
1191
+ base=rope_theta,
1192
+ rope_scaling=rope_scaling,
1193
+ is_neox_style=False,
1194
+ device=get_global_server_args().device,
1195
+ )
1180
1196
 
1181
- if rope_scaling:
1182
- mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
1183
- scaling_factor = rope_scaling["factor"]
1184
- mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
1185
- self.scaling = self.scaling * mscale * mscale
1197
+ if rope_scaling:
1198
+ mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
1199
+ scaling_factor = rope_scaling["factor"]
1200
+ mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
1201
+ self.scaling = self.scaling * mscale * mscale
1202
+ else:
1203
+ self.rotary_emb.forward = self.rotary_emb.forward_native
1186
1204
  else:
1187
- self.rotary_emb.forward = self.rotary_emb.forward_native
1205
+ self.rotary_emb = None
1188
1206
 
1189
1207
  self.attn_mqa = RadixAttention(
1190
1208
  self.num_local_heads,
@@ -1260,7 +1278,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1260
1278
  and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
1261
1279
  and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
1262
1280
  and _is_cuda
1263
- and _device_sm >= 90
1281
+ and 90 <= _device_sm < 120
1264
1282
  )
1265
1283
 
1266
1284
  self.qkv_proj_with_rope_is_int8 = (
@@ -1473,7 +1491,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1473
1491
  latent_cache = latent_cache.unsqueeze(1)
1474
1492
  kv_a = self.kv_a_layernorm(kv_a)
1475
1493
  k_pe = latent_cache[:, :, self.kv_lora_rank :]
1476
- q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1494
+ if self.rotary_emb is not None:
1495
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1477
1496
  q[..., self.qk_nope_head_dim :] = q_pe
1478
1497
 
1479
1498
  self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch)
@@ -1632,8 +1651,10 @@ class DeepseekV2AttentionMLA(nn.Module):
1632
1651
 
1633
1652
  q_nope_out = q_nope_out.transpose(0, 1)
1634
1653
 
1635
- if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
1636
- not _use_aiter or not _is_gfx95_supported or self.use_nsa
1654
+ if (
1655
+ self.rotary_emb is not None
1656
+ and (not self._fuse_rope_for_trtllm_mla(forward_batch))
1657
+ and (not _use_aiter or not _is_gfx95_supported or self.use_nsa)
1637
1658
  ):
1638
1659
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1639
1660
 
@@ -2828,6 +2849,7 @@ class DeepseekV2Model(nn.Module):
2828
2849
  self.embed_tokens.embedding_dim,
2829
2850
  )
2830
2851
  )
2852
+ self.layers_to_capture = []
2831
2853
 
2832
2854
  def get_input_embeddings(self) -> torch.Tensor:
2833
2855
  return self.embed_tokens
@@ -2884,9 +2906,11 @@ class DeepseekV2Model(nn.Module):
2884
2906
  normal_end_layer = self.first_k_dense_replace
2885
2907
  elif self.first_k_dense_replace < normal_start_layer:
2886
2908
  normal_end_layer = normal_start_layer = 0
2887
-
2909
+ aux_hidden_states = []
2888
2910
  for i in range(normal_start_layer, normal_end_layer):
2889
2911
  with get_global_expert_distribution_recorder().with_current_layer(i):
2912
+ if i in self.layers_to_capture:
2913
+ aux_hidden_states.append(hidden_states + residual)
2890
2914
  layer = self.layers[i]
2891
2915
  hidden_states, residual = layer(
2892
2916
  positions,
@@ -2924,7 +2948,9 @@ class DeepseekV2Model(nn.Module):
2924
2948
  hidden_states = self.norm(hidden_states)
2925
2949
  else:
2926
2950
  hidden_states, _ = self.norm(hidden_states, residual)
2927
- return hidden_states
2951
+ if len(aux_hidden_states) == 0:
2952
+ return hidden_states
2953
+ return hidden_states, aux_hidden_states
2928
2954
 
2929
2955
 
2930
2956
  class DeepseekV2ForCausalLM(nn.Module):
@@ -2978,6 +3004,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2978
3004
  if isinstance(layer.mlp, DeepseekV2MoE)
2979
3005
  }
2980
3006
  )
3007
+ self.capture_aux_hidden_states = False
2981
3008
 
2982
3009
  @property
2983
3010
  def routed_experts_weights_of_layer(self):
@@ -3002,7 +3029,7 @@ class DeepseekV2ForCausalLM(nn.Module):
3002
3029
  disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
3003
3030
  elif get_moe_expert_parallel_world_size() > 1:
3004
3031
  disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
3005
- elif self.quant_config.get_name() == "w4afp8":
3032
+ elif self.quant_config and self.quant_config.get_name() == "w4afp8":
3006
3033
  disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
3007
3034
 
3008
3035
  if disable_reason is not None:
@@ -3031,10 +3058,13 @@ class DeepseekV2ForCausalLM(nn.Module):
3031
3058
  hidden_states = self.model(
3032
3059
  input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
3033
3060
  )
3061
+ aux_hidden_states = None
3062
+ if self.capture_aux_hidden_states:
3063
+ hidden_states, aux_hidden_states = hidden_states
3034
3064
 
3035
3065
  if self.pp_group.is_last_rank:
3036
3066
  return self.logits_processor(
3037
- input_ids, hidden_states, self.lm_head, forward_batch
3067
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
3038
3068
  )
3039
3069
  else:
3040
3070
  return hidden_states
@@ -3293,8 +3323,8 @@ class DeepseekV2ForCausalLM(nn.Module):
3293
3323
  experts = layer.mlp.experts
3294
3324
  if isinstance(experts, DeepEPMoE):
3295
3325
  for w in [
3296
- experts.w13_weight_fp8,
3297
- experts.w2_weight_fp8,
3326
+ (experts.w13_weight, experts.w13_weight_scale_inv),
3327
+ (experts.w2_weight, experts.w2_weight_scale_inv),
3298
3328
  ]:
3299
3329
  requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
3300
3330
  else:
@@ -3342,10 +3372,26 @@ class DeepseekV2ForCausalLM(nn.Module):
3342
3372
  )
3343
3373
 
3344
3374
  experts = layer.mlp.experts
3375
+ w13_weight_fp8 = (
3376
+ experts.w13_weight,
3377
+ (
3378
+ experts.w13_weight_scale_inv
3379
+ if hasattr(experts, "w13_weight_scale_inv")
3380
+ else experts.w13_weight_scale
3381
+ ),
3382
+ )
3383
+ w2_weight_fp8 = (
3384
+ experts.w2_weight,
3385
+ (
3386
+ experts.w2_weight_scale_inv
3387
+ if hasattr(experts, "w2_weight_scale_inv")
3388
+ else experts.w2_weight_scale
3389
+ ),
3390
+ )
3345
3391
  if isinstance(experts, DeepEPMoE):
3346
3392
  for w in [
3347
- experts.w13_weight_fp8,
3348
- experts.w2_weight_fp8,
3393
+ w13_weight_fp8,
3394
+ w2_weight_fp8,
3349
3395
  ]:
3350
3396
  transform_scale_ue8m0_inplace(w[1], mn=w[0].shape[-2])
3351
3397
 
@@ -3398,6 +3444,10 @@ class DeepseekV2ForCausalLM(nn.Module):
3398
3444
  self.config.q_lora_rank is not None
3399
3445
  )
3400
3446
  cached_a_proj = {} if fuse_qkv_a_proj else None
3447
+ fuse_wk_and_weights_proj = is_nsa_indexer_wk_and_weights_proj_fused(
3448
+ self.config, self.quant_config
3449
+ )
3450
+ cached_wk_and_weights_proj = {} if fuse_wk_and_weights_proj else None
3401
3451
 
3402
3452
  if is_nextn:
3403
3453
  nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
@@ -3569,6 +3619,53 @@ class DeepseekV2ForCausalLM(nn.Module):
3569
3619
  )
3570
3620
  cached_a_proj.pop(q_a_proj_name)
3571
3621
  cached_a_proj.pop(kv_a_proj_name)
3622
+ elif fuse_wk_and_weights_proj and (
3623
+ "wk" in name or "weights_proj" in name
3624
+ ):
3625
+ cached_wk_and_weights_proj[name] = loaded_weight
3626
+ wk_name = (
3627
+ name
3628
+ if "wk" in name
3629
+ else name.replace("weights_proj", "wk")
3630
+ )
3631
+ weights_proj_name = (
3632
+ name
3633
+ if "weights_proj" in name
3634
+ else name.replace("wk", "weights_proj")
3635
+ )
3636
+
3637
+ # When both wk and weights_proj has been cached, load the fused weight to parameter
3638
+ if (
3639
+ wk_name in cached_wk_and_weights_proj
3640
+ and weights_proj_name in cached_wk_and_weights_proj
3641
+ ):
3642
+ wk_weight = cached_wk_and_weights_proj[wk_name]
3643
+ weights_proj_weight = cached_wk_and_weights_proj[
3644
+ weights_proj_name
3645
+ ]
3646
+ # todo dequantize wk for fp8
3647
+ assert wk_weight.dtype == weights_proj_weight.dtype
3648
+ fused_weight = torch.cat(
3649
+ [wk_weight, weights_proj_weight], dim=0
3650
+ )
3651
+ param_name = (
3652
+ name.replace("wk", "fused_wk_and_weights_proj")
3653
+ if "wk" in name
3654
+ else name.replace(
3655
+ "weights_proj",
3656
+ "fused_wk_and_weights_proj",
3657
+ )
3658
+ )
3659
+ param = params_dict[param_name]
3660
+
3661
+ weight_loader = getattr(
3662
+ param, "weight_loader", default_weight_loader
3663
+ )
3664
+ futures.append(
3665
+ executor.submit(weight_loader, param, fused_weight)
3666
+ )
3667
+ cached_wk_and_weights_proj.pop(wk_name)
3668
+ cached_wk_and_weights_proj.pop(weights_proj_name)
3572
3669
  else:
3573
3670
  if (
3574
3671
  "k_scale" in name or "v_scale" in name
@@ -3664,8 +3761,12 @@ class DeepseekV2ForCausalLM(nn.Module):
3664
3761
  del self.lm_head.weight
3665
3762
  self.model.embed_tokens.weight = embed
3666
3763
  self.lm_head.weight = head
3667
- torch.cuda.empty_cache()
3668
- torch.cuda.synchronize()
3764
+ if not _is_npu:
3765
+ torch.cuda.empty_cache()
3766
+ torch.cuda.synchronize()
3767
+ else:
3768
+ torch.npu.empty_cache()
3769
+ torch.npu.synchronize()
3669
3770
 
3670
3771
  @classmethod
3671
3772
  def get_model_config_for_expert_location(cls, config):
@@ -3675,6 +3776,20 @@ class DeepseekV2ForCausalLM(nn.Module):
3675
3776
  num_groups=config.n_group,
3676
3777
  )
3677
3778
 
3779
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
3780
+ if not self.pp_group.is_last_rank:
3781
+ return
3782
+
3783
+ if layer_ids is None:
3784
+ self.capture_aux_hidden_states = True
3785
+ num_layers = self.config.num_hidden_layers
3786
+ self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
3787
+ else:
3788
+ self.capture_aux_hidden_states = True
3789
+ # we plus 1 here because in sglang, for the ith layer, it takes the output
3790
+ # of the (i-1)th layer as aux hidden state
3791
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
3792
+
3678
3793
 
3679
3794
  AttentionBackendRegistry.register("ascend", handle_attention_ascend)
3680
3795
  AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)