sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +56 -12
 - sglang/launch_server.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
 - sglang/srt/compilation/backend.py +1 -1
 - sglang/srt/configs/model_config.py +5 -5
 - sglang/srt/distributed/parallel_state.py +0 -7
 - sglang/srt/entrypoints/engine.py +18 -15
 - sglang/srt/entrypoints/grpc_server.py +0 -1
 - sglang/srt/entrypoints/http_server.py +75 -94
 - sglang/srt/environ.py +16 -2
 - sglang/srt/eplb/expert_distribution.py +30 -0
 - sglang/srt/function_call/function_call_parser.py +2 -0
 - sglang/srt/function_call/minimax_m2.py +367 -0
 - sglang/srt/layers/activation.py +6 -0
 - sglang/srt/layers/attention/flashattention_backend.py +12 -2
 - sglang/srt/layers/attention/flashinfer_backend.py +10 -1
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
 - sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
 - sglang/srt/layers/attention/utils.py +78 -0
 - sglang/srt/layers/communicator.py +1 -0
 - sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
 - sglang/srt/layers/layernorm.py +19 -4
 - sglang/srt/layers/logits_processor.py +5 -0
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
 - sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
 - sglang/srt/layers/moe/ep_moe/layer.py +79 -272
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
 - sglang/srt/layers/moe/moe_runner/runner.py +3 -0
 - sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
 - sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
 - sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
 - sglang/srt/layers/moe/topk.py +4 -4
 - sglang/srt/layers/moe/utils.py +3 -4
 - sglang/srt/layers/quantization/__init__.py +3 -5
 - sglang/srt/layers/quantization/awq.py +0 -3
 - sglang/srt/layers/quantization/base_config.py +7 -0
 - sglang/srt/layers/quantization/fp8.py +68 -63
 - sglang/srt/layers/quantization/gguf.py +566 -0
 - sglang/srt/layers/quantization/mxfp4.py +30 -38
 - sglang/srt/layers/quantization/unquant.py +23 -45
 - sglang/srt/layers/quantization/w4afp8.py +38 -2
 - sglang/srt/layers/radix_attention.py +5 -2
 - sglang/srt/layers/rotary_embedding.py +13 -1
 - sglang/srt/layers/sampler.py +12 -1
 - sglang/srt/managers/io_struct.py +3 -0
 - sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
 - sglang/srt/managers/scheduler.py +21 -15
 - sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
 - sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
 - sglang/srt/managers/tokenizer_manager.py +11 -19
 - sglang/srt/mem_cache/hicache_storage.py +7 -1
 - sglang/srt/mem_cache/memory_pool.py +82 -0
 - sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
 - sglang/srt/model_executor/forward_batch_info.py +44 -3
 - sglang/srt/model_executor/model_runner.py +1 -149
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
 - sglang/srt/models/deepseek_v2.py +147 -44
 - sglang/srt/models/glm4_moe.py +322 -354
 - sglang/srt/models/glm4_moe_nextn.py +4 -14
 - sglang/srt/models/glm4v_moe.py +29 -196
 - sglang/srt/models/minimax_m2.py +922 -0
 - sglang/srt/models/nvila.py +355 -0
 - sglang/srt/models/nvila_lite.py +184 -0
 - sglang/srt/models/qwen2.py +22 -1
 - sglang/srt/models/qwen3.py +34 -4
 - sglang/srt/models/qwen3_moe.py +2 -4
 - sglang/srt/multimodal/processors/base_processor.py +1 -0
 - sglang/srt/multimodal/processors/glm4v.py +1 -1
 - sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
 - sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
 - sglang/srt/parser/reasoning_parser.py +28 -1
 - sglang/srt/server_args.py +365 -186
 - sglang/srt/single_batch_overlap.py +2 -7
 - sglang/srt/utils/common.py +87 -42
 - sglang/srt/utils/hf_transformers_utils.py +7 -3
 - sglang/test/test_deterministic.py +235 -12
 - sglang/test/test_deterministic_utils.py +2 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
 - sglang/srt/models/vila.py +0 -306
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
 
    
        sglang/srt/models/deepseek_v2.py
    CHANGED
    
    | 
         @@ -57,6 +57,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import ( 
     | 
|
| 
       57 
57 
     | 
    
         
             
                is_mla_preprocess_enabled,
         
     | 
| 
       58 
58 
     | 
    
         
             
            )
         
     | 
