sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
from sglang.srt.layers.attention.nsa import index_buf_accessor
|
19
|
+
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
16
20
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
17
21
|
|
18
22
|
"""
|
@@ -27,7 +31,7 @@ KVCache actually holds the physical kv cache.
|
|
27
31
|
import abc
|
28
32
|
import logging
|
29
33
|
from contextlib import nullcontext
|
30
|
-
from typing import Dict, List, Optional, Tuple, Union
|
34
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
31
35
|
|
32
36
|
import numpy as np
|
33
37
|
import torch
|
@@ -38,6 +42,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
|
38
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
43
|
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
40
44
|
|
45
|
+
if TYPE_CHECKING:
|
46
|
+
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
47
|
+
|
41
48
|
logger = logging.getLogger(__name__)
|
42
49
|
|
43
50
|
GB = 1024 * 1024 * 1024
|
@@ -47,6 +54,10 @@ if _is_npu:
|
|
47
54
|
import torch_npu
|
48
55
|
|
49
56
|
|
57
|
+
def get_tensor_size_bytes(t: torch.Tensor):
|
58
|
+
return np.prod(t.shape) * t.dtype.itemsize
|
59
|
+
|
60
|
+
|
50
61
|
class ReqToTokenPool:
|
51
62
|
"""A memory pool that maps a request to its token locations."""
|
52
63
|
|
@@ -97,6 +108,211 @@ class ReqToTokenPool:
|
|
97
108
|
self.free_slots = list(range(self.size))
|
98
109
|
|
99
110
|
|
111
|
+
class MambaPool:
|
112
|
+
def __init__(
|
113
|
+
self,
|
114
|
+
size: int,
|
115
|
+
conv_dtype: torch.dtype,
|
116
|
+
ssm_dtype: torch.dtype,
|
117
|
+
num_mamba_layers: int,
|
118
|
+
conv_state_shape: Tuple[int, int],
|
119
|
+
temporal_state_shape: Tuple[int, int],
|
120
|
+
device: str,
|
121
|
+
speculative_num_draft_tokens: Optional[int] = None,
|
122
|
+
):
|
123
|
+
conv_state = torch.zeros(
|
124
|
+
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
125
|
+
dtype=conv_dtype,
|
126
|
+
device=device,
|
127
|
+
)
|
128
|
+
temporal_state = torch.zeros(
|
129
|
+
size=(num_mamba_layers, size + 1) + temporal_state_shape,
|
130
|
+
dtype=ssm_dtype,
|
131
|
+
device=device,
|
132
|
+
)
|
133
|
+
if speculative_num_draft_tokens is not None:
|
134
|
+
# Cache intermediate SSM states per draft token during target verify
|
135
|
+
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
136
|
+
intermediate_ssm_state_cache = torch.zeros(
|
137
|
+
size=(
|
138
|
+
num_mamba_layers,
|
139
|
+
size + 1,
|
140
|
+
speculative_num_draft_tokens,
|
141
|
+
temporal_state_shape[0],
|
142
|
+
temporal_state_shape[1],
|
143
|
+
temporal_state_shape[2],
|
144
|
+
),
|
145
|
+
dtype=ssm_dtype,
|
146
|
+
device="cuda",
|
147
|
+
)
|
148
|
+
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
149
|
+
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
150
|
+
intermediate_conv_window_cache = torch.zeros(
|
151
|
+
size=(
|
152
|
+
num_mamba_layers,
|
153
|
+
size + 1,
|
154
|
+
speculative_num_draft_tokens,
|
155
|
+
conv_state_shape[0],
|
156
|
+
conv_state_shape[1],
|
157
|
+
),
|
158
|
+
dtype=conv_dtype,
|
159
|
+
device="cuda",
|
160
|
+
)
|
161
|
+
self.mamba_cache = (
|
162
|
+
conv_state,
|
163
|
+
temporal_state,
|
164
|
+
intermediate_ssm_state_cache,
|
165
|
+
intermediate_conv_window_cache,
|
166
|
+
)
|
167
|
+
logger.info(
|
168
|
+
f"Mamba Cache is allocated. "
|
169
|
+
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
170
|
+
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
171
|
+
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
|
172
|
+
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
|
173
|
+
)
|
174
|
+
else:
|
175
|
+
self.mamba_cache = (conv_state, temporal_state)
|
176
|
+
logger.info(
|
177
|
+
f"Mamba Cache is allocated. "
|
178
|
+
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
179
|
+
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
180
|
+
)
|
181
|
+
self.size = size
|
182
|
+
self.free_slots = list(range(size))
|
183
|
+
self.mem_usage = self.get_mamba_size() / GB
|
184
|
+
|
185
|
+
def get_mamba_params_all_layers(self):
|
186
|
+
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
|
187
|
+
|
188
|
+
def get_mamba_params(self, layer_id: int):
|
189
|
+
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
|
190
|
+
|
191
|
+
def get_mamba_size(self):
|
192
|
+
return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
|
193
|
+
|
194
|
+
def available_size(self):
|
195
|
+
return len(self.free_slots)
|
196
|
+
|
197
|
+
def alloc(self, need_size: int) -> Optional[List[int]]:
|
198
|
+
if need_size > len(self.free_slots):
|
199
|
+
return None
|
200
|
+
|
201
|
+
select_index = self.free_slots[:need_size]
|
202
|
+
self.free_slots = self.free_slots[need_size:]
|
203
|
+
|
204
|
+
return select_index
|
205
|
+
|
206
|
+
def free(self, free_index: Union[int, List[int]]):
|
207
|
+
if isinstance(free_index, (int,)):
|
208
|
+
self.free_slots.append(free_index)
|
209
|
+
else:
|
210
|
+
self.free_slots.extend(free_index)
|
211
|
+
self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
|
212
|
+
|
213
|
+
def clear(self):
|
214
|
+
self.free_slots = list(range(self.size))
|
215
|
+
|
216
|
+
|
217
|
+
class HybridReqToTokenPool(ReqToTokenPool):
|
218
|
+
"""A memory pool that maps a request to its token locations."""
|
219
|
+
|
220
|
+
def __init__(
|
221
|
+
self,
|
222
|
+
size: int,
|
223
|
+
max_context_len: int,
|
224
|
+
device: str,
|
225
|
+
enable_memory_saver: bool,
|
226
|
+
conv_dtype: torch.dtype,
|
227
|
+
ssm_dtype: torch.dtype,
|
228
|
+
mamba_layers: List[int],
|
229
|
+
conv_state_shape: Tuple[int, int],
|
230
|
+
temporal_state_shape: Tuple[int, int],
|
231
|
+
speculative_num_draft_tokens: int,
|
232
|
+
):
|
233
|
+
super().__init__(
|
234
|
+
size=size,
|
235
|
+
max_context_len=max_context_len,
|
236
|
+
device=device,
|
237
|
+
enable_memory_saver=enable_memory_saver,
|
238
|
+
)
|
239
|
+
|
240
|
+
self.mamba_pool = MambaPool(
|
241
|
+
size,
|
242
|
+
conv_dtype,
|
243
|
+
ssm_dtype,
|
244
|
+
len(mamba_layers),
|
245
|
+
conv_state_shape,
|
246
|
+
temporal_state_shape,
|
247
|
+
device,
|
248
|
+
speculative_num_draft_tokens,
|
249
|
+
)
|
250
|
+
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
|
251
|
+
|
252
|
+
self.device = device
|
253
|
+
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
|
254
|
+
size, dtype=torch.int32, device=self.device
|
255
|
+
)
|
256
|
+
|
257
|
+
self.rid_to_mamba_index_mapping: Dict[str, int] = {}
|
258
|
+
self.mamba_index_to_rid_mapping: Dict[int, str] = {}
|
259
|
+
|
260
|
+
# For chunk prefill req, we do not need to allocate mamba cache,
|
261
|
+
# We could use allocated mamba cache instead.
|
262
|
+
def alloc(
|
263
|
+
self, need_size: int, reqs: Optional[List["Req"]] = None
|
264
|
+
) -> Optional[List[int]]:
|
265
|
+
select_index = super().alloc(need_size)
|
266
|
+
if select_index == None:
|
267
|
+
return None
|
268
|
+
|
269
|
+
mamba_index = []
|
270
|
+
for req in reqs:
|
271
|
+
rid = req.rid
|
272
|
+
if rid in self.rid_to_mamba_index_mapping:
|
273
|
+
mid = self.rid_to_mamba_index_mapping[rid]
|
274
|
+
elif (mid := self.mamba_pool.alloc(1)) is not None:
|
275
|
+
mid = mid[0]
|
276
|
+
self.rid_to_mamba_index_mapping[rid] = mid
|
277
|
+
self.mamba_index_to_rid_mapping[mid] = rid
|
278
|
+
mamba_index.append(mid)
|
279
|
+
assert len(select_index) == len(
|
280
|
+
mamba_index
|
281
|
+
), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
|
282
|
+
self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
|
283
|
+
mamba_index, dtype=torch.int32, device=self.device
|
284
|
+
)
|
285
|
+
return select_index
|
286
|
+
|
287
|
+
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
|
288
|
+
return self.req_index_to_mamba_index_mapping[req_indices]
|
289
|
+
|
290
|
+
def get_mamba_params(self, layer_id: int):
|
291
|
+
assert layer_id in self.mamba_map
|
292
|
+
return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
|
293
|
+
|
294
|
+
def get_mamba_params_all_layers(self):
|
295
|
+
return self.mamba_pool.get_mamba_params_all_layers()
|
296
|
+
|
297
|
+
# For chunk prefill, we can not free mamba cache, we need use it in the future
|
298
|
+
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
|
299
|
+
super().free(free_index)
|
300
|
+
if free_mamba_cache:
|
301
|
+
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
|
302
|
+
mamba_index_list = mamba_index.tolist()
|
303
|
+
if isinstance(mamba_index_list, int):
|
304
|
+
mamba_index_list = [mamba_index_list]
|
305
|
+
self.mamba_pool.free(mamba_index_list)
|
306
|
+
for mid in mamba_index_list:
|
307
|
+
rid = self.mamba_index_to_rid_mapping[mid]
|
308
|
+
self.mamba_index_to_rid_mapping.pop(mid)
|
309
|
+
self.rid_to_mamba_index_mapping.pop(rid)
|
310
|
+
|
311
|
+
def clear(self):
|
312
|
+
super().clear()
|
313
|
+
self.mamba_pool.clear()
|
314
|
+
|
315
|
+
|
100
316
|
class KVCache(abc.ABC):
|
101
317
|
@abc.abstractmethod
|
102
318
|
def __init__(
|
@@ -130,6 +346,29 @@ class KVCache(abc.ABC):
|
|
130
346
|
# used for chunked cpu-offloading
|
131
347
|
self.cpu_offloading_chunk_size = 8192
|
132
348
|
|
349
|
+
# default state for optional layer-wise transfer control
|
350
|
+
self.layer_transfer_counter = None
|
351
|
+
|
352
|
+
def _finalize_allocation_log(self, num_tokens: int):
|
353
|
+
"""Common logging and mem_usage computation for KV cache allocation.
|
354
|
+
Supports both tuple (K, V) size returns and single KV size returns.
|
355
|
+
"""
|
356
|
+
kv_size_bytes = self.get_kv_size_bytes()
|
357
|
+
if isinstance(kv_size_bytes, tuple):
|
358
|
+
k_size, v_size = kv_size_bytes
|
359
|
+
k_size_GB = k_size / GB
|
360
|
+
v_size_GB = v_size / GB
|
361
|
+
logger.info(
|
362
|
+
f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB"
|
363
|
+
)
|
364
|
+
self.mem_usage = k_size_GB + v_size_GB
|
365
|
+
else:
|
366
|
+
kv_size_GB = kv_size_bytes / GB
|
367
|
+
logger.info(
|
368
|
+
f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB"
|
369
|
+
)
|
370
|
+
self.mem_usage = kv_size_GB
|
371
|
+
|
133
372
|
@abc.abstractmethod
|
134
373
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
135
374
|
raise NotImplementedError()
|
@@ -152,7 +391,7 @@ class KVCache(abc.ABC):
|
|
152
391
|
) -> None:
|
153
392
|
raise NotImplementedError()
|
154
393
|
|
155
|
-
def register_layer_transfer_counter(self, layer_transfer_counter):
|
394
|
+
def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
|
156
395
|
self.layer_transfer_counter = layer_transfer_counter
|
157
396
|
|
158
397
|
def get_cpu_copy(self, indices):
|
@@ -205,15 +444,9 @@ class MHATokenToKVPool(KVCache):
|
|
205
444
|
|
206
445
|
self._create_buffers()
|
207
446
|
|
208
|
-
self.layer_transfer_counter = None
|
209
447
|
self.device_module = torch.get_device_module(self.device)
|
210
448
|
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
211
|
-
|
212
|
-
k_size, v_size = self.get_kv_size_bytes()
|
213
|
-
logger.info(
|
214
|
-
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
215
|
-
)
|
216
|
-
self.mem_usage = (k_size + v_size) / GB
|
449
|
+
self._finalize_allocation_log(size)
|
217
450
|
|
218
451
|
def _create_buffers(self):
|
219
452
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
@@ -269,10 +502,10 @@ class MHATokenToKVPool(KVCache):
|
|
269
502
|
assert hasattr(self, "v_buffer")
|
270
503
|
k_size_bytes = 0
|
271
504
|
for k_cache in self.k_buffer:
|
272
|
-
k_size_bytes +=
|
505
|
+
k_size_bytes += get_tensor_size_bytes(k_cache)
|
273
506
|
v_size_bytes = 0
|
274
507
|
for v_cache in self.v_buffer:
|
275
|
-
v_size_bytes +=
|
508
|
+
v_size_bytes += get_tensor_size_bytes(v_cache)
|
276
509
|
return k_size_bytes, v_size_bytes
|
277
510
|
|
278
511
|
# for disagg
|
@@ -352,7 +585,6 @@ class MHATokenToKVPool(KVCache):
|
|
352
585
|
# same applies to get_value_buffer and get_kv_buffer
|
353
586
|
if self.layer_transfer_counter is not None:
|
354
587
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
355
|
-
|
356
588
|
return self._get_key_buffer(layer_id)
|
357
589
|
|
358
590
|
def _get_value_buffer(self, layer_id: int):
|
@@ -420,41 +652,31 @@ class MHATokenToKVPool(KVCache):
|
|
420
652
|
)
|
421
653
|
|
422
654
|
|
423
|
-
class
|
424
|
-
"""KV cache with separate pools for full and
|
655
|
+
class HybridLinearKVPool(KVCache):
|
656
|
+
"""KV cache with separate pools for full and linear attention layers."""
|
425
657
|
|
426
658
|
def __init__(
|
427
659
|
self,
|
428
660
|
size: int,
|
429
|
-
size_swa: int,
|
430
661
|
dtype: torch.dtype,
|
662
|
+
page_size: int,
|
431
663
|
head_num: int,
|
432
664
|
head_dim: int,
|
433
|
-
swa_attention_layer_ids: List[int],
|
434
665
|
full_attention_layer_ids: List[int],
|
435
666
|
enable_kvcache_transpose: bool,
|
436
667
|
device: str,
|
437
668
|
):
|
438
669
|
self.size = size
|
439
|
-
self.size_swa = size_swa
|
440
670
|
self.dtype = dtype
|
441
671
|
self.device = device
|
442
|
-
self.swa_layer_nums = len(swa_attention_layer_ids)
|
443
672
|
self.full_layer_nums = len(full_attention_layer_ids)
|
444
|
-
self.page_size =
|
673
|
+
self.page_size = page_size
|
445
674
|
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
446
675
|
assert not enable_kvcache_transpose
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
dtype=dtype,
|
452
|
-
head_num=head_num,
|
453
|
-
head_dim=head_dim,
|
454
|
-
layer_num=self.swa_layer_nums,
|
455
|
-
device=device,
|
456
|
-
enable_memory_saver=False,
|
457
|
-
)
|
676
|
+
if _is_npu:
|
677
|
+
TokenToKVPoolClass = AscendTokenToKVPool
|
678
|
+
else:
|
679
|
+
TokenToKVPoolClass = MHATokenToKVPool
|
458
680
|
self.full_kv_pool = TokenToKVPoolClass(
|
459
681
|
size=size,
|
460
682
|
page_size=self.page_size,
|
@@ -465,6 +687,93 @@ class SWAKVPool(KVCache):
|
|
465
687
|
device=device,
|
466
688
|
enable_memory_saver=False,
|
467
689
|
)
|
690
|
+
self.full_attention_layer_id_mapping = {
|
691
|
+
id: i for i, id in enumerate(full_attention_layer_ids)
|
692
|
+
}
|
693
|
+
k_size, v_size = self.get_kv_size_bytes()
|
694
|
+
self.mem_usage = (k_size + v_size) / GB
|
695
|
+
|
696
|
+
def get_kv_size_bytes(self):
|
697
|
+
return self.full_kv_pool.get_kv_size_bytes()
|
698
|
+
|
699
|
+
def get_contiguous_buf_infos(self):
|
700
|
+
return self.full_kv_pool.get_contiguous_buf_infos()
|
701
|
+
|
702
|
+
def _transfer_full_attention_id(self, layer_id: int):
|
703
|
+
if layer_id not in self.full_attention_layer_id_mapping:
|
704
|
+
raise ValueError(
|
705
|
+
f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
|
706
|
+
)
|
707
|
+
return self.full_attention_layer_id_mapping[layer_id]
|
708
|
+
|
709
|
+
def get_key_buffer(self, layer_id: int):
|
710
|
+
layer_id = self._transfer_full_attention_id(layer_id)
|
711
|
+
return self.full_kv_pool.get_key_buffer(layer_id)
|
712
|
+
|
713
|
+
def get_value_buffer(self, layer_id: int):
|
714
|
+
layer_id = self._transfer_full_attention_id(layer_id)
|
715
|
+
return self.full_kv_pool.get_value_buffer(layer_id)
|
716
|
+
|
717
|
+
def get_kv_buffer(self, layer_id: int):
|
718
|
+
layer_id = self._transfer_full_attention_id(layer_id)
|
719
|
+
return self.full_kv_pool.get_kv_buffer(layer_id)
|
720
|
+
|
721
|
+
def set_kv_buffer(
|
722
|
+
self,
|
723
|
+
layer: RadixAttention,
|
724
|
+
loc: torch.Tensor,
|
725
|
+
cache_k: torch.Tensor,
|
726
|
+
cache_v: torch.Tensor,
|
727
|
+
k_scale: float = 1.0,
|
728
|
+
v_scale: float = 1.0,
|
729
|
+
):
|
730
|
+
layer_id = self._transfer_full_attention_id(layer.layer_id)
|
731
|
+
self.full_kv_pool.set_kv_buffer(
|
732
|
+
None,
|
733
|
+
loc,
|
734
|
+
cache_k,
|
735
|
+
cache_v,
|
736
|
+
k_scale,
|
737
|
+
v_scale,
|
738
|
+
layer_id_override=layer_id,
|
739
|
+
)
|
740
|
+
|
741
|
+
def get_v_head_dim(self):
|
742
|
+
return self.full_kv_pool.get_value_buffer(0).shape[-1]
|
743
|
+
|
744
|
+
|
745
|
+
class SWAKVPool(KVCache):
|
746
|
+
"""KV cache with separate pools for full and SWA attention layers."""
|
747
|
+
|
748
|
+
def __init__(
|
749
|
+
self,
|
750
|
+
size: int,
|
751
|
+
size_swa: int,
|
752
|
+
swa_attention_layer_ids: List[int],
|
753
|
+
full_attention_layer_ids: List[int],
|
754
|
+
enable_kvcache_transpose: bool,
|
755
|
+
token_to_kv_pool_class: KVCache = MHATokenToKVPool,
|
756
|
+
**kwargs,
|
757
|
+
):
|
758
|
+
self.size = size
|
759
|
+
self.size_swa = size_swa
|
760
|
+
self.swa_layer_nums = len(swa_attention_layer_ids)
|
761
|
+
self.full_layer_nums = len(full_attention_layer_ids)
|
762
|
+
kwargs["page_size"] = 1
|
763
|
+
kwargs["enable_memory_saver"] = False
|
764
|
+
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
765
|
+
assert not enable_kvcache_transpose
|
766
|
+
|
767
|
+
self.swa_kv_pool = token_to_kv_pool_class(
|
768
|
+
size=size_swa,
|
769
|
+
layer_num=self.swa_layer_nums,
|
770
|
+
**kwargs,
|
771
|
+
)
|
772
|
+
self.full_kv_pool = token_to_kv_pool_class(
|
773
|
+
size=size,
|
774
|
+
layer_num=self.full_layer_nums,
|
775
|
+
**kwargs,
|
776
|
+
)
|
468
777
|
self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
|
469
778
|
for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
|
470
779
|
self.layers_mapping[global_layer_id] = (full_attn_layer_id, False)
|
@@ -613,8 +922,12 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
|
613
922
|
cache_v: torch.Tensor,
|
614
923
|
k_scale: Optional[float] = None,
|
615
924
|
v_scale: Optional[float] = None,
|
925
|
+
layer_id_override: Optional[int] = None,
|
616
926
|
):
|
617
|
-
|
927
|
+
if layer_id_override is not None:
|
928
|
+
layer_id = layer_id_override
|
929
|
+
else:
|
930
|
+
layer_id = layer.layer_id
|
618
931
|
if cache_k.dtype != self.dtype:
|
619
932
|
if k_scale is not None:
|
620
933
|
cache_k.div_(k_scale)
|
@@ -719,6 +1032,8 @@ class MLATokenToKVPool(KVCache):
|
|
719
1032
|
enable_memory_saver: bool,
|
720
1033
|
start_layer: Optional[int] = None,
|
721
1034
|
end_layer: Optional[int] = None,
|
1035
|
+
use_nsa: bool = False,
|
1036
|
+
override_kv_cache_dim: Optional[int] = None,
|
722
1037
|
):
|
723
1038
|
super().__init__(
|
724
1039
|
size,
|
@@ -733,6 +1048,14 @@ class MLATokenToKVPool(KVCache):
|
|
733
1048
|
|
734
1049
|
self.kv_lora_rank = kv_lora_rank
|
735
1050
|
self.qk_rope_head_dim = qk_rope_head_dim
|
1051
|
+
self.use_nsa = use_nsa
|
1052
|
+
self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn
|
1053
|
+
# TODO do not hardcode
|
1054
|
+
self.kv_cache_dim = (
|
1055
|
+
656
|
1056
|
+
if self.use_nsa and self.nsa_kv_cache_store_fp8
|
1057
|
+
else (kv_lora_rank + qk_rope_head_dim)
|
1058
|
+
)
|
736
1059
|
|
737
1060
|
# for disagg with nvlink
|
738
1061
|
self.enable_custom_mem_pool = get_bool_env_var(
|
@@ -756,7 +1079,7 @@ class MLATokenToKVPool(KVCache):
|
|
756
1079
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
757
1080
|
self.kv_buffer = [
|
758
1081
|
torch.zeros(
|
759
|
-
(size + page_size, 1,
|
1082
|
+
(size + page_size, 1, self.kv_cache_dim),
|
760
1083
|
dtype=self.store_dtype,
|
761
1084
|
device=device,
|
762
1085
|
)
|
@@ -768,19 +1091,13 @@ class MLATokenToKVPool(KVCache):
|
|
768
1091
|
dtype=torch.uint64,
|
769
1092
|
device=self.device,
|
770
1093
|
)
|
771
|
-
self.
|
772
|
-
|
773
|
-
kv_size = self.get_kv_size_bytes()
|
774
|
-
logger.info(
|
775
|
-
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
776
|
-
)
|
777
|
-
self.mem_usage = kv_size / GB
|
1094
|
+
self._finalize_allocation_log(size)
|
778
1095
|
|
779
1096
|
def get_kv_size_bytes(self):
|
780
1097
|
assert hasattr(self, "kv_buffer")
|
781
1098
|
kv_size_bytes = 0
|
782
1099
|
for kv_cache in self.kv_buffer:
|
783
|
-
kv_size_bytes +=
|
1100
|
+
kv_size_bytes += get_tensor_size_bytes(kv_cache)
|
784
1101
|
return kv_size_bytes
|
785
1102
|
|
786
1103
|
# for disagg
|
@@ -825,6 +1142,7 @@ class MLATokenToKVPool(KVCache):
|
|
825
1142
|
cache_v: torch.Tensor,
|
826
1143
|
):
|
827
1144
|
layer_id = layer.layer_id
|
1145
|
+
assert not (self.use_nsa and self.nsa_kv_cache_store_fp8)
|
828
1146
|
if cache_k.dtype != self.dtype:
|
829
1147
|
cache_k = cache_k.to(self.dtype)
|
830
1148
|
if self.store_dtype != self.dtype:
|
@@ -842,16 +1160,28 @@ class MLATokenToKVPool(KVCache):
|
|
842
1160
|
cache_k_rope: torch.Tensor,
|
843
1161
|
):
|
844
1162
|
layer_id = layer.layer_id
|
845
|
-
if cache_k_nope.dtype != self.dtype:
|
846
|
-
cache_k_nope = cache_k_nope.to(self.dtype)
|
847
|
-
cache_k_rope = cache_k_rope.to(self.dtype)
|
848
|
-
if self.store_dtype != self.dtype:
|
849
|
-
cache_k_nope = cache_k_nope.view(self.store_dtype)
|
850
|
-
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
851
1163
|
|
852
|
-
|
853
|
-
|
854
|
-
|
1164
|
+
if self.use_nsa and self.nsa_kv_cache_store_fp8:
|
1165
|
+
# original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here
|
1166
|
+
# TODO no need to cat
|
1167
|
+
cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1)
|
1168
|
+
cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1)
|
1169
|
+
cache_k = cache_k.view(self.store_dtype)
|
1170
|
+
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
|
1171
|
+
else:
|
1172
|
+
if cache_k_nope.dtype != self.dtype:
|
1173
|
+
cache_k_nope = cache_k_nope.to(self.dtype)
|
1174
|
+
cache_k_rope = cache_k_rope.to(self.dtype)
|
1175
|
+
if self.store_dtype != self.dtype:
|
1176
|
+
cache_k_nope = cache_k_nope.view(self.store_dtype)
|
1177
|
+
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
1178
|
+
|
1179
|
+
set_mla_kv_buffer_triton(
|
1180
|
+
self.kv_buffer[layer_id - self.start_layer],
|
1181
|
+
loc,
|
1182
|
+
cache_k_nope,
|
1183
|
+
cache_k_rope,
|
1184
|
+
)
|
855
1185
|
|
856
1186
|
def get_cpu_copy(self, indices):
|
857
1187
|
torch.cuda.synchronize()
|
@@ -881,6 +1211,103 @@ class MLATokenToKVPool(KVCache):
|
|
881
1211
|
torch.cuda.synchronize()
|
882
1212
|
|
883
1213
|
|
1214
|
+
class NSATokenToKVPool(MLATokenToKVPool):
|
1215
|
+
def __init__(
|
1216
|
+
self,
|
1217
|
+
size: int,
|
1218
|
+
page_size: int,
|
1219
|
+
kv_lora_rank: int,
|
1220
|
+
dtype: torch.dtype,
|
1221
|
+
qk_rope_head_dim: int,
|
1222
|
+
layer_num: int,
|
1223
|
+
device: str,
|
1224
|
+
index_head_dim: int,
|
1225
|
+
enable_memory_saver: bool,
|
1226
|
+
start_layer: Optional[int] = None,
|
1227
|
+
end_layer: Optional[int] = None,
|
1228
|
+
):
|
1229
|
+
super().__init__(
|
1230
|
+
size,
|
1231
|
+
page_size,
|
1232
|
+
dtype,
|
1233
|
+
kv_lora_rank,
|
1234
|
+
qk_rope_head_dim,
|
1235
|
+
layer_num,
|
1236
|
+
device,
|
1237
|
+
enable_memory_saver,
|
1238
|
+
start_layer,
|
1239
|
+
end_layer,
|
1240
|
+
use_nsa=True,
|
1241
|
+
)
|
1242
|
+
# self.index_k_dtype = torch.float8_e4m3fn
|
1243
|
+
# self.index_k_scale_dtype = torch.float32
|
1244
|
+
self.index_head_dim = index_head_dim
|
1245
|
+
# num head == 1 and head dim == 128 for index_k in NSA
|
1246
|
+
assert index_head_dim == 128
|
1247
|
+
|
1248
|
+
self.quant_block_size = 128
|
1249
|
+
|
1250
|
+
assert self.page_size == 64
|
1251
|
+
self.index_k_with_scale_buffer = [
|
1252
|
+
torch.zeros(
|
1253
|
+
# Layout:
|
1254
|
+
# ref: test_attention.py :: kv_cache_cast_to_fp8
|
1255
|
+
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
|
1256
|
+
# data: for page i,
|
1257
|
+
# * buf[i, :page_size * head_dim] for fp8 data
|
1258
|
+
# * buf[i, page_size * head_dim:].view(float32) for scale
|
1259
|
+
(
|
1260
|
+
(size + page_size + 1) // self.page_size,
|
1261
|
+
self.page_size
|
1262
|
+
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
|
1263
|
+
),
|
1264
|
+
dtype=torch.uint8,
|
1265
|
+
device=device,
|
1266
|
+
)
|
1267
|
+
for _ in range(layer_num)
|
1268
|
+
]
|
1269
|
+
|
1270
|
+
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
|
1271
|
+
if self.layer_transfer_counter is not None:
|
1272
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
1273
|
+
return self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
1274
|
+
|
1275
|
+
def get_index_k_continuous(
|
1276
|
+
self,
|
1277
|
+
layer_id: int,
|
1278
|
+
seq_len: int,
|
1279
|
+
page_indices: torch.Tensor,
|
1280
|
+
):
|
1281
|
+
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
1282
|
+
return index_buf_accessor.GetK.execute(
|
1283
|
+
self, buf, seq_len=seq_len, page_indices=page_indices
|
1284
|
+
)
|
1285
|
+
|
1286
|
+
def get_index_k_scale_continuous(
|
1287
|
+
self,
|
1288
|
+
layer_id: int,
|
1289
|
+
seq_len: int,
|
1290
|
+
page_indices: torch.Tensor,
|
1291
|
+
):
|
1292
|
+
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
1293
|
+
return index_buf_accessor.GetS.execute(
|
1294
|
+
self, buf, seq_len=seq_len, page_indices=page_indices
|
1295
|
+
)
|
1296
|
+
|
1297
|
+
# TODO rename later (currently use diff name to avoid confusion)
|
1298
|
+
def set_index_k_and_scale_buffer(
|
1299
|
+
self,
|
1300
|
+
layer_id: int,
|
1301
|
+
loc: torch.Tensor,
|
1302
|
+
index_k: torch.Tensor,
|
1303
|
+
index_k_scale: torch.Tensor,
|
1304
|
+
) -> None:
|
1305
|
+
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
|
1306
|
+
index_buf_accessor.SetKAndS.execute(
|
1307
|
+
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
|
1308
|
+
)
|
1309
|
+
|
1310
|
+
|
884
1311
|
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
885
1312
|
def __init__(
|
886
1313
|
self,
|
@@ -889,6 +1316,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
889
1316
|
dtype: torch.dtype,
|
890
1317
|
kv_lora_rank: int,
|
891
1318
|
qk_rope_head_dim: int,
|
1319
|
+
index_head_dim: Optional[int],
|
892
1320
|
layer_num: int,
|
893
1321
|
device: str,
|
894
1322
|
enable_memory_saver: bool,
|
@@ -908,6 +1336,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
908
1336
|
|
909
1337
|
self.kv_lora_rank = kv_lora_rank
|
910
1338
|
self.qk_rope_head_dim = qk_rope_head_dim
|
1339
|
+
self.index_head_dim = index_head_dim
|
911
1340
|
|
912
1341
|
self.custom_mem_pool = None
|
913
1342
|
|
@@ -935,23 +1364,33 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
935
1364
|
dtype=self.store_dtype,
|
936
1365
|
device=self.device,
|
937
1366
|
)
|
1367
|
+
if self.index_head_dim is not None:
|
1368
|
+
self.index_k_buffer = torch.zeros(
|
1369
|
+
(
|
1370
|
+
layer_num,
|
1371
|
+
self.size // self.page_size + 1,
|
1372
|
+
self.page_size,
|
1373
|
+
1,
|
1374
|
+
self.index_head_dim,
|
1375
|
+
),
|
1376
|
+
dtype=self.store_dtype,
|
1377
|
+
device=self.device,
|
1378
|
+
)
|
938
1379
|
|
939
|
-
self.
|
940
|
-
|
941
|
-
kv_size = self.get_kv_size_bytes()
|
942
|
-
logger.info(
|
943
|
-
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
944
|
-
)
|
945
|
-
self.mem_usage = kv_size / GB
|
1380
|
+
self._finalize_allocation_log(size)
|
946
1381
|
|
947
1382
|
def get_kv_size_bytes(self):
|
948
1383
|
assert hasattr(self, "k_buffer")
|
949
1384
|
assert hasattr(self, "v_buffer")
|
950
1385
|
kv_size_bytes = 0
|
951
1386
|
for k_cache in self.k_buffer:
|
952
|
-
kv_size_bytes +=
|
1387
|
+
kv_size_bytes += get_tensor_size_bytes(k_cache)
|
953
1388
|
for v_cache in self.v_buffer:
|
954
|
-
kv_size_bytes +=
|
1389
|
+
kv_size_bytes += get_tensor_size_bytes(v_cache)
|
1390
|
+
if self.index_head_dim is not None:
|
1391
|
+
assert hasattr(self, "index_k_buffer")
|
1392
|
+
for index_k_cache in self.index_k_buffer:
|
1393
|
+
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
|
955
1394
|
return kv_size_bytes
|
956
1395
|
|
957
1396
|
def get_kv_buffer(self, layer_id: int):
|
@@ -978,6 +1417,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
978
1417
|
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
979
1418
|
return self.v_buffer[layer_id - self.start_layer]
|
980
1419
|
|
1420
|
+
def get_index_k_buffer(self, layer_id: int):
|
1421
|
+
if self.layer_transfer_counter is not None:
|
1422
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
1423
|
+
|
1424
|
+
if self.store_dtype != self.dtype:
|
1425
|
+
return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype)
|
1426
|
+
return self.index_k_buffer[layer_id - self.start_layer]
|
1427
|
+
|
981
1428
|
# for disagg
|
982
1429
|
def get_contiguous_buf_infos(self):
|
983
1430
|
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
@@ -990,6 +1437,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
990
1437
|
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
|
991
1438
|
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
|
992
1439
|
]
|
1440
|
+
if self.index_head_dim is not None:
|
1441
|
+
kv_data_ptrs += [
|
1442
|
+
self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)
|
1443
|
+
]
|
1444
|
+
kv_data_lens += [
|
1445
|
+
self.index_k_buffer[i].nbytes for i in range(self.layer_num)
|
1446
|
+
]
|
1447
|
+
kv_item_lens += [
|
1448
|
+
self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)
|
1449
|
+
]
|
993
1450
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
994
1451
|
|
995
1452
|
def set_kv_buffer(
|
@@ -1026,6 +1483,26 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
1026
1483
|
cache_v.view(-1, 1, self.qk_rope_head_dim),
|
1027
1484
|
)
|
1028
1485
|
|
1486
|
+
def set_index_k_buffer(
|
1487
|
+
self,
|
1488
|
+
layer_id: int,
|
1489
|
+
loc: torch.Tensor,
|
1490
|
+
index_k: torch.Tensor,
|
1491
|
+
):
|
1492
|
+
if index_k.dtype != self.dtype:
|
1493
|
+
index_k = index_k.to(self.dtype)
|
1494
|
+
|
1495
|
+
if self.store_dtype != self.dtype:
|
1496
|
+
index_k = index_k.view(self.store_dtype)
|
1497
|
+
|
1498
|
+
torch_npu.npu_scatter_nd_update_(
|
1499
|
+
self.index_k_buffer[layer_id - self.start_layer].view(
|
1500
|
+
-1, 1, self.index_head_dim
|
1501
|
+
),
|
1502
|
+
loc.view(-1, 1),
|
1503
|
+
index_k.view(-1, 1, self.index_head_dim),
|
1504
|
+
)
|
1505
|
+
|
1029
1506
|
|
1030
1507
|
class DoubleSparseTokenToKVPool(KVCache):
|
1031
1508
|
def __init__(
|