sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +54 -37
- sglang/bench_one_batch_server.py +340 -34
- sglang/bench_serving.py +340 -159
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +9 -2
- sglang/profiler.py +20 -3
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +309 -0
- sglang/srt/configs/load_config.py +33 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +284 -118
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +576 -0
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +6 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -15
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +268 -98
- sglang/srt/disaggregation/decode.py +172 -39
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +203 -555
- sglang/srt/disaggregation/nixl/conn.py +217 -63
- sglang/srt/disaggregation/prefill.py +113 -270
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +203 -97
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +85 -65
- sglang/srt/entrypoints/grpc_server.py +632 -305
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +169 -17
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +327 -34
- sglang/srt/entrypoints/openai/serving_base.py +74 -8
- sglang/srt/entrypoints/openai/serving_chat.py +202 -118
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +20 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +47 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +323 -0
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +21 -16
- sglang/srt/function_call/glm4_moe_detector.py +4 -8
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +61 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +98 -7
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/grpc_request_manager.py +915 -0
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
- sglang/srt/layers/activation.py +11 -7
- sglang/srt/layers/attention/aiter_backend.py +17 -18
- sglang/srt/layers/attention/ascend_backend.py +125 -10
- sglang/srt/layers/attention/attention_registry.py +226 -0
- sglang/srt/layers/attention/base_attn_backend.py +32 -4
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +52 -15
- sglang/srt/layers/attention/flashinfer_backend.py +357 -212
- sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
- sglang/srt/layers/attention/flashmla_backend.py +9 -7
- sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
- sglang/srt/layers/attention/mamba/mamba.py +514 -1
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +23 -0
- sglang/srt/layers/attention/nsa_backend.py +1201 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +249 -42
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
- sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +61 -3
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +19 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +28 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +47 -15
- sglang/srt/layers/linear.py +30 -5
- sglang/srt/layers/logits_processor.py +161 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
- sglang/srt/layers/moe/ep_moe/layer.py +243 -448
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +27 -1
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +86 -20
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +43 -15
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +141 -81
- sglang/srt/layers/quantization/mxfp4.py +17 -34
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -24
- sglang/srt/layers/quantization/w8a8_int8.py +45 -27
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +750 -46
- sglang/srt/layers/sampler.py +84 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +23 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +9 -4
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +33 -7
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +41 -17
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +83 -152
- sglang/srt/managers/data_parallel_controller.py +156 -87
- sglang/srt/managers/detokenizer_manager.py +51 -24
- sglang/srt/managers/io_struct.py +223 -129
- sglang/srt/managers/mm_utils.py +49 -10
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +130 -0
- sglang/srt/managers/schedule_batch.py +340 -529
- sglang/srt/managers/schedule_policy.py +158 -18
- sglang/srt/managers/scheduler.py +665 -620
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
- sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
- sglang/srt/managers/tokenizer_manager.py +462 -226
- sglang/srt/managers/tp_worker.py +217 -156
- sglang/srt/managers/utils.py +79 -47
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +42 -28
- sglang/srt/mem_cache/base_prefix_cache.py +3 -3
- sglang/srt/mem_cache/chunk_cache.py +20 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +38 -0
- sglang/srt/mem_cache/hicache_storage.py +44 -2
- sglang/srt/mem_cache/hiradix_cache.py +134 -34
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +602 -208
- sglang/srt/mem_cache/memory_pool_host.py +134 -183
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +263 -78
- sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +115 -58
- sglang/srt/metrics/collector.py +113 -120
- sglang/srt/metrics/func_timer.py +3 -8
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +81 -36
- sglang/srt/model_executor/forward_batch_info.py +40 -50
- sglang/srt/model_executor/model_runner.py +507 -319
- sglang/srt/model_executor/npu_graph_runner.py +11 -5
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +438 -37
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +200 -27
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +40 -56
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +25 -4
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +793 -235
- sglang/srt/models/dots_ocr.py +171 -0
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +570 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -3
- sglang/srt/models/glm4_moe.py +17 -40
- sglang/srt/models/glm4_moe_nextn.py +4 -4
- sglang/srt/models/glm4v.py +3 -2
- sglang/srt/models/glm4v_moe.py +6 -6
- sglang/srt/models/gpt_oss.py +12 -35
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +4 -2
- sglang/srt/models/llama.py +6 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +6 -23
- sglang/srt/models/longcat_flash_nextn.py +4 -15
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +27 -6
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +5 -5
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +70 -4
- sglang/srt/models/qwen2_vl.py +6 -3
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +50 -38
- sglang/srt/models/qwen3_next.py +43 -21
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +791 -0
- sglang/srt/models/qwen3_vl_moe.py +343 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +268 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +61 -0
- sglang/srt/multimodal/processors/base_processor.py +21 -9
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +2 -4
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +20 -10
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +83 -17
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +36 -23
- sglang/srt/sampling/sampling_params.py +75 -0
- sglang/srt/server_args.py +1300 -338
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +161 -0
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
- sglang/srt/speculative/eagle_info.py +786 -0
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +113 -1270
- sglang/srt/speculative/eagle_worker.py +120 -285
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/ngram_info.py +433 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +49 -0
- sglang/srt/speculative/spec_utils.py +641 -0
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +35 -18
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/{utils.py → utils/common.py} +583 -113
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +120 -11
- sglang/test/runners.py +3 -1
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +8 -2
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +3 -4
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +430 -0
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +93 -1
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +432 -16
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
- sglang/srt/entrypoints/grpc_request_manager.py +0 -580
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -5,7 +5,6 @@
|
|
|
5
5
|
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
|
6
6
|
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
|
7
7
|
|
|
8
|
-
import math
|
|
9
8
|
|
|
10
9
|
import torch
|
|
11
10
|
import torch.nn.functional as F
|
|
@@ -13,6 +12,8 @@ import triton
|
|
|
13
12
|
import triton.language as tl
|
|
14
13
|
from einops import rearrange
|
|
15
14
|
|
|
15
|
+
from sglang.srt.utils import device_context
|
|
16
|
+
|
|
16
17
|
|
|
17
18
|
def rms_norm_ref(
|
|
18
19
|
x,
|
|
@@ -158,7 +159,7 @@ def _layer_norm_fwd(
|
|
|
158
159
|
# heuristics for number of warps
|
|
159
160
|
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
|
160
161
|
grid = (M, ngroups)
|
|
161
|
-
with
|
|
162
|
+
with device_context(x.device):
|
|
162
163
|
_layer_norm_fwd_1pass_kernel[grid](
|
|
163
164
|
x,
|
|
164
165
|
out,
|
|
@@ -181,6 +182,45 @@ def _layer_norm_fwd(
|
|
|
181
182
|
return out, mean, rstd
|
|
182
183
|
|
|
183
184
|
|
|
185
|
+
def rms_norm_gated(
|
|
186
|
+
*,
|
|
187
|
+
x,
|
|
188
|
+
weight,
|
|
189
|
+
bias,
|
|
190
|
+
z=None,
|
|
191
|
+
eps=1e-6,
|
|
192
|
+
group_size=None,
|
|
193
|
+
norm_before_gate=True,
|
|
194
|
+
is_rms_norm=False,
|
|
195
|
+
):
|
|
196
|
+
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
|
197
|
+
|
|
198
|
+
x_shape_og = x.shape
|
|
199
|
+
# reshape input data into 2D tensor
|
|
200
|
+
x = x.reshape(-1, x.shape[-1])
|
|
201
|
+
if x.stride(-1) != 1:
|
|
202
|
+
x = x.contiguous()
|
|
203
|
+
if z is not None:
|
|
204
|
+
assert z.shape == x_shape_og
|
|
205
|
+
z = z.reshape(-1, z.shape[-1])
|
|
206
|
+
if z.stride(-1) != 1:
|
|
207
|
+
z = z.contiguous()
|
|
208
|
+
weight = weight.contiguous()
|
|
209
|
+
if bias is not None:
|
|
210
|
+
bias = bias.contiguous()
|
|
211
|
+
y, mean, rstd = _layer_norm_fwd(
|
|
212
|
+
x,
|
|
213
|
+
weight,
|
|
214
|
+
bias,
|
|
215
|
+
eps,
|
|
216
|
+
z=z,
|
|
217
|
+
group_size=group_size,
|
|
218
|
+
norm_before_gate=norm_before_gate,
|
|
219
|
+
is_rms_norm=is_rms_norm,
|
|
220
|
+
)
|
|
221
|
+
return y.reshape(x_shape_og)
|
|
222
|
+
|
|
223
|
+
|
|
184
224
|
class LayerNormFn(torch.autograd.Function):
|
|
185
225
|
|
|
186
226
|
@staticmethod
|
|
@@ -195,32 +235,16 @@ class LayerNormFn(torch.autograd.Function):
|
|
|
195
235
|
norm_before_gate=True,
|
|
196
236
|
is_rms_norm=False,
|
|
197
237
|
):
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
if x.stride(-1) != 1:
|
|
204
|
-
x = x.contiguous()
|
|
205
|
-
if z is not None:
|
|
206
|
-
assert z.shape == x_shape_og
|
|
207
|
-
z = z.reshape(-1, z.shape[-1])
|
|
208
|
-
if z.stride(-1) != 1:
|
|
209
|
-
z = z.contiguous()
|
|
210
|
-
weight = weight.contiguous()
|
|
211
|
-
if bias is not None:
|
|
212
|
-
bias = bias.contiguous()
|
|
213
|
-
y, mean, rstd = _layer_norm_fwd(
|
|
214
|
-
x,
|
|
215
|
-
weight,
|
|
216
|
-
bias,
|
|
217
|
-
eps,
|
|
238
|
+
return rms_norm_gated(
|
|
239
|
+
x=x,
|
|
240
|
+
weight=weight,
|
|
241
|
+
bias=bias,
|
|
242
|
+
eps=eps,
|
|
218
243
|
z=z,
|
|
219
244
|
group_size=group_size,
|
|
220
245
|
norm_before_gate=norm_before_gate,
|
|
221
246
|
is_rms_norm=is_rms_norm,
|
|
222
247
|
)
|
|
223
|
-
return y.reshape(x_shape_og)
|
|
224
248
|
|
|
225
249
|
|
|
226
250
|
def layernorm_fn(
|
|
@@ -238,14 +262,6 @@ def layernorm_fn(
|
|
|
238
262
|
)
|
|
239
263
|
|
|
240
264
|
|
|
241
|
-
def rmsnorm_fn(
|
|
242
|
-
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
|
243
|
-
):
|
|
244
|
-
return LayerNormFn.apply(
|
|
245
|
-
x, weight, bias, z, eps, group_size, norm_before_gate, True
|
|
246
|
-
)
|
|
247
|
-
|
|
248
|
-
|
|
249
265
|
class LayerNorm(torch.nn.Module):
|
|
250
266
|
|
|
251
267
|
def __init__(
|
|
@@ -284,6 +300,7 @@ class LayerNorm(torch.nn.Module):
|
|
|
284
300
|
group_size=self.group_size,
|
|
285
301
|
eps=self.eps,
|
|
286
302
|
norm_before_gate=self.norm_before_gate,
|
|
303
|
+
is_rms_norm=False,
|
|
287
304
|
)
|
|
288
305
|
|
|
289
306
|
|
|
@@ -315,7 +332,7 @@ class RMSNorm(torch.nn.Module):
|
|
|
315
332
|
|
|
316
333
|
def forward(self, x, z=None):
|
|
317
334
|
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
|
318
|
-
return
|
|
335
|
+
return layernorm_fn(
|
|
319
336
|
x,
|
|
320
337
|
self.weight,
|
|
321
338
|
self.bias,
|
|
@@ -323,4 +340,5 @@ class RMSNorm(torch.nn.Module):
|
|
|
323
340
|
eps=self.eps,
|
|
324
341
|
group_size=self.group_size,
|
|
325
342
|
norm_before_gate=self.norm_before_gate,
|
|
343
|
+
is_rms_norm=True,
|
|
326
344
|
)
|
|
@@ -9,8 +9,6 @@ import triton
|
|
|
9
9
|
import triton.language as tl
|
|
10
10
|
|
|
11
11
|
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
|
|
12
|
-
from sglang.srt.layers.attention.fla.op import safe_exp
|
|
13
|
-
from sglang.srt.layers.attention.fla.utils import check_shared_mem
|
|
14
12
|
|
|
15
13
|
|
|
16
14
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import torch
|
|
@@ -10,10 +10,10 @@ import triton.language as tl
|
|
|
10
10
|
|
|
11
11
|
from sglang.srt.configs.model_config import AttentionArch
|
|
12
12
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
13
|
-
from sglang.srt.
|
|
14
|
-
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
|
13
|
+
from sglang.srt.layers.radix_attention import AttentionType
|
|
15
14
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
16
|
-
from sglang.srt.
|
|
15
|
+
from sglang.srt.server_args import get_global_server_args
|
|
16
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
19
19
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
@@ -305,6 +305,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
305
305
|
speculative_step_id=0,
|
|
306
306
|
topk=0,
|
|
307
307
|
speculative_num_steps=0,
|
|
308
|
+
fa_impl_ver=3,
|
|
308
309
|
):
|
|
309
310
|
super().__init__()
|
|
310
311
|
|
|
@@ -338,6 +339,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
338
339
|
)
|
|
339
340
|
self.speculative_step_id = speculative_step_id
|
|
340
341
|
|
|
342
|
+
self.fa_impl_ver = fa_impl_ver
|
|
343
|
+
|
|
341
344
|
# Local attention settings
|
|
342
345
|
self.attention_chunk_size = (
|
|
343
346
|
model_runner.attention_chunk_size
|
|
@@ -352,6 +355,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
352
355
|
self.sliding_window_size is not None and self.sliding_window_size > -1
|
|
353
356
|
)
|
|
354
357
|
|
|
358
|
+
# If num_splits == 0, we use a heuristic to automatically determine the number of splits.
|
|
359
|
+
# We set nums splits to 1 if deterministic inference is enabled.
|
|
360
|
+
# See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
|
|
361
|
+
self.num_splits = (
|
|
362
|
+
1 if model_runner.server_args.enable_deterministic_inference else 0
|
|
363
|
+
)
|
|
364
|
+
|
|
355
365
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
356
366
|
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
|
357
367
|
metadata = FlashAttentionMetadata()
|
|
@@ -682,8 +692,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
682
692
|
k_descale, v_descale = None, None
|
|
683
693
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
|
684
694
|
# has corresponding quantization method so that layer.k_scale is not None,
|
|
685
|
-
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case
|
|
686
|
-
|
|
695
|
+
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
|
|
696
|
+
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
|
|
697
|
+
if (
|
|
698
|
+
self.kv_cache_dtype_str != "auto"
|
|
699
|
+
and layer.head_dim <= 256
|
|
700
|
+
and self.fa_impl_ver != 4
|
|
701
|
+
):
|
|
687
702
|
if layer.k_scale is not None:
|
|
688
703
|
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
|
689
704
|
k_descale = layer.k_scale.expand(descale_shape)
|
|
@@ -691,7 +706,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
691
706
|
q = q.to(self.kv_cache_dtype)
|
|
692
707
|
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
|
|
693
708
|
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
|
|
694
|
-
causal =
|
|
709
|
+
causal = True
|
|
710
|
+
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
|
711
|
+
causal = False
|
|
695
712
|
|
|
696
713
|
# Check if we should use local attention
|
|
697
714
|
use_local_attn = (
|
|
@@ -712,6 +729,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
712
729
|
|
|
713
730
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
|
714
731
|
kwargs = {}
|
|
732
|
+
if self.fa_impl_ver != 3:
|
|
733
|
+
kwargs["ver"] = self.fa_impl_ver
|
|
715
734
|
if sinks is not None:
|
|
716
735
|
kwargs["sinks"] = sinks
|
|
717
736
|
|
|
@@ -770,6 +789,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
770
789
|
k_descale=k_descale,
|
|
771
790
|
v_descale=v_descale,
|
|
772
791
|
return_softmax_lse=use_cascade_attn,
|
|
792
|
+
num_splits=self.num_splits,
|
|
773
793
|
**kwargs,
|
|
774
794
|
)
|
|
775
795
|
|
|
@@ -791,6 +811,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
791
811
|
k_descale=k_descale,
|
|
792
812
|
v_descale=v_descale,
|
|
793
813
|
return_softmax_lse=True,
|
|
814
|
+
num_splits=self.num_splits,
|
|
794
815
|
**kwargs,
|
|
795
816
|
)
|
|
796
817
|
o, _ = merge_state_v2_wrapper(
|
|
@@ -809,7 +830,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
809
830
|
):
|
|
810
831
|
# Do multi-head attention with chunked prefix cache
|
|
811
832
|
if forward_batch.attn_attend_prefix_cache:
|
|
812
|
-
assert not
|
|
833
|
+
assert not get_global_server_args().disable_chunked_prefix_cache
|
|
813
834
|
# MHA for chunked prefix kv cache when running model with MLA
|
|
814
835
|
assert forward_batch.prefix_chunk_idx is not None
|
|
815
836
|
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
|
@@ -830,6 +851,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
830
851
|
softmax_scale=layer.scaling,
|
|
831
852
|
causal=False,
|
|
832
853
|
return_softmax_lse=True,
|
|
854
|
+
**kwargs,
|
|
833
855
|
)
|
|
834
856
|
else:
|
|
835
857
|
# MHA for extend part of sequence without attending prefix kv cache
|
|
@@ -844,6 +866,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
844
866
|
softmax_scale=layer.scaling,
|
|
845
867
|
causal=True,
|
|
846
868
|
return_softmax_lse=forward_batch.mha_return_lse,
|
|
869
|
+
**kwargs,
|
|
847
870
|
)
|
|
848
871
|
if forward_batch.mha_return_lse:
|
|
849
872
|
output, lse, *rest = output
|
|
@@ -851,6 +874,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
851
874
|
return output, lse
|
|
852
875
|
return output
|
|
853
876
|
else:
|
|
877
|
+
assert self.fa_impl_ver in [3], "Only FA3 support here"
|
|
854
878
|
# Do absorbed multi-latent attention
|
|
855
879
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
|
856
880
|
layer.layer_id
|
|
@@ -892,6 +916,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
892
916
|
k_descale=k_descale,
|
|
893
917
|
v_descale=v_descale,
|
|
894
918
|
return_softmax_lse=use_cascade_attn,
|
|
919
|
+
num_splits=self.num_splits,
|
|
895
920
|
)
|
|
896
921
|
if use_cascade_attn:
|
|
897
922
|
o, softmax_lse, *rest = result
|
|
@@ -913,6 +938,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
913
938
|
k_descale=k_descale,
|
|
914
939
|
v_descale=v_descale,
|
|
915
940
|
return_softmax_lse=True,
|
|
941
|
+
num_splits=self.num_splits,
|
|
916
942
|
)
|
|
917
943
|
)
|
|
918
944
|
o, _ = merge_state_v2_wrapper(
|
|
@@ -939,6 +965,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
939
965
|
k_rope: Optional[torch.Tensor] = None,
|
|
940
966
|
sinks: Optional[torch.Tensor] = None,
|
|
941
967
|
) -> torch.Tensor:
|
|
968
|
+
assert self.fa_impl_ver in [3], "Only FA3 support decoding"
|
|
942
969
|
if k is not None:
|
|
943
970
|
assert v is not None
|
|
944
971
|
if save_kv_cache:
|
|
@@ -981,10 +1008,14 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
981
1008
|
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
|
982
1009
|
else (-1, -1)
|
|
983
1010
|
)
|
|
984
|
-
causal =
|
|
1011
|
+
causal = True
|
|
1012
|
+
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
|
1013
|
+
causal = False
|
|
985
1014
|
|
|
986
1015
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
|
987
1016
|
kwargs = {}
|
|
1017
|
+
if self.fa_impl_ver != 3:
|
|
1018
|
+
kwargs["ver"] = self.fa_impl_ver
|
|
988
1019
|
if sinks is not None:
|
|
989
1020
|
kwargs["sinks"] = sinks
|
|
990
1021
|
|
|
@@ -1030,6 +1061,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
1030
1061
|
softcap=layer.logit_cap,
|
|
1031
1062
|
k_descale=k_descale,
|
|
1032
1063
|
v_descale=v_descale,
|
|
1064
|
+
num_splits=self.num_splits,
|
|
1033
1065
|
**kwargs,
|
|
1034
1066
|
)
|
|
1035
1067
|
elif use_local_attn:
|
|
@@ -1049,6 +1081,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
1049
1081
|
softcap=layer.logit_cap,
|
|
1050
1082
|
k_descale=k_descale,
|
|
1051
1083
|
v_descale=v_descale,
|
|
1084
|
+
num_splits=self.num_splits,
|
|
1052
1085
|
**kwargs,
|
|
1053
1086
|
)
|
|
1054
1087
|
else:
|
|
@@ -1077,6 +1110,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
1077
1110
|
k_descale=k_descale,
|
|
1078
1111
|
v_descale=v_descale,
|
|
1079
1112
|
return_softmax_lse=use_cascade_attn,
|
|
1113
|
+
num_splits=self.num_splits,
|
|
1080
1114
|
**kwargs,
|
|
1081
1115
|
)
|
|
1082
1116
|
if use_cascade_attn:
|
|
@@ -1098,6 +1132,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
1098
1132
|
k_descale=k_descale,
|
|
1099
1133
|
v_descale=v_descale,
|
|
1100
1134
|
return_softmax_lse=True,
|
|
1135
|
+
num_splits=self.num_splits,
|
|
1101
1136
|
**kwargs,
|
|
1102
1137
|
)
|
|
1103
1138
|
)
|
|
@@ -1153,6 +1188,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
1153
1188
|
k_descale=k_descale,
|
|
1154
1189
|
v_descale=v_descale,
|
|
1155
1190
|
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
|
|
1191
|
+
num_splits=self.num_splits,
|
|
1156
1192
|
)
|
|
1157
1193
|
if use_cascade_attn:
|
|
1158
1194
|
o, softmax_lse, *rest = result
|
|
@@ -1173,6 +1209,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
1173
1209
|
k_descale=k_descale,
|
|
1174
1210
|
v_descale=v_descale,
|
|
1175
1211
|
return_softmax_lse=True,
|
|
1212
|
+
num_splits=self.num_splits,
|
|
1176
1213
|
)
|
|
1177
1214
|
o, _ = merge_state_v2(
|
|
1178
1215
|
o,
|
|
@@ -1453,7 +1490,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
1453
1490
|
seq_lens: torch.Tensor,
|
|
1454
1491
|
encoder_lens: Optional[torch.Tensor],
|
|
1455
1492
|
forward_mode: ForwardMode,
|
|
1456
|
-
spec_info: Optional[
|
|
1493
|
+
spec_info: Optional[SpecInput],
|
|
1457
1494
|
):
|
|
1458
1495
|
"""Initialize forward metadata for capturing CUDA graph."""
|
|
1459
1496
|
metadata = FlashAttentionMetadata()
|
|
@@ -1688,7 +1725,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
1688
1725
|
seq_lens_sum: int,
|
|
1689
1726
|
encoder_lens: Optional[torch.Tensor],
|
|
1690
1727
|
forward_mode: ForwardMode,
|
|
1691
|
-
spec_info: Optional[
|
|
1728
|
+
spec_info: Optional[SpecInput],
|
|
1692
1729
|
seq_lens_cpu: Optional[torch.Tensor],
|
|
1693
1730
|
out_cache_loc: Optional[torch.Tensor] = None,
|
|
1694
1731
|
):
|
|
@@ -2283,7 +2320,7 @@ class FlashAttentionMultiStepBackend:
|
|
|
2283
2320
|
self.topk = topk
|
|
2284
2321
|
self.speculative_num_steps = speculative_num_steps
|
|
2285
2322
|
self.attn_backends = []
|
|
2286
|
-
for i in range(self.speculative_num_steps):
|
|
2323
|
+
for i in range(self.speculative_num_steps - 1):
|
|
2287
2324
|
self.attn_backends.append(
|
|
2288
2325
|
FlashAttentionBackend(
|
|
2289
2326
|
model_runner,
|
|
@@ -2298,7 +2335,7 @@ class FlashAttentionMultiStepBackend:
|
|
|
2298
2335
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
|
2299
2336
|
|
|
2300
2337
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
|
2301
|
-
for i in range(self.speculative_num_steps):
|
|
2338
|
+
for i in range(self.speculative_num_steps - 1):
|
|
2302
2339
|
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
|
2303
2340
|
|
|
2304
2341
|
def init_forward_metadata_capture_cuda_graph(
|
|
@@ -2306,7 +2343,7 @@ class FlashAttentionMultiStepBackend:
|
|
|
2306
2343
|
forward_batch: ForwardBatch,
|
|
2307
2344
|
):
|
|
2308
2345
|
assert forward_batch.spec_info is not None
|
|
2309
|
-
assert
|
|
2346
|
+
assert forward_batch.spec_info.is_draft_input()
|
|
2310
2347
|
|
|
2311
2348
|
for i in range(self.speculative_num_steps - 1):
|
|
2312
2349
|
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
|
@@ -2323,7 +2360,7 @@ class FlashAttentionMultiStepBackend:
|
|
|
2323
2360
|
self, forward_batch: ForwardBatch, bs: int
|
|
2324
2361
|
):
|
|
2325
2362
|
assert forward_batch.spec_info is not None
|
|
2326
|
-
assert
|
|
2363
|
+
assert forward_batch.spec_info.is_draft_input()
|
|
2327
2364
|
|
|
2328
2365
|
for i in range(self.speculative_num_steps - 1):
|
|
2329
2366
|
# TODO: incrementally update the metadata for the later steps,
|