sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -21,10 +21,11 @@ Life cycle of a request in the decode server
|
|
21
21
|
from __future__ import annotations
|
22
22
|
|
23
23
|
import logging
|
24
|
+
import time
|
24
25
|
from collections import deque
|
25
26
|
from dataclasses import dataclass
|
26
27
|
from http import HTTPStatus
|
27
|
-
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
28
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
28
29
|
|
29
30
|
import torch
|
30
31
|
from torch.distributed import ProcessGroup
|
@@ -45,7 +46,7 @@ from sglang.srt.disaggregation.utils import (
|
|
45
46
|
prepare_abort,
|
46
47
|
)
|
47
48
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
48
|
-
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
49
|
+
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
|
49
50
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
50
51
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
51
52
|
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
@@ -218,8 +219,10 @@ class DecodePreallocQueue:
|
|
218
219
|
|
219
220
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
220
221
|
kv_args.gpu_id = self.scheduler.gpu_id
|
221
|
-
kv_manager_class = get_kv_class(
|
222
|
-
|
222
|
+
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
223
|
+
self.transfer_backend, KVClassType.MANAGER
|
224
|
+
)
|
225
|
+
kv_manager: BaseKVManager = kv_manager_class(
|
223
226
|
kv_args,
|
224
227
|
DisaggregationMode.DECODE,
|
225
228
|
self.scheduler.server_args,
|
@@ -248,9 +251,10 @@ class DecodePreallocQueue:
|
|
248
251
|
mgr=self.kv_manager,
|
249
252
|
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
250
253
|
bootstrap_room=req.bootstrap_room,
|
251
|
-
|
254
|
+
prefill_dp_rank=req.data_parallel_rank,
|
252
255
|
)
|
253
256
|
|
257
|
+
req.add_latency(RequestStage.DECODE_PREPARE)
|
254
258
|
self.queue.append(
|
255
259
|
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
|
256
260
|
)
|
@@ -419,8 +423,13 @@ class DecodePreallocQueue:
|
|
419
423
|
kv_indices, self.token_to_kv_pool_allocator.page_size
|
420
424
|
)
|
421
425
|
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
426
|
+
|
422
427
|
preallocated_reqs.append(decode_req)
|
423
428
|
indices_to_remove.add(i)
|
429
|
+
decode_req.req.time_stats.decode_transfer_queue_entry_time = (
|
430
|
+
time.perf_counter()
|
431
|
+
)
|
432
|
+
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
|
424
433
|
|
425
434
|
self.queue = [
|
426
435
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
@@ -514,11 +523,19 @@ class DecodePreallocQueue:
|
|
514
523
|
dtype=torch.int64,
|
515
524
|
device=self.token_to_kv_pool_allocator.device,
|
516
525
|
),
|
526
|
+
prefix_lens_cpu=torch.tensor(
|
527
|
+
[0],
|
528
|
+
dtype=torch.int64,
|
529
|
+
),
|
517
530
|
seq_lens=torch.tensor(
|
518
531
|
[num_tokens],
|
519
532
|
dtype=torch.int64,
|
520
533
|
device=self.token_to_kv_pool_allocator.device,
|
521
534
|
),
|
535
|
+
seq_lens_cpu=torch.tensor(
|
536
|
+
[num_tokens],
|
537
|
+
dtype=torch.int64,
|
538
|
+
),
|
522
539
|
last_loc=torch.tensor(
|
523
540
|
[-1],
|
524
541
|
dtype=torch.int64,
|
@@ -605,16 +622,23 @@ class DecodeTransferQueue:
|
|
605
622
|
idx = decode_req.metadata_buffer_index
|
606
623
|
(
|
607
624
|
output_id,
|
625
|
+
cached_tokens,
|
608
626
|
output_token_logprobs_val,
|
609
627
|
output_token_logprobs_idx,
|
610
628
|
output_top_logprobs_val,
|
611
629
|
output_top_logprobs_idx,
|
630
|
+
output_topk_p,
|
631
|
+
output_topk_index,
|
612
632
|
output_hidden_states,
|
613
633
|
) = self.metadata_buffers.get_buf(idx)
|
614
634
|
|
615
635
|
decode_req.req.output_ids.append(output_id[0].item())
|
636
|
+
decode_req.req.cached_tokens = cached_tokens[0].item()
|
616
637
|
if not self.spec_algorithm.is_none():
|
638
|
+
decode_req.req.output_topk_p = output_topk_p
|
639
|
+
decode_req.req.output_topk_index = output_topk_index
|
617
640
|
decode_req.req.hidden_states_tensor = output_hidden_states
|
641
|
+
|
618
642
|
if decode_req.req.return_logprob:
|
619
643
|
decode_req.req.output_token_logprobs_val.append(
|
620
644
|
output_token_logprobs_val[0].item()
|
@@ -635,10 +659,17 @@ class DecodeTransferQueue:
|
|
635
659
|
|
636
660
|
if hasattr(decode_req.kv_receiver, "clear"):
|
637
661
|
decode_req.kv_receiver.clear()
|
662
|
+
decode_req.kv_receiver = None
|
663
|
+
|
664
|
+
indices_to_remove.add(i)
|
665
|
+
decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
|
638
666
|
|
639
667
|
# special handling for sampling_params.max_new_tokens == 1
|
640
668
|
if decode_req.req.sampling_params.max_new_tokens == 1:
|
641
669
|
# finish immediately
|
670
|
+
decode_req.req.time_stats.forward_entry_time = (
|
671
|
+
decode_req.req.time_stats.completion_time
|
672
|
+
) = time.perf_counter()
|
642
673
|
decode_req.req.check_finished()
|
643
674
|
self.scheduler.stream_output(
|
644
675
|
[decode_req.req], decode_req.req.return_logprob
|
@@ -646,8 +677,6 @@ class DecodeTransferQueue:
|
|
646
677
|
self.tree_cache.cache_finished_req(decode_req.req)
|
647
678
|
else:
|
648
679
|
transferred_reqs.append(decode_req.req)
|
649
|
-
|
650
|
-
indices_to_remove.add(i)
|
651
680
|
elif poll in [
|
652
681
|
KVPoll.Bootstrapping,
|
653
682
|
KVPoll.WaitingForInput,
|
@@ -660,6 +689,7 @@ class DecodeTransferQueue:
|
|
660
689
|
for i in indices_to_remove:
|
661
690
|
idx = self.queue[i].metadata_buffer_index
|
662
691
|
assert idx != -1
|
692
|
+
self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED)
|
663
693
|
self.req_to_metadata_buffer_idx_allocator.free(idx)
|
664
694
|
|
665
695
|
self.queue = [
|
@@ -702,12 +732,15 @@ class SchedulerDisaggregationDecodeMixin:
|
|
702
732
|
elif prepare_mlp_sync_flag:
|
703
733
|
batch, _ = self._prepare_idle_batch_and_run(None)
|
704
734
|
|
705
|
-
|
735
|
+
queue_size = (
|
706
736
|
len(self.waiting_queue)
|
707
737
|
+ len(self.disagg_decode_transfer_queue.queue)
|
708
738
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
709
|
-
|
710
|
-
|
739
|
+
)
|
740
|
+
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
741
|
+
queue_size += len(self.decode_offload_manager.ongoing_offload)
|
742
|
+
|
743
|
+
if batch is None and queue_size == 0:
|
711
744
|
self.self_check_during_idle()
|
712
745
|
|
713
746
|
self.last_batch = batch
|
@@ -776,12 +809,15 @@ class SchedulerDisaggregationDecodeMixin:
|
|
776
809
|
)
|
777
810
|
self.process_batch_result(tmp_batch, tmp_result)
|
778
811
|
|
779
|
-
|
812
|
+
queue_size = (
|
780
813
|
len(self.waiting_queue)
|
781
814
|
+ len(self.disagg_decode_transfer_queue.queue)
|
782
815
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
783
|
-
|
784
|
-
|
816
|
+
)
|
817
|
+
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
818
|
+
queue_size += len(self.decode_offload_manager.ongoing_offload)
|
819
|
+
|
820
|
+
if batch is None and queue_size == 0:
|
785
821
|
self.self_check_during_idle()
|
786
822
|
|
787
823
|
self.last_batch = batch
|
@@ -851,6 +887,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
851
887
|
# we can only add at least `num_not_used_batch` new batch to the running queue
|
852
888
|
if i < num_not_used_batch:
|
853
889
|
can_run_list.append(req)
|
890
|
+
req.add_latency(RequestStage.DECODE_WAITING)
|
854
891
|
req.init_next_round_input(self.tree_cache)
|
855
892
|
else:
|
856
893
|
waiting_queue.append(req)
|
@@ -859,6 +896,9 @@ class SchedulerDisaggregationDecodeMixin:
|
|
859
896
|
if len(can_run_list) == 0:
|
860
897
|
return None
|
861
898
|
|
899
|
+
for req in can_run_list:
|
900
|
+
req.time_stats.forward_entry_time = time.perf_counter()
|
901
|
+
|
862
902
|
# construct a schedule batch with those requests and mark as decode
|
863
903
|
new_batch = ScheduleBatch.init_new(
|
864
904
|
can_run_list,
|
@@ -884,9 +924,21 @@ class SchedulerDisaggregationDecodeMixin:
|
|
884
924
|
# if there are still retracted requests, we do not allocate new requests
|
885
925
|
return
|
886
926
|
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
927
|
+
if not hasattr(self, "polling_count"):
|
928
|
+
self.polling_count = 0
|
929
|
+
self.polling_interval = (
|
930
|
+
self.server_args.disaggregation_decode_polling_interval
|
931
|
+
)
|
932
|
+
|
933
|
+
self.polling_count = (self.polling_count + 1) % self.polling_interval
|
934
|
+
|
935
|
+
if self.polling_count % self.polling_interval == 0:
|
936
|
+
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
937
|
+
self.disagg_decode_transfer_queue.extend(req_conns)
|
938
|
+
alloc_reqs = (
|
939
|
+
self.disagg_decode_transfer_queue.pop_transferred()
|
940
|
+
) # the requests which kv has arrived
|
941
|
+
self.waiting_queue.extend(alloc_reqs)
|
942
|
+
|
943
|
+
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
944
|
+
self.decode_offload_manager.check_offload_progress()
|
@@ -0,0 +1,185 @@
|
|
1
|
+
import logging
|
2
|
+
import threading
|
3
|
+
import time
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from sglang import ServerArgs
|
8
|
+
from sglang.srt.managers.cache_controller import HiCacheController
|
9
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
10
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
11
|
+
from sglang.srt.mem_cache.memory_pool import (
|
12
|
+
MHATokenToKVPool,
|
13
|
+
MLATokenToKVPool,
|
14
|
+
ReqToTokenPool,
|
15
|
+
)
|
16
|
+
from sglang.srt.mem_cache.memory_pool_host import (
|
17
|
+
MHATokenToKVPoolHost,
|
18
|
+
MLATokenToKVPoolHost,
|
19
|
+
)
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class DecodeKVCacheOffloadManager:
|
25
|
+
"""Manage decode-side KV cache offloading lifecycle and operations."""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
req_to_token_pool: ReqToTokenPool,
|
30
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
31
|
+
tp_group: torch.distributed.ProcessGroup,
|
32
|
+
tree_cache: BasePrefixCache,
|
33
|
+
server_args: ServerArgs,
|
34
|
+
) -> None:
|
35
|
+
self.req_to_token_pool = req_to_token_pool
|
36
|
+
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
37
|
+
self.page_size = server_args.page_size
|
38
|
+
self.server_args = server_args
|
39
|
+
self.request_counter = 0
|
40
|
+
self.tree_cache = tree_cache
|
41
|
+
kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
|
42
|
+
if isinstance(kv_cache, MHATokenToKVPool):
|
43
|
+
self.decode_host_mem_pool = MHATokenToKVPoolHost(
|
44
|
+
kv_cache,
|
45
|
+
server_args.hicache_ratio,
|
46
|
+
server_args.hicache_size,
|
47
|
+
self.page_size,
|
48
|
+
server_args.hicache_mem_layout,
|
49
|
+
)
|
50
|
+
elif isinstance(kv_cache, MLATokenToKVPool):
|
51
|
+
self.decode_host_mem_pool = MLATokenToKVPoolHost(
|
52
|
+
kv_cache,
|
53
|
+
server_args.hicache_ratio,
|
54
|
+
server_args.hicache_size,
|
55
|
+
self.page_size,
|
56
|
+
server_args.hicache_mem_layout,
|
57
|
+
)
|
58
|
+
else:
|
59
|
+
raise ValueError("Unsupported KV cache type for decode offload")
|
60
|
+
|
61
|
+
self.tp_group = tp_group
|
62
|
+
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
63
|
+
self.cache_controller = HiCacheController(
|
64
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
65
|
+
mem_pool_host=self.decode_host_mem_pool,
|
66
|
+
page_size=self.page_size,
|
67
|
+
tp_group=tp_group,
|
68
|
+
io_backend=server_args.hicache_io_backend,
|
69
|
+
load_cache_event=threading.Event(),
|
70
|
+
storage_backend=server_args.hicache_storage_backend,
|
71
|
+
model_name=server_args.served_model_name,
|
72
|
+
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
73
|
+
)
|
74
|
+
|
75
|
+
self.ongoing_offload = {}
|
76
|
+
self.ongoing_backup = {}
|
77
|
+
logger.info("Enable offload kv cache for decode side")
|
78
|
+
|
79
|
+
def offload_kv_cache(self, req) -> bool:
|
80
|
+
"""Offload a finished request's KV cache to storage."""
|
81
|
+
|
82
|
+
if self.cache_controller is None or self.decode_host_mem_pool is None:
|
83
|
+
return False
|
84
|
+
|
85
|
+
if req.req_pool_idx == -1:
|
86
|
+
return False
|
87
|
+
|
88
|
+
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
|
89
|
+
if token_indices.dim() == 0 or token_indices.numel() == 0:
|
90
|
+
logger.debug(
|
91
|
+
f"Request {req.rid} has invalid token_indices: {token_indices}"
|
92
|
+
)
|
93
|
+
return False
|
94
|
+
|
95
|
+
tokens = req.origin_input_ids + req.output_ids
|
96
|
+
aligned_len = (len(tokens) // self.page_size) * self.page_size
|
97
|
+
if aligned_len == 0:
|
98
|
+
return False
|
99
|
+
|
100
|
+
token_indices = token_indices[:aligned_len]
|
101
|
+
tokens = tokens[:aligned_len]
|
102
|
+
|
103
|
+
# Asynchronously offload KV cache from device to host by cache controller
|
104
|
+
self.request_counter += 1
|
105
|
+
ack_id = self.request_counter
|
106
|
+
host_indices = self.cache_controller.write(
|
107
|
+
device_indices=token_indices.long(),
|
108
|
+
node_id=ack_id,
|
109
|
+
)
|
110
|
+
if host_indices is None:
|
111
|
+
logger.error(f"Not enough host memory for request {req.rid}")
|
112
|
+
return False
|
113
|
+
|
114
|
+
self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
|
115
|
+
return True
|
116
|
+
|
117
|
+
def check_offload_progress(self):
|
118
|
+
"""Check the progress of offload from device to host and backup from host to storage."""
|
119
|
+
cc = self.cache_controller
|
120
|
+
|
121
|
+
qsizes = torch.tensor(
|
122
|
+
[
|
123
|
+
len(cc.ack_write_queue),
|
124
|
+
cc.ack_backup_queue.qsize(),
|
125
|
+
],
|
126
|
+
dtype=torch.int,
|
127
|
+
)
|
128
|
+
if self.tp_world_size > 1:
|
129
|
+
torch.distributed.all_reduce(
|
130
|
+
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
131
|
+
)
|
132
|
+
|
133
|
+
n_write, n_backup = map(int, qsizes.tolist())
|
134
|
+
self._check_offload_progress(n_write)
|
135
|
+
self._check_backup_progress(n_backup)
|
136
|
+
|
137
|
+
def _check_offload_progress(self, finish_count):
|
138
|
+
"""Check the progress of offload from device to host."""
|
139
|
+
while finish_count > 0:
|
140
|
+
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
141
|
+
finish_event.synchronize()
|
142
|
+
for ack_id in ack_list:
|
143
|
+
req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
|
144
|
+
|
145
|
+
# Release device
|
146
|
+
self.tree_cache.cache_finished_req(req)
|
147
|
+
|
148
|
+
# Trigger async backup from host to storage by cache controller
|
149
|
+
self._trigger_backup(req.rid, host_indices, tokens, start_time)
|
150
|
+
finish_count -= 1
|
151
|
+
|
152
|
+
def _check_backup_progress(self, finish_count):
|
153
|
+
"""Check the progress of backup from host to storage."""
|
154
|
+
for _ in range(finish_count):
|
155
|
+
storage_operation = self.cache_controller.ack_backup_queue.get()
|
156
|
+
ack_id = storage_operation.id
|
157
|
+
req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
|
158
|
+
|
159
|
+
# Release host memory
|
160
|
+
self.decode_host_mem_pool.free(host_indices)
|
161
|
+
|
162
|
+
logger.debug(
|
163
|
+
f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
|
164
|
+
)
|
165
|
+
|
166
|
+
def _trigger_backup(self, req_id, host_indices, tokens, start_time):
|
167
|
+
"""Trigger async backup from host to storage by cache controller."""
|
168
|
+
|
169
|
+
# Generate page hashes and write to storage
|
170
|
+
page_hashes = self._compute_prefix_hash(tokens)
|
171
|
+
ack_id = self.cache_controller.write_storage(
|
172
|
+
host_indices,
|
173
|
+
tokens,
|
174
|
+
hash_value=page_hashes,
|
175
|
+
)
|
176
|
+
self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
|
177
|
+
|
178
|
+
def _compute_prefix_hash(self, tokens):
|
179
|
+
last_hash = ""
|
180
|
+
page_hashes = []
|
181
|
+
for offset in range(0, len(tokens), self.page_size):
|
182
|
+
page_tokens = tokens[offset : offset + self.page_size]
|
183
|
+
last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
|
184
|
+
page_hashes.append(last_hash)
|
185
|
+
return page_hashes
|
@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
76
76
|
req_pool_indices, dtype=torch.int64, device=self.device
|
77
77
|
)
|
78
78
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
79
|
+
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
79
80
|
self.orig_seq_lens = torch.tensor(
|
80
81
|
seq_lens, dtype=torch.int32, device=self.device
|
81
82
|
)
|
@@ -110,7 +111,10 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
110
111
|
if req.grammar is not None:
|
111
112
|
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
112
113
|
try:
|
113
|
-
|
114
|
+
# if it is not None, then the grammar is from a retracted request, and we should not
|
115
|
+
# accept the token as it's already accepted
|
116
|
+
if req.grammar.current_token is None:
|
117
|
+
req.grammar.accept_token(req.output_ids[-1])
|
114
118
|
except ValueError as e:
|
115
119
|
# Grammar accept_token can raise ValueError if the token is not in the grammar.
|
116
120
|
# This can happen if the grammar is not set correctly or the token is invalid.
|
@@ -122,31 +126,39 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
122
126
|
req.grammar.finished = req.finished()
|
123
127
|
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
124
128
|
|
125
|
-
# Simulate the eagle run.
|
126
|
-
|
127
|
-
# of 0.
|
128
|
-
if not self.spec_algorithm.is_none():
|
129
|
+
# Simulate the eagle run.
|
130
|
+
if self.spec_algorithm.is_eagle():
|
129
131
|
|
130
132
|
b = len(self.reqs)
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
133
|
+
topk = server_args.speculative_eagle_topk
|
134
|
+
topk_p = torch.stack(
|
135
|
+
[
|
136
|
+
torch.as_tensor(
|
137
|
+
req.output_topk_p[:topk],
|
138
|
+
device=self.device,
|
139
|
+
dtype=torch.float32,
|
140
|
+
)
|
141
|
+
for req in self.reqs
|
142
|
+
],
|
143
|
+
dim=0,
|
137
144
|
)
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
145
|
+
topk_index = torch.stack(
|
146
|
+
[
|
147
|
+
torch.as_tensor(
|
148
|
+
req.output_topk_index[:topk],
|
149
|
+
device=self.device,
|
150
|
+
dtype=torch.int64,
|
151
|
+
)
|
152
|
+
for req in self.reqs
|
153
|
+
],
|
154
|
+
dim=0,
|
142
155
|
)
|
143
|
-
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
144
156
|
|
145
157
|
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
|
146
158
|
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
147
159
|
|
148
160
|
# local import to avoid circular import
|
149
|
-
from sglang.srt.speculative.
|
161
|
+
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
150
162
|
|
151
163
|
spec_info = EagleDraftInput(
|
152
164
|
topk_p=topk_p,
|