sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
|
@@ -21,7 +21,7 @@ import concurrent.futures
|
|
|
21
21
|
import logging
|
|
22
22
|
import os
|
|
23
23
|
from enum import IntEnum, auto
|
|
24
|
-
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|
24
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
25
25
|
|
|
26
26
|
import torch
|
|
27
27
|
import torch.nn.functional as F
|
|
@@ -57,6 +57,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
|
|
|
57
57
|
is_mla_preprocess_enabled,
|
|
58
58
|
)
|
|
59
59
|
from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
|
|
60
|
+
from sglang.srt.layers.attention.utils import concat_and_cast_mha_k_triton
|
|
60
61
|
from sglang.srt.layers.communicator import (
|
|
61
62
|
LayerCommunicator,
|
|
62
63
|
LayerScatterModes,
|
|
@@ -130,13 +131,11 @@ from sglang.srt.utils import (
|
|
|
130
131
|
get_int_env_var,
|
|
131
132
|
is_cpu,
|
|
132
133
|
is_cuda,
|
|
133
|
-
is_flashinfer_available,
|
|
134
134
|
is_gfx95_supported,
|
|
135
135
|
is_hip,
|
|
136
136
|
is_non_idle_and_non_empty,
|
|
137
137
|
is_npu,
|
|
138
138
|
is_nvidia_cublas_cu12_version_ge_12_9,
|
|
139
|
-
is_sm100_supported,
|
|
140
139
|
log_info_on_rank0,
|
|
141
140
|
make_layers,
|
|
142
141
|
use_intel_amx_backend,
|
|
@@ -196,8 +195,6 @@ elif _is_npu:
|
|
|
196
195
|
else:
|
|
197
196
|
pass
|
|
198
197
|
|
|
199
|
-
_is_flashinfer_available = is_flashinfer_available()
|
|
200
|
-
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
|
201
198
|
_is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
|
|
202
199
|
|
|
203
200
|
logger = logging.getLogger(__name__)
|
|
@@ -227,6 +224,17 @@ def add_forward_absorb_core_attention_backend(backend_name):
|
|
|
227
224
|
logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
|
|
228
225
|
|
|
229
226
|
|
|
227
|
+
def is_nsa_indexer_wk_and_weights_proj_fused(config, quant_config):
|
|
228
|
+
"""
|
|
229
|
+
NSA Indexer wk and weights_proj can be fused in FP4 model because they are both in BF16
|
|
230
|
+
"""
|
|
231
|
+
return (
|
|
232
|
+
is_deepseek_nsa(config)
|
|
233
|
+
and quant_config is not None
|
|
234
|
+
and quant_config.get_name() == "modelopt_fp4"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
|
|
230
238
|
class AttnForwardMethod(IntEnum):
|
|
231
239
|
# Use multi-head attention
|
|
232
240
|
MHA = auto()
|
|
@@ -241,6 +249,10 @@ class AttnForwardMethod(IntEnum):
|
|
|
241
249
|
# This method can avoid OOM when prefix lengths are long.
|
|
242
250
|
MHA_CHUNKED_KV = auto()
|
|
243
251
|
|
|
252
|
+
# Use multi-head attention, execute the MHA for prefix and extended kv in one shot
|
|
253
|
+
# when the sequence lengths are below the threshold.
|
|
254
|
+
MHA_ONE_SHOT = auto()
|
|
255
|
+
|
|
244
256
|
# Use MLA but with fused RoPE
|
|
245
257
|
MLA_FUSED_ROPE = auto()
|
|
246
258
|
|
|
@@ -278,6 +290,7 @@ def handle_attention_ascend(attn, forward_batch):
|
|
|
278
290
|
forward_batch.forward_mode.is_extend()
|
|
279
291
|
and not forward_batch.forward_mode.is_target_verify()
|
|
280
292
|
and not forward_batch.forward_mode.is_draft_extend()
|
|
293
|
+
and not forward_batch.forward_mode.is_draft_extend_v2()
|
|
281
294
|
):
|
|
282
295
|
if hasattr(attn, "indexer"):
|
|
283
296
|
return AttnForwardMethod.NPU_MLA_SPARSE
|
|
@@ -306,6 +319,14 @@ def _is_extend_without_speculative(forward_batch):
|
|
|
306
319
|
)
|
|
307
320
|
|
|
308
321
|
|
|
322
|
+
def _support_mha_one_shot(attn: DeepseekV2AttentionMLA, forward_batch, backend_name):
|
|
323
|
+
attn_supported = backend_name in ["fa3", "flashinfer", "flashmla"]
|
|
324
|
+
sum_seq_lens = (
|
|
325
|
+
sum(forward_batch.seq_lens_cpu) if forward_batch.seq_lens_cpu is not None else 0
|
|
326
|
+
)
|
|
327
|
+
return attn_supported and sum_seq_lens <= forward_batch.get_max_chunk_capacity()
|
|
328
|
+
|
|
329
|
+
|
|
309
330
|
def _handle_attention_backend(
|
|
310
331
|
attn: DeepseekV2AttentionMLA, forward_batch, backend_name
|
|
311
332
|
):
|
|
@@ -325,6 +346,8 @@ def _handle_attention_backend(
|
|
|
325
346
|
or sum_extend_prefix_lens == 0
|
|
326
347
|
)
|
|
327
348
|
):
|
|
349
|
+
if _support_mha_one_shot(attn, forward_batch, backend_name):
|
|
350
|
+
return AttnForwardMethod.MHA_ONE_SHOT
|
|
328
351
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
|
329
352
|
else:
|
|
330
353
|
return _dispatch_mla_subtype(attn, forward_batch)
|
|
@@ -335,7 +358,11 @@ def handle_attention_flashinfer(attn, forward_batch):
|
|
|
335
358
|
|
|
336
359
|
|
|
337
360
|
def handle_attention_fa3(attn, forward_batch):
|
|
338
|
-
|
|
361
|
+
# when deterministic inference is enabled, use MLA
|
|
362
|
+
if get_global_server_args().enable_deterministic_inference:
|
|
363
|
+
return _dispatch_mla_subtype(attn, forward_batch)
|
|
364
|
+
else:
|
|
365
|
+
return _handle_attention_backend(attn, forward_batch, "fa3")
|
|
339
366
|
|
|
340
367
|
|
|
341
368
|
def handle_attention_flashmla(attn, forward_batch):
|
|
@@ -379,6 +406,10 @@ def handle_attention_nsa(attn, forward_batch):
|
|
|
379
406
|
|
|
380
407
|
|
|
381
408
|
def handle_attention_triton(attn, forward_batch):
|
|
409
|
+
# when deterministic inference is enabled, use MLA
|
|
410
|
+
if get_global_server_args().enable_deterministic_inference:
|
|
411
|
+
return _dispatch_mla_subtype(attn, forward_batch)
|
|
412
|
+
|
|
382
413
|
if (
|
|
383
414
|
_is_extend_without_speculative(forward_batch)
|
|
384
415
|
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
|
@@ -496,6 +527,9 @@ class MoEGate(nn.Module):
|
|
|
496
527
|
True, # is_vnni
|
|
497
528
|
)
|
|
498
529
|
|
|
530
|
+
if get_global_server_args().enable_deterministic_inference:
|
|
531
|
+
return F.linear(hidden_states, self.weight, None)
|
|
532
|
+
|
|
499
533
|
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
|
|
500
534
|
if (
|
|
501
535
|
_is_cuda
|
|
@@ -982,16 +1016,14 @@ class DeepseekV2MoE(nn.Module):
|
|
|
982
1016
|
)
|
|
983
1017
|
|
|
984
1018
|
def op_experts(self, state):
|
|
985
|
-
state.
|
|
1019
|
+
state.combine_input = self.experts.run_moe_core(
|
|
986
1020
|
dispatch_output=state.dispatch_output,
|
|
987
1021
|
)
|
|
988
1022
|
|
|
989
1023
|
def op_combine_a(self, state):
|
|
990
1024
|
if self.ep_size > 1:
|
|
991
1025
|
self.experts.dispatcher.combine_a(
|
|
992
|
-
|
|
993
|
-
topk_ids=state.dispatch_output.topk_ids,
|
|
994
|
-
topk_weights=state.dispatch_output.topk_weights,
|
|
1026
|
+
combine_input=state.pop("combine_input"),
|
|
995
1027
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
|
996
1028
|
)
|
|
997
1029
|
state.pop("dispatch_output")
|
|
@@ -1043,6 +1075,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1043
1075
|
layer_id: int = None,
|
|
1044
1076
|
prefix: str = "",
|
|
1045
1077
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
|
1078
|
+
skip_rope: bool = False,
|
|
1046
1079
|
) -> None:
|
|
1047
1080
|
super().__init__()
|
|
1048
1081
|
self.layer_id = layer_id
|
|
@@ -1062,6 +1095,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1062
1095
|
self.scaling = self.qk_head_dim**-0.5
|
|
1063
1096
|
self.rope_theta = rope_theta
|
|
1064
1097
|
self.max_position_embeddings = max_position_embeddings
|
|
1098
|
+
self.kv_cache_dtype = get_global_server_args().kv_cache_dtype
|
|
1065
1099
|
|
|
1066
1100
|
# NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
|
|
1067
1101
|
if rope_scaling:
|
|
@@ -1122,6 +1156,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1122
1156
|
quant_config=quant_config,
|
|
1123
1157
|
layer_id=layer_id,
|
|
1124
1158
|
alt_stream=alt_stream,
|
|
1159
|
+
fuse_wk_and_weights_proj=is_nsa_indexer_wk_and_weights_proj_fused(
|
|
1160
|
+
config, quant_config
|
|
1161
|
+
),
|
|
1125
1162
|
)
|
|
1126
1163
|
|
|
1127
1164
|
self.kv_b_proj = ColumnParallelLinear(
|
|
@@ -1146,23 +1183,26 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1146
1183
|
)
|
|
1147
1184
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
|
1148
1185
|
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1186
|
+
if not skip_rope:
|
|
1187
|
+
self.rotary_emb = get_rope_wrapper(
|
|
1188
|
+
qk_rope_head_dim,
|
|
1189
|
+
rotary_dim=qk_rope_head_dim,
|
|
1190
|
+
max_position=max_position_embeddings,
|
|
1191
|
+
base=rope_theta,
|
|
1192
|
+
rope_scaling=rope_scaling,
|
|
1193
|
+
is_neox_style=False,
|
|
1194
|
+
device=get_global_server_args().device,
|
|
1195
|
+
)
|
|
1158
1196
|
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1197
|
+
if rope_scaling:
|
|
1198
|
+
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
|
1199
|
+
scaling_factor = rope_scaling["factor"]
|
|
1200
|
+
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
|
1201
|
+
self.scaling = self.scaling * mscale * mscale
|
|
1202
|
+
else:
|
|
1203
|
+
self.rotary_emb.forward = self.rotary_emb.forward_native
|
|
1164
1204
|
else:
|
|
1165
|
-
self.rotary_emb
|
|
1205
|
+
self.rotary_emb = None
|
|
1166
1206
|
|
|
1167
1207
|
self.attn_mqa = RadixAttention(
|
|
1168
1208
|
self.num_local_heads,
|
|
@@ -1238,7 +1278,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1238
1278
|
and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
|
|
1239
1279
|
and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
|
|
1240
1280
|
and _is_cuda
|
|
1241
|
-
and _device_sm
|
|
1281
|
+
and 90 <= _device_sm < 120
|
|
1242
1282
|
)
|
|
1243
1283
|
|
|
1244
1284
|
self.qkv_proj_with_rope_is_int8 = (
|
|
@@ -1359,6 +1399,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1359
1399
|
inner_state = self.forward_normal_chunked_kv_prepare(
|
|
1360
1400
|
positions, hidden_states, forward_batch, zero_allocator
|
|
1361
1401
|
)
|
|
1402
|
+
elif attn_forward_method == AttnForwardMethod.MHA_ONE_SHOT:
|
|
1403
|
+
inner_state = self.forward_normal_one_shot_prepare(
|
|
1404
|
+
positions, hidden_states, forward_batch, zero_allocator
|
|
1405
|
+
)
|
|
1362
1406
|
elif attn_forward_method == AttnForwardMethod.MLA:
|
|
1363
1407
|
if not self.is_mla_preprocess_enabled:
|
|
1364
1408
|
inner_state = self.forward_absorb_prepare(
|
|
@@ -1410,6 +1454,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1410
1454
|
return self.forward_normal_core(*inner_state)
|
|
1411
1455
|
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
|
|
1412
1456
|
return self.forward_normal_chunked_kv_core(*inner_state)
|
|
1457
|
+
elif attn_forward_method == AttnForwardMethod.MHA_ONE_SHOT:
|
|
1458
|
+
return self.forward_normal_one_shot_core(*inner_state)
|
|
1413
1459
|
elif attn_forward_method == AttnForwardMethod.MLA:
|
|
1414
1460
|
return self.forward_absorb_core(*inner_state)
|
|
1415
1461
|
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
|
|
@@ -1444,41 +1490,25 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1444
1490
|
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
|
1445
1491
|
latent_cache = latent_cache.unsqueeze(1)
|
|
1446
1492
|
kv_a = self.kv_a_layernorm(kv_a)
|
|
1447
|
-
kv = self.kv_b_proj(kv_a)[0]
|
|
1448
|
-
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
1449
|
-
k_nope = kv[..., : self.qk_nope_head_dim]
|
|
1450
|
-
v = kv[..., self.qk_nope_head_dim :]
|
|
1451
1493
|
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
|
1452
|
-
|
|
1494
|
+
if self.rotary_emb is not None:
|
|
1495
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
|
1453
1496
|
q[..., self.qk_nope_head_dim :] = q_pe
|
|
1454
|
-
k = torch.empty_like(q)
|
|
1455
1497
|
|
|
1456
|
-
|
|
1498
|
+
self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch)
|
|
1457
1499
|
if (
|
|
1458
|
-
|
|
1459
|
-
and (
|
|
1460
|
-
and (self.qk_nope_head_dim == 128)
|
|
1461
|
-
and (self.qk_rope_head_dim == 64)
|
|
1500
|
+
forward_batch.mha_one_shot
|
|
1501
|
+
and sum(forward_batch.extend_prefix_lens_cpu) != 0
|
|
1462
1502
|
):
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
k[..., : self.qk_nope_head_dim] = k_nope
|
|
1466
|
-
k[..., self.qk_nope_head_dim :] = k_pe
|
|
1467
|
-
|
|
1468
|
-
if not _is_npu:
|
|
1469
|
-
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
|
1470
|
-
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
|
1471
|
-
|
|
1472
|
-
# Save latent cache
|
|
1473
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
1474
|
-
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
|
1475
|
-
)
|
|
1476
|
-
else:
|
|
1477
|
-
# To reduce a time-costing split operation
|
|
1478
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
1479
|
-
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
|
|
1503
|
+
kv_a, k_pe = self._get_mla_kv_buffer(
|
|
1504
|
+
forward_batch.fetch_mha_one_shot_kv_indices(), q.dtype, forward_batch
|
|
1480
1505
|
)
|
|
1506
|
+
kv = self.kv_b_proj(kv_a)[0]
|
|
1507
|
+
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
1508
|
+
k_nope = kv[..., : self.qk_nope_head_dim]
|
|
1509
|
+
v = kv[..., self.qk_nope_head_dim :]
|
|
1481
1510
|
|
|
1511
|
+
k = self._concat_and_cast_mha_k(k_nope, k_pe, forward_batch)
|
|
1482
1512
|
return q, k, v, forward_batch
|
|
1483
1513
|
|
|
1484
1514
|
def forward_normal_core(self, q, k, v, forward_batch):
|
|
@@ -1621,8 +1651,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1621
1651
|
|
|
1622
1652
|
q_nope_out = q_nope_out.transpose(0, 1)
|
|
1623
1653
|
|
|
1624
|
-
if
|
|
1625
|
-
|
|
1654
|
+
if (
|
|
1655
|
+
self.rotary_emb is not None
|
|
1656
|
+
and (not self._fuse_rope_for_trtllm_mla(forward_batch))
|
|
1657
|
+
and (not _use_aiter or not _is_gfx95_supported or self.use_nsa)
|
|
1626
1658
|
):
|
|
1627
1659
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
|
1628
1660
|
|
|
@@ -2288,20 +2320,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
2288
2320
|
for i in range(forward_batch.num_prefix_chunks):
|
|
2289
2321
|
forward_batch.set_prefix_chunk_idx(i)
|
|
2290
2322
|
|
|
2323
|
+
kv_indices = forward_batch.prefix_chunk_kv_indices[i]
|
|
2291
2324
|
# Fetch latent cache from memory pool with precomputed chunked kv indices
|
|
2292
|
-
|
|
2293
|
-
|
|
2294
|
-
)
|
|
2295
|
-
latent_cache = (
|
|
2296
|
-
latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
|
|
2297
|
-
.contiguous()
|
|
2298
|
-
.to(q.dtype)
|
|
2299
|
-
)
|
|
2300
|
-
|
|
2301
|
-
kv_a_normed, k_pe = latent_cache.split(
|
|
2302
|
-
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
|
2325
|
+
kv_a_normed, k_pe = self._get_mla_kv_buffer(
|
|
2326
|
+
kv_indices, q.dtype, forward_batch
|
|
2303
2327
|
)
|
|
2304
|
-
kv_a_normed = kv_a_normed.squeeze(1).contiguous()
|
|
2305
2328
|
kv = self.kv_b_proj(kv_a_normed)[0]
|
|
2306
2329
|
kv = kv.view(
|
|
2307
2330
|
-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
|
|
@@ -2376,6 +2399,107 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
2376
2399
|
output, _ = self.o_proj(attn_output)
|
|
2377
2400
|
return output
|
|
2378
2401
|
|
|
2402
|
+
def forward_normal_one_shot_prepare(
|
|
2403
|
+
self,
|
|
2404
|
+
positions: torch.Tensor,
|
|
2405
|
+
hidden_states: torch.Tensor,
|
|
2406
|
+
forward_batch: ForwardBatch,
|
|
2407
|
+
zero_allocator: BumpAllocator,
|
|
2408
|
+
):
|
|
2409
|
+
forward_batch.mha_one_shot = True
|
|
2410
|
+
return self.forward_normal_prepare(
|
|
2411
|
+
positions, hidden_states, forward_batch, zero_allocator
|
|
2412
|
+
)
|
|
2413
|
+
|
|
2414
|
+
def forward_normal_one_shot_core(self, q, k, v, forward_batch):
|
|
2415
|
+
has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
|
|
2416
|
+
# Only initialize the info once
|
|
2417
|
+
if has_extend_prefix and forward_batch.num_prefix_chunks is None:
|
|
2418
|
+
forward_batch.num_prefix_chunks = 0
|
|
2419
|
+
if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
|
|
2420
|
+
forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
|
|
2421
|
+
forward_batch.mha_return_lse = False
|
|
2422
|
+
# Do mha for extended part without prefix
|
|
2423
|
+
forward_batch.set_attn_attend_prefix_cache(False)
|
|
2424
|
+
return self.forward_normal_core(q, k, v, forward_batch)
|
|
2425
|
+
|
|
2426
|
+
def _set_mla_kv_buffer(
|
|
2427
|
+
self,
|
|
2428
|
+
latent_cache: torch.Tensor,
|
|
2429
|
+
kv_a: torch.Tensor,
|
|
2430
|
+
k_pe: torch.Tensor,
|
|
2431
|
+
forward_batch: ForwardBatch,
|
|
2432
|
+
):
|
|
2433
|
+
if _is_cuda:
|
|
2434
|
+
# Save latent cache
|
|
2435
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
|
2436
|
+
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
|
|
2437
|
+
)
|
|
2438
|
+
elif _is_npu:
|
|
2439
|
+
# To reduce a time-costing split operation
|
|
2440
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
2441
|
+
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
|
|
2442
|
+
)
|
|
2443
|
+
else:
|
|
2444
|
+
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
|
2445
|
+
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
|
2446
|
+
|
|
2447
|
+
# Save latent cache
|
|
2448
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
2449
|
+
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
|
2450
|
+
)
|
|
2451
|
+
|
|
2452
|
+
def _get_mla_kv_buffer(
|
|
2453
|
+
self,
|
|
2454
|
+
kv_indices: torch.Tensor,
|
|
2455
|
+
dst_dtype: torch.dtype,
|
|
2456
|
+
forward_batch: ForwardBatch,
|
|
2457
|
+
):
|
|
2458
|
+
if _is_cuda:
|
|
2459
|
+
kv_a, k_pe = forward_batch.token_to_kv_pool.get_mla_kv_buffer(
|
|
2460
|
+
self.attn_mha, kv_indices, dst_dtype
|
|
2461
|
+
)
|
|
2462
|
+
kv_a = kv_a.squeeze(1)
|
|
2463
|
+
else:
|
|
2464
|
+
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
|
|
2465
|
+
self.attn_mha.layer_id
|
|
2466
|
+
)
|
|
2467
|
+
latent_cache = latent_cache_buf[kv_indices].contiguous().to(dst_dtype)
|
|
2468
|
+
|
|
2469
|
+
kv_a, k_pe = latent_cache.split(
|
|
2470
|
+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
|
2471
|
+
)
|
|
2472
|
+
kv_a = kv_a.squeeze(1).contiguous()
|
|
2473
|
+
return kv_a, k_pe
|
|
2474
|
+
|
|
2475
|
+
def _concat_and_cast_mha_k(self, k_nope, k_pe, forward_batch):
|
|
2476
|
+
# Temporary for DeepSeek V3/R1 only, but can generalize if needed
|
|
2477
|
+
k_shape = (k_nope.shape[0], self.num_local_heads, self.qk_head_dim)
|
|
2478
|
+
if (
|
|
2479
|
+
_is_cuda
|
|
2480
|
+
and (self.num_local_heads == 128)
|
|
2481
|
+
and (self.qk_nope_head_dim == 128)
|
|
2482
|
+
and (self.qk_rope_head_dim == 64)
|
|
2483
|
+
):
|
|
2484
|
+
k = k_nope.new_empty(*k_shape)
|
|
2485
|
+
concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
|
|
2486
|
+
elif _is_cuda:
|
|
2487
|
+
# fa3 mha support fp8 inputs
|
|
2488
|
+
if (
|
|
2489
|
+
self.current_attention_backend == "fa3"
|
|
2490
|
+
and self.kv_cache_dtype != "auto"
|
|
2491
|
+
):
|
|
2492
|
+
attn_dtype = forward_batch.token_to_kv_pool.dtype
|
|
2493
|
+
else:
|
|
2494
|
+
attn_dtype = k_nope.dtype
|
|
2495
|
+
k = k_nope.new_empty(*k_shape, dtype=attn_dtype)
|
|
2496
|
+
concat_and_cast_mha_k_triton(k, k_nope, k_pe)
|
|
2497
|
+
else:
|
|
2498
|
+
k = k_nope.new_empty(*k_shape)
|
|
2499
|
+
k[..., : self.qk_nope_head_dim] = k_nope
|
|
2500
|
+
k[..., self.qk_nope_head_dim :] = k_pe
|
|
2501
|
+
return k
|
|
2502
|
+
|
|
2379
2503
|
@staticmethod
|
|
2380
2504
|
def _get_q_b_proj_quant_config(quant_config):
|
|
2381
2505
|
if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
|
|
@@ -2725,6 +2849,7 @@ class DeepseekV2Model(nn.Module):
|
|
|
2725
2849
|
self.embed_tokens.embedding_dim,
|
|
2726
2850
|
)
|
|
2727
2851
|
)
|
|
2852
|
+
self.layers_to_capture = []
|
|
2728
2853
|
|
|
2729
2854
|
def get_input_embeddings(self) -> torch.Tensor:
|
|
2730
2855
|
return self.embed_tokens
|
|
@@ -2781,9 +2906,11 @@ class DeepseekV2Model(nn.Module):
|
|
|
2781
2906
|
normal_end_layer = self.first_k_dense_replace
|
|
2782
2907
|
elif self.first_k_dense_replace < normal_start_layer:
|
|
2783
2908
|
normal_end_layer = normal_start_layer = 0
|
|
2784
|
-
|
|
2909
|
+
aux_hidden_states = []
|
|
2785
2910
|
for i in range(normal_start_layer, normal_end_layer):
|
|
2786
2911
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
|
2912
|
+
if i in self.layers_to_capture:
|
|
2913
|
+
aux_hidden_states.append(hidden_states + residual)
|
|
2787
2914
|
layer = self.layers[i]
|
|
2788
2915
|
hidden_states, residual = layer(
|
|
2789
2916
|
positions,
|
|
@@ -2821,7 +2948,9 @@ class DeepseekV2Model(nn.Module):
|
|
|
2821
2948
|
hidden_states = self.norm(hidden_states)
|
|
2822
2949
|
else:
|
|
2823
2950
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
2824
|
-
|
|
2951
|
+
if len(aux_hidden_states) == 0:
|
|
2952
|
+
return hidden_states
|
|
2953
|
+
return hidden_states, aux_hidden_states
|
|
2825
2954
|
|
|
2826
2955
|
|
|
2827
2956
|
class DeepseekV2ForCausalLM(nn.Module):
|
|
@@ -2875,6 +3004,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
2875
3004
|
if isinstance(layer.mlp, DeepseekV2MoE)
|
|
2876
3005
|
}
|
|
2877
3006
|
)
|
|
3007
|
+
self.capture_aux_hidden_states = False
|
|
2878
3008
|
|
|
2879
3009
|
@property
|
|
2880
3010
|
def routed_experts_weights_of_layer(self):
|
|
@@ -2899,7 +3029,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
2899
3029
|
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
|
2900
3030
|
elif get_moe_expert_parallel_world_size() > 1:
|
|
2901
3031
|
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
|
|
2902
|
-
elif self.quant_config.get_name() == "w4afp8":
|
|
3032
|
+
elif self.quant_config and self.quant_config.get_name() == "w4afp8":
|
|
2903
3033
|
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
|
|
2904
3034
|
|
|
2905
3035
|
if disable_reason is not None:
|
|
@@ -2928,10 +3058,13 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
2928
3058
|
hidden_states = self.model(
|
|
2929
3059
|
input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
|
|
2930
3060
|
)
|
|
3061
|
+
aux_hidden_states = None
|
|
3062
|
+
if self.capture_aux_hidden_states:
|
|
3063
|
+
hidden_states, aux_hidden_states = hidden_states
|
|
2931
3064
|
|
|
2932
3065
|
if self.pp_group.is_last_rank:
|
|
2933
3066
|
return self.logits_processor(
|
|
2934
|
-
input_ids, hidden_states, self.lm_head, forward_batch
|
|
3067
|
+
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
|
2935
3068
|
)
|
|
2936
3069
|
else:
|
|
2937
3070
|
return hidden_states
|
|
@@ -3190,8 +3323,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
3190
3323
|
experts = layer.mlp.experts
|
|
3191
3324
|
if isinstance(experts, DeepEPMoE):
|
|
3192
3325
|
for w in [
|
|
3193
|
-
experts.
|
|
3194
|
-
experts.
|
|
3326
|
+
(experts.w13_weight, experts.w13_weight_scale_inv),
|
|
3327
|
+
(experts.w2_weight, experts.w2_weight_scale_inv),
|
|
3195
3328
|
]:
|
|
3196
3329
|
requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
|
|
3197
3330
|
else:
|
|
@@ -3239,10 +3372,26 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
3239
3372
|
)
|
|
3240
3373
|
|
|
3241
3374
|
experts = layer.mlp.experts
|
|
3375
|
+
w13_weight_fp8 = (
|
|
3376
|
+
experts.w13_weight,
|
|
3377
|
+
(
|
|
3378
|
+
experts.w13_weight_scale_inv
|
|
3379
|
+
if hasattr(experts, "w13_weight_scale_inv")
|
|
3380
|
+
else experts.w13_weight_scale
|
|
3381
|
+
),
|
|
3382
|
+
)
|
|
3383
|
+
w2_weight_fp8 = (
|
|
3384
|
+
experts.w2_weight,
|
|
3385
|
+
(
|
|
3386
|
+
experts.w2_weight_scale_inv
|
|
3387
|
+
if hasattr(experts, "w2_weight_scale_inv")
|
|
3388
|
+
else experts.w2_weight_scale
|
|
3389
|
+
),
|
|
3390
|
+
)
|
|
3242
3391
|
if isinstance(experts, DeepEPMoE):
|
|
3243
3392
|
for w in [
|
|
3244
|
-
|
|
3245
|
-
|
|
3393
|
+
w13_weight_fp8,
|
|
3394
|
+
w2_weight_fp8,
|
|
3246
3395
|
]:
|
|
3247
3396
|
transform_scale_ue8m0_inplace(w[1], mn=w[0].shape[-2])
|
|
3248
3397
|
|
|
@@ -3295,6 +3444,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
3295
3444
|
self.config.q_lora_rank is not None
|
|
3296
3445
|
)
|
|
3297
3446
|
cached_a_proj = {} if fuse_qkv_a_proj else None
|
|
3447
|
+
fuse_wk_and_weights_proj = is_nsa_indexer_wk_and_weights_proj_fused(
|
|
3448
|
+
self.config, self.quant_config
|
|
3449
|
+
)
|
|
3450
|
+
cached_wk_and_weights_proj = {} if fuse_wk_and_weights_proj else None
|
|
3298
3451
|
|
|
3299
3452
|
if is_nextn:
|
|
3300
3453
|
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
|
@@ -3466,6 +3619,53 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
3466
3619
|
)
|
|
3467
3620
|
cached_a_proj.pop(q_a_proj_name)
|
|
3468
3621
|
cached_a_proj.pop(kv_a_proj_name)
|
|
3622
|
+
elif fuse_wk_and_weights_proj and (
|
|
3623
|
+
"wk" in name or "weights_proj" in name
|
|
3624
|
+
):
|
|
3625
|
+
cached_wk_and_weights_proj[name] = loaded_weight
|
|
3626
|
+
wk_name = (
|
|
3627
|
+
name
|
|
3628
|
+
if "wk" in name
|
|
3629
|
+
else name.replace("weights_proj", "wk")
|
|
3630
|
+
)
|
|
3631
|
+
weights_proj_name = (
|
|
3632
|
+
name
|
|
3633
|
+
if "weights_proj" in name
|
|
3634
|
+
else name.replace("wk", "weights_proj")
|
|
3635
|
+
)
|
|
3636
|
+
|
|
3637
|
+
# When both wk and weights_proj has been cached, load the fused weight to parameter
|
|
3638
|
+
if (
|
|
3639
|
+
wk_name in cached_wk_and_weights_proj
|
|
3640
|
+
and weights_proj_name in cached_wk_and_weights_proj
|
|
3641
|
+
):
|
|
3642
|
+
wk_weight = cached_wk_and_weights_proj[wk_name]
|
|
3643
|
+
weights_proj_weight = cached_wk_and_weights_proj[
|
|
3644
|
+
weights_proj_name
|
|
3645
|
+
]
|
|
3646
|
+
# todo dequantize wk for fp8
|
|
3647
|
+
assert wk_weight.dtype == weights_proj_weight.dtype
|
|
3648
|
+
fused_weight = torch.cat(
|
|
3649
|
+
[wk_weight, weights_proj_weight], dim=0
|
|
3650
|
+
)
|
|
3651
|
+
param_name = (
|
|
3652
|
+
name.replace("wk", "fused_wk_and_weights_proj")
|
|
3653
|
+
if "wk" in name
|
|
3654
|
+
else name.replace(
|
|
3655
|
+
"weights_proj",
|
|
3656
|
+
"fused_wk_and_weights_proj",
|
|
3657
|
+
)
|
|
3658
|
+
)
|
|
3659
|
+
param = params_dict[param_name]
|
|
3660
|
+
|
|
3661
|
+
weight_loader = getattr(
|
|
3662
|
+
param, "weight_loader", default_weight_loader
|
|
3663
|
+
)
|
|
3664
|
+
futures.append(
|
|
3665
|
+
executor.submit(weight_loader, param, fused_weight)
|
|
3666
|
+
)
|
|
3667
|
+
cached_wk_and_weights_proj.pop(wk_name)
|
|
3668
|
+
cached_wk_and_weights_proj.pop(weights_proj_name)
|
|
3469
3669
|
else:
|
|
3470
3670
|
if (
|
|
3471
3671
|
"k_scale" in name or "v_scale" in name
|
|
@@ -3561,8 +3761,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
3561
3761
|
del self.lm_head.weight
|
|
3562
3762
|
self.model.embed_tokens.weight = embed
|
|
3563
3763
|
self.lm_head.weight = head
|
|
3564
|
-
|
|
3565
|
-
|
|
3764
|
+
if not _is_npu:
|
|
3765
|
+
torch.cuda.empty_cache()
|
|
3766
|
+
torch.cuda.synchronize()
|
|
3767
|
+
else:
|
|
3768
|
+
torch.npu.empty_cache()
|
|
3769
|
+
torch.npu.synchronize()
|
|
3566
3770
|
|
|
3567
3771
|
@classmethod
|
|
3568
3772
|
def get_model_config_for_expert_location(cls, config):
|
|
@@ -3572,6 +3776,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
3572
3776
|
num_groups=config.n_group,
|
|
3573
3777
|
)
|
|
3574
3778
|
|
|
3779
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
|
3780
|
+
if not self.pp_group.is_last_rank:
|
|
3781
|
+
return
|
|
3782
|
+
|
|
3783
|
+
if layer_ids is None:
|
|
3784
|
+
self.capture_aux_hidden_states = True
|
|
3785
|
+
num_layers = self.config.num_hidden_layers
|
|
3786
|
+
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
|
3787
|
+
else:
|
|
3788
|
+
self.capture_aux_hidden_states = True
|
|
3789
|
+
# we plus 1 here because in sglang, for the ith layer, it takes the output
|
|
3790
|
+
# of the (i-1)th layer as aux hidden state
|
|
3791
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
|
3792
|
+
|
|
3575
3793
|
|
|
3576
3794
|
AttentionBackendRegistry.register("ascend", handle_attention_ascend)
|
|
3577
3795
|
AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
|