sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import enum
|
4
|
+
|
3
5
|
# Copyright 2023-2024 SGLang Team
|
4
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
7
|
# you may not use this file except in compliance with the License.
|
@@ -35,6 +37,7 @@ import copy
|
|
35
37
|
import dataclasses
|
36
38
|
import logging
|
37
39
|
import threading
|
40
|
+
import time
|
38
41
|
from enum import Enum, auto
|
39
42
|
from http import HTTPStatus
|
40
43
|
from itertools import chain
|
@@ -51,18 +54,18 @@ from sglang.srt.disaggregation.base import BaseKVSender
|
|
51
54
|
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
52
55
|
ScheduleBatchDisaggregationDecodeMixin,
|
53
56
|
)
|
57
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
54
58
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
55
|
-
from sglang.srt.layers.moe import is_tbo_enabled
|
56
59
|
from sglang.srt.mem_cache.allocator import (
|
57
60
|
BaseTokenToKVPoolAllocator,
|
58
61
|
SWATokenToKVPoolAllocator,
|
59
62
|
)
|
60
63
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
61
64
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
62
|
-
from sglang.srt.mem_cache.
|
63
|
-
from sglang.srt.mem_cache.
|
65
|
+
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
66
|
+
from sglang.srt.mem_cache.radix_cache import RadixKey
|
64
67
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
65
|
-
from sglang.srt.metrics.collector import TimeStats
|
68
|
+
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
66
69
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
67
70
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
68
71
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -71,8 +74,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
|
|
71
74
|
|
72
75
|
if TYPE_CHECKING:
|
73
76
|
from sglang.srt.configs.model_config import ModelConfig
|
74
|
-
from sglang.srt.speculative.
|
75
|
-
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
77
|
+
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
|
76
78
|
|
77
79
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
78
80
|
|
@@ -87,6 +89,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
87
89
|
"disable_flashinfer_cutlass_moe_fp4_allgather",
|
88
90
|
"disable_radix_cache",
|
89
91
|
"enable_dp_lm_head",
|
92
|
+
"enable_fp32_lm_head",
|
90
93
|
"flashinfer_mxfp4_moe_precision",
|
91
94
|
"enable_flashinfer_allreduce_fusion",
|
92
95
|
"moe_dense_tp_size",
|
@@ -99,15 +102,18 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
99
102
|
"sampling_backend",
|
100
103
|
"speculative_accept_threshold_single",
|
101
104
|
"speculative_accept_threshold_acc",
|
105
|
+
"speculative_attention_mode",
|
102
106
|
"torchao_config",
|
103
107
|
"triton_attention_reduce_in_fp32",
|
104
108
|
"num_reserved_decode_tokens",
|
105
109
|
"weight_loader_disable_mmap",
|
106
110
|
"enable_multimodal",
|
107
111
|
"enable_symm_mem",
|
108
|
-
"quantization",
|
109
112
|
"enable_custom_logit_processor",
|
110
113
|
"disaggregation_mode",
|
114
|
+
"enable_deterministic_inference",
|
115
|
+
"nsa_prefill",
|
116
|
+
"nsa_decode",
|
111
117
|
]
|
112
118
|
|
113
119
|
# Put some global args for easy access
|
@@ -408,6 +414,23 @@ class MultimodalInputs:
|
|
408
414
|
# other args would be kept intact
|
409
415
|
|
410
416
|
|
417
|
+
class RequestStage(str, enum.Enum):
|
418
|
+
# prefill
|
419
|
+
PREFILL_WAITING = "prefill_waiting"
|
420
|
+
|
421
|
+
# disaggregation prefill
|
422
|
+
PREFILL_PREPARE = "prefill_prepare"
|
423
|
+
PREFILL_BOOTSTRAP = "prefill_bootstrap"
|
424
|
+
PREFILL_FORWARD = "prefill_forward"
|
425
|
+
PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
|
426
|
+
|
427
|
+
# disaggregation decode
|
428
|
+
DECODE_PREPARE = "decode_prepare"
|
429
|
+
DECODE_BOOTSTRAP = "decode_bootstrap"
|
430
|
+
DECODE_WAITING = "decode_waiting"
|
431
|
+
DECODE_TRANSFERRED = "decode_transferred"
|
432
|
+
|
433
|
+
|
411
434
|
class Req:
|
412
435
|
"""The input and output status of a request."""
|
413
436
|
|
@@ -432,8 +455,12 @@ class Req:
|
|
432
455
|
bootstrap_host: Optional[str] = None,
|
433
456
|
bootstrap_port: Optional[int] = None,
|
434
457
|
bootstrap_room: Optional[int] = None,
|
458
|
+
disagg_mode: Optional[DisaggregationMode] = None,
|
435
459
|
data_parallel_rank: Optional[int] = None,
|
436
460
|
vocab_size: Optional[int] = None,
|
461
|
+
priority: Optional[int] = None,
|
462
|
+
metrics_collector: Optional[SchedulerMetricsCollector] = None,
|
463
|
+
extra_key: Optional[str] = None,
|
437
464
|
):
|
438
465
|
# Input and output info
|
439
466
|
self.rid = rid
|
@@ -466,6 +493,14 @@ class Req:
|
|
466
493
|
self.sampling_params = sampling_params
|
467
494
|
self.custom_logit_processor = custom_logit_processor
|
468
495
|
self.return_hidden_states = return_hidden_states
|
496
|
+
|
497
|
+
# extra key for classifying the request (e.g. cache_salt)
|
498
|
+
if lora_id is not None:
|
499
|
+
extra_key = (
|
500
|
+
extra_key or ""
|
501
|
+
) + lora_id # lora_id is concatenated to the extra key
|
502
|
+
|
503
|
+
self.extra_key = extra_key
|
469
504
|
self.lora_id = lora_id
|
470
505
|
|
471
506
|
# Memory pool info
|
@@ -484,6 +519,7 @@ class Req:
|
|
484
519
|
self.stream = stream
|
485
520
|
self.eos_token_ids = eos_token_ids
|
486
521
|
self.vocab_size = vocab_size
|
522
|
+
self.priority = priority
|
487
523
|
|
488
524
|
# For incremental decoding
|
489
525
|
# ----- | --------- read_ids -------|
|
@@ -513,6 +549,8 @@ class Req:
|
|
513
549
|
self.host_hit_length = 0
|
514
550
|
# The node to lock until for swa radix tree lock ref
|
515
551
|
self.swa_uuid_for_lock: Optional[int] = None
|
552
|
+
# The prefix length of the last prefix matching
|
553
|
+
self.last_matched_prefix_len: int = 0
|
516
554
|
|
517
555
|
# Whether or not if it is chunked. It increments whenever
|
518
556
|
# it is chunked, and decrement whenever chunked request is
|
@@ -561,7 +599,10 @@ class Req:
|
|
561
599
|
# shape: (bs, k)
|
562
600
|
self.output_top_logprobs_val = []
|
563
601
|
self.output_top_logprobs_idx = []
|
564
|
-
|
602
|
+
# Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
|
603
|
+
self.output_token_ids_logprobs_val: List[
|
604
|
+
Union[List[float], torch.Tensor]
|
605
|
+
] = []
|
565
606
|
self.output_token_ids_logprobs_idx = []
|
566
607
|
else:
|
567
608
|
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
@@ -571,6 +612,8 @@ class Req:
|
|
571
612
|
) = None
|
572
613
|
self.hidden_states: List[List[float]] = []
|
573
614
|
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
|
615
|
+
self.output_topk_p = None
|
616
|
+
self.output_topk_index = None
|
574
617
|
|
575
618
|
# Embedding (return values)
|
576
619
|
self.embedding = None
|
@@ -588,10 +631,10 @@ class Req:
|
|
588
631
|
self.spec_verify_ct = 0
|
589
632
|
|
590
633
|
# For metrics
|
591
|
-
self.
|
634
|
+
self.metrics_collector = metrics_collector
|
635
|
+
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
|
592
636
|
self.has_log_time_stats: bool = False
|
593
|
-
self.
|
594
|
-
self.queue_time_end = None
|
637
|
+
self.last_tic = time.monotonic()
|
595
638
|
|
596
639
|
# For disaggregation
|
597
640
|
self.bootstrap_host: str = bootstrap_host
|
@@ -619,6 +662,25 @@ class Req:
|
|
619
662
|
def seqlen(self):
|
620
663
|
return len(self.origin_input_ids) + len(self.output_ids)
|
621
664
|
|
665
|
+
@property
|
666
|
+
def is_prefill_only(self) -> bool:
|
667
|
+
"""Check if this request is prefill-only (no token generation needed)."""
|
668
|
+
# NOTE: when spec is enabled, prefill_only optimizations are disabled
|
669
|
+
return (
|
670
|
+
self.sampling_params.max_new_tokens == 0
|
671
|
+
and global_server_args_dict["speculative_algorithm"] is None
|
672
|
+
)
|
673
|
+
|
674
|
+
def add_latency(self, stage: RequestStage):
|
675
|
+
if self.metrics_collector is None:
|
676
|
+
return
|
677
|
+
|
678
|
+
now = time.monotonic()
|
679
|
+
self.metrics_collector.observe_per_stage_req_latency(
|
680
|
+
stage.value, now - self.last_tic
|
681
|
+
)
|
682
|
+
self.last_tic = now
|
683
|
+
|
622
684
|
def extend_image_inputs(self, image_inputs):
|
623
685
|
if self.multimodal_inputs is None:
|
624
686
|
self.multimodal_inputs = image_inputs
|
@@ -635,26 +697,17 @@ class Req:
|
|
635
697
|
):
|
636
698
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
637
699
|
if tree_cache is not None:
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
else:
|
650
|
-
(
|
651
|
-
self.prefix_indices,
|
652
|
-
self.last_node,
|
653
|
-
self.last_host_node,
|
654
|
-
self.host_hit_length,
|
655
|
-
) = tree_cache.match_prefix(
|
656
|
-
key=self.adjust_max_prefix_ids(),
|
657
|
-
)
|
700
|
+
(
|
701
|
+
self.prefix_indices,
|
702
|
+
self.last_node,
|
703
|
+
self.last_host_node,
|
704
|
+
self.host_hit_length,
|
705
|
+
) = tree_cache.match_prefix(
|
706
|
+
key=RadixKey(
|
707
|
+
token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
|
708
|
+
),
|
709
|
+
)
|
710
|
+
self.last_matched_prefix_len = len(self.prefix_indices)
|
658
711
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
659
712
|
|
660
713
|
def adjust_max_prefix_ids(self):
|
@@ -684,9 +737,15 @@ class Req:
|
|
684
737
|
self.surr_offset = max(
|
685
738
|
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
|
686
739
|
)
|
740
|
+
self.surr_and_decode_ids = (
|
741
|
+
self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
|
742
|
+
)
|
743
|
+
self.cur_decode_ids_len = len(self.output_ids)
|
744
|
+
else:
|
745
|
+
self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
|
746
|
+
self.cur_decode_ids_len = len(self.output_ids)
|
687
747
|
|
688
|
-
|
689
|
-
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
748
|
+
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
|
690
749
|
|
691
750
|
def check_finished(self):
|
692
751
|
if self.finished():
|
@@ -781,10 +840,10 @@ class Req:
|
|
781
840
|
return
|
782
841
|
|
783
842
|
if self.bootstrap_room is not None:
|
784
|
-
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.
|
843
|
+
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
|
785
844
|
else:
|
786
|
-
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.
|
787
|
-
logger.info(f"{prefix}: {self.time_stats}")
|
845
|
+
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
|
846
|
+
logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
|
788
847
|
self.has_log_time_stats = True
|
789
848
|
|
790
849
|
def set_finish_with_abort(self, error_msg: str):
|
@@ -807,10 +866,6 @@ class Req:
|
|
807
866
|
)
|
808
867
|
|
809
868
|
|
810
|
-
# Batch id
|
811
|
-
bid = 0
|
812
|
-
|
813
|
-
|
814
869
|
@dataclasses.dataclass
|
815
870
|
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
816
871
|
"""Store all information of a batch on the scheduler."""
|
@@ -847,6 +902,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
847
902
|
token_type_ids: torch.Tensor = None # shape: [b], int64
|
848
903
|
req_pool_indices: torch.Tensor = None # shape: [b], int64
|
849
904
|
seq_lens: torch.Tensor = None # shape: [b], int64
|
905
|
+
seq_lens_cpu: torch.Tensor = None # shape: [b], int64
|
850
906
|
# The output locations of the KV cache
|
851
907
|
out_cache_loc: torch.Tensor = None # shape: [b], int64
|
852
908
|
output_ids: torch.Tensor = None # shape: [b], int64
|
@@ -902,7 +958,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
902
958
|
|
903
959
|
# Speculative decoding
|
904
960
|
spec_algorithm: SpeculativeAlgorithm = None
|
905
|
-
spec_info: Optional[
|
961
|
+
# spec_info: Optional[SpecInput] = None
|
962
|
+
spec_info: Optional[SpecInput] = None
|
906
963
|
|
907
964
|
# Whether to return hidden states
|
908
965
|
return_hidden_states: bool = False
|
@@ -911,7 +968,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
911
968
|
is_prefill_only: bool = False
|
912
969
|
|
913
970
|
# hicache pointer for synchronizing data loading from CPU to GPU
|
914
|
-
hicache_consumer_index: int =
|
971
|
+
hicache_consumer_index: int = -1
|
915
972
|
|
916
973
|
@classmethod
|
917
974
|
def init_new(
|
@@ -950,9 +1007,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
950
1007
|
device=req_to_token_pool.device,
|
951
1008
|
spec_algorithm=spec_algorithm,
|
952
1009
|
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
953
|
-
is_prefill_only=all(
|
954
|
-
req.sampling_params.max_new_tokens == 0 for req in reqs
|
955
|
-
),
|
1010
|
+
is_prefill_only=all(req.is_prefill_only for req in reqs),
|
956
1011
|
chunked_req=chunked_req,
|
957
1012
|
)
|
958
1013
|
|
@@ -962,8 +1017,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
962
1017
|
def is_empty(self):
|
963
1018
|
return len(self.reqs) == 0
|
964
1019
|
|
965
|
-
def alloc_req_slots(self, num_reqs: int):
|
966
|
-
|
1020
|
+
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
|
1021
|
+
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
|
1022
|
+
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
|
1023
|
+
else:
|
1024
|
+
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
967
1025
|
if req_pool_indices is None:
|
968
1026
|
raise RuntimeError(
|
969
1027
|
"alloc_req_slots runs out of memory. "
|
@@ -1000,7 +1058,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1000
1058
|
def alloc_paged_token_slots_extend(
|
1001
1059
|
self,
|
1002
1060
|
prefix_lens: torch.Tensor,
|
1061
|
+
prefix_lens_cpu: torch.Tensor,
|
1003
1062
|
seq_lens: torch.Tensor,
|
1063
|
+
seq_lens_cpu: torch.Tensor,
|
1004
1064
|
last_loc: torch.Tensor,
|
1005
1065
|
extend_num_tokens: int,
|
1006
1066
|
backup_state: bool = False,
|
@@ -1008,7 +1068,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1008
1068
|
# Over estimate the number of tokens: assume each request needs a new page.
|
1009
1069
|
num_tokens = (
|
1010
1070
|
extend_num_tokens
|
1011
|
-
+ len(
|
1071
|
+
+ len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
|
1012
1072
|
)
|
1013
1073
|
self._evict_tree_cache_if_needed(num_tokens)
|
1014
1074
|
|
@@ -1016,7 +1076,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1016
1076
|
state = self.token_to_kv_pool_allocator.backup_state()
|
1017
1077
|
|
1018
1078
|
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
1019
|
-
prefix_lens,
|
1079
|
+
prefix_lens,
|
1080
|
+
prefix_lens_cpu,
|
1081
|
+
seq_lens,
|
1082
|
+
seq_lens_cpu,
|
1083
|
+
last_loc,
|
1084
|
+
extend_num_tokens,
|
1020
1085
|
)
|
1021
1086
|
if out_cache_loc is None:
|
1022
1087
|
error_msg = (
|
@@ -1035,6 +1100,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1035
1100
|
def alloc_paged_token_slots_decode(
|
1036
1101
|
self,
|
1037
1102
|
seq_lens: torch.Tensor,
|
1103
|
+
seq_lens_cpu: torch.Tensor,
|
1038
1104
|
last_loc: torch.Tensor,
|
1039
1105
|
backup_state: bool = False,
|
1040
1106
|
):
|
@@ -1045,7 +1111,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1045
1111
|
if backup_state:
|
1046
1112
|
state = self.token_to_kv_pool_allocator.backup_state()
|
1047
1113
|
|
1048
|
-
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
|
1114
|
+
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
|
1115
|
+
seq_lens, seq_lens_cpu, last_loc
|
1116
|
+
)
|
1049
1117
|
if out_cache_loc is None:
|
1050
1118
|
error_msg = (
|
1051
1119
|
f"Decode out of memory. Try to lower your batch size.\n"
|
@@ -1114,6 +1182,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1114
1182
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
1115
1183
|
self.device, non_blocking=True
|
1116
1184
|
)
|
1185
|
+
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
1117
1186
|
|
1118
1187
|
if not decoder_out_cache_loc:
|
1119
1188
|
self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
|
@@ -1138,7 +1207,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1138
1207
|
|
1139
1208
|
# Allocate req slots
|
1140
1209
|
bs = len(self.reqs)
|
1141
|
-
req_pool_indices = self.alloc_req_slots(bs)
|
1210
|
+
req_pool_indices = self.alloc_req_slots(bs, self.reqs)
|
1142
1211
|
|
1143
1212
|
# Init tensors
|
1144
1213
|
reqs = self.reqs
|
@@ -1162,12 +1231,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1162
1231
|
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
1163
1232
|
self.device, non_blocking=True
|
1164
1233
|
)
|
1234
|
+
seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
1165
1235
|
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
1166
1236
|
self.device, non_blocking=True
|
1167
1237
|
)
|
1168
1238
|
prefix_lens_tensor = torch.tensor(
|
1169
1239
|
prefix_lens, dtype=torch.int64, device=self.device
|
1170
1240
|
)
|
1241
|
+
prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
|
1171
1242
|
|
1172
1243
|
token_type_ids_tensor = None
|
1173
1244
|
if len(token_type_ids) > 0:
|
@@ -1207,13 +1278,36 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1207
1278
|
req.is_retracted = False
|
1208
1279
|
|
1209
1280
|
# Compute the relative logprob_start_len in an extend batch
|
1281
|
+
#
|
1282
|
+
# Key variables:
|
1283
|
+
# - logprob_start_len: Absolute position in full sequence where logprob computation begins
|
1284
|
+
# - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
|
1285
|
+
# - extend_input_len: Number of tokens that need to be processed in this extend batch
|
1286
|
+
# (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
|
1287
|
+
# and prefix_indices are the cached/shared prefix tokens)
|
1288
|
+
#
|
1210
1289
|
if req.logprob_start_len >= pre_len:
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1290
|
+
# Optimization for prefill-only requests: When we only need logprobs at
|
1291
|
+
# positions beyond the input sequence (to score next-token likelihood), skip all
|
1292
|
+
# input logprob computation during prefill since no generation will occur.
|
1293
|
+
if self.is_prefill_only and req.logprob_start_len == len(
|
1294
|
+
req.origin_input_ids
|
1295
|
+
):
|
1296
|
+
# Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
|
1297
|
+
req.extend_logprob_start_len = req.extend_input_len
|
1298
|
+
else:
|
1299
|
+
# Convert absolute logprob_start_len to relative extend_logprob_start_len
|
1300
|
+
#
|
1301
|
+
# Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
|
1302
|
+
# Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
|
1303
|
+
# This means: "compute logprobs from position 3 onwards in extend batch"
|
1304
|
+
req.extend_logprob_start_len = min(
|
1305
|
+
req.logprob_start_len - pre_len,
|
1306
|
+
req.extend_input_len,
|
1307
|
+
req.seqlen - 1,
|
1308
|
+
)
|
1216
1309
|
else:
|
1310
|
+
# logprob_start_len is before the current extend batch, so start from beginning
|
1217
1311
|
req.extend_logprob_start_len = 0
|
1218
1312
|
|
1219
1313
|
if self.return_logprob:
|
@@ -1271,13 +1365,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1271
1365
|
prefix_lens_tensor,
|
1272
1366
|
)
|
1273
1367
|
out_cache_loc = self.alloc_paged_token_slots_extend(
|
1274
|
-
prefix_lens_tensor,
|
1368
|
+
prefix_lens_tensor,
|
1369
|
+
prefix_lens_cpu_tensor,
|
1370
|
+
seq_lens_tensor,
|
1371
|
+
seq_lens_cpu,
|
1372
|
+
last_loc,
|
1373
|
+
extend_num_tokens,
|
1275
1374
|
)
|
1276
1375
|
|
1277
1376
|
# Set fields
|
1278
1377
|
self.input_ids = input_ids_tensor
|
1279
1378
|
self.req_pool_indices = req_pool_indices_tensor
|
1280
1379
|
self.seq_lens = seq_lens_tensor
|
1380
|
+
self.seq_lens_cpu = seq_lens_cpu
|
1281
1381
|
self.orig_seq_lens = orig_seq_lens_tensor
|
1282
1382
|
self.out_cache_loc = out_cache_loc
|
1283
1383
|
self.input_embeds = (
|
@@ -1372,21 +1472,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1372
1472
|
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
1373
1473
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
1374
1474
|
|
1375
|
-
def new_page_count_next_decode(self):
|
1475
|
+
def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
|
1376
1476
|
page_size = self.token_to_kv_pool_allocator.page_size
|
1477
|
+
requests = (
|
1478
|
+
self.reqs
|
1479
|
+
if selected_indices is None
|
1480
|
+
else [self.reqs[i] for i in selected_indices]
|
1481
|
+
)
|
1377
1482
|
if page_size == 1:
|
1378
|
-
return len(
|
1483
|
+
return len(requests)
|
1379
1484
|
# In the decoding phase, the length of a request's KV cache should be
|
1380
1485
|
# the total length of the request minus 1
|
1381
1486
|
return (
|
1382
|
-
sum(1 for req in
|
1487
|
+
sum(1 for req in requests if req.seqlen % page_size == 0)
|
1383
1488
|
if self.enable_overlap
|
1384
|
-
else sum(1 for req in
|
1489
|
+
else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
|
1385
1490
|
)
|
1386
1491
|
|
1387
|
-
def check_decode_mem(
|
1492
|
+
def check_decode_mem(
|
1493
|
+
self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
|
1494
|
+
):
|
1388
1495
|
num_tokens = (
|
1389
|
-
self.new_page_count_next_decode()
|
1496
|
+
self.new_page_count_next_decode(selected_indices)
|
1390
1497
|
* buf_multiplier
|
1391
1498
|
* self.token_to_kv_pool_allocator.page_size
|
1392
1499
|
)
|
@@ -1412,34 +1519,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1412
1519
|
reverse=True,
|
1413
1520
|
)
|
1414
1521
|
|
1415
|
-
def get_required_tokens(num_reqs: int):
|
1416
|
-
headroom_for_spec_decode = 0
|
1417
|
-
if server_args.speculative_algorithm:
|
1418
|
-
headroom_for_spec_decode += (
|
1419
|
-
num_reqs
|
1420
|
-
* server_args.speculative_eagle_topk
|
1421
|
-
* server_args.speculative_num_steps
|
1422
|
-
+ num_reqs * server_args.speculative_num_draft_tokens
|
1423
|
-
)
|
1424
|
-
return (
|
1425
|
-
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
1426
|
-
)
|
1427
|
-
|
1428
|
-
def _get_available_size():
|
1429
|
-
if self.is_hybrid:
|
1430
|
-
return min(
|
1431
|
-
self.token_to_kv_pool_allocator.full_available_size(),
|
1432
|
-
self.token_to_kv_pool_allocator.swa_available_size(),
|
1433
|
-
)
|
1434
|
-
else:
|
1435
|
-
return self.token_to_kv_pool_allocator.available_size()
|
1436
|
-
|
1437
1522
|
retracted_reqs = []
|
1438
|
-
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
1439
1523
|
first_iter = True
|
1440
|
-
while (
|
1441
|
-
|
1442
|
-
or first_iter
|
1524
|
+
while first_iter or (
|
1525
|
+
not self.check_decode_mem(selected_indices=sorted_indices)
|
1443
1526
|
):
|
1444
1527
|
if len(sorted_indices) == 1:
|
1445
1528
|
# Corner case: only one request left
|
@@ -1463,41 +1546,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1463
1546
|
idx = sorted_indices.pop()
|
1464
1547
|
req = self.reqs[idx]
|
1465
1548
|
retracted_reqs.append(req)
|
1466
|
-
|
1467
|
-
if server_args.disaggregation_mode == "decode":
|
1468
|
-
req.offload_kv_cache(
|
1469
|
-
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
1470
|
-
)
|
1471
|
-
|
1472
|
-
if isinstance(self.tree_cache, ChunkCache):
|
1473
|
-
# ChunkCache does not have eviction
|
1474
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
1475
|
-
req.req_pool_idx, : seq_lens_cpu[idx]
|
1476
|
-
]
|
1477
|
-
self.token_to_kv_pool_allocator.free(token_indices)
|
1478
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
1479
|
-
else:
|
1480
|
-
# TODO: apply more fine-grained retraction
|
1481
|
-
last_uncached_pos = (
|
1482
|
-
len(req.prefix_indices) // server_args.page_size
|
1483
|
-
) * server_args.page_size
|
1484
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
1485
|
-
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
1486
|
-
]
|
1487
|
-
self.token_to_kv_pool_allocator.free(token_indices)
|
1488
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
1489
|
-
|
1490
|
-
# release the last node
|
1491
|
-
if self.is_hybrid:
|
1492
|
-
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
1493
|
-
else:
|
1494
|
-
self.tree_cache.dec_lock_ref(req.last_node)
|
1495
|
-
|
1496
|
-
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
1497
|
-
num_tokens = len(sorted_indices) * global_config.retract_decode_steps
|
1498
|
-
self._evict_tree_cache_if_needed(num_tokens)
|
1499
|
-
|
1500
|
-
req.reset_for_retract()
|
1549
|
+
self.release_req(idx, len(sorted_indices), server_args)
|
1501
1550
|
|
1502
1551
|
if len(retracted_reqs) == 0:
|
1503
1552
|
# Corner case: only one request left
|
@@ -1516,7 +1565,45 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1516
1565
|
) / total_max_new_tokens
|
1517
1566
|
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
1518
1567
|
|
1519
|
-
return retracted_reqs, new_estimate_ratio
|
1568
|
+
return retracted_reqs, new_estimate_ratio, []
|
1569
|
+
|
1570
|
+
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
1571
|
+
req = self.reqs[idx]
|
1572
|
+
seq_lens_cpu = self.seq_lens_cpu.numpy()
|
1573
|
+
|
1574
|
+
if server_args.disaggregation_mode == "decode":
|
1575
|
+
req.offload_kv_cache(
|
1576
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
1577
|
+
)
|
1578
|
+
if isinstance(self.tree_cache, ChunkCache):
|
1579
|
+
# ChunkCache does not have eviction
|
1580
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
1581
|
+
req.req_pool_idx, : seq_lens_cpu[idx]
|
1582
|
+
]
|
1583
|
+
self.token_to_kv_pool_allocator.free(token_indices)
|
1584
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
1585
|
+
else:
|
1586
|
+
# TODO: apply more fine-grained retraction
|
1587
|
+
last_uncached_pos = (
|
1588
|
+
len(req.prefix_indices) // server_args.page_size
|
1589
|
+
) * server_args.page_size
|
1590
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
1591
|
+
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
1592
|
+
]
|
1593
|
+
self.token_to_kv_pool_allocator.free(token_indices)
|
1594
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
1595
|
+
|
1596
|
+
# release the last node
|
1597
|
+
if self.is_hybrid:
|
1598
|
+
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
1599
|
+
else:
|
1600
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
1601
|
+
|
1602
|
+
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
1603
|
+
num_tokens = remaing_req_count * global_config.retract_decode_steps
|
1604
|
+
self._evict_tree_cache_if_needed(num_tokens)
|
1605
|
+
|
1606
|
+
req.reset_for_retract()
|
1520
1607
|
|
1521
1608
|
def prepare_encoder_info_decode(self):
|
1522
1609
|
# Reset the encoder cached status
|
@@ -1526,6 +1613,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1526
1613
|
self.forward_mode = ForwardMode.IDLE
|
1527
1614
|
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
|
1528
1615
|
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
1616
|
+
self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
|
1529
1617
|
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
1530
1618
|
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
|
1531
1619
|
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
@@ -1540,7 +1628,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1540
1628
|
self.forward_mode = ForwardMode.DECODE
|
1541
1629
|
bs = len(self.reqs)
|
1542
1630
|
|
1543
|
-
if
|
1631
|
+
if (
|
1632
|
+
self.spec_algorithm.is_eagle()
|
1633
|
+
or self.spec_algorithm.is_standalone()
|
1634
|
+
or self.spec_algorithm.is_ngram()
|
1635
|
+
):
|
1544
1636
|
# if spec decoding is used, the decode batch is prepared inside
|
1545
1637
|
# `forward_batch_speculative_generation` after running draft models.
|
1546
1638
|
return
|
@@ -1581,10 +1673,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1581
1673
|
if self.enable_overlap:
|
1582
1674
|
# Do not use in-place operations in the overlap mode
|
1583
1675
|
self.seq_lens = self.seq_lens + 1
|
1676
|
+
self.seq_lens_cpu = self.seq_lens_cpu + 1
|
1584
1677
|
self.orig_seq_lens = self.orig_seq_lens + 1
|
1585
1678
|
else:
|
1586
1679
|
# A faster in-place version
|
1587
1680
|
self.seq_lens.add_(1)
|
1681
|
+
self.seq_lens_cpu.add_(1)
|
1588
1682
|
self.orig_seq_lens.add_(1)
|
1589
1683
|
self.seq_lens_sum += bs
|
1590
1684
|
|
@@ -1603,7 +1697,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1603
1697
|
self.req_pool_indices, self.seq_lens - 2
|
1604
1698
|
]
|
1605
1699
|
self.out_cache_loc = self.alloc_paged_token_slots_decode(
|
1606
|
-
self.seq_lens, last_loc
|
1700
|
+
self.seq_lens, self.seq_lens_cpu, last_loc
|
1607
1701
|
)
|
1608
1702
|
|
1609
1703
|
self.req_to_token_pool.write(
|
@@ -1649,6 +1743,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1649
1743
|
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
1650
1744
|
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
1651
1745
|
self.seq_lens = self.seq_lens[keep_indices_device]
|
1746
|
+
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
|
1652
1747
|
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
1653
1748
|
self.out_cache_loc = None
|
1654
1749
|
self.seq_lens_sum = self.seq_lens.sum().item()
|
@@ -1666,7 +1761,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1666
1761
|
|
1667
1762
|
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
|
1668
1763
|
if self.spec_info:
|
1669
|
-
|
1764
|
+
if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
|
1765
|
+
has_been_filtered = False
|
1766
|
+
else:
|
1767
|
+
has_been_filtered = True
|
1768
|
+
self.spec_info.filter_batch(
|
1769
|
+
new_indices=keep_indices_device,
|
1770
|
+
has_been_filtered=has_been_filtered,
|
1771
|
+
)
|
1670
1772
|
|
1671
1773
|
def merge_batch(self, other: "ScheduleBatch"):
|
1672
1774
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
@@ -1682,6 +1784,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1682
1784
|
[self.req_pool_indices, other.req_pool_indices]
|
1683
1785
|
)
|
1684
1786
|
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
|
1787
|
+
self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
|
1685
1788
|
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
|
1686
1789
|
self.out_cache_loc = None
|
1687
1790
|
self.seq_lens_sum += other.seq_lens_sum
|
@@ -1725,15 +1828,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1725
1828
|
self.sampling_info.grammars = None
|
1726
1829
|
|
1727
1830
|
seq_lens_cpu = (
|
1728
|
-
seq_lens_cpu_cache
|
1729
|
-
if seq_lens_cpu_cache is not None
|
1730
|
-
else self.seq_lens.cpu()
|
1831
|
+
seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
|
1731
1832
|
)
|
1732
1833
|
|
1733
|
-
global bid
|
1734
|
-
bid += 1
|
1735
1834
|
return ModelWorkerBatch(
|
1736
|
-
bid=bid,
|
1737
1835
|
forward_mode=self.forward_mode,
|
1738
1836
|
input_ids=self.input_ids,
|
1739
1837
|
req_pool_indices=self.req_pool_indices,
|
@@ -1780,6 +1878,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1780
1878
|
),
|
1781
1879
|
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
1782
1880
|
launch_done=self.launch_done,
|
1881
|
+
is_prefill_only=self.is_prefill_only,
|
1783
1882
|
)
|
1784
1883
|
|
1785
1884
|
def copy(self):
|
@@ -1852,8 +1951,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1852
1951
|
|
1853
1952
|
@dataclasses.dataclass
|
1854
1953
|
class ModelWorkerBatch:
|
1855
|
-
# The batch id
|
1856
|
-
bid: int
|
1857
1954
|
# The forward mode
|
1858
1955
|
forward_mode: ForwardMode
|
1859
1956
|
# The input ids
|
@@ -1914,14 +2011,19 @@ class ModelWorkerBatch:
|
|
1914
2011
|
|
1915
2012
|
# Speculative decoding
|
1916
2013
|
spec_algorithm: SpeculativeAlgorithm = None
|
1917
|
-
|
2014
|
+
|
2015
|
+
spec_info: Optional[SpecInput] = None
|
2016
|
+
|
1918
2017
|
# If set, the output of the batch contains the hidden states of the run.
|
1919
2018
|
capture_hidden_mode: CaptureHiddenMode = None
|
1920
|
-
hicache_consumer_index: int =
|
2019
|
+
hicache_consumer_index: int = -1
|
1921
2020
|
|
1922
2021
|
# Overlap event
|
1923
2022
|
launch_done: Optional[threading.Event] = None
|
1924
2023
|
|
2024
|
+
# Whether this batch is prefill-only (no token generation needed)
|
2025
|
+
is_prefill_only: bool = False
|
2026
|
+
|
1925
2027
|
|
1926
2028
|
@triton.jit
|
1927
2029
|
def write_req_to_token_pool_triton(
|