sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +54 -37
- sglang/bench_one_batch_server.py +340 -34
- sglang/bench_serving.py +340 -159
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +9 -2
- sglang/profiler.py +20 -3
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +309 -0
- sglang/srt/configs/load_config.py +33 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +284 -118
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +576 -0
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +6 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -15
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +268 -98
- sglang/srt/disaggregation/decode.py +172 -39
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +203 -555
- sglang/srt/disaggregation/nixl/conn.py +217 -63
- sglang/srt/disaggregation/prefill.py +113 -270
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +203 -97
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +85 -65
- sglang/srt/entrypoints/grpc_server.py +632 -305
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +169 -17
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +327 -34
- sglang/srt/entrypoints/openai/serving_base.py +74 -8
- sglang/srt/entrypoints/openai/serving_chat.py +202 -118
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +20 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +47 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +323 -0
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +21 -16
- sglang/srt/function_call/glm4_moe_detector.py +4 -8
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +61 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +98 -7
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/grpc_request_manager.py +915 -0
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
- sglang/srt/layers/activation.py +11 -7
- sglang/srt/layers/attention/aiter_backend.py +17 -18
- sglang/srt/layers/attention/ascend_backend.py +125 -10
- sglang/srt/layers/attention/attention_registry.py +226 -0
- sglang/srt/layers/attention/base_attn_backend.py +32 -4
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +52 -15
- sglang/srt/layers/attention/flashinfer_backend.py +357 -212
- sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
- sglang/srt/layers/attention/flashmla_backend.py +9 -7
- sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
- sglang/srt/layers/attention/mamba/mamba.py +514 -1
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +23 -0
- sglang/srt/layers/attention/nsa_backend.py +1201 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +249 -42
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
- sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +61 -3
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +19 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +28 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +47 -15
- sglang/srt/layers/linear.py +30 -5
- sglang/srt/layers/logits_processor.py +161 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
- sglang/srt/layers/moe/ep_moe/layer.py +243 -448
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +27 -1
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +86 -20
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +43 -15
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +141 -81
- sglang/srt/layers/quantization/mxfp4.py +17 -34
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -24
- sglang/srt/layers/quantization/w8a8_int8.py +45 -27
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +750 -46
- sglang/srt/layers/sampler.py +84 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +23 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +9 -4
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +33 -7
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +41 -17
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +83 -152
- sglang/srt/managers/data_parallel_controller.py +156 -87
- sglang/srt/managers/detokenizer_manager.py +51 -24
- sglang/srt/managers/io_struct.py +223 -129
- sglang/srt/managers/mm_utils.py +49 -10
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +130 -0
- sglang/srt/managers/schedule_batch.py +340 -529
- sglang/srt/managers/schedule_policy.py +158 -18
- sglang/srt/managers/scheduler.py +665 -620
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
- sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
- sglang/srt/managers/tokenizer_manager.py +462 -226
- sglang/srt/managers/tp_worker.py +217 -156
- sglang/srt/managers/utils.py +79 -47
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +42 -28
- sglang/srt/mem_cache/base_prefix_cache.py +3 -3
- sglang/srt/mem_cache/chunk_cache.py +20 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +38 -0
- sglang/srt/mem_cache/hicache_storage.py +44 -2
- sglang/srt/mem_cache/hiradix_cache.py +134 -34
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +602 -208
- sglang/srt/mem_cache/memory_pool_host.py +134 -183
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +263 -78
- sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +115 -58
- sglang/srt/metrics/collector.py +113 -120
- sglang/srt/metrics/func_timer.py +3 -8
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +81 -36
- sglang/srt/model_executor/forward_batch_info.py +40 -50
- sglang/srt/model_executor/model_runner.py +507 -319
- sglang/srt/model_executor/npu_graph_runner.py +11 -5
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +438 -37
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +200 -27
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +40 -56
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +25 -4
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +793 -235
- sglang/srt/models/dots_ocr.py +171 -0
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +570 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -3
- sglang/srt/models/glm4_moe.py +17 -40
- sglang/srt/models/glm4_moe_nextn.py +4 -4
- sglang/srt/models/glm4v.py +3 -2
- sglang/srt/models/glm4v_moe.py +6 -6
- sglang/srt/models/gpt_oss.py +12 -35
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +4 -2
- sglang/srt/models/llama.py +6 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +6 -23
- sglang/srt/models/longcat_flash_nextn.py +4 -15
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +27 -6
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +5 -5
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +70 -4
- sglang/srt/models/qwen2_vl.py +6 -3
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +50 -38
- sglang/srt/models/qwen3_next.py +43 -21
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +791 -0
- sglang/srt/models/qwen3_vl_moe.py +343 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +268 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +61 -0
- sglang/srt/multimodal/processors/base_processor.py +21 -9
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +2 -4
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +20 -10
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +83 -17
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +36 -23
- sglang/srt/sampling/sampling_params.py +75 -0
- sglang/srt/server_args.py +1300 -338
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +161 -0
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
- sglang/srt/speculative/eagle_info.py +786 -0
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +113 -1270
- sglang/srt/speculative/eagle_worker.py +120 -285
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/ngram_info.py +433 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +49 -0
- sglang/srt/speculative/spec_utils.py +641 -0
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +35 -18
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/{utils.py → utils/common.py} +583 -113
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +120 -11
- sglang/test/runners.py +3 -1
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +8 -2
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +3 -4
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +430 -0
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +93 -1
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +432 -16
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
- sglang/srt/entrypoints/grpc_request_manager.py +0 -580
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -9,28 +9,21 @@ and uses BatchMLAPaged wrapper for decoding.
|
|
|
9
9
|
More details can be found in https://docs.flashinfer.ai/api/mla.html
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
|
-
import os
|
|
13
12
|
from dataclasses import dataclass
|
|
14
13
|
from functools import partial
|
|
15
14
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
|
16
15
|
|
|
17
16
|
import torch
|
|
18
17
|
|
|
19
|
-
|
|
20
|
-
import logging
|
|
21
|
-
|
|
22
|
-
torch._logging.set_logs(dynamo=logging.ERROR)
|
|
23
|
-
torch._dynamo.config.suppress_errors = True
|
|
24
|
-
|
|
25
|
-
from sglang.global_config import global_config
|
|
18
|
+
from sglang.srt.environ import envs
|
|
26
19
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
27
20
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
|
28
21
|
create_flashinfer_kv_indices_triton,
|
|
29
22
|
)
|
|
30
23
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
31
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
32
24
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
33
|
-
from sglang.srt.
|
|
25
|
+
from sglang.srt.server_args import get_global_server_args
|
|
26
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
|
34
27
|
from sglang.srt.utils import (
|
|
35
28
|
is_flashinfer_available,
|
|
36
29
|
is_sm100_supported,
|
|
@@ -38,9 +31,18 @@ from sglang.srt.utils import (
|
|
|
38
31
|
)
|
|
39
32
|
|
|
40
33
|
if TYPE_CHECKING:
|
|
34
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
|
35
|
+
FlashInferMlaAttnBackend,
|
|
36
|
+
)
|
|
41
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
42
38
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
43
|
-
from sglang.srt.speculative.spec_info import
|
|
39
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
|
40
|
+
|
|
41
|
+
if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
|
|
42
|
+
import logging
|
|
43
|
+
|
|
44
|
+
torch._logging.set_logs(dynamo=logging.ERROR)
|
|
45
|
+
torch._dynamo.config.suppress_errors = True
|
|
44
46
|
|
|
45
47
|
if is_flashinfer_available():
|
|
46
48
|
from flashinfer import (
|
|
@@ -66,7 +68,7 @@ global_workspace_buffer = None
|
|
|
66
68
|
|
|
67
69
|
class FlashInferMhaChunkKVRunner:
|
|
68
70
|
def __init__(
|
|
69
|
-
self, model_runner: ModelRunner, attn_backend:
|
|
71
|
+
self, model_runner: ModelRunner, attn_backend: FlashInferMlaAttnBackend
|
|
70
72
|
):
|
|
71
73
|
# Parse Constants
|
|
72
74
|
self.num_local_heads = (
|
|
@@ -193,9 +195,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
|
193
195
|
self.skip_prefill = skip_prefill
|
|
194
196
|
self.enable_chunk_kv = (
|
|
195
197
|
not skip_prefill
|
|
196
|
-
and
|
|
197
|
-
and not
|
|
198
|
-
and not
|
|
198
|
+
and get_global_server_args().disaggregation_mode != "decode"
|
|
199
|
+
and not get_global_server_args().disable_chunked_prefix_cache
|
|
200
|
+
and not get_global_server_args().flashinfer_mla_disable_ragged
|
|
199
201
|
)
|
|
200
202
|
self.page_size = model_runner.page_size
|
|
201
203
|
|
|
@@ -204,7 +206,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
|
204
206
|
if global_workspace_buffer is None:
|
|
205
207
|
# different from flashinfer zero_init_global_workspace_buffer
|
|
206
208
|
global_workspace_buffer = torch.empty(
|
|
207
|
-
|
|
209
|
+
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
|
|
208
210
|
dtype=torch.uint8,
|
|
209
211
|
device=model_runner.device,
|
|
210
212
|
)
|
|
@@ -306,7 +308,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
|
306
308
|
prefix_lens = forward_batch.extend_prefix_lens
|
|
307
309
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
|
308
310
|
use_ragged = (
|
|
309
|
-
not
|
|
311
|
+
not get_global_server_args().flashinfer_mla_disable_ragged
|
|
310
312
|
and extend_no_prefix
|
|
311
313
|
)
|
|
312
314
|
|
|
@@ -361,7 +363,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
|
361
363
|
seq_lens: torch.Tensor,
|
|
362
364
|
encoder_lens: Optional[torch.Tensor],
|
|
363
365
|
forward_mode: ForwardMode,
|
|
364
|
-
spec_info: Optional[
|
|
366
|
+
spec_info: Optional[SpecInput],
|
|
365
367
|
):
|
|
366
368
|
if forward_mode.is_decode_or_idle():
|
|
367
369
|
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
|
@@ -441,7 +443,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
|
441
443
|
seq_lens_sum: int,
|
|
442
444
|
encoder_lens: Optional[torch.Tensor],
|
|
443
445
|
forward_mode: ForwardMode,
|
|
444
|
-
spec_info: Optional[
|
|
446
|
+
spec_info: Optional[SpecInput],
|
|
445
447
|
seq_lens_cpu: Optional[torch.Tensor],
|
|
446
448
|
):
|
|
447
449
|
if forward_mode.is_decode_or_idle():
|
|
@@ -663,7 +665,7 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
|
663
665
|
seq_lens_sum: int,
|
|
664
666
|
decode_wrapper: BatchMLAPagedAttentionWrapper,
|
|
665
667
|
init_metadata_replay: bool = False,
|
|
666
|
-
spec_info: Optional[
|
|
668
|
+
spec_info: Optional[SpecInput] = None,
|
|
667
669
|
**fast_decode_kwargs,
|
|
668
670
|
):
|
|
669
671
|
decode_wrapper = decode_wrapper or self.decode_wrapper
|
|
@@ -688,7 +690,7 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
|
688
690
|
q_indptr: torch.Tensor,
|
|
689
691
|
kv_indptr: torch.Tensor,
|
|
690
692
|
init_metadata_replay: bool = False,
|
|
691
|
-
spec_info: Optional[
|
|
693
|
+
spec_info: Optional[SpecInput] = None,
|
|
692
694
|
**fast_decode_kwargs,
|
|
693
695
|
):
|
|
694
696
|
bs = len(req_pool_indices)
|
|
@@ -776,7 +778,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
|
776
778
|
prefix_lens: torch.Tensor,
|
|
777
779
|
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
|
|
778
780
|
use_ragged: bool,
|
|
779
|
-
spec_info: Optional[
|
|
781
|
+
spec_info: Optional[SpecInput] = None,
|
|
780
782
|
):
|
|
781
783
|
if use_ragged:
|
|
782
784
|
paged_kernel_lens = prefix_lens
|
|
@@ -811,7 +813,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
|
811
813
|
kv_indptr: torch.Tensor,
|
|
812
814
|
qo_indptr: torch.Tensor,
|
|
813
815
|
use_ragged: bool,
|
|
814
|
-
spec_info: Optional[
|
|
816
|
+
spec_info: Optional[SpecInput] = None,
|
|
815
817
|
):
|
|
816
818
|
bs = len(seq_lens)
|
|
817
819
|
sm_scale = self.scaling
|
|
@@ -838,9 +840,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
|
838
840
|
qo_indptr = qo_indptr[: bs + 1]
|
|
839
841
|
custom_mask = None
|
|
840
842
|
else:
|
|
841
|
-
assert isinstance(spec_info,
|
|
842
|
-
spec_info, EagleVerifyInput
|
|
843
|
-
)
|
|
843
|
+
assert isinstance(spec_info, SpecInput)
|
|
844
844
|
# TODO: Support topk > 1 with custom mask
|
|
845
845
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
|
846
846
|
spec_info.generate_attn_arg_prefill(
|
|
@@ -894,7 +894,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
|
894
894
|
topk: int,
|
|
895
895
|
speculative_num_steps: int,
|
|
896
896
|
):
|
|
897
|
-
from sglang.srt.speculative.
|
|
897
|
+
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
|
898
898
|
|
|
899
899
|
if topk > 1:
|
|
900
900
|
raise ValueError(
|
|
@@ -918,7 +918,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
|
918
918
|
)
|
|
919
919
|
|
|
920
920
|
self.attn_backends = []
|
|
921
|
-
for i in range(self.speculative_num_steps):
|
|
921
|
+
for i in range(self.speculative_num_steps - 1):
|
|
922
922
|
self.attn_backends.append(
|
|
923
923
|
FlashInferMLAAttnBackend(
|
|
924
924
|
model_runner,
|
|
@@ -963,7 +963,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
|
963
963
|
)
|
|
964
964
|
|
|
965
965
|
assert forward_batch.spec_info is not None
|
|
966
|
-
assert
|
|
966
|
+
assert forward_batch.spec_info.is_draft_input()
|
|
967
967
|
|
|
968
968
|
for i in range(self.speculative_num_steps - 1):
|
|
969
969
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
|
@@ -983,8 +983,6 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
|
983
983
|
)
|
|
984
984
|
|
|
985
985
|
def call_fn(i, forward_batch):
|
|
986
|
-
assert forward_batch.spec_info is not None
|
|
987
|
-
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
|
988
986
|
forward_batch.spec_info.kv_indptr = (
|
|
989
987
|
forward_batch.spec_info.kv_indptr.clone()
|
|
990
988
|
)
|
|
@@ -1002,7 +1000,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
|
1002
1000
|
device="cuda",
|
|
1003
1001
|
)
|
|
1004
1002
|
|
|
1005
|
-
for i in range(self.speculative_num_steps):
|
|
1003
|
+
for i in range(self.speculative_num_steps - 1):
|
|
1006
1004
|
self.attn_backends[i].init_cuda_graph_state(
|
|
1007
1005
|
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
|
1008
1006
|
)
|
|
@@ -1064,7 +1062,7 @@ def fast_mla_decode_plan(
|
|
|
1064
1062
|
|
|
1065
1063
|
try:
|
|
1066
1064
|
# Standard version with just the required arguments (no use_profiler)
|
|
1067
|
-
self._cached_module.plan
|
|
1065
|
+
self._cached_module.plan(
|
|
1068
1066
|
self._float_workspace_buffer,
|
|
1069
1067
|
self._int_workspace_buffer,
|
|
1070
1068
|
self._pin_memory_int_workspace_buffer,
|
|
@@ -19,7 +19,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
21
21
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
22
|
-
from sglang.srt.speculative.spec_info import
|
|
22
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
# FlashMLA only supports pagesize=64
|
|
@@ -187,7 +187,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
|
187
187
|
seq_lens: torch.Tensor,
|
|
188
188
|
encoder_lens: Optional[torch.Tensor],
|
|
189
189
|
forward_mode: ForwardMode,
|
|
190
|
-
spec_info: Optional[
|
|
190
|
+
spec_info: Optional[SpecInput],
|
|
191
191
|
):
|
|
192
192
|
if forward_mode.is_decode_or_idle():
|
|
193
193
|
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
|
|
@@ -201,9 +201,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
|
201
201
|
self.req_to_token.stride(0),
|
|
202
202
|
self.cuda_graph_kv_indices.stride(0),
|
|
203
203
|
)
|
|
204
|
+
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
|
|
204
205
|
mla_metadata, num_splits = get_mla_metadata(
|
|
205
206
|
seq_lens.to(torch.int32),
|
|
206
|
-
|
|
207
|
+
num_q_heads,
|
|
207
208
|
1,
|
|
208
209
|
)
|
|
209
210
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
|
@@ -257,7 +258,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
|
257
258
|
seq_lens_sum: int,
|
|
258
259
|
encoder_lens: Optional[torch.Tensor],
|
|
259
260
|
forward_mode: ForwardMode,
|
|
260
|
-
spec_info: Optional[
|
|
261
|
+
spec_info: Optional[SpecInput],
|
|
261
262
|
seq_lens_cpu: Optional[torch.Tensor],
|
|
262
263
|
):
|
|
263
264
|
|
|
@@ -275,9 +276,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
|
275
276
|
self.req_to_token.stride(0),
|
|
276
277
|
self.cuda_graph_kv_indices.stride(0),
|
|
277
278
|
)
|
|
279
|
+
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
|
|
278
280
|
mla_metadata, num_splits = get_mla_metadata(
|
|
279
281
|
seq_lens.to(torch.int32),
|
|
280
|
-
|
|
282
|
+
num_q_heads,
|
|
281
283
|
1,
|
|
282
284
|
)
|
|
283
285
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
|
@@ -476,7 +478,7 @@ class FlashMLAMultiStepDraftBackend:
|
|
|
476
478
|
)
|
|
477
479
|
|
|
478
480
|
self.attn_backends = []
|
|
479
|
-
for i in range(self.speculative_num_steps):
|
|
481
|
+
for i in range(self.speculative_num_steps - 1):
|
|
480
482
|
self.attn_backends.append(
|
|
481
483
|
FlashMLABackend(
|
|
482
484
|
model_runner,
|
|
@@ -504,7 +506,7 @@ class FlashMLAMultiStepDraftBackend:
|
|
|
504
506
|
self.common_template(forward_batch, call_fn)
|
|
505
507
|
|
|
506
508
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
|
507
|
-
for i in range(self.speculative_num_steps):
|
|
509
|
+
for i in range(self.speculative_num_steps - 1):
|
|
508
510
|
self.attn_backends[i].init_cuda_graph_state(
|
|
509
511
|
max_bs, max_num_tokens, block_kv_indices=None
|
|
510
512
|
)
|
|
@@ -1,12 +1,13 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
6
|
+
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
|
6
7
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
7
8
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
8
9
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
9
|
-
from sglang.srt.speculative.
|
|
10
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class HybridAttnBackend(AttentionBackend):
|
|
@@ -21,6 +22,7 @@ class HybridAttnBackend(AttentionBackend):
|
|
|
21
22
|
self.model_runner = model_runner
|
|
22
23
|
self.prefill_backend = prefill_backend
|
|
23
24
|
self.decode_backend = decode_backend
|
|
25
|
+
self.data_type = model_runner.kv_cache_dtype
|
|
24
26
|
|
|
25
27
|
def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
|
|
26
28
|
"""
|
|
@@ -70,7 +72,7 @@ class HybridAttnBackend(AttentionBackend):
|
|
|
70
72
|
seq_lens: torch.Tensor,
|
|
71
73
|
encoder_lens: Optional[torch.Tensor],
|
|
72
74
|
forward_mode: ForwardMode,
|
|
73
|
-
spec_info: Optional[
|
|
75
|
+
spec_info: Optional[SpecInput],
|
|
74
76
|
):
|
|
75
77
|
backend = self._select_backend(forward_mode)
|
|
76
78
|
backend.init_forward_metadata_capture_cuda_graph(
|
|
@@ -91,7 +93,7 @@ class HybridAttnBackend(AttentionBackend):
|
|
|
91
93
|
seq_lens_sum: int,
|
|
92
94
|
encoder_lens: Optional[torch.Tensor],
|
|
93
95
|
forward_mode: ForwardMode,
|
|
94
|
-
spec_info: Optional[
|
|
96
|
+
spec_info: Optional[SpecInput],
|
|
95
97
|
seq_lens_cpu: Optional[torch.Tensor],
|
|
96
98
|
):
|
|
97
99
|
backend = self._select_backend(forward_mode)
|
|
@@ -137,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
|
|
|
137
139
|
return backend.forward_extend(
|
|
138
140
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
|
139
141
|
)
|
|
142
|
+
|
|
143
|
+
def get_indexer_metadata(
|
|
144
|
+
self, layer_id: int, forward_batch: ForwardBatch
|
|
145
|
+
) -> Optional[BaseIndexerMetadata]:
|
|
146
|
+
backend = self._select_backend(forward_batch.forward_mode)
|
|
147
|
+
return backend.get_indexer_metadata(layer_id, forward_batch)
|