sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -20,8 +20,9 @@ from sglang.srt.layers.attention.utils import (
|
|
20
20
|
create_flashmla_kv_indices_triton,
|
21
21
|
)
|
22
22
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
23
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
23
24
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
24
|
-
from sglang.srt.utils import is_flashinfer_available
|
25
|
+
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
25
26
|
|
26
27
|
if is_flashinfer_available():
|
27
28
|
import flashinfer
|
@@ -29,7 +30,12 @@ if is_flashinfer_available():
|
|
29
30
|
if TYPE_CHECKING:
|
30
31
|
from sglang.srt.layers.radix_attention import RadixAttention
|
31
32
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
32
|
-
from sglang.srt.speculative.spec_info import
|
33
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
34
|
+
|
35
|
+
_is_cuda = is_cuda()
|
36
|
+
|
37
|
+
if _is_cuda:
|
38
|
+
from sgl_kernel import concat_mla_absorb_q
|
33
39
|
|
34
40
|
# Constants
|
35
41
|
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
@@ -45,11 +51,19 @@ TRTLLM_BLOCK_CONSTRAINT = 128
|
|
45
51
|
global_zero_init_workspace_buffer = None
|
46
52
|
|
47
53
|
|
54
|
+
@dataclass
|
55
|
+
class TRTLLMMLAPrefillMetadata:
|
56
|
+
"""Metadata for TRTLLM MLA prefill operations."""
|
57
|
+
|
58
|
+
max_seq_len: int
|
59
|
+
cum_seq_lens: torch.Tensor
|
60
|
+
seq_lens: torch.Tensor
|
61
|
+
|
62
|
+
|
48
63
|
@dataclass
|
49
64
|
class TRTLLMMLADecodeMetadata:
|
50
65
|
"""Metadata for TRTLLM MLA decode operations."""
|
51
66
|
|
52
|
-
workspace: Optional[torch.Tensor] = None
|
53
67
|
block_kv_indices: Optional[torch.Tensor] = None
|
54
68
|
max_seq_len: Optional[int] = None
|
55
69
|
|
@@ -64,7 +78,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
64
78
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
65
79
|
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
66
80
|
):
|
67
|
-
super().__init__(
|
81
|
+
super().__init__(
|
82
|
+
model_runner,
|
83
|
+
skip_prefill,
|
84
|
+
kv_indptr_buf,
|
85
|
+
q_indptr_decode_buf,
|
86
|
+
)
|
68
87
|
|
69
88
|
config = model_runner.model_config
|
70
89
|
|
@@ -101,7 +120,14 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
101
120
|
# CUDA graph state
|
102
121
|
self.decode_cuda_graph_metadata = {}
|
103
122
|
self.decode_cuda_graph_kv_indices = None
|
104
|
-
self.
|
123
|
+
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
|
124
|
+
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
125
|
+
|
126
|
+
self.disable_chunked_prefix_cache = global_server_args_dict[
|
127
|
+
"disable_chunked_prefix_cache"
|
128
|
+
]
|
129
|
+
|
130
|
+
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
105
131
|
|
106
132
|
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
107
133
|
"""
|
@@ -177,9 +203,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
177
203
|
self.decode_cuda_graph_kv_indices = torch.full(
|
178
204
|
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
179
205
|
)
|
180
|
-
self.decode_cuda_graph_workspace = torch.empty(
|
181
|
-
self.workspace_size, dtype=torch.int8, device=self.device
|
182
|
-
)
|
183
206
|
|
184
207
|
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
|
185
208
|
|
@@ -191,12 +214,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
191
214
|
seq_lens: torch.Tensor,
|
192
215
|
encoder_lens: Optional[torch.Tensor],
|
193
216
|
forward_mode: ForwardMode,
|
194
|
-
spec_info: Optional[
|
217
|
+
spec_info: Optional[SpecInput],
|
195
218
|
):
|
196
219
|
"""Initialize metadata for CUDA graph capture."""
|
197
220
|
|
198
221
|
# Delegate to parent for non-decode modes.
|
199
|
-
if not forward_mode.is_decode_or_idle():
|
222
|
+
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
|
200
223
|
return super().init_forward_metadata_capture_cuda_graph(
|
201
224
|
bs,
|
202
225
|
num_tokens,
|
@@ -207,6 +230,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
207
230
|
spec_info,
|
208
231
|
)
|
209
232
|
|
233
|
+
if forward_mode.is_target_verify():
|
234
|
+
seq_lens = seq_lens + self.num_draft_tokens
|
235
|
+
|
210
236
|
# Custom fast-path for decode/idle.
|
211
237
|
# Capture with full width so future longer sequences are safe during replay
|
212
238
|
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
@@ -230,12 +256,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
230
256
|
max_seq_len_val = int(seq_lens.max().item())
|
231
257
|
|
232
258
|
metadata = TRTLLMMLADecodeMetadata(
|
233
|
-
self.decode_cuda_graph_workspace,
|
234
259
|
block_kv_indices,
|
235
260
|
max_seq_len_val,
|
236
261
|
)
|
237
262
|
self.decode_cuda_graph_metadata[bs] = metadata
|
238
|
-
self.
|
263
|
+
self.forward_decode_metadata = metadata
|
239
264
|
|
240
265
|
def init_forward_metadata_replay_cuda_graph(
|
241
266
|
self,
|
@@ -245,12 +270,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
245
270
|
seq_lens_sum: int,
|
246
271
|
encoder_lens: Optional[torch.Tensor],
|
247
272
|
forward_mode: ForwardMode,
|
248
|
-
spec_info: Optional[
|
273
|
+
spec_info: Optional[SpecInput],
|
249
274
|
seq_lens_cpu: Optional[torch.Tensor],
|
250
275
|
):
|
251
276
|
"""Replay CUDA graph with new inputs."""
|
252
277
|
# Delegate to parent for non-decode modes.
|
253
|
-
if not forward_mode.is_decode_or_idle():
|
278
|
+
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
|
254
279
|
return super().init_forward_metadata_replay_cuda_graph(
|
255
280
|
bs,
|
256
281
|
req_pool_indices,
|
@@ -262,6 +287,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
262
287
|
seq_lens_cpu,
|
263
288
|
)
|
264
289
|
|
290
|
+
if forward_mode.is_target_verify():
|
291
|
+
seq_lens = seq_lens + self.num_draft_tokens
|
292
|
+
del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
|
293
|
+
|
265
294
|
metadata = self.decode_cuda_graph_metadata[bs]
|
266
295
|
|
267
296
|
# Update block indices for new sequences.
|
@@ -291,31 +320,64 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
291
320
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
292
321
|
"""Initialize the metadata for a forward pass."""
|
293
322
|
# Delegate to parent for non-decode modes.
|
294
|
-
if
|
295
|
-
|
296
|
-
|
297
|
-
|
323
|
+
if (
|
324
|
+
forward_batch.forward_mode.is_extend()
|
325
|
+
and not forward_batch.forward_mode.is_target_verify()
|
326
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
327
|
+
):
|
328
|
+
if self.disable_chunked_prefix_cache:
|
329
|
+
super().init_forward_metadata(forward_batch)
|
330
|
+
|
331
|
+
seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
|
332
|
+
cum_seq_lens_q = torch.cat(
|
333
|
+
(
|
334
|
+
torch.tensor([0], device=forward_batch.seq_lens.device),
|
335
|
+
torch.cumsum(seq_lens, dim=0),
|
336
|
+
)
|
337
|
+
).int()
|
338
|
+
max_seq_len = max(forward_batch.extend_seq_lens_cpu)
|
339
|
+
self.forward_prefill_metadata = TRTLLMMLAPrefillMetadata(
|
340
|
+
max_seq_len,
|
341
|
+
cum_seq_lens_q,
|
342
|
+
seq_lens,
|
343
|
+
)
|
344
|
+
elif (
|
345
|
+
forward_batch.forward_mode.is_decode_or_idle()
|
346
|
+
or forward_batch.forward_mode.is_target_verify()
|
347
|
+
):
|
348
|
+
bs = forward_batch.batch_size
|
349
|
+
|
350
|
+
# Get maximum sequence length.
|
351
|
+
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
|
352
|
+
max_seq = forward_batch.seq_lens_cpu.max().item()
|
353
|
+
else:
|
354
|
+
max_seq = forward_batch.seq_lens.max().item()
|
355
|
+
|
356
|
+
seq_lens = forward_batch.seq_lens
|
357
|
+
|
358
|
+
if forward_batch.forward_mode.is_target_verify():
|
359
|
+
max_seq = max_seq + self.num_draft_tokens
|
360
|
+
seq_lens = seq_lens + self.num_draft_tokens
|
361
|
+
|
362
|
+
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
363
|
+
block_kv_indices = self._create_block_kv_indices(
|
364
|
+
bs,
|
365
|
+
max_seqlen_pad,
|
366
|
+
forward_batch.req_pool_indices,
|
367
|
+
seq_lens,
|
368
|
+
seq_lens.device,
|
369
|
+
)
|
298
370
|
|
299
|
-
|
300
|
-
|
301
|
-
|
371
|
+
max_seq_len_val = int(max_seq)
|
372
|
+
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
|
373
|
+
block_kv_indices, max_seq_len_val
|
374
|
+
)
|
375
|
+
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
|
302
376
|
else:
|
303
|
-
|
304
|
-
|
305
|
-
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
306
|
-
block_kv_indices = self._create_block_kv_indices(
|
307
|
-
bs,
|
308
|
-
max_seqlen_pad,
|
309
|
-
forward_batch.req_pool_indices,
|
310
|
-
forward_batch.seq_lens,
|
311
|
-
forward_batch.seq_lens.device,
|
312
|
-
)
|
377
|
+
return super().init_forward_metadata(forward_batch)
|
313
378
|
|
314
|
-
|
315
|
-
|
316
|
-
self.workspace_buffer, block_kv_indices, max_seq_len_val
|
317
|
-
)
|
318
|
-
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
379
|
+
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
|
380
|
+
super().init_mha_chunk_metadata(forward_batch, disable_flashinfer_ragged=True)
|
319
381
|
|
320
382
|
def quantize_and_rope_for_fp8(
|
321
383
|
self,
|
@@ -443,7 +505,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
443
505
|
q_rope_reshaped = q_rope.view(
|
444
506
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
445
507
|
)
|
446
|
-
query =
|
508
|
+
query = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
|
447
509
|
else:
|
448
510
|
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
449
511
|
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
@@ -459,7 +521,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
459
521
|
# Get metadata
|
460
522
|
metadata = (
|
461
523
|
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
462
|
-
or self.
|
524
|
+
or self.forward_decode_metadata
|
463
525
|
)
|
464
526
|
|
465
527
|
# Scale computation for TRTLLM MLA kernel BMM1 operation:
|
@@ -482,7 +544,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
482
544
|
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
483
545
|
query=query,
|
484
546
|
kv_cache=kv_cache,
|
485
|
-
workspace_buffer=
|
547
|
+
workspace_buffer=self.workspace_buffer,
|
486
548
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
487
549
|
kv_lora_rank=self.kv_lora_rank,
|
488
550
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
@@ -496,6 +558,174 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
496
558
|
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
497
559
|
return output
|
498
560
|
|
561
|
+
def forward_extend(
|
562
|
+
self,
|
563
|
+
q: torch.Tensor,
|
564
|
+
k: torch.Tensor,
|
565
|
+
v: torch.Tensor,
|
566
|
+
layer: RadixAttention,
|
567
|
+
forward_batch: ForwardBatch,
|
568
|
+
save_kv_cache: bool = True,
|
569
|
+
q_rope: Optional[torch.Tensor] = None,
|
570
|
+
k_rope: Optional[torch.Tensor] = None,
|
571
|
+
cos_sin_cache: Optional[torch.Tensor] = None,
|
572
|
+
is_neox: Optional[bool] = False,
|
573
|
+
) -> torch.Tensor:
|
574
|
+
if forward_batch.forward_mode.is_draft_extend():
|
575
|
+
return super().forward_extend(
|
576
|
+
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
577
|
+
)
|
578
|
+
|
579
|
+
# TODO refactor to avoid code duplication
|
580
|
+
merge_query = q_rope is not None
|
581
|
+
if (
|
582
|
+
self.data_type == torch.float8_e4m3fn
|
583
|
+
) and forward_batch.forward_mode.is_target_verify():
|
584
|
+
# For FP8 path, we quantize the query and rope parts and merge them into a single tensor
|
585
|
+
# Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
|
586
|
+
assert all(
|
587
|
+
x is not None for x in [q_rope, k_rope, cos_sin_cache]
|
588
|
+
), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
|
589
|
+
q, k, k_rope = self.quantize_and_rope_for_fp8(
|
590
|
+
q,
|
591
|
+
q_rope,
|
592
|
+
k.squeeze(1),
|
593
|
+
k_rope.squeeze(1),
|
594
|
+
forward_batch,
|
595
|
+
cos_sin_cache,
|
596
|
+
is_neox,
|
597
|
+
)
|
598
|
+
merge_query = False
|
599
|
+
|
600
|
+
# Save KV cache if requested
|
601
|
+
if save_kv_cache:
|
602
|
+
assert (
|
603
|
+
k is not None and k_rope is not None
|
604
|
+
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
|
605
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
606
|
+
layer, forward_batch.out_cache_loc, k, k_rope
|
607
|
+
)
|
608
|
+
|
609
|
+
# TODO refactor to avoid code duplication
|
610
|
+
# Prepare query tensor inline
|
611
|
+
if merge_query:
|
612
|
+
# For FP16 path, we merge the query and rope parts into a single tensor
|
613
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
614
|
+
q_rope_reshaped = q_rope.view(
|
615
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
616
|
+
)
|
617
|
+
q = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
|
618
|
+
else:
|
619
|
+
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
620
|
+
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
621
|
+
|
622
|
+
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
623
|
+
|
624
|
+
if k_rope is not None:
|
625
|
+
k = torch.cat([k, k_rope], dim=-1)
|
626
|
+
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
627
|
+
|
628
|
+
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
629
|
+
|
630
|
+
if forward_batch.forward_mode.is_target_verify():
|
631
|
+
metadata = (
|
632
|
+
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
633
|
+
or self.forward_decode_metadata
|
634
|
+
)
|
635
|
+
|
636
|
+
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
|
637
|
+
bs = forward_batch.batch_size
|
638
|
+
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
639
|
+
|
640
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
641
|
+
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
|
642
|
+
|
643
|
+
q_scale = 1.0
|
644
|
+
k_scale = (
|
645
|
+
layer.k_scale_float
|
646
|
+
if getattr(layer, "k_scale_float", None) is not None
|
647
|
+
else 1.0
|
648
|
+
)
|
649
|
+
|
650
|
+
bmm1_scale = q_scale * k_scale * layer.scaling
|
651
|
+
|
652
|
+
seq_lens = (
|
653
|
+
forward_batch.seq_lens.to(torch.int32)
|
654
|
+
+ forward_batch.spec_info.draft_token_num
|
655
|
+
)
|
656
|
+
max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
|
657
|
+
|
658
|
+
# TODO may use `mla_rope_quantize_fp8` fusion
|
659
|
+
q = q.to(self.data_type)
|
660
|
+
assert kv_cache.dtype == self.data_type
|
661
|
+
|
662
|
+
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
663
|
+
query=q,
|
664
|
+
kv_cache=kv_cache,
|
665
|
+
workspace_buffer=self.workspace_buffer,
|
666
|
+
qk_nope_head_dim=self.qk_nope_head_dim,
|
667
|
+
kv_lora_rank=self.kv_lora_rank,
|
668
|
+
qk_rope_head_dim=self.qk_rope_head_dim,
|
669
|
+
block_tables=metadata.block_kv_indices,
|
670
|
+
seq_lens=seq_lens,
|
671
|
+
max_seq_len=max_seq_len,
|
672
|
+
bmm1_scale=bmm1_scale,
|
673
|
+
)
|
674
|
+
|
675
|
+
# Reshape output directly without slicing
|
676
|
+
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
677
|
+
return output
|
678
|
+
|
679
|
+
if forward_batch.attn_attend_prefix_cache:
|
680
|
+
# MHA for chunked prefix kv cache when running model with MLA
|
681
|
+
assert forward_batch.prefix_chunk_idx is not None
|
682
|
+
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
683
|
+
assert q_rope is None
|
684
|
+
assert k_rope is None
|
685
|
+
chunk_idx = forward_batch.prefix_chunk_idx
|
686
|
+
|
687
|
+
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
|
688
|
+
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
689
|
+
query=q,
|
690
|
+
key=k,
|
691
|
+
value=v,
|
692
|
+
workspace_buffer=self.workspace_buffer,
|
693
|
+
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
|
694
|
+
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
695
|
+
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
696
|
+
bmm1_scale=layer.scaling,
|
697
|
+
bmm2_scale=1.0,
|
698
|
+
o_sf_scale=-1.0,
|
699
|
+
batch_size=forward_batch.batch_size,
|
700
|
+
window_left=-1,
|
701
|
+
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
702
|
+
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
703
|
+
enable_pdl=False,
|
704
|
+
is_causal=False,
|
705
|
+
return_lse=True,
|
706
|
+
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
|
707
|
+
)
|
708
|
+
|
709
|
+
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
710
|
+
query=q,
|
711
|
+
key=k,
|
712
|
+
value=v,
|
713
|
+
workspace_buffer=self.workspace_buffer,
|
714
|
+
seq_lens=self.forward_prefill_metadata.seq_lens,
|
715
|
+
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
716
|
+
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
717
|
+
bmm1_scale=layer.scaling,
|
718
|
+
bmm2_scale=1.0,
|
719
|
+
o_sf_scale=1.0,
|
720
|
+
batch_size=forward_batch.batch_size,
|
721
|
+
window_left=-1,
|
722
|
+
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
723
|
+
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
724
|
+
enable_pdl=False,
|
725
|
+
is_causal=True,
|
726
|
+
return_lse=forward_batch.mha_return_lse,
|
727
|
+
)
|
728
|
+
|
499
729
|
|
500
730
|
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
501
731
|
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
|
@@ -512,3 +742,10 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
|
512
742
|
kv_indptr_buf=self.kv_indptr[i],
|
513
743
|
q_indptr_decode_buf=self.q_indptr_decode,
|
514
744
|
)
|
745
|
+
|
746
|
+
|
747
|
+
def _concat_mla_absorb_q_general(q_nope, q_rope):
|
748
|
+
if _is_cuda and q_nope.shape[-1] == 512 and q_rope.shape[-1] == 64:
|
749
|
+
return concat_mla_absorb_q(q_nope, q_rope)
|
750
|
+
else:
|
751
|
+
return torch.cat([q_nope, q_rope], dim=-1)
|
@@ -16,14 +16,19 @@ from sglang.srt.utils import (
|
|
16
16
|
get_device_capability,
|
17
17
|
is_blackwell,
|
18
18
|
is_cuda,
|
19
|
+
is_npu,
|
19
20
|
print_info_once,
|
20
21
|
)
|
21
22
|
|
22
23
|
_is_cuda = is_cuda()
|
24
|
+
_is_npu = is_npu()
|
23
25
|
|
24
26
|
if _is_cuda:
|
25
27
|
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
26
28
|
|
29
|
+
if _is_npu:
|
30
|
+
import torch_npu
|
31
|
+
|
27
32
|
from sglang.srt.distributed import (
|
28
33
|
split_tensor_along_last_dim,
|
29
34
|
tensor_model_parallel_all_gather,
|
@@ -331,10 +336,63 @@ class VisionFlash3Attention(nn.Module):
|
|
331
336
|
return output
|
332
337
|
|
333
338
|
|
339
|
+
class VisionAscendAttention(nn.Module):
|
340
|
+
|
341
|
+
def __init__(
|
342
|
+
self,
|
343
|
+
**kwargs,
|
344
|
+
):
|
345
|
+
if not _is_npu:
|
346
|
+
raise Exception("VisionAscendAttention is only available for ascend npu")
|
347
|
+
super().__init__()
|
348
|
+
|
349
|
+
def forward(
|
350
|
+
self,
|
351
|
+
q: torch.Tensor,
|
352
|
+
k: torch.Tensor,
|
353
|
+
v: torch.Tensor,
|
354
|
+
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
|
355
|
+
bsz: int,
|
356
|
+
seq_len: int,
|
357
|
+
**kwargs,
|
358
|
+
) -> torch.Tensor:
|
359
|
+
r"""
|
360
|
+
Args:
|
361
|
+
cu_seqlens: [b]
|
362
|
+
Returns:
|
363
|
+
[b * s, h, head_size]
|
364
|
+
"""
|
365
|
+
if cu_seqlens is None:
|
366
|
+
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
367
|
+
|
368
|
+
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
369
|
+
if seq_lens.is_npu:
|
370
|
+
# cu_seqlens must be on cpu because of operator restriction
|
371
|
+
seq_lens = seq_lens.to("cpu")
|
372
|
+
_, num_heads, head_size = q.shape
|
373
|
+
num_kv_heads = k.shape[1]
|
374
|
+
output = torch.empty_like(q)
|
375
|
+
|
376
|
+
# operator requires pta version >= 2.5.1
|
377
|
+
torch_npu._npu_flash_attention_unpad(
|
378
|
+
query=q,
|
379
|
+
key=k,
|
380
|
+
value=v,
|
381
|
+
seq_len=seq_lens.to(torch.int32),
|
382
|
+
scale_value=head_size**-0.5,
|
383
|
+
num_heads=num_heads,
|
384
|
+
num_kv_heads=num_kv_heads,
|
385
|
+
out=output,
|
386
|
+
)
|
387
|
+
|
388
|
+
return output
|
389
|
+
|
390
|
+
|
334
391
|
QKV_BACKEND_IMPL = {
|
335
392
|
"triton_attn": VisionTritonAttention,
|
336
393
|
"sdpa": VisionSdpaAttention,
|
337
394
|
"fa3": VisionFlash3Attention,
|
395
|
+
"ascend_attn": VisionAscendAttention,
|
338
396
|
}
|
339
397
|
|
340
398
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from typing import TYPE_CHECKING, Optional
|
5
|
+
from typing import TYPE_CHECKING, Optional
|
6
6
|
|
7
7
|
import torch
|
8
8
|
import triton
|
@@ -17,7 +17,7 @@ from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
from sglang.srt.layers.radix_attention import RadixAttention
|
19
19
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
20
|
-
from sglang.srt.speculative.
|
20
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
21
21
|
|
22
22
|
logger = logging.getLogger(__name__)
|
23
23
|
|
@@ -393,7 +393,7 @@ class WaveAttnBackend(AttentionBackend):
|
|
393
393
|
seq_lens: torch.Tensor,
|
394
394
|
encoder_lens: Optional[torch.Tensor],
|
395
395
|
forward_mode: ForwardMode,
|
396
|
-
spec_info: Optional[
|
396
|
+
spec_info: Optional[SpecInput],
|
397
397
|
):
|
398
398
|
assert encoder_lens is None, "Not supported"
|
399
399
|
|
@@ -477,7 +477,7 @@ class WaveAttnBackend(AttentionBackend):
|
|
477
477
|
seq_lens_sum: int,
|
478
478
|
encoder_lens: Optional[torch.Tensor],
|
479
479
|
forward_mode: ForwardMode,
|
480
|
-
spec_info: Optional[
|
480
|
+
spec_info: Optional[SpecInput],
|
481
481
|
seq_lens_cpu: Optional[torch.Tensor],
|
482
482
|
):
|
483
483
|
# NOTE: encoder_lens expected to be zeros or None
|
@@ -64,8 +64,7 @@ def get_wave_kernel(
|
|
64
64
|
subs=hyperparams_0,
|
65
65
|
canonicalize=True,
|
66
66
|
run_bench=False,
|
67
|
-
|
68
|
-
use_buffer_store_ops=True,
|
67
|
+
use_buffer_ops=True,
|
69
68
|
waves_per_eu=2,
|
70
69
|
dynamic_symbols=dynamic_symbols_0,
|
71
70
|
wave_runtime=True,
|
@@ -77,8 +76,7 @@ def get_wave_kernel(
|
|
77
76
|
subs=hyperparams_1,
|
78
77
|
canonicalize=True,
|
79
78
|
run_bench=False,
|
80
|
-
|
81
|
-
use_buffer_store_ops=False,
|
79
|
+
use_buffer_ops=False,
|
82
80
|
waves_per_eu=4,
|
83
81
|
dynamic_symbols=dynamic_symbols_1,
|
84
82
|
wave_runtime=True,
|
@@ -67,11 +67,9 @@ def get_wave_kernel(
|
|
67
67
|
schedule=SchedulingType.NONE,
|
68
68
|
use_scheduling_barriers=False,
|
69
69
|
dynamic_symbols=dynamic_symbols,
|
70
|
-
|
71
|
-
use_buffer_store_ops=True,
|
70
|
+
use_buffer_ops=True,
|
72
71
|
waves_per_eu=2,
|
73
72
|
denorm_fp_math_f32="preserve-sign",
|
74
|
-
gpu_native_math_precision=True,
|
75
73
|
wave_runtime=True,
|
76
74
|
)
|
77
75
|
options = set_default_run_config(options)
|
@@ -50,6 +50,7 @@ from sglang.srt.utils import (
|
|
50
50
|
is_hip,
|
51
51
|
is_sm90_supported,
|
52
52
|
is_sm100_supported,
|
53
|
+
prepare_weight_cache,
|
53
54
|
)
|
54
55
|
|
55
56
|
_is_flashinfer_available = is_flashinfer_available()
|
@@ -275,7 +276,11 @@ class LayerCommunicator:
|
|
275
276
|
hidden_states: torch.Tensor,
|
276
277
|
residual: torch.Tensor,
|
277
278
|
forward_batch: ForwardBatch,
|
279
|
+
cache=None,
|
278
280
|
):
|
281
|
+
if cache is not None:
|
282
|
+
self._context.cache = cache
|
283
|
+
|
279
284
|
return self._communicate_with_all_reduce_and_layer_norm_fn(
|
280
285
|
hidden_states=hidden_states,
|
281
286
|
residual=residual,
|
@@ -349,6 +354,7 @@ class CommunicateContext:
|
|
349
354
|
attn_tp_size: int
|
350
355
|
attn_dp_size: int
|
351
356
|
tp_size: int
|
357
|
+
cache = None
|
352
358
|
|
353
359
|
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
354
360
|
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
@@ -533,6 +539,8 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
533
539
|
)
|
534
540
|
else:
|
535
541
|
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
542
|
+
if context.cache is not None:
|
543
|
+
_ = prepare_weight_cache(hidden_states, context.cache)
|
536
544
|
hidden_states, residual = layernorm(hidden_states, residual)
|
537
545
|
return hidden_states, residual
|
538
546
|
|