| 
       59 
59 
     | 
    
         
             
            from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
         
     | 
| 
      
 60 
     | 
    
         
            +
            from sglang.srt.layers.attention.utils import concat_and_cast_mha_k_triton
         
     | 
| 
       60 
61 
     | 
    
         
             
            from sglang.srt.layers.communicator import (
         
     | 
| 
       61 
62 
     | 
    
         
             
                LayerCommunicator,
         
     | 
| 
       62 
63 
     | 
    
         
             
                LayerScatterModes,
         
     | 
| 
         @@ -241,6 +242,10 @@ class AttnForwardMethod(IntEnum): 
     | 
|
| 
       241 
242 
     | 
    
         
             
                # This method can avoid OOM when prefix lengths are long.
         
     | 
| 
       242 
243 
     | 
    
         
             
                MHA_CHUNKED_KV = auto()
         
     | 
| 
       243 
244 
     | 
    
         | 
| 
      
 245 
     | 
    
         
            +
                # Use multi-head attention, execute the MHA for prefix and extended kv in one shot
         
     | 
| 
      
 246 
     | 
    
         
            +
                # when the sequence lengths are below the threshold.
         
     | 
| 
      
 247 
     | 
    
         
            +
                MHA_ONE_SHOT = auto()
         
     | 
| 
      
 248 
     | 
    
         
            +
             
     | 
| 
       244 
249 
     | 
    
         
             
                # Use MLA but with fused RoPE
         
     | 
| 
       245 
250 
     | 
    
         
             
                MLA_FUSED_ROPE = auto()
         
     | 
| 
       246 
251 
     | 
    
         | 
| 
         @@ -306,6 +311,14 @@ def _is_extend_without_speculative(forward_batch): 
     | 
|
| 
       306 
311 
     | 
    
         
             
                )
         
     | 
| 
       307 
312 
     | 
    
         | 
| 
       308 
313 
     | 
    
         | 
| 
      
 314 
     | 
    
         
            +
            def _support_mha_one_shot(attn: DeepseekV2AttentionMLA, forward_batch, backend_name):
         
     | 
| 
      
 315 
     | 
    
         
            +
                attn_supported = backend_name in ["fa3", "flashinfer", "flashmla"]
         
     | 
| 
      
 316 
     | 
    
         
            +
                sum_seq_lens = (
         
     | 
| 
      
 317 
     | 
    
         
            +
                    sum(forward_batch.seq_lens_cpu) if forward_batch.seq_lens_cpu is not None else 0
         
     | 
| 
      
 318 
     | 
    
         
            +
                )
         
     | 
| 
      
 319 
     | 
    
         
            +
                return attn_supported and sum_seq_lens <= forward_batch.get_max_chunk_capacity()
         
     | 
| 
      
 320 
     | 
    
         
            +
             
     | 
| 
      
 321 
     | 
    
         
            +
             
     | 
| 
       309 
322 
     | 
    
         
             
            def _handle_attention_backend(
         
     | 
| 
       310 
323 
     | 
    
         
             
                attn: DeepseekV2AttentionMLA, forward_batch, backend_name
         
     | 
| 
       311 
324 
     | 
    
         
             
            ):
         
     | 
| 
         @@ -325,6 +338,8 @@ def _handle_attention_backend( 
     | 
|
| 
       325 
338 
     | 
    
         
             
                        or sum_extend_prefix_lens == 0
         
     | 
| 
       326 
339 
     | 
    
         
             
                    )
         
     | 
| 
       327 
340 
     | 
    
         
             
                ):
         
     | 
| 
      
 341 
     | 
    
         
            +
                    if _support_mha_one_shot(attn, forward_batch, backend_name):
         
     | 
| 
      
 342 
     | 
    
         
            +
                        return AttnForwardMethod.MHA_ONE_SHOT
         
     | 
| 
       328 
343 
     | 
    
         
             
                    return AttnForwardMethod.MHA_CHUNKED_KV
         
     | 
| 
       329 
344 
     | 
    
         
             
                else:
         
     | 
| 
       330 
345 
     | 
    
         
             
                    return _dispatch_mla_subtype(attn, forward_batch)
         
     | 
