sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +192 -113
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +132 -57
- sglang/srt/entrypoints/openai/protocol.py +115 -7
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +207 -58
- sglang/srt/entrypoints/openai/serving_completions.py +17 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +49 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +106 -82
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +53 -7
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +225 -57
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +78 -49
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +215 -314
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +147 -19
- sglang/srt/managers/scheduler.py +501 -304
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +321 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +15 -21
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +58 -34
- sglang/srt/mem_cache/hiradix_cache.py +227 -80
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -223
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +519 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +55 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +98 -57
- sglang/srt/model_executor/model_runner.py +433 -158
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +833 -152
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +14 -5
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +124 -14
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +26 -5
- sglang/srt/models/qwen3_moe.py +71 -12
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +10 -3
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1030 -254
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +253 -136
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +445 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +22 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
|
|
23
23
|
from sglang.srt.model_executor.forward_batch_info import (
|
24
24
|
CaptureHiddenMode,
|
25
25
|
ForwardBatch,
|
26
|
+
ForwardBatchOutput,
|
26
27
|
ForwardMode,
|
27
28
|
)
|
28
29
|
from sglang.srt.server_args import ServerArgs
|
@@ -33,20 +34,23 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
|
33
34
|
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
|
34
35
|
EAGLEDraftExtendCudaGraphRunner,
|
35
36
|
)
|
36
|
-
from sglang.srt.speculative.
|
37
|
+
from sglang.srt.speculative.eagle_info import (
|
37
38
|
EagleDraftInput,
|
38
39
|
EagleVerifyInput,
|
39
40
|
EagleVerifyOutput,
|
41
|
+
)
|
42
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
43
|
+
from sglang.srt.speculative.spec_utils import (
|
40
44
|
assign_draft_cache_locs,
|
41
45
|
fast_topk,
|
42
46
|
generate_token_bitmask,
|
43
47
|
select_top_k_tokens,
|
44
48
|
)
|
45
|
-
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
46
49
|
from sglang.srt.utils import (
|
47
50
|
empty_context,
|
48
51
|
get_available_gpu_memory,
|
49
52
|
get_bool_env_var,
|
53
|
+
is_blackwell,
|
50
54
|
is_cuda,
|
51
55
|
next_power_of_2,
|
52
56
|
)
|
@@ -187,137 +191,204 @@ class EAGLEWorker(TpModelWorker):
|
|
187
191
|
self.has_prefill_wrapper_verify = False
|
188
192
|
self.draft_extend_attn_backend = None
|
189
193
|
|
190
|
-
|
191
|
-
|
192
|
-
from sglang.srt.layers.attention.flashinfer_backend import (
|
193
|
-
FlashInferAttnBackend,
|
194
|
-
FlashInferMultiStepDraftBackend,
|
195
|
-
)
|
194
|
+
# Initialize decode attention backend
|
195
|
+
self.draft_attn_backend = self._create_decode_backend()
|
196
196
|
|
197
|
-
|
198
|
-
|
199
|
-
self.topk,
|
200
|
-
self.speculative_num_steps,
|
201
|
-
)
|
202
|
-
self.draft_extend_attn_backend = FlashInferAttnBackend(
|
203
|
-
self.draft_model_runner,
|
204
|
-
skip_prefill=False,
|
205
|
-
)
|
206
|
-
else:
|
207
|
-
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
208
|
-
FlashInferMLAAttnBackend,
|
209
|
-
FlashInferMLAMultiStepDraftBackend,
|
210
|
-
)
|
197
|
+
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
|
198
|
+
self.draft_extend_attn_backend = self._create_draft_extend_backend()
|
211
199
|
|
212
|
-
|
213
|
-
self.draft_model_runner,
|
214
|
-
self.topk,
|
215
|
-
self.speculative_num_steps,
|
216
|
-
)
|
217
|
-
self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
|
218
|
-
self.draft_model_runner,
|
219
|
-
skip_prefill=False,
|
220
|
-
)
|
221
|
-
self.has_prefill_wrapper_verify = True
|
222
|
-
elif self.server_args.attention_backend == "triton":
|
223
|
-
from sglang.srt.layers.attention.triton_backend import (
|
224
|
-
TritonAttnBackend,
|
225
|
-
TritonMultiStepDraftBackend,
|
226
|
-
)
|
200
|
+
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
227
201
|
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
)
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
202
|
+
def _create_backend(
|
203
|
+
self, backend_name: str, backend_map: dict, error_template: str
|
204
|
+
):
|
205
|
+
backend_type = getattr(self.server_args, backend_name)
|
206
|
+
if backend_type is None:
|
207
|
+
backend_type = self.server_args.attention_backend
|
208
|
+
|
209
|
+
if backend_type not in backend_map:
|
210
|
+
raise ValueError(error_template.format(backend_type=backend_type))
|
211
|
+
|
212
|
+
return backend_map[backend_type]()
|
213
|
+
|
214
|
+
def _create_decode_backend(self):
|
215
|
+
backend_map = {
|
216
|
+
"flashinfer": self._create_flashinfer_decode_backend,
|
217
|
+
"triton": self._create_triton_decode_backend,
|
218
|
+
"aiter": self._create_aiter_decode_backend,
|
219
|
+
"fa3": self._create_fa3_decode_backend,
|
220
|
+
"hybrid_linear_attn": (
|
221
|
+
self._create_fa3_decode_backend
|
222
|
+
if not is_blackwell()
|
223
|
+
else self._create_triton_decode_backend
|
224
|
+
),
|
225
|
+
"flashmla": self._create_flashmla_decode_backend,
|
226
|
+
"trtllm_mha": self._create_trtllm_mha_decode_backend,
|
227
|
+
"trtllm_mla": self._create_trtllm_mla_decode_backend,
|
228
|
+
}
|
229
|
+
|
230
|
+
return self._create_backend(
|
231
|
+
"decode_attention_backend",
|
232
|
+
backend_map,
|
233
|
+
"EAGLE is not supported in decode attention backend {backend_type}",
|
234
|
+
)
|
242
235
|
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
self.
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
236
|
+
def _create_draft_extend_backend(self):
|
237
|
+
backend_map = {
|
238
|
+
"flashinfer": self._create_flashinfer_prefill_backend,
|
239
|
+
"triton": self._create_triton_prefill_backend,
|
240
|
+
"aiter": self._create_aiter_prefill_backend,
|
241
|
+
"fa3": self._create_fa3_prefill_backend,
|
242
|
+
"hybrid_linear_attn": (
|
243
|
+
self._create_fa3_prefill_backend
|
244
|
+
if not is_blackwell()
|
245
|
+
else self._create_triton_prefill_backend
|
246
|
+
),
|
247
|
+
"flashmla": self._create_flashmla_prefill_backend,
|
248
|
+
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
249
|
+
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
250
|
+
}
|
251
|
+
backend_name = (
|
252
|
+
"decode_attention_backend"
|
253
|
+
if self.server_args.speculative_attention_mode == "decode"
|
254
|
+
else "prefill_attention_backend"
|
255
|
+
)
|
256
|
+
return self._create_backend(
|
257
|
+
backend_name,
|
258
|
+
backend_map,
|
259
|
+
"EAGLE is not supported in attention backend {backend_type}",
|
260
|
+
)
|
258
261
|
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
)
|
264
|
-
self.draft_extend_attn_backend = FlashAttentionBackend(
|
265
|
-
self.draft_model_runner,
|
266
|
-
skip_prefill=False,
|
267
|
-
)
|
268
|
-
elif self.server_args.attention_backend == "flashmla":
|
269
|
-
from sglang.srt.layers.attention.flashmla_backend import (
|
270
|
-
FlashMLAMultiStepDraftBackend,
|
262
|
+
def _create_flashinfer_decode_backend(self):
|
263
|
+
if not global_server_args_dict["use_mla_backend"]:
|
264
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
265
|
+
FlashInferMultiStepDraftBackend,
|
271
266
|
)
|
272
267
|
|
273
|
-
self.
|
274
|
-
|
275
|
-
self.topk,
|
276
|
-
self.speculative_num_steps,
|
268
|
+
self.has_prefill_wrapper_verify = True
|
269
|
+
return FlashInferMultiStepDraftBackend(
|
270
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
277
271
|
)
|
278
|
-
|
279
|
-
from sglang.srt.layers.attention.
|
280
|
-
|
281
|
-
TRTLLMHAAttnMultiStepDraftBackend,
|
272
|
+
else:
|
273
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
274
|
+
FlashInferMLAMultiStepDraftBackend,
|
282
275
|
)
|
283
276
|
|
284
|
-
self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
|
285
|
-
self.draft_model_runner,
|
286
|
-
self.topk,
|
287
|
-
self.speculative_num_steps,
|
288
|
-
)
|
289
|
-
self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
|
290
|
-
self.draft_model_runner,
|
291
|
-
skip_prefill=False,
|
292
|
-
)
|
293
277
|
self.has_prefill_wrapper_verify = True
|
294
|
-
|
295
|
-
|
296
|
-
raise ValueError(
|
297
|
-
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
298
|
-
)
|
299
|
-
|
300
|
-
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
301
|
-
TRTLLMMLABackend,
|
302
|
-
TRTLLMMLAMultiStepDraftBackend,
|
278
|
+
return FlashInferMLAMultiStepDraftBackend(
|
279
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
303
280
|
)
|
304
281
|
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
282
|
+
def _create_triton_decode_backend(self):
|
283
|
+
from sglang.srt.layers.attention.triton_backend import (
|
284
|
+
TritonMultiStepDraftBackend,
|
285
|
+
)
|
286
|
+
|
287
|
+
return TritonMultiStepDraftBackend(
|
288
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
289
|
+
)
|
290
|
+
|
291
|
+
def _create_aiter_decode_backend(self):
|
292
|
+
from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
|
293
|
+
|
294
|
+
return AiterMultiStepDraftBackend(
|
295
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
296
|
+
)
|
297
|
+
|
298
|
+
def _create_fa3_decode_backend(self):
|
299
|
+
from sglang.srt.layers.attention.flashattention_backend import (
|
300
|
+
FlashAttentionMultiStepBackend,
|
301
|
+
)
|
302
|
+
|
303
|
+
return FlashAttentionMultiStepBackend(
|
304
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
305
|
+
)
|
306
|
+
|
307
|
+
def _create_flashmla_decode_backend(self):
|
308
|
+
from sglang.srt.layers.attention.flashmla_backend import (
|
309
|
+
FlashMLAMultiStepDraftBackend,
|
310
|
+
)
|
311
|
+
|
312
|
+
return FlashMLAMultiStepDraftBackend(
|
313
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
314
|
+
)
|
315
|
+
|
316
|
+
def _create_trtllm_mha_decode_backend(self):
|
317
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
318
|
+
TRTLLMHAAttnMultiStepDraftBackend,
|
319
|
+
)
|
320
|
+
|
321
|
+
self.has_prefill_wrapper_verify = True
|
322
|
+
return TRTLLMHAAttnMultiStepDraftBackend(
|
323
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
324
|
+
)
|
325
|
+
|
326
|
+
def _create_trtllm_mla_decode_backend(self):
|
327
|
+
if not global_server_args_dict["use_mla_backend"]:
|
328
|
+
raise ValueError(
|
329
|
+
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
309
330
|
)
|
310
|
-
|
311
|
-
|
312
|
-
|
331
|
+
|
332
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
333
|
+
TRTLLMMLAMultiStepDraftBackend,
|
334
|
+
)
|
335
|
+
|
336
|
+
self.has_prefill_wrapper_verify = True
|
337
|
+
return TRTLLMMLAMultiStepDraftBackend(
|
338
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
339
|
+
)
|
340
|
+
|
341
|
+
def _create_flashinfer_prefill_backend(self):
|
342
|
+
if not global_server_args_dict["use_mla_backend"]:
|
343
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
344
|
+
FlashInferAttnBackend,
|
313
345
|
)
|
314
|
-
|
346
|
+
|
347
|
+
return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
|
315
348
|
else:
|
349
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
350
|
+
FlashInferMLAAttnBackend,
|
351
|
+
)
|
352
|
+
|
353
|
+
return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
354
|
+
|
355
|
+
def _create_triton_prefill_backend(self):
|
356
|
+
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
357
|
+
|
358
|
+
return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
|
359
|
+
|
360
|
+
def _create_aiter_prefill_backend(self):
|
361
|
+
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
362
|
+
|
363
|
+
return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
|
364
|
+
|
365
|
+
def _create_fa3_prefill_backend(self):
|
366
|
+
from sglang.srt.layers.attention.flashattention_backend import (
|
367
|
+
FlashAttentionBackend,
|
368
|
+
)
|
369
|
+
|
370
|
+
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
|
371
|
+
|
372
|
+
def _create_trtllm_mha_prefill_backend(self):
|
373
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
|
374
|
+
|
375
|
+
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
376
|
+
|
377
|
+
def _create_trtllm_mla_prefill_backend(self):
|
378
|
+
if not global_server_args_dict["use_mla_backend"]:
|
316
379
|
raise ValueError(
|
317
|
-
|
380
|
+
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
318
381
|
)
|
319
382
|
|
320
|
-
|
383
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
384
|
+
|
385
|
+
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
|
386
|
+
|
387
|
+
def _create_flashmla_prefill_backend(self):
|
388
|
+
logger.warning(
|
389
|
+
"flashmla prefill backend is not yet supported for draft extend."
|
390
|
+
)
|
391
|
+
return None
|
321
392
|
|
322
393
|
def init_cuda_graphs(self):
|
323
394
|
"""Capture cuda graphs."""
|
@@ -358,9 +429,7 @@ class EAGLEWorker(TpModelWorker):
|
|
358
429
|
def draft_model_runner(self):
|
359
430
|
return self.model_runner
|
360
431
|
|
361
|
-
def
|
362
|
-
self, batch: ScheduleBatch
|
363
|
-
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
|
432
|
+
def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
|
364
433
|
"""Run speculative decoding forward.
|
365
434
|
|
366
435
|
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
@@ -373,14 +442,19 @@ class EAGLEWorker(TpModelWorker):
|
|
373
442
|
the batch id (used for overlap schedule), and number of accepted tokens.
|
374
443
|
"""
|
375
444
|
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
|
376
|
-
logits_output, next_token_ids,
|
377
|
-
|
445
|
+
logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend(
|
446
|
+
batch
|
378
447
|
)
|
379
448
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
380
449
|
self.forward_draft_extend(
|
381
450
|
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
382
451
|
)
|
383
|
-
return
|
452
|
+
return ForwardBatchOutput(
|
453
|
+
logits_output=logits_output,
|
454
|
+
next_token_ids=next_token_ids,
|
455
|
+
num_accepted_tokens=0,
|
456
|
+
can_run_cuda_graph=False,
|
457
|
+
)
|
384
458
|
else:
|
385
459
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
386
460
|
spec_info = self.draft(batch)
|
@@ -398,12 +472,11 @@ class EAGLEWorker(TpModelWorker):
|
|
398
472
|
# decode is not finished
|
399
473
|
self.forward_draft_extend_after_decode(batch)
|
400
474
|
|
401
|
-
return (
|
402
|
-
logits_output,
|
403
|
-
verify_output.verified_id,
|
404
|
-
|
405
|
-
|
406
|
-
can_run_cuda_graph,
|
475
|
+
return ForwardBatchOutput(
|
476
|
+
logits_output=logits_output,
|
477
|
+
next_token_ids=verify_output.verified_id,
|
478
|
+
num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
|
479
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
407
480
|
)
|
408
481
|
|
409
482
|
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
@@ -435,19 +508,21 @@ class EAGLEWorker(TpModelWorker):
|
|
435
508
|
Returns:
|
436
509
|
logits_output: The output of logits. It will contain the full hidden states.
|
437
510
|
next_token_ids: Next token ids generated.
|
438
|
-
bid: The model batch ID. Used for overlap schedule.
|
439
511
|
"""
|
440
512
|
# Forward with the target model and get hidden states.
|
441
513
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
442
514
|
model_worker_batch = batch.get_model_worker_batch()
|
443
515
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
444
|
-
|
516
|
+
forward_batch_output = self.target_worker.forward_batch_generation(
|
445
517
|
model_worker_batch
|
446
518
|
)
|
519
|
+
logits_output, next_token_ids = (
|
520
|
+
forward_batch_output.logits_output,
|
521
|
+
forward_batch_output.next_token_ids,
|
522
|
+
)
|
447
523
|
return (
|
448
524
|
logits_output,
|
449
525
|
next_token_ids,
|
450
|
-
model_worker_batch.bid,
|
451
526
|
model_worker_batch.seq_lens_cpu,
|
452
527
|
)
|
453
528
|
|
@@ -479,6 +554,8 @@ class EAGLEWorker(TpModelWorker):
|
|
479
554
|
batch.seq_lens,
|
480
555
|
self.speculative_num_steps,
|
481
556
|
)
|
557
|
+
prefix_lens_cpu = batch.seq_lens_cpu
|
558
|
+
seq_lens_cpu = batch.seq_lens_cpu + self.speculative_num_steps
|
482
559
|
extend_num_tokens = num_seqs * self.speculative_num_steps
|
483
560
|
else:
|
484
561
|
# In this case, the last partial page needs to be duplicated.
|
@@ -514,14 +591,23 @@ class EAGLEWorker(TpModelWorker):
|
|
514
591
|
self.topk,
|
515
592
|
self.page_size,
|
516
593
|
)
|
517
|
-
|
518
|
-
|
519
|
-
|
594
|
+
prefix_lens_cpu = batch.seq_lens_cpu
|
595
|
+
last_page_lens = prefix_lens_cpu % self.page_size
|
596
|
+
num_new_pages_per_topk = (
|
597
|
+
last_page_lens + self.speculative_num_steps + self.page_size - 1
|
598
|
+
) // self.page_size
|
599
|
+
seq_lens_cpu = (
|
600
|
+
prefix_lens_cpu // self.page_size * self.page_size
|
601
|
+
+ num_new_pages_per_topk * (self.page_size * self.topk)
|
602
|
+
)
|
603
|
+
extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item()
|
520
604
|
|
521
605
|
out_cache_loc, token_to_kv_pool_state_backup = (
|
522
606
|
batch.alloc_paged_token_slots_extend(
|
523
607
|
prefix_lens,
|
608
|
+
prefix_lens_cpu,
|
524
609
|
seq_lens,
|
610
|
+
seq_lens_cpu,
|
525
611
|
last_loc,
|
526
612
|
extend_num_tokens,
|
527
613
|
backup_state=True,
|
@@ -683,6 +769,14 @@ class EAGLEWorker(TpModelWorker):
|
|
683
769
|
|
684
770
|
# Set inputs
|
685
771
|
forward_batch.input_ids = input_ids
|
772
|
+
# This is a temporary fix for the case that the user is using standalone
|
773
|
+
# speculative decoding and the draft model architecture is gpt-oss. gpt-oss
|
774
|
+
# rope kernel needs cache_loc to be contiguous.
|
775
|
+
if (
|
776
|
+
self.server_args.speculative_algorithm == "STANDALONE"
|
777
|
+
and self.model_config.hf_config.architectures[0] == "GptOssForCausalLM"
|
778
|
+
):
|
779
|
+
out_cache_loc = out_cache_loc.contiguous()
|
686
780
|
forward_batch.out_cache_loc = out_cache_loc[i]
|
687
781
|
forward_batch.positions.add_(1)
|
688
782
|
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
@@ -701,6 +795,10 @@ class EAGLEWorker(TpModelWorker):
|
|
701
795
|
|
702
796
|
return score_list, token_list, parents_list
|
703
797
|
|
798
|
+
def clear_cache_pool(self):
|
799
|
+
self.model_runner.req_to_token_pool.clear()
|
800
|
+
self.model_runner.token_to_kv_pool_allocator.clear()
|
801
|
+
|
704
802
|
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
705
803
|
spec_info.prepare_for_verify(batch, self.page_size)
|
706
804
|
batch.return_hidden_states = False
|
@@ -724,10 +822,12 @@ class EAGLEWorker(TpModelWorker):
|
|
724
822
|
).cpu()
|
725
823
|
|
726
824
|
# Forward
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
825
|
+
forward_batch_output = self.target_worker.forward_batch_generation(
|
826
|
+
model_worker_batch, is_verify=True
|
827
|
+
)
|
828
|
+
logits_output, can_run_cuda_graph = (
|
829
|
+
forward_batch_output.logits_output,
|
830
|
+
forward_batch_output.can_run_cuda_graph,
|
731
831
|
)
|
732
832
|
|
733
833
|
vocab_mask = None
|
@@ -767,6 +867,21 @@ class EAGLEWorker(TpModelWorker):
|
|
767
867
|
]
|
768
868
|
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
769
869
|
|
870
|
+
# QQ: can be optimized
|
871
|
+
if self.target_worker.model_runner.is_hybrid_gdn:
|
872
|
+
# res.draft_input.accept_length is on GPU but may be empty for last verify?
|
873
|
+
accepted_length = (
|
874
|
+
torch.tensor(
|
875
|
+
res.accept_length_per_req_cpu,
|
876
|
+
device=logits_output.hidden_states.device,
|
877
|
+
dtype=torch.int32,
|
878
|
+
)
|
879
|
+
+ 1
|
880
|
+
)
|
881
|
+
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
|
882
|
+
accepted_length, self.target_worker.model_runner.model
|
883
|
+
)
|
884
|
+
|
770
885
|
if batch.return_logprob:
|
771
886
|
self.add_logprob_values(batch, res, logits_output)
|
772
887
|
|
@@ -912,6 +1027,7 @@ class EAGLEWorker(TpModelWorker):
|
|
912
1027
|
assert isinstance(batch.spec_info, EagleDraftInput)
|
913
1028
|
# Backup fields that will be modified in-place
|
914
1029
|
seq_lens_backup = batch.seq_lens.clone()
|
1030
|
+
seq_lens_cpu_backup = batch.seq_lens_cpu.clone()
|
915
1031
|
req_pool_indices_backup = batch.req_pool_indices
|
916
1032
|
accept_length_backup = batch.spec_info.accept_length
|
917
1033
|
return_logprob_backup = batch.return_logprob
|
@@ -990,6 +1106,7 @@ class EAGLEWorker(TpModelWorker):
|
|
990
1106
|
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
|
991
1107
|
)
|
992
1108
|
batch.seq_lens = seq_lens_backup
|
1109
|
+
batch.seq_lens_cpu = seq_lens_cpu_backup
|
993
1110
|
batch.req_pool_indices = req_pool_indices_backup
|
994
1111
|
batch.spec_info.accept_length = accept_length_backup
|
995
1112
|
batch.return_logprob = return_logprob_backup
|