sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -12,25 +12,24 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
"""A tensor parallel worker."""
|
15
|
+
from __future__ import annotations
|
15
16
|
|
16
17
|
import logging
|
17
18
|
import threading
|
18
|
-
from typing import Optional, Tuple, Union
|
19
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
19
20
|
|
20
21
|
import torch
|
21
22
|
|
22
23
|
from sglang.srt.configs.model_config import ModelConfig
|
23
24
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
24
|
-
from sglang.srt.hf_transformers_utils import (
|
25
|
-
get_processor,
|
26
|
-
get_tokenizer,
|
27
|
-
get_tokenizer_from_processor,
|
28
|
-
)
|
29
25
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
30
26
|
from sglang.srt.managers.io_struct import (
|
27
|
+
DestroyWeightsUpdateGroupReqInput,
|
31
28
|
GetWeightsByNameReqInput,
|
29
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
32
30
|
InitWeightsUpdateGroupReqInput,
|
33
31
|
LoadLoRAAdapterReqInput,
|
32
|
+
SendWeightsToRemoteInstanceReqInput,
|
34
33
|
UnloadLoRAAdapterReqInput,
|
35
34
|
UpdateWeightFromDiskReqInput,
|
36
35
|
UpdateWeightsFromDistributedReqInput,
|
@@ -39,11 +38,23 @@ from sglang.srt.managers.io_struct import (
|
|
39
38
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
40
39
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
41
40
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
42
|
-
from sglang.srt.model_executor.forward_batch_info import
|
41
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
42
|
+
ForwardBatch,
|
43
|
+
ForwardBatchOutput,
|
44
|
+
PPProxyTensors,
|
45
|
+
)
|
43
46
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
44
|
-
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
45
47
|
from sglang.srt.server_args import ServerArgs
|
46
48
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
49
|
+
from sglang.srt.utils.hf_transformers_utils import (
|
50
|
+
get_processor,
|
51
|
+
get_tokenizer,
|
52
|
+
get_tokenizer_from_processor,
|
53
|
+
)
|
54
|
+
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
|
55
|
+
|
56
|
+
if TYPE_CHECKING:
|
57
|
+
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
47
58
|
|
48
59
|
logger = logging.getLogger(__name__)
|
49
60
|
|
@@ -78,6 +89,11 @@ class TpModelWorker:
|
|
78
89
|
if not is_draft_worker
|
79
90
|
else server_args.speculative_draft_model_path
|
80
91
|
),
|
92
|
+
model_revision=(
|
93
|
+
server_args.revision
|
94
|
+
if not is_draft_worker
|
95
|
+
else server_args.speculative_draft_model_revision
|
96
|
+
),
|
81
97
|
is_draft_model=is_draft_worker,
|
82
98
|
)
|
83
99
|
|
@@ -137,8 +153,8 @@ class TpModelWorker:
|
|
137
153
|
assert self.max_running_requests > 0, "max_running_request is zero"
|
138
154
|
self.max_queued_requests = server_args.max_queued_requests
|
139
155
|
assert (
|
140
|
-
self.
|
141
|
-
), "
|
156
|
+
self.max_queued_requests is None or self.max_queued_requests >= 1
|
157
|
+
), "If configured, max_queued_requests must be at least 1 for any work to be scheduled."
|
142
158
|
self.max_req_len = min(
|
143
159
|
self.model_config.context_len - 1,
|
144
160
|
self.max_total_num_tokens - 1,
|
@@ -162,10 +178,10 @@ class TpModelWorker:
|
|
162
178
|
|
163
179
|
self.hicache_layer_transfer_counter = None
|
164
180
|
|
165
|
-
def register_hicache_layer_transfer_counter(self, counter):
|
181
|
+
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
166
182
|
self.hicache_layer_transfer_counter = counter
|
167
183
|
|
168
|
-
def set_hicache_consumer(self, consumer_index):
|
184
|
+
def set_hicache_consumer(self, consumer_index: int):
|
169
185
|
if self.hicache_layer_transfer_counter is not None:
|
170
186
|
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
171
187
|
|
@@ -221,10 +237,11 @@ class TpModelWorker:
|
|
221
237
|
self,
|
222
238
|
model_worker_batch: ModelWorkerBatch,
|
223
239
|
launch_done: Optional[threading.Event] = None,
|
224
|
-
|
225
|
-
) ->
|
226
|
-
|
227
|
-
|
240
|
+
is_verify: bool = False,
|
241
|
+
) -> ForwardBatchOutput:
|
242
|
+
# update the consumer index of hicache to the running batch
|
243
|
+
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
244
|
+
|
228
245
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
229
246
|
|
230
247
|
pp_proxy_tensors = None
|
@@ -242,20 +259,31 @@ class TpModelWorker:
|
|
242
259
|
if launch_done is not None:
|
243
260
|
launch_done.set()
|
244
261
|
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
262
|
+
skip_sample = is_verify or model_worker_batch.is_prefill_only
|
263
|
+
next_token_ids = None
|
264
|
+
|
265
|
+
if not skip_sample:
|
266
|
+
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
|
267
|
+
elif model_worker_batch.return_logprob and not is_verify:
|
268
|
+
# NOTE: Compute logprobs without full sampling
|
269
|
+
self.model_runner.compute_logprobs_only(
|
249
270
|
logits_output, model_worker_batch
|
250
271
|
)
|
251
272
|
|
252
|
-
return
|
273
|
+
return ForwardBatchOutput(
|
274
|
+
logits_output=logits_output,
|
275
|
+
next_token_ids=next_token_ids,
|
276
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
277
|
+
)
|
253
278
|
else:
|
254
279
|
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
|
255
280
|
forward_batch,
|
256
281
|
pp_proxy_tensors=pp_proxy_tensors,
|
257
282
|
)
|
258
|
-
return
|
283
|
+
return ForwardBatchOutput(
|
284
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
285
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
286
|
+
)
|
259
287
|
|
260
288
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
261
289
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
@@ -280,6 +308,37 @@ class TpModelWorker:
|
|
280
308
|
)
|
281
309
|
return success, message
|
282
310
|
|
311
|
+
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
312
|
+
success, message = self.model_runner.destroy_weights_update_group(
|
313
|
+
recv_req.group_name,
|
314
|
+
)
|
315
|
+
return success, message
|
316
|
+
|
317
|
+
def init_weights_send_group_for_remote_instance(
|
318
|
+
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
319
|
+
):
|
320
|
+
success, message = (
|
321
|
+
self.model_runner.init_weights_send_group_for_remote_instance(
|
322
|
+
recv_req.master_address,
|
323
|
+
recv_req.ports,
|
324
|
+
recv_req.group_rank,
|
325
|
+
recv_req.world_size,
|
326
|
+
recv_req.group_name,
|
327
|
+
recv_req.backend,
|
328
|
+
)
|
329
|
+
)
|
330
|
+
return success, message
|
331
|
+
|
332
|
+
def send_weights_to_remote_instance(
|
333
|
+
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
334
|
+
):
|
335
|
+
success, message = self.model_runner.send_weights_to_remote_instance(
|
336
|
+
recv_req.master_address,
|
337
|
+
recv_req.ports,
|
338
|
+
recv_req.group_name,
|
339
|
+
)
|
340
|
+
return success, message
|
341
|
+
|
283
342
|
def update_weights_from_distributed(
|
284
343
|
self, recv_req: UpdateWeightsFromDistributedReqInput
|
285
344
|
):
|
@@ -12,42 +12,42 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
"""A tensor parallel worker."""
|
15
|
+
from __future__ import annotations
|
15
16
|
|
16
17
|
import dataclasses
|
17
18
|
import logging
|
18
19
|
import signal
|
19
20
|
import threading
|
20
21
|
from queue import Queue
|
21
|
-
from typing import Optional, Tuple
|
22
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
22
23
|
|
23
24
|
import psutil
|
24
25
|
import torch
|
25
26
|
|
26
27
|
from sglang.srt.managers.io_struct import (
|
28
|
+
DestroyWeightsUpdateGroupReqInput,
|
27
29
|
GetWeightsByNameReqInput,
|
30
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
28
31
|
InitWeightsUpdateGroupReqInput,
|
29
32
|
LoadLoRAAdapterReqInput,
|
33
|
+
SendWeightsToRemoteInstanceReqInput,
|
30
34
|
UnloadLoRAAdapterReqInput,
|
31
35
|
UpdateWeightFromDiskReqInput,
|
32
36
|
UpdateWeightsFromDistributedReqInput,
|
33
37
|
UpdateWeightsFromTensorReqInput,
|
34
38
|
)
|
39
|
+
from sglang.srt.managers.overlap_utils import FutureMap
|
35
40
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
36
41
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
42
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput
|
37
43
|
from sglang.srt.server_args import ServerArgs
|
38
|
-
from sglang.srt.utils import DynamicGradMode
|
44
|
+
from sglang.srt.utils import DynamicGradMode
|
39
45
|
from sglang.utils import get_exception_traceback
|
40
46
|
|
41
|
-
|
42
|
-
|
47
|
+
if TYPE_CHECKING:
|
48
|
+
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
43
49
|
|
44
|
-
|
45
|
-
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
46
|
-
input_ids[:] = torch.where(
|
47
|
-
input_ids < 0,
|
48
|
-
future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
49
|
-
input_ids,
|
50
|
-
)
|
50
|
+
logger = logging.getLogger(__name__)
|
51
51
|
|
52
52
|
|
53
53
|
class TpModelWorkerClient:
|
@@ -72,14 +72,10 @@ class TpModelWorkerClient:
|
|
72
72
|
self.gpu_id = gpu_id
|
73
73
|
|
74
74
|
# Init future mappings
|
75
|
-
self.
|
76
|
-
self.future_token_ids_limit = self.max_running_requests * 3
|
77
|
-
self.future_token_ids_map = torch.empty(
|
78
|
-
(self.max_running_requests * 5,), dtype=torch.int64, device=self.device
|
79
|
-
)
|
75
|
+
self.future_map = FutureMap(self.max_running_requests, self.device)
|
80
76
|
|
81
77
|
# Launch threads
|
82
|
-
self.input_queue = Queue()
|
78
|
+
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
|
83
79
|
self.output_queue = Queue()
|
84
80
|
self.forward_stream = torch.get_device_module(self.device).Stream()
|
85
81
|
self.forward_thread = threading.Thread(
|
@@ -93,13 +89,9 @@ class TpModelWorkerClient:
|
|
93
89
|
|
94
90
|
self.hicache_layer_transfer_counter = None
|
95
91
|
|
96
|
-
def register_hicache_layer_transfer_counter(self, counter):
|
92
|
+
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
97
93
|
self.hicache_layer_transfer_counter = counter
|
98
94
|
|
99
|
-
def set_hicache_consumer(self, consumer_index):
|
100
|
-
if self.hicache_layer_transfer_counter is not None:
|
101
|
-
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
102
|
-
|
103
95
|
def get_worker_info(self):
|
104
96
|
return self.worker.get_worker_info()
|
105
97
|
|
@@ -147,10 +139,10 @@ class TpModelWorkerClient:
|
|
147
139
|
@DynamicGradMode()
|
148
140
|
def forward_thread_func_(self):
|
149
141
|
batch_pt = 0
|
150
|
-
batch_lists = [None] * 2
|
142
|
+
batch_lists: List = [None] * 2
|
151
143
|
|
152
144
|
while True:
|
153
|
-
model_worker_batch,
|
145
|
+
model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
|
154
146
|
if not model_worker_batch:
|
155
147
|
break
|
156
148
|
|
@@ -166,29 +158,35 @@ class TpModelWorkerClient:
|
|
166
158
|
copy_done = torch.get_device_module(self.device).Event()
|
167
159
|
|
168
160
|
# Resolve future tokens in the input
|
169
|
-
|
170
|
-
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
161
|
+
self.future_map.resolve_future(model_worker_batch)
|
171
162
|
|
172
|
-
# update the consumer index of hicache to the running batch
|
173
|
-
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
174
163
|
# Run forward
|
164
|
+
forward_batch_output = self.worker.forward_batch_generation(
|
165
|
+
model_worker_batch,
|
166
|
+
model_worker_batch.launch_done,
|
167
|
+
)
|
168
|
+
|
175
169
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
176
|
-
|
177
|
-
|
178
|
-
|
170
|
+
forward_batch_output.logits_output,
|
171
|
+
forward_batch_output.next_token_ids,
|
172
|
+
forward_batch_output.can_run_cuda_graph,
|
179
173
|
)
|
180
174
|
|
181
175
|
# Update the future token ids map
|
182
176
|
bs = len(model_worker_batch.seq_lens)
|
183
|
-
|
184
|
-
|
185
|
-
|
177
|
+
if model_worker_batch.is_prefill_only:
|
178
|
+
# For prefill-only requests, create dummy token IDs on CPU
|
179
|
+
next_token_ids = torch.zeros(bs, dtype=torch.long)
|
180
|
+
|
181
|
+
# store the future indices into future map
|
182
|
+
self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
|
186
183
|
|
187
184
|
# Copy results to the CPU
|
188
185
|
if model_worker_batch.return_logprob:
|
189
|
-
logits_output.next_token_logprobs
|
190
|
-
logits_output.next_token_logprobs
|
191
|
-
|
186
|
+
if logits_output.next_token_logprobs is not None:
|
187
|
+
logits_output.next_token_logprobs = (
|
188
|
+
logits_output.next_token_logprobs.to("cpu", non_blocking=True)
|
189
|
+
)
|
192
190
|
if logits_output.input_token_logprobs is not None:
|
193
191
|
logits_output.input_token_logprobs = (
|
194
192
|
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
@@ -197,7 +195,9 @@ class TpModelWorkerClient:
|
|
197
195
|
logits_output.hidden_states = logits_output.hidden_states.to(
|
198
196
|
"cpu", non_blocking=True
|
199
197
|
)
|
200
|
-
|
198
|
+
# Only copy to CPU if not already on CPU
|
199
|
+
if next_token_ids.device.type != "cpu":
|
200
|
+
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
201
201
|
copy_done.record()
|
202
202
|
|
203
203
|
self.output_queue.put(
|
@@ -221,16 +221,16 @@ class TpModelWorkerClient:
|
|
221
221
|
logits_output.next_token_logprobs = (
|
222
222
|
logits_output.next_token_logprobs.tolist()
|
223
223
|
)
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
224
|
+
if logits_output.input_token_logprobs is not None:
|
225
|
+
logits_output.input_token_logprobs = tuple(
|
226
|
+
logits_output.input_token_logprobs.tolist()
|
227
|
+
)
|
228
228
|
next_token_ids = next_token_ids.tolist()
|
229
229
|
return logits_output, next_token_ids, can_run_cuda_graph
|
230
230
|
|
231
231
|
def forward_batch_generation(
|
232
232
|
self, model_worker_batch: ModelWorkerBatch
|
233
|
-
) ->
|
233
|
+
) -> ForwardBatchOutput:
|
234
234
|
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
235
235
|
sampling_info = model_worker_batch.sampling_info
|
236
236
|
sampling_info.update_penalties()
|
@@ -245,21 +245,18 @@ class TpModelWorkerClient:
|
|
245
245
|
sync_event.record(self.scheduler_stream)
|
246
246
|
|
247
247
|
# Push a new batch to the queue
|
248
|
-
self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
|
249
|
-
|
250
|
-
# Allocate output future objects
|
251
248
|
bs = len(model_worker_batch.seq_lens)
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
249
|
+
cur_future_map_ct = self.future_map.update_ct(bs)
|
250
|
+
self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
|
251
|
+
|
252
|
+
# get this forward batch's future token ids
|
253
|
+
future_next_token_ids = self.future_map.update_next_future(
|
254
|
+
cur_future_map_ct, bs
|
255
|
+
)
|
256
|
+
return ForwardBatchOutput(
|
257
|
+
next_token_ids=future_next_token_ids,
|
258
|
+
can_run_cuda_graph=False,
|
258
259
|
)
|
259
|
-
self.future_token_ids_ct = (
|
260
|
-
self.future_token_ids_ct + bs
|
261
|
-
) % self.future_token_ids_limit
|
262
|
-
return None, future_next_token_ids, False
|
263
260
|
|
264
261
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
265
262
|
success, message = self.worker.update_weights_from_disk(recv_req)
|
@@ -269,6 +266,24 @@ class TpModelWorkerClient:
|
|
269
266
|
success, message = self.worker.init_weights_update_group(recv_req)
|
270
267
|
return success, message
|
271
268
|
|
269
|
+
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
270
|
+
success, message = self.worker.destroy_weights_update_group(recv_req)
|
271
|
+
return success, message
|
272
|
+
|
273
|
+
def init_weights_send_group_for_remote_instance(
|
274
|
+
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
275
|
+
):
|
276
|
+
success, message = self.worker.init_weights_send_group_for_remote_instance(
|
277
|
+
recv_req
|
278
|
+
)
|
279
|
+
return success, message
|
280
|
+
|
281
|
+
def send_weights_to_remote_instance(
|
282
|
+
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
283
|
+
):
|
284
|
+
success, message = self.worker.send_weights_to_remote_instance(recv_req)
|
285
|
+
return success, message
|
286
|
+
|
272
287
|
def update_weights_from_distributed(
|
273
288
|
self, recv_req: UpdateWeightsFromDistributedReqInput
|
274
289
|
):
|
sglang/srt/managers/utils.py
CHANGED
@@ -2,11 +2,10 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
import multiprocessing as mp
|
5
|
-
from http import HTTPStatus
|
6
5
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
7
6
|
|
8
7
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
9
|
-
from sglang.srt.managers.schedule_batch import
|
8
|
+
from sglang.srt.managers.schedule_batch import Req
|
10
9
|
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
11
10
|
|
12
11
|
if TYPE_CHECKING:
|
@@ -97,46 +96,3 @@ def get_logprob_from_pp_outputs(
|
|
97
96
|
]
|
98
97
|
|
99
98
|
return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
|
100
|
-
|
101
|
-
|
102
|
-
class DPBalanceMeta:
|
103
|
-
"""
|
104
|
-
This class will be use in scheduler and dp controller
|
105
|
-
"""
|
106
|
-
|
107
|
-
def __init__(self, num_workers: int):
|
108
|
-
self.num_workers = num_workers
|
109
|
-
self._manager = mp.Manager()
|
110
|
-
self.mutex = self._manager.Lock()
|
111
|
-
|
112
|
-
init_local_tokens = [0] * self.num_workers
|
113
|
-
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
|
114
|
-
|
115
|
-
self.shared_state = self._manager.Namespace()
|
116
|
-
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
|
117
|
-
self.shared_state.onfly_info = self._manager.list(init_onfly_info)
|
118
|
-
|
119
|
-
def destructor(self):
|
120
|
-
# we must destructor this class manually
|
121
|
-
self._manager.shutdown()
|
122
|
-
|
123
|
-
def get_shared_onfly(self) -> List[Dict[int, int]]:
|
124
|
-
return [dict(d) for d in self.shared_state.onfly_info]
|
125
|
-
|
126
|
-
def set_shared_onfly_info(self, data: List[Dict[int, int]]):
|
127
|
-
self.shared_state.onfly_info = data
|
128
|
-
|
129
|
-
def get_shared_local_tokens(self) -> List[int]:
|
130
|
-
return list(self.shared_state.local_tokens)
|
131
|
-
|
132
|
-
def set_shared_local_tokens(self, data: List[int]):
|
133
|
-
self.shared_state.local_tokens = data
|
134
|
-
|
135
|
-
def __getstate__(self):
|
136
|
-
state = self.__dict__.copy()
|
137
|
-
del state["_manager"]
|
138
|
-
return state
|
139
|
-
|
140
|
-
def __setstate__(self, state):
|
141
|
-
self.__dict__.update(state)
|
142
|
-
self._manager = None
|
@@ -27,7 +27,7 @@ import triton
|
|
27
27
|
import triton.language as tl
|
28
28
|
|
29
29
|
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
30
|
-
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
30
|
+
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
|
31
31
|
|
32
32
|
if TYPE_CHECKING:
|
33
33
|
from sglang.srt.mem_cache.memory_pool import KVCache
|
@@ -294,7 +294,6 @@ def alloc_extend_kernel(
|
|
294
294
|
last_loc_ptr,
|
295
295
|
free_page_ptr,
|
296
296
|
out_indices,
|
297
|
-
ret_values,
|
298
297
|
bs_upper: tl.constexpr,
|
299
298
|
page_size: tl.constexpr,
|
300
299
|
max_num_extend_tokens: tl.constexpr,
|
@@ -323,13 +322,6 @@ def alloc_extend_kernel(
|
|
323
322
|
sum_num_new_pages = tl.sum(num_new_pages)
|
324
323
|
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
325
324
|
|
326
|
-
# Return value
|
327
|
-
if pid == tl.num_programs(0) - 1:
|
328
|
-
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
|
329
|
-
tl.int64
|
330
|
-
)
|
331
|
-
tl.store(ret_values, merged_value)
|
332
|
-
|
333
325
|
# Part 1: fill the old partial page
|
334
326
|
last_loc = tl.load(last_loc_ptr + pid)
|
335
327
|
num_part1 = (
|
@@ -381,7 +373,6 @@ def alloc_decode_kernel(
|
|
381
373
|
last_loc_ptr,
|
382
374
|
free_page_ptr,
|
383
375
|
out_indices,
|
384
|
-
ret_values,
|
385
376
|
bs_upper: tl.constexpr,
|
386
377
|
page_size: tl.constexpr,
|
387
378
|
):
|
@@ -404,10 +395,6 @@ def alloc_decode_kernel(
|
|
404
395
|
sum_num_new_pages = tl.sum(num_new_pages)
|
405
396
|
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
406
397
|
|
407
|
-
# Return value
|
408
|
-
if pid == tl.num_programs(0) - 1:
|
409
|
-
tl.store(ret_values, sum_num_new_pages)
|
410
|
-
|
411
398
|
if num_page_start_loc_self == 0:
|
412
399
|
last_loc = tl.load(last_loc_ptr + pid)
|
413
400
|
tl.store(out_indices + pid, last_loc + 1)
|
@@ -438,7 +425,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
438
425
|
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
439
426
|
self.num_pages = size // page_size
|
440
427
|
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
441
|
-
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
442
428
|
self.seen_max_num_extend_tokens_next_power_of_2 = 1
|
443
429
|
self.clear()
|
444
430
|
|
@@ -468,7 +454,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
468
454
|
def alloc_extend(
|
469
455
|
self,
|
470
456
|
prefix_lens: torch.Tensor,
|
457
|
+
prefix_lens_cpu: torch.Tensor,
|
471
458
|
seq_lens: torch.Tensor,
|
459
|
+
seq_lens_cpu: torch.Tensor,
|
472
460
|
last_loc: torch.Tensor,
|
473
461
|
extend_num_tokens: int,
|
474
462
|
):
|
@@ -497,7 +485,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
497
485
|
last_loc,
|
498
486
|
self.free_pages,
|
499
487
|
out_indices,
|
500
|
-
self.ret_values,
|
501
488
|
next_power_of_2(bs),
|
502
489
|
self.page_size,
|
503
490
|
self.seen_max_num_extend_tokens_next_power_of_2,
|
@@ -506,8 +493,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
506
493
|
if self.debug_mode:
|
507
494
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
508
495
|
|
509
|
-
|
510
|
-
|
496
|
+
num_new_pages = get_num_new_pages(
|
497
|
+
seq_lens=seq_lens_cpu,
|
498
|
+
page_size=self.page_size,
|
499
|
+
prefix_lens=prefix_lens_cpu,
|
500
|
+
)
|
511
501
|
if num_new_pages > len(self.free_pages):
|
512
502
|
return None
|
513
503
|
|
@@ -517,6 +507,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
517
507
|
def alloc_decode(
|
518
508
|
self,
|
519
509
|
seq_lens: torch.Tensor,
|
510
|
+
seq_lens_cpu: torch.Tensor,
|
520
511
|
last_loc: torch.Tensor,
|
521
512
|
):
|
522
513
|
if self.debug_mode:
|
@@ -534,7 +525,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
534
525
|
last_loc,
|
535
526
|
self.free_pages,
|
536
527
|
out_indices,
|
537
|
-
self.ret_values,
|
538
528
|
next_power_of_2(bs),
|
539
529
|
self.page_size,
|
540
530
|
)
|
@@ -542,7 +532,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
542
532
|
if self.debug_mode:
|
543
533
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
544
534
|
|
545
|
-
num_new_pages =
|
535
|
+
num_new_pages = get_num_new_pages(
|
536
|
+
seq_lens=seq_lens_cpu,
|
537
|
+
page_size=self.page_size,
|
538
|
+
decode=True,
|
539
|
+
)
|
546
540
|
if num_new_pages > len(self.free_pages):
|
547
541
|
return None
|
548
542
|
|