| 
         @@ -335,7 +350,11 @@ def handle_attention_flashinfer(attn, forward_batch): 
     | 
|
| 
       335 
350 
     | 
    
         | 
| 
       336 
351 
     | 
    
         | 
| 
       337 
352 
     | 
    
         
             
            def handle_attention_fa3(attn, forward_batch):
         
     | 
| 
       338 
     | 
    
         
            -
                 
     | 
| 
      
 353 
     | 
    
         
            +
                # when deterministic inference is enabled, use MLA
         
     | 
| 
      
 354 
     | 
    
         
            +
                if get_global_server_args().enable_deterministic_inference:
         
     | 
| 
      
 355 
     | 
    
         
            +
                    return _dispatch_mla_subtype(attn, forward_batch)
         
     | 
| 
      
 356 
     | 
    
         
            +
                else:
         
     | 
| 
      
 357 
     | 
    
         
            +
                    return _handle_attention_backend(attn, forward_batch, "fa3")
         
     | 
| 
       339 
358 
     | 
    
         | 
| 
       340 
359 
     | 
    
         | 
| 
       341 
360 
     | 
    
         
             
            def handle_attention_flashmla(attn, forward_batch):
         
     | 
| 
         @@ -379,6 +398,10 @@ def handle_attention_nsa(attn, forward_batch): 
     | 
|
| 
       379 
398 
     | 
    
         | 
| 
       380 
399 
     | 
    
         | 
| 
       381 
400 
     | 
    
         
             
            def handle_attention_triton(attn, forward_batch):
         
     | 
| 
      
 401 
     | 
    
         
            +
                # when deterministic inference is enabled, use MLA
         
     | 
| 
      
 402 
     | 
    
         
            +
                if get_global_server_args().enable_deterministic_inference:
         
     | 
| 
      
 403 
     | 
    
         
            +
                    return _dispatch_mla_subtype(attn, forward_batch)
         
     | 
| 
      
 404 
     | 
    
         
            +
             
     | 
| 
       382 
405 
     | 
    
         
             
                if (
         
     | 
| 
       383 
406 
     | 
    
         
             
                    _is_extend_without_speculative(forward_batch)
         
     | 
| 
       384 
407 
     | 
    
         
             
                    and sum(forward_batch.extend_prefix_lens_cpu) == 0
         
     | 
| 
         @@ -982,16 +1005,14 @@ class DeepseekV2MoE(nn.Module): 
     | 
|
| 
       982 
1005 
     | 
    
         
             
                            )
         
     | 
| 
       983 
1006 
     | 
    
         | 
| 
       984 
1007 
     | 
    
         
             
                def op_experts(self, state):
         
     | 
| 
       985 
     | 
    
         
            -
                    state. 
     | 
| 
      
 1008 
     | 
    
         
            +
                    state.combine_input = self.experts.run_moe_core(
         
     | 
| 
       986 
1009 
     | 
    
         
             
                        dispatch_output=state.dispatch_output,
         
     | 
| 
       987 
1010 
     | 
    
         
             
                    )
         
     | 
| 
       988 
1011 
     | 
    
         | 
| 
       989 
1012 
     | 
    
         
             
                def op_combine_a(self, state):
         
     | 
| 
       990 
1013 
     | 
    
         
             
                    if self.ep_size > 1:
         
     | 
| 
       991 
1014 
     | 
    
         
             
                        self.experts.dispatcher.combine_a(
         
     | 
| 
       992 
     | 
    
         
            -
                             
     | 
| 
       993 
     | 
    
         
            -
                            topk_ids=state.dispatch_output.topk_ids,
         
     | 
| 
       994 
     | 
    
         
            -
                            topk_weights=state.dispatch_output.topk_weights,
         
     | 
| 
      
 1015 
     | 
    
         
            +
                            combine_input=state.pop("combine_input"),
         
     | 
| 
       995 
1016 
     | 
    
         
             
                            tbo_subbatch_index=state.get("tbo_subbatch_index"),
         
     | 
| 
       996 
1017 
     | 
    
         
             
                        )
         
     | 
| 
       997 
1018 
     | 
    
         
             
                        state.pop("dispatch_output")
         
     | 
