sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -48,18 +48,22 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
48
48
|
PPProxyTensors,
|
49
49
|
enable_num_token_non_padded,
|
50
50
|
)
|
51
|
-
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
52
51
|
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
53
52
|
from sglang.srt.utils import (
|
54
53
|
empty_context,
|
55
54
|
get_available_gpu_memory,
|
55
|
+
get_bool_env_var,
|
56
56
|
get_device_memory_capacity,
|
57
|
+
is_hip,
|
57
58
|
log_info_on_rank0,
|
58
59
|
require_attn_tp_gather,
|
59
60
|
require_gathered_buffer,
|
60
61
|
require_mlp_sync,
|
61
62
|
require_mlp_tp_gather,
|
62
63
|
)
|
64
|
+
from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
|
65
|
+
|
66
|
+
_is_hip = is_hip()
|
63
67
|
|
64
68
|
logger = logging.getLogger(__name__)
|
65
69
|
|
@@ -100,6 +104,7 @@ def freeze_gc(enable_cudagraph_gc: bool):
|
|
100
104
|
finally:
|
101
105
|
if should_freeze:
|
102
106
|
gc.unfreeze()
|
107
|
+
gc.collect()
|
103
108
|
|
104
109
|
|
105
110
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
@@ -136,7 +141,7 @@ def patch_model(
|
|
136
141
|
mode=os.environ.get(
|
137
142
|
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
|
138
143
|
),
|
139
|
-
dynamic=
|
144
|
+
dynamic=_is_hip and get_bool_env_var("SGLANG_TORCH_DYNAMIC_SHAPE"),
|
140
145
|
)
|
141
146
|
else:
|
142
147
|
yield model.forward
|
@@ -166,29 +171,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
166
171
|
server_args = model_runner.server_args
|
167
172
|
capture_bs = server_args.cuda_graph_bs
|
168
173
|
|
169
|
-
if capture_bs is None:
|
170
|
-
if server_args.speculative_algorithm is None:
|
171
|
-
if server_args.disable_cuda_graph_padding:
|
172
|
-
capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
|
173
|
-
else:
|
174
|
-
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
175
|
-
else:
|
176
|
-
# Since speculative decoding requires more cuda graph memory, we
|
177
|
-
# capture less.
|
178
|
-
capture_bs = (
|
179
|
-
list(range(1, 9))
|
180
|
-
+ list(range(10, 33, 2))
|
181
|
-
+ list(range(40, 64, 8))
|
182
|
-
+ list(range(80, 161, 16))
|
183
|
-
)
|
184
|
-
|
185
|
-
gpu_mem = get_device_memory_capacity()
|
186
|
-
if gpu_mem is not None:
|
187
|
-
if gpu_mem > 90 * 1024: # H200, H20
|
188
|
-
capture_bs += list(range(160, 257, 8))
|
189
|
-
if gpu_mem > 160 * 1000: # B200, MI300
|
190
|
-
capture_bs += list(range(256, 513, 16))
|
191
|
-
|
192
174
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
193
175
|
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
194
176
|
# is very small. We add more values here to make sure we capture the maximum bs.
|
@@ -204,12 +186,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
204
186
|
|
205
187
|
capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
|
206
188
|
|
207
|
-
if server_args.cuda_graph_max_bs:
|
208
|
-
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
209
|
-
if max(capture_bs) < server_args.cuda_graph_max_bs:
|
210
|
-
capture_bs += list(
|
211
|
-
range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
|
212
|
-
)
|
213
189
|
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
214
190
|
capture_bs = list(sorted(set(capture_bs)))
|
215
191
|
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
|
@@ -271,7 +247,11 @@ class CudaGraphRunner:
|
|
271
247
|
self.capture_forward_mode = ForwardMode.DECODE
|
272
248
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
273
249
|
self.num_tokens_per_bs = 1
|
274
|
-
if
|
250
|
+
if (
|
251
|
+
model_runner.spec_algorithm.is_eagle()
|
252
|
+
or model_runner.spec_algorithm.is_standalone()
|
253
|
+
or model_runner.spec_algorithm.is_ngram()
|
254
|
+
):
|
275
255
|
if self.model_runner.is_draft_worker:
|
276
256
|
raise RuntimeError("This should not happen")
|
277
257
|
else:
|
@@ -317,7 +297,9 @@ class CudaGraphRunner:
|
|
317
297
|
(self.max_num_token,), dtype=self._cache_loc_dtype()
|
318
298
|
)
|
319
299
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
320
|
-
self.mrope_positions = torch.zeros(
|
300
|
+
self.mrope_positions = torch.zeros(
|
301
|
+
(3, self.max_num_token), dtype=torch.int64
|
302
|
+
)
|
321
303
|
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
322
304
|
self.tbo_plugin = TboCudaGraphRunnerPlugin()
|
323
305
|
|
@@ -435,11 +417,21 @@ class CudaGraphRunner:
|
|
435
417
|
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
436
418
|
)
|
437
419
|
|
420
|
+
is_ngram_supported = (
|
421
|
+
(
|
422
|
+
forward_batch.batch_size * self.num_tokens_per_bs
|
423
|
+
== forward_batch.input_ids.numel()
|
424
|
+
)
|
425
|
+
if self.model_runner.spec_algorithm.is_ngram()
|
426
|
+
else True
|
427
|
+
)
|
428
|
+
|
438
429
|
return (
|
439
430
|
is_bs_supported
|
440
431
|
and is_encoder_lens_supported
|
441
432
|
and is_tbo_supported
|
442
433
|
and capture_hidden_mode_matches
|
434
|
+
and is_ngram_supported
|
443
435
|
)
|
444
436
|
|
445
437
|
def capture(self) -> None:
|
@@ -449,6 +441,7 @@ class CudaGraphRunner:
|
|
449
441
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
450
442
|
record_shapes=True,
|
451
443
|
)
|
444
|
+
torch.cuda.memory._record_memory_history()
|
452
445
|
|
453
446
|
# Trigger CUDA graph capture for specific shapes.
|
454
447
|
# Capture the large shapes first so that the smaller shapes
|
@@ -497,6 +490,8 @@ class CudaGraphRunner:
|
|
497
490
|
save_gemlite_cache()
|
498
491
|
|
499
492
|
if self.enable_profile_cuda_graph:
|
493
|
+
torch.cuda.memory._dump_snapshot(f"cuda_graph_runner_memory_usage.pickle")
|
494
|
+
torch.cuda.memory._record_memory_history(enabled=None)
|
500
495
|
log_message = (
|
501
496
|
"Sorted by CUDA Time:\n"
|
502
497
|
+ prof.key_averages(group_by_input_shape=True).table(
|
@@ -506,6 +501,7 @@ class CudaGraphRunner:
|
|
506
501
|
+ prof.key_averages(group_by_input_shape=True).table(
|
507
502
|
sort_by="cpu_time_total", row_limit=10
|
508
503
|
)
|
504
|
+
+ "\n\nMemory Usage is saved to cuda_graph_runner_memory_usage.pickle\n"
|
509
505
|
)
|
510
506
|
logger.info(log_message)
|
511
507
|
|
@@ -526,13 +522,14 @@ class CudaGraphRunner:
|
|
526
522
|
input_ids = self.input_ids[:num_tokens]
|
527
523
|
req_pool_indices = self.req_pool_indices[:bs]
|
528
524
|
seq_lens = self.seq_lens[:bs]
|
525
|
+
seq_lens_cpu = self.seq_lens_cpu[:bs]
|
529
526
|
out_cache_loc = self.out_cache_loc[:num_tokens]
|
530
527
|
positions = self.positions[:num_tokens]
|
531
528
|
if self.is_encoder_decoder:
|
532
529
|
encoder_lens = self.encoder_lens[:bs]
|
533
530
|
else:
|
534
531
|
encoder_lens = None
|
535
|
-
mrope_positions = self.mrope_positions[:, :
|
532
|
+
mrope_positions = self.mrope_positions[:, :num_tokens]
|
536
533
|
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
|
537
534
|
self.num_token_non_padded[...] = num_tokens
|
538
535
|
|
@@ -596,6 +593,7 @@ class CudaGraphRunner:
|
|
596
593
|
input_ids=input_ids,
|
597
594
|
req_pool_indices=req_pool_indices,
|
598
595
|
seq_lens=seq_lens,
|
596
|
+
seq_lens_cpu=seq_lens_cpu,
|
599
597
|
next_token_logits_buffer=next_token_logits_buffer,
|
600
598
|
orig_seq_lens=seq_lens,
|
601
599
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
@@ -751,7 +749,7 @@ class CudaGraphRunner:
|
|
751
749
|
if self.is_encoder_decoder:
|
752
750
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
753
751
|
if forward_batch.mrope_positions is not None:
|
754
|
-
self.mrope_positions[:, :
|
752
|
+
self.mrope_positions[:, :raw_num_token].copy_(forward_batch.mrope_positions)
|
755
753
|
if self.require_gathered_buffer:
|
756
754
|
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
757
755
|
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
@@ -825,8 +823,11 @@ class CudaGraphRunner:
|
|
825
823
|
|
826
824
|
def get_spec_info(self, num_tokens: int):
|
827
825
|
spec_info = None
|
828
|
-
if
|
829
|
-
|
826
|
+
if (
|
827
|
+
self.model_runner.spec_algorithm.is_eagle()
|
828
|
+
or self.model_runner.spec_algorithm.is_standalone()
|
829
|
+
):
|
830
|
+
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
830
831
|
|
831
832
|
if self.model_runner.is_draft_worker:
|
832
833
|
raise RuntimeError("This should not happen.")
|
@@ -847,6 +848,20 @@ class CudaGraphRunner:
|
|
847
848
|
seq_lens_cpu=None,
|
848
849
|
)
|
849
850
|
|
851
|
+
elif self.model_runner.spec_algorithm.is_ngram():
|
852
|
+
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
853
|
+
|
854
|
+
spec_info = NgramVerifyInput(
|
855
|
+
draft_token=None,
|
856
|
+
tree_mask=self.custom_mask,
|
857
|
+
positions=None,
|
858
|
+
retrive_index=None,
|
859
|
+
retrive_next_token=None,
|
860
|
+
retrive_next_sibling=None,
|
861
|
+
draft_token_num=self.num_tokens_per_bs,
|
862
|
+
)
|
863
|
+
spec_info.capture_hidden_mode = CaptureHiddenMode.NULL
|
864
|
+
|
850
865
|
return spec_info
|
851
866
|
|
852
867
|
|
@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
|
|
45
45
|
get_attention_tp_size,
|
46
46
|
set_dp_buffer_len,
|
47
47
|
)
|
48
|
-
from sglang.srt.
|
49
|
-
from sglang.srt.utils import (
|
50
|
-
flatten_nested_list,
|
51
|
-
get_compiler_backend,
|
52
|
-
is_npu,
|
53
|
-
support_triton,
|
54
|
-
)
|
48
|
+
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
|
55
49
|
|
56
50
|
if TYPE_CHECKING:
|
57
51
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -60,8 +54,7 @@ if TYPE_CHECKING:
|
|
60
54
|
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
61
55
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
62
56
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
63
|
-
from sglang.srt.speculative.
|
64
|
-
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
57
|
+
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
|
65
58
|
|
66
59
|
_is_npu = is_npu()
|
67
60
|
|
@@ -132,6 +125,9 @@ class ForwardMode(IntEnum):
|
|
132
125
|
or self == ForwardMode.IDLE
|
133
126
|
)
|
134
127
|
|
128
|
+
def is_cpu_graph(self):
|
129
|
+
return self == ForwardMode.DECODE
|
130
|
+
|
135
131
|
def is_dummy_first(self):
|
136
132
|
return self == ForwardMode.DUMMY_FIRST
|
137
133
|
|
@@ -290,13 +286,14 @@ class ForwardBatch:
|
|
290
286
|
global_forward_mode: Optional[ForwardMode] = None
|
291
287
|
|
292
288
|
# Speculative decoding
|
293
|
-
spec_info: Optional[
|
289
|
+
spec_info: Optional[SpecInput] = None
|
294
290
|
spec_algorithm: SpeculativeAlgorithm = None
|
295
291
|
capture_hidden_mode: CaptureHiddenMode = None
|
296
292
|
|
297
293
|
# For padding
|
298
294
|
padded_static_len: int = -1 # -1 if not padded
|
299
295
|
num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
|
296
|
+
num_token_non_padded_cpu: int = None
|
300
297
|
|
301
298
|
# For Qwen2-VL
|
302
299
|
mrope_positions: torch.Tensor = None
|
@@ -358,36 +355,18 @@ class ForwardBatch:
|
|
358
355
|
ret.num_token_non_padded = torch.tensor(
|
359
356
|
len(batch.input_ids), dtype=torch.int32
|
360
357
|
).to(device, non_blocking=True)
|
358
|
+
ret.num_token_non_padded_cpu = len(batch.input_ids)
|
361
359
|
|
362
360
|
# For MLP sync
|
363
361
|
if batch.global_num_tokens is not None:
|
364
|
-
from sglang.srt.speculative.eagle_utils import (
|
365
|
-
EagleDraftInput,
|
366
|
-
EagleVerifyInput,
|
367
|
-
)
|
368
|
-
|
369
362
|
assert batch.global_num_tokens_for_logprob is not None
|
363
|
+
|
370
364
|
# process global_num_tokens and global_num_tokens_for_logprob
|
371
365
|
if batch.spec_info is not None:
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
]
|
377
|
-
global_num_tokens_for_logprob = [
|
378
|
-
x * batch.spec_info.num_tokens_for_logprob_per_batch
|
379
|
-
for x in batch.global_num_tokens_for_logprob
|
380
|
-
]
|
381
|
-
else:
|
382
|
-
assert isinstance(batch.spec_info, EagleVerifyInput)
|
383
|
-
global_num_tokens = [
|
384
|
-
x * batch.spec_info.draft_token_num
|
385
|
-
for x in batch.global_num_tokens
|
386
|
-
]
|
387
|
-
global_num_tokens_for_logprob = [
|
388
|
-
x * batch.spec_info.draft_token_num
|
389
|
-
for x in batch.global_num_tokens_for_logprob
|
390
|
-
]
|
366
|
+
spec_info: SpecInput = batch.spec_info
|
367
|
+
global_num_tokens, global_num_tokens_for_logprob = (
|
368
|
+
spec_info.get_spec_adjusted_global_num_tokens(batch)
|
369
|
+
)
|
391
370
|
else:
|
392
371
|
global_num_tokens = batch.global_num_tokens
|
393
372
|
global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
|
@@ -441,7 +420,13 @@ class ForwardBatch:
|
|
441
420
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
442
421
|
|
443
422
|
if model_runner.model_is_mrope:
|
444
|
-
|
423
|
+
if (
|
424
|
+
ret.spec_info is not None
|
425
|
+
and getattr(ret.spec_info, "positions", None) is not None
|
426
|
+
):
|
427
|
+
ret._compute_spec_mrope_positions(model_runner, batch)
|
428
|
+
else:
|
429
|
+
ret._compute_mrope_positions(model_runner, batch)
|
445
430
|
|
446
431
|
# Init lora information
|
447
432
|
if model_runner.server_args.enable_lora:
|
@@ -507,6 +492,52 @@ class ForwardBatch:
|
|
507
492
|
or self.contains_image_inputs()
|
508
493
|
)
|
509
494
|
|
495
|
+
def _compute_spec_mrope_positions(
|
496
|
+
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
497
|
+
):
|
498
|
+
# TODO support batched deltas
|
499
|
+
batch_size = self.seq_lens.shape[0]
|
500
|
+
device = model_runner.device
|
501
|
+
mm_inputs = batch.multimodal_inputs
|
502
|
+
|
503
|
+
if batch.forward_mode.is_draft_extend(): # draft_extend_after_decode
|
504
|
+
mrope_deltas = []
|
505
|
+
extend_lens = []
|
506
|
+
for batch_idx in range(batch_size):
|
507
|
+
extend_seq_len = batch.extend_seq_lens[batch_idx]
|
508
|
+
extend_lens.append(extend_seq_len)
|
509
|
+
mrope_delta = (
|
510
|
+
torch.zeros(1, dtype=torch.int64)
|
511
|
+
if mm_inputs[batch_idx] is None
|
512
|
+
else mm_inputs[batch_idx].mrope_position_delta.squeeze(0)
|
513
|
+
)
|
514
|
+
mrope_deltas.append(mrope_delta.to(device=device))
|
515
|
+
position_chunks = torch.split(batch.spec_info.positions, extend_lens)
|
516
|
+
mrope_positions_list = [
|
517
|
+
pos_chunk + delta
|
518
|
+
for pos_chunk, delta in zip(position_chunks, mrope_deltas)
|
519
|
+
]
|
520
|
+
next_input_positions = (
|
521
|
+
torch.cat(mrope_positions_list, dim=0).unsqueeze(0).repeat(3, 1)
|
522
|
+
)
|
523
|
+
|
524
|
+
else: # target_verify or draft_decode
|
525
|
+
seq_positions = batch.spec_info.positions.view(batch_size, -1)
|
526
|
+
mrope_deltas = [
|
527
|
+
(
|
528
|
+
torch.tensor([0], dtype=torch.int64)
|
529
|
+
if mm_inputs[i] is None
|
530
|
+
else mm_inputs[i].mrope_position_delta.squeeze(0)
|
531
|
+
)
|
532
|
+
for i in range(batch_size)
|
533
|
+
]
|
534
|
+
mrope_delta_tensor = torch.stack(mrope_deltas, dim=0).to(device=device)
|
535
|
+
next_input_positions = (
|
536
|
+
(seq_positions + mrope_delta_tensor).flatten().unsqueeze(0).repeat(3, 1)
|
537
|
+
)
|
538
|
+
|
539
|
+
self.mrope_positions = next_input_positions
|
540
|
+
|
510
541
|
def _compute_mrope_positions(
|
511
542
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
512
543
|
):
|
@@ -614,9 +645,6 @@ class ForwardBatch:
|
|
614
645
|
)
|
615
646
|
|
616
647
|
def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
|
617
|
-
|
618
|
-
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
619
|
-
|
620
648
|
assert self.global_num_tokens_cpu is not None
|
621
649
|
assert self.global_num_tokens_for_logprob_cpu is not None
|
622
650
|
|
@@ -631,7 +659,9 @@ class ForwardBatch:
|
|
631
659
|
(global_num_tokens[i] - 1) // attn_tp_size + 1
|
632
660
|
) * attn_tp_size
|
633
661
|
|
634
|
-
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(
|
662
|
+
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(
|
663
|
+
self.is_extend_in_batch, global_num_tokens
|
664
|
+
)
|
635
665
|
self.dp_padding_mode = dp_padding_mode
|
636
666
|
|
637
667
|
if dp_padding_mode.is_max_len():
|
@@ -711,7 +741,8 @@ class ForwardBatch:
|
|
711
741
|
if self.extend_seq_lens is not None:
|
712
742
|
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
|
713
743
|
|
714
|
-
if self.spec_info is not None and
|
744
|
+
if self.spec_info is not None and self.spec_info.is_draft_input():
|
745
|
+
# FIXME(lsyin): remove this isinstance logic
|
715
746
|
spec_info = self.spec_info
|
716
747
|
self.output_cache_loc_backup = self.out_cache_loc
|
717
748
|
self.hidden_states_backup = spec_info.hidden_states
|
@@ -871,6 +902,17 @@ class ForwardBatch:
|
|
871
902
|
return self.tbo_split_seq_index is not None
|
872
903
|
|
873
904
|
|
905
|
+
@dataclass
|
906
|
+
class ForwardBatchOutput:
|
907
|
+
# FIXME(lsyin): unify the forward batch output between different spec and parallelism
|
908
|
+
# need to be more organized
|
909
|
+
logits_output: Optional[torch.Tensor] = None
|
910
|
+
next_token_ids: Optional[torch.Tensor] = None
|
911
|
+
num_accepted_tokens: Optional[int] = None
|
912
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None
|
913
|
+
can_run_cuda_graph: bool = False
|
914
|
+
|
915
|
+
|
874
916
|
def enable_num_token_non_padded(server_args):
|
875
917
|
return get_moe_expert_parallel_world_size() > 1
|
876
918
|
|