sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ from enum import Enum, auto
|
|
17
17
|
from typing import Any, List, Optional
|
18
18
|
|
19
19
|
from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
|
20
|
-
from sglang.srt.poll_based_barrier import PollBasedBarrier
|
20
|
+
from sglang.srt.utils.poll_based_barrier import PollBasedBarrier
|
21
21
|
|
22
22
|
logger = logging.getLogger(__name__)
|
23
23
|
|
@@ -12,7 +12,6 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
12
12
|
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
13
13
|
from sglang.srt.managers.schedule_policy import PrefillAdder
|
14
14
|
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
15
|
-
from sglang.srt.managers.utils import DPBalanceMeta
|
16
15
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
17
16
|
from sglang.srt.utils import get_bool_env_var
|
18
17
|
|
@@ -47,8 +46,11 @@ class SchedulerMetricsMixin:
|
|
47
46
|
self.spec_num_total_forward_ct = 0
|
48
47
|
self.cum_spec_accept_length = 0
|
49
48
|
self.cum_spec_accept_count = 0
|
50
|
-
self.
|
49
|
+
self.kv_transfer_speed_gb_s: float = 0.0
|
50
|
+
self.kv_transfer_latency_ms: float = 0.0
|
51
|
+
|
51
52
|
self.stats = SchedulerStats()
|
53
|
+
|
52
54
|
if self.enable_metrics:
|
53
55
|
engine_type = "unified"
|
54
56
|
labels = {
|
@@ -61,33 +63,30 @@ class SchedulerMetricsMixin:
|
|
61
63
|
labels["dp_rank"] = dp_rank
|
62
64
|
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
63
65
|
|
64
|
-
def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
|
65
|
-
self.balance_meta = dp_balance_meta
|
66
|
-
if (
|
67
|
-
self.server_args.enable_dp_attention
|
68
|
-
and self.server_args.load_balance_method == "minimum_tokens"
|
69
|
-
):
|
70
|
-
assert dp_balance_meta is not None
|
71
|
-
|
72
|
-
self.recv_dp_balance_id_this_term = []
|
73
|
-
|
74
66
|
def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
|
75
67
|
if self.enable_kv_cache_events:
|
76
68
|
self.kv_event_publisher = EventPublisherFactory.create(
|
77
69
|
kv_events_config, self.attn_dp_rank
|
78
70
|
)
|
79
71
|
|
72
|
+
def udpate_spec_metrics(self, bs: int, num_accepted_tokens: int):
|
73
|
+
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
|
74
|
+
self.spec_num_total_forward_ct += bs
|
75
|
+
self.num_generated_tokens += num_accepted_tokens
|
76
|
+
|
80
77
|
def log_prefill_stats(
|
81
78
|
self: Scheduler,
|
82
79
|
adder: PrefillAdder,
|
83
80
|
can_run_list: List[Req],
|
84
81
|
running_bs: int,
|
82
|
+
running_bs_offline_batch: int,
|
85
83
|
):
|
86
84
|
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
|
87
85
|
self.last_prefill_stats_tic = time.perf_counter()
|
88
86
|
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
89
87
|
self.last_prefill_tokens = adder.log_input_tokens
|
90
88
|
|
89
|
+
# TODO: generalize this for various memory pools
|
91
90
|
if self.is_hybrid:
|
92
91
|
(
|
93
92
|
full_num_used,
|
@@ -101,51 +100,53 @@ class SchedulerMetricsMixin:
|
|
101
100
|
) = self._get_swa_token_info()
|
102
101
|
num_used = max(full_num_used, swa_num_used)
|
103
102
|
token_usage = max(full_token_usage, swa_token_usage)
|
104
|
-
|
103
|
+
token_usage_msg = (
|
105
104
|
f"full token usage: {full_token_usage:.2f}, "
|
106
105
|
f"swa token usage: {swa_token_usage:.2f}, "
|
107
106
|
)
|
108
107
|
else:
|
109
108
|
num_used, token_usage, _, _ = self._get_token_info()
|
110
|
-
|
109
|
+
token_usage_msg = f"token usage: {token_usage:.2f}, "
|
111
110
|
|
112
|
-
num_new_seq = len(can_run_list)
|
113
111
|
f = (
|
114
112
|
f"Prefill batch. "
|
115
|
-
f"#new-seq: {
|
113
|
+
f"#new-seq: {len(can_run_list)}, "
|
116
114
|
f"#new-token: {adder.log_input_tokens}, "
|
117
115
|
f"#cached-token: {adder.log_hit_tokens}, "
|
118
|
-
f"{
|
116
|
+
f"{token_usage_msg}"
|
117
|
+
f"#running-req: {running_bs}, "
|
118
|
+
f"#queue-req: {len(self.waiting_queue)}, "
|
119
119
|
)
|
120
120
|
|
121
121
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
122
|
-
f += f"#
|
123
|
-
f += f"#
|
124
|
-
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
|
125
|
-
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
|
126
|
-
else:
|
127
|
-
f += f"#running-req: {running_bs}, "
|
128
|
-
f += f"#queue-req: {len(self.waiting_queue)}, "
|
122
|
+
f += f"#prealloc-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
123
|
+
f += f"#inflight-req: {len(self.disagg_prefill_inflight_queue)}, "
|
129
124
|
|
130
125
|
logger.info(f)
|
131
126
|
|
132
127
|
if self.enable_metrics:
|
128
|
+
# Basics
|
133
129
|
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
|
134
|
-
|
135
130
|
cache_hit_rate = (
|
136
131
|
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
|
137
132
|
)
|
133
|
+
|
138
134
|
self.stats.num_running_reqs = running_bs
|
135
|
+
self.stats.num_running_reqs_offline_batch = running_bs_offline_batch
|
139
136
|
self.stats.num_used_tokens = num_used
|
140
|
-
self.stats.token_usage =
|
137
|
+
self.stats.token_usage = token_usage
|
138
|
+
if self.is_hybrid:
|
139
|
+
self.stats.swa_token_usage = swa_token_usage
|
141
140
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
141
|
+
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
142
142
|
self.stats.cache_hit_rate = cache_hit_rate
|
143
143
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
self.
|
144
|
+
# Retract
|
145
|
+
self.stats.num_retracted_reqs = self.num_retracted_reqs
|
146
|
+
self.stats.num_paused_reqs = self.num_paused_reqs
|
147
|
+
self.num_retracted_reqs = self.num_paused_reqs = 0
|
148
148
|
|
149
|
+
# PD disaggregation
|
149
150
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
150
151
|
self.stats.num_prefill_prealloc_queue_reqs = len(
|
151
152
|
self.disagg_prefill_bootstrap_queue.queue
|
@@ -153,7 +154,18 @@ class SchedulerMetricsMixin:
|
|
153
154
|
self.stats.num_prefill_inflight_queue_reqs = len(
|
154
155
|
self.disagg_prefill_inflight_queue
|
155
156
|
)
|
157
|
+
self.stats.kv_transfer_speed_gb_s = self.kv_transfer_speed_gb_s
|
158
|
+
self.stats.kv_transfer_latency_ms = self.kv_transfer_latency_ms
|
159
|
+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
160
|
+
self.stats.num_decode_prealloc_queue_reqs = len(
|
161
|
+
self.disagg_decode_prealloc_queue.queue
|
162
|
+
)
|
163
|
+
self.stats.num_decode_transfer_queue_reqs = len(
|
164
|
+
self.disagg_decode_transfer_queue.queue
|
165
|
+
)
|
156
166
|
|
167
|
+
# Others
|
168
|
+
self.calculate_utilization()
|
157
169
|
self.metrics_collector.log_stats(self.stats)
|
158
170
|
self._emit_kv_metrics()
|
159
171
|
self._publish_kv_events()
|
@@ -166,8 +178,12 @@ class SchedulerMetricsMixin:
|
|
166
178
|
gap_latency = time.perf_counter() - self.last_decode_stats_tic
|
167
179
|
self.last_decode_stats_tic = time.perf_counter()
|
168
180
|
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
181
|
+
|
169
182
|
self.num_generated_tokens = 0
|
170
183
|
num_running_reqs = len(batch.reqs)
|
184
|
+
num_running_reqs_offline_batch = 0
|
185
|
+
|
186
|
+
# TODO: generalize this for various memory pools
|
171
187
|
if self.is_hybrid:
|
172
188
|
(
|
173
189
|
full_num_used,
|
@@ -181,7 +197,7 @@ class SchedulerMetricsMixin:
|
|
181
197
|
) = self._get_swa_token_info()
|
182
198
|
num_used = max(full_num_used, swa_num_used)
|
183
199
|
token_usage = max(full_token_usage, swa_token_usage)
|
184
|
-
|
200
|
+
token_usage_msg = (
|
185
201
|
f"#full token: {full_num_used}, "
|
186
202
|
f"full token usage: {full_token_usage:.2f}, "
|
187
203
|
f"#swa token: {swa_num_used}, "
|
@@ -189,14 +205,14 @@ class SchedulerMetricsMixin:
|
|
189
205
|
)
|
190
206
|
else:
|
191
207
|
num_used, token_usage, _, _ = self._get_token_info()
|
192
|
-
|
208
|
+
token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, "
|
193
209
|
|
194
210
|
if RECORD_STEP_TIME:
|
195
211
|
self.step_time_dict[num_running_reqs].append(
|
196
212
|
gap_latency / self.server_args.decode_log_interval
|
197
213
|
)
|
198
214
|
|
199
|
-
msg = f"Decode batch. #running-req: {num_running_reqs}, {
|
215
|
+
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_usage_msg}"
|
200
216
|
|
201
217
|
if self.spec_algorithm.is_none():
|
202
218
|
spec_accept_length = 0
|
@@ -208,40 +224,66 @@ class SchedulerMetricsMixin:
|
|
208
224
|
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
209
225
|
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
210
226
|
msg += f"accept len: {spec_accept_length:.2f}, "
|
227
|
+
cache_hit_rate = 0.0
|
211
228
|
|
212
229
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
213
230
|
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
231
|
+
msg += f"#prealloc-req: {len(self.disagg_decode_prealloc_queue.queue)}, "
|
232
|
+
msg += f"#transfer-req: {len(self.disagg_decode_transfer_queue.queue)}, "
|
214
233
|
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
215
234
|
|
216
235
|
msg += (
|
217
|
-
f"cuda graph: {can_run_cuda_graph}, "
|
236
|
+
f"{'cuda graph' if self.device == 'cuda' else 'cpu graph'}: {can_run_cuda_graph}, "
|
218
237
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
219
238
|
f"#queue-req: {len(self.waiting_queue)}, "
|
220
239
|
)
|
221
240
|
|
222
241
|
logger.info(msg)
|
223
242
|
if self.enable_metrics:
|
243
|
+
# Basics
|
224
244
|
self.stats.num_running_reqs = num_running_reqs
|
245
|
+
self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch
|
225
246
|
self.stats.num_used_tokens = num_used
|
226
|
-
self.stats.token_usage =
|
227
|
-
self.
|
247
|
+
self.stats.token_usage = token_usage
|
248
|
+
if self.is_hybrid:
|
249
|
+
self.stats.swa_token_usage = swa_token_usage
|
228
250
|
self.stats.gen_throughput = self.last_gen_throughput
|
229
251
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
230
252
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
253
|
+
self.stats.cache_hit_rate = cache_hit_rate
|
231
254
|
self.stats.spec_accept_length = spec_accept_length
|
232
|
-
|
233
|
-
|
234
|
-
|
255
|
+
|
256
|
+
# Retract
|
257
|
+
self.stats.num_retracted_reqs = self.num_retracted_reqs
|
258
|
+
self.stats.num_paused_reqs = self.num_paused_reqs
|
259
|
+
self.num_retracted_reqs = self.num_paused_reqs = 0
|
260
|
+
|
261
|
+
# PD disaggregation
|
262
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
263
|
+
self.stats.num_prefill_prealloc_queue_reqs = len(
|
264
|
+
self.disagg_prefill_bootstrap_queue.queue
|
265
|
+
)
|
266
|
+
self.stats.num_prefill_inflight_queue_reqs = len(
|
267
|
+
self.disagg_prefill_inflight_queue
|
268
|
+
)
|
269
|
+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
235
270
|
self.stats.num_decode_prealloc_queue_reqs = len(
|
236
271
|
self.disagg_decode_prealloc_queue.queue
|
237
272
|
)
|
238
273
|
self.stats.num_decode_transfer_queue_reqs = len(
|
239
274
|
self.disagg_decode_transfer_queue.queue
|
240
275
|
)
|
276
|
+
|
277
|
+
# Others
|
278
|
+
self.calculate_utilization()
|
279
|
+
self.metrics_collector.log_stats(self.stats)
|
241
280
|
self._emit_kv_metrics()
|
242
281
|
self._publish_kv_events()
|
243
282
|
|
244
283
|
def _emit_kv_metrics(self: Scheduler):
|
284
|
+
if not self.enable_kv_cache_events:
|
285
|
+
return
|
286
|
+
|
245
287
|
kv_metrics = KvMetrics()
|
246
288
|
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
247
289
|
kv_metrics.request_total_slots = self.max_running_requests
|
@@ -258,93 +300,24 @@ class SchedulerMetricsMixin:
|
|
258
300
|
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
259
301
|
|
260
302
|
def _publish_kv_events(self: Scheduler):
|
261
|
-
if self.enable_kv_cache_events:
|
262
|
-
|
263
|
-
if events:
|
264
|
-
batch = KVEventBatch(ts=time.time(), events=events)
|
265
|
-
self.kv_event_publisher.publish(batch)
|
266
|
-
|
267
|
-
def maybe_update_dp_balance_data(
|
268
|
-
self: Scheduler, recv_req: TokenizedGenerateReqInput
|
269
|
-
):
|
270
|
-
if (
|
271
|
-
self.server_args.enable_dp_attention
|
272
|
-
and self.server_args.load_balance_method == "minimum_tokens"
|
273
|
-
):
|
274
|
-
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
275
|
-
|
276
|
-
def maybe_handle_dp_balance_data(self: Scheduler):
|
277
|
-
if (
|
278
|
-
self.server_args.load_balance_method == "minimum_tokens"
|
279
|
-
and self.forward_ct % 40 == 0
|
280
|
-
):
|
281
|
-
holding_tokens = self.get_load()
|
282
|
-
|
283
|
-
new_recv_dp_balance_id_list, holding_token_list = (
|
284
|
-
self.gather_dp_balance_info(holding_tokens)
|
285
|
-
)
|
303
|
+
if not self.enable_kv_cache_events:
|
304
|
+
return
|
286
305
|
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
)
|
306
|
+
events = self.tree_cache.take_events()
|
307
|
+
if events:
|
308
|
+
batch = KVEventBatch(ts=time.time(), events=events)
|
309
|
+
self.kv_event_publisher.publish(batch)
|
292
310
|
|
293
|
-
def
|
294
|
-
self
|
295
|
-
|
296
|
-
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
297
|
-
recv_list = self.recv_dp_balance_id_this_term
|
298
|
-
assert len(recv_list) <= 511, (
|
299
|
-
"The number of requests received this round is too large. "
|
300
|
-
"Please increase gather_tensor_size and onfly_info_size."
|
301
|
-
)
|
302
|
-
# The maximum size of the tensor used for gathering data from all workers.
|
303
|
-
gather_tensor_size = 512
|
304
|
-
|
305
|
-
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
306
|
-
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
307
|
-
recv_tensor[0] = holding_tokens_list
|
308
|
-
recv_tensor[1] = len(recv_list) # The first element is the length of the list.
|
309
|
-
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
|
310
|
-
|
311
|
-
if self.tp_rank == 0:
|
312
|
-
gathered_list = [
|
313
|
-
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
314
|
-
for _ in range(self.balance_meta.num_workers)
|
315
|
-
]
|
311
|
+
def calculate_utilization(self):
|
312
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
313
|
+
self.stats.utilization = -1
|
316
314
|
else:
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
holding_tokens_list.append(tensor[0].item())
|
327
|
-
list_length = tensor[1].item()
|
328
|
-
gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
|
329
|
-
|
330
|
-
return gathered_id_list_per_worker, holding_tokens_list
|
331
|
-
|
332
|
-
def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
|
333
|
-
meta = self.balance_meta
|
334
|
-
|
335
|
-
with meta.mutex:
|
336
|
-
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
337
|
-
assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
|
338
|
-
# 1.Check if the rid received by each worker this round is present in onfly.
|
339
|
-
# If it is, remove the corresponding onfly item.
|
340
|
-
worker_id = 0
|
341
|
-
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
342
|
-
for new_recv_rid in new_recv_rids:
|
343
|
-
assert (
|
344
|
-
new_recv_rid in on_fly_reqs
|
345
|
-
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
346
|
-
del on_fly_reqs[new_recv_rid]
|
347
|
-
worker_id += 1
|
348
|
-
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
349
|
-
meta.set_shared_onfly_info(onfly_list)
|
350
|
-
meta.set_shared_local_tokens(local_tokens)
|
315
|
+
if (
|
316
|
+
self.stats.max_running_requests_under_SLO is not None
|
317
|
+
and self.stats.max_running_requests_under_SLO > 0
|
318
|
+
):
|
319
|
+
self.stats.utilization = max(
|
320
|
+
self.stats.num_running_reqs
|
321
|
+
/ self.stats.max_running_requests_under_SLO,
|
322
|
+
self.stats.token_usage / 0.9,
|
323
|
+
)
|
@@ -5,9 +5,15 @@ import threading
|
|
5
5
|
import time
|
6
6
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
7
7
|
|
8
|
+
import torch
|
9
|
+
|
8
10
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
9
11
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
10
|
-
from sglang.srt.managers.io_struct import
|
12
|
+
from sglang.srt.managers.io_struct import (
|
13
|
+
AbortReq,
|
14
|
+
BatchEmbeddingOutput,
|
15
|
+
BatchTokenIDOutput,
|
16
|
+
)
|
11
17
|
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
12
18
|
|
13
19
|
if TYPE_CHECKING:
|
@@ -71,6 +77,7 @@ class SchedulerOutputProcessorMixin:
|
|
71
77
|
|
72
78
|
# Check finish conditions
|
73
79
|
logprob_pt = 0
|
80
|
+
|
74
81
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
75
82
|
if req.is_retracted:
|
76
83
|
continue
|
@@ -88,7 +95,7 @@ class SchedulerOutputProcessorMixin:
|
|
88
95
|
|
89
96
|
if req.finished():
|
90
97
|
self.tree_cache.cache_finished_req(req)
|
91
|
-
req.time_stats.completion_time = time.
|
98
|
+
req.time_stats.completion_time = time.perf_counter()
|
92
99
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
93
100
|
# This updates radix so others can match
|
94
101
|
self.tree_cache.cache_unfinished_req(req)
|
@@ -99,6 +106,7 @@ class SchedulerOutputProcessorMixin:
|
|
99
106
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
100
107
|
extend_input_len = extend_input_len_per_req[i]
|
101
108
|
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
109
|
+
|
102
110
|
if req.return_logprob:
|
103
111
|
self.add_logprob_return_values(
|
104
112
|
i,
|
@@ -136,7 +144,7 @@ class SchedulerOutputProcessorMixin:
|
|
136
144
|
logger.error(
|
137
145
|
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
138
146
|
)
|
139
|
-
self.abort_request(AbortReq(req.rid))
|
147
|
+
self.abort_request(AbortReq(rid=req.rid))
|
140
148
|
req.grammar.finished = req.finished()
|
141
149
|
else:
|
142
150
|
# being chunked reqs' prefill is not finished
|
@@ -169,8 +177,7 @@ class SchedulerOutputProcessorMixin:
|
|
169
177
|
self.set_next_batch_sampling_info_done(batch)
|
170
178
|
|
171
179
|
else: # embedding or reward model
|
172
|
-
embeddings
|
173
|
-
embeddings = embeddings.tolist()
|
180
|
+
embeddings = result.embeddings.tolist()
|
174
181
|
|
175
182
|
# Check finish conditions
|
176
183
|
for i, req in enumerate(batch.reqs):
|
@@ -246,8 +253,14 @@ class SchedulerOutputProcessorMixin:
|
|
246
253
|
|
247
254
|
req.check_finished()
|
248
255
|
if req.finished():
|
249
|
-
self.
|
250
|
-
|
256
|
+
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
257
|
+
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
|
258
|
+
if not self.decode_offload_manager.offload_kv_cache(req):
|
259
|
+
self.tree_cache.cache_finished_req(req)
|
260
|
+
else:
|
261
|
+
self.tree_cache.cache_finished_req(req)
|
262
|
+
|
263
|
+
req.time_stats.completion_time = time.perf_counter()
|
251
264
|
|
252
265
|
if req.return_logprob and batch.spec_algorithm.is_none():
|
253
266
|
# speculative worker handles logprob in speculative decoding
|
@@ -283,7 +296,7 @@ class SchedulerOutputProcessorMixin:
|
|
283
296
|
logger.error(
|
284
297
|
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
285
298
|
)
|
286
|
-
self.abort_request(AbortReq(req.rid))
|
299
|
+
self.abort_request(AbortReq(rid=req.rid))
|
287
300
|
req.grammar.finished = req.finished()
|
288
301
|
|
289
302
|
self.set_next_batch_sampling_info_done(batch)
|
@@ -441,27 +454,59 @@ class SchedulerOutputProcessorMixin:
|
|
441
454
|
output: LogitsProcessorOutput,
|
442
455
|
):
|
443
456
|
"""Attach logprobs to the return values."""
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
457
|
+
if output.next_token_logprobs is not None:
|
458
|
+
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
|
459
|
+
req.output_token_logprobs_idx.append(next_token_ids[i])
|
460
|
+
|
461
|
+
# Only add input logprobs if there are input tokens to process
|
462
|
+
# Note: For prefill-only requests with default logprob_start_len, this will be 0,
|
463
|
+
# meaning we only compute output logprobs (which is the intended behavior)
|
464
|
+
if num_input_logprobs > 0:
|
465
|
+
self.add_input_logprob_return_values(
|
466
|
+
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
|
467
|
+
)
|
468
|
+
else:
|
469
|
+
self._initialize_empty_logprob_containers(req)
|
450
470
|
|
451
471
|
if req.top_logprobs_num > 0:
|
452
472
|
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
453
473
|
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
454
474
|
|
455
|
-
if
|
456
|
-
req.
|
457
|
-
|
458
|
-
|
475
|
+
if (
|
476
|
+
req.token_ids_logprob is not None
|
477
|
+
and output.next_token_token_ids_logprobs_val is not None
|
478
|
+
):
|
479
|
+
# Convert GPU tensor to list if needed
|
480
|
+
logprobs_val = output.next_token_token_ids_logprobs_val[i]
|
481
|
+
if isinstance(logprobs_val, torch.Tensor):
|
482
|
+
logprobs_val = logprobs_val.tolist()
|
483
|
+
req.output_token_ids_logprobs_val.append(logprobs_val)
|
459
484
|
req.output_token_ids_logprobs_idx.append(
|
460
485
|
output.next_token_token_ids_logprobs_idx[i]
|
461
486
|
)
|
462
487
|
|
463
488
|
return num_input_logprobs
|
464
489
|
|
490
|
+
def _initialize_empty_logprob_containers(self, req: Req) -> None:
|
491
|
+
"""
|
492
|
+
Initialize logprob fields to empty lists if unset.
|
493
|
+
|
494
|
+
This is needed for prefill-only requests where the normal initialization
|
495
|
+
flow might be bypassed, but downstream code expects these fields to be lists.
|
496
|
+
"""
|
497
|
+
if req.input_token_logprobs_val is None:
|
498
|
+
req.input_token_logprobs_val = []
|
499
|
+
if req.input_token_logprobs_idx is None:
|
500
|
+
req.input_token_logprobs_idx = []
|
501
|
+
if req.input_top_logprobs_val is None:
|
502
|
+
req.input_top_logprobs_val = []
|
503
|
+
if req.input_top_logprobs_idx is None:
|
504
|
+
req.input_top_logprobs_idx = []
|
505
|
+
if req.input_token_ids_logprobs_val is None:
|
506
|
+
req.input_token_ids_logprobs_val = []
|
507
|
+
if req.input_token_ids_logprobs_idx is None:
|
508
|
+
req.input_token_ids_logprobs_idx = []
|
509
|
+
|
465
510
|
def stream_output(
|
466
511
|
self: Scheduler,
|
467
512
|
reqs: List[Req],
|
@@ -673,8 +718,7 @@ class SchedulerOutputProcessorMixin:
|
|
673
718
|
return
|
674
719
|
|
675
720
|
self.send_to_detokenizer.send_pyobj(
|
676
|
-
|
677
|
-
rids,
|
721
|
+
BatchTokenIDOutput(
|
678
722
|
finished_reasons,
|
679
723
|
decoded_texts,
|
680
724
|
decode_ids_list,
|
@@ -700,6 +744,9 @@ class SchedulerOutputProcessorMixin:
|
|
700
744
|
output_token_ids_logprobs_val,
|
701
745
|
output_token_ids_logprobs_idx,
|
702
746
|
output_hidden_states,
|
747
|
+
rids=rids,
|
748
|
+
placeholder_tokens_idx=None,
|
749
|
+
placeholder_tokens_val=None,
|
703
750
|
)
|
704
751
|
)
|
705
752
|
|
@@ -718,7 +765,13 @@ class SchedulerOutputProcessorMixin:
|
|
718
765
|
prompt_tokens.append(len(req.origin_input_ids))
|
719
766
|
cached_tokens.append(req.cached_tokens)
|
720
767
|
self.send_to_detokenizer.send_pyobj(
|
721
|
-
|
722
|
-
|
768
|
+
BatchEmbeddingOutput(
|
769
|
+
finished_reasons,
|
770
|
+
embeddings,
|
771
|
+
prompt_tokens,
|
772
|
+
cached_tokens,
|
773
|
+
rids=rids,
|
774
|
+
placeholder_tokens_idx=None,
|
775
|
+
placeholder_tokens_val=None,
|
723
776
|
)
|
724
777
|
)
|
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
|
|
26
26
|
|
27
27
|
class SchedulerProfilerMixin:
|
28
28
|
|
29
|
-
def
|
29
|
+
def init_profiler(self):
|
30
30
|
self.torch_profiler = None
|
31
31
|
self.torch_profiler_output_dir: Optional[str] = None
|
32
32
|
self.profiler_activities: Optional[List[str]] = None
|
@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
|
|
97
97
|
def start_profile(
|
98
98
|
self, stage: Optional[ForwardMode] = None
|
99
99
|
) -> ProfileReqOutput | None:
|
100
|
-
stage_str = f" for {stage.
|
100
|
+
stage_str = f" for {stage.name}" if stage else ""
|
101
101
|
logger.info(
|
102
102
|
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
|
103
103
|
)
|
@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
|
|
181
181
|
if not Path(self.torch_profiler_output_dir).exists():
|
182
182
|
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
|
183
183
|
|
184
|
-
stage_suffix = f"-{stage.
|
184
|
+
stage_suffix = f"-{stage.name}" if stage else ""
|
185
185
|
logger.info("Stop profiling" + stage_suffix + "...")
|
186
186
|
if self.torch_profiler is not None:
|
187
187
|
self.torch_profiler.stop()
|
@@ -204,7 +204,7 @@ class SchedulerProfilerMixin:
|
|
204
204
|
|
205
205
|
torch.distributed.barrier(self.tp_cpu_group)
|
206
206
|
if self.tp_rank == 0:
|
207
|
-
from sglang.srt.utils import rpd_to_chrome_trace
|
207
|
+
from sglang.srt.utils.rpd_utils import rpd_to_chrome_trace
|
208
208
|
|
209
209
|
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
|
210
210
|
self.rpd_profiler = None
|
@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
|
|
247
247
|
if self.profiler_decode_ct == 0:
|
248
248
|
if self.profile_in_progress:
|
249
249
|
# force trace flush
|
250
|
-
self.stop_profile(ForwardMode.EXTEND)
|
250
|
+
self.stop_profile(stage=ForwardMode.EXTEND)
|
251
251
|
self.start_profile(batch.forward_mode)
|
252
252
|
self.profiler_decode_ct += 1
|
253
253
|
if self.profiler_decode_ct > self.profiler_target_decode_ct:
|
@@ -294,6 +294,6 @@ class SchedulerProfilerMixin:
|
|
294
294
|
recv_req.profile_by_stage,
|
295
295
|
recv_req.profile_id,
|
296
296
|
)
|
297
|
-
return self.start_profile(
|
297
|
+
return self.start_profile()
|
298
298
|
else:
|
299
299
|
return self.stop_profile()
|
@@ -5,6 +5,8 @@ import torch
|
|
5
5
|
|
6
6
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
7
7
|
from sglang.srt.managers.io_struct import (
|
8
|
+
DestroyWeightsUpdateGroupReqInput,
|
9
|
+
DestroyWeightsUpdateGroupReqOutput,
|
8
10
|
GetWeightsByNameReqInput,
|
9
11
|
GetWeightsByNameReqOutput,
|
10
12
|
InitWeightsUpdateGroupReqInput,
|
@@ -41,6 +43,11 @@ class SchedulerUpdateWeightsMixin:
|
|
41
43
|
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
42
44
|
return InitWeightsUpdateGroupReqOutput(success, message)
|
43
45
|
|
46
|
+
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
47
|
+
"""Destroy the online model parameter update group."""
|
48
|
+
success, message = self.tp_worker.destroy_weights_update_group(recv_req)
|
49
|
+
return DestroyWeightsUpdateGroupReqOutput(success, message)
|
50
|
+
|
44
51
|
def update_weights_from_distributed(
|
45
52
|
self,
|
46
53
|
recv_req: UpdateWeightsFromDistributedReqInput,
|