| 
         @@ -1062,6 +1083,7 @@ class DeepseekV2AttentionMLA(nn.Module): 
     | 
|
| 
       1062 
1083 
     | 
    
         
             
                    self.scaling = self.qk_head_dim**-0.5
         
     | 
| 
       1063 
1084 
     | 
    
         
             
                    self.rope_theta = rope_theta
         
     | 
| 
       1064 
1085 
     | 
    
         
             
                    self.max_position_embeddings = max_position_embeddings
         
     | 
| 
      
 1086 
     | 
    
         
            +
                    self.kv_cache_dtype = get_global_server_args().kv_cache_dtype
         
     | 
| 
       1065 
1087 
     | 
    
         | 
| 
       1066 
1088 
     | 
    
         
             
                    # NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
         
     | 
| 
       1067 
1089 
     | 
    
         
             
                    if rope_scaling:
         
     | 
| 
         @@ -1359,6 +1381,10 @@ class DeepseekV2AttentionMLA(nn.Module): 
     | 
|
| 
       1359 
1381 
     | 
    
         
             
                        inner_state = self.forward_normal_chunked_kv_prepare(
         
     | 
| 
       1360 
1382 
     | 
    
         
             
                            positions, hidden_states, forward_batch, zero_allocator
         
     | 
| 
       1361 
1383 
     | 
    
         
             
                        )
         
     | 
| 
      
 1384 
     | 
    
         
            +
                    elif attn_forward_method == AttnForwardMethod.MHA_ONE_SHOT:
         
     | 
| 
      
 1385 
     | 
    
         
            +
                        inner_state = self.forward_normal_one_shot_prepare(
         
     | 
| 
      
 1386 
     | 
    
         
            +
                            positions, hidden_states, forward_batch, zero_allocator
         
     | 
| 
      
 1387 
     | 
    
         
            +
                        )
         
     | 
| 
       1362 
1388 
     | 
    
         
             
                    elif attn_forward_method == AttnForwardMethod.MLA:
         
     | 
| 
       1363 
1389 
     | 
    
         
             
                        if not self.is_mla_preprocess_enabled:
         
     | 
| 
       1364 
1390 
     | 
    
         
             
                            inner_state = self.forward_absorb_prepare(
         
     | 
| 
         @@ -1410,6 +1436,8 @@ class DeepseekV2AttentionMLA(nn.Module): 
     | 
|
| 
       1410 
1436 
     | 
    
         
             
                        return self.forward_normal_core(*inner_state)
         
     | 
| 
       1411 
1437 
     | 
    
         
             
                    elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
         
     | 
| 
       1412 
1438 
     | 
    
         
             
                        return self.forward_normal_chunked_kv_core(*inner_state)
         
     | 
| 
      
 1439 
     | 
    
         
            +
                    elif attn_forward_method == AttnForwardMethod.MHA_ONE_SHOT:
         
     | 
| 
      
 1440 
     | 
    
         
            +
                        return self.forward_normal_one_shot_core(*inner_state)
         
     | 
| 
       1413 
1441 
     | 
    
         
             
                    elif attn_forward_method == AttnForwardMethod.MLA:
         
     | 
| 
       1414 
1442 
     | 
    
         
             
                        return self.forward_absorb_core(*inner_state)
         
     | 
| 
       1415 
1443 
     | 
    
         
             
                    elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
         
     | 
| 
         @@ -1444,41 +1472,24 @@ class DeepseekV2AttentionMLA(nn.Module): 
     | 
|
| 
       1444 
1472 
     | 
    
         
             
                    kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
         
     | 
| 
       1445 
1473 
     | 
    
         
             
                    latent_cache = latent_cache.unsqueeze(1)
         
     | 
| 
       1446 
1474 
     | 
    
         
             
                    kv_a = self.kv_a_layernorm(kv_a)
         
     | 
| 
       1447 
     | 
    
         
            -
                    kv = self.kv_b_proj(kv_a)[0]
         
     | 
| 
       1448 
     | 
    
         
            -
                    kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
         
     | 
| 
       1449 
     | 
    
         
            -
                    k_nope = kv[..., : self.qk_nope_head_dim]
         
     | 
| 
       1450 
     | 
    
         
            -
                    v = kv[..., self.qk_nope_head_dim :]
         
     | 
| 
       1451 
1475 
     | 
    
         
             
                    k_pe = latent_cache[:, :, self.kv_lora_rank :]
         
     | 
| 
       1452 
1476 
     | 
    
         
             
                    q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
         
     | 
| 
       1453 
1477 
     | 
    
         
             
                    q[..., self.qk_nope_head_dim :] = q_pe
         
     | 
| 
       1454 
     | 
    
         
            -
                    k = torch.empty_like(q)
         
     | 
| 
       1455 
1478 
     | 
    
         | 
| 
       1456 
     | 
    
         
            -
                     
     | 
| 
      
 1479 
     | 
    
         
            +
                    self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch)
         
     | 
