sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,13 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
from dataclasses import dataclass
|
19
|
+
|
20
|
+
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
|
21
|
+
from sglang.srt.layers.attention.nsa import index_buf_accessor
|
22
|
+
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
16
23
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
17
24
|
|
18
25
|
"""
|
@@ -27,7 +34,7 @@ KVCache actually holds the physical kv cache.
|
|
27
34
|
import abc
|
28
35
|
import logging
|
29
36
|
from contextlib import nullcontext
|
30
|
-
from typing import Dict, List, Optional, Tuple, Union
|
37
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
31
38
|
|
32
39
|
import numpy as np
|
33
40
|
import torch
|
@@ -38,6 +45,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
|
38
45
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
46
|
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
40
47
|
|
48
|
+
if TYPE_CHECKING:
|
49
|
+
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
50
|
+
|
41
51
|
logger = logging.getLogger(__name__)
|
42
52
|
|
43
53
|
GB = 1024 * 1024 * 1024
|
@@ -47,6 +57,10 @@ if _is_npu:
|
|
47
57
|
import torch_npu
|
48
58
|
|
49
59
|
|
60
|
+
def get_tensor_size_bytes(t: torch.Tensor):
|
61
|
+
return np.prod(t.shape) * t.dtype.itemsize
|
62
|
+
|
63
|
+
|
50
64
|
class ReqToTokenPool:
|
51
65
|
"""A memory pool that maps a request to its token locations."""
|
52
66
|
|
@@ -97,6 +111,225 @@ class ReqToTokenPool:
|
|
97
111
|
self.free_slots = list(range(self.size))
|
98
112
|
|
99
113
|
|
114
|
+
class MambaPool:
|
115
|
+
@dataclass(frozen=True, kw_only=True)
|
116
|
+
class State:
|
117
|
+
conv: torch.Tensor
|
118
|
+
temporal: torch.Tensor
|
119
|
+
|
120
|
+
def at_layer_idx(self, layer: int):
|
121
|
+
return type(self)(**{k: v[layer] for k, v in vars(self).items()})
|
122
|
+
|
123
|
+
def mem_usage_bytes(self):
|
124
|
+
return sum(get_tensor_size_bytes(t) for t in vars(self).values())
|
125
|
+
|
126
|
+
@dataclass(frozen=True, kw_only=True)
|
127
|
+
class SpeculativeState(State):
|
128
|
+
intermediate_ssm: torch.Tensor
|
129
|
+
intermediate_conv_window: torch.Tensor
|
130
|
+
|
131
|
+
def __init__(
|
132
|
+
self,
|
133
|
+
*,
|
134
|
+
size: int,
|
135
|
+
cache_params: "Mamba2CacheParams",
|
136
|
+
device: str,
|
137
|
+
speculative_num_draft_tokens: Optional[int] = None,
|
138
|
+
):
|
139
|
+
conv_state_shape = cache_params.shape.conv
|
140
|
+
temporal_state_shape = cache_params.shape.temporal
|
141
|
+
conv_dtype = cache_params.dtype.conv
|
142
|
+
ssm_dtype = cache_params.dtype.temporal
|
143
|
+
num_mamba_layers = len(cache_params.layers)
|
144
|
+
|
145
|
+
# assume conv_state = (dim, state_len)
|
146
|
+
assert conv_state_shape[0] > conv_state_shape[1]
|
147
|
+
conv_state = torch.zeros(
|
148
|
+
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
149
|
+
dtype=conv_dtype,
|
150
|
+
device=device,
|
151
|
+
)
|
152
|
+
temporal_state = torch.zeros(
|
153
|
+
size=(num_mamba_layers, size + 1) + temporal_state_shape,
|
154
|
+
dtype=ssm_dtype,
|
155
|
+
device=device,
|
156
|
+
)
|
157
|
+
if speculative_num_draft_tokens is not None:
|
158
|
+
# Cache intermediate SSM states per draft token during target verify
|
159
|
+
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
160
|
+
intermediate_ssm_state_cache = torch.zeros(
|
161
|
+
size=(
|
162
|
+
num_mamba_layers,
|
163
|
+
size + 1,
|
164
|
+
speculative_num_draft_tokens,
|
165
|
+
temporal_state_shape[0],
|
166
|
+
temporal_state_shape[1],
|
167
|
+
temporal_state_shape[2],
|
168
|
+
),
|
169
|
+
dtype=ssm_dtype,
|
170
|
+
device="cuda",
|
171
|
+
)
|
172
|
+
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
173
|
+
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
174
|
+
intermediate_conv_window_cache = torch.zeros(
|
175
|
+
size=(
|
176
|
+
num_mamba_layers,
|
177
|
+
size + 1,
|
178
|
+
speculative_num_draft_tokens,
|
179
|
+
conv_state_shape[0],
|
180
|
+
conv_state_shape[1],
|
181
|
+
),
|
182
|
+
dtype=conv_dtype,
|
183
|
+
device="cuda",
|
184
|
+
)
|
185
|
+
self.mamba_cache = self.SpeculativeState(
|
186
|
+
conv=conv_state,
|
187
|
+
temporal=temporal_state,
|
188
|
+
intermediate_ssm=intermediate_ssm_state_cache,
|
189
|
+
intermediate_conv_window=intermediate_conv_window_cache,
|
190
|
+
)
|
191
|
+
logger.info(
|
192
|
+
f"Mamba Cache is allocated. "
|
193
|
+
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
194
|
+
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
195
|
+
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
|
196
|
+
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
|
197
|
+
)
|
198
|
+
else:
|
199
|
+
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
|
200
|
+
logger.info(
|
201
|
+
f"Mamba Cache is allocated. "
|
202
|
+
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
203
|
+
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
204
|
+
)
|
205
|
+
self.size = size
|
206
|
+
self.free_slots = list(range(size))
|
207
|
+
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
|
208
|
+
|
209
|
+
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
|
210
|
+
assert isinstance(self.mamba_cache, self.SpeculativeState)
|
211
|
+
return self.mamba_cache
|
212
|
+
|
213
|
+
def mamba2_layer_cache(self, layer_id: int):
|
214
|
+
return self.mamba_cache.at_layer_idx(layer_id)
|
215
|
+
|
216
|
+
def available_size(self):
|
217
|
+
return len(self.free_slots)
|
218
|
+
|
219
|
+
def alloc(self, need_size: int) -> Optional[List[int]]:
|
220
|
+
if need_size > len(self.free_slots):
|
221
|
+
return None
|
222
|
+
|
223
|
+
select_index = self.free_slots[:need_size]
|
224
|
+
self.free_slots = self.free_slots[need_size:]
|
225
|
+
|
226
|
+
return select_index
|
227
|
+
|
228
|
+
def free(self, free_index: Union[int, List[int]]):
|
229
|
+
if isinstance(free_index, (int,)):
|
230
|
+
self.free_slots.append(free_index)
|
231
|
+
else:
|
232
|
+
self.free_slots.extend(free_index)
|
233
|
+
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
|
234
|
+
:, free_index
|
235
|
+
] = 0
|
236
|
+
|
237
|
+
def clear(self):
|
238
|
+
self.free_slots = list(range(self.size))
|
239
|
+
|
240
|
+
|
241
|
+
class HybridReqToTokenPool(ReqToTokenPool):
|
242
|
+
"""A memory pool that maps a request to its token locations."""
|
243
|
+
|
244
|
+
def __init__(
|
245
|
+
self,
|
246
|
+
*,
|
247
|
+
size: int,
|
248
|
+
max_context_len: int,
|
249
|
+
device: str,
|
250
|
+
enable_memory_saver: bool,
|
251
|
+
cache_params: "Mamba2CacheParams",
|
252
|
+
speculative_num_draft_tokens: int = None,
|
253
|
+
):
|
254
|
+
super().__init__(
|
255
|
+
size=size,
|
256
|
+
max_context_len=max_context_len,
|
257
|
+
device=device,
|
258
|
+
enable_memory_saver=enable_memory_saver,
|
259
|
+
)
|
260
|
+
|
261
|
+
self.mamba_pool = MambaPool(
|
262
|
+
size=size,
|
263
|
+
cache_params=cache_params,
|
264
|
+
device=device,
|
265
|
+
speculative_num_draft_tokens=speculative_num_draft_tokens,
|
266
|
+
)
|
267
|
+
self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
|
268
|
+
|
269
|
+
self.device = device
|
270
|
+
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
|
271
|
+
size, dtype=torch.int32, device=self.device
|
272
|
+
)
|
273
|
+
|
274
|
+
self.rid_to_mamba_index_mapping: Dict[str, int] = {}
|
275
|
+
self.mamba_index_to_rid_mapping: Dict[int, str] = {}
|
276
|
+
|
277
|
+
# For chunk prefill req, we do not need to allocate mamba cache,
|
278
|
+
# We could use allocated mamba cache instead.
|
279
|
+
def alloc(
|
280
|
+
self, need_size: int, reqs: Optional[List["Req"]] = None
|
281
|
+
) -> Optional[List[int]]:
|
282
|
+
select_index = super().alloc(need_size)
|
283
|
+
if select_index == None:
|
284
|
+
return None
|
285
|
+
|
286
|
+
mamba_index = []
|
287
|
+
for req in reqs:
|
288
|
+
rid = req.rid
|
289
|
+
if rid in self.rid_to_mamba_index_mapping:
|
290
|
+
mid = self.rid_to_mamba_index_mapping[rid]
|
291
|
+
elif (mid := self.mamba_pool.alloc(1)) is not None:
|
292
|
+
mid = mid[0]
|
293
|
+
self.rid_to_mamba_index_mapping[rid] = mid
|
294
|
+
self.mamba_index_to_rid_mapping[mid] = rid
|
295
|
+
mamba_index.append(mid)
|
296
|
+
assert len(select_index) == len(
|
297
|
+
mamba_index
|
298
|
+
), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
|
299
|
+
self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
|
300
|
+
mamba_index, dtype=torch.int32, device=self.device
|
301
|
+
)
|
302
|
+
return select_index
|
303
|
+
|
304
|
+
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
|
305
|
+
return self.req_index_to_mamba_index_mapping[req_indices]
|
306
|
+
|
307
|
+
def mamba2_layer_cache(self, layer_id: int):
|
308
|
+
assert layer_id in self.mamba_map
|
309
|
+
return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id])
|
310
|
+
|
311
|
+
def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState:
|
312
|
+
return self.mamba_pool.get_speculative_mamba2_params_all_layers()
|
313
|
+
|
314
|
+
# For chunk prefill, we can not free mamba cache, we need use it in the future
|
315
|
+
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
|
316
|
+
super().free(free_index)
|
317
|
+
if free_mamba_cache:
|
318
|
+
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
|
319
|
+
mamba_index_list = mamba_index.tolist()
|
320
|
+
if isinstance(mamba_index_list, int):
|
321
|
+
mamba_index_list = [mamba_index_list]
|
322
|
+
self.mamba_pool.free(mamba_index_list)
|
323
|
+
for mid in mamba_index_list:
|
324
|
+
rid = self.mamba_index_to_rid_mapping[mid]
|
325
|
+
self.mamba_index_to_rid_mapping.pop(mid)
|
326
|
+
self.rid_to_mamba_index_mapping.pop(rid)
|
327
|
+
|
328
|
+
def clear(self):
|
329
|
+
super().clear()
|
330
|
+
self.mamba_pool.clear()
|
331
|
+
|
332
|
+
|
100
333
|
class KVCache(abc.ABC):
|
101
334
|
@abc.abstractmethod
|
102
335
|
def __init__(
|
@@ -130,6 +363,29 @@ class KVCache(abc.ABC):
|
|
130
363
|
# used for chunked cpu-offloading
|
131
364
|
self.cpu_offloading_chunk_size = 8192
|
132
365
|
|
366
|
+
# default state for optional layer-wise transfer control
|
367
|
+
self.layer_transfer_counter = None
|
368
|
+
|
369
|
+
def _finalize_allocation_log(self, num_tokens: int):
|
370
|
+
"""Common logging and mem_usage computation for KV cache allocation.
|
371
|
+
Supports both tuple (K, V) size returns and single KV size returns.
|
372
|
+
"""
|
373
|
+
kv_size_bytes = self.get_kv_size_bytes()
|
374
|
+
if isinstance(kv_size_bytes, tuple):
|
375
|
+
k_size, v_size = kv_size_bytes
|
376
|
+
k_size_GB = k_size / GB
|
377
|
+
v_size_GB = v_size / GB
|
378
|
+
logger.info(
|
379
|
+
f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB"
|
380
|
+
)
|
381
|
+
self.mem_usage = k_size_GB + v_size_GB
|
382
|
+
else:
|
383
|
+
kv_size_GB = kv_size_bytes / GB
|
384
|
+
logger.info(
|
385
|
+
f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB"
|
386
|
+
)
|
387
|
+
self.mem_usage = kv_size_GB
|
388
|
+
|
133
389
|
@abc.abstractmethod
|
134
390
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
135
391
|
raise NotImplementedError()
|
@@ -152,7 +408,7 @@ class KVCache(abc.ABC):
|
|
152
408
|
) -> None:
|
153
409
|
raise NotImplementedError()
|
154
410
|
|
155
|
-
def register_layer_transfer_counter(self, layer_transfer_counter):
|
411
|
+
def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
|
156
412
|
self.layer_transfer_counter = layer_transfer_counter
|
157
413
|
|
158
414
|
def get_cpu_copy(self, indices):
|
@@ -176,6 +432,7 @@ class MHATokenToKVPool(KVCache):
|
|
176
432
|
enable_memory_saver: bool,
|
177
433
|
start_layer: Optional[int] = None,
|
178
434
|
end_layer: Optional[int] = None,
|
435
|
+
enable_kv_cache_copy: bool = False,
|
179
436
|
):
|
180
437
|
super().__init__(
|
181
438
|
size,
|
@@ -205,15 +462,58 @@ class MHATokenToKVPool(KVCache):
|
|
205
462
|
|
206
463
|
self._create_buffers()
|
207
464
|
|
208
|
-
self.layer_transfer_counter = None
|
209
465
|
self.device_module = torch.get_device_module(self.device)
|
210
466
|
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
211
467
|
|
212
|
-
|
213
|
-
|
214
|
-
|
468
|
+
if enable_kv_cache_copy:
|
469
|
+
self._init_kv_copy_and_warmup()
|
470
|
+
else:
|
471
|
+
self._kv_copy_config = None
|
472
|
+
|
473
|
+
self._finalize_allocation_log(size)
|
474
|
+
|
475
|
+
def _init_kv_copy_and_warmup(self):
|
476
|
+
# Heuristics for KV copy tiling
|
477
|
+
_KV_COPY_STRIDE_THRESHOLD_LARGE = 8192
|
478
|
+
_KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096
|
479
|
+
_KV_COPY_TILE_SIZE_LARGE = 512
|
480
|
+
_KV_COPY_TILE_SIZE_MEDIUM = 256
|
481
|
+
_KV_COPY_TILE_SIZE_SMALL = 128
|
482
|
+
_KV_COPY_NUM_WARPS_LARGE_TILE = 8
|
483
|
+
_KV_COPY_NUM_WARPS_SMALL_TILE = 4
|
484
|
+
|
485
|
+
stride_bytes = int(self.data_strides[0].item())
|
486
|
+
if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE:
|
487
|
+
bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE
|
488
|
+
elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM:
|
489
|
+
bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM
|
490
|
+
else:
|
491
|
+
bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL
|
492
|
+
|
493
|
+
self._kv_copy_config = {
|
494
|
+
"bytes_per_tile": bytes_per_tile,
|
495
|
+
"byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile,
|
496
|
+
"num_warps": (
|
497
|
+
_KV_COPY_NUM_WARPS_SMALL_TILE
|
498
|
+
if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM
|
499
|
+
else _KV_COPY_NUM_WARPS_LARGE_TILE
|
500
|
+
),
|
501
|
+
}
|
502
|
+
|
503
|
+
dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device)
|
504
|
+
grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"])
|
505
|
+
|
506
|
+
copy_all_layer_kv_cache_tiled[grid](
|
507
|
+
self.data_ptrs,
|
508
|
+
self.data_strides,
|
509
|
+
dummy_loc,
|
510
|
+
dummy_loc,
|
511
|
+
1,
|
512
|
+
1,
|
513
|
+
BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"],
|
514
|
+
num_warps=self._kv_copy_config["num_warps"],
|
515
|
+
num_stages=2,
|
215
516
|
)
|
216
|
-
self.mem_usage = (k_size + v_size) / GB
|
217
517
|
|
218
518
|
def _create_buffers(self):
|
219
519
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
@@ -269,10 +569,10 @@ class MHATokenToKVPool(KVCache):
|
|
269
569
|
assert hasattr(self, "v_buffer")
|
270
570
|
k_size_bytes = 0
|
271
571
|
for k_cache in self.k_buffer:
|
272
|
-
k_size_bytes +=
|
572
|
+
k_size_bytes += get_tensor_size_bytes(k_cache)
|
273
573
|
v_size_bytes = 0
|
274
574
|
for v_cache in self.v_buffer:
|
275
|
-
v_size_bytes +=
|
575
|
+
v_size_bytes += get_tensor_size_bytes(v_cache)
|
276
576
|
return k_size_bytes, v_size_bytes
|
277
577
|
|
278
578
|
# for disagg
|
@@ -352,7 +652,6 @@ class MHATokenToKVPool(KVCache):
|
|
352
652
|
# same applies to get_value_buffer and get_kv_buffer
|
353
653
|
if self.layer_transfer_counter is not None:
|
354
654
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
355
|
-
|
356
655
|
return self._get_key_buffer(layer_id)
|
357
656
|
|
358
657
|
def _get_value_buffer(self, layer_id: int):
|
@@ -410,60 +709,156 @@ class MHATokenToKVPool(KVCache):
|
|
410
709
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
411
710
|
|
412
711
|
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
|
413
|
-
|
712
|
+
N = tgt_loc.numel()
|
713
|
+
if N == 0:
|
714
|
+
return
|
715
|
+
|
716
|
+
assert (
|
717
|
+
self._kv_copy_config is not None
|
718
|
+
), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__"
|
719
|
+
|
720
|
+
cfg = self._kv_copy_config
|
721
|
+
N_upper = next_power_of_2(N)
|
722
|
+
grid = (self.data_ptrs.numel(), cfg["byte_tiles"])
|
723
|
+
|
724
|
+
copy_all_layer_kv_cache_tiled[grid](
|
414
725
|
self.data_ptrs,
|
415
726
|
self.data_strides,
|
416
727
|
tgt_loc,
|
417
728
|
src_loc,
|
418
|
-
|
419
|
-
|
729
|
+
N,
|
730
|
+
N_upper,
|
731
|
+
BYTES_PER_TILE=cfg["bytes_per_tile"],
|
732
|
+
num_warps=cfg["num_warps"],
|
733
|
+
num_stages=2,
|
420
734
|
)
|
421
735
|
|
422
736
|
|
423
|
-
class
|
424
|
-
"""KV cache with separate pools for full and
|
737
|
+
class HybridLinearKVPool(KVCache):
|
738
|
+
"""KV cache with separate pools for full and linear attention layers."""
|
425
739
|
|
426
740
|
def __init__(
|
427
741
|
self,
|
428
742
|
size: int,
|
429
|
-
size_swa: int,
|
430
743
|
dtype: torch.dtype,
|
744
|
+
page_size: int,
|
431
745
|
head_num: int,
|
432
746
|
head_dim: int,
|
433
|
-
swa_attention_layer_ids: List[int],
|
434
747
|
full_attention_layer_ids: List[int],
|
435
748
|
enable_kvcache_transpose: bool,
|
436
749
|
device: str,
|
437
750
|
):
|
438
751
|
self.size = size
|
439
|
-
self.size_swa = size_swa
|
440
752
|
self.dtype = dtype
|
441
753
|
self.device = device
|
442
|
-
self.swa_layer_nums = len(swa_attention_layer_ids)
|
443
754
|
self.full_layer_nums = len(full_attention_layer_ids)
|
444
|
-
self.page_size =
|
755
|
+
self.page_size = page_size
|
445
756
|
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
446
757
|
assert not enable_kvcache_transpose
|
447
|
-
|
448
|
-
|
449
|
-
|
758
|
+
if _is_npu:
|
759
|
+
TokenToKVPoolClass = AscendTokenToKVPool
|
760
|
+
else:
|
761
|
+
TokenToKVPoolClass = MHATokenToKVPool
|
762
|
+
self.full_kv_pool = TokenToKVPoolClass(
|
763
|
+
size=size,
|
450
764
|
page_size=self.page_size,
|
451
765
|
dtype=dtype,
|
452
766
|
head_num=head_num,
|
453
767
|
head_dim=head_dim,
|
454
|
-
layer_num=self.
|
768
|
+
layer_num=self.full_layer_nums,
|
455
769
|
device=device,
|
456
770
|
enable_memory_saver=False,
|
457
771
|
)
|
458
|
-
self.
|
772
|
+
self.full_attention_layer_id_mapping = {
|
773
|
+
id: i for i, id in enumerate(full_attention_layer_ids)
|
774
|
+
}
|
775
|
+
k_size, v_size = self.get_kv_size_bytes()
|
776
|
+
self.mem_usage = (k_size + v_size) / GB
|
777
|
+
|
778
|
+
def get_kv_size_bytes(self):
|
779
|
+
return self.full_kv_pool.get_kv_size_bytes()
|
780
|
+
|
781
|
+
def get_contiguous_buf_infos(self):
|
782
|
+
return self.full_kv_pool.get_contiguous_buf_infos()
|
783
|
+
|
784
|
+
def _transfer_full_attention_id(self, layer_id: int):
|
785
|
+
if layer_id not in self.full_attention_layer_id_mapping:
|
786
|
+
raise ValueError(
|
787
|
+
f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
|
788
|
+
)
|
789
|
+
return self.full_attention_layer_id_mapping[layer_id]
|
790
|
+
|
791
|
+
def get_key_buffer(self, layer_id: int):
|
792
|
+
layer_id = self._transfer_full_attention_id(layer_id)
|
793
|
+
return self.full_kv_pool.get_key_buffer(layer_id)
|
794
|
+
|
795
|
+
def get_value_buffer(self, layer_id: int):
|
796
|
+
layer_id = self._transfer_full_attention_id(layer_id)
|
797
|
+
return self.full_kv_pool.get_value_buffer(layer_id)
|
798
|
+
|
799
|
+
def get_kv_buffer(self, layer_id: int):
|
800
|
+
layer_id = self._transfer_full_attention_id(layer_id)
|
801
|
+
return self.full_kv_pool.get_kv_buffer(layer_id)
|
802
|
+
|
803
|
+
def set_kv_buffer(
|
804
|
+
self,
|
805
|
+
layer: RadixAttention,
|
806
|
+
loc: torch.Tensor,
|
807
|
+
cache_k: torch.Tensor,
|
808
|
+
cache_v: torch.Tensor,
|
809
|
+
k_scale: float = 1.0,
|
810
|
+
v_scale: float = 1.0,
|
811
|
+
):
|
812
|
+
layer_id = self._transfer_full_attention_id(layer.layer_id)
|
813
|
+
self.full_kv_pool.set_kv_buffer(
|
814
|
+
None,
|
815
|
+
loc,
|
816
|
+
cache_k,
|
817
|
+
cache_v,
|
818
|
+
k_scale,
|
819
|
+
v_scale,
|
820
|
+
layer_id_override=layer_id,
|
821
|
+
)
|
822
|
+
|
823
|
+
def get_v_head_dim(self):
|
824
|
+
return self.full_kv_pool.get_value_buffer(0).shape[-1]
|
825
|
+
|
826
|
+
|
827
|
+
class SWAKVPool(KVCache):
|
828
|
+
"""KV cache with separate pools for full and SWA attention layers."""
|
829
|
+
|
830
|
+
def __init__(
|
831
|
+
self,
|
832
|
+
size: int,
|
833
|
+
size_swa: int,
|
834
|
+
dtype: torch.dtype,
|
835
|
+
swa_attention_layer_ids: List[int],
|
836
|
+
full_attention_layer_ids: List[int],
|
837
|
+
enable_kvcache_transpose: bool,
|
838
|
+
token_to_kv_pool_class: KVCache = MHATokenToKVPool,
|
839
|
+
**kwargs,
|
840
|
+
):
|
841
|
+
self.size = size
|
842
|
+
self.size_swa = size_swa
|
843
|
+
self.dtype = dtype
|
844
|
+
self.swa_layer_nums = len(swa_attention_layer_ids)
|
845
|
+
self.full_layer_nums = len(full_attention_layer_ids)
|
846
|
+
kwargs["page_size"] = 1
|
847
|
+
kwargs["enable_memory_saver"] = False
|
848
|
+
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
849
|
+
assert not enable_kvcache_transpose
|
850
|
+
|
851
|
+
self.swa_kv_pool = token_to_kv_pool_class(
|
852
|
+
size=size_swa,
|
853
|
+
dtype=dtype,
|
854
|
+
layer_num=self.swa_layer_nums,
|
855
|
+
**kwargs,
|
856
|
+
)
|
857
|
+
self.full_kv_pool = token_to_kv_pool_class(
|
459
858
|
size=size,
|
460
|
-
page_size=self.page_size,
|
461
859
|
dtype=dtype,
|
462
|
-
head_num=head_num,
|
463
|
-
head_dim=head_dim,
|
464
860
|
layer_num=self.full_layer_nums,
|
465
|
-
|
466
|
-
enable_memory_saver=False,
|
861
|
+
**kwargs,
|
467
862
|
)
|
468
863
|
self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
|
469
864
|
for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
|
@@ -613,8 +1008,12 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
|
613
1008
|
cache_v: torch.Tensor,
|
614
1009
|
k_scale: Optional[float] = None,
|
615
1010
|
v_scale: Optional[float] = None,
|
1011
|
+
layer_id_override: Optional[int] = None,
|
616
1012
|
):
|
617
|
-
|
1013
|
+
if layer_id_override is not None:
|
1014
|
+
layer_id = layer_id_override
|
1015
|
+
else:
|
1016
|
+
layer_id = layer.layer_id
|
618
1017
|
if cache_k.dtype != self.dtype:
|
619
1018
|
if k_scale is not None:
|
620
1019
|
cache_k.div_(k_scale)
|
@@ -719,6 +1118,8 @@ class MLATokenToKVPool(KVCache):
|
|
719
1118
|
enable_memory_saver: bool,
|
720
1119
|
start_layer: Optional[int] = None,
|
721
1120
|
end_layer: Optional[int] = None,
|
1121
|
+
use_nsa: bool = False,
|
1122
|
+
override_kv_cache_dim: Optional[int] = None,
|
722
1123
|
):
|
723
1124
|
super().__init__(
|
724
1125
|
size,
|
@@ -733,6 +1134,14 @@ class MLATokenToKVPool(KVCache):
|
|
733
1134
|
|
734
1135
|
self.kv_lora_rank = kv_lora_rank
|
735
1136
|
self.qk_rope_head_dim = qk_rope_head_dim
|
1137
|
+
self.use_nsa = use_nsa
|
1138
|
+
self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn
|
1139
|
+
# TODO do not hardcode
|
1140
|
+
self.kv_cache_dim = (
|
1141
|
+
656
|
1142
|
+
if self.use_nsa and self.nsa_kv_cache_store_fp8
|
1143
|
+
else (kv_lora_rank + qk_rope_head_dim)
|
1144
|
+
)
|
736
1145
|
|
737
1146
|
# for disagg with nvlink
|
738
1147
|
self.enable_custom_mem_pool = get_bool_env_var(
|
@@ -756,7 +1165,7 @@ class MLATokenToKVPool(KVCache):
|
|
756
1165
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
757
1166
|
self.kv_buffer = [
|
758
1167
|
torch.zeros(
|
759
|
-
(size + page_size, 1,
|
1168
|
+
(size + page_size, 1, self.kv_cache_dim),
|
760
1169
|
dtype=self.store_dtype,
|
761
1170
|
device=device,
|
762
1171
|
)
|
@@ -768,19 +1177,15 @@ class MLATokenToKVPool(KVCache):
|
|
768
1177
|
dtype=torch.uint64,
|
769
1178
|
device=self.device,
|
770
1179
|
)
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
logger.info(
|
775
|
-
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
776
|
-
)
|
777
|
-
self.mem_usage = kv_size / GB
|
1180
|
+
if not use_nsa:
|
1181
|
+
# NSA will allocate indexer KV cache later and then log the total size
|
1182
|
+
self._finalize_allocation_log(size)
|
778
1183
|
|
779
1184
|
def get_kv_size_bytes(self):
|
780
1185
|
assert hasattr(self, "kv_buffer")
|
781
1186
|
kv_size_bytes = 0
|
782
1187
|
for kv_cache in self.kv_buffer:
|
783
|
-
kv_size_bytes +=
|
1188
|
+
kv_size_bytes += get_tensor_size_bytes(kv_cache)
|
784
1189
|
return kv_size_bytes
|
785
1190
|
|
786
1191
|
# for disagg
|
@@ -825,6 +1230,7 @@ class MLATokenToKVPool(KVCache):
|
|
825
1230
|
cache_v: torch.Tensor,
|
826
1231
|
):
|
827
1232
|
layer_id = layer.layer_id
|
1233
|
+
assert not (self.use_nsa and self.nsa_kv_cache_store_fp8)
|
828
1234
|
if cache_k.dtype != self.dtype:
|
829
1235
|
cache_k = cache_k.to(self.dtype)
|
830
1236
|
if self.store_dtype != self.dtype:
|
@@ -842,16 +1248,28 @@ class MLATokenToKVPool(KVCache):
|
|
842
1248
|
cache_k_rope: torch.Tensor,
|
843
1249
|
):
|
844
1250
|
layer_id = layer.layer_id
|
845
|
-
if cache_k_nope.dtype != self.dtype:
|
846
|
-
cache_k_nope = cache_k_nope.to(self.dtype)
|
847
|
-
cache_k_rope = cache_k_rope.to(self.dtype)
|
848
|
-
if self.store_dtype != self.dtype:
|
849
|
-
cache_k_nope = cache_k_nope.view(self.store_dtype)
|
850
|
-
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
851
1251
|
|
852
|
-
|
853
|
-
|
854
|
-
|
1252
|
+
if self.use_nsa and self.nsa_kv_cache_store_fp8:
|
1253
|
+
# original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here
|
1254
|
+
# TODO no need to cat
|
1255
|
+
cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1)
|
1256
|
+
cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1)
|
1257
|
+
cache_k = cache_k.view(self.store_dtype)
|
1258
|
+
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
|
1259
|
+
else:
|
1260
|
+
if cache_k_nope.dtype != self.dtype:
|
1261
|
+
cache_k_nope = cache_k_nope.to(self.dtype)
|
1262
|
+
cache_k_rope = cache_k_rope.to(self.dtype)
|
1263
|
+
if self.store_dtype != self.dtype:
|
1264
|
+
cache_k_nope = cache_k_nope.view(self.store_dtype)
|
1265
|
+
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
1266
|
+
|
1267
|
+
set_mla_kv_buffer_triton(
|
1268
|
+
self.kv_buffer[layer_id - self.start_layer],
|
1269
|
+
loc,
|
1270
|
+
cache_k_nope,
|
1271
|
+
cache_k_rope,
|
1272
|
+
)
|
855
1273
|
|
856
1274
|
def get_cpu_copy(self, indices):
|
857
1275
|
torch.cuda.synchronize()
|
@@ -881,6 +1299,111 @@ class MLATokenToKVPool(KVCache):
|
|
881
1299
|
torch.cuda.synchronize()
|
882
1300
|
|
883
1301
|
|
1302
|
+
class NSATokenToKVPool(MLATokenToKVPool):
|
1303
|
+
quant_block_size = 128
|
1304
|
+
index_k_with_scale_buffer_dtype = torch.uint8
|
1305
|
+
|
1306
|
+
def __init__(
|
1307
|
+
self,
|
1308
|
+
size: int,
|
1309
|
+
page_size: int,
|
1310
|
+
kv_lora_rank: int,
|
1311
|
+
dtype: torch.dtype,
|
1312
|
+
qk_rope_head_dim: int,
|
1313
|
+
layer_num: int,
|
1314
|
+
device: str,
|
1315
|
+
index_head_dim: int,
|
1316
|
+
enable_memory_saver: bool,
|
1317
|
+
start_layer: Optional[int] = None,
|
1318
|
+
end_layer: Optional[int] = None,
|
1319
|
+
):
|
1320
|
+
super().__init__(
|
1321
|
+
size,
|
1322
|
+
page_size,
|
1323
|
+
dtype,
|
1324
|
+
kv_lora_rank,
|
1325
|
+
qk_rope_head_dim,
|
1326
|
+
layer_num,
|
1327
|
+
device,
|
1328
|
+
enable_memory_saver,
|
1329
|
+
start_layer,
|
1330
|
+
end_layer,
|
1331
|
+
use_nsa=True,
|
1332
|
+
)
|
1333
|
+
# self.index_k_dtype = torch.float8_e4m3fn
|
1334
|
+
# self.index_k_scale_dtype = torch.float32
|
1335
|
+
self.index_head_dim = index_head_dim
|
1336
|
+
# num head == 1 and head dim == 128 for index_k in NSA
|
1337
|
+
assert index_head_dim == 128
|
1338
|
+
|
1339
|
+
assert self.page_size == 64
|
1340
|
+
self.index_k_with_scale_buffer = [
|
1341
|
+
torch.zeros(
|
1342
|
+
# Layout:
|
1343
|
+
# ref: test_attention.py :: kv_cache_cast_to_fp8
|
1344
|
+
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
|
1345
|
+
# data: for page i,
|
1346
|
+
# * buf[i, :page_size * head_dim] for fp8 data
|
1347
|
+
# * buf[i, page_size * head_dim:].view(float32) for scale
|
1348
|
+
(
|
1349
|
+
(size + page_size + 1) // self.page_size,
|
1350
|
+
self.page_size
|
1351
|
+
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
|
1352
|
+
),
|
1353
|
+
dtype=self.index_k_with_scale_buffer_dtype,
|
1354
|
+
device=device,
|
1355
|
+
)
|
1356
|
+
for _ in range(layer_num)
|
1357
|
+
]
|
1358
|
+
self._finalize_allocation_log(size)
|
1359
|
+
|
1360
|
+
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
|
1361
|
+
if self.layer_transfer_counter is not None:
|
1362
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
1363
|
+
return self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
1364
|
+
|
1365
|
+
def get_index_k_continuous(
|
1366
|
+
self,
|
1367
|
+
layer_id: int,
|
1368
|
+
seq_len: int,
|
1369
|
+
page_indices: torch.Tensor,
|
1370
|
+
):
|
1371
|
+
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
1372
|
+
return index_buf_accessor.GetK.execute(
|
1373
|
+
self, buf, seq_len=seq_len, page_indices=page_indices
|
1374
|
+
)
|
1375
|
+
|
1376
|
+
def get_index_k_scale_continuous(
|
1377
|
+
self,
|
1378
|
+
layer_id: int,
|
1379
|
+
seq_len: int,
|
1380
|
+
page_indices: torch.Tensor,
|
1381
|
+
):
|
1382
|
+
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
1383
|
+
return index_buf_accessor.GetS.execute(
|
1384
|
+
self, buf, seq_len=seq_len, page_indices=page_indices
|
1385
|
+
)
|
1386
|
+
|
1387
|
+
# TODO rename later (currently use diff name to avoid confusion)
|
1388
|
+
def set_index_k_and_scale_buffer(
|
1389
|
+
self,
|
1390
|
+
layer_id: int,
|
1391
|
+
loc: torch.Tensor,
|
1392
|
+
index_k: torch.Tensor,
|
1393
|
+
index_k_scale: torch.Tensor,
|
1394
|
+
) -> None:
|
1395
|
+
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
1396
|
+
index_buf_accessor.SetKAndS.execute(
|
1397
|
+
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
|
1398
|
+
)
|
1399
|
+
|
1400
|
+
def get_kv_size_bytes(self):
|
1401
|
+
kv_size_bytes = super().get_kv_size_bytes()
|
1402
|
+
for index_k_cache in self.index_k_with_scale_buffer:
|
1403
|
+
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
|
1404
|
+
return kv_size_bytes
|
1405
|
+
|
1406
|
+
|
884
1407
|
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
885
1408
|
def __init__(
|
886
1409
|
self,
|
@@ -889,6 +1412,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
889
1412
|
dtype: torch.dtype,
|
890
1413
|
kv_lora_rank: int,
|
891
1414
|
qk_rope_head_dim: int,
|
1415
|
+
index_head_dim: Optional[int],
|
892
1416
|
layer_num: int,
|
893
1417
|
device: str,
|
894
1418
|
enable_memory_saver: bool,
|
@@ -908,6 +1432,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
908
1432
|
|
909
1433
|
self.kv_lora_rank = kv_lora_rank
|
910
1434
|
self.qk_rope_head_dim = qk_rope_head_dim
|
1435
|
+
self.index_head_dim = index_head_dim
|
911
1436
|
|
912
1437
|
self.custom_mem_pool = None
|
913
1438
|
|
@@ -935,23 +1460,33 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
935
1460
|
dtype=self.store_dtype,
|
936
1461
|
device=self.device,
|
937
1462
|
)
|
1463
|
+
if self.index_head_dim is not None:
|
1464
|
+
self.index_k_buffer = torch.zeros(
|
1465
|
+
(
|
1466
|
+
layer_num,
|
1467
|
+
self.size // self.page_size + 1,
|
1468
|
+
self.page_size,
|
1469
|
+
1,
|
1470
|
+
self.index_head_dim,
|
1471
|
+
),
|
1472
|
+
dtype=self.store_dtype,
|
1473
|
+
device=self.device,
|
1474
|
+
)
|
938
1475
|
|
939
|
-
self.
|
940
|
-
|
941
|
-
kv_size = self.get_kv_size_bytes()
|
942
|
-
logger.info(
|
943
|
-
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
944
|
-
)
|
945
|
-
self.mem_usage = kv_size / GB
|
1476
|
+
self._finalize_allocation_log(size)
|
946
1477
|
|
947
1478
|
def get_kv_size_bytes(self):
|
948
1479
|
assert hasattr(self, "k_buffer")
|
949
1480
|
assert hasattr(self, "v_buffer")
|
950
1481
|
kv_size_bytes = 0
|
951
1482
|
for k_cache in self.k_buffer:
|
952
|
-
kv_size_bytes +=
|
1483
|
+
kv_size_bytes += get_tensor_size_bytes(k_cache)
|
953
1484
|
for v_cache in self.v_buffer:
|
954
|
-
kv_size_bytes +=
|
1485
|
+
kv_size_bytes += get_tensor_size_bytes(v_cache)
|
1486
|
+
if self.index_head_dim is not None:
|
1487
|
+
assert hasattr(self, "index_k_buffer")
|
1488
|
+
for index_k_cache in self.index_k_buffer:
|
1489
|
+
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
|
955
1490
|
return kv_size_bytes
|
956
1491
|
|
957
1492
|
def get_kv_buffer(self, layer_id: int):
|
@@ -978,6 +1513,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
978
1513
|
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
979
1514
|
return self.v_buffer[layer_id - self.start_layer]
|
980
1515
|
|
1516
|
+
def get_index_k_buffer(self, layer_id: int):
|
1517
|
+
if self.layer_transfer_counter is not None:
|
1518
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
1519
|
+
|
1520
|
+
if self.store_dtype != self.dtype:
|
1521
|
+
return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype)
|
1522
|
+
return self.index_k_buffer[layer_id - self.start_layer]
|
1523
|
+
|
981
1524
|
# for disagg
|
982
1525
|
def get_contiguous_buf_infos(self):
|
983
1526
|
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
@@ -990,6 +1533,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
990
1533
|
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
|
991
1534
|
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
|
992
1535
|
]
|
1536
|
+
if self.index_head_dim is not None:
|
1537
|
+
kv_data_ptrs += [
|
1538
|
+
self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)
|
1539
|
+
]
|
1540
|
+
kv_data_lens += [
|
1541
|
+
self.index_k_buffer[i].nbytes for i in range(self.layer_num)
|
1542
|
+
]
|
1543
|
+
kv_item_lens += [
|
1544
|
+
self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)
|
1545
|
+
]
|
993
1546
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
994
1547
|
|
995
1548
|
def set_kv_buffer(
|
@@ -1026,6 +1579,26 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
1026
1579
|
cache_v.view(-1, 1, self.qk_rope_head_dim),
|
1027
1580
|
)
|
1028
1581
|
|
1582
|
+
def set_index_k_buffer(
|
1583
|
+
self,
|
1584
|
+
layer_id: int,
|
1585
|
+
loc: torch.Tensor,
|
1586
|
+
index_k: torch.Tensor,
|
1587
|
+
):
|
1588
|
+
if index_k.dtype != self.dtype:
|
1589
|
+
index_k = index_k.to(self.dtype)
|
1590
|
+
|
1591
|
+
if self.store_dtype != self.dtype:
|
1592
|
+
index_k = index_k.view(self.store_dtype)
|
1593
|
+
|
1594
|
+
torch_npu.npu_scatter_nd_update_(
|
1595
|
+
self.index_k_buffer[layer_id - self.start_layer].view(
|
1596
|
+
-1, 1, self.index_head_dim
|
1597
|
+
),
|
1598
|
+
loc.view(-1, 1),
|
1599
|
+
index_k.view(-1, 1, self.index_head_dim),
|
1600
|
+
)
|
1601
|
+
|
1029
1602
|
|
1030
1603
|
class DoubleSparseTokenToKVPool(KVCache):
|
1031
1604
|
def __init__(
|
@@ -1107,38 +1680,36 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
1107
1680
|
|
1108
1681
|
|
1109
1682
|
@triton.jit
|
1110
|
-
def
|
1683
|
+
def copy_all_layer_kv_cache_tiled(
|
1111
1684
|
data_ptrs,
|
1112
1685
|
strides,
|
1113
1686
|
tgt_loc_ptr,
|
1114
1687
|
src_loc_ptr,
|
1115
1688
|
num_locs,
|
1116
1689
|
num_locs_upper: tl.constexpr,
|
1690
|
+
BYTES_PER_TILE: tl.constexpr,
|
1117
1691
|
):
|
1118
|
-
|
1119
|
-
|
1692
|
+
"""2D tiled kernel. Safe for in-place copy."""
|
1120
1693
|
bid = tl.program_id(0)
|
1694
|
+
tid = tl.program_id(1)
|
1695
|
+
|
1121
1696
|
stride = tl.load(strides + bid)
|
1697
|
+
base_ptr = tl.load(data_ptrs + bid)
|
1698
|
+
base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8))
|
1122
1699
|
|
1123
|
-
|
1124
|
-
|
1700
|
+
byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE)
|
1701
|
+
mask_byte = byte_off < stride
|
1702
|
+
tl.multiple_of(byte_off, 16)
|
1125
1703
|
|
1126
|
-
|
1127
|
-
|
1128
|
-
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
1704
|
+
loc_idx = tl.arange(0, num_locs_upper)
|
1705
|
+
mask_loc = loc_idx < num_locs
|
1129
1706
|
|
1130
|
-
|
1131
|
-
|
1707
|
+
src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0)
|
1708
|
+
tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0)
|
1132
1709
|
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
)
|
1140
|
-
tl.store(
|
1141
|
-
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
|
1142
|
-
value,
|
1143
|
-
mask=mask,
|
1144
|
-
)
|
1710
|
+
src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :]
|
1711
|
+
tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :]
|
1712
|
+
|
1713
|
+
mask = mask_loc[:, None] & mask_byte[None, :]
|
1714
|
+
vals = tl.load(src_ptr, mask=mask)
|
1715
|
+
tl.store(tgt_ptr, vals, mask=mask)
|