sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import enum
|
4
|
+
|
3
5
|
# Copyright 2023-2024 SGLang Team
|
4
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
7
|
# you may not use this file except in compliance with the License.
|
@@ -35,6 +37,7 @@ import copy
|
|
35
37
|
import dataclasses
|
36
38
|
import logging
|
37
39
|
import threading
|
40
|
+
import time
|
38
41
|
from enum import Enum, auto
|
39
42
|
from http import HTTPStatus
|
40
43
|
from itertools import chain
|
@@ -51,18 +54,18 @@ from sglang.srt.disaggregation.base import BaseKVSender
|
|
51
54
|
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
52
55
|
ScheduleBatchDisaggregationDecodeMixin,
|
53
56
|
)
|
57
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
54
58
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
55
|
-
from sglang.srt.layers.moe import is_tbo_enabled
|
56
59
|
from sglang.srt.mem_cache.allocator import (
|
57
60
|
BaseTokenToKVPoolAllocator,
|
58
61
|
SWATokenToKVPoolAllocator,
|
59
62
|
)
|
60
63
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
61
64
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
62
|
-
from sglang.srt.mem_cache.
|
63
|
-
from sglang.srt.mem_cache.
|
65
|
+
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
66
|
+
from sglang.srt.mem_cache.radix_cache import RadixKey
|
64
67
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
65
|
-
from sglang.srt.metrics.collector import TimeStats
|
68
|
+
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
66
69
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
67
70
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
68
71
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -71,8 +74,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
|
|
71
74
|
|
72
75
|
if TYPE_CHECKING:
|
73
76
|
from sglang.srt.configs.model_config import ModelConfig
|
74
|
-
from sglang.srt.speculative.
|
75
|
-
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
77
|
+
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
|
76
78
|
|
77
79
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
78
80
|
|
@@ -87,6 +89,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
87
89
|
"disable_flashinfer_cutlass_moe_fp4_allgather",
|
88
90
|
"disable_radix_cache",
|
89
91
|
"enable_dp_lm_head",
|
92
|
+
"enable_fp32_lm_head",
|
90
93
|
"flashinfer_mxfp4_moe_precision",
|
91
94
|
"enable_flashinfer_allreduce_fusion",
|
92
95
|
"moe_dense_tp_size",
|
@@ -94,20 +97,24 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
94
97
|
"ep_num_redundant_experts",
|
95
98
|
"enable_nan_detection",
|
96
99
|
"flashinfer_mla_disable_ragged",
|
97
|
-
"
|
100
|
+
"pp_max_micro_batch_size",
|
98
101
|
"disable_shared_experts_fusion",
|
99
102
|
"sampling_backend",
|
100
103
|
"speculative_accept_threshold_single",
|
101
104
|
"speculative_accept_threshold_acc",
|
105
|
+
"speculative_attention_mode",
|
102
106
|
"torchao_config",
|
103
107
|
"triton_attention_reduce_in_fp32",
|
104
108
|
"num_reserved_decode_tokens",
|
105
109
|
"weight_loader_disable_mmap",
|
106
110
|
"enable_multimodal",
|
107
111
|
"enable_symm_mem",
|
108
|
-
"quantization",
|
109
112
|
"enable_custom_logit_processor",
|
110
113
|
"disaggregation_mode",
|
114
|
+
"enable_deterministic_inference",
|
115
|
+
"nsa_prefill",
|
116
|
+
"nsa_decode",
|
117
|
+
"multi_item_scoring_delimiter",
|
111
118
|
]
|
112
119
|
|
113
120
|
# Put some global args for easy access
|
@@ -408,6 +415,23 @@ class MultimodalInputs:
|
|
408
415
|
# other args would be kept intact
|
409
416
|
|
410
417
|
|
418
|
+
class RequestStage(str, enum.Enum):
|
419
|
+
# prefill
|
420
|
+
PREFILL_WAITING = "prefill_waiting"
|
421
|
+
|
422
|
+
# disaggregation prefill
|
423
|
+
PREFILL_PREPARE = "prefill_prepare"
|
424
|
+
PREFILL_BOOTSTRAP = "prefill_bootstrap"
|
425
|
+
PREFILL_FORWARD = "prefill_forward"
|
426
|
+
PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
|
427
|
+
|
428
|
+
# disaggregation decode
|
429
|
+
DECODE_PREPARE = "decode_prepare"
|
430
|
+
DECODE_BOOTSTRAP = "decode_bootstrap"
|
431
|
+
DECODE_WAITING = "decode_waiting"
|
432
|
+
DECODE_TRANSFERRED = "decode_transferred"
|
433
|
+
|
434
|
+
|
411
435
|
class Req:
|
412
436
|
"""The input and output status of a request."""
|
413
437
|
|
@@ -432,8 +456,12 @@ class Req:
|
|
432
456
|
bootstrap_host: Optional[str] = None,
|
433
457
|
bootstrap_port: Optional[int] = None,
|
434
458
|
bootstrap_room: Optional[int] = None,
|
459
|
+
disagg_mode: Optional[DisaggregationMode] = None,
|
435
460
|
data_parallel_rank: Optional[int] = None,
|
436
461
|
vocab_size: Optional[int] = None,
|
462
|
+
priority: Optional[int] = None,
|
463
|
+
metrics_collector: Optional[SchedulerMetricsCollector] = None,
|
464
|
+
extra_key: Optional[str] = None,
|
437
465
|
):
|
438
466
|
# Input and output info
|
439
467
|
self.rid = rid
|
@@ -466,6 +494,14 @@ class Req:
|
|
466
494
|
self.sampling_params = sampling_params
|
467
495
|
self.custom_logit_processor = custom_logit_processor
|
468
496
|
self.return_hidden_states = return_hidden_states
|
497
|
+
|
498
|
+
# extra key for classifying the request (e.g. cache_salt)
|
499
|
+
if lora_id is not None:
|
500
|
+
extra_key = (
|
501
|
+
extra_key or ""
|
502
|
+
) + lora_id # lora_id is concatenated to the extra key
|
503
|
+
|
504
|
+
self.extra_key = extra_key
|
469
505
|
self.lora_id = lora_id
|
470
506
|
|
471
507
|
# Memory pool info
|
@@ -484,6 +520,7 @@ class Req:
|
|
484
520
|
self.stream = stream
|
485
521
|
self.eos_token_ids = eos_token_ids
|
486
522
|
self.vocab_size = vocab_size
|
523
|
+
self.priority = priority
|
487
524
|
|
488
525
|
# For incremental decoding
|
489
526
|
# ----- | --------- read_ids -------|
|
@@ -503,7 +540,7 @@ class Req:
|
|
503
540
|
|
504
541
|
# Prefix info
|
505
542
|
# The indices to kv cache for the shared prefix.
|
506
|
-
self.prefix_indices: torch.Tensor =
|
543
|
+
self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
|
507
544
|
# Number of tokens to run prefill.
|
508
545
|
self.extend_input_len = 0
|
509
546
|
# The relative logprob_start_len in an extend batch
|
@@ -513,6 +550,8 @@ class Req:
|
|
513
550
|
self.host_hit_length = 0
|
514
551
|
# The node to lock until for swa radix tree lock ref
|
515
552
|
self.swa_uuid_for_lock: Optional[int] = None
|
553
|
+
# The prefix length of the last prefix matching
|
554
|
+
self.last_matched_prefix_len: int = 0
|
516
555
|
|
517
556
|
# Whether or not if it is chunked. It increments whenever
|
518
557
|
# it is chunked, and decrement whenever chunked request is
|
@@ -561,7 +600,10 @@ class Req:
|
|
561
600
|
# shape: (bs, k)
|
562
601
|
self.output_top_logprobs_val = []
|
563
602
|
self.output_top_logprobs_idx = []
|
564
|
-
|
603
|
+
# Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
|
604
|
+
self.output_token_ids_logprobs_val: List[
|
605
|
+
Union[List[float], torch.Tensor]
|
606
|
+
] = []
|
565
607
|
self.output_token_ids_logprobs_idx = []
|
566
608
|
else:
|
567
609
|
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
@@ -571,6 +613,8 @@ class Req:
|
|
571
613
|
) = None
|
572
614
|
self.hidden_states: List[List[float]] = []
|
573
615
|
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
|
616
|
+
self.output_topk_p = None
|
617
|
+
self.output_topk_index = None
|
574
618
|
|
575
619
|
# Embedding (return values)
|
576
620
|
self.embedding = None
|
@@ -588,10 +632,10 @@ class Req:
|
|
588
632
|
self.spec_verify_ct = 0
|
589
633
|
|
590
634
|
# For metrics
|
591
|
-
self.
|
635
|
+
self.metrics_collector = metrics_collector
|
636
|
+
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
|
592
637
|
self.has_log_time_stats: bool = False
|
593
|
-
self.
|
594
|
-
self.queue_time_end = None
|
638
|
+
self.last_tic = time.monotonic()
|
595
639
|
|
596
640
|
# For disaggregation
|
597
641
|
self.bootstrap_host: str = bootstrap_host
|
@@ -619,6 +663,27 @@ class Req:
|
|
619
663
|
def seqlen(self):
|
620
664
|
return len(self.origin_input_ids) + len(self.output_ids)
|
621
665
|
|
666
|
+
@property
|
667
|
+
def is_prefill_only(self) -> bool:
|
668
|
+
"""Check if this request is prefill-only (no token generation needed)."""
|
669
|
+
# NOTE: when spec is enabled, prefill_only optimizations are disabled
|
670
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
671
|
+
|
672
|
+
spec_alg = global_server_args_dict["speculative_algorithm"]
|
673
|
+
return self.sampling_params.max_new_tokens == 0 and (
|
674
|
+
spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
|
675
|
+
)
|
676
|
+
|
677
|
+
def add_latency(self, stage: RequestStage):
|
678
|
+
if self.metrics_collector is None:
|
679
|
+
return
|
680
|
+
|
681
|
+
now = time.monotonic()
|
682
|
+
self.metrics_collector.observe_per_stage_req_latency(
|
683
|
+
stage.value, now - self.last_tic
|
684
|
+
)
|
685
|
+
self.last_tic = now
|
686
|
+
|
622
687
|
def extend_image_inputs(self, image_inputs):
|
623
688
|
if self.multimodal_inputs is None:
|
624
689
|
self.multimodal_inputs = image_inputs
|
@@ -629,51 +694,27 @@ class Req:
|
|
629
694
|
# Whether request reached finished condition
|
630
695
|
return self.finished_reason is not None
|
631
696
|
|
632
|
-
def init_next_round_input(
|
633
|
-
self,
|
634
|
-
tree_cache: Optional[BasePrefixCache] = None,
|
635
|
-
):
|
636
|
-
self.fill_ids = self.origin_input_ids + self.output_ids
|
637
|
-
if tree_cache is not None:
|
638
|
-
if isinstance(tree_cache, LoRARadixCache):
|
639
|
-
(
|
640
|
-
self.prefix_indices,
|
641
|
-
self.last_node,
|
642
|
-
self.last_host_node,
|
643
|
-
self.host_hit_length,
|
644
|
-
) = tree_cache.match_prefix_with_lora_id(
|
645
|
-
key=LoRAKey(
|
646
|
-
lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids()
|
647
|
-
),
|
648
|
-
)
|
649
|
-
else:
|
650
|
-
(
|
651
|
-
self.prefix_indices,
|
652
|
-
self.last_node,
|
653
|
-
self.last_host_node,
|
654
|
-
self.host_hit_length,
|
655
|
-
) = tree_cache.match_prefix(
|
656
|
-
key=self.adjust_max_prefix_ids(),
|
657
|
-
)
|
658
|
-
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
659
|
-
|
660
|
-
def adjust_max_prefix_ids(self):
|
697
|
+
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
661
698
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
662
699
|
input_len = len(self.fill_ids)
|
663
|
-
|
664
|
-
# FIXME: To work around some bugs in logprob computation, we need to ensure each
|
665
|
-
# request has at least one token. Later, we can relax this requirement and use `input_len`.
|
700
|
+
# NOTE: the matched length is at most 1 less than the input length to enable logprob computation
|
666
701
|
max_prefix_len = input_len - 1
|
667
|
-
|
668
|
-
if self.sampling_params.max_new_tokens > 0:
|
669
|
-
# Need at least one token to compute logits
|
670
|
-
max_prefix_len = min(max_prefix_len, input_len - 1)
|
671
|
-
|
672
702
|
if self.return_logprob:
|
673
703
|
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
674
|
-
|
675
704
|
max_prefix_len = max(max_prefix_len, 0)
|
676
|
-
|
705
|
+
token_ids = self.fill_ids[:max_prefix_len]
|
706
|
+
|
707
|
+
if tree_cache is not None:
|
708
|
+
(
|
709
|
+
self.prefix_indices,
|
710
|
+
self.last_node,
|
711
|
+
self.last_host_node,
|
712
|
+
self.host_hit_length,
|
713
|
+
) = tree_cache.match_prefix(
|
714
|
+
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key)
|
715
|
+
)
|
716
|
+
self.last_matched_prefix_len = len(self.prefix_indices)
|
717
|
+
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
677
718
|
|
678
719
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
679
720
|
def init_incremental_detokenize(self):
|
@@ -684,9 +725,15 @@ class Req:
|
|
684
725
|
self.surr_offset = max(
|
685
726
|
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
|
686
727
|
)
|
728
|
+
self.surr_and_decode_ids = (
|
729
|
+
self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
|
730
|
+
)
|
731
|
+
self.cur_decode_ids_len = len(self.output_ids)
|
732
|
+
else:
|
733
|
+
self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
|
734
|
+
self.cur_decode_ids_len = len(self.output_ids)
|
687
735
|
|
688
|
-
|
689
|
-
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
736
|
+
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
|
690
737
|
|
691
738
|
def check_finished(self):
|
692
739
|
if self.finished():
|
@@ -749,7 +796,7 @@ class Req:
|
|
749
796
|
return
|
750
797
|
|
751
798
|
def reset_for_retract(self):
|
752
|
-
self.prefix_indices =
|
799
|
+
self.prefix_indices = torch.empty((0,), dtype=torch.int64)
|
753
800
|
self.last_node = None
|
754
801
|
self.swa_uuid_for_lock = None
|
755
802
|
self.extend_input_len = 0
|
@@ -781,10 +828,10 @@ class Req:
|
|
781
828
|
return
|
782
829
|
|
783
830
|
if self.bootstrap_room is not None:
|
784
|
-
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.
|
831
|
+
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
|
785
832
|
else:
|
786
|
-
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.
|
787
|
-
logger.info(f"{prefix}: {self.time_stats}")
|
833
|
+
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
|
834
|
+
logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
|
788
835
|
self.has_log_time_stats = True
|
789
836
|
|
790
837
|
def set_finish_with_abort(self, error_msg: str):
|
@@ -807,10 +854,6 @@ class Req:
|
|
807
854
|
)
|
808
855
|
|
809
856
|
|
810
|
-
# Batch id
|
811
|
-
bid = 0
|
812
|
-
|
813
|
-
|
814
857
|
@dataclasses.dataclass
|
815
858
|
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
816
859
|
"""Store all information of a batch on the scheduler."""
|
@@ -831,15 +874,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
831
874
|
# This is an optimization to reduce the overhead of the prefill check.
|
832
875
|
batch_is_full: bool = False
|
833
876
|
|
834
|
-
# Events
|
835
|
-
launch_done: Optional[threading.Event] = None
|
836
|
-
|
837
877
|
# For chunked prefill in PP
|
838
878
|
chunked_req: Optional[Req] = None
|
839
879
|
|
840
880
|
# Sampling info
|
841
881
|
sampling_info: SamplingBatchInfo = None
|
842
|
-
next_batch_sampling_info: SamplingBatchInfo = None
|
843
882
|
|
844
883
|
# Batched arguments to model runner
|
845
884
|
input_ids: torch.Tensor = None # shape: [b], int64
|
@@ -847,6 +886,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
847
886
|
token_type_ids: torch.Tensor = None # shape: [b], int64
|
848
887
|
req_pool_indices: torch.Tensor = None # shape: [b], int64
|
849
888
|
seq_lens: torch.Tensor = None # shape: [b], int64
|
889
|
+
seq_lens_cpu: torch.Tensor = None # shape: [b], int64
|
850
890
|
# The output locations of the KV cache
|
851
891
|
out_cache_loc: torch.Tensor = None # shape: [b], int64
|
852
892
|
output_ids: torch.Tensor = None # shape: [b], int64
|
@@ -902,7 +942,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
902
942
|
|
903
943
|
# Speculative decoding
|
904
944
|
spec_algorithm: SpeculativeAlgorithm = None
|
905
|
-
spec_info: Optional[
|
945
|
+
# spec_info: Optional[SpecInput] = None
|
946
|
+
spec_info: Optional[SpecInput] = None
|
906
947
|
|
907
948
|
# Whether to return hidden states
|
908
949
|
return_hidden_states: bool = False
|
@@ -911,7 +952,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
911
952
|
is_prefill_only: bool = False
|
912
953
|
|
913
954
|
# hicache pointer for synchronizing data loading from CPU to GPU
|
914
|
-
hicache_consumer_index: int =
|
955
|
+
hicache_consumer_index: int = -1
|
915
956
|
|
916
957
|
@classmethod
|
917
958
|
def init_new(
|
@@ -950,9 +991,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
950
991
|
device=req_to_token_pool.device,
|
951
992
|
spec_algorithm=spec_algorithm,
|
952
993
|
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
953
|
-
is_prefill_only=all(
|
954
|
-
req.sampling_params.max_new_tokens == 0 for req in reqs
|
955
|
-
),
|
994
|
+
is_prefill_only=all(req.is_prefill_only for req in reqs),
|
956
995
|
chunked_req=chunked_req,
|
957
996
|
)
|
958
997
|
|
@@ -962,8 +1001,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
962
1001
|
def is_empty(self):
|
963
1002
|
return len(self.reqs) == 0
|
964
1003
|
|
965
|
-
def alloc_req_slots(self, num_reqs: int):
|
966
|
-
|
1004
|
+
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
|
1005
|
+
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
|
1006
|
+
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
|
1007
|
+
else:
|
1008
|
+
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
967
1009
|
if req_pool_indices is None:
|
968
1010
|
raise RuntimeError(
|
969
1011
|
"alloc_req_slots runs out of memory. "
|
@@ -1000,7 +1042,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1000
1042
|
def alloc_paged_token_slots_extend(
|
1001
1043
|
self,
|
1002
1044
|
prefix_lens: torch.Tensor,
|
1045
|
+
prefix_lens_cpu: torch.Tensor,
|
1003
1046
|
seq_lens: torch.Tensor,
|
1047
|
+
seq_lens_cpu: torch.Tensor,
|
1004
1048
|
last_loc: torch.Tensor,
|
1005
1049
|
extend_num_tokens: int,
|
1006
1050
|
backup_state: bool = False,
|
@@ -1008,7 +1052,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1008
1052
|
# Over estimate the number of tokens: assume each request needs a new page.
|
1009
1053
|
num_tokens = (
|
1010
1054
|
extend_num_tokens
|
1011
|
-
+ len(
|
1055
|
+
+ len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
|
1012
1056
|
)
|
1013
1057
|
self._evict_tree_cache_if_needed(num_tokens)
|
1014
1058
|
|
@@ -1016,7 +1060,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1016
1060
|
state = self.token_to_kv_pool_allocator.backup_state()
|
1017
1061
|
|
1018
1062
|
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
1019
|
-
prefix_lens,
|
1063
|
+
prefix_lens,
|
1064
|
+
prefix_lens_cpu,
|
1065
|
+
seq_lens,
|
1066
|
+
seq_lens_cpu,
|
1067
|
+
last_loc,
|
1068
|
+
extend_num_tokens,
|
1020
1069
|
)
|
1021
1070
|
if out_cache_loc is None:
|
1022
1071
|
error_msg = (
|
@@ -1035,6 +1084,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1035
1084
|
def alloc_paged_token_slots_decode(
|
1036
1085
|
self,
|
1037
1086
|
seq_lens: torch.Tensor,
|
1087
|
+
seq_lens_cpu: torch.Tensor,
|
1038
1088
|
last_loc: torch.Tensor,
|
1039
1089
|
backup_state: bool = False,
|
1040
1090
|
):
|
@@ -1045,7 +1095,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1045
1095
|
if backup_state:
|
1046
1096
|
state = self.token_to_kv_pool_allocator.backup_state()
|
1047
1097
|
|
1048
|
-
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
|
1098
|
+
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
|
1099
|
+
seq_lens, seq_lens_cpu, last_loc
|
1100
|
+
)
|
1049
1101
|
if out_cache_loc is None:
|
1050
1102
|
error_msg = (
|
1051
1103
|
f"Decode out of memory. Try to lower your batch size.\n"
|
@@ -1060,6 +1112,47 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1060
1112
|
else:
|
1061
1113
|
return out_cache_loc
|
1062
1114
|
|
1115
|
+
def write_cache_indices(
|
1116
|
+
self,
|
1117
|
+
req_pool_indices: List[int],
|
1118
|
+
prefix_lens: List[int],
|
1119
|
+
seq_lens: List[int],
|
1120
|
+
extend_lens: List[int],
|
1121
|
+
out_cache_loc: torch.Tensor,
|
1122
|
+
req_pool_indices_tensor: torch.Tensor,
|
1123
|
+
prefix_lens_tensor: torch.Tensor,
|
1124
|
+
seq_lens_tensor: torch.Tensor,
|
1125
|
+
extend_lens_tensor: torch.Tensor,
|
1126
|
+
prefix_tensors: list[torch.Tensor],
|
1127
|
+
):
|
1128
|
+
if support_triton(global_server_args_dict.get("attention_backend")):
|
1129
|
+
prefix_pointers = torch.tensor(
|
1130
|
+
[t.data_ptr() for t in prefix_tensors], device=self.device
|
1131
|
+
)
|
1132
|
+
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
1133
|
+
write_req_to_token_pool_triton[(len(req_pool_indices),)](
|
1134
|
+
self.req_to_token_pool.req_to_token,
|
1135
|
+
req_pool_indices_tensor,
|
1136
|
+
prefix_pointers,
|
1137
|
+
prefix_lens_tensor,
|
1138
|
+
seq_lens_tensor,
|
1139
|
+
extend_lens_tensor,
|
1140
|
+
out_cache_loc,
|
1141
|
+
self.req_to_token_pool.req_to_token.shape[1],
|
1142
|
+
)
|
1143
|
+
else:
|
1144
|
+
pt = 0
|
1145
|
+
for i in range(len(req_pool_indices)):
|
1146
|
+
self.req_to_token_pool.write(
|
1147
|
+
(req_pool_indices[i], slice(0, prefix_lens[i])),
|
1148
|
+
prefix_tensors[i],
|
1149
|
+
)
|
1150
|
+
self.req_to_token_pool.write(
|
1151
|
+
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
|
1152
|
+
out_cache_loc[pt : pt + extend_lens[i]],
|
1153
|
+
)
|
1154
|
+
pt += extend_lens[i]
|
1155
|
+
|
1063
1156
|
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
1064
1157
|
self.encoder_lens_cpu = []
|
1065
1158
|
self.encoder_cached = []
|
@@ -1114,6 +1207,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1114
1207
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
1115
1208
|
self.device, non_blocking=True
|
1116
1209
|
)
|
1210
|
+
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
1117
1211
|
|
1118
1212
|
if not decoder_out_cache_loc:
|
1119
1213
|
self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
|
@@ -1136,10 +1230,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1136
1230
|
def prepare_for_extend(self):
|
1137
1231
|
self.forward_mode = ForwardMode.EXTEND
|
1138
1232
|
|
1139
|
-
# Allocate req slots
|
1140
|
-
bs = len(self.reqs)
|
1141
|
-
req_pool_indices = self.alloc_req_slots(bs)
|
1142
|
-
|
1143
1233
|
# Init tensors
|
1144
1234
|
reqs = self.reqs
|
1145
1235
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
@@ -1153,21 +1243,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1153
1243
|
r.token_type_ids for r in reqs if r.token_type_ids is not None
|
1154
1244
|
]
|
1155
1245
|
|
1156
|
-
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
1157
|
-
self.device, non_blocking=True
|
1158
|
-
)
|
1159
1246
|
input_ids_tensor = torch.tensor(
|
1160
1247
|
list(chain.from_iterable(input_ids)), dtype=torch.int64
|
1161
1248
|
).to(self.device, non_blocking=True)
|
1162
1249
|
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
1163
1250
|
self.device, non_blocking=True
|
1164
1251
|
)
|
1252
|
+
seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
1165
1253
|
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
1166
1254
|
self.device, non_blocking=True
|
1167
1255
|
)
|
1168
1256
|
prefix_lens_tensor = torch.tensor(
|
1169
1257
|
prefix_lens, dtype=torch.int64, device=self.device
|
1170
1258
|
)
|
1259
|
+
prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
|
1171
1260
|
|
1172
1261
|
token_type_ids_tensor = None
|
1173
1262
|
if len(token_type_ids) > 0:
|
@@ -1177,7 +1266,49 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1177
1266
|
|
1178
1267
|
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
|
1179
1268
|
|
1180
|
-
#
|
1269
|
+
# Allocate req slots
|
1270
|
+
bs = len(self.reqs)
|
1271
|
+
req_pool_indices = self.alloc_req_slots(bs, self.reqs)
|
1272
|
+
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
1273
|
+
self.device, non_blocking=True
|
1274
|
+
)
|
1275
|
+
|
1276
|
+
# Allocate memory
|
1277
|
+
if self.token_to_kv_pool_allocator.page_size == 1:
|
1278
|
+
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
1279
|
+
else:
|
1280
|
+
last_loc = [
|
1281
|
+
(
|
1282
|
+
r.prefix_indices[-1:]
|
1283
|
+
if len(r.prefix_indices) > 0
|
1284
|
+
else torch.tensor([-1], device=self.device)
|
1285
|
+
)
|
1286
|
+
for r in self.reqs
|
1287
|
+
]
|
1288
|
+
out_cache_loc = self.alloc_paged_token_slots_extend(
|
1289
|
+
prefix_lens_tensor,
|
1290
|
+
prefix_lens_cpu_tensor,
|
1291
|
+
seq_lens_tensor,
|
1292
|
+
seq_lens_cpu,
|
1293
|
+
torch.cat(last_loc),
|
1294
|
+
extend_num_tokens,
|
1295
|
+
)
|
1296
|
+
|
1297
|
+
# Write allocated tokens to req_to_token_pool
|
1298
|
+
self.write_cache_indices(
|
1299
|
+
req_pool_indices,
|
1300
|
+
prefix_lens,
|
1301
|
+
seq_lens,
|
1302
|
+
extend_lens,
|
1303
|
+
out_cache_loc,
|
1304
|
+
req_pool_indices_tensor,
|
1305
|
+
prefix_lens_tensor,
|
1306
|
+
seq_lens_tensor,
|
1307
|
+
extend_lens_tensor,
|
1308
|
+
[r.prefix_indices for r in reqs],
|
1309
|
+
)
|
1310
|
+
|
1311
|
+
# Set fields
|
1181
1312
|
input_embeds = []
|
1182
1313
|
extend_input_logprob_token_ids = []
|
1183
1314
|
multimodal_inputs = []
|
@@ -1187,9 +1318,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1187
1318
|
assert seq_len - pre_len == req.extend_input_len
|
1188
1319
|
|
1189
1320
|
if pre_len > 0:
|
1190
|
-
self.req_to_token_pool.write(
|
1191
|
-
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
1192
|
-
)
|
1193
1321
|
if isinstance(self.tree_cache, SWAChunkCache):
|
1194
1322
|
self.tree_cache.evict_swa(
|
1195
1323
|
req, pre_len, self.model_config.attention_chunk_size
|
@@ -1207,13 +1335,36 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1207
1335
|
req.is_retracted = False
|
1208
1336
|
|
1209
1337
|
# Compute the relative logprob_start_len in an extend batch
|
1338
|
+
#
|
1339
|
+
# Key variables:
|
1340
|
+
# - logprob_start_len: Absolute position in full sequence where logprob computation begins
|
1341
|
+
# - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
|
1342
|
+
# - extend_input_len: Number of tokens that need to be processed in this extend batch
|
1343
|
+
# (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
|
1344
|
+
# and prefix_indices are the cached/shared prefix tokens)
|
1345
|
+
#
|
1210
1346
|
if req.logprob_start_len >= pre_len:
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1347
|
+
# Optimization for prefill-only requests: When we only need logprobs at
|
1348
|
+
# positions beyond the input sequence (to score next-token likelihood), skip all
|
1349
|
+
# input logprob computation during prefill since no generation will occur.
|
1350
|
+
if self.is_prefill_only and req.logprob_start_len == len(
|
1351
|
+
req.origin_input_ids
|
1352
|
+
):
|
1353
|
+
# Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
|
1354
|
+
req.extend_logprob_start_len = req.extend_input_len
|
1355
|
+
else:
|
1356
|
+
# Convert absolute logprob_start_len to relative extend_logprob_start_len
|
1357
|
+
#
|
1358
|
+
# Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
|
1359
|
+
# Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
|
1360
|
+
# This means: "compute logprobs from position 3 onwards in extend batch"
|
1361
|
+
req.extend_logprob_start_len = min(
|
1362
|
+
req.logprob_start_len - pre_len,
|
1363
|
+
req.extend_input_len,
|
1364
|
+
req.seqlen - 1,
|
1365
|
+
)
|
1216
1366
|
else:
|
1367
|
+
# logprob_start_len is before the current extend batch, so start from beginning
|
1217
1368
|
req.extend_logprob_start_len = 0
|
1218
1369
|
|
1219
1370
|
if self.return_logprob:
|
@@ -1261,23 +1412,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1261
1412
|
else:
|
1262
1413
|
extend_input_logprob_token_ids = None
|
1263
1414
|
|
1264
|
-
# Allocate memory
|
1265
|
-
if self.token_to_kv_pool_allocator.page_size == 1:
|
1266
|
-
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
1267
|
-
else:
|
1268
|
-
last_loc = get_last_loc(
|
1269
|
-
self.req_to_token_pool.req_to_token,
|
1270
|
-
req_pool_indices_tensor,
|
1271
|
-
prefix_lens_tensor,
|
1272
|
-
)
|
1273
|
-
out_cache_loc = self.alloc_paged_token_slots_extend(
|
1274
|
-
prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
|
1275
|
-
)
|
1276
|
-
|
1277
|
-
# Set fields
|
1278
1415
|
self.input_ids = input_ids_tensor
|
1279
1416
|
self.req_pool_indices = req_pool_indices_tensor
|
1280
1417
|
self.seq_lens = seq_lens_tensor
|
1418
|
+
self.seq_lens_cpu = seq_lens_cpu
|
1281
1419
|
self.orig_seq_lens = orig_seq_lens_tensor
|
1282
1420
|
self.out_cache_loc = out_cache_loc
|
1283
1421
|
self.input_embeds = (
|
@@ -1306,28 +1444,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1306
1444
|
self.extend_lens = extend_lens
|
1307
1445
|
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
1308
1446
|
|
1309
|
-
# Write to req_to_token_pool
|
1310
|
-
if support_triton(global_server_args_dict.get("attention_backend")):
|
1311
|
-
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
1312
|
-
|
1313
|
-
write_req_to_token_pool_triton[(bs,)](
|
1314
|
-
self.req_to_token_pool.req_to_token,
|
1315
|
-
req_pool_indices_tensor,
|
1316
|
-
prefix_lens_tensor,
|
1317
|
-
seq_lens_tensor,
|
1318
|
-
extend_lens_tensor,
|
1319
|
-
out_cache_loc,
|
1320
|
-
self.req_to_token_pool.req_to_token.shape[1],
|
1321
|
-
)
|
1322
|
-
else:
|
1323
|
-
pt = 0
|
1324
|
-
for i in range(bs):
|
1325
|
-
self.req_to_token_pool.write(
|
1326
|
-
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
|
1327
|
-
out_cache_loc[pt : pt + extend_lens[i]],
|
1328
|
-
)
|
1329
|
-
pt += extend_lens[i]
|
1330
|
-
|
1331
1447
|
if self.model_config.is_encoder_decoder:
|
1332
1448
|
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
1333
1449
|
|
@@ -1372,21 +1488,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1372
1488
|
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
1373
1489
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
1374
1490
|
|
1375
|
-
def new_page_count_next_decode(self):
|
1491
|
+
def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
|
1376
1492
|
page_size = self.token_to_kv_pool_allocator.page_size
|
1493
|
+
requests = (
|
1494
|
+
self.reqs
|
1495
|
+
if selected_indices is None
|
1496
|
+
else [self.reqs[i] for i in selected_indices]
|
1497
|
+
)
|
1377
1498
|
if page_size == 1:
|
1378
|
-
return len(
|
1499
|
+
return len(requests)
|
1379
1500
|
# In the decoding phase, the length of a request's KV cache should be
|
1380
1501
|
# the total length of the request minus 1
|
1381
1502
|
return (
|
1382
|
-
sum(1 for req in
|
1503
|
+
sum(1 for req in requests if req.seqlen % page_size == 0)
|
1383
1504
|
if self.enable_overlap
|
1384
|
-
else sum(1 for req in
|
1505
|
+
else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
|
1385
1506
|
)
|
1386
1507
|
|
1387
|
-
def check_decode_mem(
|
1508
|
+
def check_decode_mem(
|
1509
|
+
self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
|
1510
|
+
):
|
1388
1511
|
num_tokens = (
|
1389
|
-
self.new_page_count_next_decode()
|
1512
|
+
self.new_page_count_next_decode(selected_indices)
|
1390
1513
|
* buf_multiplier
|
1391
1514
|
* self.token_to_kv_pool_allocator.page_size
|
1392
1515
|
)
|
@@ -1412,34 +1535,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1412
1535
|
reverse=True,
|
1413
1536
|
)
|
1414
1537
|
|
1415
|
-
def get_required_tokens(num_reqs: int):
|
1416
|
-
headroom_for_spec_decode = 0
|
1417
|
-
if server_args.speculative_algorithm:
|
1418
|
-
headroom_for_spec_decode += (
|
1419
|
-
num_reqs
|
1420
|
-
* server_args.speculative_eagle_topk
|
1421
|
-
* server_args.speculative_num_steps
|
1422
|
-
+ num_reqs * server_args.speculative_num_draft_tokens
|
1423
|
-
)
|
1424
|
-
return (
|
1425
|
-
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
1426
|
-
)
|
1427
|
-
|
1428
|
-
def _get_available_size():
|
1429
|
-
if self.is_hybrid:
|
1430
|
-
return min(
|
1431
|
-
self.token_to_kv_pool_allocator.full_available_size(),
|
1432
|
-
self.token_to_kv_pool_allocator.swa_available_size(),
|
1433
|
-
)
|
1434
|
-
else:
|
1435
|
-
return self.token_to_kv_pool_allocator.available_size()
|
1436
|
-
|
1437
1538
|
retracted_reqs = []
|
1438
|
-
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
1439
1539
|
first_iter = True
|
1440
|
-
while (
|
1441
|
-
|
1442
|
-
or first_iter
|
1540
|
+
while first_iter or (
|
1541
|
+
not self.check_decode_mem(selected_indices=sorted_indices)
|
1443
1542
|
):
|
1444
1543
|
if len(sorted_indices) == 1:
|
1445
1544
|
# Corner case: only one request left
|
@@ -1463,41 +1562,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1463
1562
|
idx = sorted_indices.pop()
|
1464
1563
|
req = self.reqs[idx]
|
1465
1564
|
retracted_reqs.append(req)
|
1466
|
-
|
1467
|
-
if server_args.disaggregation_mode == "decode":
|
1468
|
-
req.offload_kv_cache(
|
1469
|
-
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
1470
|
-
)
|
1471
|
-
|
1472
|
-
if isinstance(self.tree_cache, ChunkCache):
|
1473
|
-
# ChunkCache does not have eviction
|
1474
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
1475
|
-
req.req_pool_idx, : seq_lens_cpu[idx]
|
1476
|
-
]
|
1477
|
-
self.token_to_kv_pool_allocator.free(token_indices)
|
1478
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
1479
|
-
else:
|
1480
|
-
# TODO: apply more fine-grained retraction
|
1481
|
-
last_uncached_pos = (
|
1482
|
-
len(req.prefix_indices) // server_args.page_size
|
1483
|
-
) * server_args.page_size
|
1484
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
1485
|
-
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
1486
|
-
]
|
1487
|
-
self.token_to_kv_pool_allocator.free(token_indices)
|
1488
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
1489
|
-
|
1490
|
-
# release the last node
|
1491
|
-
if self.is_hybrid:
|
1492
|
-
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
1493
|
-
else:
|
1494
|
-
self.tree_cache.dec_lock_ref(req.last_node)
|
1495
|
-
|
1496
|
-
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
1497
|
-
num_tokens = len(sorted_indices) * global_config.retract_decode_steps
|
1498
|
-
self._evict_tree_cache_if_needed(num_tokens)
|
1499
|
-
|
1500
|
-
req.reset_for_retract()
|
1565
|
+
self.release_req(idx, len(sorted_indices), server_args)
|
1501
1566
|
|
1502
1567
|
if len(retracted_reqs) == 0:
|
1503
1568
|
# Corner case: only one request left
|
@@ -1516,7 +1581,45 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1516
1581
|
) / total_max_new_tokens
|
1517
1582
|
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
1518
1583
|
|
1519
|
-
return retracted_reqs, new_estimate_ratio
|
1584
|
+
return retracted_reqs, new_estimate_ratio, []
|
1585
|
+
|
1586
|
+
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
1587
|
+
req = self.reqs[idx]
|
1588
|
+
seq_lens_cpu = self.seq_lens_cpu.numpy()
|
1589
|
+
|
1590
|
+
if server_args.disaggregation_mode == "decode":
|
1591
|
+
req.offload_kv_cache(
|
1592
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
1593
|
+
)
|
1594
|
+
if isinstance(self.tree_cache, ChunkCache):
|
1595
|
+
# ChunkCache does not have eviction
|
1596
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
1597
|
+
req.req_pool_idx, : seq_lens_cpu[idx]
|
1598
|
+
]
|
1599
|
+
self.token_to_kv_pool_allocator.free(token_indices)
|
1600
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
1601
|
+
else:
|
1602
|
+
# TODO: apply more fine-grained retraction
|
1603
|
+
last_uncached_pos = (
|
1604
|
+
len(req.prefix_indices) // server_args.page_size
|
1605
|
+
) * server_args.page_size
|
1606
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
1607
|
+
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
1608
|
+
]
|
1609
|
+
self.token_to_kv_pool_allocator.free(token_indices)
|
1610
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
1611
|
+
|
1612
|
+
# release the last node
|
1613
|
+
if self.is_hybrid:
|
1614
|
+
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
1615
|
+
else:
|
1616
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
1617
|
+
|
1618
|
+
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
1619
|
+
num_tokens = remaing_req_count * global_config.retract_decode_steps
|
1620
|
+
self._evict_tree_cache_if_needed(num_tokens)
|
1621
|
+
|
1622
|
+
req.reset_for_retract()
|
1520
1623
|
|
1521
1624
|
def prepare_encoder_info_decode(self):
|
1522
1625
|
# Reset the encoder cached status
|
@@ -1526,6 +1629,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1526
1629
|
self.forward_mode = ForwardMode.IDLE
|
1527
1630
|
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
|
1528
1631
|
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
1632
|
+
self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
|
1529
1633
|
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
1530
1634
|
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
|
1531
1635
|
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
@@ -1540,7 +1644,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1540
1644
|
self.forward_mode = ForwardMode.DECODE
|
1541
1645
|
bs = len(self.reqs)
|
1542
1646
|
|
1543
|
-
if
|
1647
|
+
if (
|
1648
|
+
self.spec_algorithm.is_eagle()
|
1649
|
+
or self.spec_algorithm.is_standalone()
|
1650
|
+
or self.spec_algorithm.is_ngram()
|
1651
|
+
):
|
1544
1652
|
# if spec decoding is used, the decode batch is prepared inside
|
1545
1653
|
# `forward_batch_speculative_generation` after running draft models.
|
1546
1654
|
return
|
@@ -1581,10 +1689,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1581
1689
|
if self.enable_overlap:
|
1582
1690
|
# Do not use in-place operations in the overlap mode
|
1583
1691
|
self.seq_lens = self.seq_lens + 1
|
1692
|
+
self.seq_lens_cpu = self.seq_lens_cpu + 1
|
1584
1693
|
self.orig_seq_lens = self.orig_seq_lens + 1
|
1585
1694
|
else:
|
1586
1695
|
# A faster in-place version
|
1587
1696
|
self.seq_lens.add_(1)
|
1697
|
+
self.seq_lens_cpu.add_(1)
|
1588
1698
|
self.orig_seq_lens.add_(1)
|
1589
1699
|
self.seq_lens_sum += bs
|
1590
1700
|
|
@@ -1603,7 +1713,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1603
1713
|
self.req_pool_indices, self.seq_lens - 2
|
1604
1714
|
]
|
1605
1715
|
self.out_cache_loc = self.alloc_paged_token_slots_decode(
|
1606
|
-
self.seq_lens, last_loc
|
1716
|
+
self.seq_lens, self.seq_lens_cpu, last_loc
|
1607
1717
|
)
|
1608
1718
|
|
1609
1719
|
self.req_to_token_pool.write(
|
@@ -1649,6 +1759,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1649
1759
|
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
1650
1760
|
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
1651
1761
|
self.seq_lens = self.seq_lens[keep_indices_device]
|
1762
|
+
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
|
1652
1763
|
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
1653
1764
|
self.out_cache_loc = None
|
1654
1765
|
self.seq_lens_sum = self.seq_lens.sum().item()
|
@@ -1666,7 +1777,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1666
1777
|
|
1667
1778
|
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
|
1668
1779
|
if self.spec_info:
|
1669
|
-
|
1780
|
+
if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
|
1781
|
+
has_been_filtered = False
|
1782
|
+
else:
|
1783
|
+
has_been_filtered = True
|
1784
|
+
self.spec_info.filter_batch(
|
1785
|
+
new_indices=keep_indices_device,
|
1786
|
+
has_been_filtered=has_been_filtered,
|
1787
|
+
)
|
1670
1788
|
|
1671
1789
|
def merge_batch(self, other: "ScheduleBatch"):
|
1672
1790
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
@@ -1682,6 +1800,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1682
1800
|
[self.req_pool_indices, other.req_pool_indices]
|
1683
1801
|
)
|
1684
1802
|
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
|
1803
|
+
self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
|
1685
1804
|
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
|
1686
1805
|
self.out_cache_loc = None
|
1687
1806
|
self.seq_lens_sum += other.seq_lens_sum
|
@@ -1725,15 +1844,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1725
1844
|
self.sampling_info.grammars = None
|
1726
1845
|
|
1727
1846
|
seq_lens_cpu = (
|
1728
|
-
seq_lens_cpu_cache
|
1729
|
-
if seq_lens_cpu_cache is not None
|
1730
|
-
else self.seq_lens.cpu()
|
1847
|
+
seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
|
1731
1848
|
)
|
1732
1849
|
|
1733
|
-
global bid
|
1734
|
-
bid += 1
|
1735
1850
|
return ModelWorkerBatch(
|
1736
|
-
bid=bid,
|
1737
1851
|
forward_mode=self.forward_mode,
|
1738
1852
|
input_ids=self.input_ids,
|
1739
1853
|
req_pool_indices=self.req_pool_indices,
|
@@ -1779,7 +1893,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1779
1893
|
)
|
1780
1894
|
),
|
1781
1895
|
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
1782
|
-
|
1896
|
+
is_prefill_only=self.is_prefill_only,
|
1783
1897
|
)
|
1784
1898
|
|
1785
1899
|
def copy(self):
|
@@ -1852,8 +1966,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1852
1966
|
|
1853
1967
|
@dataclasses.dataclass
|
1854
1968
|
class ModelWorkerBatch:
|
1855
|
-
# The batch id
|
1856
|
-
bid: int
|
1857
1969
|
# The forward mode
|
1858
1970
|
forward_mode: ForwardMode
|
1859
1971
|
# The input ids
|
@@ -1914,19 +2026,25 @@ class ModelWorkerBatch:
|
|
1914
2026
|
|
1915
2027
|
# Speculative decoding
|
1916
2028
|
spec_algorithm: SpeculativeAlgorithm = None
|
1917
|
-
|
2029
|
+
|
2030
|
+
spec_info: Optional[SpecInput] = None
|
2031
|
+
|
1918
2032
|
# If set, the output of the batch contains the hidden states of the run.
|
1919
2033
|
capture_hidden_mode: CaptureHiddenMode = None
|
1920
|
-
hicache_consumer_index: int =
|
2034
|
+
hicache_consumer_index: int = -1
|
2035
|
+
|
2036
|
+
# Overlap scheduler related
|
2037
|
+
delay_sample_launch: bool = False
|
1921
2038
|
|
1922
|
-
#
|
1923
|
-
|
2039
|
+
# Whether this batch is prefill-only (no token generation needed)
|
2040
|
+
is_prefill_only: bool = False
|
1924
2041
|
|
1925
2042
|
|
1926
2043
|
@triton.jit
|
1927
2044
|
def write_req_to_token_pool_triton(
|
1928
2045
|
req_to_token_ptr, # [max_batch, max_context_len]
|
1929
2046
|
req_pool_indices,
|
2047
|
+
prefix_tensors,
|
1930
2048
|
pre_lens,
|
1931
2049
|
seq_lens,
|
1932
2050
|
extend_lens,
|
@@ -1939,6 +2057,19 @@ def write_req_to_token_pool_triton(
|
|
1939
2057
|
req_pool_index = tl.load(req_pool_indices + pid)
|
1940
2058
|
pre_len = tl.load(pre_lens + pid)
|
1941
2059
|
seq_len = tl.load(seq_lens + pid)
|
2060
|
+
prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64))
|
2061
|
+
|
2062
|
+
# write prefix
|
2063
|
+
num_loop = tl.cdiv(pre_len, BLOCK_SIZE)
|
2064
|
+
for i in range(num_loop):
|
2065
|
+
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
2066
|
+
mask = offset < pre_len
|
2067
|
+
value = tl.load(prefix_tensor + offset, mask=mask)
|
2068
|
+
tl.store(
|
2069
|
+
req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset,
|
2070
|
+
value,
|
2071
|
+
mask=mask,
|
2072
|
+
)
|
1942
2073
|
|
1943
2074
|
# NOTE: This can be slow for large bs
|
1944
2075
|
cumsum_start = tl.cast(0, tl.int64)
|