sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -36,7 +36,7 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
|
|
|
36
36
|
import copy
|
|
37
37
|
import dataclasses
|
|
38
38
|
import logging
|
|
39
|
-
import
|
|
39
|
+
import re
|
|
40
40
|
import time
|
|
41
41
|
from enum import Enum, auto
|
|
42
42
|
from http import HTTPStatus
|
|
@@ -45,10 +45,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
|
|
45
45
|
|
|
46
46
|
import numpy as np
|
|
47
47
|
import torch
|
|
48
|
-
import triton
|
|
49
|
-
import triton.language as tl
|
|
50
48
|
|
|
51
|
-
from sglang.global_config import global_config
|
|
52
49
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
|
53
50
|
from sglang.srt.disaggregation.base import BaseKVSender
|
|
54
51
|
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
@@ -56,68 +53,36 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
|
56
53
|
)
|
|
57
54
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
58
55
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
|
56
|
+
from sglang.srt.environ import envs
|
|
59
57
|
from sglang.srt.mem_cache.allocator import (
|
|
60
58
|
BaseTokenToKVPoolAllocator,
|
|
61
59
|
SWATokenToKVPoolAllocator,
|
|
62
60
|
)
|
|
63
61
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
|
64
|
-
from sglang.srt.mem_cache.chunk_cache import
|
|
65
|
-
from sglang.srt.mem_cache.
|
|
62
|
+
from sglang.srt.mem_cache.chunk_cache import SWAChunkCache
|
|
63
|
+
from sglang.srt.mem_cache.common import (
|
|
64
|
+
alloc_for_decode,
|
|
65
|
+
alloc_for_extend,
|
|
66
|
+
evict_from_tree_cache,
|
|
67
|
+
)
|
|
68
|
+
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
|
69
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
|
66
70
|
from sglang.srt.mem_cache.radix_cache import RadixKey
|
|
67
71
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
|
68
72
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
|
69
73
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
|
70
74
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|
71
75
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
|
72
|
-
from sglang.srt.server_args import ServerArgs
|
|
73
|
-
from sglang.srt.utils import flatten_nested_list
|
|
76
|
+
from sglang.srt.server_args import ServerArgs, get_global_server_args
|
|
77
|
+
from sglang.srt.utils import flatten_nested_list
|
|
74
78
|
|
|
75
79
|
if TYPE_CHECKING:
|
|
76
80
|
from sglang.srt.configs.model_config import ModelConfig
|
|
81
|
+
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
|
77
82
|
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
|
|
78
83
|
|
|
79
84
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
|
80
85
|
|
|
81
|
-
GLOBAL_SERVER_ARGS_KEYS = [
|
|
82
|
-
"attention_backend",
|
|
83
|
-
"mm_attention_backend",
|
|
84
|
-
"debug_tensor_dump_inject",
|
|
85
|
-
"debug_tensor_dump_output_folder",
|
|
86
|
-
"chunked_prefill_size",
|
|
87
|
-
"device",
|
|
88
|
-
"disable_chunked_prefix_cache",
|
|
89
|
-
"disable_flashinfer_cutlass_moe_fp4_allgather",
|
|
90
|
-
"disable_radix_cache",
|
|
91
|
-
"enable_dp_lm_head",
|
|
92
|
-
"enable_fp32_lm_head",
|
|
93
|
-
"flashinfer_mxfp4_moe_precision",
|
|
94
|
-
"enable_flashinfer_allreduce_fusion",
|
|
95
|
-
"moe_dense_tp_size",
|
|
96
|
-
"ep_dispatch_algorithm",
|
|
97
|
-
"ep_num_redundant_experts",
|
|
98
|
-
"enable_nan_detection",
|
|
99
|
-
"flashinfer_mla_disable_ragged",
|
|
100
|
-
"max_micro_batch_size",
|
|
101
|
-
"disable_shared_experts_fusion",
|
|
102
|
-
"sampling_backend",
|
|
103
|
-
"speculative_accept_threshold_single",
|
|
104
|
-
"speculative_accept_threshold_acc",
|
|
105
|
-
"speculative_attention_mode",
|
|
106
|
-
"torchao_config",
|
|
107
|
-
"triton_attention_reduce_in_fp32",
|
|
108
|
-
"num_reserved_decode_tokens",
|
|
109
|
-
"weight_loader_disable_mmap",
|
|
110
|
-
"enable_multimodal",
|
|
111
|
-
"enable_symm_mem",
|
|
112
|
-
"enable_custom_logit_processor",
|
|
113
|
-
"disaggregation_mode",
|
|
114
|
-
"enable_deterministic_inference",
|
|
115
|
-
"nsa_prefill",
|
|
116
|
-
"nsa_decode",
|
|
117
|
-
]
|
|
118
|
-
|
|
119
|
-
# Put some global args for easy access
|
|
120
|
-
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
|
|
121
86
|
|
|
122
87
|
logger = logging.getLogger(__name__)
|
|
123
88
|
|
|
@@ -154,6 +119,18 @@ class FINISH_MATCHED_STR(BaseFinishReason):
|
|
|
154
119
|
}
|
|
155
120
|
|
|
156
121
|
|
|
122
|
+
class FINISHED_MATCHED_REGEX(BaseFinishReason):
|
|
123
|
+
def __init__(self, matched: str):
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.matched = matched
|
|
126
|
+
|
|
127
|
+
def to_json(self):
|
|
128
|
+
return {
|
|
129
|
+
"type": "stop", # to match OpenAI API's return value
|
|
130
|
+
"matched": self.matched,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
|
|
157
134
|
class FINISH_LENGTH(BaseFinishReason):
|
|
158
135
|
def __init__(self, length: int):
|
|
159
136
|
super().__init__()
|
|
@@ -461,6 +438,7 @@ class Req:
|
|
|
461
438
|
priority: Optional[int] = None,
|
|
462
439
|
metrics_collector: Optional[SchedulerMetricsCollector] = None,
|
|
463
440
|
extra_key: Optional[str] = None,
|
|
441
|
+
http_worker_ipc: Optional[str] = None,
|
|
464
442
|
):
|
|
465
443
|
# Input and output info
|
|
466
444
|
self.rid = rid
|
|
@@ -484,6 +462,9 @@ class Req:
|
|
|
484
462
|
# The length of KV that have been removed in local attention chunked prefill
|
|
485
463
|
self.evicted_seqlen_local = 0
|
|
486
464
|
|
|
465
|
+
# For multi-http worker
|
|
466
|
+
self.http_worker_ipc = http_worker_ipc
|
|
467
|
+
|
|
487
468
|
# Sampling info
|
|
488
469
|
if isinstance(sampling_params.custom_params, dict):
|
|
489
470
|
sampling_params = copy.copy(sampling_params)
|
|
@@ -505,10 +486,13 @@ class Req:
|
|
|
505
486
|
|
|
506
487
|
# Memory pool info
|
|
507
488
|
self.req_pool_idx: Optional[int] = None
|
|
489
|
+
self.mamba_pool_idx: Optional[torch.Tensor] = None # shape (1)
|
|
508
490
|
|
|
509
491
|
# Check finish
|
|
510
492
|
self.tokenizer = None
|
|
511
493
|
self.finished_reason = None
|
|
494
|
+
# finished position (in output_ids), used when checking stop conditions with speculative decoding
|
|
495
|
+
self.finished_len = None
|
|
512
496
|
# Whether this request has finished output
|
|
513
497
|
self.finished_output = None
|
|
514
498
|
# If we want to abort the request in the middle of the event loop, set this to true
|
|
@@ -539,7 +523,7 @@ class Req:
|
|
|
539
523
|
|
|
540
524
|
# Prefix info
|
|
541
525
|
# The indices to kv cache for the shared prefix.
|
|
542
|
-
self.prefix_indices: torch.Tensor =
|
|
526
|
+
self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
|
|
543
527
|
# Number of tokens to run prefill.
|
|
544
528
|
self.extend_input_len = 0
|
|
545
529
|
# The relative logprob_start_len in an extend batch
|
|
@@ -630,6 +614,10 @@ class Req:
|
|
|
630
614
|
# This is used to compute the average acceptance length per request.
|
|
631
615
|
self.spec_verify_ct = 0
|
|
632
616
|
|
|
617
|
+
# The number of accepted tokens in speculative decoding for this request.
|
|
618
|
+
# This is used to compute the acceptance rate and average acceptance length per request.
|
|
619
|
+
self.spec_accepted_tokens = 0
|
|
620
|
+
|
|
633
621
|
# For metrics
|
|
634
622
|
self.metrics_collector = metrics_collector
|
|
635
623
|
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
|
|
@@ -666,10 +654,16 @@ class Req:
|
|
|
666
654
|
def is_prefill_only(self) -> bool:
|
|
667
655
|
"""Check if this request is prefill-only (no token generation needed)."""
|
|
668
656
|
# NOTE: when spec is enabled, prefill_only optimizations are disabled
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
657
|
+
|
|
658
|
+
spec_alg = get_global_server_args().speculative_algorithm
|
|
659
|
+
return self.sampling_params.max_new_tokens == 0 and spec_alg is None
|
|
660
|
+
|
|
661
|
+
@property
|
|
662
|
+
def output_ids_through_stop(self) -> List[int]:
|
|
663
|
+
"""Get the output ids through the stop condition. Stop position is included."""
|
|
664
|
+
if self.finished_len is not None:
|
|
665
|
+
return self.output_ids[: self.finished_len]
|
|
666
|
+
return self.output_ids
|
|
673
667
|
|
|
674
668
|
def add_latency(self, stage: RequestStage):
|
|
675
669
|
if self.metrics_collector is None:
|
|
@@ -691,11 +685,16 @@ class Req:
|
|
|
691
685
|
# Whether request reached finished condition
|
|
692
686
|
return self.finished_reason is not None
|
|
693
687
|
|
|
694
|
-
def init_next_round_input(
|
|
695
|
-
self,
|
|
696
|
-
tree_cache: Optional[BasePrefixCache] = None,
|
|
697
|
-
):
|
|
688
|
+
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
|
698
689
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
|
690
|
+
input_len = len(self.fill_ids)
|
|
691
|
+
# NOTE: the matched length is at most 1 less than the input length to enable logprob computation
|
|
692
|
+
max_prefix_len = input_len - 1
|
|
693
|
+
if self.return_logprob:
|
|
694
|
+
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
|
695
|
+
max_prefix_len = max(max_prefix_len, 0)
|
|
696
|
+
token_ids = self.fill_ids[:max_prefix_len]
|
|
697
|
+
|
|
699
698
|
if tree_cache is not None:
|
|
700
699
|
(
|
|
701
700
|
self.prefix_indices,
|
|
@@ -703,51 +702,146 @@ class Req:
|
|
|
703
702
|
self.last_host_node,
|
|
704
703
|
self.host_hit_length,
|
|
705
704
|
) = tree_cache.match_prefix(
|
|
706
|
-
key=RadixKey(
|
|
707
|
-
|
|
705
|
+
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
|
|
706
|
+
**(
|
|
707
|
+
{"req": self, "cow_mamba": True}
|
|
708
|
+
if isinstance(tree_cache, MambaRadixCache)
|
|
709
|
+
else {}
|
|
708
710
|
),
|
|
709
711
|
)
|
|
710
712
|
self.last_matched_prefix_len = len(self.prefix_indices)
|
|
711
713
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
|
712
714
|
|
|
713
|
-
def adjust_max_prefix_ids(self):
|
|
714
|
-
self.fill_ids = self.origin_input_ids + self.output_ids
|
|
715
|
-
input_len = len(self.fill_ids)
|
|
716
|
-
|
|
717
|
-
# FIXME: To work around some bugs in logprob computation, we need to ensure each
|
|
718
|
-
# request has at least one token. Later, we can relax this requirement and use `input_len`.
|
|
719
|
-
max_prefix_len = input_len - 1
|
|
720
|
-
|
|
721
|
-
if self.sampling_params.max_new_tokens > 0:
|
|
722
|
-
# Need at least one token to compute logits
|
|
723
|
-
max_prefix_len = min(max_prefix_len, input_len - 1)
|
|
724
|
-
|
|
725
|
-
if self.return_logprob:
|
|
726
|
-
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
|
727
|
-
|
|
728
|
-
max_prefix_len = max(max_prefix_len, 0)
|
|
729
|
-
return self.fill_ids[:max_prefix_len]
|
|
730
|
-
|
|
731
715
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
|
732
716
|
def init_incremental_detokenize(self):
|
|
733
717
|
first_iter = self.surr_offset is None or self.read_offset is None
|
|
734
718
|
|
|
719
|
+
output_ids = self.output_ids_through_stop
|
|
720
|
+
|
|
735
721
|
if first_iter:
|
|
736
722
|
self.read_offset = len(self.origin_input_ids_unpadded)
|
|
737
723
|
self.surr_offset = max(
|
|
738
724
|
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
|
|
739
725
|
)
|
|
740
726
|
self.surr_and_decode_ids = (
|
|
741
|
-
self.origin_input_ids_unpadded[self.surr_offset :] +
|
|
727
|
+
self.origin_input_ids_unpadded[self.surr_offset :] + output_ids
|
|
742
728
|
)
|
|
743
|
-
self.cur_decode_ids_len = len(
|
|
729
|
+
self.cur_decode_ids_len = len(output_ids)
|
|
744
730
|
else:
|
|
745
|
-
self.surr_and_decode_ids.extend(
|
|
746
|
-
self.cur_decode_ids_len = len(
|
|
731
|
+
self.surr_and_decode_ids.extend(output_ids[self.cur_decode_ids_len :])
|
|
732
|
+
self.cur_decode_ids_len = len(output_ids)
|
|
747
733
|
|
|
748
734
|
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
|
|
749
735
|
|
|
750
|
-
def
|
|
736
|
+
def tail_str(self) -> str:
|
|
737
|
+
# Check stop strings and stop regex patterns together
|
|
738
|
+
if (
|
|
739
|
+
len(self.sampling_params.stop_strs) > 0
|
|
740
|
+
or len(self.sampling_params.stop_regex_strs) > 0
|
|
741
|
+
):
|
|
742
|
+
max_len_tail_str = max(
|
|
743
|
+
self.sampling_params.stop_str_max_len + 1,
|
|
744
|
+
self.sampling_params.stop_regex_max_len + 1,
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
tail_len = min((max_len_tail_str + 1), len(self.output_ids))
|
|
748
|
+
return self.tokenizer.decode(self.output_ids[-tail_len:])
|
|
749
|
+
|
|
750
|
+
def check_match_stop_str_prefix(self) -> bool:
|
|
751
|
+
"""
|
|
752
|
+
Check if the suffix of tail_str overlaps with any stop_str prefix
|
|
753
|
+
"""
|
|
754
|
+
if not self.sampling_params.stop_strs:
|
|
755
|
+
return False
|
|
756
|
+
|
|
757
|
+
tail_str = self.tail_str()
|
|
758
|
+
|
|
759
|
+
# Early return if tail_str is empty
|
|
760
|
+
if not tail_str:
|
|
761
|
+
return False
|
|
762
|
+
|
|
763
|
+
for stop_str in self.sampling_params.stop_strs:
|
|
764
|
+
if not stop_str:
|
|
765
|
+
continue
|
|
766
|
+
# Check if stop_str is contained in tail_str (fastest check first)
|
|
767
|
+
if stop_str in tail_str:
|
|
768
|
+
return True
|
|
769
|
+
|
|
770
|
+
# Check if tail_str suffix matches stop_str prefix
|
|
771
|
+
# Only check if stop_str is not empty, it's for stream output
|
|
772
|
+
min_len = min(len(tail_str), len(stop_str))
|
|
773
|
+
for i in range(1, min_len + 1):
|
|
774
|
+
if tail_str[-i:] == stop_str[:i]:
|
|
775
|
+
return True
|
|
776
|
+
|
|
777
|
+
return False
|
|
778
|
+
|
|
779
|
+
def _check_token_based_finish(self, new_accepted_tokens: List[int]) -> bool:
|
|
780
|
+
if self.sampling_params.ignore_eos:
|
|
781
|
+
return False
|
|
782
|
+
|
|
783
|
+
# Check stop token ids
|
|
784
|
+
matched_eos = False
|
|
785
|
+
|
|
786
|
+
for i, token_id in enumerate(new_accepted_tokens):
|
|
787
|
+
if self.sampling_params.stop_token_ids:
|
|
788
|
+
matched_eos |= token_id in self.sampling_params.stop_token_ids
|
|
789
|
+
if self.eos_token_ids:
|
|
790
|
+
matched_eos |= token_id in self.eos_token_ids
|
|
791
|
+
if self.tokenizer is not None:
|
|
792
|
+
matched_eos |= token_id == self.tokenizer.eos_token_id
|
|
793
|
+
if self.tokenizer.additional_stop_token_ids:
|
|
794
|
+
matched_eos |= token_id in self.tokenizer.additional_stop_token_ids
|
|
795
|
+
if matched_eos:
|
|
796
|
+
self.finished_reason = FINISH_MATCHED_TOKEN(matched=token_id)
|
|
797
|
+
matched_pos = len(self.output_ids) - len(new_accepted_tokens) + i
|
|
798
|
+
self.finished_len = matched_pos + 1
|
|
799
|
+
return True
|
|
800
|
+
|
|
801
|
+
return False
|
|
802
|
+
|
|
803
|
+
def _check_str_based_finish(self):
|
|
804
|
+
if (
|
|
805
|
+
len(self.sampling_params.stop_strs) > 0
|
|
806
|
+
or len(self.sampling_params.stop_regex_strs) > 0
|
|
807
|
+
):
|
|
808
|
+
tail_str = self.tail_str()
|
|
809
|
+
|
|
810
|
+
# Check stop strings
|
|
811
|
+
if len(self.sampling_params.stop_strs) > 0:
|
|
812
|
+
for stop_str in self.sampling_params.stop_strs:
|
|
813
|
+
if stop_str in tail_str or stop_str in self.decoded_text:
|
|
814
|
+
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
|
815
|
+
return True
|
|
816
|
+
|
|
817
|
+
# Check stop regex
|
|
818
|
+
if len(self.sampling_params.stop_regex_strs) > 0:
|
|
819
|
+
for stop_regex_str in self.sampling_params.stop_regex_strs:
|
|
820
|
+
if re.search(stop_regex_str, tail_str):
|
|
821
|
+
self.finished_reason = FINISHED_MATCHED_REGEX(
|
|
822
|
+
matched=stop_regex_str
|
|
823
|
+
)
|
|
824
|
+
return True
|
|
825
|
+
|
|
826
|
+
return False
|
|
827
|
+
|
|
828
|
+
def _check_vocab_boundary_finish(self, new_accepted_tokens: List[int] = None):
|
|
829
|
+
for i, token_id in enumerate(new_accepted_tokens):
|
|
830
|
+
if token_id > self.vocab_size or token_id < 0:
|
|
831
|
+
offset = len(self.output_ids) - len(new_accepted_tokens) + i
|
|
832
|
+
if self.sampling_params.stop_token_ids:
|
|
833
|
+
self.output_ids[offset] = next(
|
|
834
|
+
iter(self.sampling_params.stop_token_ids)
|
|
835
|
+
)
|
|
836
|
+
if self.eos_token_ids:
|
|
837
|
+
self.output_ids[offset] = next(iter(self.eos_token_ids))
|
|
838
|
+
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
|
|
839
|
+
self.finished_len = offset + 1
|
|
840
|
+
return True
|
|
841
|
+
|
|
842
|
+
return False
|
|
843
|
+
|
|
844
|
+
def check_finished(self, new_accepted_len: int = 1):
|
|
751
845
|
if self.finished():
|
|
752
846
|
return
|
|
753
847
|
|
|
@@ -761,6 +855,7 @@ class Req:
|
|
|
761
855
|
self.finished_reason = FINISH_LENGTH(
|
|
762
856
|
length=self.sampling_params.max_new_tokens
|
|
763
857
|
)
|
|
858
|
+
self.finished_len = self.sampling_params.max_new_tokens
|
|
764
859
|
return
|
|
765
860
|
|
|
766
861
|
if self.grammar is not None:
|
|
@@ -768,47 +863,19 @@ class Req:
|
|
|
768
863
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
|
|
769
864
|
return
|
|
770
865
|
|
|
771
|
-
|
|
866
|
+
new_accepted_tokens = self.output_ids[-new_accepted_len:]
|
|
772
867
|
|
|
773
|
-
if
|
|
774
|
-
matched_eos = False
|
|
775
|
-
|
|
776
|
-
# Check stop token ids
|
|
777
|
-
if self.sampling_params.stop_token_ids:
|
|
778
|
-
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
|
779
|
-
if self.eos_token_ids:
|
|
780
|
-
matched_eos |= last_token_id in self.eos_token_ids
|
|
781
|
-
if self.tokenizer is not None:
|
|
782
|
-
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
|
783
|
-
if self.tokenizer.additional_stop_token_ids:
|
|
784
|
-
matched_eos |= (
|
|
785
|
-
last_token_id in self.tokenizer.additional_stop_token_ids
|
|
786
|
-
)
|
|
787
|
-
if matched_eos:
|
|
788
|
-
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
|
789
|
-
return
|
|
790
|
-
|
|
791
|
-
if last_token_id > self.vocab_size or last_token_id < 0:
|
|
792
|
-
if self.sampling_params.stop_token_ids:
|
|
793
|
-
self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
|
|
794
|
-
if self.eos_token_ids:
|
|
795
|
-
self.output_ids[-1] = next(iter(self.eos_token_ids))
|
|
796
|
-
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
|
|
868
|
+
if self._check_token_based_finish(new_accepted_tokens):
|
|
797
869
|
return
|
|
798
870
|
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
tail_str = self.tokenizer.decode(
|
|
802
|
-
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
|
803
|
-
)
|
|
871
|
+
if self._check_vocab_boundary_finish(new_accepted_tokens):
|
|
872
|
+
return
|
|
804
873
|
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
|
808
|
-
return
|
|
874
|
+
if self._check_str_based_finish():
|
|
875
|
+
return
|
|
809
876
|
|
|
810
877
|
def reset_for_retract(self):
|
|
811
|
-
self.prefix_indices =
|
|
878
|
+
self.prefix_indices = torch.empty((0,), dtype=torch.int64)
|
|
812
879
|
self.last_node = None
|
|
813
880
|
self.swa_uuid_for_lock = None
|
|
814
881
|
self.extend_input_len = 0
|
|
@@ -818,7 +885,7 @@ class Req:
|
|
|
818
885
|
self.temp_input_top_logprobs_idx = None
|
|
819
886
|
self.extend_logprob_start_len = 0
|
|
820
887
|
self.is_chunked = 0
|
|
821
|
-
self.
|
|
888
|
+
self.mamba_pool_idx = None
|
|
822
889
|
self.already_computed = 0
|
|
823
890
|
|
|
824
891
|
def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
|
|
@@ -886,15 +953,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
886
953
|
# This is an optimization to reduce the overhead of the prefill check.
|
|
887
954
|
batch_is_full: bool = False
|
|
888
955
|
|
|
889
|
-
# Events
|
|
890
|
-
launch_done: Optional[threading.Event] = None
|
|
891
|
-
|
|
892
956
|
# For chunked prefill in PP
|
|
893
957
|
chunked_req: Optional[Req] = None
|
|
894
958
|
|
|
895
959
|
# Sampling info
|
|
896
960
|
sampling_info: SamplingBatchInfo = None
|
|
897
|
-
next_batch_sampling_info: SamplingBatchInfo = None
|
|
898
961
|
|
|
899
962
|
# Batched arguments to model runner
|
|
900
963
|
input_ids: torch.Tensor = None # shape: [b], int64
|
|
@@ -1017,117 +1080,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1017
1080
|
def is_empty(self):
|
|
1018
1081
|
return len(self.reqs) == 0
|
|
1019
1082
|
|
|
1020
|
-
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
|
|
1021
|
-
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
|
|
1022
|
-
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
|
|
1023
|
-
else:
|
|
1024
|
-
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
|
1025
|
-
if req_pool_indices is None:
|
|
1026
|
-
raise RuntimeError(
|
|
1027
|
-
"alloc_req_slots runs out of memory. "
|
|
1028
|
-
"Please set a smaller number for `--max-running-requests`. "
|
|
1029
|
-
f"{self.req_to_token_pool.available_size()=}, "
|
|
1030
|
-
f"{num_reqs=}, "
|
|
1031
|
-
)
|
|
1032
|
-
return req_pool_indices
|
|
1033
|
-
|
|
1034
|
-
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
|
|
1035
|
-
self._evict_tree_cache_if_needed(num_tokens)
|
|
1036
|
-
|
|
1037
|
-
if backup_state:
|
|
1038
|
-
state = self.token_to_kv_pool_allocator.backup_state()
|
|
1039
|
-
|
|
1040
|
-
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
|
1041
|
-
if out_cache_loc is None:
|
|
1042
|
-
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
|
1043
|
-
error_msg = (
|
|
1044
|
-
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
|
1045
|
-
f"Try to allocate {num_tokens} tokens.\n"
|
|
1046
|
-
f"{self._available_and_evictable_str()}"
|
|
1047
|
-
)
|
|
1048
|
-
logger.error(error_msg)
|
|
1049
|
-
if self.tree_cache is not None:
|
|
1050
|
-
self.tree_cache.pretty_print()
|
|
1051
|
-
raise RuntimeError(error_msg)
|
|
1052
|
-
|
|
1053
|
-
if backup_state:
|
|
1054
|
-
return out_cache_loc, state
|
|
1055
|
-
else:
|
|
1056
|
-
return out_cache_loc
|
|
1057
|
-
|
|
1058
|
-
def alloc_paged_token_slots_extend(
|
|
1059
|
-
self,
|
|
1060
|
-
prefix_lens: torch.Tensor,
|
|
1061
|
-
prefix_lens_cpu: torch.Tensor,
|
|
1062
|
-
seq_lens: torch.Tensor,
|
|
1063
|
-
seq_lens_cpu: torch.Tensor,
|
|
1064
|
-
last_loc: torch.Tensor,
|
|
1065
|
-
extend_num_tokens: int,
|
|
1066
|
-
backup_state: bool = False,
|
|
1067
|
-
):
|
|
1068
|
-
# Over estimate the number of tokens: assume each request needs a new page.
|
|
1069
|
-
num_tokens = (
|
|
1070
|
-
extend_num_tokens
|
|
1071
|
-
+ len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
|
|
1072
|
-
)
|
|
1073
|
-
self._evict_tree_cache_if_needed(num_tokens)
|
|
1074
|
-
|
|
1075
|
-
if backup_state:
|
|
1076
|
-
state = self.token_to_kv_pool_allocator.backup_state()
|
|
1077
|
-
|
|
1078
|
-
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
|
1079
|
-
prefix_lens,
|
|
1080
|
-
prefix_lens_cpu,
|
|
1081
|
-
seq_lens,
|
|
1082
|
-
seq_lens_cpu,
|
|
1083
|
-
last_loc,
|
|
1084
|
-
extend_num_tokens,
|
|
1085
|
-
)
|
|
1086
|
-
if out_cache_loc is None:
|
|
1087
|
-
error_msg = (
|
|
1088
|
-
f"Prefill out of memory. Try to lower your batch size.\n"
|
|
1089
|
-
f"Try to allocate {extend_num_tokens} tokens.\n"
|
|
1090
|
-
f"{self._available_and_evictable_str()}"
|
|
1091
|
-
)
|
|
1092
|
-
logger.error(error_msg)
|
|
1093
|
-
raise RuntimeError(error_msg)
|
|
1094
|
-
|
|
1095
|
-
if backup_state:
|
|
1096
|
-
return out_cache_loc, state
|
|
1097
|
-
else:
|
|
1098
|
-
return out_cache_loc
|
|
1099
|
-
|
|
1100
|
-
def alloc_paged_token_slots_decode(
|
|
1101
|
-
self,
|
|
1102
|
-
seq_lens: torch.Tensor,
|
|
1103
|
-
seq_lens_cpu: torch.Tensor,
|
|
1104
|
-
last_loc: torch.Tensor,
|
|
1105
|
-
backup_state: bool = False,
|
|
1106
|
-
):
|
|
1107
|
-
# Over estimate the number of tokens: assume each request needs a new page.
|
|
1108
|
-
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
|
1109
|
-
self._evict_tree_cache_if_needed(num_tokens)
|
|
1110
|
-
|
|
1111
|
-
if backup_state:
|
|
1112
|
-
state = self.token_to_kv_pool_allocator.backup_state()
|
|
1113
|
-
|
|
1114
|
-
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
|
|
1115
|
-
seq_lens, seq_lens_cpu, last_loc
|
|
1116
|
-
)
|
|
1117
|
-
if out_cache_loc is None:
|
|
1118
|
-
error_msg = (
|
|
1119
|
-
f"Decode out of memory. Try to lower your batch size.\n"
|
|
1120
|
-
f"Try to allocate {len(seq_lens)} tokens.\n"
|
|
1121
|
-
f"{self._available_and_evictable_str()}"
|
|
1122
|
-
)
|
|
1123
|
-
logger.error(error_msg)
|
|
1124
|
-
raise RuntimeError(error_msg)
|
|
1125
|
-
|
|
1126
|
-
if backup_state:
|
|
1127
|
-
return out_cache_loc, state
|
|
1128
|
-
else:
|
|
1129
|
-
return out_cache_loc
|
|
1130
|
-
|
|
1131
1083
|
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
|
1132
1084
|
self.encoder_lens_cpu = []
|
|
1133
1085
|
self.encoder_cached = []
|
|
@@ -1205,10 +1157,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1205
1157
|
def prepare_for_extend(self):
|
|
1206
1158
|
self.forward_mode = ForwardMode.EXTEND
|
|
1207
1159
|
|
|
1208
|
-
# Allocate req slots
|
|
1209
|
-
bs = len(self.reqs)
|
|
1210
|
-
req_pool_indices = self.alloc_req_slots(bs, self.reqs)
|
|
1211
|
-
|
|
1212
1160
|
# Init tensors
|
|
1213
1161
|
reqs = self.reqs
|
|
1214
1162
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
|
@@ -1222,9 +1170,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1222
1170
|
r.token_type_ids for r in reqs if r.token_type_ids is not None
|
|
1223
1171
|
]
|
|
1224
1172
|
|
|
1225
|
-
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
|
1226
|
-
self.device, non_blocking=True
|
|
1227
|
-
)
|
|
1228
1173
|
input_ids_tensor = torch.tensor(
|
|
1229
1174
|
list(chain.from_iterable(input_ids)), dtype=torch.int64
|
|
1230
1175
|
).to(self.device, non_blocking=True)
|
|
@@ -1235,10 +1180,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1235
1180
|
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
|
1236
1181
|
self.device, non_blocking=True
|
|
1237
1182
|
)
|
|
1238
|
-
prefix_lens_tensor = torch.tensor(
|
|
1239
|
-
prefix_lens, dtype=torch.int64, device=self.device
|
|
1240
|
-
)
|
|
1241
|
-
prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
|
|
1242
1183
|
|
|
1243
1184
|
token_type_ids_tensor = None
|
|
1244
1185
|
if len(token_type_ids) > 0:
|
|
@@ -1246,9 +1187,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1246
1187
|
sum(token_type_ids, []), dtype=torch.int64
|
|
1247
1188
|
).to(self.device, non_blocking=True)
|
|
1248
1189
|
|
|
1249
|
-
|
|
1190
|
+
# Set batch fields needed by alloc_for_extend
|
|
1191
|
+
self.prefix_lens = prefix_lens
|
|
1192
|
+
self.extend_lens = extend_lens
|
|
1193
|
+
self.seq_lens = seq_lens_tensor
|
|
1194
|
+
self.seq_lens_cpu = seq_lens_cpu
|
|
1195
|
+
self.extend_num_tokens = extend_num_tokens
|
|
1196
|
+
|
|
1197
|
+
# Allocate memory
|
|
1198
|
+
out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
|
|
1199
|
+
self
|
|
1200
|
+
)
|
|
1250
1201
|
|
|
1251
|
-
#
|
|
1202
|
+
# Set fields
|
|
1252
1203
|
input_embeds = []
|
|
1253
1204
|
extend_input_logprob_token_ids = []
|
|
1254
1205
|
multimodal_inputs = []
|
|
@@ -1257,15 +1208,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1257
1208
|
req.req_pool_idx = req_pool_indices[i]
|
|
1258
1209
|
assert seq_len - pre_len == req.extend_input_len
|
|
1259
1210
|
|
|
1260
|
-
if pre_len > 0:
|
|
1261
|
-
self.req_to_token_pool.write(
|
|
1262
|
-
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
|
1263
|
-
)
|
|
1264
|
-
if isinstance(self.tree_cache, SWAChunkCache):
|
|
1265
|
-
self.tree_cache.evict_swa(
|
|
1266
|
-
req, pre_len, self.model_config.attention_chunk_size
|
|
1267
|
-
)
|
|
1268
|
-
|
|
1269
1211
|
# If input_embeds are available, store them
|
|
1270
1212
|
if req.input_embeds is not None:
|
|
1271
1213
|
# If req.input_embeds is already a list, append its content directly
|
|
@@ -1355,29 +1297,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1355
1297
|
else:
|
|
1356
1298
|
extend_input_logprob_token_ids = None
|
|
1357
1299
|
|
|
1358
|
-
# Allocate memory
|
|
1359
|
-
if self.token_to_kv_pool_allocator.page_size == 1:
|
|
1360
|
-
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
|
1361
|
-
else:
|
|
1362
|
-
last_loc = get_last_loc(
|
|
1363
|
-
self.req_to_token_pool.req_to_token,
|
|
1364
|
-
req_pool_indices_tensor,
|
|
1365
|
-
prefix_lens_tensor,
|
|
1366
|
-
)
|
|
1367
|
-
out_cache_loc = self.alloc_paged_token_slots_extend(
|
|
1368
|
-
prefix_lens_tensor,
|
|
1369
|
-
prefix_lens_cpu_tensor,
|
|
1370
|
-
seq_lens_tensor,
|
|
1371
|
-
seq_lens_cpu,
|
|
1372
|
-
last_loc,
|
|
1373
|
-
extend_num_tokens,
|
|
1374
|
-
)
|
|
1375
|
-
|
|
1376
|
-
# Set fields
|
|
1377
1300
|
self.input_ids = input_ids_tensor
|
|
1378
1301
|
self.req_pool_indices = req_pool_indices_tensor
|
|
1379
|
-
self.seq_lens = seq_lens_tensor
|
|
1380
|
-
self.seq_lens_cpu = seq_lens_cpu
|
|
1381
1302
|
self.orig_seq_lens = orig_seq_lens_tensor
|
|
1382
1303
|
self.out_cache_loc = out_cache_loc
|
|
1383
1304
|
self.input_embeds = (
|
|
@@ -1401,33 +1322,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1401
1322
|
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
|
1402
1323
|
|
|
1403
1324
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
|
1404
|
-
self.extend_num_tokens = extend_num_tokens
|
|
1405
|
-
self.prefix_lens = prefix_lens
|
|
1406
|
-
self.extend_lens = extend_lens
|
|
1407
1325
|
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
|
1408
1326
|
|
|
1409
|
-
# Write to req_to_token_pool
|
|
1410
|
-
if support_triton(global_server_args_dict.get("attention_backend")):
|
|
1411
|
-
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
|
1412
|
-
|
|
1413
|
-
write_req_to_token_pool_triton[(bs,)](
|
|
1414
|
-
self.req_to_token_pool.req_to_token,
|
|
1415
|
-
req_pool_indices_tensor,
|
|
1416
|
-
prefix_lens_tensor,
|
|
1417
|
-
seq_lens_tensor,
|
|
1418
|
-
extend_lens_tensor,
|
|
1419
|
-
out_cache_loc,
|
|
1420
|
-
self.req_to_token_pool.req_to_token.shape[1],
|
|
1421
|
-
)
|
|
1422
|
-
else:
|
|
1423
|
-
pt = 0
|
|
1424
|
-
for i in range(bs):
|
|
1425
|
-
self.req_to_token_pool.write(
|
|
1426
|
-
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
|
|
1427
|
-
out_cache_loc[pt : pt + extend_lens[i]],
|
|
1428
|
-
)
|
|
1429
|
-
pt += extend_lens[i]
|
|
1430
|
-
|
|
1431
1327
|
if self.model_config.is_encoder_decoder:
|
|
1432
1328
|
self.prepare_encoder_info_extend(input_ids, seq_lens)
|
|
1433
1329
|
|
|
@@ -1498,7 +1394,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1498
1394
|
* self.token_to_kv_pool_allocator.page_size
|
|
1499
1395
|
)
|
|
1500
1396
|
|
|
1501
|
-
self.
|
|
1397
|
+
evict_from_tree_cache(self.tree_cache, num_tokens)
|
|
1502
1398
|
return self._is_available_size_sufficient(num_tokens)
|
|
1503
1399
|
|
|
1504
1400
|
def retract_decode(self, server_args: ServerArgs):
|
|
@@ -1546,6 +1442,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1546
1442
|
idx = sorted_indices.pop()
|
|
1547
1443
|
req = self.reqs[idx]
|
|
1548
1444
|
retracted_reqs.append(req)
|
|
1445
|
+
# release memory and don't insert into the tree because we need the space instantly
|
|
1549
1446
|
self.release_req(idx, len(sorted_indices), server_args)
|
|
1550
1447
|
|
|
1551
1448
|
if len(retracted_reqs) == 0:
|
|
@@ -1561,47 +1458,27 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1561
1458
|
total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
|
|
1562
1459
|
|
|
1563
1460
|
new_estimate_ratio = (
|
|
1564
|
-
total_decoded_tokens
|
|
1565
|
-
|
|
1461
|
+
total_decoded_tokens
|
|
1462
|
+
+ envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs)
|
|
1463
|
+
) / (
|
|
1464
|
+
total_max_new_tokens + 1
|
|
1465
|
+
) # avoid zero division
|
|
1566
1466
|
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
|
1567
1467
|
|
|
1568
1468
|
return retracted_reqs, new_estimate_ratio, []
|
|
1569
1469
|
|
|
1570
1470
|
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
|
1571
1471
|
req = self.reqs[idx]
|
|
1572
|
-
seq_lens_cpu = self.seq_lens_cpu.numpy()
|
|
1573
1472
|
|
|
1574
1473
|
if server_args.disaggregation_mode == "decode":
|
|
1575
1474
|
req.offload_kv_cache(
|
|
1576
1475
|
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
|
1577
1476
|
)
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
self.token_to_kv_pool_allocator.free(token_indices)
|
|
1584
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
|
1585
|
-
else:
|
|
1586
|
-
# TODO: apply more fine-grained retraction
|
|
1587
|
-
last_uncached_pos = (
|
|
1588
|
-
len(req.prefix_indices) // server_args.page_size
|
|
1589
|
-
) * server_args.page_size
|
|
1590
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
|
1591
|
-
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
|
1592
|
-
]
|
|
1593
|
-
self.token_to_kv_pool_allocator.free(token_indices)
|
|
1594
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
|
1595
|
-
|
|
1596
|
-
# release the last node
|
|
1597
|
-
if self.is_hybrid:
|
|
1598
|
-
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
|
1599
|
-
else:
|
|
1600
|
-
self.tree_cache.dec_lock_ref(req.last_node)
|
|
1601
|
-
|
|
1602
|
-
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
|
1603
|
-
num_tokens = remaing_req_count * global_config.retract_decode_steps
|
|
1604
|
-
self._evict_tree_cache_if_needed(num_tokens)
|
|
1477
|
+
# TODO (csy): for preempted requests, we may want to insert into the tree
|
|
1478
|
+
self.tree_cache.cache_finished_req(req, is_insert=False)
|
|
1479
|
+
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
|
1480
|
+
num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
|
|
1481
|
+
evict_from_tree_cache(self.tree_cache, num_tokens)
|
|
1605
1482
|
|
|
1606
1483
|
req.reset_for_retract()
|
|
1607
1484
|
|
|
@@ -1624,15 +1501,21 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1624
1501
|
self.model_config.vocab_size,
|
|
1625
1502
|
)
|
|
1626
1503
|
|
|
1504
|
+
@property
|
|
1505
|
+
def is_v2_eagle(self):
|
|
1506
|
+
# FIXME: finally deprecate is_v2_eagle
|
|
1507
|
+
return self.enable_overlap and self.spec_algorithm.is_eagle()
|
|
1508
|
+
|
|
1627
1509
|
def prepare_for_decode(self):
|
|
1628
1510
|
self.forward_mode = ForwardMode.DECODE
|
|
1629
1511
|
bs = len(self.reqs)
|
|
1630
1512
|
|
|
1631
|
-
if
|
|
1632
|
-
|
|
1633
|
-
|
|
1634
|
-
|
|
1635
|
-
|
|
1513
|
+
if self.is_v2_eagle:
|
|
1514
|
+
# TODO(spec-v2): all v2 spec should go through this path
|
|
1515
|
+
draft_input: EagleDraftInput = self.spec_info
|
|
1516
|
+
draft_input.prepare_for_decode(self)
|
|
1517
|
+
|
|
1518
|
+
if not self.spec_algorithm.is_none():
|
|
1636
1519
|
# if spec decoding is used, the decode batch is prepared inside
|
|
1637
1520
|
# `forward_batch_speculative_generation` after running draft models.
|
|
1638
1521
|
return
|
|
@@ -1665,11 +1548,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1665
1548
|
self.output_ids = None
|
|
1666
1549
|
|
|
1667
1550
|
if self.model_config.is_encoder_decoder:
|
|
1668
|
-
locs = self.encoder_lens + self.seq_lens
|
|
1669
1551
|
self.prepare_encoder_info_decode()
|
|
1670
|
-
else:
|
|
1671
|
-
locs = self.seq_lens.clone()
|
|
1672
1552
|
|
|
1553
|
+
# Allocate memory
|
|
1554
|
+
self.out_cache_loc = alloc_for_decode(self, token_per_req=1)
|
|
1555
|
+
|
|
1556
|
+
# Update seq_lens after allocation
|
|
1673
1557
|
if self.enable_overlap:
|
|
1674
1558
|
# Do not use in-place operations in the overlap mode
|
|
1675
1559
|
self.seq_lens = self.seq_lens + 1
|
|
@@ -1682,33 +1566,21 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1682
1566
|
self.orig_seq_lens.add_(1)
|
|
1683
1567
|
self.seq_lens_sum += bs
|
|
1684
1568
|
|
|
1685
|
-
|
|
1686
|
-
if
|
|
1687
|
-
|
|
1688
|
-
|
|
1689
|
-
|
|
1690
|
-
)
|
|
1691
|
-
|
|
1692
|
-
# Allocate memory
|
|
1693
|
-
if self.token_to_kv_pool_allocator.page_size == 1:
|
|
1694
|
-
self.out_cache_loc = self.alloc_token_slots(bs)
|
|
1695
|
-
else:
|
|
1696
|
-
last_loc = self.req_to_token_pool.req_to_token[
|
|
1697
|
-
self.req_pool_indices, self.seq_lens - 2
|
|
1698
|
-
]
|
|
1699
|
-
self.out_cache_loc = self.alloc_paged_token_slots_decode(
|
|
1700
|
-
self.seq_lens, self.seq_lens_cpu, last_loc
|
|
1701
|
-
)
|
|
1702
|
-
|
|
1703
|
-
self.req_to_token_pool.write(
|
|
1704
|
-
(self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
|
|
1705
|
-
)
|
|
1569
|
+
def maybe_wait_verify_done(self):
|
|
1570
|
+
if self.is_v2_eagle:
|
|
1571
|
+
draft_input: EagleDraftInput = self.spec_info
|
|
1572
|
+
if draft_input.verify_done is not None:
|
|
1573
|
+
draft_input.verify_done.synchronize()
|
|
1706
1574
|
|
|
1707
1575
|
def filter_batch(
|
|
1708
1576
|
self,
|
|
1709
1577
|
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
|
|
1710
1578
|
keep_indices: Optional[List[int]] = None,
|
|
1711
1579
|
):
|
|
1580
|
+
# FIXME(lsyin): used here to get the correct seq_lens
|
|
1581
|
+
# The batch has been launched but we need it verified to get correct next batch info
|
|
1582
|
+
self.maybe_wait_verify_done()
|
|
1583
|
+
|
|
1712
1584
|
if keep_indices is None:
|
|
1713
1585
|
if isinstance(chunked_req_to_exclude, Req):
|
|
1714
1586
|
chunked_req_to_exclude = [chunked_req_to_exclude]
|
|
@@ -1771,6 +1643,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1771
1643
|
)
|
|
1772
1644
|
|
|
1773
1645
|
def merge_batch(self, other: "ScheduleBatch"):
|
|
1646
|
+
# NOTE: in v2 eagle mode, we do not need wait verify here because
|
|
1647
|
+
# 1) current batch is always prefill, whose seq_lens and allocate_lens are not a future
|
|
1648
|
+
# 2) other batch is always decode, which is finished in previous step
|
|
1649
|
+
|
|
1774
1650
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
|
1775
1651
|
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
|
1776
1652
|
# needs to be called with pre-merged Batch.reqs.
|
|
@@ -1877,7 +1753,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1877
1753
|
)
|
|
1878
1754
|
),
|
|
1879
1755
|
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
|
1880
|
-
launch_done=self.launch_done,
|
|
1881
1756
|
is_prefill_only=self.is_prefill_only,
|
|
1882
1757
|
)
|
|
1883
1758
|
|
|
@@ -1885,6 +1760,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1885
1760
|
# Only contain fields that will be used by process_batch_result
|
|
1886
1761
|
return ScheduleBatch(
|
|
1887
1762
|
reqs=self.reqs,
|
|
1763
|
+
req_to_token_pool=self.req_to_token_pool,
|
|
1764
|
+
req_pool_indices=self.req_pool_indices,
|
|
1888
1765
|
model_config=self.model_config,
|
|
1889
1766
|
forward_mode=self.forward_mode,
|
|
1890
1767
|
out_cache_loc=self.out_cache_loc,
|
|
@@ -1896,26 +1773,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1896
1773
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
|
1897
1774
|
is_extend_in_batch=self.is_extend_in_batch,
|
|
1898
1775
|
is_prefill_only=self.is_prefill_only,
|
|
1776
|
+
seq_lens_cpu=self.seq_lens_cpu,
|
|
1777
|
+
enable_overlap=self.enable_overlap,
|
|
1899
1778
|
)
|
|
1900
1779
|
|
|
1901
|
-
def _evict_tree_cache_if_needed(self, num_tokens: int):
|
|
1902
|
-
if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
|
|
1903
|
-
return
|
|
1904
|
-
|
|
1905
|
-
if self.is_hybrid:
|
|
1906
|
-
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
|
1907
|
-
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
|
1908
|
-
|
|
1909
|
-
if full_available_size < num_tokens or swa_available_size < num_tokens:
|
|
1910
|
-
if self.tree_cache is not None:
|
|
1911
|
-
full_num_tokens = max(0, num_tokens - full_available_size)
|
|
1912
|
-
swa_num_tokens = max(0, num_tokens - swa_available_size)
|
|
1913
|
-
self.tree_cache.evict(full_num_tokens, swa_num_tokens)
|
|
1914
|
-
else:
|
|
1915
|
-
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
|
1916
|
-
if self.tree_cache is not None:
|
|
1917
|
-
self.tree_cache.evict(num_tokens)
|
|
1918
|
-
|
|
1919
1780
|
def _is_available_size_sufficient(self, num_tokens: int) -> bool:
|
|
1920
1781
|
if self.is_hybrid:
|
|
1921
1782
|
return (
|
|
@@ -1925,23 +1786,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
|
1925
1786
|
else:
|
|
1926
1787
|
return self.token_to_kv_pool_allocator.available_size() >= num_tokens
|
|
1927
1788
|
|
|
1928
|
-
def _available_and_evictable_str(self) -> str:
|
|
1929
|
-
if self.is_hybrid:
|
|
1930
|
-
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
|
1931
|
-
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
|
1932
|
-
full_evictable_size = self.tree_cache.full_evictable_size()
|
|
1933
|
-
swa_evictable_size = self.tree_cache.swa_evictable_size()
|
|
1934
|
-
return (
|
|
1935
|
-
f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
|
|
1936
|
-
f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
|
|
1937
|
-
f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
|
|
1938
|
-
f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
|
|
1939
|
-
)
|
|
1940
|
-
else:
|
|
1941
|
-
available_size = self.token_to_kv_pool_allocator.available_size()
|
|
1942
|
-
evictable_size = self.tree_cache.evictable_size()
|
|
1943
|
-
return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
|
|
1944
|
-
|
|
1945
1789
|
def __str__(self):
|
|
1946
1790
|
return (
|
|
1947
1791
|
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
|
|
@@ -2018,119 +1862,5 @@ class ModelWorkerBatch:
|
|
|
2018
1862
|
capture_hidden_mode: CaptureHiddenMode = None
|
|
2019
1863
|
hicache_consumer_index: int = -1
|
|
2020
1864
|
|
|
2021
|
-
# Overlap event
|
|
2022
|
-
launch_done: Optional[threading.Event] = None
|
|
2023
|
-
|
|
2024
1865
|
# Whether this batch is prefill-only (no token generation needed)
|
|
2025
1866
|
is_prefill_only: bool = False
|
|
2026
|
-
|
|
2027
|
-
|
|
2028
|
-
@triton.jit
|
|
2029
|
-
def write_req_to_token_pool_triton(
|
|
2030
|
-
req_to_token_ptr, # [max_batch, max_context_len]
|
|
2031
|
-
req_pool_indices,
|
|
2032
|
-
pre_lens,
|
|
2033
|
-
seq_lens,
|
|
2034
|
-
extend_lens,
|
|
2035
|
-
out_cache_loc,
|
|
2036
|
-
req_to_token_ptr_stride: tl.constexpr,
|
|
2037
|
-
):
|
|
2038
|
-
BLOCK_SIZE: tl.constexpr = 512
|
|
2039
|
-
pid = tl.program_id(0)
|
|
2040
|
-
|
|
2041
|
-
req_pool_index = tl.load(req_pool_indices + pid)
|
|
2042
|
-
pre_len = tl.load(pre_lens + pid)
|
|
2043
|
-
seq_len = tl.load(seq_lens + pid)
|
|
2044
|
-
|
|
2045
|
-
# NOTE: This can be slow for large bs
|
|
2046
|
-
cumsum_start = tl.cast(0, tl.int64)
|
|
2047
|
-
for i in range(pid):
|
|
2048
|
-
cumsum_start += tl.load(extend_lens + i)
|
|
2049
|
-
|
|
2050
|
-
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
|
|
2051
|
-
for i in range(num_loop):
|
|
2052
|
-
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
|
2053
|
-
mask = offset < (seq_len - pre_len)
|
|
2054
|
-
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
|
|
2055
|
-
tl.store(
|
|
2056
|
-
req_to_token_ptr
|
|
2057
|
-
+ req_pool_index * req_to_token_ptr_stride
|
|
2058
|
-
+ offset
|
|
2059
|
-
+ pre_len,
|
|
2060
|
-
value,
|
|
2061
|
-
mask=mask,
|
|
2062
|
-
)
|
|
2063
|
-
|
|
2064
|
-
|
|
2065
|
-
def get_last_loc(
|
|
2066
|
-
req_to_token: torch.Tensor,
|
|
2067
|
-
req_pool_indices_tensor: torch.Tensor,
|
|
2068
|
-
prefix_lens_tensor: torch.Tensor,
|
|
2069
|
-
) -> torch.Tensor:
|
|
2070
|
-
if (
|
|
2071
|
-
global_server_args_dict["attention_backend"] != "ascend"
|
|
2072
|
-
and global_server_args_dict["attention_backend"] != "torch_native"
|
|
2073
|
-
):
|
|
2074
|
-
impl = get_last_loc_triton
|
|
2075
|
-
else:
|
|
2076
|
-
impl = get_last_loc_torch
|
|
2077
|
-
|
|
2078
|
-
return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
|
|
2079
|
-
|
|
2080
|
-
|
|
2081
|
-
def get_last_loc_torch(
|
|
2082
|
-
req_to_token: torch.Tensor,
|
|
2083
|
-
req_pool_indices_tensor: torch.Tensor,
|
|
2084
|
-
prefix_lens_tensor: torch.Tensor,
|
|
2085
|
-
) -> torch.Tensor:
|
|
2086
|
-
return torch.where(
|
|
2087
|
-
prefix_lens_tensor > 0,
|
|
2088
|
-
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
|
|
2089
|
-
torch.full_like(prefix_lens_tensor, -1),
|
|
2090
|
-
)
|
|
2091
|
-
|
|
2092
|
-
|
|
2093
|
-
@triton.jit
|
|
2094
|
-
def get_last_loc_kernel(
|
|
2095
|
-
req_to_token,
|
|
2096
|
-
req_pool_indices_tensor,
|
|
2097
|
-
prefix_lens_tensor,
|
|
2098
|
-
result,
|
|
2099
|
-
num_tokens,
|
|
2100
|
-
req_to_token_stride,
|
|
2101
|
-
BLOCK_SIZE: tl.constexpr,
|
|
2102
|
-
):
|
|
2103
|
-
pid = tl.program_id(0)
|
|
2104
|
-
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
|
|
2105
|
-
mask = offset < num_tokens
|
|
2106
|
-
|
|
2107
|
-
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
|
|
2108
|
-
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
|
|
2109
|
-
|
|
2110
|
-
token_mask = prefix_lens > 0
|
|
2111
|
-
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
|
|
2112
|
-
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
|
|
2113
|
-
|
|
2114
|
-
tl.store(result + offset, tokens, mask=mask)
|
|
2115
|
-
|
|
2116
|
-
|
|
2117
|
-
def get_last_loc_triton(
|
|
2118
|
-
req_to_token: torch.Tensor,
|
|
2119
|
-
req_pool_indices_tensor: torch.Tensor,
|
|
2120
|
-
prefix_lens_tensor: torch.Tensor,
|
|
2121
|
-
) -> torch.Tensor:
|
|
2122
|
-
BLOCK_SIZE = 256
|
|
2123
|
-
num_tokens = prefix_lens_tensor.shape[0]
|
|
2124
|
-
result = torch.empty_like(prefix_lens_tensor)
|
|
2125
|
-
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
|
|
2126
|
-
|
|
2127
|
-
get_last_loc_kernel[grid](
|
|
2128
|
-
req_to_token,
|
|
2129
|
-
req_pool_indices_tensor,
|
|
2130
|
-
prefix_lens_tensor,
|
|
2131
|
-
result,
|
|
2132
|
-
num_tokens,
|
|
2133
|
-
req_to_token.stride(0),
|
|
2134
|
-
BLOCK_SIZE,
|
|
2135
|
-
)
|
|
2136
|
-
return result
|