| 
       1457 
1480 
     | 
    
         
             
                    if (
         
     | 
| 
       1458 
     | 
    
         
            -
                         
     | 
| 
       1459 
     | 
    
         
            -
                        and ( 
     | 
| 
       1460 
     | 
    
         
            -
                        and (self.qk_nope_head_dim == 128)
         
     | 
| 
       1461 
     | 
    
         
            -
                        and (self.qk_rope_head_dim == 64)
         
     | 
| 
      
 1481 
     | 
    
         
            +
                        forward_batch.mha_one_shot
         
     | 
| 
      
 1482 
     | 
    
         
            +
                        and sum(forward_batch.extend_prefix_lens_cpu) != 0
         
     | 
| 
       1462 
1483 
     | 
    
         
             
                    ):
         
     | 
| 
       1463 
     | 
    
         
            -
                         
     | 
| 
       1464 
     | 
    
         
            -
             
     | 
| 
       1465 
     | 
    
         
            -
                        k[..., : self.qk_nope_head_dim] = k_nope
         
     | 
| 
       1466 
     | 
    
         
            -
                        k[..., self.qk_nope_head_dim :] = k_pe
         
     | 
| 
       1467 
     | 
    
         
            -
             
     | 
| 
       1468 
     | 
    
         
            -
                    if not _is_npu:
         
     | 
| 
       1469 
     | 
    
         
            -
                        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
         
     | 
| 
       1470 
     | 
    
         
            -
                        latent_cache[:, :, self.kv_lora_rank :] = k_pe
         
     | 
| 
       1471 
     | 
    
         
            -
             
     | 
| 
       1472 
     | 
    
         
            -
                        # Save latent cache
         
     | 
| 
       1473 
     | 
    
         
            -
                        forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
       1474 
     | 
    
         
            -
                            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
         
     | 
| 
       1475 
     | 
    
         
            -
                        )
         
     | 
| 
       1476 
     | 
    
         
            -
                    else:
         
     | 
| 
       1477 
     | 
    
         
            -
                        # To reduce a time-costing split operation
         
     | 
| 
       1478 
     | 
    
         
            -
                        forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
       1479 
     | 
    
         
            -
                            self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
         
     | 
| 
      
 1484 
     | 
    
         
            +
                        kv_a, k_pe = self._get_mla_kv_buffer(
         
     | 
| 
      
 1485 
     | 
    
         
            +
                            forward_batch.fetch_mha_one_shot_kv_indices(), q.dtype, forward_batch
         
     | 
| 
       1480 
1486 
     | 
    
         
             
                        )
         
     | 
| 
      
 1487 
     | 
    
         
            +
                    kv = self.kv_b_proj(kv_a)[0]
         
     | 
| 
      
 1488 
     | 
    
         
            +
                    kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
         
     | 
| 
      
 1489 
     | 
    
         
            +
                    k_nope = kv[..., : self.qk_nope_head_dim]
         
     | 
| 
      
 1490 
     | 
    
         
            +
                    v = kv[..., self.qk_nope_head_dim :]
         
     | 
| 
       1481 
1491 
     | 
    
         | 
| 
      
 1492 
     | 
    
         
            +
                    k = self._concat_and_cast_mha_k(k_nope, k_pe, forward_batch)
         
     | 
| 
       1482 
1493 
     | 
    
         
             
                    return q, k, v, forward_batch
         
     | 
| 
       1483 
1494 
     | 
    
         | 
| 
       1484 
1495 
     | 
    
         
             
                def forward_normal_core(self, q, k, v, forward_batch):
         
     | 
