sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +378 -160
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +10 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +136 -25
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +63 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +83 -80
- sglang/srt/entrypoints/grpc_server.py +430 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +195 -102
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +58 -6
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +33 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +20 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +10 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +24 -10
- sglang/srt/layers/attention/flashinfer_backend.py +258 -22
- sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
- sglang/srt/layers/attention/utils.py +89 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +12 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +64 -19
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +152 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
- sglang/srt/layers/moe/ep_moe/layer.py +154 -625
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
- sglang/srt/layers/moe/moe_runner/runner.py +6 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
- sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +7 -6
- sglang/srt/layers/moe/utils.py +20 -5
- sglang/srt/layers/quantization/__init__.py +5 -58
- sglang/srt/layers/quantization/awq.py +183 -9
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +27 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +152 -81
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +35 -68
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +23 -48
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +87 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +62 -9
- sglang/srt/layers/rotary_embedding.py +686 -17
- sglang/srt/layers/sampler.py +47 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +69 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +420 -514
- sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +375 -95
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +11 -2
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +517 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +71 -25
- sglang/srt/model_executor/model_runner.py +362 -270
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +418 -140
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +327 -382
- sglang/srt/models/glm4_moe_nextn.py +6 -16
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +32 -199
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +22 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3.py +34 -4
- sglang/srt/models/qwen3_moe.py +19 -37
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +7 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +2 -6
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +28 -2
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +846 -163
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +36 -31
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +272 -82
- sglang/srt/utils/hf_transformers_utils.py +44 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +463 -107
- sglang/test/test_deterministic_utils.py +74 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/models/vila.py +0 -306
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import sys
|
|
4
3
|
from dataclasses import dataclass
|
|
5
4
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
|
|
6
5
|
|
|
@@ -30,22 +29,23 @@ if TYPE_CHECKING:
|
|
|
30
29
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
31
30
|
from sglang.srt.speculative.spec_info import SpecInput
|
|
32
31
|
|
|
32
|
+
|
|
33
33
|
_is_hip = is_hip()
|
|
34
34
|
|
|
35
35
|
if _is_hip:
|
|
36
36
|
try:
|
|
37
|
-
from aiter import (
|
|
37
|
+
from aiter import ( # noqa: F401
|
|
38
38
|
flash_attn_varlen_func,
|
|
39
39
|
mha_batch_prefill_func,
|
|
40
40
|
paged_attention_ragged,
|
|
41
41
|
)
|
|
42
|
-
from aiter.mla import mla_decode_fwd, mla_prefill_fwd
|
|
42
|
+
from aiter.mla import mla_decode_fwd, mla_prefill_fwd # noqa: F401
|
|
43
43
|
except ImportError:
|
|
44
44
|
print(
|
|
45
45
|
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
|
46
46
|
)
|
|
47
47
|
else:
|
|
48
|
-
from sgl_kernel.flash_attn import
|
|
48
|
+
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
|
49
49
|
|
|
50
50
|
|
|
51
51
|
@dataclass(frozen=True)
|
|
@@ -140,16 +140,21 @@ def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
|
|
|
140
140
|
)
|
|
141
141
|
|
|
142
142
|
|
|
143
|
-
_NSA_IMPL_T: TypeAlias = Literal[
|
|
144
|
-
"flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
|
|
145
|
-
]
|
|
143
|
+
_NSA_IMPL_T: TypeAlias = Literal["flashmla_sparse", "flashmla_kv", "fa3", "tilelang"]
|
|
146
144
|
|
|
147
145
|
NSA_PREFILL_IMPL: _NSA_IMPL_T
|
|
148
146
|
NSA_DECODE_IMPL: _NSA_IMPL_T
|
|
149
147
|
|
|
150
148
|
|
|
151
149
|
class NativeSparseAttnBackend(AttentionBackend):
|
|
152
|
-
def __init__(
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
model_runner: ModelRunner,
|
|
153
|
+
skip_prefill: bool = False,
|
|
154
|
+
speculative_step_id=0,
|
|
155
|
+
topk=0,
|
|
156
|
+
speculative_num_steps=0,
|
|
157
|
+
):
|
|
153
158
|
super().__init__()
|
|
154
159
|
self.forward_metadata: NSAMetadata
|
|
155
160
|
self.device = model_runner.device
|
|
@@ -174,8 +179,8 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
174
179
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
|
175
180
|
|
|
176
181
|
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
|
|
177
|
-
NSA_PREFILL_IMPL = model_runner.server_args.
|
|
178
|
-
NSA_DECODE_IMPL = model_runner.server_args.
|
|
182
|
+
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill_backend
|
|
183
|
+
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode_backend
|
|
179
184
|
|
|
180
185
|
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
|
|
181
186
|
|
|
@@ -186,6 +191,14 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
186
191
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
|
187
192
|
)
|
|
188
193
|
|
|
194
|
+
# Speculative decoding
|
|
195
|
+
self.topk = model_runner.server_args.speculative_eagle_topk or 0
|
|
196
|
+
self.speculative_num_steps = speculative_num_steps
|
|
197
|
+
self.speculative_num_draft_tokens = (
|
|
198
|
+
model_runner.server_args.speculative_num_draft_tokens
|
|
199
|
+
)
|
|
200
|
+
self.speculative_step_id = speculative_step_id
|
|
201
|
+
|
|
189
202
|
def get_device_int32_arange(self, l: int) -> torch.Tensor:
|
|
190
203
|
if l > len(self._arange_buf):
|
|
191
204
|
next_pow_of_2 = 1 << (l - 1).bit_length()
|
|
@@ -209,13 +222,15 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
209
222
|
batch_size = forward_batch.batch_size
|
|
210
223
|
device = forward_batch.seq_lens.device
|
|
211
224
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
225
|
+
if forward_batch.forward_mode.is_target_verify():
|
|
226
|
+
draft_token_num = self.speculative_num_draft_tokens
|
|
227
|
+
else:
|
|
228
|
+
draft_token_num = 0
|
|
229
|
+
|
|
230
|
+
cache_seqlens_int32 = (forward_batch.seq_lens + draft_token_num).to(torch.int32)
|
|
216
231
|
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
|
217
232
|
assert forward_batch.seq_lens_cpu is not None
|
|
218
|
-
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item())
|
|
233
|
+
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item() + draft_token_num)
|
|
219
234
|
page_table = forward_batch.req_to_token_pool.req_to_token[
|
|
220
235
|
forward_batch.req_pool_indices, :max_seqlen_k
|
|
221
236
|
]
|
|
@@ -225,6 +240,41 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
225
240
|
max_seqlen_q = 1
|
|
226
241
|
cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
|
|
227
242
|
seqlens_expanded = cache_seqlens_int32
|
|
243
|
+
elif forward_batch.forward_mode.is_target_verify():
|
|
244
|
+
max_seqlen_q = self.speculative_num_draft_tokens
|
|
245
|
+
nsa_max_seqlen_q = self.speculative_num_draft_tokens
|
|
246
|
+
cu_seqlens_q = torch.arange(
|
|
247
|
+
0,
|
|
248
|
+
batch_size * self.speculative_num_draft_tokens + 1,
|
|
249
|
+
1,
|
|
250
|
+
dtype=torch.int32,
|
|
251
|
+
device=device,
|
|
252
|
+
)
|
|
253
|
+
extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * batch_size
|
|
254
|
+
forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu
|
|
255
|
+
|
|
256
|
+
seqlens_int32_cpu = [
|
|
257
|
+
self.speculative_num_draft_tokens + kv_len
|
|
258
|
+
for kv_len in forward_batch.seq_lens_cpu.tolist()
|
|
259
|
+
]
|
|
260
|
+
seqlens_expanded = torch.cat(
|
|
261
|
+
[
|
|
262
|
+
torch.arange(
|
|
263
|
+
kv_len - qo_len + 1,
|
|
264
|
+
kv_len + 1,
|
|
265
|
+
dtype=torch.int32,
|
|
266
|
+
device=device,
|
|
267
|
+
)
|
|
268
|
+
for qo_len, kv_len in zip(
|
|
269
|
+
extend_seq_lens_cpu,
|
|
270
|
+
seqlens_int32_cpu,
|
|
271
|
+
strict=True,
|
|
272
|
+
)
|
|
273
|
+
]
|
|
274
|
+
)
|
|
275
|
+
page_table = torch.repeat_interleave(
|
|
276
|
+
page_table, repeats=self.speculative_num_draft_tokens, dim=0
|
|
277
|
+
)
|
|
228
278
|
elif forward_batch.forward_mode.is_extend():
|
|
229
279
|
assert (
|
|
230
280
|
forward_batch.extend_seq_lens_cpu is not None
|
|
@@ -233,7 +283,11 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
233
283
|
), "All of them must not be None"
|
|
234
284
|
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
|
|
235
285
|
assert forward_batch.extend_seq_lens is not None
|
|
236
|
-
|
|
286
|
+
|
|
287
|
+
if (
|
|
288
|
+
any(forward_batch.extend_prefix_lens_cpu)
|
|
289
|
+
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
|
290
|
+
):
|
|
237
291
|
max_seqlen_q = max(extend_seq_lens_cpu)
|
|
238
292
|
cu_seqlens_q = compute_cu_seqlens(
|
|
239
293
|
forward_batch.extend_seq_lens.to(torch.int32)
|
|
@@ -278,9 +332,9 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
278
332
|
flashmla_metadata=(
|
|
279
333
|
self._compute_flashmla_metadata(
|
|
280
334
|
cache_seqlens=nsa_cache_seqlens_int32,
|
|
281
|
-
seq_len_q=1,
|
|
335
|
+
seq_len_q=1,
|
|
282
336
|
)
|
|
283
|
-
if NSA_DECODE_IMPL == "
|
|
337
|
+
if NSA_DECODE_IMPL == "flashmla_kv"
|
|
284
338
|
else None
|
|
285
339
|
),
|
|
286
340
|
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
|
|
@@ -289,6 +343,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
289
343
|
nsa_seqlens_expanded=seqlens_expanded,
|
|
290
344
|
nsa_extend_seq_lens_list=extend_seq_lens_cpu,
|
|
291
345
|
real_page_table=self._transform_table_1_to_real(page_table),
|
|
346
|
+
nsa_max_seqlen_q=1,
|
|
292
347
|
)
|
|
293
348
|
|
|
294
349
|
self.forward_metadata = metadata
|
|
@@ -303,7 +358,9 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
303
358
|
to avoid memory allocations.
|
|
304
359
|
"""
|
|
305
360
|
self.decode_cuda_graph_metadata: Dict = {
|
|
306
|
-
"cache_seqlens": torch.
|
|
361
|
+
"cache_seqlens": torch.ones(
|
|
362
|
+
max_num_tokens, dtype=torch.int32, device=self.device
|
|
363
|
+
),
|
|
307
364
|
"cu_seqlens_q": torch.arange(
|
|
308
365
|
0, max_bs + 1, dtype=torch.int32, device=self.device
|
|
309
366
|
),
|
|
@@ -312,7 +369,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
312
369
|
),
|
|
313
370
|
# fake page_table for sparse_prefill
|
|
314
371
|
"page_table": torch.zeros(
|
|
315
|
-
|
|
372
|
+
max_num_tokens,
|
|
316
373
|
self.max_context_len,
|
|
317
374
|
dtype=torch.int32,
|
|
318
375
|
device=self.device,
|
|
@@ -320,11 +377,11 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
320
377
|
"flashmla_metadata": (
|
|
321
378
|
self._compute_flashmla_metadata(
|
|
322
379
|
cache_seqlens=torch.ones(
|
|
323
|
-
|
|
380
|
+
max_num_tokens, dtype=torch.int32, device=self.device
|
|
324
381
|
),
|
|
325
|
-
seq_len_q=1,
|
|
382
|
+
seq_len_q=1,
|
|
326
383
|
)
|
|
327
|
-
if NSA_DECODE_IMPL == "
|
|
384
|
+
if NSA_DECODE_IMPL == "flashmla_kv"
|
|
328
385
|
else None
|
|
329
386
|
),
|
|
330
387
|
}
|
|
@@ -340,50 +397,166 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
340
397
|
spec_info: Optional[SpecInput],
|
|
341
398
|
):
|
|
342
399
|
"""Initialize forward metadata for capturing CUDA graph."""
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
400
|
+
if forward_mode.is_decode_or_idle():
|
|
401
|
+
# Normal Decode
|
|
402
|
+
# Get sequence information
|
|
403
|
+
cache_seqlens_int32 = seq_lens.to(torch.int32)
|
|
404
|
+
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
|
405
|
+
|
|
406
|
+
# Use max context length for seq_len_k
|
|
407
|
+
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
|
408
|
+
max_seqlen_q = 1
|
|
409
|
+
max_seqlen_k = page_table_1.shape[1]
|
|
347
410
|
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
cache_seqlens_int32 = seq_lens.to(torch.int32)
|
|
351
|
-
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
|
411
|
+
# Precompute page table
|
|
412
|
+
# Precompute cumulative sequence lengths
|
|
352
413
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
414
|
+
# NOTE(dark): this is always arange, since we are decoding
|
|
415
|
+
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
|
|
416
|
+
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
|
417
|
+
cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
|
|
418
|
+
)
|
|
356
419
|
|
|
357
|
-
|
|
358
|
-
|
|
420
|
+
seqlens_expanded = cache_seqlens_int32
|
|
421
|
+
nsa_extend_seq_lens_list = [1] * num_tokens
|
|
422
|
+
if NSA_DECODE_IMPL == "flashmla_kv":
|
|
423
|
+
flashmla_metadata = self.decode_cuda_graph_metadata[
|
|
424
|
+
"flashmla_metadata"
|
|
425
|
+
].slice(slice(0, num_tokens + 1))
|
|
426
|
+
flashmla_metadata.copy_(
|
|
427
|
+
self._compute_flashmla_metadata(
|
|
428
|
+
cache_seqlens=nsa_cache_seqlens_int32,
|
|
429
|
+
seq_len_q=1,
|
|
430
|
+
)
|
|
431
|
+
)
|
|
432
|
+
else:
|
|
433
|
+
flashmla_metadata = None
|
|
434
|
+
elif forward_mode.is_target_verify():
|
|
435
|
+
cache_seqlens_int32 = (seq_lens + self.speculative_num_draft_tokens).to(
|
|
436
|
+
torch.int32
|
|
437
|
+
)
|
|
438
|
+
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
|
439
|
+
max_seqlen_q = 1
|
|
440
|
+
page_table_1 = self.decode_cuda_graph_metadata["page_table"][
|
|
441
|
+
: bs * self.speculative_num_draft_tokens, :
|
|
442
|
+
]
|
|
443
|
+
max_seqlen_k = page_table_1.shape[1]
|
|
444
|
+
|
|
445
|
+
cu_seqlens_q = torch.arange(
|
|
446
|
+
0,
|
|
447
|
+
bs * self.speculative_num_draft_tokens + 1,
|
|
448
|
+
1,
|
|
449
|
+
dtype=torch.int32,
|
|
450
|
+
device=self.device,
|
|
451
|
+
)
|
|
359
452
|
|
|
360
|
-
|
|
361
|
-
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
|
|
362
|
-
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
|
363
|
-
cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
|
|
364
|
-
)
|
|
365
|
-
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
|
|
366
|
-
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
|
|
367
|
-
real_page_table = self._transform_table_1_to_real(page_table_1)
|
|
453
|
+
extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
|
|
368
454
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
]
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
455
|
+
seqlens_int32_cpu = [
|
|
456
|
+
self.speculative_num_draft_tokens + kv_len
|
|
457
|
+
for kv_len in seq_lens.tolist()
|
|
458
|
+
]
|
|
459
|
+
seqlens_expanded = torch.cat(
|
|
460
|
+
[
|
|
461
|
+
torch.arange(
|
|
462
|
+
kv_len - qo_len + 1,
|
|
463
|
+
kv_len + 1,
|
|
464
|
+
dtype=torch.int32,
|
|
465
|
+
device=self.device,
|
|
466
|
+
)
|
|
467
|
+
for qo_len, kv_len in zip(
|
|
468
|
+
extend_seq_lens_cpu,
|
|
469
|
+
seqlens_int32_cpu,
|
|
470
|
+
strict=True,
|
|
471
|
+
)
|
|
472
|
+
]
|
|
473
|
+
)
|
|
474
|
+
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
|
475
|
+
seqlens_expanded, nsa_index_topk=self.nsa_index_topk
|
|
476
|
+
)
|
|
477
|
+
nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens
|
|
478
|
+
|
|
479
|
+
if NSA_DECODE_IMPL == "flashmla_kv":
|
|
480
|
+
flashmla_metadata = self.decode_cuda_graph_metadata[
|
|
481
|
+
"flashmla_metadata"
|
|
482
|
+
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
|
|
483
|
+
|
|
484
|
+
flashmla_metadata.copy_(
|
|
485
|
+
self._compute_flashmla_metadata(
|
|
486
|
+
cache_seqlens=nsa_cache_seqlens_int32,
|
|
487
|
+
seq_len_q=1,
|
|
488
|
+
)
|
|
377
489
|
)
|
|
490
|
+
else:
|
|
491
|
+
flashmla_metadata = None
|
|
492
|
+
elif forward_mode.is_draft_extend():
|
|
493
|
+
cache_seqlens_int32 = (seq_lens + self.speculative_num_draft_tokens).to(
|
|
494
|
+
torch.int32
|
|
378
495
|
)
|
|
379
|
-
|
|
380
|
-
|
|
496
|
+
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
|
497
|
+
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
|
498
|
+
max_seqlen_k = page_table_1.shape[1]
|
|
499
|
+
|
|
500
|
+
extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
|
|
501
|
+
extend_seq_lens = torch.full(
|
|
502
|
+
(bs,),
|
|
503
|
+
self.speculative_num_draft_tokens,
|
|
504
|
+
device=self.device,
|
|
505
|
+
dtype=torch.int32,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
max_seqlen_q = max(extend_seq_lens_cpu)
|
|
509
|
+
cu_seqlens_q = compute_cu_seqlens(extend_seq_lens.to(torch.int32))
|
|
510
|
+
|
|
511
|
+
seqlens_int32_cpu = [
|
|
512
|
+
self.speculative_num_draft_tokens + kv_len
|
|
513
|
+
for kv_len in seq_lens.tolist()
|
|
514
|
+
]
|
|
515
|
+
seqlens_expanded = torch.cat(
|
|
516
|
+
[
|
|
517
|
+
torch.arange(
|
|
518
|
+
kv_len - qo_len + 1,
|
|
519
|
+
kv_len + 1,
|
|
520
|
+
dtype=torch.int32,
|
|
521
|
+
device=self.device,
|
|
522
|
+
)
|
|
523
|
+
for qo_len, kv_len in zip(
|
|
524
|
+
extend_seq_lens_cpu,
|
|
525
|
+
seqlens_int32_cpu,
|
|
526
|
+
strict=True,
|
|
527
|
+
)
|
|
528
|
+
]
|
|
529
|
+
)
|
|
530
|
+
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
|
531
|
+
seqlens_expanded, nsa_index_topk=self.nsa_index_topk
|
|
532
|
+
)
|
|
533
|
+
nsa_extend_seq_lens_list = [1] * bs
|
|
534
|
+
|
|
535
|
+
if NSA_DECODE_IMPL == "flashmla_kv":
|
|
536
|
+
flashmla_metadata = self.decode_cuda_graph_metadata[
|
|
537
|
+
"flashmla_metadata"
|
|
538
|
+
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
|
|
539
|
+
# As the DeepGemm is not support for q_len = 3/4 in Indexer and every token has independent topk_indices,
|
|
540
|
+
# we made the Q shape [bs * speculative_num_draft_tokens, 1, head_nums, dim].
|
|
541
|
+
# So seq_len_q is 1 for flashmla_metadata in target_verify and draft_extend mode.
|
|
542
|
+
flashmla_metadata.copy_(
|
|
543
|
+
self._compute_flashmla_metadata(
|
|
544
|
+
cache_seqlens=nsa_cache_seqlens_int32,
|
|
545
|
+
seq_len_q=1,
|
|
546
|
+
)
|
|
547
|
+
)
|
|
548
|
+
else:
|
|
549
|
+
flashmla_metadata = None
|
|
550
|
+
|
|
551
|
+
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
|
|
552
|
+
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
|
|
553
|
+
real_page_table = self._transform_table_1_to_real(page_table_1)
|
|
381
554
|
|
|
382
555
|
metadata = NSAMetadata(
|
|
383
556
|
page_size=self.real_page_size,
|
|
384
557
|
cache_seqlens_int32=cache_seqlens_int32,
|
|
385
|
-
max_seq_len_q=
|
|
386
|
-
max_seq_len_k=
|
|
558
|
+
max_seq_len_q=max_seqlen_q,
|
|
559
|
+
max_seq_len_k=max_seqlen_k,
|
|
387
560
|
cu_seqlens_q=cu_seqlens_q,
|
|
388
561
|
cu_seqlens_k=cu_seqlens_k,
|
|
389
562
|
page_table_1=page_table_1,
|
|
@@ -391,9 +564,9 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
391
564
|
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
|
|
392
565
|
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
|
|
393
566
|
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
|
|
394
|
-
nsa_seqlens_expanded=
|
|
567
|
+
nsa_seqlens_expanded=seqlens_expanded,
|
|
395
568
|
real_page_table=real_page_table,
|
|
396
|
-
nsa_extend_seq_lens_list=
|
|
569
|
+
nsa_extend_seq_lens_list=nsa_extend_seq_lens_list,
|
|
397
570
|
)
|
|
398
571
|
self.decode_cuda_graph_metadata[bs] = metadata
|
|
399
572
|
self.forward_metadata = metadata
|
|
@@ -412,33 +585,119 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
412
585
|
):
|
|
413
586
|
"""Initialize forward metadata for replaying CUDA graph."""
|
|
414
587
|
assert seq_lens_cpu is not None
|
|
415
|
-
|
|
416
|
-
assert (
|
|
417
|
-
spec_info is None
|
|
418
|
-
), "Speculative decoding is not supported for NSA backend now"
|
|
588
|
+
|
|
419
589
|
seq_lens = seq_lens[:bs]
|
|
420
590
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
|
421
591
|
req_pool_indices = req_pool_indices[:bs]
|
|
422
592
|
|
|
423
593
|
# Normal Decode
|
|
424
594
|
metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
|
|
425
|
-
|
|
595
|
+
if forward_mode.is_decode_or_idle():
|
|
596
|
+
# Normal Decode
|
|
597
|
+
max_len = int(seq_lens_cpu.max().item())
|
|
598
|
+
|
|
599
|
+
cache_seqlens = seq_lens.to(torch.int32)
|
|
600
|
+
metadata.cache_seqlens_int32.copy_(cache_seqlens)
|
|
601
|
+
metadata.cu_seqlens_k[1:].copy_(
|
|
602
|
+
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
|
|
603
|
+
)
|
|
604
|
+
page_indices = self.req_to_token[req_pool_indices, :max_len]
|
|
605
|
+
metadata.page_table_1[:, :max_len].copy_(page_indices)
|
|
606
|
+
nsa_cache_seqlens = compute_nsa_seqlens(
|
|
607
|
+
cache_seqlens, nsa_index_topk=self.nsa_index_topk
|
|
608
|
+
)
|
|
609
|
+
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
|
|
610
|
+
seqlens_expanded = cache_seqlens
|
|
611
|
+
elif forward_mode.is_target_verify():
|
|
612
|
+
max_seqlen_k = int(
|
|
613
|
+
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
|
614
|
+
)
|
|
426
615
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
616
|
+
cache_seqlens = (seq_lens + self.speculative_num_draft_tokens).to(
|
|
617
|
+
torch.int32
|
|
618
|
+
)
|
|
619
|
+
metadata.cache_seqlens_int32.copy_(cache_seqlens)
|
|
620
|
+
metadata.cu_seqlens_k[1:].copy_(
|
|
621
|
+
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
|
|
622
|
+
)
|
|
623
|
+
page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]
|
|
624
|
+
page_indices = torch.repeat_interleave(
|
|
625
|
+
page_indices, repeats=self.speculative_num_draft_tokens, dim=0
|
|
626
|
+
)
|
|
627
|
+
metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices)
|
|
628
|
+
extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs
|
|
629
|
+
|
|
630
|
+
seqlens_int32_cpu = [
|
|
631
|
+
self.speculative_num_draft_tokens + kv_len
|
|
632
|
+
for kv_len in seq_lens_cpu.tolist()
|
|
633
|
+
]
|
|
634
|
+
seqlens_expanded = torch.cat(
|
|
635
|
+
[
|
|
636
|
+
torch.arange(
|
|
637
|
+
kv_len - qo_len + 1,
|
|
638
|
+
kv_len + 1,
|
|
639
|
+
dtype=torch.int32,
|
|
640
|
+
device=self.device,
|
|
641
|
+
)
|
|
642
|
+
for qo_len, kv_len in zip(
|
|
643
|
+
extend_seq_lens_cpu,
|
|
644
|
+
seqlens_int32_cpu,
|
|
645
|
+
strict=True,
|
|
646
|
+
)
|
|
647
|
+
]
|
|
648
|
+
)
|
|
649
|
+
metadata.nsa_seqlens_expanded.copy_(seqlens_expanded)
|
|
650
|
+
nsa_cache_seqlens = compute_nsa_seqlens(
|
|
651
|
+
seqlens_expanded, self.nsa_index_topk
|
|
652
|
+
)
|
|
653
|
+
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
|
|
654
|
+
elif forward_mode.is_draft_extend():
|
|
655
|
+
max_seqlen_k = int(seq_lens_cpu.max().item())
|
|
656
|
+
cache_seqlens = seq_lens.to(torch.int32)
|
|
657
|
+
metadata.cache_seqlens_int32.copy_(cache_seqlens)
|
|
658
|
+
metadata.cu_seqlens_k[1:].copy_(
|
|
659
|
+
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
|
|
660
|
+
)
|
|
661
|
+
page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]
|
|
662
|
+
metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices)
|
|
663
|
+
extend_seq_lens_cpu = spec_info.accept_length[:bs].tolist()
|
|
664
|
+
|
|
665
|
+
seqlens_int32_cpu = [
|
|
666
|
+
self.speculative_num_draft_tokens + kv_len
|
|
667
|
+
for kv_len in seq_lens_cpu.tolist()
|
|
668
|
+
]
|
|
669
|
+
seqlens_expanded = torch.cat(
|
|
670
|
+
[
|
|
671
|
+
torch.arange(
|
|
672
|
+
kv_len - qo_len + 1,
|
|
673
|
+
kv_len + 1,
|
|
674
|
+
dtype=torch.int32,
|
|
675
|
+
device=self.device,
|
|
676
|
+
)
|
|
677
|
+
for qo_len, kv_len in zip(
|
|
678
|
+
extend_seq_lens_cpu,
|
|
679
|
+
seqlens_int32_cpu,
|
|
680
|
+
strict=True,
|
|
681
|
+
)
|
|
682
|
+
]
|
|
683
|
+
)
|
|
684
|
+
metadata.nsa_seqlens_expanded[: seqlens_expanded.size(0)].copy_(
|
|
685
|
+
seqlens_expanded
|
|
686
|
+
)
|
|
687
|
+
nsa_cache_seqlens = compute_nsa_seqlens(
|
|
688
|
+
seqlens_expanded, self.nsa_index_topk
|
|
689
|
+
)
|
|
690
|
+
metadata.nsa_cache_seqlens_int32[: seqlens_expanded.size(0)].copy_(
|
|
691
|
+
nsa_cache_seqlens
|
|
692
|
+
)
|
|
693
|
+
seqlens_expanded_size = seqlens_expanded.size(0)
|
|
434
694
|
assert (
|
|
435
695
|
metadata.nsa_cache_seqlens_int32 is not None
|
|
436
696
|
and metadata.nsa_cu_seqlens_k is not None
|
|
437
697
|
and self.nsa_index_topk is not None
|
|
438
698
|
)
|
|
439
|
-
|
|
440
|
-
metadata.
|
|
441
|
-
metadata.nsa_cu_seqlens_k[1:].copy_(
|
|
699
|
+
|
|
700
|
+
metadata.nsa_cu_seqlens_k[1 : 1 + seqlens_expanded_size].copy_(
|
|
442
701
|
torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
|
|
443
702
|
)
|
|
444
703
|
# NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
|
|
@@ -451,11 +710,14 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
451
710
|
else:
|
|
452
711
|
assert metadata.real_page_table is metadata.page_table_1
|
|
453
712
|
|
|
454
|
-
if NSA_DECODE_IMPL == "
|
|
455
|
-
metadata.flashmla_metadata.
|
|
713
|
+
if NSA_DECODE_IMPL == "flashmla_kv":
|
|
714
|
+
flashmla_metadata = metadata.flashmla_metadata.slice(
|
|
715
|
+
slice(0, seqlens_expanded_size + 1)
|
|
716
|
+
)
|
|
717
|
+
flashmla_metadata.copy_(
|
|
456
718
|
self._compute_flashmla_metadata(
|
|
457
719
|
cache_seqlens=nsa_cache_seqlens,
|
|
458
|
-
seq_len_q=1,
|
|
720
|
+
seq_len_q=1,
|
|
459
721
|
)
|
|
460
722
|
)
|
|
461
723
|
|
|
@@ -474,10 +736,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
474
736
|
k_rope: Optional[torch.Tensor] = None,
|
|
475
737
|
topk_indices: Optional[torch.Tensor] = None,
|
|
476
738
|
) -> torch.Tensor:
|
|
477
|
-
|
|
478
|
-
not forward_batch.forward_mode.is_target_verify()
|
|
479
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
|
480
|
-
), "NSA backend doesn't support speculative decoding"
|
|
739
|
+
|
|
481
740
|
if k is not None:
|
|
482
741
|
assert v is not None
|
|
483
742
|
if save_kv_cache:
|
|
@@ -542,20 +801,20 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
542
801
|
sm_scale=layer.scaling,
|
|
543
802
|
v_head_dim=layer.v_head_dim,
|
|
544
803
|
)
|
|
545
|
-
elif NSA_PREFILL_IMPL == "
|
|
804
|
+
elif NSA_PREFILL_IMPL == "flashmla_sparse":
|
|
546
805
|
if q_rope is not None:
|
|
547
806
|
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
|
548
|
-
return self.
|
|
807
|
+
return self._forward_flashmla_sparse(
|
|
549
808
|
q_all=q_all,
|
|
550
809
|
kv_cache=kv_cache,
|
|
551
810
|
page_table_1=page_table_1,
|
|
552
811
|
sm_scale=layer.scaling,
|
|
553
812
|
v_head_dim=layer.v_head_dim,
|
|
554
813
|
)
|
|
555
|
-
elif NSA_PREFILL_IMPL == "
|
|
814
|
+
elif NSA_PREFILL_IMPL == "flashmla_kv":
|
|
556
815
|
if q_rope is not None:
|
|
557
816
|
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
|
558
|
-
return self.
|
|
817
|
+
return self._forward_flashmla_kv(
|
|
559
818
|
q_all=q_all,
|
|
560
819
|
kv_cache=kv_cache,
|
|
561
820
|
sm_scale=layer.scaling,
|
|
@@ -636,20 +895,20 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
636
895
|
page_size=1,
|
|
637
896
|
)
|
|
638
897
|
|
|
639
|
-
if NSA_DECODE_IMPL == "
|
|
898
|
+
if NSA_DECODE_IMPL == "flashmla_sparse":
|
|
640
899
|
if q_rope is not None:
|
|
641
900
|
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
|
642
|
-
return self.
|
|
901
|
+
return self._forward_flashmla_sparse(
|
|
643
902
|
q_all=q_all,
|
|
644
903
|
kv_cache=kv_cache,
|
|
645
904
|
page_table_1=page_table_1,
|
|
646
905
|
sm_scale=layer.scaling,
|
|
647
906
|
v_head_dim=layer.v_head_dim,
|
|
648
907
|
)
|
|
649
|
-
elif NSA_DECODE_IMPL == "
|
|
908
|
+
elif NSA_DECODE_IMPL == "flashmla_kv":
|
|
650
909
|
if q_rope is not None:
|
|
651
910
|
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
|
652
|
-
return self.
|
|
911
|
+
return self._forward_flashmla_kv(
|
|
653
912
|
q_all=q_all,
|
|
654
913
|
kv_cache=kv_cache,
|
|
655
914
|
sm_scale=layer.scaling,
|
|
@@ -737,7 +996,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
737
996
|
)
|
|
738
997
|
return o # type: ignore
|
|
739
998
|
|
|
740
|
-
def
|
|
999
|
+
def _forward_flashmla_sparse(
|
|
741
1000
|
self,
|
|
742
1001
|
q_all: torch.Tensor,
|
|
743
1002
|
kv_cache: torch.Tensor,
|
|
@@ -756,7 +1015,7 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
756
1015
|
)
|
|
757
1016
|
return o
|
|
758
1017
|
|
|
759
|
-
def
|
|
1018
|
+
def _forward_flashmla_kv(
|
|
760
1019
|
self,
|
|
761
1020
|
q_all: torch.Tensor,
|
|
762
1021
|
kv_cache: torch.Tensor,
|
|
@@ -885,3 +1144,58 @@ class NativeSparseAttnBackend(AttentionBackend):
|
|
|
885
1144
|
flashmla_metadata=flashmla_metadata,
|
|
886
1145
|
num_splits=num_splits,
|
|
887
1146
|
)
|
|
1147
|
+
|
|
1148
|
+
|
|
1149
|
+
class NativeSparseAttnMultiStepBackend:
|
|
1150
|
+
|
|
1151
|
+
def __init__(
|
|
1152
|
+
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
|
|
1153
|
+
):
|
|
1154
|
+
self.model_runner = model_runner
|
|
1155
|
+
self.topk = topk
|
|
1156
|
+
self.speculative_num_steps = speculative_num_steps
|
|
1157
|
+
self.attn_backends = []
|
|
1158
|
+
for i in range(self.speculative_num_steps):
|
|
1159
|
+
self.attn_backends.append(
|
|
1160
|
+
NativeSparseAttnBackend(
|
|
1161
|
+
model_runner,
|
|
1162
|
+
speculative_step_id=i,
|
|
1163
|
+
topk=self.topk,
|
|
1164
|
+
speculative_num_steps=self.speculative_num_steps,
|
|
1165
|
+
)
|
|
1166
|
+
)
|
|
1167
|
+
|
|
1168
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
1169
|
+
for i in range(self.speculative_num_steps - 1):
|
|
1170
|
+
self.attn_backends[i].init_forward_metadata(forward_batch)
|
|
1171
|
+
|
|
1172
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
|
1173
|
+
for i in range(self.speculative_num_steps):
|
|
1174
|
+
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
|
1175
|
+
|
|
1176
|
+
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
|
1177
|
+
for i in range(self.speculative_num_steps):
|
|
1178
|
+
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
|
1179
|
+
forward_batch.batch_size,
|
|
1180
|
+
forward_batch.batch_size * self.topk,
|
|
1181
|
+
forward_batch.req_pool_indices,
|
|
1182
|
+
forward_batch.seq_lens,
|
|
1183
|
+
encoder_lens=None,
|
|
1184
|
+
forward_mode=ForwardMode.DECODE,
|
|
1185
|
+
spec_info=forward_batch.spec_info,
|
|
1186
|
+
)
|
|
1187
|
+
|
|
1188
|
+
def init_forward_metadata_replay_cuda_graph(
|
|
1189
|
+
self, forward_batch: ForwardBatch, bs: int
|
|
1190
|
+
):
|
|
1191
|
+
for i in range(self.speculative_num_steps):
|
|
1192
|
+
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
|
1193
|
+
bs,
|
|
1194
|
+
forward_batch.req_pool_indices,
|
|
1195
|
+
forward_batch.seq_lens,
|
|
1196
|
+
seq_lens_sum=-1,
|
|
1197
|
+
encoder_lens=None,
|
|
1198
|
+
forward_mode=ForwardMode.DECODE,
|
|
1199
|
+
spec_info=forward_batch.spec_info,
|
|
1200
|
+
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
|
1201
|
+
)
|