sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -10,19 +10,21 @@ from typing import TYPE_CHECKING, Optional, Union
|
|
|
10
10
|
|
|
11
11
|
import torch
|
|
12
12
|
import triton
|
|
13
|
+
import triton.language as tl
|
|
13
14
|
|
|
14
15
|
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
|
15
16
|
FlashInferMLAAttnBackend,
|
|
16
17
|
FlashInferMLAMultiStepDraftBackend,
|
|
17
18
|
)
|
|
18
19
|
from sglang.srt.layers.attention.utils import (
|
|
19
|
-
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
|
20
20
|
create_flashmla_kv_indices_triton,
|
|
21
|
+
get_num_page_per_block_flashmla,
|
|
21
22
|
)
|
|
22
23
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
23
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
24
24
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
25
|
+
from sglang.srt.server_args import get_global_server_args
|
|
25
26
|
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
|
27
|
+
from sglang.srt.utils.common import cached_triton_kernel
|
|
26
28
|
|
|
27
29
|
if is_flashinfer_available():
|
|
28
30
|
import flashinfer
|
|
@@ -48,6 +50,153 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
|
|
48
50
|
# compute the LCM with other padding constraints.
|
|
49
51
|
TRTLLM_BLOCK_CONSTRAINT = 128
|
|
50
52
|
|
|
53
|
+
|
|
54
|
+
@cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
|
|
55
|
+
@triton.jit
|
|
56
|
+
def pad_draft_extend_query_kernel(
|
|
57
|
+
q_ptr, # Input query tensor [total_seq_len, num_heads, head_dim]
|
|
58
|
+
padded_q_ptr, # Output padded query tensor [batch_size, max_seq_len, num_heads, head_dim]
|
|
59
|
+
seq_lens_q_ptr, # Sequence lengths for each sequence [batch_size]
|
|
60
|
+
cumsum_ptr, # Cumulative sum of accept lengths [batch_size + 1]
|
|
61
|
+
batch_size,
|
|
62
|
+
max_seq_len,
|
|
63
|
+
num_heads,
|
|
64
|
+
head_dim,
|
|
65
|
+
BLOCK_SIZE: tl.constexpr,
|
|
66
|
+
):
|
|
67
|
+
"""Triton kernel for padding draft extended query tensor with parallelized head and dim processing."""
|
|
68
|
+
# Use 3D program IDs: (batch_seq, head_block, dim_block)
|
|
69
|
+
batch_seq_pid = tl.program_id(0)
|
|
70
|
+
head_pid = tl.program_id(1)
|
|
71
|
+
dim_pid = tl.program_id(2)
|
|
72
|
+
|
|
73
|
+
batch_id = batch_seq_pid // max_seq_len
|
|
74
|
+
seq_pos = batch_seq_pid % max_seq_len
|
|
75
|
+
|
|
76
|
+
if batch_id >= batch_size:
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
# Load accept length for this batch
|
|
80
|
+
seq_len = tl.load(seq_lens_q_ptr + batch_id)
|
|
81
|
+
|
|
82
|
+
if seq_pos >= seq_len:
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
# Load cumulative sum to get start position in input tensor
|
|
86
|
+
input_start = tl.load(cumsum_ptr + batch_id)
|
|
87
|
+
input_pos = input_start + seq_pos
|
|
88
|
+
|
|
89
|
+
# Calculate head and dim block ranges
|
|
90
|
+
head_start = head_pid * BLOCK_SIZE
|
|
91
|
+
head_end = tl.minimum(head_start + BLOCK_SIZE, num_heads)
|
|
92
|
+
head_mask = tl.arange(0, BLOCK_SIZE) < (head_end - head_start)
|
|
93
|
+
|
|
94
|
+
dim_start = dim_pid * BLOCK_SIZE
|
|
95
|
+
dim_end = tl.minimum(dim_start + BLOCK_SIZE, head_dim)
|
|
96
|
+
dim_mask = tl.arange(0, BLOCK_SIZE) < (dim_end - dim_start)
|
|
97
|
+
|
|
98
|
+
# Calculate input offset
|
|
99
|
+
input_offset = (
|
|
100
|
+
input_pos * num_heads * head_dim
|
|
101
|
+
+ (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * head_dim
|
|
102
|
+
+ (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Load data
|
|
106
|
+
data = tl.load(
|
|
107
|
+
q_ptr + input_offset,
|
|
108
|
+
mask=head_mask[:, None] & dim_mask[None, :],
|
|
109
|
+
other=0.0,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Calculate output offset
|
|
113
|
+
output_offset = (
|
|
114
|
+
batch_id * max_seq_len * num_heads * head_dim
|
|
115
|
+
+ seq_pos * num_heads * head_dim
|
|
116
|
+
+ (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * head_dim
|
|
117
|
+
+ (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Store data
|
|
121
|
+
tl.store(
|
|
122
|
+
padded_q_ptr + output_offset,
|
|
123
|
+
data,
|
|
124
|
+
mask=head_mask[:, None] & dim_mask[None, :],
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"]))
|
|
129
|
+
@triton.jit
|
|
130
|
+
def unpad_draft_extend_output_kernel(
|
|
131
|
+
raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)
|
|
132
|
+
output_ptr, # Output tensor (-1, tp_q_head_num, v_head_dim)
|
|
133
|
+
accept_length_ptr, # Accept lengths for each sequence [batch_size]
|
|
134
|
+
cumsum_ptr, # Cumulative sum of accept lengths [batch_size + 1]
|
|
135
|
+
batch_size,
|
|
136
|
+
token_per_batch,
|
|
137
|
+
tp_q_head_num,
|
|
138
|
+
v_head_dim,
|
|
139
|
+
BLOCK_SIZE: tl.constexpr,
|
|
140
|
+
):
|
|
141
|
+
"""Triton kernel for unpadding draft extended output tensor with parallelized head and dim processing."""
|
|
142
|
+
batch_seq_pid = tl.program_id(0)
|
|
143
|
+
head_pid = tl.program_id(1)
|
|
144
|
+
dim_pid = tl.program_id(2)
|
|
145
|
+
|
|
146
|
+
batch_id = batch_seq_pid // token_per_batch
|
|
147
|
+
seq_pos = batch_seq_pid % token_per_batch
|
|
148
|
+
|
|
149
|
+
if batch_id >= batch_size:
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
# Load accept length for this batch
|
|
153
|
+
accept_len = tl.load(accept_length_ptr + batch_id)
|
|
154
|
+
|
|
155
|
+
if seq_pos >= accept_len:
|
|
156
|
+
return
|
|
157
|
+
|
|
158
|
+
# Load cumulative sum to get start position in output tensor
|
|
159
|
+
output_start = tl.load(cumsum_ptr + batch_id)
|
|
160
|
+
output_pos = output_start + seq_pos
|
|
161
|
+
|
|
162
|
+
# Calculate head and dim block ranges
|
|
163
|
+
head_start = head_pid * BLOCK_SIZE
|
|
164
|
+
head_end = tl.minimum(head_start + BLOCK_SIZE, tp_q_head_num)
|
|
165
|
+
head_mask = tl.arange(0, BLOCK_SIZE) < (head_end - head_start)
|
|
166
|
+
|
|
167
|
+
dim_start = dim_pid * BLOCK_SIZE
|
|
168
|
+
dim_end = tl.minimum(dim_start + BLOCK_SIZE, v_head_dim)
|
|
169
|
+
dim_mask = tl.arange(0, BLOCK_SIZE) < (dim_end - dim_start)
|
|
170
|
+
|
|
171
|
+
# Calculate input offset: (batch_id, seq_pos, head_id, dim_id)
|
|
172
|
+
input_offset = (
|
|
173
|
+
batch_id * token_per_batch * tp_q_head_num * v_head_dim
|
|
174
|
+
+ seq_pos * tp_q_head_num * v_head_dim
|
|
175
|
+
+ (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * v_head_dim
|
|
176
|
+
+ (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Load data
|
|
180
|
+
data = tl.load(
|
|
181
|
+
raw_out_ptr + input_offset,
|
|
182
|
+
mask=head_mask[:, None] & dim_mask[None, :],
|
|
183
|
+
other=0.0,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
output_offset = (
|
|
187
|
+
output_pos * tp_q_head_num * v_head_dim
|
|
188
|
+
+ (head_start + tl.arange(0, BLOCK_SIZE))[:, None] * v_head_dim
|
|
189
|
+
+ (dim_start + tl.arange(0, BLOCK_SIZE))[None, :]
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Store data
|
|
193
|
+
tl.store(
|
|
194
|
+
output_ptr + output_offset,
|
|
195
|
+
data,
|
|
196
|
+
mask=head_mask[:, None] & dim_mask[None, :],
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
51
200
|
global_zero_init_workspace_buffer = None
|
|
52
201
|
|
|
53
202
|
|
|
@@ -65,7 +214,11 @@ class TRTLLMMLADecodeMetadata:
|
|
|
65
214
|
"""Metadata for TRTLLM MLA decode operations."""
|
|
66
215
|
|
|
67
216
|
block_kv_indices: Optional[torch.Tensor] = None
|
|
68
|
-
|
|
217
|
+
max_seq_len_k: Optional[int] = None
|
|
218
|
+
max_seq_len_q: Optional[int] = None
|
|
219
|
+
sum_seq_lens_q: Optional[int] = None
|
|
220
|
+
cu_seqlens_q: Optional[torch.Tensor] = None
|
|
221
|
+
seq_lens_q: Optional[torch.Tensor] = None
|
|
69
222
|
|
|
70
223
|
|
|
71
224
|
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
@@ -120,12 +273,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
120
273
|
# CUDA graph state
|
|
121
274
|
self.decode_cuda_graph_metadata = {}
|
|
122
275
|
self.decode_cuda_graph_kv_indices = None
|
|
276
|
+
self.padded_q_buffer = None
|
|
277
|
+
self.unpad_output_buffer = None
|
|
123
278
|
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
|
|
124
279
|
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
|
125
280
|
|
|
126
|
-
self.disable_chunked_prefix_cache =
|
|
127
|
-
|
|
128
|
-
|
|
281
|
+
self.disable_chunked_prefix_cache = (
|
|
282
|
+
get_global_server_args().disable_chunked_prefix_cache
|
|
283
|
+
)
|
|
129
284
|
|
|
130
285
|
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
|
131
286
|
|
|
@@ -143,9 +298,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
143
298
|
|
|
144
299
|
# Apply dual constraints (take LCM to satisfy both):
|
|
145
300
|
# 1. TRT-LLM: block_num % (128 / page_size) == 0
|
|
146
|
-
# 2. Triton:
|
|
301
|
+
# 2. Triton: number of pages per block
|
|
147
302
|
trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
|
|
148
|
-
|
|
303
|
+
triton_constraint = get_num_page_per_block_flashmla(self.page_size)
|
|
304
|
+
constraint_lcm = math.lcm(trtllm_constraint, triton_constraint)
|
|
149
305
|
|
|
150
306
|
if blocks % constraint_lcm != 0:
|
|
151
307
|
blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
|
|
@@ -184,7 +340,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
184
340
|
block_kv_indices,
|
|
185
341
|
self.req_to_token.stride(0),
|
|
186
342
|
max_blocks,
|
|
187
|
-
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
|
188
343
|
PAGED_SIZE=self.page_size,
|
|
189
344
|
)
|
|
190
345
|
|
|
@@ -203,6 +358,21 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
203
358
|
self.decode_cuda_graph_kv_indices = torch.full(
|
|
204
359
|
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
|
205
360
|
)
|
|
361
|
+
num_tokens_per_bs = max_num_tokens // max_bs
|
|
362
|
+
|
|
363
|
+
# Buffer for padded query: (max_bs, max_draft_tokens, num_q_heads, v_head_dim)
|
|
364
|
+
self.padded_q_buffer = torch.zeros(
|
|
365
|
+
(max_bs, num_tokens_per_bs, self.num_q_heads, self.kv_cache_dim),
|
|
366
|
+
dtype=self.data_type,
|
|
367
|
+
device=self.device,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# Buffer for unpadded output: (max_num_tokens, num_q_heads, v_head_dim)
|
|
371
|
+
self.unpad_output_buffer = torch.zeros(
|
|
372
|
+
(max_num_tokens, self.num_q_heads, 512),
|
|
373
|
+
dtype=self.data_type,
|
|
374
|
+
device=self.device,
|
|
375
|
+
)
|
|
206
376
|
|
|
207
377
|
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
|
|
208
378
|
|
|
@@ -219,7 +389,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
219
389
|
"""Initialize metadata for CUDA graph capture."""
|
|
220
390
|
|
|
221
391
|
# Delegate to parent for non-decode modes.
|
|
222
|
-
if
|
|
392
|
+
if (
|
|
393
|
+
not forward_mode.is_decode_or_idle()
|
|
394
|
+
and not forward_mode.is_target_verify()
|
|
395
|
+
and not forward_mode.is_draft_extend(include_v2=True)
|
|
396
|
+
):
|
|
223
397
|
return super().init_forward_metadata_capture_cuda_graph(
|
|
224
398
|
bs,
|
|
225
399
|
num_tokens,
|
|
@@ -246,7 +420,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
246
420
|
block_kv_indices,
|
|
247
421
|
self.req_to_token.stride(0),
|
|
248
422
|
max_blocks_per_seq,
|
|
249
|
-
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
|
250
423
|
PAGED_SIZE=self.page_size,
|
|
251
424
|
)
|
|
252
425
|
|
|
@@ -259,6 +432,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
259
432
|
block_kv_indices,
|
|
260
433
|
max_seq_len_val,
|
|
261
434
|
)
|
|
435
|
+
if forward_mode.is_draft_extend(include_v2=True):
|
|
436
|
+
num_tokens_per_bs = num_tokens // bs
|
|
437
|
+
metadata.max_seq_len_q = num_tokens_per_bs + 1
|
|
438
|
+
metadata.sum_seq_lens_q = num_tokens_per_bs * bs
|
|
439
|
+
metadata.cu_seqlens_q = torch.arange(
|
|
440
|
+
0,
|
|
441
|
+
bs * num_tokens_per_bs + 1,
|
|
442
|
+
num_tokens_per_bs,
|
|
443
|
+
dtype=torch.int32,
|
|
444
|
+
device=seq_lens.device,
|
|
445
|
+
)
|
|
446
|
+
metadata.seq_lens_q = torch.full(
|
|
447
|
+
(bs,), num_tokens_per_bs, dtype=torch.int32, device=seq_lens.device
|
|
448
|
+
)
|
|
262
449
|
self.decode_cuda_graph_metadata[bs] = metadata
|
|
263
450
|
self.forward_decode_metadata = metadata
|
|
264
451
|
|
|
@@ -275,7 +462,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
275
462
|
):
|
|
276
463
|
"""Replay CUDA graph with new inputs."""
|
|
277
464
|
# Delegate to parent for non-decode modes.
|
|
278
|
-
if
|
|
465
|
+
if (
|
|
466
|
+
not forward_mode.is_decode_or_idle()
|
|
467
|
+
and not forward_mode.is_target_verify()
|
|
468
|
+
and not forward_mode.is_draft_extend(include_v2=True)
|
|
469
|
+
):
|
|
279
470
|
return super().init_forward_metadata_replay_cuda_graph(
|
|
280
471
|
bs,
|
|
281
472
|
req_pool_indices,
|
|
@@ -293,6 +484,19 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
293
484
|
|
|
294
485
|
metadata = self.decode_cuda_graph_metadata[bs]
|
|
295
486
|
|
|
487
|
+
if forward_mode.is_draft_extend(include_v2=True):
|
|
488
|
+
accept_length = spec_info.accept_length[:bs]
|
|
489
|
+
if spec_info.accept_length_cpu:
|
|
490
|
+
metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs])
|
|
491
|
+
metadata.sum_seq_lens_q = sum(spec_info.accept_length_cpu[:bs])
|
|
492
|
+
else:
|
|
493
|
+
metadata.max_seq_len_q = 1
|
|
494
|
+
metadata.sum_seq_lens_q = bs
|
|
495
|
+
metadata.cu_seqlens_q[1:].copy_(
|
|
496
|
+
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
|
497
|
+
)
|
|
498
|
+
metadata.seq_lens_q.copy_(accept_length)
|
|
499
|
+
|
|
296
500
|
# Update block indices for new sequences.
|
|
297
501
|
create_flashmla_kv_indices_triton[(bs,)](
|
|
298
502
|
self.req_to_token,
|
|
@@ -302,7 +506,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
302
506
|
metadata.block_kv_indices,
|
|
303
507
|
self.req_to_token.stride(0),
|
|
304
508
|
metadata.block_kv_indices.shape[1],
|
|
305
|
-
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
|
306
509
|
PAGED_SIZE=self.page_size,
|
|
307
510
|
)
|
|
308
511
|
|
|
@@ -323,7 +526,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
323
526
|
if (
|
|
324
527
|
forward_batch.forward_mode.is_extend()
|
|
325
528
|
and not forward_batch.forward_mode.is_target_verify()
|
|
326
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
|
529
|
+
and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
|
|
327
530
|
):
|
|
328
531
|
if self.disable_chunked_prefix_cache:
|
|
329
532
|
super().init_forward_metadata(forward_batch)
|
|
@@ -344,6 +547,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
344
547
|
elif (
|
|
345
548
|
forward_batch.forward_mode.is_decode_or_idle()
|
|
346
549
|
or forward_batch.forward_mode.is_target_verify()
|
|
550
|
+
or forward_batch.forward_mode.is_draft_extend(include_v2=True)
|
|
347
551
|
):
|
|
348
552
|
bs = forward_batch.batch_size
|
|
349
553
|
|
|
@@ -372,6 +576,23 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
372
576
|
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
|
|
373
577
|
block_kv_indices, max_seq_len_val
|
|
374
578
|
)
|
|
579
|
+
if forward_batch.forward_mode.is_draft_extend(include_v2=True):
|
|
580
|
+
max_seq = forward_batch.seq_lens_cpu.max().item()
|
|
581
|
+
|
|
582
|
+
sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu)
|
|
583
|
+
max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
|
584
|
+
cu_seqlens_q = torch.nn.functional.pad(
|
|
585
|
+
torch.cumsum(
|
|
586
|
+
forward_batch.extend_seq_lens, dim=0, dtype=torch.int32
|
|
587
|
+
),
|
|
588
|
+
(1, 0),
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
self.forward_decode_metadata.max_seq_len_q = max_seq_len_q
|
|
592
|
+
self.forward_decode_metadata.sum_seq_lens_q = sum_seq_lens_q
|
|
593
|
+
self.forward_decode_metadata.cu_seqlens_q = cu_seqlens_q
|
|
594
|
+
self.forward_decode_metadata.seq_lens_q = forward_batch.extend_seq_lens
|
|
595
|
+
|
|
375
596
|
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
|
|
376
597
|
else:
|
|
377
598
|
return super().init_forward_metadata(forward_batch)
|
|
@@ -457,6 +678,86 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
457
678
|
|
|
458
679
|
return q_out, k_nope_out, k_rope_out
|
|
459
680
|
|
|
681
|
+
def pad_draft_extend_query(
|
|
682
|
+
self,
|
|
683
|
+
q: torch.Tensor,
|
|
684
|
+
padded_q: torch.Tensor,
|
|
685
|
+
seq_lens_q: torch.Tensor,
|
|
686
|
+
cu_seqlens_q: torch.Tensor,
|
|
687
|
+
) -> torch.Tensor:
|
|
688
|
+
"""Pad draft extended query using Triton kernel."""
|
|
689
|
+
batch_size = cu_seqlens_q.shape[0] - 1
|
|
690
|
+
max_seq_len_q = padded_q.shape[1]
|
|
691
|
+
num_heads = padded_q.shape[2]
|
|
692
|
+
head_dim = padded_q.shape[3]
|
|
693
|
+
|
|
694
|
+
# Launch Triton kernel with 3D grid for parallelized head and dim processing
|
|
695
|
+
BLOCK_SIZE = 64
|
|
696
|
+
num_head_blocks = triton.cdiv(num_heads, BLOCK_SIZE)
|
|
697
|
+
num_dim_blocks = triton.cdiv(head_dim, BLOCK_SIZE)
|
|
698
|
+
grid = (batch_size * max_seq_len_q, num_head_blocks, num_dim_blocks)
|
|
699
|
+
|
|
700
|
+
pad_draft_extend_query_kernel[grid](
|
|
701
|
+
q_ptr=q,
|
|
702
|
+
padded_q_ptr=padded_q,
|
|
703
|
+
seq_lens_q_ptr=seq_lens_q,
|
|
704
|
+
cumsum_ptr=cu_seqlens_q,
|
|
705
|
+
batch_size=batch_size,
|
|
706
|
+
max_seq_len=max_seq_len_q,
|
|
707
|
+
num_heads=num_heads,
|
|
708
|
+
head_dim=head_dim,
|
|
709
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
710
|
+
)
|
|
711
|
+
return padded_q
|
|
712
|
+
|
|
713
|
+
def unpad_draft_extend_output(
|
|
714
|
+
self,
|
|
715
|
+
raw_out: torch.Tensor,
|
|
716
|
+
cu_seqlens_q: torch.Tensor,
|
|
717
|
+
seq_lens_q: torch.Tensor,
|
|
718
|
+
sum_seq_lens_q: int,
|
|
719
|
+
) -> torch.Tensor:
|
|
720
|
+
"""Unpad draft extended output using Triton kernel."""
|
|
721
|
+
# raw_out: (batch_size, token_per_batch, layer.tp_q_head_num, layer.v_head_dim)
|
|
722
|
+
batch_size = seq_lens_q.shape[0]
|
|
723
|
+
token_per_batch = raw_out.shape[1] # max_seq_len
|
|
724
|
+
tp_q_head_num = raw_out.shape[2] # num_heads
|
|
725
|
+
v_head_dim = raw_out.shape[3] # head_dim
|
|
726
|
+
total_tokens = sum_seq_lens_q
|
|
727
|
+
|
|
728
|
+
# Check if we're in CUDA graph mode (buffers are pre-allocated)
|
|
729
|
+
if self.unpad_output_buffer is not None:
|
|
730
|
+
# Use pre-allocated buffer for CUDA graph compatibility
|
|
731
|
+
output = self.unpad_output_buffer[:total_tokens, :, :].to(
|
|
732
|
+
dtype=raw_out.dtype
|
|
733
|
+
)
|
|
734
|
+
else:
|
|
735
|
+
# Dynamic allocation for non-CUDA graph mode
|
|
736
|
+
output = torch.empty(
|
|
737
|
+
(total_tokens, tp_q_head_num, v_head_dim),
|
|
738
|
+
dtype=raw_out.dtype,
|
|
739
|
+
device=raw_out.device,
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
# Launch Triton kernel with 3D grid for parallelized head and dim processing
|
|
743
|
+
BLOCK_SIZE = 64
|
|
744
|
+
num_head_blocks = triton.cdiv(tp_q_head_num, BLOCK_SIZE)
|
|
745
|
+
num_dim_blocks = triton.cdiv(v_head_dim, BLOCK_SIZE)
|
|
746
|
+
grid = (batch_size * token_per_batch, num_head_blocks, num_dim_blocks)
|
|
747
|
+
|
|
748
|
+
unpad_draft_extend_output_kernel[grid](
|
|
749
|
+
raw_out_ptr=raw_out,
|
|
750
|
+
output_ptr=output,
|
|
751
|
+
accept_length_ptr=seq_lens_q,
|
|
752
|
+
cumsum_ptr=cu_seqlens_q,
|
|
753
|
+
batch_size=batch_size,
|
|
754
|
+
token_per_batch=token_per_batch,
|
|
755
|
+
tp_q_head_num=tp_q_head_num,
|
|
756
|
+
v_head_dim=v_head_dim,
|
|
757
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
758
|
+
)
|
|
759
|
+
return output[:total_tokens, :, :]
|
|
760
|
+
|
|
460
761
|
def forward_decode(
|
|
461
762
|
self,
|
|
462
763
|
q: torch.Tensor, # q_nope
|
|
@@ -550,7 +851,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
550
851
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
|
551
852
|
block_tables=metadata.block_kv_indices,
|
|
552
853
|
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
|
553
|
-
max_seq_len=metadata.
|
|
854
|
+
max_seq_len=metadata.max_seq_len_k,
|
|
554
855
|
bmm1_scale=bmm1_scale,
|
|
555
856
|
)
|
|
556
857
|
|
|
@@ -571,11 +872,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
571
872
|
cos_sin_cache: Optional[torch.Tensor] = None,
|
|
572
873
|
is_neox: Optional[bool] = False,
|
|
573
874
|
) -> torch.Tensor:
|
|
574
|
-
if forward_batch.forward_mode.is_draft_extend():
|
|
575
|
-
return super().forward_extend(
|
|
576
|
-
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
|
577
|
-
)
|
|
578
|
-
|
|
579
875
|
# TODO refactor to avoid code duplication
|
|
580
876
|
merge_query = q_rope is not None
|
|
581
877
|
if (
|
|
@@ -627,7 +923,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
627
923
|
|
|
628
924
|
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
|
629
925
|
|
|
630
|
-
if
|
|
926
|
+
if (
|
|
927
|
+
forward_batch.forward_mode.is_target_verify()
|
|
928
|
+
or forward_batch.forward_mode.is_draft_extend(include_v2=True)
|
|
929
|
+
):
|
|
631
930
|
metadata = (
|
|
632
931
|
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
|
633
932
|
or self.forward_decode_metadata
|
|
@@ -635,7 +934,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
635
934
|
|
|
636
935
|
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
|
|
637
936
|
bs = forward_batch.batch_size
|
|
638
|
-
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
|
639
937
|
|
|
640
938
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
|
641
939
|
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
|
|
@@ -646,17 +944,42 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
646
944
|
if getattr(layer, "k_scale_float", None) is not None
|
|
647
945
|
else 1.0
|
|
648
946
|
)
|
|
947
|
+
q = q.to(self.data_type)
|
|
649
948
|
|
|
650
949
|
bmm1_scale = q_scale * k_scale * layer.scaling
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
950
|
+
if forward_batch.forward_mode.is_target_verify():
|
|
951
|
+
seq_lens = (
|
|
952
|
+
forward_batch.seq_lens.to(torch.int32)
|
|
953
|
+
+ forward_batch.spec_info.draft_token_num
|
|
954
|
+
)
|
|
955
|
+
max_seq_len = (
|
|
956
|
+
metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num
|
|
957
|
+
)
|
|
958
|
+
else:
|
|
959
|
+
seq_lens = forward_batch.seq_lens.to(torch.int32)
|
|
960
|
+
max_seq_len = metadata.max_seq_len_k
|
|
961
|
+
# Check if we're in CUDA graph mode (buffers are pre-allocated)
|
|
962
|
+
if self.padded_q_buffer is not None:
|
|
963
|
+
# Use pre-allocated buffer for CUDA graph compatibility
|
|
964
|
+
padded_q = self.padded_q_buffer[
|
|
965
|
+
:bs, : metadata.max_seq_len_q, :, :
|
|
966
|
+
].to(dtype=q.dtype)
|
|
967
|
+
else:
|
|
968
|
+
# Dynamic allocation for non-CUDA graph mode
|
|
969
|
+
padded_q = torch.zeros(
|
|
970
|
+
bs,
|
|
971
|
+
metadata.max_seq_len_q,
|
|
972
|
+
layer.tp_q_head_num,
|
|
973
|
+
layer.head_dim,
|
|
974
|
+
dtype=q.dtype,
|
|
975
|
+
device=q.device,
|
|
976
|
+
)
|
|
977
|
+
q = self.pad_draft_extend_query(
|
|
978
|
+
q, padded_q, metadata.seq_lens_q, metadata.cu_seqlens_q
|
|
979
|
+
)
|
|
657
980
|
|
|
658
981
|
# TODO may use `mla_rope_quantize_fp8` fusion
|
|
659
|
-
q = q.
|
|
982
|
+
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
|
660
983
|
assert kv_cache.dtype == self.data_type
|
|
661
984
|
|
|
662
985
|
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
|
@@ -673,6 +996,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
|
673
996
|
)
|
|
674
997
|
|
|
675
998
|
# Reshape output directly without slicing
|
|
999
|
+
|
|
1000
|
+
if forward_batch.forward_mode.is_draft_extend(include_v2=True):
|
|
1001
|
+
raw_out = self.unpad_draft_extend_output(
|
|
1002
|
+
raw_out,
|
|
1003
|
+
metadata.cu_seqlens_q,
|
|
1004
|
+
metadata.seq_lens_q,
|
|
1005
|
+
metadata.sum_seq_lens_q,
|
|
1006
|
+
)
|
|
676
1007
|
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
|
677
1008
|
return output
|
|
678
1009
|
|
|
@@ -735,7 +1066,7 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
|
|
735
1066
|
):
|
|
736
1067
|
super().__init__(model_runner, topk, speculative_num_steps)
|
|
737
1068
|
|
|
738
|
-
for i in range(self.speculative_num_steps):
|
|
1069
|
+
for i in range(self.speculative_num_steps - 1):
|
|
739
1070
|
self.attn_backends[i] = TRTLLMMLABackend(
|
|
740
1071
|
model_runner,
|
|
741
1072
|
skip_prefill=True,
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import triton
|
|
2
2
|
import triton.language as tl
|
|
3
3
|
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
# Exposed here so other Python modules can import it instead of hard-coding 64.
|
|
7
|
-
TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
|
|
4
|
+
_FLASHMLA_CREATE_KV_BLOCK_SIZE = 4096
|
|
5
|
+
FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON = tl.constexpr(_FLASHMLA_CREATE_KV_BLOCK_SIZE)
|
|
8
6
|
|
|
9
7
|
|
|
10
8
|
@triton.jit
|
|
@@ -46,6 +44,11 @@ def create_flashinfer_kv_indices_triton(
|
|
|
46
44
|
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
|
47
45
|
|
|
48
46
|
|
|
47
|
+
def get_num_page_per_block_flashmla(page_size: int = 64) -> int:
|
|
48
|
+
num_page_per_block = _FLASHMLA_CREATE_KV_BLOCK_SIZE // page_size
|
|
49
|
+
return num_page_per_block
|
|
50
|
+
|
|
51
|
+
|
|
49
52
|
@triton.jit
|
|
50
53
|
def create_flashmla_kv_indices_triton(
|
|
51
54
|
req_to_token_ptr, # [max_batch, max_context_len]
|
|
@@ -55,10 +58,11 @@ def create_flashmla_kv_indices_triton(
|
|
|
55
58
|
kv_indices_ptr,
|
|
56
59
|
req_to_token_ptr_stride: tl.constexpr,
|
|
57
60
|
kv_indices_ptr_stride: tl.constexpr,
|
|
58
|
-
NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
|
59
61
|
PAGED_SIZE: tl.constexpr = 64,
|
|
60
62
|
):
|
|
61
|
-
|
|
63
|
+
NUM_PAGE_PER_BLOCK: tl.constexpr = (
|
|
64
|
+
FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON // PAGED_SIZE
|
|
65
|
+
)
|
|
62
66
|
pid = tl.program_id(axis=0)
|
|
63
67
|
|
|
64
68
|
# find the req pool idx, this is for batch to token
|
|
@@ -73,7 +77,7 @@ def create_flashmla_kv_indices_triton(
|
|
|
73
77
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
|
74
78
|
|
|
75
79
|
num_paged = tl.cdiv(kv_end - kv_start, PAGED_SIZE)
|
|
76
|
-
num_pages_loop = tl.cdiv(kv_end - kv_start,
|
|
80
|
+
num_pages_loop = tl.cdiv(kv_end - kv_start, FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON)
|
|
77
81
|
|
|
78
82
|
for i in range(num_pages_loop):
|
|
79
83
|
# index into req_to_token_ptr needs to be int64
|
|
@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
|
|
|
45
45
|
)
|
|
46
46
|
from sglang.srt.layers.quantization import QuantizationConfig
|
|
47
47
|
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
|
48
|
-
from sglang.srt.
|
|
48
|
+
from sglang.srt.server_args import get_global_server_args
|
|
49
49
|
from sglang.srt.utils import add_prefix
|
|
50
50
|
|
|
51
51
|
ROTARY_EMBED_CLASSES = {
|
|
@@ -468,7 +468,7 @@ class VisionAttention(nn.Module):
|
|
|
468
468
|
_passed_backend = qkv_backend
|
|
469
469
|
qkv_backend = self._determine_attention_backend(_passed_backend)
|
|
470
470
|
if (
|
|
471
|
-
|
|
471
|
+
get_global_server_args().mm_attention_backend is None
|
|
472
472
|
and _passed_backend is None
|
|
473
473
|
):
|
|
474
474
|
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
|
@@ -528,7 +528,7 @@ class VisionAttention(nn.Module):
|
|
|
528
528
|
- CUDA: "triton_attn"
|
|
529
529
|
- Non-CUDA: "sdpa"
|
|
530
530
|
"""
|
|
531
|
-
override_backend =
|
|
531
|
+
override_backend = get_global_server_args().mm_attention_backend
|
|
532
532
|
if override_backend is not None:
|
|
533
533
|
backend = override_backend
|
|
534
534
|
elif passed_backend is not None:
|