| 
         @@ -2288,20 +2299,11 @@ class DeepseekV2AttentionMLA(nn.Module): 
     | 
|
| 
       2288 
2299 
     | 
    
         
             
                    for i in range(forward_batch.num_prefix_chunks):
         
     | 
| 
       2289 
2300 
     | 
    
         
             
                        forward_batch.set_prefix_chunk_idx(i)
         
     | 
| 
       2290 
2301 
     | 
    
         | 
| 
      
 2302 
     | 
    
         
            +
                        kv_indices = forward_batch.prefix_chunk_kv_indices[i]
         
     | 
| 
       2291 
2303 
     | 
    
         
             
                        # Fetch latent cache from memory pool with precomputed chunked kv indices
         
     | 
| 
       2292 
     | 
    
         
            -
                         
     | 
| 
       2293 
     | 
    
         
            -
                             
     | 
| 
       2294 
     | 
    
         
            -
                        )
         
     | 
| 
       2295 
     | 
    
         
            -
                        latent_cache = (
         
     | 
| 
       2296 
     | 
    
         
            -
                            latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
         
     | 
| 
       2297 
     | 
    
         
            -
                            .contiguous()
         
     | 
| 
       2298 
     | 
    
         
            -
                            .to(q.dtype)
         
     | 
| 
      
 2304 
     | 
    
         
            +
                        kv_a_normed, k_pe = self._get_mla_kv_buffer(
         
     | 
| 
      
 2305 
     | 
    
         
            +
                            kv_indices, q.dtype, forward_batch
         
     | 
| 
       2299 
2306 
     | 
    
         
             
                        )
         
     | 
| 
       2300 
     | 
    
         
            -
             
     | 
| 
       2301 
     | 
    
         
            -
                        kv_a_normed, k_pe = latent_cache.split(
         
     | 
| 
       2302 
     | 
    
         
            -
                            [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
         
     | 
| 
       2303 
     | 
    
         
            -
                        )
         
     | 
| 
       2304 
     | 
    
         
            -
                        kv_a_normed = kv_a_normed.squeeze(1).contiguous()
         
     | 
| 
       2305 
2307 
     | 
    
         
             
                        kv = self.kv_b_proj(kv_a_normed)[0]
         
     | 
| 
       2306 
2308 
     | 
    
         
             
                        kv = kv.view(
         
     | 
| 
       2307 
2309 
     | 
    
         
             
                            -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
         
     | 
| 
         @@ -2376,6 +2378,107 @@ class DeepseekV2AttentionMLA(nn.Module): 
     | 
|
| 
       2376 
2378 
     | 
    
         
             
                    output, _ = self.o_proj(attn_output)
         
     | 
| 
       2377 
2379 
     | 
    
         
             
                    return output
         
     | 
| 
       2378 
2380 
     | 
    
         | 
| 
      
 2381 
     | 
    
         
            +
                def forward_normal_one_shot_prepare(
         
     | 
| 
      
 2382 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 2383 
     | 
    
         
            +
                    positions: torch.Tensor,
         
     | 
| 
      
 2384 
     | 
    
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 
      
 2385 
     | 
    
         
            +
                    forward_batch: ForwardBatch,
         
     | 
| 
      
 2386 
     | 
    
         
            +
                    zero_allocator: BumpAllocator,
         
     | 
| 
      
 2387 
     | 
    
         
            +
                ):
         
     | 
| 
      
 2388 
     | 
    
         
            +
                    forward_batch.mha_one_shot = True
         
     | 
| 
      
 2389 
     | 
    
         
            +
                    return self.forward_normal_prepare(
         
     | 
| 
      
 2390 
     | 
    
         
            +
                        positions, hidden_states, forward_batch, zero_allocator
         
     | 
| 
      
 2391 
     | 
    
         
            +
                    )
         
     | 
| 
      
 2392 
     | 
    
         
            +
             
     | 
| 
      
 2393 
     | 
    
         
            +
                def forward_normal_one_shot_core(self, q, k, v, forward_batch):
         
     | 
| 
      
 2394 
     | 
    
         
            +
                    has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
         
     | 
| 
      
 2395 
     | 
    
         
            +
                    # Only initialize the info once
         
     | 
| 
      
 2396 
     | 
    
         
            +
                    if has_extend_prefix and forward_batch.num_prefix_chunks is None:
         
     | 
| 
      
 2397 
     | 
    
         
            +
                        forward_batch.num_prefix_chunks = 0
         
     | 
| 
      
 2398 
     | 
    
         
            +
                        if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
         
     | 
| 
      
 2399 
     | 
    
         
            +
                            forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
         
     | 
| 
      
 2400 
     | 
    
         
            +
                    forward_batch.mha_return_lse = False
         
     | 
| 
      
 2401 
     | 
    
         
            +
                    # Do mha for extended part without prefix
         
     | 
| 
      
 2402 
     | 
    
         
            +
                    forward_batch.set_attn_attend_prefix_cache(False)
         
     | 
| 
      
 2403 
     | 
    
         
            +
                    return self.forward_normal_core(q, k, v, forward_batch)
         
     | 
| 
      
 2404 
     | 
    
         
            +
             
     | 
| 
      
 2405 
     | 
    
         
            +
                def _set_mla_kv_buffer(
         
     | 
| 
      
 2406 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 2407 
     | 
    
         
            +
                    latent_cache: torch.Tensor,
         
     | 
| 
      
 2408 
     | 
    
         
            +
                    kv_a: torch.Tensor,
         
     | 
| 
      
 2409 
     | 
    
         
            +
                    k_pe: torch.Tensor,
         
     | 
| 
      
 2410 
     | 
    
         
            +
                    forward_batch: ForwardBatch,
         
     | 
| 
      
 2411 
     | 
    
         
            +
                ):
         
     | 
| 
      
 2412 
     | 
    
         
            +
                    if _is_cuda:
         
     | 
| 
      
 2413 
     | 
    
         
            +
                        # Save latent cache
         
     | 
| 
      
 2414 
     | 
    
         
            +
                        forward_batch.token_to_kv_pool.set_mla_kv_buffer(
         
     | 
| 
      
 2415 
     | 
    
         
            +
                            self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
         
     | 
| 
      
 2416 
     | 
    
         
            +
                        )
         
     | 
| 
      
 2417 
     | 
    
         
            +
                    elif _is_npu:
         
     | 
| 
      
 2418 
     | 
    
         
            +
                        # To reduce a time-costing split operation
         
     | 
| 
      
 2419 
     | 
    
         
            +
                        forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
      
 2420 
     | 
    
         
            +
                            self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
         
     | 
| 
      
 2421 
     | 
    
         
            +
                        )
         
     | 
| 
      
 2422 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 2423 
     | 
    
         
            +
                        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
         
     | 
| 
      
 2424 
     | 
    
         
            +
                        latent_cache[:, :, self.kv_lora_rank :] = k_pe
         
     | 
| 
      
 2425 
     | 
    
         
            +
             
     | 
| 
      
 2426 
     | 
    
         
            +
                        # Save latent cache
         
     | 
| 
      
 2427 
     | 
    
         
            +
                        forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
      
 2428 
     | 
    
         
            +
                            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
         
     | 
| 
      
 2429 
     | 
    
         
            +
                        )
         
     | 
| 
      
 2430 
     | 
    
         
            +
             
     | 
| 
      
 2431 
     | 
    
         
            +
                def _get_mla_kv_buffer(
         
     | 
| 
      
 2432 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 2433 
     | 
    
         
            +
                    kv_indices: torch.Tensor,
         
     | 
| 
      
 2434 
     | 
    
         
            +
                    dst_dtype: torch.dtype,
         
     | 
| 
      
 2435 
     | 
    
         
            +
                    forward_batch: ForwardBatch,
         
     | 
| 
      
 2436 
     | 
    
         
            +
                ):
         
     | 
| 
      
 2437 
     | 
    
         
            +
                    if _is_cuda:
         
     | 
| 
      
 2438 
     | 
    
         
            +
                        kv_a, k_pe = forward_batch.token_to_kv_pool.get_mla_kv_buffer(
         
     | 
| 
      
 2439 
     | 
    
         
            +
                            self.attn_mha, kv_indices, dst_dtype
         
     | 
| 
      
 2440 
     | 
    
         
            +
                        )
         
     | 
| 
      
 2441 
     | 
    
         
            +
                        kv_a = kv_a.squeeze(1)
         
     | 
| 
      
 2442 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 2443 
     | 
    
         
            +
                        latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
         
     | 
| 
      
 2444 
     | 
    
         
            +
                            self.attn_mha.layer_id
         
     | 
| 
      
 2445 
     | 
    
         
            +
                        )
         
     | 
