sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -5,9 +5,15 @@ import threading
|
|
5
5
|
import time
|
6
6
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
7
7
|
|
8
|
+
import torch
|
9
|
+
|
8
10
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
9
11
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
10
|
-
from sglang.srt.managers.io_struct import
|
12
|
+
from sglang.srt.managers.io_struct import (
|
13
|
+
AbortReq,
|
14
|
+
BatchEmbeddingOutput,
|
15
|
+
BatchTokenIDOutput,
|
16
|
+
)
|
11
17
|
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
12
18
|
|
13
19
|
if TYPE_CHECKING:
|
@@ -33,7 +39,6 @@ class SchedulerOutputProcessorMixin:
|
|
33
39
|
self: Scheduler,
|
34
40
|
batch: ScheduleBatch,
|
35
41
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
36
|
-
launch_done: Optional[threading.Event] = None,
|
37
42
|
):
|
38
43
|
skip_stream_req = None
|
39
44
|
|
@@ -43,34 +48,35 @@ class SchedulerOutputProcessorMixin:
|
|
43
48
|
next_token_ids,
|
44
49
|
extend_input_len_per_req,
|
45
50
|
extend_logprob_start_len_per_req,
|
51
|
+
copy_done,
|
46
52
|
) = (
|
47
53
|
result.logits_output,
|
48
54
|
result.next_token_ids,
|
49
55
|
result.extend_input_len_per_req,
|
50
56
|
result.extend_logprob_start_len_per_req,
|
57
|
+
result.copy_done,
|
51
58
|
)
|
52
59
|
|
53
|
-
if
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
logits_output.input_token_logprobs.tolist()
|
68
|
-
)
|
60
|
+
if copy_done is not None:
|
61
|
+
copy_done.synchronize()
|
62
|
+
|
63
|
+
# Move next_token_ids and logprobs to cpu
|
64
|
+
next_token_ids = next_token_ids.tolist()
|
65
|
+
if batch.return_logprob:
|
66
|
+
if logits_output.next_token_logprobs is not None:
|
67
|
+
logits_output.next_token_logprobs = (
|
68
|
+
logits_output.next_token_logprobs.tolist()
|
69
|
+
)
|
70
|
+
if logits_output.input_token_logprobs is not None:
|
71
|
+
logits_output.input_token_logprobs = tuple(
|
72
|
+
logits_output.input_token_logprobs.tolist()
|
73
|
+
)
|
69
74
|
|
70
75
|
hidden_state_offset = 0
|
71
76
|
|
72
77
|
# Check finish conditions
|
73
78
|
logprob_pt = 0
|
79
|
+
|
74
80
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
75
81
|
if req.is_retracted:
|
76
82
|
continue
|
@@ -88,7 +94,7 @@ class SchedulerOutputProcessorMixin:
|
|
88
94
|
|
89
95
|
if req.finished():
|
90
96
|
self.tree_cache.cache_finished_req(req)
|
91
|
-
req.time_stats.completion_time = time.
|
97
|
+
req.time_stats.completion_time = time.perf_counter()
|
92
98
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
93
99
|
# This updates radix so others can match
|
94
100
|
self.tree_cache.cache_unfinished_req(req)
|
@@ -98,7 +104,11 @@ class SchedulerOutputProcessorMixin:
|
|
98
104
|
assert extend_input_len_per_req is not None
|
99
105
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
100
106
|
extend_input_len = extend_input_len_per_req[i]
|
101
|
-
|
107
|
+
|
108
|
+
num_input_logprobs = self._calculate_num_input_logprobs(
|
109
|
+
req, extend_input_len, extend_logprob_start_len
|
110
|
+
)
|
111
|
+
|
102
112
|
if req.return_logprob:
|
103
113
|
self.add_logprob_return_values(
|
104
114
|
i,
|
@@ -136,7 +146,7 @@ class SchedulerOutputProcessorMixin:
|
|
136
146
|
logger.error(
|
137
147
|
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
138
148
|
)
|
139
|
-
self.abort_request(AbortReq(req.rid))
|
149
|
+
self.abort_request(AbortReq(rid=req.rid))
|
140
150
|
req.grammar.finished = req.finished()
|
141
151
|
else:
|
142
152
|
# being chunked reqs' prefill is not finished
|
@@ -152,8 +162,8 @@ class SchedulerOutputProcessorMixin:
|
|
152
162
|
extend_input_len = extend_input_len_per_req[i]
|
153
163
|
if extend_logprob_start_len < extend_input_len:
|
154
164
|
# Update input logprobs.
|
155
|
-
num_input_logprobs = (
|
156
|
-
extend_input_len
|
165
|
+
num_input_logprobs = self._calculate_num_input_logprobs(
|
166
|
+
req, extend_input_len, extend_logprob_start_len
|
157
167
|
)
|
158
168
|
if req.return_logprob:
|
159
169
|
self.add_input_logprob_return_values(
|
@@ -166,11 +176,8 @@ class SchedulerOutputProcessorMixin:
|
|
166
176
|
)
|
167
177
|
logprob_pt += num_input_logprobs
|
168
178
|
|
169
|
-
self.set_next_batch_sampling_info_done(batch)
|
170
|
-
|
171
179
|
else: # embedding or reward model
|
172
|
-
embeddings
|
173
|
-
embeddings = embeddings.tolist()
|
180
|
+
embeddings = result.embeddings.tolist()
|
174
181
|
|
175
182
|
# Check finish conditions
|
176
183
|
for i, req in enumerate(batch.reqs):
|
@@ -197,22 +204,19 @@ class SchedulerOutputProcessorMixin:
|
|
197
204
|
self: Scheduler,
|
198
205
|
batch: ScheduleBatch,
|
199
206
|
result: GenerationBatchResult,
|
200
|
-
launch_done: Optional[threading.Event] = None,
|
201
207
|
):
|
202
|
-
logits_output, next_token_ids, can_run_cuda_graph = (
|
208
|
+
logits_output, next_token_ids, can_run_cuda_graph, copy_done = (
|
203
209
|
result.logits_output,
|
204
210
|
result.next_token_ids,
|
205
211
|
result.can_run_cuda_graph,
|
212
|
+
result.copy_done,
|
206
213
|
)
|
207
214
|
self.num_generated_tokens += len(batch.reqs)
|
208
215
|
|
209
|
-
if
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
next_token_logprobs = logits_output.next_token_logprobs
|
214
|
-
elif batch.spec_algorithm.is_none():
|
215
|
-
# spec decoding handles output logprobs inside verify process.
|
216
|
+
if copy_done is not None:
|
217
|
+
copy_done.synchronize()
|
218
|
+
|
219
|
+
if batch.spec_algorithm.is_none():
|
216
220
|
next_token_ids = next_token_ids.tolist()
|
217
221
|
if batch.return_logprob:
|
218
222
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
@@ -246,8 +250,14 @@ class SchedulerOutputProcessorMixin:
|
|
246
250
|
|
247
251
|
req.check_finished()
|
248
252
|
if req.finished():
|
249
|
-
self.
|
250
|
-
|
253
|
+
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
254
|
+
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
|
255
|
+
if not self.decode_offload_manager.offload_kv_cache(req):
|
256
|
+
self.tree_cache.cache_finished_req(req)
|
257
|
+
else:
|
258
|
+
self.tree_cache.cache_finished_req(req)
|
259
|
+
|
260
|
+
req.time_stats.completion_time = time.perf_counter()
|
251
261
|
|
252
262
|
if req.return_logprob and batch.spec_algorithm.is_none():
|
253
263
|
# speculative worker handles logprob in speculative decoding
|
@@ -283,10 +293,9 @@ class SchedulerOutputProcessorMixin:
|
|
283
293
|
logger.error(
|
284
294
|
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
285
295
|
)
|
286
|
-
self.abort_request(AbortReq(req.rid))
|
296
|
+
self.abort_request(AbortReq(rid=req.rid))
|
287
297
|
req.grammar.finished = req.finished()
|
288
298
|
|
289
|
-
self.set_next_batch_sampling_info_done(batch)
|
290
299
|
self.stream_output(batch.reqs, batch.return_logprob)
|
291
300
|
self.token_to_kv_pool_allocator.free_group_end()
|
292
301
|
|
@@ -297,6 +306,153 @@ class SchedulerOutputProcessorMixin:
|
|
297
306
|
):
|
298
307
|
self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
|
299
308
|
|
309
|
+
def _process_input_token_logprobs(
|
310
|
+
self, req: Req, input_token_logprobs: List
|
311
|
+
) -> None:
|
312
|
+
"""Process input token logprobs values and indices."""
|
313
|
+
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
314
|
+
|
315
|
+
# Process logprob values - handle multi-item scoring vs regular requests
|
316
|
+
if is_multi_item_scoring:
|
317
|
+
# Multi-item scoring: use all logprobs as-is
|
318
|
+
req.input_token_logprobs_val = input_token_logprobs
|
319
|
+
else:
|
320
|
+
# Regular request: add None at start, remove last (sampling token)
|
321
|
+
req.input_token_logprobs_val = [None] + input_token_logprobs[:-1]
|
322
|
+
|
323
|
+
# Process logprob indices based on scoring type
|
324
|
+
if is_multi_item_scoring:
|
325
|
+
# Multi-item scoring: only include delimiter token positions
|
326
|
+
relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
|
327
|
+
input_token_logprobs_idx = [
|
328
|
+
token_id
|
329
|
+
for token_id in relevant_tokens
|
330
|
+
if token_id == self.server_args.multi_item_scoring_delimiter
|
331
|
+
]
|
332
|
+
else:
|
333
|
+
# Regular request: include all tokens from logprob_start_len onwards
|
334
|
+
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
335
|
+
|
336
|
+
# Clip padded hash values from image tokens to prevent detokenization errors
|
337
|
+
req.input_token_logprobs_idx = [
|
338
|
+
x if x < self.model_config.vocab_size - 1 else 0
|
339
|
+
for x in input_token_logprobs_idx
|
340
|
+
]
|
341
|
+
|
342
|
+
def _process_input_top_logprobs(self, req: Req) -> None:
|
343
|
+
"""Process input top logprobs."""
|
344
|
+
if req.top_logprobs_num <= 0:
|
345
|
+
return
|
346
|
+
|
347
|
+
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
348
|
+
|
349
|
+
# Initialize arrays - multi-item scoring starts empty, others start with None
|
350
|
+
req.input_top_logprobs_val = [] if is_multi_item_scoring else [None]
|
351
|
+
req.input_top_logprobs_idx = [] if is_multi_item_scoring else [None]
|
352
|
+
|
353
|
+
# Extend arrays with temp values
|
354
|
+
for val, idx in zip(
|
355
|
+
req.temp_input_top_logprobs_val,
|
356
|
+
req.temp_input_top_logprobs_idx,
|
357
|
+
strict=True,
|
358
|
+
):
|
359
|
+
req.input_top_logprobs_val.extend(val)
|
360
|
+
req.input_top_logprobs_idx.extend(idx)
|
361
|
+
|
362
|
+
# Remove last token (sampling token) for non multi-item scoring requests
|
363
|
+
if not is_multi_item_scoring:
|
364
|
+
req.input_top_logprobs_val.pop()
|
365
|
+
req.input_top_logprobs_idx.pop()
|
366
|
+
|
367
|
+
# Clean up temp storage
|
368
|
+
req.temp_input_top_logprobs_idx = None
|
369
|
+
req.temp_input_top_logprobs_val = None
|
370
|
+
|
371
|
+
def _process_input_token_ids_logprobs(self, req: Req) -> None:
|
372
|
+
"""Process input token IDs logprobs."""
|
373
|
+
if req.token_ids_logprob is None:
|
374
|
+
return
|
375
|
+
|
376
|
+
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
377
|
+
|
378
|
+
# Initialize arrays - multi-item scoring starts empty, others start with None
|
379
|
+
req.input_token_ids_logprobs_val = [] if is_multi_item_scoring else [None]
|
380
|
+
req.input_token_ids_logprobs_idx = [] if is_multi_item_scoring else [None]
|
381
|
+
|
382
|
+
# Process temp values - convert tensors to lists and extend arrays
|
383
|
+
for val, idx in zip(
|
384
|
+
req.temp_input_token_ids_logprobs_val,
|
385
|
+
req.temp_input_token_ids_logprobs_idx,
|
386
|
+
strict=True,
|
387
|
+
):
|
388
|
+
val_list = val.tolist() if isinstance(val, torch.Tensor) else val
|
389
|
+
req.input_token_ids_logprobs_val.extend(
|
390
|
+
val_list if isinstance(val_list, list) else [val_list]
|
391
|
+
)
|
392
|
+
req.input_token_ids_logprobs_idx.extend(idx)
|
393
|
+
|
394
|
+
# Remove last token (sampling token) for non multi-item scoring requests
|
395
|
+
if not is_multi_item_scoring:
|
396
|
+
req.input_token_ids_logprobs_val.pop()
|
397
|
+
req.input_token_ids_logprobs_idx.pop()
|
398
|
+
|
399
|
+
# Clean up temp storage
|
400
|
+
req.temp_input_token_ids_logprobs_idx = None
|
401
|
+
req.temp_input_token_ids_logprobs_val = None
|
402
|
+
|
403
|
+
def _calculate_relevant_tokens_len(self, req: Req) -> int:
|
404
|
+
"""Calculate the expected length of logprob arrays based on whether multi-item scoring is enabled.
|
405
|
+
|
406
|
+
For multi-item scoring, only delimiter positions have logprobs.
|
407
|
+
For regular requests, all positions from logprob_start_len onwards have logprobs.
|
408
|
+
"""
|
409
|
+
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
410
|
+
|
411
|
+
if is_multi_item_scoring:
|
412
|
+
# Multi-item scoring: count delimiter tokens from logprob_start_len onwards
|
413
|
+
relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
|
414
|
+
return sum(
|
415
|
+
1
|
416
|
+
for token_id in relevant_tokens
|
417
|
+
if token_id == self.server_args.multi_item_scoring_delimiter
|
418
|
+
)
|
419
|
+
else:
|
420
|
+
# Regular request: all tokens from logprob_start_len onwards
|
421
|
+
return len(req.origin_input_ids) - req.logprob_start_len
|
422
|
+
|
423
|
+
def _calculate_num_input_logprobs(
|
424
|
+
self, req: Req, extend_input_len: int, extend_logprob_start_len: int
|
425
|
+
) -> int:
|
426
|
+
"""Calculate the number of input logprobs based on whether multi-item scoring is enabled.
|
427
|
+
|
428
|
+
For multi-item scoring, only delimiter positions have logprobs.
|
429
|
+
For regular requests, all positions in the range have logprobs.
|
430
|
+
"""
|
431
|
+
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
432
|
+
|
433
|
+
if is_multi_item_scoring:
|
434
|
+
# Multi-item scoring: count delimiter tokens in the relevant portion
|
435
|
+
relevant_tokens = req.origin_input_ids[
|
436
|
+
extend_logprob_start_len:extend_input_len
|
437
|
+
]
|
438
|
+
return sum(
|
439
|
+
1
|
440
|
+
for token_id in relevant_tokens
|
441
|
+
if token_id == self.server_args.multi_item_scoring_delimiter
|
442
|
+
)
|
443
|
+
else:
|
444
|
+
# Regular request: all tokens in the range
|
445
|
+
return extend_input_len - extend_logprob_start_len
|
446
|
+
|
447
|
+
def _is_multi_item_scoring(self, req: Req) -> bool:
|
448
|
+
"""Check if request uses multi-item scoring.
|
449
|
+
|
450
|
+
Multi-item scoring applies to prefill-only requests when a delimiter
|
451
|
+
token is configured. In this mode, only positions containing the
|
452
|
+
delimiter token receive logprobs.
|
453
|
+
"""
|
454
|
+
return req.is_prefill_only and self.server_args.multi_item_scoring_delimiter
|
455
|
+
|
300
456
|
def add_input_logprob_return_values(
|
301
457
|
self: Scheduler,
|
302
458
|
i: int,
|
@@ -365,63 +521,14 @@ class SchedulerOutputProcessorMixin:
|
|
365
521
|
assert req.input_top_logprobs_val is None
|
366
522
|
assert req.input_top_logprobs_idx is None
|
367
523
|
|
368
|
-
#
|
369
|
-
|
370
|
-
req
|
371
|
-
req.input_token_logprobs_val.extend(input_token_logprobs)
|
372
|
-
# The last input logprob is for sampling, so just pop it out.
|
373
|
-
req.input_token_logprobs_val.pop()
|
374
|
-
|
375
|
-
# Compute input_token_logprobs_idx
|
376
|
-
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
377
|
-
# Clip the padded hash values from image tokens.
|
378
|
-
# Otherwise, it will lead to detokenization errors.
|
379
|
-
input_token_logprobs_idx = [
|
380
|
-
x if x < self.model_config.vocab_size - 1 else 0
|
381
|
-
for x in input_token_logprobs_idx
|
382
|
-
]
|
383
|
-
req.input_token_logprobs_idx = input_token_logprobs_idx
|
524
|
+
# Process all input logprob types using helper functions
|
525
|
+
self._process_input_token_logprobs(req, input_token_logprobs)
|
526
|
+
self._process_input_top_logprobs(req)
|
384
527
|
|
385
|
-
|
386
|
-
req.input_top_logprobs_val = [None]
|
387
|
-
req.input_top_logprobs_idx = [None]
|
388
|
-
assert len(req.temp_input_token_ids_logprobs_val) == len(
|
389
|
-
req.temp_input_token_ids_logprobs_idx
|
390
|
-
)
|
391
|
-
for val, idx in zip(
|
392
|
-
req.temp_input_top_logprobs_val,
|
393
|
-
req.temp_input_top_logprobs_idx,
|
394
|
-
strict=True,
|
395
|
-
):
|
396
|
-
req.input_top_logprobs_val.extend(val)
|
397
|
-
req.input_top_logprobs_idx.extend(idx)
|
398
|
-
|
399
|
-
# Last token is a sample token.
|
400
|
-
req.input_top_logprobs_val.pop()
|
401
|
-
req.input_top_logprobs_idx.pop()
|
402
|
-
req.temp_input_top_logprobs_idx = None
|
403
|
-
req.temp_input_top_logprobs_val = None
|
404
|
-
|
405
|
-
if req.token_ids_logprob is not None:
|
406
|
-
req.input_token_ids_logprobs_val = [None]
|
407
|
-
req.input_token_ids_logprobs_idx = [None]
|
408
|
-
|
409
|
-
for val, idx in zip(
|
410
|
-
req.temp_input_token_ids_logprobs_val,
|
411
|
-
req.temp_input_token_ids_logprobs_idx,
|
412
|
-
strict=True,
|
413
|
-
):
|
414
|
-
req.input_token_ids_logprobs_val.extend(val)
|
415
|
-
req.input_token_ids_logprobs_idx.extend(idx)
|
416
|
-
|
417
|
-
# Last token is a sample token.
|
418
|
-
req.input_token_ids_logprobs_val.pop()
|
419
|
-
req.input_token_ids_logprobs_idx.pop()
|
420
|
-
req.temp_input_token_ids_logprobs_idx = None
|
421
|
-
req.temp_input_token_ids_logprobs_val = None
|
528
|
+
self._process_input_token_ids_logprobs(req)
|
422
529
|
|
423
530
|
if req.return_logprob:
|
424
|
-
relevant_tokens_len =
|
531
|
+
relevant_tokens_len = self._calculate_relevant_tokens_len(req)
|
425
532
|
assert len(req.input_token_logprobs_val) == relevant_tokens_len
|
426
533
|
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
|
427
534
|
if req.top_logprobs_num > 0:
|
@@ -441,27 +548,59 @@ class SchedulerOutputProcessorMixin:
|
|
441
548
|
output: LogitsProcessorOutput,
|
442
549
|
):
|
443
550
|
"""Attach logprobs to the return values."""
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
551
|
+
if output.next_token_logprobs is not None:
|
552
|
+
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
|
553
|
+
req.output_token_logprobs_idx.append(next_token_ids[i])
|
554
|
+
|
555
|
+
# Only add input logprobs if there are input tokens to process
|
556
|
+
# Note: For prefill-only requests with default logprob_start_len, this will be 0,
|
557
|
+
# meaning we only compute output logprobs (which is the intended behavior)
|
558
|
+
if num_input_logprobs > 0:
|
559
|
+
self.add_input_logprob_return_values(
|
560
|
+
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
|
561
|
+
)
|
562
|
+
else:
|
563
|
+
self._initialize_empty_logprob_containers(req)
|
450
564
|
|
451
565
|
if req.top_logprobs_num > 0:
|
452
566
|
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
453
567
|
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
454
568
|
|
455
|
-
if
|
456
|
-
req.
|
457
|
-
|
458
|
-
|
569
|
+
if (
|
570
|
+
req.token_ids_logprob is not None
|
571
|
+
and output.next_token_token_ids_logprobs_val is not None
|
572
|
+
):
|
573
|
+
# Convert GPU tensor to list if needed
|
574
|
+
logprobs_val = output.next_token_token_ids_logprobs_val[i]
|
575
|
+
if isinstance(logprobs_val, torch.Tensor):
|
576
|
+
logprobs_val = logprobs_val.tolist()
|
577
|
+
req.output_token_ids_logprobs_val.append(logprobs_val)
|
459
578
|
req.output_token_ids_logprobs_idx.append(
|
460
579
|
output.next_token_token_ids_logprobs_idx[i]
|
461
580
|
)
|
462
581
|
|
463
582
|
return num_input_logprobs
|
464
583
|
|
584
|
+
def _initialize_empty_logprob_containers(self, req: Req) -> None:
|
585
|
+
"""
|
586
|
+
Initialize logprob fields to empty lists if unset.
|
587
|
+
|
588
|
+
This is needed for prefill-only requests where the normal initialization
|
589
|
+
flow might be bypassed, but downstream code expects these fields to be lists.
|
590
|
+
"""
|
591
|
+
if req.input_token_logprobs_val is None:
|
592
|
+
req.input_token_logprobs_val = []
|
593
|
+
if req.input_token_logprobs_idx is None:
|
594
|
+
req.input_token_logprobs_idx = []
|
595
|
+
if req.input_top_logprobs_val is None:
|
596
|
+
req.input_top_logprobs_val = []
|
597
|
+
if req.input_top_logprobs_idx is None:
|
598
|
+
req.input_top_logprobs_idx = []
|
599
|
+
if req.input_token_ids_logprobs_val is None:
|
600
|
+
req.input_token_ids_logprobs_val = []
|
601
|
+
if req.input_token_ids_logprobs_idx is None:
|
602
|
+
req.input_token_ids_logprobs_idx = []
|
603
|
+
|
465
604
|
def stream_output(
|
466
605
|
self: Scheduler,
|
467
606
|
reqs: List[Req],
|
@@ -673,8 +812,7 @@ class SchedulerOutputProcessorMixin:
|
|
673
812
|
return
|
674
813
|
|
675
814
|
self.send_to_detokenizer.send_pyobj(
|
676
|
-
|
677
|
-
rids,
|
815
|
+
BatchTokenIDOutput(
|
678
816
|
finished_reasons,
|
679
817
|
decoded_texts,
|
680
818
|
decode_ids_list,
|
@@ -700,6 +838,9 @@ class SchedulerOutputProcessorMixin:
|
|
700
838
|
output_token_ids_logprobs_val,
|
701
839
|
output_token_ids_logprobs_idx,
|
702
840
|
output_hidden_states,
|
841
|
+
rids=rids,
|
842
|
+
placeholder_tokens_idx=None,
|
843
|
+
placeholder_tokens_val=None,
|
703
844
|
)
|
704
845
|
)
|
705
846
|
|
@@ -718,7 +859,13 @@ class SchedulerOutputProcessorMixin:
|
|
718
859
|
prompt_tokens.append(len(req.origin_input_ids))
|
719
860
|
cached_tokens.append(req.cached_tokens)
|
720
861
|
self.send_to_detokenizer.send_pyobj(
|
721
|
-
|
722
|
-
|
862
|
+
BatchEmbeddingOutput(
|
863
|
+
finished_reasons,
|
864
|
+
embeddings,
|
865
|
+
prompt_tokens,
|
866
|
+
cached_tokens,
|
867
|
+
rids=rids,
|
868
|
+
placeholder_tokens_idx=None,
|
869
|
+
placeholder_tokens_val=None,
|
723
870
|
)
|
724
871
|
)
|
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
|
|
26
26
|
|
27
27
|
class SchedulerProfilerMixin:
|
28
28
|
|
29
|
-
def
|
29
|
+
def init_profiler(self):
|
30
30
|
self.torch_profiler = None
|
31
31
|
self.torch_profiler_output_dir: Optional[str] = None
|
32
32
|
self.profiler_activities: Optional[List[str]] = None
|
@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
|
|
97
97
|
def start_profile(
|
98
98
|
self, stage: Optional[ForwardMode] = None
|
99
99
|
) -> ProfileReqOutput | None:
|
100
|
-
stage_str = f" for {stage.
|
100
|
+
stage_str = f" for {stage.name}" if stage else ""
|
101
101
|
logger.info(
|
102
102
|
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
|
103
103
|
)
|
@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
|
|
181
181
|
if not Path(self.torch_profiler_output_dir).exists():
|
182
182
|
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
|
183
183
|
|
184
|
-
stage_suffix = f"-{stage.
|
184
|
+
stage_suffix = f"-{stage.name}" if stage else ""
|
185
185
|
logger.info("Stop profiling" + stage_suffix + "...")
|
186
186
|
if self.torch_profiler is not None:
|
187
187
|
self.torch_profiler.stop()
|
@@ -204,7 +204,7 @@ class SchedulerProfilerMixin:
|
|
204
204
|
|
205
205
|
torch.distributed.barrier(self.tp_cpu_group)
|
206
206
|
if self.tp_rank == 0:
|
207
|
-
from sglang.srt.utils import rpd_to_chrome_trace
|
207
|
+
from sglang.srt.utils.rpd_utils import rpd_to_chrome_trace
|
208
208
|
|
209
209
|
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
|
210
210
|
self.rpd_profiler = None
|
@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
|
|
247
247
|
if self.profiler_decode_ct == 0:
|
248
248
|
if self.profile_in_progress:
|
249
249
|
# force trace flush
|
250
|
-
self.stop_profile(ForwardMode.EXTEND)
|
250
|
+
self.stop_profile(stage=ForwardMode.EXTEND)
|
251
251
|
self.start_profile(batch.forward_mode)
|
252
252
|
self.profiler_decode_ct += 1
|
253
253
|
if self.profiler_decode_ct > self.profiler_target_decode_ct:
|
@@ -294,6 +294,6 @@ class SchedulerProfilerMixin:
|
|
294
294
|
recv_req.profile_by_stage,
|
295
295
|
recv_req.profile_id,
|
296
296
|
)
|
297
|
-
return self.start_profile(
|
297
|
+
return self.start_profile()
|
298
298
|
else:
|
299
299
|
return self.stop_profile()
|
@@ -5,6 +5,8 @@ import torch
|
|
5
5
|
|
6
6
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
7
7
|
from sglang.srt.managers.io_struct import (
|
8
|
+
DestroyWeightsUpdateGroupReqInput,
|
9
|
+
DestroyWeightsUpdateGroupReqOutput,
|
8
10
|
GetWeightsByNameReqInput,
|
9
11
|
GetWeightsByNameReqOutput,
|
10
12
|
InitWeightsUpdateGroupReqInput,
|
@@ -41,6 +43,11 @@ class SchedulerUpdateWeightsMixin:
|
|
41
43
|
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
42
44
|
return InitWeightsUpdateGroupReqOutput(success, message)
|
43
45
|
|
46
|
+
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
47
|
+
"""Destroy the online model parameter update group."""
|
48
|
+
success, message = self.tp_worker.destroy_weights_update_group(recv_req)
|
49
|
+
return DestroyWeightsUpdateGroupReqOutput(success, message)
|
50
|
+
|
44
51
|
def update_weights_from_distributed(
|
45
52
|
self,
|
46
53
|
recv_req: UpdateWeightsFromDistributedReqInput,
|