sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,10 @@
|
|
1
|
-
from typing import TYPE_CHECKING, Callable, List, Optional
|
1
|
+
from typing import TYPE_CHECKING, Callable, List, Optional
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from sglang.srt import two_batch_overlap
|
6
6
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
7
|
-
from sglang.srt.speculative.
|
7
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
8
8
|
|
9
9
|
if TYPE_CHECKING:
|
10
10
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
@@ -46,7 +46,7 @@ class TboAttnBackend(AttentionBackend):
|
|
46
46
|
seq_lens: torch.Tensor,
|
47
47
|
encoder_lens: Optional[torch.Tensor],
|
48
48
|
forward_mode: "ForwardMode",
|
49
|
-
spec_info: Optional[
|
49
|
+
spec_info: Optional[SpecInput],
|
50
50
|
):
|
51
51
|
self.primary.init_forward_metadata_capture_cuda_graph(
|
52
52
|
bs=bs,
|
@@ -77,7 +77,7 @@ class TboAttnBackend(AttentionBackend):
|
|
77
77
|
seq_lens_sum: int,
|
78
78
|
encoder_lens: Optional[torch.Tensor],
|
79
79
|
forward_mode: "ForwardMode",
|
80
|
-
spec_info: Optional[
|
80
|
+
spec_info: Optional[SpecInput],
|
81
81
|
seq_lens_cpu: Optional[torch.Tensor],
|
82
82
|
):
|
83
83
|
self.primary.init_forward_metadata_replay_cuda_graph(
|
@@ -112,7 +112,7 @@ class TboAttnBackend(AttentionBackend):
|
|
112
112
|
seq_lens: torch.Tensor,
|
113
113
|
encoder_lens: Optional[torch.Tensor],
|
114
114
|
forward_mode: "ForwardMode",
|
115
|
-
spec_info: Optional[
|
115
|
+
spec_info: Optional[SpecInput],
|
116
116
|
# capture args
|
117
117
|
capture_num_tokens: int = None,
|
118
118
|
# replay args
|
@@ -196,7 +196,7 @@ def _init_forward_metadata_cuda_graph_split(
|
|
196
196
|
seq_lens: torch.Tensor,
|
197
197
|
encoder_lens: Optional[torch.Tensor],
|
198
198
|
forward_mode: "ForwardMode",
|
199
|
-
spec_info: Optional[
|
199
|
+
spec_info: Optional[SpecInput],
|
200
200
|
# capture args
|
201
201
|
capture_num_tokens: int = None,
|
202
202
|
# replay args
|
@@ -0,0 +1,325 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
7
|
+
|
8
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
9
|
+
from sglang.srt.layers.radix_attention import AttentionType
|
10
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
14
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
15
|
+
|
16
|
+
|
17
|
+
class TorchFlexAttnBackend(AttentionBackend):
|
18
|
+
def __init__(self, model_runner: ModelRunner):
|
19
|
+
super().__init__()
|
20
|
+
self.forward_metadata = None
|
21
|
+
self.device = model_runner.device
|
22
|
+
self.flex_attention = torch.compile(flex_attention, dynamic=True)
|
23
|
+
torch._dynamo.config.cache_size_limit = 1024
|
24
|
+
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
25
|
+
|
26
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
27
|
+
"""Init the metadata for a forward pass."""
|
28
|
+
# TODO: find a more elegant way to save memory
|
29
|
+
# Currently maintain the same memory as torch_native_backend
|
30
|
+
torch.cuda.empty_cache()
|
31
|
+
|
32
|
+
# Provide two block_mask Lists per seq_idx for lower latency, later will support per layer level mask generation
|
33
|
+
self.extend_block_masks = []
|
34
|
+
self.decode_block_masks = []
|
35
|
+
|
36
|
+
if forward_batch.forward_mode.is_extend():
|
37
|
+
for seq_idx in range(forward_batch.seq_lens.shape[0]):
|
38
|
+
seq_len_kv = forward_batch.seq_lens[seq_idx]
|
39
|
+
seq_len_q = seq_len_kv
|
40
|
+
self.extend_block_masks.append(
|
41
|
+
create_block_mask(
|
42
|
+
self._causal_mask,
|
43
|
+
None,
|
44
|
+
None,
|
45
|
+
seq_len_q,
|
46
|
+
seq_len_kv,
|
47
|
+
device=self.device,
|
48
|
+
_compile=False,
|
49
|
+
)
|
50
|
+
)
|
51
|
+
|
52
|
+
elif forward_batch.forward_mode.is_decode():
|
53
|
+
for seq_idx in range(forward_batch.seq_lens.shape[0]):
|
54
|
+
seq_len_q = 1
|
55
|
+
seq_len_kv = forward_batch.seq_lens[seq_idx]
|
56
|
+
|
57
|
+
self.decode_block_masks.append(
|
58
|
+
create_block_mask(
|
59
|
+
self._decode_mask,
|
60
|
+
None,
|
61
|
+
None,
|
62
|
+
seq_len_q,
|
63
|
+
seq_len_kv,
|
64
|
+
device=self.device,
|
65
|
+
_compile=False,
|
66
|
+
)
|
67
|
+
)
|
68
|
+
|
69
|
+
def _causal_mask(self, b, h, q_idx, kv_idx):
|
70
|
+
return q_idx >= kv_idx
|
71
|
+
|
72
|
+
def _decode_mask(self, b, h, q_idx, kv_idx):
|
73
|
+
return q_idx <= kv_idx
|
74
|
+
|
75
|
+
def _run_flex_forward_extend(
|
76
|
+
self,
|
77
|
+
query: torch.Tensor,
|
78
|
+
output: torch.Tensor,
|
79
|
+
k_cache: torch.Tensor,
|
80
|
+
v_cache: torch.Tensor,
|
81
|
+
req_to_token: torch.Tensor,
|
82
|
+
req_pool_indices: torch.Tensor,
|
83
|
+
seq_lens: torch.Tensor,
|
84
|
+
extend_prefix_lens: torch.Tensor,
|
85
|
+
extend_seq_lens: torch.Tensor,
|
86
|
+
scaling=None,
|
87
|
+
enable_gqa=False,
|
88
|
+
causal=False,
|
89
|
+
):
|
90
|
+
"""Run the extend forward by using torch flex attention op.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
query: [num_tokens, num_heads, head_size]
|
94
|
+
output: [num_tokens, num_heads, head_size]
|
95
|
+
k_cache: [max_total_num_tokens, num_heads, head_size]
|
96
|
+
v_cache: [max_total_num_tokens, num_heads, head_size]
|
97
|
+
req_to_token: [max_num_reqs, max_context_len]
|
98
|
+
req_pool_indices: [num_seqs]
|
99
|
+
seq_lens: [num_seqs]
|
100
|
+
extend_prefix_lens: [num_seqs]
|
101
|
+
extend_seq_lens: [num_seqs]
|
102
|
+
scaling: float or None
|
103
|
+
enable_gqa: bool
|
104
|
+
causal: bool
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
output: [num_tokens, num_heads, head_size]
|
108
|
+
"""
|
109
|
+
|
110
|
+
assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
|
111
|
+
assert seq_lens.shape[0] == extend_seq_lens.shape[0]
|
112
|
+
|
113
|
+
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
114
|
+
query = query.movedim(0, query.dim() - 2)
|
115
|
+
|
116
|
+
start_q, start_kv = 0, 0
|
117
|
+
|
118
|
+
for seq_idx in range(seq_lens.shape[0]):
|
119
|
+
# TODO: this loop process a sequence per iter, this is inefficient.
|
120
|
+
# Need optimize the performance later.
|
121
|
+
extend_seq_len_q = extend_seq_lens[seq_idx]
|
122
|
+
prefill_seq_len_q = extend_prefix_lens[seq_idx]
|
123
|
+
|
124
|
+
seq_len_kv = seq_lens[seq_idx]
|
125
|
+
end_q = start_q + extend_seq_len_q
|
126
|
+
end_kv = start_kv + seq_len_kv
|
127
|
+
|
128
|
+
per_req_query = query[:, start_q:end_q, :]
|
129
|
+
per_req_query_redundant = torch.empty(
|
130
|
+
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
|
131
|
+
dtype=per_req_query.dtype,
|
132
|
+
device=per_req_query.device,
|
133
|
+
)
|
134
|
+
|
135
|
+
per_req_query_redundant[:, prefill_seq_len_q:, :] = per_req_query
|
136
|
+
|
137
|
+
# get key and value from cache. per_req_tokens contains the kv cache
|
138
|
+
# index for each token in the sequence.
|
139
|
+
req_pool_idx = req_pool_indices[seq_idx]
|
140
|
+
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
141
|
+
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
142
|
+
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
143
|
+
|
144
|
+
if not causal:
|
145
|
+
raise NotImplementedError("Non-causal mode is not yet implemented.")
|
146
|
+
|
147
|
+
per_req_out_redundant = (
|
148
|
+
self.flex_attention(
|
149
|
+
per_req_query_redundant.unsqueeze(0),
|
150
|
+
per_req_key.unsqueeze(0),
|
151
|
+
per_req_value.unsqueeze(0),
|
152
|
+
block_mask=self.extend_block_masks[seq_idx],
|
153
|
+
scale=scaling,
|
154
|
+
enable_gqa=enable_gqa,
|
155
|
+
)
|
156
|
+
.squeeze(0)
|
157
|
+
.movedim(query.dim() - 2, 0)
|
158
|
+
)
|
159
|
+
output[start_q:end_q, :, :] = per_req_out_redundant[
|
160
|
+
prefill_seq_len_q:, :, :
|
161
|
+
]
|
162
|
+
start_q, start_kv = end_q, end_kv
|
163
|
+
return output
|
164
|
+
|
165
|
+
def _run_flex_forward_decode(
|
166
|
+
self,
|
167
|
+
query: torch.Tensor,
|
168
|
+
output: torch.Tensor,
|
169
|
+
k_cache: torch.Tensor,
|
170
|
+
v_cache: torch.Tensor,
|
171
|
+
req_to_token: torch.Tensor,
|
172
|
+
req_pool_indices: torch.Tensor,
|
173
|
+
seq_lens: torch.Tensor,
|
174
|
+
scaling=None,
|
175
|
+
enable_gqa=False,
|
176
|
+
causal=False,
|
177
|
+
):
|
178
|
+
"""Run the decode forward by using torch flex attention op.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
query: [num_tokens, num_heads, head_size]
|
182
|
+
output: [num_tokens, num_heads, head_size]
|
183
|
+
k_cache: [max_total_num_tokens, num_heads, head_size]
|
184
|
+
v_cache: [max_total_num_tokens, num_heads, head_size]
|
185
|
+
req_to_token: [max_num_reqs, max_context_len]
|
186
|
+
req_pool_indices: [num_seqs]
|
187
|
+
seq_lens: [num_seqs]
|
188
|
+
scaling: float or None
|
189
|
+
enable_gqa: bool
|
190
|
+
causal: bool
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
output: [num_tokens, num_heads, head_size]
|
194
|
+
"""
|
195
|
+
|
196
|
+
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
197
|
+
query = query.movedim(0, query.dim() - 2)
|
198
|
+
|
199
|
+
start_q, start_kv = 0, 0
|
200
|
+
for seq_idx in range(seq_lens.shape[0]):
|
201
|
+
# TODO: this loop process a sequence per iter, this is inefficient.
|
202
|
+
# Need optimize the performance later.
|
203
|
+
|
204
|
+
seq_len_q = 1
|
205
|
+
seq_len_kv = seq_lens[seq_idx]
|
206
|
+
end_q = start_q + seq_len_q
|
207
|
+
end_kv = start_kv + seq_len_kv
|
208
|
+
|
209
|
+
per_req_query = query[:, start_q:end_q, :]
|
210
|
+
|
211
|
+
# get key and value from cache. per_req_tokens contains the kv cache
|
212
|
+
# index for each token in the sequence.
|
213
|
+
req_pool_idx = req_pool_indices[seq_idx]
|
214
|
+
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
215
|
+
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
216
|
+
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
217
|
+
|
218
|
+
per_req_out = (
|
219
|
+
self.flex_attention(
|
220
|
+
per_req_query.unsqueeze(0),
|
221
|
+
per_req_key.unsqueeze(0),
|
222
|
+
per_req_value.unsqueeze(0),
|
223
|
+
block_mask=self.decode_block_masks[seq_idx],
|
224
|
+
scale=scaling,
|
225
|
+
enable_gqa=enable_gqa,
|
226
|
+
)
|
227
|
+
.squeeze(0)
|
228
|
+
.movedim(query.dim() - 2, 0)
|
229
|
+
)
|
230
|
+
|
231
|
+
output[start_q:end_q, :, :] = per_req_out
|
232
|
+
start_q, start_kv = end_q, end_kv
|
233
|
+
|
234
|
+
return output
|
235
|
+
|
236
|
+
def forward_extend(
|
237
|
+
self,
|
238
|
+
q,
|
239
|
+
k,
|
240
|
+
v,
|
241
|
+
layer: RadixAttention,
|
242
|
+
forward_batch: ForwardBatch,
|
243
|
+
save_kv_cache=True,
|
244
|
+
):
|
245
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
246
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
247
|
+
else:
|
248
|
+
o = torch.empty_like(q)
|
249
|
+
|
250
|
+
if save_kv_cache:
|
251
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
252
|
+
layer, forward_batch.out_cache_loc, k, v
|
253
|
+
)
|
254
|
+
|
255
|
+
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
256
|
+
|
257
|
+
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
258
|
+
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
259
|
+
|
260
|
+
causal = True
|
261
|
+
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
262
|
+
raise NotImplementedError(
|
263
|
+
"TorchFlexAttnBackend does not support non-causal attention for now."
|
264
|
+
)
|
265
|
+
|
266
|
+
self._run_flex_forward_extend(
|
267
|
+
q_,
|
268
|
+
o_,
|
269
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
270
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
271
|
+
forward_batch.req_to_token_pool.req_to_token,
|
272
|
+
forward_batch.req_pool_indices,
|
273
|
+
forward_batch.seq_lens,
|
274
|
+
forward_batch.extend_prefix_lens,
|
275
|
+
forward_batch.extend_seq_lens,
|
276
|
+
scaling=layer.scaling,
|
277
|
+
enable_gqa=use_gqa,
|
278
|
+
causal=causal,
|
279
|
+
)
|
280
|
+
return o
|
281
|
+
|
282
|
+
def forward_decode(
|
283
|
+
self,
|
284
|
+
q,
|
285
|
+
k,
|
286
|
+
v,
|
287
|
+
layer: RadixAttention,
|
288
|
+
forward_batch: ForwardBatch,
|
289
|
+
save_kv_cache=True,
|
290
|
+
):
|
291
|
+
# During torch.compile, there is a bug in rotary_emb that causes the
|
292
|
+
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
293
|
+
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
294
|
+
|
295
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
296
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
297
|
+
else:
|
298
|
+
o = torch.empty_like(q)
|
299
|
+
|
300
|
+
if save_kv_cache:
|
301
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
302
|
+
layer, forward_batch.out_cache_loc, k, v
|
303
|
+
)
|
304
|
+
|
305
|
+
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
306
|
+
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
307
|
+
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
308
|
+
|
309
|
+
self._run_flex_forward_decode(
|
310
|
+
q_,
|
311
|
+
o_,
|
312
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
313
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
314
|
+
forward_batch.req_to_token_pool.req_to_token,
|
315
|
+
forward_batch.req_pool_indices,
|
316
|
+
forward_batch.seq_lens,
|
317
|
+
scaling=layer.scaling,
|
318
|
+
enable_gqa=use_gqa,
|
319
|
+
causal=False,
|
320
|
+
)
|
321
|
+
|
322
|
+
return o
|
323
|
+
|
324
|
+
def support_triton(self):
|
325
|
+
return False
|
@@ -193,10 +193,13 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
193
193
|
else:
|
194
194
|
o = torch.empty_like(q)
|
195
195
|
|
196
|
+
if layer.is_cross_attention:
|
197
|
+
cache_loc = forward_batch.encoder_out_cache_loc
|
198
|
+
else:
|
199
|
+
cache_loc = forward_batch.out_cache_loc
|
200
|
+
|
196
201
|
if save_kv_cache:
|
197
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
198
|
-
layer, forward_batch.out_cache_loc, k, v
|
199
|
-
)
|
202
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
200
203
|
|
201
204
|
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
202
205
|
|
@@ -241,10 +244,13 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
241
244
|
else:
|
242
245
|
o = torch.empty_like(q)
|
243
246
|
|
247
|
+
if layer.is_cross_attention:
|
248
|
+
cache_loc = forward_batch.encoder_out_cache_loc
|
249
|
+
else:
|
250
|
+
cache_loc = forward_batch.out_cache_loc
|
251
|
+
|
244
252
|
if save_kv_cache:
|
245
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
246
|
-
layer, forward_batch.out_cache_loc, k, v
|
247
|
-
)
|
253
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
248
254
|
|
249
255
|
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
250
256
|
|
@@ -12,12 +12,17 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
|
12
12
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
13
13
|
from sglang.srt.layers.radix_attention import AttentionType
|
14
14
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
15
|
-
from sglang.srt.utils import
|
15
|
+
from sglang.srt.utils import (
|
16
|
+
get_bool_env_var,
|
17
|
+
get_device_core_count,
|
18
|
+
get_int_env_var,
|
19
|
+
next_power_of_2,
|
20
|
+
)
|
16
21
|
|
17
22
|
if TYPE_CHECKING:
|
18
23
|
from sglang.srt.layers.radix_attention import RadixAttention
|
19
24
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
20
|
-
from sglang.srt.speculative.
|
25
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
21
26
|
|
22
27
|
|
23
28
|
def logit_capping_mod(logit_capping_method, logit_cap):
|
@@ -80,7 +85,13 @@ class TritonAttnBackend(AttentionBackend):
|
|
80
85
|
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
81
86
|
get_attention_tp_size()
|
82
87
|
)
|
83
|
-
|
88
|
+
if model_runner.is_hybrid_gdn:
|
89
|
+
# For hybrid linear models, layer_id = 0 may not be full attention
|
90
|
+
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
|
91
|
+
else:
|
92
|
+
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[
|
93
|
+
-1
|
94
|
+
]
|
84
95
|
self.max_context_len = model_runner.model_config.context_len
|
85
96
|
self.device = model_runner.device
|
86
97
|
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
@@ -89,6 +100,29 @@ class TritonAttnBackend(AttentionBackend):
|
|
89
100
|
)
|
90
101
|
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
91
102
|
|
103
|
+
# Decide whether enable deterministic inference with batch-invariant operations
|
104
|
+
self.enable_deterministic = (
|
105
|
+
model_runner.server_args.enable_deterministic_inference
|
106
|
+
)
|
107
|
+
|
108
|
+
# Configure deterministic inference settings
|
109
|
+
if self.enable_deterministic:
|
110
|
+
# Use fixed split tile size for batch invariance
|
111
|
+
self.split_tile_size = get_int_env_var(
|
112
|
+
"SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256
|
113
|
+
)
|
114
|
+
# Set static_kv_splits to False to use deterministic logic instead
|
115
|
+
self.static_kv_splits = False
|
116
|
+
else:
|
117
|
+
self.split_tile_size = (
|
118
|
+
model_runner.server_args.triton_attention_split_tile_size
|
119
|
+
)
|
120
|
+
|
121
|
+
if self.split_tile_size is not None:
|
122
|
+
self.max_kv_splits = (
|
123
|
+
self.max_context_len + self.split_tile_size - 1
|
124
|
+
) // self.split_tile_size
|
125
|
+
|
92
126
|
# Check arguments
|
93
127
|
assert not (
|
94
128
|
model_runner.sliding_window_size is not None
|
@@ -143,10 +177,26 @@ class TritonAttnBackend(AttentionBackend):
|
|
143
177
|
num_group * num_seq == num_token
|
144
178
|
), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"
|
145
179
|
|
146
|
-
|
180
|
+
# Legacy dynamic splitting logic (non-deterministic)
|
181
|
+
if (
|
182
|
+
self.static_kv_splits or self.device_core_count <= 0
|
183
|
+
) and not self.enable_deterministic:
|
147
184
|
num_kv_splits.fill_(self.max_kv_splits)
|
148
185
|
return
|
149
186
|
|
187
|
+
# deterministic
|
188
|
+
if self.split_tile_size is not None and self.enable_deterministic:
|
189
|
+
# expand seq_lens to match num_token
|
190
|
+
if num_group > 1:
|
191
|
+
expanded_seq_lens = seq_lens.repeat_interleave(num_group)
|
192
|
+
else:
|
193
|
+
expanded_seq_lens = seq_lens
|
194
|
+
|
195
|
+
num_kv_splits[:] = (
|
196
|
+
expanded_seq_lens + self.split_tile_size - 1
|
197
|
+
) // self.split_tile_size
|
198
|
+
return
|
199
|
+
|
150
200
|
if num_seq < 256:
|
151
201
|
SCHEDULE_SEQ = 256
|
152
202
|
else:
|
@@ -432,7 +482,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
432
482
|
seq_lens: torch.Tensor,
|
433
483
|
encoder_lens: Optional[torch.Tensor],
|
434
484
|
forward_mode: ForwardMode,
|
435
|
-
spec_info: Optional[
|
485
|
+
spec_info: Optional[SpecInput],
|
436
486
|
):
|
437
487
|
assert encoder_lens is None, "Not supported"
|
438
488
|
window_kv_indptr = self.window_kv_indptr
|
@@ -588,7 +638,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
588
638
|
seq_lens_sum: int,
|
589
639
|
encoder_lens: Optional[torch.Tensor],
|
590
640
|
forward_mode: ForwardMode,
|
591
|
-
spec_info: Optional[
|
641
|
+
spec_info: Optional[SpecInput],
|
592
642
|
seq_lens_cpu: Optional[torch.Tensor],
|
593
643
|
):
|
594
644
|
# NOTE: encoder_lens expected to be zeros or None
|
@@ -833,7 +883,7 @@ class TritonMultiStepDraftBackend:
|
|
833
883
|
topk: int,
|
834
884
|
speculative_num_steps: int,
|
835
885
|
):
|
836
|
-
from sglang.srt.speculative.
|
886
|
+
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
837
887
|
|
838
888
|
self.topk = topk
|
839
889
|
self.speculative_num_steps = speculative_num_steps
|
@@ -20,12 +20,10 @@ from sglang.srt.utils import is_flashinfer_available
|
|
20
20
|
if is_flashinfer_available():
|
21
21
|
import flashinfer
|
22
22
|
|
23
|
-
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
24
|
-
|
25
23
|
if TYPE_CHECKING:
|
26
24
|
from sglang.srt.layers.radix_attention import RadixAttention
|
27
25
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
28
|
-
from sglang.srt.speculative.spec_info import
|
26
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
29
27
|
|
30
28
|
# Constants
|
31
29
|
DEFAULT_WORKSPACE_SIZE_MB = (
|
@@ -201,7 +199,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
201
199
|
seq_lens: torch.Tensor,
|
202
200
|
encoder_lens: Optional[torch.Tensor],
|
203
201
|
forward_mode: ForwardMode,
|
204
|
-
spec_info: Optional[
|
202
|
+
spec_info: Optional[SpecInput],
|
205
203
|
):
|
206
204
|
"""Initialize metadata for CUDA graph capture."""
|
207
205
|
metadata = TRTLLMMHAMetadata()
|
@@ -314,7 +312,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
|
314
312
|
seq_lens_sum: int,
|
315
313
|
encoder_lens: Optional[torch.Tensor],
|
316
314
|
forward_mode: ForwardMode,
|
317
|
-
spec_info: Optional[
|
315
|
+
spec_info: Optional[SpecInput],
|
318
316
|
seq_lens_cpu: Optional[torch.Tensor],
|
319
317
|
):
|
320
318
|
"""Replay CUDA graph with new inputs."""
|
@@ -661,7 +659,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
|
|
661
659
|
forward_batch: ForwardBatch,
|
662
660
|
):
|
663
661
|
assert forward_batch.spec_info is not None
|
664
|
-
assert
|
662
|
+
assert forward_batch.spec_info.is_draft_input()
|
665
663
|
|
666
664
|
for i in range(self.speculative_num_steps - 1):
|
667
665
|
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
@@ -678,7 +676,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
|
|
678
676
|
self, forward_batch: ForwardBatch, bs: int
|
679
677
|
):
|
680
678
|
assert forward_batch.spec_info is not None
|
681
|
-
assert
|
679
|
+
assert forward_batch.spec_info.is_draft_input()
|
682
680
|
|
683
681
|
for i in range(self.speculative_num_steps - 1):
|
684
682
|
|