sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
|
|
1
|
+
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
import contextlib
|
5
|
+
import functools
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
import sys
|
9
|
+
from enum import Enum
|
10
|
+
from functools import lru_cache
|
11
|
+
from typing import Any, Callable, Dict, Literal, Optional, Tuple
|
12
|
+
|
13
|
+
import torch
|
14
|
+
import triton
|
15
|
+
from packaging import version
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
|
20
|
+
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
|
21
|
+
|
22
|
+
|
23
|
+
@lru_cache(maxsize=1)
|
24
|
+
def check_environments():
|
25
|
+
"""
|
26
|
+
Checks the current operating system, Triton version, and Python version,
|
27
|
+
issuing warnings if they don't meet recommendations.
|
28
|
+
This function's body only runs once due to lru_cache.
|
29
|
+
"""
|
30
|
+
# Check Operating System
|
31
|
+
if sys.platform == "win32":
|
32
|
+
logger.warning(
|
33
|
+
"Detected Windows operating system. Triton does not have an official Windows release, "
|
34
|
+
"thus FLA will not be adapted for Windows, and any potential errors will not be fixed. "
|
35
|
+
"Please consider using a Linux environment for compatibility."
|
36
|
+
)
|
37
|
+
|
38
|
+
triton_version = version.parse(triton.__version__)
|
39
|
+
required_triton_version = version.parse("3.2.0")
|
40
|
+
|
41
|
+
if triton_version < required_triton_version:
|
42
|
+
logger.warning(
|
43
|
+
f"Current Triton version {triton_version} is below the recommended 3.2.0 version. "
|
44
|
+
"Errors may occur and these issues will not be fixed. "
|
45
|
+
"Please consider upgrading Triton."
|
46
|
+
)
|
47
|
+
|
48
|
+
# Check Python version
|
49
|
+
py_version = version.parse(f"{sys.version_info.major}.{sys.version_info.minor}")
|
50
|
+
required_py_version = version.parse("3.11")
|
51
|
+
|
52
|
+
if py_version < required_py_version:
|
53
|
+
logger.warning(
|
54
|
+
f"Current Python version {py_version} is below the recommended 3.11 version. "
|
55
|
+
"It is recommended to upgrade to Python 3.11 or higher for the best experience."
|
56
|
+
)
|
57
|
+
|
58
|
+
return None
|
59
|
+
|
60
|
+
|
61
|
+
check_environments()
|
62
|
+
|
63
|
+
|
64
|
+
def get_abs_err(x, y):
|
65
|
+
return (x.detach() - y.detach()).flatten().abs().max().item()
|
66
|
+
|
67
|
+
|
68
|
+
def get_err_ratio(x, y):
|
69
|
+
err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
|
70
|
+
base = (x.detach()).flatten().square().mean().sqrt().item()
|
71
|
+
return err / (base + 1e-8)
|
72
|
+
|
73
|
+
|
74
|
+
def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):
|
75
|
+
abs_atol = get_abs_err(ref, tri)
|
76
|
+
msg = f"{prefix} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
|
77
|
+
logger.info(msg)
|
78
|
+
error_rate = get_err_ratio(ref, tri)
|
79
|
+
if abs_atol <= err_atol:
|
80
|
+
return
|
81
|
+
if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)):
|
82
|
+
if error_rate > ratio:
|
83
|
+
import warnings
|
84
|
+
|
85
|
+
warnings.warn(msg)
|
86
|
+
else:
|
87
|
+
assert error_rate < ratio, msg
|
88
|
+
|
89
|
+
|
90
|
+
SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
|
91
|
+
|
92
|
+
|
93
|
+
def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
94
|
+
"""
|
95
|
+
A decorator that caches the most recent results of a function with tensor inputs.
|
96
|
+
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
97
|
+
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
|
98
|
+
Args:
|
99
|
+
fn (Callable[..., torch.Tensor]):
|
100
|
+
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
101
|
+
Returns:
|
102
|
+
Callable[..., torch.Tensor]:
|
103
|
+
A wrapped version of the input function with single-entry caching.
|
104
|
+
"""
|
105
|
+
|
106
|
+
cache_entries: Tuple[Optional[Tuple], Optional[Dict], Any] = []
|
107
|
+
cache_size = 4
|
108
|
+
|
109
|
+
@functools.wraps(fn)
|
110
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
111
|
+
nonlocal cache_entries, cache_size
|
112
|
+
for i, entry in enumerate(cache_entries):
|
113
|
+
last_args, last_kwargs, last_result = entry
|
114
|
+
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
|
115
|
+
if all(a is b for a, b in zip(args, last_args)) and all(
|
116
|
+
k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
|
117
|
+
):
|
118
|
+
cache_entries = (
|
119
|
+
cache_entries[:i]
|
120
|
+
+ cache_entries[i + 1 :]
|
121
|
+
+ [(args, kwargs, last_result)]
|
122
|
+
)
|
123
|
+
return last_result
|
124
|
+
|
125
|
+
result = fn(*args, **kwargs)
|
126
|
+
|
127
|
+
if len(cache_entries) >= cache_size:
|
128
|
+
cache_entries = cache_entries[1:]
|
129
|
+
cache_entries.append((args, kwargs, result))
|
130
|
+
return result
|
131
|
+
|
132
|
+
return wrapper
|
133
|
+
|
134
|
+
|
135
|
+
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
136
|
+
"""
|
137
|
+
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
138
|
+
"""
|
139
|
+
|
140
|
+
@functools.wraps(fn)
|
141
|
+
def wrapper(*args, **kwargs):
|
142
|
+
contiguous_args = (
|
143
|
+
i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
|
144
|
+
)
|
145
|
+
contiguous_kwargs = {
|
146
|
+
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
147
|
+
for k, v in kwargs.items()
|
148
|
+
}
|
149
|
+
|
150
|
+
tensor = None
|
151
|
+
for arg in args:
|
152
|
+
if isinstance(arg, torch.Tensor):
|
153
|
+
tensor = arg
|
154
|
+
break
|
155
|
+
if tensor is None:
|
156
|
+
for value in kwargs.values():
|
157
|
+
if isinstance(value, torch.Tensor):
|
158
|
+
tensor = value
|
159
|
+
break
|
160
|
+
|
161
|
+
if tensor is not None:
|
162
|
+
ctx = custom_device_ctx(tensor.device.index)
|
163
|
+
else:
|
164
|
+
ctx = contextlib.nullcontext()
|
165
|
+
|
166
|
+
with ctx:
|
167
|
+
return fn(*contiguous_args, **contiguous_kwargs)
|
168
|
+
|
169
|
+
return wrapper
|
170
|
+
|
171
|
+
|
172
|
+
contiguous = input_guard
|
173
|
+
|
174
|
+
|
175
|
+
def require_version(version, hint):
|
176
|
+
"""
|
177
|
+
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
|
178
|
+
"""
|
179
|
+
|
180
|
+
def decorator(fn):
|
181
|
+
@functools.wraps(fn)
|
182
|
+
def wrapper(ctx, *args, **kwargs):
|
183
|
+
from transformers.utils.versions import require_version
|
184
|
+
|
185
|
+
require_version(version, hint)
|
186
|
+
return fn(
|
187
|
+
ctx,
|
188
|
+
*(
|
189
|
+
i if not isinstance(i, torch.Tensor) else i.contiguous()
|
190
|
+
for i in args
|
191
|
+
),
|
192
|
+
**{
|
193
|
+
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
194
|
+
for k, v in kwargs.items()
|
195
|
+
},
|
196
|
+
)
|
197
|
+
|
198
|
+
return wrapper
|
199
|
+
|
200
|
+
return decorator
|
201
|
+
|
202
|
+
|
203
|
+
def checkpoint(fn):
|
204
|
+
def wrapper(*args, **kwargs):
|
205
|
+
return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
|
206
|
+
|
207
|
+
return wrapper
|
208
|
+
|
209
|
+
|
210
|
+
@lru_cache(maxsize=None)
|
211
|
+
def check_pytorch_version(version_s: str = "2.4") -> bool:
|
212
|
+
return version.parse(torch.__version__) >= version.parse(version_s)
|
213
|
+
|
214
|
+
|
215
|
+
def _cpu_device_warning():
|
216
|
+
import warnings
|
217
|
+
|
218
|
+
warnings.warn(
|
219
|
+
("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
|
220
|
+
)
|
221
|
+
|
222
|
+
|
223
|
+
@lru_cache(maxsize=None)
|
224
|
+
def get_multiprocessor_count(tensor_idx: int = 0) -> int:
|
225
|
+
try:
|
226
|
+
return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
|
227
|
+
"multiprocessor_count"
|
228
|
+
]
|
229
|
+
except BaseException:
|
230
|
+
_cpu_device_warning()
|
231
|
+
return -1
|
232
|
+
|
233
|
+
|
234
|
+
@lru_cache(maxsize=None)
|
235
|
+
def get_available_device() -> str:
|
236
|
+
try:
|
237
|
+
return triton.runtime.driver.active.get_current_target().backend
|
238
|
+
except BaseException:
|
239
|
+
_cpu_device_warning()
|
240
|
+
return "cpu"
|
241
|
+
|
242
|
+
|
243
|
+
@lru_cache(maxsize=None)
|
244
|
+
def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
|
245
|
+
device = get_available_device()
|
246
|
+
if device == "cuda":
|
247
|
+
return "nvidia"
|
248
|
+
elif device == "hip":
|
249
|
+
return "amd"
|
250
|
+
elif device == "xpu":
|
251
|
+
return "intel"
|
252
|
+
else:
|
253
|
+
return device
|
254
|
+
|
255
|
+
|
256
|
+
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
257
|
+
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
258
|
+
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
259
|
+
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
260
|
+
device_torch_lib = getattr(torch, device)
|
261
|
+
device_platform = _check_platform()
|
262
|
+
|
263
|
+
is_amd = device_platform == "amd"
|
264
|
+
is_intel = device_platform == "intel"
|
265
|
+
is_nvidia = device_platform == "nvidia"
|
266
|
+
is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
|
267
|
+
is_nvidia_hopper = is_nvidia and (
|
268
|
+
"NVIDIA H" in torch.cuda.get_device_name(0)
|
269
|
+
or torch.cuda.get_device_capability()[0] >= 9
|
270
|
+
)
|
271
|
+
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
|
272
|
+
|
273
|
+
# Nvidia Ampere or newer, haven't check AMD and intel yet.
|
274
|
+
is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
|
275
|
+
is_gather_supported = hasattr(triton.language, "gather")
|
276
|
+
|
277
|
+
|
278
|
+
def get_all_max_shared_mem():
|
279
|
+
try:
|
280
|
+
return [
|
281
|
+
triton.runtime.driver.active.utils.get_device_properties(i)[
|
282
|
+
"max_shared_mem"
|
283
|
+
]
|
284
|
+
for i in range(device_torch_lib.device_count())
|
285
|
+
]
|
286
|
+
except BaseException:
|
287
|
+
_cpu_device_warning()
|
288
|
+
return [-1]
|
289
|
+
|
290
|
+
|
291
|
+
class Backend(Enum):
|
292
|
+
ADA = 101376 # RTX 4090
|
293
|
+
AMPERE = 166912 # A100
|
294
|
+
HOPPER = 232448 # H100
|
295
|
+
DEFAULT = 102400 # Default
|
296
|
+
|
297
|
+
@classmethod
|
298
|
+
def get_shared_memory(cls, arch: str) -> int:
|
299
|
+
try:
|
300
|
+
return cls[arch.upper()].value
|
301
|
+
except KeyError:
|
302
|
+
return cls.DEFAULT.value
|
303
|
+
|
304
|
+
|
305
|
+
@lru_cache(maxsize=None)
|
306
|
+
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
307
|
+
try:
|
308
|
+
device_shared_mem_list = get_all_max_shared_mem()
|
309
|
+
max_shared_memory = device_shared_mem_list[tensor_idx]
|
310
|
+
return max_shared_memory >= Backend.get_shared_memory(arch)
|
311
|
+
except Exception:
|
312
|
+
return False
|
313
|
+
|
314
|
+
|
315
|
+
if check_pytorch_version("2.4"):
|
316
|
+
device = "cuda" if device == "cpu" else device
|
317
|
+
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
|
318
|
+
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
|
319
|
+
|
320
|
+
def custom_device_ctx(index: int):
|
321
|
+
return device_torch_lib.device(index)
|
322
|
+
|
323
|
+
else:
|
324
|
+
assert (
|
325
|
+
device == "cuda"
|
326
|
+
), "Only cuda device is supported for PyTorch version < 2.4.0."
|
327
|
+
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
|
328
|
+
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
|
329
|
+
|
330
|
+
def custom_device_ctx(index: int):
|
331
|
+
return torch.cuda.device(index)
|
@@ -0,0 +1,158 @@
|
|
1
|
+
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/wy_fast.py
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
4
|
+
|
5
|
+
from typing import Optional, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import triton
|
9
|
+
import triton.language as tl
|
10
|
+
|
11
|
+
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
|
12
|
+
from sglang.srt.layers.attention.fla.op import safe_exp
|
13
|
+
from sglang.srt.layers.attention.fla.utils import check_shared_mem
|
14
|
+
|
15
|
+
|
16
|
+
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
17
|
+
# @triton.autotune(
|
18
|
+
# configs=[
|
19
|
+
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
20
|
+
# for num_warps in [2, 4, 8]
|
21
|
+
# for num_stages in [2, 3, 4]
|
22
|
+
# ],
|
23
|
+
# key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
|
24
|
+
# )
|
25
|
+
@triton.jit(do_not_specialize=["T"])
|
26
|
+
def recompute_w_u_fwd_kernel(
|
27
|
+
k,
|
28
|
+
v,
|
29
|
+
beta,
|
30
|
+
w,
|
31
|
+
u,
|
32
|
+
A,
|
33
|
+
g,
|
34
|
+
cu_seqlens,
|
35
|
+
chunk_indices,
|
36
|
+
T,
|
37
|
+
H: tl.constexpr,
|
38
|
+
Hg: tl.constexpr,
|
39
|
+
K: tl.constexpr,
|
40
|
+
V: tl.constexpr,
|
41
|
+
BT: tl.constexpr,
|
42
|
+
BK: tl.constexpr,
|
43
|
+
BV: tl.constexpr,
|
44
|
+
IS_VARLEN: tl.constexpr,
|
45
|
+
):
|
46
|
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
47
|
+
i_b, i_h = i_bh // H, i_bh % H
|
48
|
+
if IS_VARLEN:
|
49
|
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
50
|
+
chunk_indices + i_t * 2 + 1
|
51
|
+
).to(tl.int32)
|
52
|
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
53
|
+
cu_seqlens + i_n + 1
|
54
|
+
).to(tl.int32)
|
55
|
+
T = eos - bos
|
56
|
+
else:
|
57
|
+
bos, eos = i_b * T, i_b * T + T
|
58
|
+
p_beta = tl.make_block_ptr(
|
59
|
+
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
60
|
+
)
|
61
|
+
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
|
62
|
+
p_A = tl.make_block_ptr(
|
63
|
+
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
64
|
+
)
|
65
|
+
b_beta = tl.load(p_beta, boundary_check=(0,))
|
66
|
+
b_A = tl.load(p_A, boundary_check=(0, 1))
|
67
|
+
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
|
68
|
+
|
69
|
+
for i_v in range(tl.cdiv(V, BV)):
|
70
|
+
p_v = tl.make_block_ptr(
|
71
|
+
v + (bos * H + i_h) * V,
|
72
|
+
(T, V),
|
73
|
+
(H * V, 1),
|
74
|
+
(i_t * BT, i_v * BV),
|
75
|
+
(BT, BV),
|
76
|
+
(1, 0),
|
77
|
+
)
|
78
|
+
p_u = tl.make_block_ptr(
|
79
|
+
u + (bos * H + i_h) * V,
|
80
|
+
(T, V),
|
81
|
+
(H * V, 1),
|
82
|
+
(i_t * BT, i_v * BV),
|
83
|
+
(BT, BV),
|
84
|
+
(1, 0),
|
85
|
+
)
|
86
|
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
87
|
+
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
|
88
|
+
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
89
|
+
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
90
|
+
|
91
|
+
for i_k in range(tl.cdiv(K, BK)):
|
92
|
+
p_k = tl.make_block_ptr(
|
93
|
+
k + (bos * Hg + i_h // (H // Hg)) * K,
|
94
|
+
(T, K),
|
95
|
+
(Hg * K, 1),
|
96
|
+
(i_t * BT, i_k * BK),
|
97
|
+
(BT, BK),
|
98
|
+
(1, 0),
|
99
|
+
)
|
100
|
+
p_w = tl.make_block_ptr(
|
101
|
+
w + (bos * H + i_h) * K,
|
102
|
+
(T, K),
|
103
|
+
(H * K, 1),
|
104
|
+
(i_t * BT, i_k * BK),
|
105
|
+
(BT, BK),
|
106
|
+
(1, 0),
|
107
|
+
)
|
108
|
+
b_k = tl.load(p_k, boundary_check=(0, 1))
|
109
|
+
b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
|
110
|
+
b_w = tl.dot(b_A, b_kb)
|
111
|
+
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
112
|
+
|
113
|
+
|
114
|
+
def recompute_w_u_fwd(
|
115
|
+
k: torch.Tensor,
|
116
|
+
v: torch.Tensor,
|
117
|
+
beta: torch.Tensor,
|
118
|
+
g_cumsum: torch.Tensor,
|
119
|
+
A: torch.Tensor,
|
120
|
+
cu_seqlens: Optional[torch.LongTensor],
|
121
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
122
|
+
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
123
|
+
H = v.shape[-2]
|
124
|
+
BT = A.shape[-1]
|
125
|
+
|
126
|
+
chunk_indices = (
|
127
|
+
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
128
|
+
)
|
129
|
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
130
|
+
BK = 64
|
131
|
+
BV = 64
|
132
|
+
u = torch.empty_like(v)
|
133
|
+
w = k.new_empty(B, T, H, K)
|
134
|
+
recompute_w_u_fwd_kernel[(NT, B * H)](
|
135
|
+
k=k,
|
136
|
+
v=v,
|
137
|
+
beta=beta,
|
138
|
+
w=w,
|
139
|
+
u=u,
|
140
|
+
A=A,
|
141
|
+
g=g_cumsum,
|
142
|
+
cu_seqlens=cu_seqlens,
|
143
|
+
chunk_indices=chunk_indices,
|
144
|
+
T=T,
|
145
|
+
H=H,
|
146
|
+
Hg=Hg,
|
147
|
+
K=K,
|
148
|
+
V=V,
|
149
|
+
BT=BT,
|
150
|
+
BK=BK,
|
151
|
+
BV=BV,
|
152
|
+
num_warps=4,
|
153
|
+
num_stages=3,
|
154
|
+
)
|
155
|
+
return w, u
|
156
|
+
|
157
|
+
|
158
|
+
fwd_recompute_w_u = recompute_w_u_fwd
|
@@ -11,9 +11,8 @@ import triton.language as tl
|
|
11
11
|
from sglang.srt.configs.model_config import AttentionArch
|
12
12
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
13
13
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
14
|
-
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
15
14
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
16
|
-
from sglang.srt.speculative.
|
15
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
17
16
|
|
18
17
|
if TYPE_CHECKING:
|
19
18
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -305,6 +304,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
305
304
|
speculative_step_id=0,
|
306
305
|
topk=0,
|
307
306
|
speculative_num_steps=0,
|
307
|
+
fa_impl_ver=3,
|
308
308
|
):
|
309
309
|
super().__init__()
|
310
310
|
|
@@ -338,6 +338,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
338
338
|
)
|
339
339
|
self.speculative_step_id = speculative_step_id
|
340
340
|
|
341
|
+
self.fa_impl_ver = fa_impl_ver
|
342
|
+
|
341
343
|
# Local attention settings
|
342
344
|
self.attention_chunk_size = (
|
343
345
|
model_runner.attention_chunk_size
|
@@ -352,6 +354,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
352
354
|
self.sliding_window_size is not None and self.sliding_window_size > -1
|
353
355
|
)
|
354
356
|
|
357
|
+
# If num_splits == 0, we use a heuristic to automatically determine the number of splits.
|
358
|
+
# We set nums splits to 1 if deterministic inference is enabled.
|
359
|
+
# See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
|
360
|
+
self.num_splits = (
|
361
|
+
1 if model_runner.server_args.enable_deterministic_inference else 0
|
362
|
+
)
|
363
|
+
|
355
364
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
356
365
|
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
357
366
|
metadata = FlashAttentionMetadata()
|
@@ -682,8 +691,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
682
691
|
k_descale, v_descale = None, None
|
683
692
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
684
693
|
# has corresponding quantization method so that layer.k_scale is not None,
|
685
|
-
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case
|
686
|
-
|
694
|
+
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
|
695
|
+
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
|
696
|
+
if (
|
697
|
+
self.kv_cache_dtype_str != "auto"
|
698
|
+
and layer.head_dim <= 256
|
699
|
+
and self.fa_impl_ver != 4
|
700
|
+
):
|
687
701
|
if layer.k_scale is not None:
|
688
702
|
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
689
703
|
k_descale = layer.k_scale.expand(descale_shape)
|
@@ -712,6 +726,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
712
726
|
|
713
727
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
714
728
|
kwargs = {}
|
729
|
+
if self.fa_impl_ver != 3:
|
730
|
+
kwargs["ver"] = self.fa_impl_ver
|
715
731
|
if sinks is not None:
|
716
732
|
kwargs["sinks"] = sinks
|
717
733
|
|
@@ -770,6 +786,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
770
786
|
k_descale=k_descale,
|
771
787
|
v_descale=v_descale,
|
772
788
|
return_softmax_lse=use_cascade_attn,
|
789
|
+
num_splits=self.num_splits,
|
773
790
|
**kwargs,
|
774
791
|
)
|
775
792
|
|
@@ -791,6 +808,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
791
808
|
k_descale=k_descale,
|
792
809
|
v_descale=v_descale,
|
793
810
|
return_softmax_lse=True,
|
811
|
+
num_splits=self.num_splits,
|
794
812
|
**kwargs,
|
795
813
|
)
|
796
814
|
o, _ = merge_state_v2_wrapper(
|
@@ -830,6 +848,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
830
848
|
softmax_scale=layer.scaling,
|
831
849
|
causal=False,
|
832
850
|
return_softmax_lse=True,
|
851
|
+
**kwargs,
|
833
852
|
)
|
834
853
|
else:
|
835
854
|
# MHA for extend part of sequence without attending prefix kv cache
|
@@ -844,6 +863,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
844
863
|
softmax_scale=layer.scaling,
|
845
864
|
causal=True,
|
846
865
|
return_softmax_lse=forward_batch.mha_return_lse,
|
866
|
+
**kwargs,
|
847
867
|
)
|
848
868
|
if forward_batch.mha_return_lse:
|
849
869
|
output, lse, *rest = output
|
@@ -851,6 +871,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
851
871
|
return output, lse
|
852
872
|
return output
|
853
873
|
else:
|
874
|
+
assert self.fa_impl_ver in [3], "Only FA3 support here"
|
854
875
|
# Do absorbed multi-latent attention
|
855
876
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
856
877
|
layer.layer_id
|
@@ -892,6 +913,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
892
913
|
k_descale=k_descale,
|
893
914
|
v_descale=v_descale,
|
894
915
|
return_softmax_lse=use_cascade_attn,
|
916
|
+
num_splits=self.num_splits,
|
895
917
|
)
|
896
918
|
if use_cascade_attn:
|
897
919
|
o, softmax_lse, *rest = result
|
@@ -913,6 +935,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
913
935
|
k_descale=k_descale,
|
914
936
|
v_descale=v_descale,
|
915
937
|
return_softmax_lse=True,
|
938
|
+
num_splits=self.num_splits,
|
916
939
|
)
|
917
940
|
)
|
918
941
|
o, _ = merge_state_v2_wrapper(
|
@@ -939,6 +962,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
939
962
|
k_rope: Optional[torch.Tensor] = None,
|
940
963
|
sinks: Optional[torch.Tensor] = None,
|
941
964
|
) -> torch.Tensor:
|
965
|
+
assert self.fa_impl_ver in [3], "Only FA3 support decoding"
|
942
966
|
if k is not None:
|
943
967
|
assert v is not None
|
944
968
|
if save_kv_cache:
|
@@ -985,6 +1009,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
985
1009
|
|
986
1010
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
987
1011
|
kwargs = {}
|
1012
|
+
if self.fa_impl_ver != 3:
|
1013
|
+
kwargs["ver"] = self.fa_impl_ver
|
988
1014
|
if sinks is not None:
|
989
1015
|
kwargs["sinks"] = sinks
|
990
1016
|
|
@@ -1030,6 +1056,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1030
1056
|
softcap=layer.logit_cap,
|
1031
1057
|
k_descale=k_descale,
|
1032
1058
|
v_descale=v_descale,
|
1059
|
+
num_splits=self.num_splits,
|
1033
1060
|
**kwargs,
|
1034
1061
|
)
|
1035
1062
|
elif use_local_attn:
|
@@ -1049,6 +1076,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1049
1076
|
softcap=layer.logit_cap,
|
1050
1077
|
k_descale=k_descale,
|
1051
1078
|
v_descale=v_descale,
|
1079
|
+
num_splits=self.num_splits,
|
1052
1080
|
**kwargs,
|
1053
1081
|
)
|
1054
1082
|
else:
|
@@ -1077,6 +1105,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1077
1105
|
k_descale=k_descale,
|
1078
1106
|
v_descale=v_descale,
|
1079
1107
|
return_softmax_lse=use_cascade_attn,
|
1108
|
+
num_splits=self.num_splits,
|
1080
1109
|
**kwargs,
|
1081
1110
|
)
|
1082
1111
|
if use_cascade_attn:
|
@@ -1098,6 +1127,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1098
1127
|
k_descale=k_descale,
|
1099
1128
|
v_descale=v_descale,
|
1100
1129
|
return_softmax_lse=True,
|
1130
|
+
num_splits=self.num_splits,
|
1101
1131
|
**kwargs,
|
1102
1132
|
)
|
1103
1133
|
)
|
@@ -1153,6 +1183,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1153
1183
|
k_descale=k_descale,
|
1154
1184
|
v_descale=v_descale,
|
1155
1185
|
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
|
1186
|
+
num_splits=self.num_splits,
|
1156
1187
|
)
|
1157
1188
|
if use_cascade_attn:
|
1158
1189
|
o, softmax_lse, *rest = result
|
@@ -1173,6 +1204,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1173
1204
|
k_descale=k_descale,
|
1174
1205
|
v_descale=v_descale,
|
1175
1206
|
return_softmax_lse=True,
|
1207
|
+
num_splits=self.num_splits,
|
1176
1208
|
)
|
1177
1209
|
o, _ = merge_state_v2(
|
1178
1210
|
o,
|
@@ -1453,7 +1485,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1453
1485
|
seq_lens: torch.Tensor,
|
1454
1486
|
encoder_lens: Optional[torch.Tensor],
|
1455
1487
|
forward_mode: ForwardMode,
|
1456
|
-
spec_info: Optional[
|
1488
|
+
spec_info: Optional[SpecInput],
|
1457
1489
|
):
|
1458
1490
|
"""Initialize forward metadata for capturing CUDA graph."""
|
1459
1491
|
metadata = FlashAttentionMetadata()
|
@@ -1688,7 +1720,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1688
1720
|
seq_lens_sum: int,
|
1689
1721
|
encoder_lens: Optional[torch.Tensor],
|
1690
1722
|
forward_mode: ForwardMode,
|
1691
|
-
spec_info: Optional[
|
1723
|
+
spec_info: Optional[SpecInput],
|
1692
1724
|
seq_lens_cpu: Optional[torch.Tensor],
|
1693
1725
|
out_cache_loc: Optional[torch.Tensor] = None,
|
1694
1726
|
):
|
@@ -2306,7 +2338,7 @@ class FlashAttentionMultiStepBackend:
|
|
2306
2338
|
forward_batch: ForwardBatch,
|
2307
2339
|
):
|
2308
2340
|
assert forward_batch.spec_info is not None
|
2309
|
-
assert
|
2341
|
+
assert forward_batch.spec_info.is_draft_input()
|
2310
2342
|
|
2311
2343
|
for i in range(self.speculative_num_steps - 1):
|
2312
2344
|
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
@@ -2323,7 +2355,7 @@ class FlashAttentionMultiStepBackend:
|
|
2323
2355
|
self, forward_batch: ForwardBatch, bs: int
|
2324
2356
|
):
|
2325
2357
|
assert forward_batch.spec_info is not None
|
2326
|
-
assert
|
2358
|
+
assert forward_batch.spec_info.is_draft_input()
|
2327
2359
|
|
2328
2360
|
for i in range(self.speculative_num_steps - 1):
|
2329
2361
|
# TODO: incrementally update the metadata for the later steps,
|