| 
      
 2446 
     | 
    
         
            +
                        latent_cache = latent_cache_buf[kv_indices].contiguous().to(dst_dtype)
         
     | 
| 
      
 2447 
     | 
    
         
            +
             
     | 
| 
      
 2448 
     | 
    
         
            +
                        kv_a, k_pe = latent_cache.split(
         
     | 
| 
      
 2449 
     | 
    
         
            +
                            [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
         
     | 
| 
      
 2450 
     | 
    
         
            +
                        )
         
     | 
| 
      
 2451 
     | 
    
         
            +
                        kv_a = kv_a.squeeze(1).contiguous()
         
     | 
| 
      
 2452 
     | 
    
         
            +
                    return kv_a, k_pe
         
     | 
| 
      
 2453 
     | 
    
         
            +
             
     | 
| 
      
 2454 
     | 
    
         
            +
                def _concat_and_cast_mha_k(self, k_nope, k_pe, forward_batch):
         
     | 
| 
      
 2455 
     | 
    
         
            +
                    # Temporary for DeepSeek V3/R1 only, but can generalize if needed
         
     | 
| 
      
 2456 
     | 
    
         
            +
                    k_shape = (k_nope.shape[0], self.num_local_heads, self.qk_head_dim)
         
     | 
| 
      
 2457 
     | 
    
         
            +
                    if (
         
     | 
| 
      
 2458 
     | 
    
         
            +
                        _is_cuda
         
     | 
| 
      
 2459 
     | 
    
         
            +
                        and (self.num_local_heads == 128)
         
     | 
| 
      
 2460 
     | 
    
         
            +
                        and (self.qk_nope_head_dim == 128)
         
     | 
| 
      
 2461 
     | 
    
         
            +
                        and (self.qk_rope_head_dim == 64)
         
     | 
| 
      
 2462 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 2463 
     | 
    
         
            +
                        k = k_nope.new_empty(*k_shape)
         
     | 
| 
      
 2464 
     | 
    
         
            +
                        concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
         
     | 
| 
      
 2465 
     | 
    
         
            +
                    elif _is_cuda:
         
     | 
| 
      
 2466 
     | 
    
         
            +
                        # fa3 mha support fp8 inputs
         
     | 
| 
      
 2467 
     | 
    
         
            +
                        if (
         
     | 
| 
      
 2468 
     | 
    
         
            +
                            self.current_attention_backend == "fa3"
         
     | 
| 
      
 2469 
     | 
    
         
            +
                            and self.kv_cache_dtype != "auto"
         
     | 
| 
      
 2470 
     | 
    
         
            +
                        ):
         
     | 
| 
      
 2471 
     | 
    
         
            +
                            attn_dtype = forward_batch.token_to_kv_pool.dtype
         
     | 
| 
      
 2472 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 2473 
     | 
    
         
            +
                            attn_dtype = k_nope.dtype
         
     | 
| 
      
 2474 
     | 
    
         
            +
                        k = k_nope.new_empty(*k_shape, dtype=attn_dtype)
         
     | 
| 
      
 2475 
     | 
    
         
            +
                        concat_and_cast_mha_k_triton(k, k_nope, k_pe)
         
     | 
| 
      
 2476 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 2477 
     | 
    
         
            +
                        k = k_nope.new_empty(*k_shape)
         
     | 
| 
      
 2478 
     | 
    
         
            +
                        k[..., : self.qk_nope_head_dim] = k_nope
         
     | 
| 
      
 2479 
     | 
    
         
            +
                        k[..., self.qk_nope_head_dim :] = k_pe
         
     | 
| 
      
 2480 
     | 
    
         
            +
                    return k
         
     | 
| 
      
 2481 
     | 
    
         
            +
             
     | 
| 
       2379 
2482 
     | 
    
         
             
                @staticmethod
         
     | 
| 
       2380 
2483 
     | 
    
         
             
                def _get_q_b_proj_quant_config(quant_config):
         
     | 
| 
       2381 
2484 
     | 
    
         
             
                    if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
         
     |