sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -15,9 +15,12 @@ limitations under the License.
|
|
|
15
15
|
|
|
16
16
|
from __future__ import annotations
|
|
17
17
|
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
|
|
20
|
+
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
|
|
18
21
|
from sglang.srt.layers.attention.nsa import index_buf_accessor
|
|
19
22
|
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
|
20
|
-
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
23
|
+
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
21
24
|
|
|
22
25
|
"""
|
|
23
26
|
Memory pool.
|
|
@@ -44,6 +47,8 @@ from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
|
|
44
47
|
|
|
45
48
|
if TYPE_CHECKING:
|
|
46
49
|
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
|
50
|
+
from sglang.srt.managers.schedule_batch import Req
|
|
51
|
+
|
|
47
52
|
|
|
48
53
|
logger = logging.getLogger(__name__)
|
|
49
54
|
|
|
@@ -109,92 +114,135 @@ class ReqToTokenPool:
|
|
|
109
114
|
|
|
110
115
|
|
|
111
116
|
class MambaPool:
|
|
117
|
+
@dataclass(frozen=True, kw_only=True)
|
|
118
|
+
class State:
|
|
119
|
+
conv: torch.Tensor
|
|
120
|
+
temporal: torch.Tensor
|
|
121
|
+
|
|
122
|
+
def at_layer_idx(self, layer: int):
|
|
123
|
+
return type(self)(**{k: v[layer] for k, v in vars(self).items()})
|
|
124
|
+
|
|
125
|
+
def mem_usage_bytes(self):
|
|
126
|
+
return sum(get_tensor_size_bytes(t) for t in vars(self).values())
|
|
127
|
+
|
|
128
|
+
@dataclass(frozen=True, kw_only=True)
|
|
129
|
+
class SpeculativeState(State):
|
|
130
|
+
intermediate_ssm: torch.Tensor
|
|
131
|
+
intermediate_conv_window: torch.Tensor
|
|
132
|
+
|
|
112
133
|
def __init__(
|
|
113
134
|
self,
|
|
135
|
+
*,
|
|
114
136
|
size: int,
|
|
115
|
-
|
|
116
|
-
ssm_dtype: torch.dtype,
|
|
117
|
-
num_mamba_layers: int,
|
|
118
|
-
conv_state_shape: Tuple[int, int],
|
|
119
|
-
temporal_state_shape: Tuple[int, int],
|
|
137
|
+
cache_params: "Mamba2CacheParams",
|
|
120
138
|
device: str,
|
|
121
139
|
speculative_num_draft_tokens: Optional[int] = None,
|
|
122
140
|
):
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
141
|
+
conv_state_shape = cache_params.shape.conv
|
|
142
|
+
temporal_state_shape = cache_params.shape.temporal
|
|
143
|
+
conv_dtype = cache_params.dtype.conv
|
|
144
|
+
ssm_dtype = cache_params.dtype.temporal
|
|
145
|
+
num_mamba_layers = len(cache_params.layers)
|
|
146
|
+
|
|
147
|
+
# for disagg with nvlink
|
|
148
|
+
self.enable_custom_mem_pool = get_bool_env_var(
|
|
149
|
+
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
|
132
150
|
)
|
|
133
|
-
if
|
|
134
|
-
#
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
size=(
|
|
152
|
-
num_mamba_layers,
|
|
153
|
-
size + 1,
|
|
154
|
-
speculative_num_draft_tokens,
|
|
155
|
-
conv_state_shape[0],
|
|
156
|
-
conv_state_shape[1],
|
|
157
|
-
),
|
|
151
|
+
if self.enable_custom_mem_pool:
|
|
152
|
+
# TODO(shangming): abstract custom allocator class for more backends
|
|
153
|
+
from mooncake.allocator import NVLinkAllocator
|
|
154
|
+
|
|
155
|
+
allocator = NVLinkAllocator.get_allocator(self.device)
|
|
156
|
+
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
|
157
|
+
else:
|
|
158
|
+
self.custom_mem_pool = None
|
|
159
|
+
|
|
160
|
+
with (
|
|
161
|
+
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
|
162
|
+
if self.enable_custom_mem_pool
|
|
163
|
+
else nullcontext()
|
|
164
|
+
):
|
|
165
|
+
# assume conv_state = (dim, state_len)
|
|
166
|
+
assert conv_state_shape[0] > conv_state_shape[1]
|
|
167
|
+
conv_state = torch.zeros(
|
|
168
|
+
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
|
158
169
|
dtype=conv_dtype,
|
|
159
|
-
device=
|
|
160
|
-
)
|
|
161
|
-
self.mamba_cache = (
|
|
162
|
-
conv_state,
|
|
163
|
-
temporal_state,
|
|
164
|
-
intermediate_ssm_state_cache,
|
|
165
|
-
intermediate_conv_window_cache,
|
|
170
|
+
device=device,
|
|
166
171
|
)
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
|
|
172
|
-
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
|
|
172
|
+
temporal_state = torch.zeros(
|
|
173
|
+
size=(num_mamba_layers, size + 1) + temporal_state_shape,
|
|
174
|
+
dtype=ssm_dtype,
|
|
175
|
+
device=device,
|
|
173
176
|
)
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
177
|
+
if speculative_num_draft_tokens is not None:
|
|
178
|
+
# Cache intermediate SSM states per draft token during target verify
|
|
179
|
+
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
|
180
|
+
intermediate_ssm_state_cache = torch.zeros(
|
|
181
|
+
size=(
|
|
182
|
+
num_mamba_layers,
|
|
183
|
+
size + 1,
|
|
184
|
+
speculative_num_draft_tokens,
|
|
185
|
+
temporal_state_shape[0],
|
|
186
|
+
temporal_state_shape[1],
|
|
187
|
+
temporal_state_shape[2],
|
|
188
|
+
),
|
|
189
|
+
dtype=ssm_dtype,
|
|
190
|
+
device="cuda",
|
|
191
|
+
)
|
|
192
|
+
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
|
193
|
+
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
|
194
|
+
intermediate_conv_window_cache = torch.zeros(
|
|
195
|
+
size=(
|
|
196
|
+
num_mamba_layers,
|
|
197
|
+
size + 1,
|
|
198
|
+
speculative_num_draft_tokens,
|
|
199
|
+
conv_state_shape[0],
|
|
200
|
+
conv_state_shape[1],
|
|
201
|
+
),
|
|
202
|
+
dtype=conv_dtype,
|
|
203
|
+
device="cuda",
|
|
204
|
+
)
|
|
205
|
+
self.mamba_cache = self.SpeculativeState(
|
|
206
|
+
conv=conv_state,
|
|
207
|
+
temporal=temporal_state,
|
|
208
|
+
intermediate_ssm=intermediate_ssm_state_cache,
|
|
209
|
+
intermediate_conv_window=intermediate_conv_window_cache,
|
|
210
|
+
)
|
|
211
|
+
logger.info(
|
|
212
|
+
f"Mamba Cache is allocated. "
|
|
213
|
+
f"max_mamba_cache_size: {size}, "
|
|
214
|
+
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
|
215
|
+
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
|
216
|
+
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
|
|
217
|
+
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
|
|
221
|
+
logger.info(
|
|
222
|
+
f"Mamba Cache is allocated. "
|
|
223
|
+
f"max_mamba_cache_size: {size}, "
|
|
224
|
+
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
|
225
|
+
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
|
226
|
+
)
|
|
227
|
+
self.size = size
|
|
228
|
+
self.device = device
|
|
229
|
+
self.free_slots = torch.arange(
|
|
230
|
+
self.size, dtype=torch.int64, device=self.device
|
|
180
231
|
)
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
self.mem_usage = self.get_mamba_size() / GB
|
|
232
|
+
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
|
|
233
|
+
self.num_mamba_layers = num_mamba_layers
|
|
184
234
|
|
|
185
|
-
def
|
|
186
|
-
|
|
235
|
+
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
|
|
236
|
+
assert isinstance(self.mamba_cache, self.SpeculativeState)
|
|
237
|
+
return self.mamba_cache
|
|
187
238
|
|
|
188
|
-
def
|
|
189
|
-
return
|
|
190
|
-
|
|
191
|
-
def get_mamba_size(self):
|
|
192
|
-
return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
|
|
239
|
+
def mamba2_layer_cache(self, layer_id: int):
|
|
240
|
+
return self.mamba_cache.at_layer_idx(layer_id)
|
|
193
241
|
|
|
194
242
|
def available_size(self):
|
|
195
243
|
return len(self.free_slots)
|
|
196
244
|
|
|
197
|
-
def alloc(self, need_size: int) -> Optional[
|
|
245
|
+
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
|
|
198
246
|
if need_size > len(self.free_slots):
|
|
199
247
|
return None
|
|
200
248
|
|
|
@@ -203,15 +251,46 @@ class MambaPool:
|
|
|
203
251
|
|
|
204
252
|
return select_index
|
|
205
253
|
|
|
206
|
-
def free(self, free_index:
|
|
207
|
-
if
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
254
|
+
def free(self, free_index: torch.Tensor):
|
|
255
|
+
if free_index.numel() == 0:
|
|
256
|
+
return
|
|
257
|
+
self.free_slots = torch.cat((self.free_slots, free_index))
|
|
258
|
+
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
|
|
259
|
+
:, free_index
|
|
260
|
+
] = 0
|
|
212
261
|
|
|
213
262
|
def clear(self):
|
|
214
|
-
self.free_slots =
|
|
263
|
+
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
|
|
264
|
+
|
|
265
|
+
def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
|
|
266
|
+
self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
|
|
267
|
+
self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
|
|
268
|
+
:, src_index
|
|
269
|
+
]
|
|
270
|
+
return
|
|
271
|
+
|
|
272
|
+
def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]:
|
|
273
|
+
dst_index = self.alloc(1)
|
|
274
|
+
if dst_index == None:
|
|
275
|
+
return None
|
|
276
|
+
self.copy_from(src_index, dst_index)
|
|
277
|
+
return dst_index
|
|
278
|
+
|
|
279
|
+
def get_contiguous_buf_infos(self):
|
|
280
|
+
state_tensors = [
|
|
281
|
+
getattr(self.mamba_cache, field) for field in vars(self.mamba_cache)
|
|
282
|
+
]
|
|
283
|
+
data_ptrs, data_lens, item_lens = [], [], []
|
|
284
|
+
|
|
285
|
+
for _, state_tensor in enumerate(state_tensors):
|
|
286
|
+
data_ptrs += [
|
|
287
|
+
state_tensor[i].data_ptr() for i in range(self.num_mamba_layers)
|
|
288
|
+
]
|
|
289
|
+
data_lens += [state_tensor[i].nbytes for i in range(self.num_mamba_layers)]
|
|
290
|
+
item_lens += [
|
|
291
|
+
state_tensor[i][0].nbytes for i in range(self.num_mamba_layers)
|
|
292
|
+
]
|
|
293
|
+
return data_ptrs, data_lens, item_lens
|
|
215
294
|
|
|
216
295
|
|
|
217
296
|
class HybridReqToTokenPool(ReqToTokenPool):
|
|
@@ -219,16 +298,14 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
|
|
219
298
|
|
|
220
299
|
def __init__(
|
|
221
300
|
self,
|
|
301
|
+
*,
|
|
222
302
|
size: int,
|
|
303
|
+
mamba_size: int,
|
|
223
304
|
max_context_len: int,
|
|
224
305
|
device: str,
|
|
225
306
|
enable_memory_saver: bool,
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
mamba_layers: List[int],
|
|
229
|
-
conv_state_shape: Tuple[int, int],
|
|
230
|
-
temporal_state_shape: Tuple[int, int],
|
|
231
|
-
speculative_num_draft_tokens: int,
|
|
307
|
+
cache_params: "Mamba2CacheParams",
|
|
308
|
+
speculative_num_draft_tokens: int = None,
|
|
232
309
|
):
|
|
233
310
|
super().__init__(
|
|
234
311
|
size=size,
|
|
@@ -236,31 +313,37 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
|
|
236
313
|
device=device,
|
|
237
314
|
enable_memory_saver=enable_memory_saver,
|
|
238
315
|
)
|
|
316
|
+
self._init_mamba_pool(
|
|
317
|
+
size=mamba_size,
|
|
318
|
+
cache_params=cache_params,
|
|
319
|
+
device=device,
|
|
320
|
+
speculative_num_draft_tokens=speculative_num_draft_tokens,
|
|
321
|
+
)
|
|
239
322
|
|
|
323
|
+
def _init_mamba_pool(
|
|
324
|
+
self,
|
|
325
|
+
size: int,
|
|
326
|
+
cache_params: "Mamba2CacheParams",
|
|
327
|
+
device: str,
|
|
328
|
+
speculative_num_draft_tokens: int = None,
|
|
329
|
+
):
|
|
240
330
|
self.mamba_pool = MambaPool(
|
|
241
|
-
size,
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
conv_state_shape,
|
|
246
|
-
temporal_state_shape,
|
|
247
|
-
device,
|
|
248
|
-
speculative_num_draft_tokens,
|
|
331
|
+
size=size,
|
|
332
|
+
cache_params=cache_params,
|
|
333
|
+
device=device,
|
|
334
|
+
speculative_num_draft_tokens=speculative_num_draft_tokens,
|
|
249
335
|
)
|
|
250
|
-
self.mamba_map = {layer_id: i for i, layer_id in enumerate(
|
|
336
|
+
self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
|
|
251
337
|
|
|
252
338
|
self.device = device
|
|
253
339
|
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
|
|
254
340
|
size, dtype=torch.int32, device=self.device
|
|
255
341
|
)
|
|
256
342
|
|
|
257
|
-
self.rid_to_mamba_index_mapping: Dict[str, int] = {}
|
|
258
|
-
self.mamba_index_to_rid_mapping: Dict[int, str] = {}
|
|
259
|
-
|
|
260
343
|
# For chunk prefill req, we do not need to allocate mamba cache,
|
|
261
344
|
# We could use allocated mamba cache instead.
|
|
262
345
|
def alloc(
|
|
263
|
-
self, need_size: int, reqs: Optional[List[
|
|
346
|
+
self, need_size: int, reqs: Optional[List[Req]] = None
|
|
264
347
|
) -> Optional[List[int]]:
|
|
265
348
|
select_index = super().alloc(need_size)
|
|
266
349
|
if select_index == None:
|
|
@@ -268,14 +351,14 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
|
|
268
351
|
|
|
269
352
|
mamba_index = []
|
|
270
353
|
for req in reqs:
|
|
271
|
-
|
|
272
|
-
if
|
|
273
|
-
mid =
|
|
274
|
-
|
|
275
|
-
mid =
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
354
|
+
mid = None
|
|
355
|
+
if req.mamba_pool_idx is not None: # for radix cache
|
|
356
|
+
mid = req.mamba_pool_idx
|
|
357
|
+
else:
|
|
358
|
+
mid = self.mamba_pool.alloc(1)[0]
|
|
359
|
+
req.mamba_pool_idx = mid
|
|
360
|
+
if mid is not None:
|
|
361
|
+
mamba_index.append(mid)
|
|
279
362
|
assert len(select_index) == len(
|
|
280
363
|
mamba_index
|
|
281
364
|
), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
|
|
@@ -287,26 +370,21 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
|
|
287
370
|
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
|
|
288
371
|
return self.req_index_to_mamba_index_mapping[req_indices]
|
|
289
372
|
|
|
290
|
-
def
|
|
373
|
+
def mamba2_layer_cache(self, layer_id: int):
|
|
291
374
|
assert layer_id in self.mamba_map
|
|
292
|
-
return self.mamba_pool.
|
|
375
|
+
return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id])
|
|
293
376
|
|
|
294
|
-
def
|
|
295
|
-
return self.mamba_pool.
|
|
377
|
+
def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState:
|
|
378
|
+
return self.mamba_pool.get_speculative_mamba2_params_all_layers()
|
|
296
379
|
|
|
297
380
|
# For chunk prefill, we can not free mamba cache, we need use it in the future
|
|
298
381
|
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
|
|
382
|
+
if isinstance(free_index, (int,)):
|
|
383
|
+
free_index = [free_index]
|
|
299
384
|
super().free(free_index)
|
|
300
385
|
if free_mamba_cache:
|
|
301
386
|
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
|
|
302
|
-
|
|
303
|
-
if isinstance(mamba_index_list, int):
|
|
304
|
-
mamba_index_list = [mamba_index_list]
|
|
305
|
-
self.mamba_pool.free(mamba_index_list)
|
|
306
|
-
for mid in mamba_index_list:
|
|
307
|
-
rid = self.mamba_index_to_rid_mapping[mid]
|
|
308
|
-
self.mamba_index_to_rid_mapping.pop(mid)
|
|
309
|
-
self.rid_to_mamba_index_mapping.pop(rid)
|
|
387
|
+
self.mamba_pool.free(mamba_index)
|
|
310
388
|
|
|
311
389
|
def clear(self):
|
|
312
390
|
super().clear()
|
|
@@ -349,6 +427,19 @@ class KVCache(abc.ABC):
|
|
|
349
427
|
# default state for optional layer-wise transfer control
|
|
350
428
|
self.layer_transfer_counter = None
|
|
351
429
|
|
|
430
|
+
# for disagg with nvlink
|
|
431
|
+
self.enable_custom_mem_pool = get_bool_env_var(
|
|
432
|
+
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
|
433
|
+
)
|
|
434
|
+
if self.enable_custom_mem_pool:
|
|
435
|
+
# TODO(shangming): abstract custom allocator class for more backends
|
|
436
|
+
from mooncake.allocator import NVLinkAllocator
|
|
437
|
+
|
|
438
|
+
allocator = NVLinkAllocator.get_allocator(self.device)
|
|
439
|
+
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
|
440
|
+
else:
|
|
441
|
+
self.custom_mem_pool = None
|
|
442
|
+
|
|
352
443
|
def _finalize_allocation_log(self, num_tokens: int):
|
|
353
444
|
"""Common logging and mem_usage computation for KV cache allocation.
|
|
354
445
|
Supports both tuple (K, V) size returns and single KV size returns.
|
|
@@ -400,6 +491,9 @@ class KVCache(abc.ABC):
|
|
|
400
491
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
|
401
492
|
raise NotImplementedError()
|
|
402
493
|
|
|
494
|
+
def maybe_get_custom_mem_pool(self):
|
|
495
|
+
return self.custom_mem_pool
|
|
496
|
+
|
|
403
497
|
|
|
404
498
|
class MHATokenToKVPool(KVCache):
|
|
405
499
|
|
|
@@ -415,6 +509,7 @@ class MHATokenToKVPool(KVCache):
|
|
|
415
509
|
enable_memory_saver: bool,
|
|
416
510
|
start_layer: Optional[int] = None,
|
|
417
511
|
end_layer: Optional[int] = None,
|
|
512
|
+
enable_kv_cache_copy: bool = False,
|
|
418
513
|
):
|
|
419
514
|
super().__init__(
|
|
420
515
|
size,
|
|
@@ -429,25 +524,61 @@ class MHATokenToKVPool(KVCache):
|
|
|
429
524
|
self.head_num = head_num
|
|
430
525
|
self.head_dim = head_dim
|
|
431
526
|
|
|
432
|
-
# for disagg with nvlink
|
|
433
|
-
self.enable_custom_mem_pool = get_bool_env_var(
|
|
434
|
-
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
|
435
|
-
)
|
|
436
|
-
if self.enable_custom_mem_pool:
|
|
437
|
-
# TODO(shangming): abstract custom allocator class for more backends
|
|
438
|
-
from mooncake.allocator import NVLinkAllocator
|
|
439
|
-
|
|
440
|
-
allocator = NVLinkAllocator.get_allocator(self.device)
|
|
441
|
-
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
|
442
|
-
else:
|
|
443
|
-
self.custom_mem_pool = None
|
|
444
|
-
|
|
445
527
|
self._create_buffers()
|
|
446
528
|
|
|
447
529
|
self.device_module = torch.get_device_module(self.device)
|
|
448
530
|
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
|
531
|
+
|
|
532
|
+
if enable_kv_cache_copy:
|
|
533
|
+
self._init_kv_copy_and_warmup()
|
|
534
|
+
else:
|
|
535
|
+
self._kv_copy_config = None
|
|
536
|
+
|
|
449
537
|
self._finalize_allocation_log(size)
|
|
450
538
|
|
|
539
|
+
def _init_kv_copy_and_warmup(self):
|
|
540
|
+
# Heuristics for KV copy tiling
|
|
541
|
+
_KV_COPY_STRIDE_THRESHOLD_LARGE = 8192
|
|
542
|
+
_KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096
|
|
543
|
+
_KV_COPY_TILE_SIZE_LARGE = 512
|
|
544
|
+
_KV_COPY_TILE_SIZE_MEDIUM = 256
|
|
545
|
+
_KV_COPY_TILE_SIZE_SMALL = 128
|
|
546
|
+
_KV_COPY_NUM_WARPS_LARGE_TILE = 8
|
|
547
|
+
_KV_COPY_NUM_WARPS_SMALL_TILE = 4
|
|
548
|
+
|
|
549
|
+
stride_bytes = int(self.data_strides[0].item())
|
|
550
|
+
if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE:
|
|
551
|
+
bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE
|
|
552
|
+
elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM:
|
|
553
|
+
bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM
|
|
554
|
+
else:
|
|
555
|
+
bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL
|
|
556
|
+
|
|
557
|
+
self._kv_copy_config = {
|
|
558
|
+
"bytes_per_tile": bytes_per_tile,
|
|
559
|
+
"byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile,
|
|
560
|
+
"num_warps": (
|
|
561
|
+
_KV_COPY_NUM_WARPS_SMALL_TILE
|
|
562
|
+
if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM
|
|
563
|
+
else _KV_COPY_NUM_WARPS_LARGE_TILE
|
|
564
|
+
),
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device)
|
|
568
|
+
grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"])
|
|
569
|
+
|
|
570
|
+
copy_all_layer_kv_cache_tiled[grid](
|
|
571
|
+
self.data_ptrs,
|
|
572
|
+
self.data_strides,
|
|
573
|
+
dummy_loc,
|
|
574
|
+
dummy_loc,
|
|
575
|
+
1,
|
|
576
|
+
1,
|
|
577
|
+
BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"],
|
|
578
|
+
num_warps=self._kv_copy_config["num_warps"],
|
|
579
|
+
num_stages=2,
|
|
580
|
+
)
|
|
581
|
+
|
|
451
582
|
def _create_buffers(self):
|
|
452
583
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
|
453
584
|
with (
|
|
@@ -535,9 +666,6 @@ class MHATokenToKVPool(KVCache):
|
|
|
535
666
|
]
|
|
536
667
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
|
537
668
|
|
|
538
|
-
def maybe_get_custom_mem_pool(self):
|
|
539
|
-
return self.custom_mem_pool
|
|
540
|
-
|
|
541
669
|
def get_cpu_copy(self, indices):
|
|
542
670
|
torch.cuda.synchronize()
|
|
543
671
|
kv_cache_cpu = []
|
|
@@ -642,13 +770,28 @@ class MHATokenToKVPool(KVCache):
|
|
|
642
770
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
|
643
771
|
|
|
644
772
|
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
|
|
645
|
-
|
|
773
|
+
N = tgt_loc.numel()
|
|
774
|
+
if N == 0:
|
|
775
|
+
return
|
|
776
|
+
|
|
777
|
+
assert (
|
|
778
|
+
self._kv_copy_config is not None
|
|
779
|
+
), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__"
|
|
780
|
+
|
|
781
|
+
cfg = self._kv_copy_config
|
|
782
|
+
N_upper = next_power_of_2(N)
|
|
783
|
+
grid = (self.data_ptrs.numel(), cfg["byte_tiles"])
|
|
784
|
+
|
|
785
|
+
copy_all_layer_kv_cache_tiled[grid](
|
|
646
786
|
self.data_ptrs,
|
|
647
787
|
self.data_strides,
|
|
648
788
|
tgt_loc,
|
|
649
789
|
src_loc,
|
|
650
|
-
|
|
651
|
-
|
|
790
|
+
N,
|
|
791
|
+
N_upper,
|
|
792
|
+
BYTES_PER_TILE=cfg["bytes_per_tile"],
|
|
793
|
+
num_warps=cfg["num_warps"],
|
|
794
|
+
num_stages=2,
|
|
652
795
|
)
|
|
653
796
|
|
|
654
797
|
|
|
@@ -665,12 +808,18 @@ class HybridLinearKVPool(KVCache):
|
|
|
665
808
|
full_attention_layer_ids: List[int],
|
|
666
809
|
enable_kvcache_transpose: bool,
|
|
667
810
|
device: str,
|
|
811
|
+
mamba_pool: MambaPool,
|
|
668
812
|
):
|
|
669
813
|
self.size = size
|
|
670
814
|
self.dtype = dtype
|
|
671
815
|
self.device = device
|
|
672
816
|
self.full_layer_nums = len(full_attention_layer_ids)
|
|
673
817
|
self.page_size = page_size
|
|
818
|
+
# TODO support pp?
|
|
819
|
+
self.start_layer = 0
|
|
820
|
+
self.head_num = head_num
|
|
821
|
+
self.head_dim = head_dim
|
|
822
|
+
self.mamba_pool = mamba_pool
|
|
674
823
|
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
|
675
824
|
assert not enable_kvcache_transpose
|
|
676
825
|
if _is_npu:
|
|
@@ -699,6 +848,15 @@ class HybridLinearKVPool(KVCache):
|
|
|
699
848
|
def get_contiguous_buf_infos(self):
|
|
700
849
|
return self.full_kv_pool.get_contiguous_buf_infos()
|
|
701
850
|
|
|
851
|
+
def get_state_buf_infos(self):
|
|
852
|
+
mamba_data_ptrs, mamba_data_lens, mamba_item_lens = (
|
|
853
|
+
self.mamba_pool.get_contiguous_buf_infos()
|
|
854
|
+
)
|
|
855
|
+
return mamba_data_ptrs, mamba_data_lens, mamba_item_lens
|
|
856
|
+
|
|
857
|
+
def maybe_get_custom_mem_pool(self):
|
|
858
|
+
return self.full_kv_pool.maybe_get_custom_mem_pool()
|
|
859
|
+
|
|
702
860
|
def _transfer_full_attention_id(self, layer_id: int):
|
|
703
861
|
if layer_id not in self.full_attention_layer_id_mapping:
|
|
704
862
|
raise ValueError(
|
|
@@ -749,28 +907,57 @@ class SWAKVPool(KVCache):
|
|
|
749
907
|
self,
|
|
750
908
|
size: int,
|
|
751
909
|
size_swa: int,
|
|
910
|
+
dtype: torch.dtype,
|
|
911
|
+
head_num: int,
|
|
912
|
+
head_dim: int,
|
|
752
913
|
swa_attention_layer_ids: List[int],
|
|
753
914
|
full_attention_layer_ids: List[int],
|
|
754
915
|
enable_kvcache_transpose: bool,
|
|
916
|
+
device: str,
|
|
755
917
|
token_to_kv_pool_class: KVCache = MHATokenToKVPool,
|
|
756
918
|
**kwargs,
|
|
757
919
|
):
|
|
758
920
|
self.size = size
|
|
759
921
|
self.size_swa = size_swa
|
|
922
|
+
self.dtype = dtype
|
|
923
|
+
self.head_num = head_num
|
|
924
|
+
self.head_dim = head_dim
|
|
925
|
+
self.device = device
|
|
760
926
|
self.swa_layer_nums = len(swa_attention_layer_ids)
|
|
761
927
|
self.full_layer_nums = len(full_attention_layer_ids)
|
|
928
|
+
self.start_layer = 0
|
|
929
|
+
self.page_size = 1
|
|
930
|
+
|
|
762
931
|
kwargs["page_size"] = 1
|
|
763
932
|
kwargs["enable_memory_saver"] = False
|
|
933
|
+
kwargs["head_num"] = head_num
|
|
934
|
+
kwargs["head_dim"] = head_dim
|
|
935
|
+
kwargs["device"] = device
|
|
764
936
|
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
|
765
937
|
assert not enable_kvcache_transpose
|
|
766
938
|
|
|
939
|
+
# for disagg with nvlink
|
|
940
|
+
self.enable_custom_mem_pool = get_bool_env_var(
|
|
941
|
+
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
|
942
|
+
)
|
|
943
|
+
if self.enable_custom_mem_pool:
|
|
944
|
+
# TODO(shangming): abstract custom allocator class for more backends
|
|
945
|
+
from mooncake.allocator import NVLinkAllocator
|
|
946
|
+
|
|
947
|
+
allocator = NVLinkAllocator.get_allocator(self.device)
|
|
948
|
+
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
|
949
|
+
else:
|
|
950
|
+
self.custom_mem_pool = None
|
|
951
|
+
|
|
767
952
|
self.swa_kv_pool = token_to_kv_pool_class(
|
|
768
953
|
size=size_swa,
|
|
954
|
+
dtype=dtype,
|
|
769
955
|
layer_num=self.swa_layer_nums,
|
|
770
956
|
**kwargs,
|
|
771
957
|
)
|
|
772
958
|
self.full_kv_pool = token_to_kv_pool_class(
|
|
773
959
|
size=size,
|
|
960
|
+
dtype=dtype,
|
|
774
961
|
layer_num=self.full_layer_nums,
|
|
775
962
|
**kwargs,
|
|
776
963
|
)
|
|
@@ -783,6 +970,9 @@ class SWAKVPool(KVCache):
|
|
|
783
970
|
|
|
784
971
|
k_size, v_size = self.get_kv_size_bytes()
|
|
785
972
|
self.mem_usage = (k_size + v_size) / GB
|
|
973
|
+
logger.info(
|
|
974
|
+
f"SWAKVPool mem usage: {self.mem_usage} GB, swa size: {self.size_swa}, full size: {self.size}"
|
|
975
|
+
)
|
|
786
976
|
|
|
787
977
|
def get_kv_size_bytes(self):
|
|
788
978
|
k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
|
|
@@ -793,15 +983,19 @@ class SWAKVPool(KVCache):
|
|
|
793
983
|
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
|
|
794
984
|
self.full_kv_pool.get_contiguous_buf_infos()
|
|
795
985
|
)
|
|
986
|
+
|
|
987
|
+
kv_data_ptrs = full_kv_data_ptrs
|
|
988
|
+
kv_data_lens = full_kv_data_lens
|
|
989
|
+
kv_item_lens = full_kv_item_lens
|
|
990
|
+
|
|
991
|
+
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
|
992
|
+
|
|
993
|
+
def get_state_buf_infos(self):
|
|
796
994
|
swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
|
|
797
995
|
self.swa_kv_pool.get_contiguous_buf_infos()
|
|
798
996
|
)
|
|
799
997
|
|
|
800
|
-
|
|
801
|
-
kv_data_lens = full_kv_data_lens + swa_kv_data_lens
|
|
802
|
-
kv_item_lens = full_kv_item_lens + swa_kv_item_lens
|
|
803
|
-
|
|
804
|
-
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
|
998
|
+
return swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens
|
|
805
999
|
|
|
806
1000
|
def get_key_buffer(self, layer_id: int):
|
|
807
1001
|
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
|
@@ -1057,19 +1251,6 @@ class MLATokenToKVPool(KVCache):
|
|
|
1057
1251
|
else (kv_lora_rank + qk_rope_head_dim)
|
|
1058
1252
|
)
|
|
1059
1253
|
|
|
1060
|
-
# for disagg with nvlink
|
|
1061
|
-
self.enable_custom_mem_pool = get_bool_env_var(
|
|
1062
|
-
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
|
1063
|
-
)
|
|
1064
|
-
if self.enable_custom_mem_pool:
|
|
1065
|
-
# TODO(shangming): abstract custom allocator class for more backends
|
|
1066
|
-
from mooncake.allocator import NVLinkAllocator
|
|
1067
|
-
|
|
1068
|
-
allocator = NVLinkAllocator.get_allocator(self.device)
|
|
1069
|
-
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
|
1070
|
-
else:
|
|
1071
|
-
self.custom_mem_pool = None
|
|
1072
|
-
|
|
1073
1254
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
|
1074
1255
|
with (
|
|
1075
1256
|
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
|
@@ -1091,7 +1272,9 @@ class MLATokenToKVPool(KVCache):
|
|
|
1091
1272
|
dtype=torch.uint64,
|
|
1092
1273
|
device=self.device,
|
|
1093
1274
|
)
|
|
1094
|
-
|
|
1275
|
+
if not use_nsa:
|
|
1276
|
+
# NSA will allocate indexer KV cache later and then log the total size
|
|
1277
|
+
self._finalize_allocation_log(size)
|
|
1095
1278
|
|
|
1096
1279
|
def get_kv_size_bytes(self):
|
|
1097
1280
|
assert hasattr(self, "kv_buffer")
|
|
@@ -1110,9 +1293,6 @@ class MLATokenToKVPool(KVCache):
|
|
|
1110
1293
|
]
|
|
1111
1294
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
|
1112
1295
|
|
|
1113
|
-
def maybe_get_custom_mem_pool(self):
|
|
1114
|
-
return self.custom_mem_pool
|
|
1115
|
-
|
|
1116
1296
|
def get_key_buffer(self, layer_id: int):
|
|
1117
1297
|
if self.layer_transfer_counter is not None:
|
|
1118
1298
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
|
@@ -1212,6 +1392,9 @@ class MLATokenToKVPool(KVCache):
|
|
|
1212
1392
|
|
|
1213
1393
|
|
|
1214
1394
|
class NSATokenToKVPool(MLATokenToKVPool):
|
|
1395
|
+
quant_block_size = 128
|
|
1396
|
+
index_k_with_scale_buffer_dtype = torch.uint8
|
|
1397
|
+
|
|
1215
1398
|
def __init__(
|
|
1216
1399
|
self,
|
|
1217
1400
|
size: int,
|
|
@@ -1245,27 +1428,33 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
|
|
1245
1428
|
# num head == 1 and head dim == 128 for index_k in NSA
|
|
1246
1429
|
assert index_head_dim == 128
|
|
1247
1430
|
|
|
1248
|
-
self.quant_block_size = 128
|
|
1249
|
-
|
|
1250
1431
|
assert self.page_size == 64
|
|
1251
|
-
|
|
1252
|
-
torch.
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
(
|
|
1261
|
-
|
|
1262
|
-
*
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1432
|
+
with (
|
|
1433
|
+
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
|
1434
|
+
if self.custom_mem_pool
|
|
1435
|
+
else nullcontext()
|
|
1436
|
+
):
|
|
1437
|
+
self.index_k_with_scale_buffer = [
|
|
1438
|
+
torch.zeros(
|
|
1439
|
+
# Layout:
|
|
1440
|
+
# ref: test_attention.py :: kv_cache_cast_to_fp8
|
|
1441
|
+
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
|
|
1442
|
+
# data: for page i,
|
|
1443
|
+
# * buf[i, :page_size * head_dim] for fp8 data
|
|
1444
|
+
# * buf[i, page_size * head_dim:].view(float32) for scale
|
|
1445
|
+
(
|
|
1446
|
+
(size + page_size + 1) // self.page_size,
|
|
1447
|
+
self.page_size
|
|
1448
|
+
* (
|
|
1449
|
+
index_head_dim + index_head_dim // self.quant_block_size * 4
|
|
1450
|
+
),
|
|
1451
|
+
),
|
|
1452
|
+
dtype=self.index_k_with_scale_buffer_dtype,
|
|
1453
|
+
device=device,
|
|
1454
|
+
)
|
|
1455
|
+
for _ in range(layer_num)
|
|
1456
|
+
]
|
|
1457
|
+
self._finalize_allocation_log(size)
|
|
1269
1458
|
|
|
1270
1459
|
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
|
|
1271
1460
|
if self.layer_transfer_counter is not None:
|
|
@@ -1307,6 +1496,24 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
|
|
1307
1496
|
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
|
|
1308
1497
|
)
|
|
1309
1498
|
|
|
1499
|
+
def get_state_buf_infos(self):
|
|
1500
|
+
data_ptrs = [
|
|
1501
|
+
self.index_k_with_scale_buffer[i].data_ptr() for i in range(self.layer_num)
|
|
1502
|
+
]
|
|
1503
|
+
data_lens = [
|
|
1504
|
+
self.index_k_with_scale_buffer[i].nbytes for i in range(self.layer_num)
|
|
1505
|
+
]
|
|
1506
|
+
item_lens = [
|
|
1507
|
+
self.index_k_with_scale_buffer[i][0].nbytes for i in range(self.layer_num)
|
|
1508
|
+
]
|
|
1509
|
+
return data_ptrs, data_lens, item_lens
|
|
1510
|
+
|
|
1511
|
+
def get_kv_size_bytes(self):
|
|
1512
|
+
kv_size_bytes = super().get_kv_size_bytes()
|
|
1513
|
+
for index_k_cache in self.index_k_with_scale_buffer:
|
|
1514
|
+
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
|
|
1515
|
+
return kv_size_bytes
|
|
1516
|
+
|
|
1310
1517
|
|
|
1311
1518
|
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
1312
1519
|
def __init__(
|
|
@@ -1531,27 +1738,38 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
|
1531
1738
|
)
|
|
1532
1739
|
|
|
1533
1740
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
for
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1741
|
+
with (
|
|
1742
|
+
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
|
1743
|
+
if self.enable_custom_mem_pool
|
|
1744
|
+
else nullcontext()
|
|
1745
|
+
):
|
|
1746
|
+
# [size, head_num, head_dim] for each layer
|
|
1747
|
+
self.k_buffer = [
|
|
1748
|
+
torch.zeros(
|
|
1749
|
+
(size + page_size, head_num, head_dim),
|
|
1750
|
+
dtype=dtype,
|
|
1751
|
+
device=device,
|
|
1752
|
+
)
|
|
1753
|
+
for _ in range(layer_num)
|
|
1754
|
+
]
|
|
1755
|
+
self.v_buffer = [
|
|
1756
|
+
torch.zeros(
|
|
1757
|
+
(size + page_size, head_num, head_dim),
|
|
1758
|
+
dtype=dtype,
|
|
1759
|
+
device=device,
|
|
1760
|
+
)
|
|
1761
|
+
for _ in range(layer_num)
|
|
1762
|
+
]
|
|
1547
1763
|
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1764
|
+
# [size, head_num, heavy_channel_num] for each layer
|
|
1765
|
+
self.label_buffer = [
|
|
1766
|
+
torch.zeros(
|
|
1767
|
+
(size + 1, head_num, heavy_channel_num),
|
|
1768
|
+
dtype=dtype,
|
|
1769
|
+
device=device,
|
|
1770
|
+
)
|
|
1771
|
+
for _ in range(layer_num)
|
|
1772
|
+
]
|
|
1555
1773
|
|
|
1556
1774
|
def get_key_buffer(self, layer_id: int):
|
|
1557
1775
|
return self.k_buffer[layer_id - self.start_layer]
|
|
@@ -1584,38 +1802,36 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
|
1584
1802
|
|
|
1585
1803
|
|
|
1586
1804
|
@triton.jit
|
|
1587
|
-
def
|
|
1805
|
+
def copy_all_layer_kv_cache_tiled(
|
|
1588
1806
|
data_ptrs,
|
|
1589
1807
|
strides,
|
|
1590
1808
|
tgt_loc_ptr,
|
|
1591
1809
|
src_loc_ptr,
|
|
1592
1810
|
num_locs,
|
|
1593
1811
|
num_locs_upper: tl.constexpr,
|
|
1812
|
+
BYTES_PER_TILE: tl.constexpr,
|
|
1594
1813
|
):
|
|
1595
|
-
|
|
1596
|
-
|
|
1814
|
+
"""2D tiled kernel. Safe for in-place copy."""
|
|
1597
1815
|
bid = tl.program_id(0)
|
|
1816
|
+
tid = tl.program_id(1)
|
|
1817
|
+
|
|
1598
1818
|
stride = tl.load(strides + bid)
|
|
1819
|
+
base_ptr = tl.load(data_ptrs + bid)
|
|
1820
|
+
base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8))
|
|
1599
1821
|
|
|
1600
|
-
|
|
1601
|
-
|
|
1822
|
+
byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE)
|
|
1823
|
+
mask_byte = byte_off < stride
|
|
1824
|
+
tl.multiple_of(byte_off, 16)
|
|
1602
1825
|
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
|
1826
|
+
loc_idx = tl.arange(0, num_locs_upper)
|
|
1827
|
+
mask_loc = loc_idx < num_locs
|
|
1606
1828
|
|
|
1607
|
-
|
|
1608
|
-
|
|
1829
|
+
src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0)
|
|
1830
|
+
tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0)
|
|
1609
1831
|
|
|
1610
|
-
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
)
|
|
1617
|
-
tl.store(
|
|
1618
|
-
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
|
|
1619
|
-
value,
|
|
1620
|
-
mask=mask,
|
|
1621
|
-
)
|
|
1832
|
+
src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :]
|
|
1833
|
+
tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :]
|
|
1834
|
+
|
|
1835
|
+
mask = mask_loc[:, None] & mask_byte[None, :]
|
|
1836
|
+
vals = tl.load(src_ptr, mask=mask)
|
|
1837
|
+
tl.store(tgt_ptr, vals, mask=mask)
|