sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -48,18 +48,22 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
48
48
|
PPProxyTensors,
|
49
49
|
enable_num_token_non_padded,
|
50
50
|
)
|
51
|
-
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
52
51
|
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
53
52
|
from sglang.srt.utils import (
|
54
53
|
empty_context,
|
55
54
|
get_available_gpu_memory,
|
55
|
+
get_bool_env_var,
|
56
56
|
get_device_memory_capacity,
|
57
|
+
is_hip,
|
57
58
|
log_info_on_rank0,
|
58
59
|
require_attn_tp_gather,
|
59
60
|
require_gathered_buffer,
|
60
61
|
require_mlp_sync,
|
61
62
|
require_mlp_tp_gather,
|
62
63
|
)
|
64
|
+
from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
|
65
|
+
|
66
|
+
_is_hip = is_hip()
|
63
67
|
|
64
68
|
logger = logging.getLogger(__name__)
|
65
69
|
|
@@ -100,6 +104,7 @@ def freeze_gc(enable_cudagraph_gc: bool):
|
|
100
104
|
finally:
|
101
105
|
if should_freeze:
|
102
106
|
gc.unfreeze()
|
107
|
+
gc.collect()
|
103
108
|
|
104
109
|
|
105
110
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
@@ -136,7 +141,7 @@ def patch_model(
|
|
136
141
|
mode=os.environ.get(
|
137
142
|
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
|
138
143
|
),
|
139
|
-
dynamic=
|
144
|
+
dynamic=_is_hip and get_bool_env_var("SGLANG_TORCH_DYNAMIC_SHAPE"),
|
140
145
|
)
|
141
146
|
else:
|
142
147
|
yield model.forward
|
@@ -166,29 +171,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
166
171
|
server_args = model_runner.server_args
|
167
172
|
capture_bs = server_args.cuda_graph_bs
|
168
173
|
|
169
|
-
if capture_bs is None:
|
170
|
-
if server_args.speculative_algorithm is None:
|
171
|
-
if server_args.disable_cuda_graph_padding:
|
172
|
-
capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
|
173
|
-
else:
|
174
|
-
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
175
|
-
else:
|
176
|
-
# Since speculative decoding requires more cuda graph memory, we
|
177
|
-
# capture less.
|
178
|
-
capture_bs = (
|
179
|
-
list(range(1, 9))
|
180
|
-
+ list(range(10, 33, 2))
|
181
|
-
+ list(range(40, 64, 8))
|
182
|
-
+ list(range(80, 161, 16))
|
183
|
-
)
|
184
|
-
|
185
|
-
gpu_mem = get_device_memory_capacity()
|
186
|
-
if gpu_mem is not None:
|
187
|
-
if gpu_mem > 90 * 1024: # H200, H20
|
188
|
-
capture_bs += list(range(160, 257, 8))
|
189
|
-
if gpu_mem > 160 * 1000: # B200, MI300
|
190
|
-
capture_bs += list(range(256, 513, 16))
|
191
|
-
|
192
174
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
193
175
|
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
194
176
|
# is very small. We add more values here to make sure we capture the maximum bs.
|
@@ -204,12 +186,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
204
186
|
|
205
187
|
capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
|
206
188
|
|
207
|
-
if server_args.cuda_graph_max_bs:
|
208
|
-
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
209
|
-
if max(capture_bs) < server_args.cuda_graph_max_bs:
|
210
|
-
capture_bs += list(
|
211
|
-
range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
|
212
|
-
)
|
213
189
|
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
214
190
|
capture_bs = list(sorted(set(capture_bs)))
|
215
191
|
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
|
@@ -271,7 +247,11 @@ class CudaGraphRunner:
|
|
271
247
|
self.capture_forward_mode = ForwardMode.DECODE
|
272
248
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
273
249
|
self.num_tokens_per_bs = 1
|
274
|
-
if
|
250
|
+
if (
|
251
|
+
model_runner.spec_algorithm.is_eagle()
|
252
|
+
or model_runner.spec_algorithm.is_standalone()
|
253
|
+
or model_runner.spec_algorithm.is_ngram()
|
254
|
+
):
|
275
255
|
if self.model_runner.is_draft_worker:
|
276
256
|
raise RuntimeError("This should not happen")
|
277
257
|
else:
|
@@ -317,7 +297,9 @@ class CudaGraphRunner:
|
|
317
297
|
(self.max_num_token,), dtype=self._cache_loc_dtype()
|
318
298
|
)
|
319
299
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
320
|
-
self.mrope_positions = torch.zeros(
|
300
|
+
self.mrope_positions = torch.zeros(
|
301
|
+
(3, self.max_num_token), dtype=torch.int64
|
302
|
+
)
|
321
303
|
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
322
304
|
self.tbo_plugin = TboCudaGraphRunnerPlugin()
|
323
305
|
|
@@ -435,11 +417,21 @@ class CudaGraphRunner:
|
|
435
417
|
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
436
418
|
)
|
437
419
|
|
420
|
+
is_ngram_supported = (
|
421
|
+
(
|
422
|
+
forward_batch.batch_size * self.num_tokens_per_bs
|
423
|
+
== forward_batch.input_ids.numel()
|
424
|
+
)
|
425
|
+
if self.model_runner.spec_algorithm.is_ngram()
|
426
|
+
else True
|
427
|
+
)
|
428
|
+
|
438
429
|
return (
|
439
430
|
is_bs_supported
|
440
431
|
and is_encoder_lens_supported
|
441
432
|
and is_tbo_supported
|
442
433
|
and capture_hidden_mode_matches
|
434
|
+
and is_ngram_supported
|
443
435
|
)
|
444
436
|
|
445
437
|
def capture(self) -> None:
|
@@ -449,6 +441,7 @@ class CudaGraphRunner:
|
|
449
441
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
450
442
|
record_shapes=True,
|
451
443
|
)
|
444
|
+
torch.cuda.memory._record_memory_history()
|
452
445
|
|
453
446
|
# Trigger CUDA graph capture for specific shapes.
|
454
447
|
# Capture the large shapes first so that the smaller shapes
|
@@ -497,6 +490,8 @@ class CudaGraphRunner:
|
|
497
490
|
save_gemlite_cache()
|
498
491
|
|
499
492
|
if self.enable_profile_cuda_graph:
|
493
|
+
torch.cuda.memory._dump_snapshot(f"cuda_graph_runner_memory_usage.pickle")
|
494
|
+
torch.cuda.memory._record_memory_history(enabled=None)
|
500
495
|
log_message = (
|
501
496
|
"Sorted by CUDA Time:\n"
|
502
497
|
+ prof.key_averages(group_by_input_shape=True).table(
|
@@ -506,6 +501,7 @@ class CudaGraphRunner:
|
|
506
501
|
+ prof.key_averages(group_by_input_shape=True).table(
|
507
502
|
sort_by="cpu_time_total", row_limit=10
|
508
503
|
)
|
504
|
+
+ "\n\nMemory Usage is saved to cuda_graph_runner_memory_usage.pickle\n"
|
509
505
|
)
|
510
506
|
logger.info(log_message)
|
511
507
|
|
@@ -526,13 +522,14 @@ class CudaGraphRunner:
|
|
526
522
|
input_ids = self.input_ids[:num_tokens]
|
527
523
|
req_pool_indices = self.req_pool_indices[:bs]
|
528
524
|
seq_lens = self.seq_lens[:bs]
|
525
|
+
seq_lens_cpu = self.seq_lens_cpu[:bs]
|
529
526
|
out_cache_loc = self.out_cache_loc[:num_tokens]
|
530
527
|
positions = self.positions[:num_tokens]
|
531
528
|
if self.is_encoder_decoder:
|
532
529
|
encoder_lens = self.encoder_lens[:bs]
|
533
530
|
else:
|
534
531
|
encoder_lens = None
|
535
|
-
mrope_positions = self.mrope_positions[:, :
|
532
|
+
mrope_positions = self.mrope_positions[:, :num_tokens]
|
536
533
|
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
|
537
534
|
self.num_token_non_padded[...] = num_tokens
|
538
535
|
|
@@ -596,6 +593,7 @@ class CudaGraphRunner:
|
|
596
593
|
input_ids=input_ids,
|
597
594
|
req_pool_indices=req_pool_indices,
|
598
595
|
seq_lens=seq_lens,
|
596
|
+
seq_lens_cpu=seq_lens_cpu,
|
599
597
|
next_token_logits_buffer=next_token_logits_buffer,
|
600
598
|
orig_seq_lens=seq_lens,
|
601
599
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
@@ -751,7 +749,7 @@ class CudaGraphRunner:
|
|
751
749
|
if self.is_encoder_decoder:
|
752
750
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
753
751
|
if forward_batch.mrope_positions is not None:
|
754
|
-
self.mrope_positions[:, :
|
752
|
+
self.mrope_positions[:, :raw_num_token].copy_(forward_batch.mrope_positions)
|
755
753
|
if self.require_gathered_buffer:
|
756
754
|
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
757
755
|
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
@@ -825,8 +823,11 @@ class CudaGraphRunner:
|
|
825
823
|
|
826
824
|
def get_spec_info(self, num_tokens: int):
|
827
825
|
spec_info = None
|
828
|
-
if
|
829
|
-
|
826
|
+
if (
|
827
|
+
self.model_runner.spec_algorithm.is_eagle()
|
828
|
+
or self.model_runner.spec_algorithm.is_standalone()
|
829
|
+
):
|
830
|
+
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
830
831
|
|
831
832
|
if self.model_runner.is_draft_worker:
|
832
833
|
raise RuntimeError("This should not happen.")
|
@@ -847,6 +848,20 @@ class CudaGraphRunner:
|
|
847
848
|
seq_lens_cpu=None,
|
848
849
|
)
|
849
850
|
|
851
|
+
elif self.model_runner.spec_algorithm.is_ngram():
|
852
|
+
from sglang.srt.speculative.ngram_info import NgramVerifyInput
|
853
|
+
|
854
|
+
spec_info = NgramVerifyInput(
|
855
|
+
draft_token=None,
|
856
|
+
tree_mask=self.custom_mask,
|
857
|
+
positions=None,
|
858
|
+
retrive_index=None,
|
859
|
+
retrive_next_token=None,
|
860
|
+
retrive_next_sibling=None,
|
861
|
+
draft_token_num=self.num_tokens_per_bs,
|
862
|
+
)
|
863
|
+
spec_info.capture_hidden_mode = CaptureHiddenMode.NULL
|
864
|
+
|
850
865
|
return spec_info
|
851
866
|
|
852
867
|
|
@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
|
|
45
45
|
get_attention_tp_size,
|
46
46
|
set_dp_buffer_len,
|
47
47
|
)
|
48
|
-
from sglang.srt.
|
49
|
-
from sglang.srt.utils import (
|
50
|
-
flatten_nested_list,
|
51
|
-
get_compiler_backend,
|
52
|
-
is_npu,
|
53
|
-
support_triton,
|
54
|
-
)
|
48
|
+
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
|
55
49
|
|
56
50
|
if TYPE_CHECKING:
|
57
51
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -60,8 +54,7 @@ if TYPE_CHECKING:
|
|
60
54
|
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
61
55
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
62
56
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
63
|
-
from sglang.srt.speculative.
|
64
|
-
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
57
|
+
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
|
65
58
|
|
66
59
|
_is_npu = is_npu()
|
67
60
|
|
@@ -82,10 +75,6 @@ class ForwardMode(IntEnum):
|
|
82
75
|
# Used in speculative decoding: extend a batch in the draft model.
|
83
76
|
DRAFT_EXTEND = auto()
|
84
77
|
|
85
|
-
# A dummy first batch to start the pipeline for overlap scheduler.
|
86
|
-
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
87
|
-
DUMMY_FIRST = auto()
|
88
|
-
|
89
78
|
# Split Prefill for PD multiplexing
|
90
79
|
SPLIT_PREFILL = auto()
|
91
80
|
|
@@ -132,8 +121,8 @@ class ForwardMode(IntEnum):
|
|
132
121
|
or self == ForwardMode.IDLE
|
133
122
|
)
|
134
123
|
|
135
|
-
def
|
136
|
-
return self == ForwardMode.
|
124
|
+
def is_cpu_graph(self):
|
125
|
+
return self == ForwardMode.DECODE
|
137
126
|
|
138
127
|
def is_split_prefill(self):
|
139
128
|
return self == ForwardMode.SPLIT_PREFILL
|
@@ -289,14 +278,18 @@ class ForwardBatch:
|
|
289
278
|
can_run_dp_cuda_graph: bool = False
|
290
279
|
global_forward_mode: Optional[ForwardMode] = None
|
291
280
|
|
281
|
+
# Whether this batch is prefill-only (no token generation needed)
|
282
|
+
is_prefill_only: bool = False
|
283
|
+
|
292
284
|
# Speculative decoding
|
293
|
-
spec_info: Optional[
|
285
|
+
spec_info: Optional[SpecInput] = None
|
294
286
|
spec_algorithm: SpeculativeAlgorithm = None
|
295
287
|
capture_hidden_mode: CaptureHiddenMode = None
|
296
288
|
|
297
289
|
# For padding
|
298
290
|
padded_static_len: int = -1 # -1 if not padded
|
299
291
|
num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
|
292
|
+
num_token_non_padded_cpu: int = None
|
300
293
|
|
301
294
|
# For Qwen2-VL
|
302
295
|
mrope_positions: torch.Tensor = None
|
@@ -335,6 +328,7 @@ class ForwardBatch:
|
|
335
328
|
is_extend_in_batch=batch.is_extend_in_batch,
|
336
329
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
337
330
|
global_forward_mode=batch.global_forward_mode,
|
331
|
+
is_prefill_only=batch.is_prefill_only,
|
338
332
|
lora_ids=batch.lora_ids,
|
339
333
|
sampling_info=batch.sampling_info,
|
340
334
|
req_to_token_pool=model_runner.req_to_token_pool,
|
@@ -358,36 +352,18 @@ class ForwardBatch:
|
|
358
352
|
ret.num_token_non_padded = torch.tensor(
|
359
353
|
len(batch.input_ids), dtype=torch.int32
|
360
354
|
).to(device, non_blocking=True)
|
355
|
+
ret.num_token_non_padded_cpu = len(batch.input_ids)
|
361
356
|
|
362
357
|
# For MLP sync
|
363
358
|
if batch.global_num_tokens is not None:
|
364
|
-
from sglang.srt.speculative.eagle_utils import (
|
365
|
-
EagleDraftInput,
|
366
|
-
EagleVerifyInput,
|
367
|
-
)
|
368
|
-
|
369
359
|
assert batch.global_num_tokens_for_logprob is not None
|
360
|
+
|
370
361
|
# process global_num_tokens and global_num_tokens_for_logprob
|
371
362
|
if batch.spec_info is not None:
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
]
|
377
|
-
global_num_tokens_for_logprob = [
|
378
|
-
x * batch.spec_info.num_tokens_for_logprob_per_batch
|
379
|
-
for x in batch.global_num_tokens_for_logprob
|
380
|
-
]
|
381
|
-
else:
|
382
|
-
assert isinstance(batch.spec_info, EagleVerifyInput)
|
383
|
-
global_num_tokens = [
|
384
|
-
x * batch.spec_info.draft_token_num
|
385
|
-
for x in batch.global_num_tokens
|
386
|
-
]
|
387
|
-
global_num_tokens_for_logprob = [
|
388
|
-
x * batch.spec_info.draft_token_num
|
389
|
-
for x in batch.global_num_tokens_for_logprob
|
390
|
-
]
|
363
|
+
spec_info: SpecInput = batch.spec_info
|
364
|
+
global_num_tokens, global_num_tokens_for_logprob = (
|
365
|
+
spec_info.get_spec_adjusted_global_num_tokens(batch)
|
366
|
+
)
|
391
367
|
else:
|
392
368
|
global_num_tokens = batch.global_num_tokens
|
393
369
|
global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
|
@@ -441,7 +417,13 @@ class ForwardBatch:
|
|
441
417
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
442
418
|
|
443
419
|
if model_runner.model_is_mrope:
|
444
|
-
|
420
|
+
if (
|
421
|
+
ret.spec_info is not None
|
422
|
+
and getattr(ret.spec_info, "positions", None) is not None
|
423
|
+
):
|
424
|
+
ret._compute_spec_mrope_positions(model_runner, batch)
|
425
|
+
else:
|
426
|
+
ret._compute_mrope_positions(model_runner, batch)
|
445
427
|
|
446
428
|
# Init lora information
|
447
429
|
if model_runner.server_args.enable_lora:
|
@@ -507,6 +489,52 @@ class ForwardBatch:
|
|
507
489
|
or self.contains_image_inputs()
|
508
490
|
)
|
509
491
|
|
492
|
+
def _compute_spec_mrope_positions(
|
493
|
+
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
494
|
+
):
|
495
|
+
# TODO support batched deltas
|
496
|
+
batch_size = self.seq_lens.shape[0]
|
497
|
+
device = model_runner.device
|
498
|
+
mm_inputs = batch.multimodal_inputs
|
499
|
+
|
500
|
+
if batch.forward_mode.is_draft_extend(): # draft_extend_after_decode
|
501
|
+
mrope_deltas = []
|
502
|
+
extend_lens = []
|
503
|
+
for batch_idx in range(batch_size):
|
504
|
+
extend_seq_len = batch.extend_seq_lens[batch_idx]
|
505
|
+
extend_lens.append(extend_seq_len)
|
506
|
+
mrope_delta = (
|
507
|
+
torch.zeros(1, dtype=torch.int64)
|
508
|
+
if mm_inputs[batch_idx] is None
|
509
|
+
else mm_inputs[batch_idx].mrope_position_delta.squeeze(0)
|
510
|
+
)
|
511
|
+
mrope_deltas.append(mrope_delta.to(device=device))
|
512
|
+
position_chunks = torch.split(batch.spec_info.positions, extend_lens)
|
513
|
+
mrope_positions_list = [
|
514
|
+
pos_chunk + delta
|
515
|
+
for pos_chunk, delta in zip(position_chunks, mrope_deltas)
|
516
|
+
]
|
517
|
+
next_input_positions = (
|
518
|
+
torch.cat(mrope_positions_list, dim=0).unsqueeze(0).repeat(3, 1)
|
519
|
+
)
|
520
|
+
|
521
|
+
else: # target_verify or draft_decode
|
522
|
+
seq_positions = batch.spec_info.positions.view(batch_size, -1)
|
523
|
+
mrope_deltas = [
|
524
|
+
(
|
525
|
+
torch.tensor([0], dtype=torch.int64)
|
526
|
+
if mm_inputs[i] is None
|
527
|
+
else mm_inputs[i].mrope_position_delta.squeeze(0)
|
528
|
+
)
|
529
|
+
for i in range(batch_size)
|
530
|
+
]
|
531
|
+
mrope_delta_tensor = torch.stack(mrope_deltas, dim=0).to(device=device)
|
532
|
+
next_input_positions = (
|
533
|
+
(seq_positions + mrope_delta_tensor).flatten().unsqueeze(0).repeat(3, 1)
|
534
|
+
)
|
535
|
+
|
536
|
+
self.mrope_positions = next_input_positions
|
537
|
+
|
510
538
|
def _compute_mrope_positions(
|
511
539
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
512
540
|
):
|
@@ -614,9 +642,6 @@ class ForwardBatch:
|
|
614
642
|
)
|
615
643
|
|
616
644
|
def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
|
617
|
-
|
618
|
-
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
619
|
-
|
620
645
|
assert self.global_num_tokens_cpu is not None
|
621
646
|
assert self.global_num_tokens_for_logprob_cpu is not None
|
622
647
|
|
@@ -631,7 +656,9 @@ class ForwardBatch:
|
|
631
656
|
(global_num_tokens[i] - 1) // attn_tp_size + 1
|
632
657
|
) * attn_tp_size
|
633
658
|
|
634
|
-
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(
|
659
|
+
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(
|
660
|
+
self.is_extend_in_batch, global_num_tokens
|
661
|
+
)
|
635
662
|
self.dp_padding_mode = dp_padding_mode
|
636
663
|
|
637
664
|
if dp_padding_mode.is_max_len():
|
@@ -711,7 +738,8 @@ class ForwardBatch:
|
|
711
738
|
if self.extend_seq_lens is not None:
|
712
739
|
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
|
713
740
|
|
714
|
-
if self.spec_info is not None and
|
741
|
+
if self.spec_info is not None and self.spec_info.is_draft_input():
|
742
|
+
# FIXME(lsyin): remove this isinstance logic
|
715
743
|
spec_info = self.spec_info
|
716
744
|
self.output_cache_loc_backup = self.out_cache_loc
|
717
745
|
self.hidden_states_backup = spec_info.hidden_states
|