sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -12,38 +12,44 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
"""A tensor parallel worker."""
|
15
|
+
from __future__ import annotations
|
15
16
|
|
16
17
|
import logging
|
17
|
-
import
|
18
|
-
from typing import Optional, Tuple, Union
|
18
|
+
from typing import TYPE_CHECKING, Optional
|
19
19
|
|
20
20
|
import torch
|
21
21
|
|
22
22
|
from sglang.srt.configs.model_config import ModelConfig
|
23
23
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
24
|
-
from sglang.srt.hf_transformers_utils import (
|
25
|
-
get_processor,
|
26
|
-
get_tokenizer,
|
27
|
-
get_tokenizer_from_processor,
|
28
|
-
)
|
29
|
-
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
30
24
|
from sglang.srt.managers.io_struct import (
|
25
|
+
DestroyWeightsUpdateGroupReqInput,
|
31
26
|
GetWeightsByNameReqInput,
|
27
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
32
28
|
InitWeightsUpdateGroupReqInput,
|
33
29
|
LoadLoRAAdapterReqInput,
|
30
|
+
SendWeightsToRemoteInstanceReqInput,
|
34
31
|
UnloadLoRAAdapterReqInput,
|
35
32
|
UpdateWeightFromDiskReqInput,
|
36
33
|
UpdateWeightsFromDistributedReqInput,
|
37
34
|
UpdateWeightsFromTensorReqInput,
|
38
35
|
)
|
39
36
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
37
|
+
from sglang.srt.managers.scheduler import GenerationBatchResult
|
40
38
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
41
39
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
42
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
43
41
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
44
|
-
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
45
42
|
from sglang.srt.server_args import ServerArgs
|
46
43
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
44
|
+
from sglang.srt.utils.hf_transformers_utils import (
|
45
|
+
get_processor,
|
46
|
+
get_tokenizer,
|
47
|
+
get_tokenizer_from_processor,
|
48
|
+
)
|
49
|
+
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
|
50
|
+
|
51
|
+
if TYPE_CHECKING:
|
52
|
+
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
47
53
|
|
48
54
|
logger = logging.getLogger(__name__)
|
49
55
|
|
@@ -78,6 +84,11 @@ class TpModelWorker:
|
|
78
84
|
if not is_draft_worker
|
79
85
|
else server_args.speculative_draft_model_path
|
80
86
|
),
|
87
|
+
model_revision=(
|
88
|
+
server_args.revision
|
89
|
+
if not is_draft_worker
|
90
|
+
else server_args.speculative_draft_model_revision
|
91
|
+
),
|
81
92
|
is_draft_model=is_draft_worker,
|
82
93
|
)
|
83
94
|
|
@@ -137,8 +148,8 @@ class TpModelWorker:
|
|
137
148
|
assert self.max_running_requests > 0, "max_running_request is zero"
|
138
149
|
self.max_queued_requests = server_args.max_queued_requests
|
139
150
|
assert (
|
140
|
-
self.
|
141
|
-
), "
|
151
|
+
self.max_queued_requests is None or self.max_queued_requests >= 1
|
152
|
+
), "If configured, max_queued_requests must be at least 1 for any work to be scheduled."
|
142
153
|
self.max_req_len = min(
|
143
154
|
self.model_config.context_len - 1,
|
144
155
|
self.max_total_num_tokens - 1,
|
@@ -162,10 +173,10 @@ class TpModelWorker:
|
|
162
173
|
|
163
174
|
self.hicache_layer_transfer_counter = None
|
164
175
|
|
165
|
-
def register_hicache_layer_transfer_counter(self, counter):
|
176
|
+
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
166
177
|
self.hicache_layer_transfer_counter = counter
|
167
178
|
|
168
|
-
def set_hicache_consumer(self, consumer_index):
|
179
|
+
def set_hicache_consumer(self, consumer_index: int):
|
169
180
|
if self.hicache_layer_transfer_counter is not None:
|
170
181
|
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
171
182
|
|
@@ -220,11 +231,11 @@ class TpModelWorker:
|
|
220
231
|
def forward_batch_generation(
|
221
232
|
self,
|
222
233
|
model_worker_batch: ModelWorkerBatch,
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
234
|
+
is_verify: bool = False,
|
235
|
+
) -> GenerationBatchResult:
|
236
|
+
# update the consumer index of hicache to the running batch
|
237
|
+
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
238
|
+
|
228
239
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
229
240
|
|
230
241
|
pp_proxy_tensors = None
|
@@ -239,23 +250,51 @@ class TpModelWorker:
|
|
239
250
|
logits_output, can_run_cuda_graph = self.model_runner.forward(
|
240
251
|
forward_batch, pp_proxy_tensors=pp_proxy_tensors
|
241
252
|
)
|
242
|
-
|
243
|
-
|
253
|
+
batch_result = GenerationBatchResult(
|
254
|
+
logits_output=logits_output,
|
255
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
256
|
+
)
|
244
257
|
|
245
|
-
if
|
246
|
-
|
258
|
+
if is_verify:
|
259
|
+
# Skip sampling and return logits for target forward
|
260
|
+
return batch_result
|
261
|
+
|
262
|
+
if model_worker_batch.delay_sample_launch:
|
263
|
+
batch_result.delay_sample_launch = True
|
264
|
+
batch_result.forward_batch = forward_batch
|
265
|
+
return batch_result
|
266
|
+
|
267
|
+
if model_worker_batch.is_prefill_only:
|
268
|
+
# For prefill-only requests, create dummy token IDs on CPU
|
269
|
+
# The size should match the batch size (number of sequences), not total tokens
|
270
|
+
batch_result.next_token_ids = torch.zeros(
|
271
|
+
len(model_worker_batch.seq_lens),
|
272
|
+
dtype=torch.long,
|
273
|
+
device=model_worker_batch.input_ids.device,
|
274
|
+
)
|
275
|
+
if (
|
276
|
+
model_worker_batch.return_logprob
|
277
|
+
and logits_output.next_token_logits is not None
|
278
|
+
):
|
279
|
+
# NOTE: Compute logprobs without full sampling
|
280
|
+
self.model_runner.compute_logprobs_only(
|
281
|
+
logits_output, model_worker_batch
|
282
|
+
)
|
247
283
|
else:
|
248
|
-
next_token_ids = self.model_runner.sample(
|
249
|
-
logits_output,
|
284
|
+
batch_result.next_token_ids = self.model_runner.sample(
|
285
|
+
logits_output, forward_batch
|
250
286
|
)
|
251
287
|
|
252
|
-
return
|
288
|
+
return batch_result
|
253
289
|
else:
|
254
290
|
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
|
255
291
|
forward_batch,
|
256
292
|
pp_proxy_tensors=pp_proxy_tensors,
|
257
293
|
)
|
258
|
-
return
|
294
|
+
return GenerationBatchResult(
|
295
|
+
pp_hidden_states_proxy_tensors=pp_proxy_tensors,
|
296
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
297
|
+
)
|
259
298
|
|
260
299
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
261
300
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
@@ -280,6 +319,37 @@ class TpModelWorker:
|
|
280
319
|
)
|
281
320
|
return success, message
|
282
321
|
|
322
|
+
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
323
|
+
success, message = self.model_runner.destroy_weights_update_group(
|
324
|
+
recv_req.group_name,
|
325
|
+
)
|
326
|
+
return success, message
|
327
|
+
|
328
|
+
def init_weights_send_group_for_remote_instance(
|
329
|
+
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
330
|
+
):
|
331
|
+
success, message = (
|
332
|
+
self.model_runner.init_weights_send_group_for_remote_instance(
|
333
|
+
recv_req.master_address,
|
334
|
+
recv_req.ports,
|
335
|
+
recv_req.group_rank,
|
336
|
+
recv_req.world_size,
|
337
|
+
recv_req.group_name,
|
338
|
+
recv_req.backend,
|
339
|
+
)
|
340
|
+
)
|
341
|
+
return success, message
|
342
|
+
|
343
|
+
def send_weights_to_remote_instance(
|
344
|
+
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
345
|
+
):
|
346
|
+
success, message = self.model_runner.send_weights_to_remote_instance(
|
347
|
+
recv_req.master_address,
|
348
|
+
recv_req.ports,
|
349
|
+
recv_req.group_name,
|
350
|
+
)
|
351
|
+
return success, message
|
352
|
+
|
283
353
|
def update_weights_from_distributed(
|
284
354
|
self, recv_req: UpdateWeightsFromDistributedReqInput
|
285
355
|
):
|
sglang/srt/managers/utils.py
CHANGED
@@ -2,11 +2,10 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
import multiprocessing as mp
|
5
|
-
from http import HTTPStatus
|
6
5
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
7
6
|
|
8
7
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
9
|
-
from sglang.srt.managers.schedule_batch import
|
8
|
+
from sglang.srt.managers.schedule_batch import Req
|
10
9
|
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
11
10
|
|
12
11
|
if TYPE_CHECKING:
|
@@ -97,46 +96,3 @@ def get_logprob_from_pp_outputs(
|
|
97
96
|
]
|
98
97
|
|
99
98
|
return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
|
100
|
-
|
101
|
-
|
102
|
-
class DPBalanceMeta:
|
103
|
-
"""
|
104
|
-
This class will be use in scheduler and dp controller
|
105
|
-
"""
|
106
|
-
|
107
|
-
def __init__(self, num_workers: int):
|
108
|
-
self.num_workers = num_workers
|
109
|
-
self._manager = mp.Manager()
|
110
|
-
self.mutex = self._manager.Lock()
|
111
|
-
|
112
|
-
init_local_tokens = [0] * self.num_workers
|
113
|
-
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
|
114
|
-
|
115
|
-
self.shared_state = self._manager.Namespace()
|
116
|
-
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
|
117
|
-
self.shared_state.onfly_info = self._manager.list(init_onfly_info)
|
118
|
-
|
119
|
-
def destructor(self):
|
120
|
-
# we must destructor this class manually
|
121
|
-
self._manager.shutdown()
|
122
|
-
|
123
|
-
def get_shared_onfly(self) -> List[Dict[int, int]]:
|
124
|
-
return [dict(d) for d in self.shared_state.onfly_info]
|
125
|
-
|
126
|
-
def set_shared_onfly_info(self, data: List[Dict[int, int]]):
|
127
|
-
self.shared_state.onfly_info = data
|
128
|
-
|
129
|
-
def get_shared_local_tokens(self) -> List[int]:
|
130
|
-
return list(self.shared_state.local_tokens)
|
131
|
-
|
132
|
-
def set_shared_local_tokens(self, data: List[int]):
|
133
|
-
self.shared_state.local_tokens = data
|
134
|
-
|
135
|
-
def __getstate__(self):
|
136
|
-
state = self.__dict__.copy()
|
137
|
-
del state["_manager"]
|
138
|
-
return state
|
139
|
-
|
140
|
-
def __setstate__(self, state):
|
141
|
-
self.__dict__.update(state)
|
142
|
-
self._manager = None
|
@@ -27,7 +27,7 @@ import triton
|
|
27
27
|
import triton.language as tl
|
28
28
|
|
29
29
|
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
30
|
-
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
30
|
+
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
|
31
31
|
|
32
32
|
if TYPE_CHECKING:
|
33
33
|
from sglang.srt.mem_cache.memory_pool import KVCache
|
@@ -274,10 +274,15 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
274
274
|
self.full_to_swa_index_mapping[free_index] = 0
|
275
275
|
|
276
276
|
def backup_state(self):
|
277
|
-
|
277
|
+
return [
|
278
|
+
self.full_attn_allocator.backup_state(),
|
279
|
+
self.swa_attn_allocator.backup_state(),
|
280
|
+
]
|
278
281
|
|
279
282
|
def restore_state(self, state):
|
280
|
-
|
283
|
+
assert len(state) == 2
|
284
|
+
self.full_attn_allocator.restore_state(state[0])
|
285
|
+
self.swa_attn_allocator.restore_state(state[1])
|
281
286
|
|
282
287
|
def clear(self):
|
283
288
|
self.swa_attn_allocator.clear()
|
@@ -294,7 +299,6 @@ def alloc_extend_kernel(
|
|
294
299
|
last_loc_ptr,
|
295
300
|
free_page_ptr,
|
296
301
|
out_indices,
|
297
|
-
ret_values,
|
298
302
|
bs_upper: tl.constexpr,
|
299
303
|
page_size: tl.constexpr,
|
300
304
|
max_num_extend_tokens: tl.constexpr,
|
@@ -323,13 +327,6 @@ def alloc_extend_kernel(
|
|
323
327
|
sum_num_new_pages = tl.sum(num_new_pages)
|
324
328
|
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
325
329
|
|
326
|
-
# Return value
|
327
|
-
if pid == tl.num_programs(0) - 1:
|
328
|
-
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
|
329
|
-
tl.int64
|
330
|
-
)
|
331
|
-
tl.store(ret_values, merged_value)
|
332
|
-
|
333
330
|
# Part 1: fill the old partial page
|
334
331
|
last_loc = tl.load(last_loc_ptr + pid)
|
335
332
|
num_part1 = (
|
@@ -381,7 +378,6 @@ def alloc_decode_kernel(
|
|
381
378
|
last_loc_ptr,
|
382
379
|
free_page_ptr,
|
383
380
|
out_indices,
|
384
|
-
ret_values,
|
385
381
|
bs_upper: tl.constexpr,
|
386
382
|
page_size: tl.constexpr,
|
387
383
|
):
|
@@ -404,10 +400,6 @@ def alloc_decode_kernel(
|
|
404
400
|
sum_num_new_pages = tl.sum(num_new_pages)
|
405
401
|
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
406
402
|
|
407
|
-
# Return value
|
408
|
-
if pid == tl.num_programs(0) - 1:
|
409
|
-
tl.store(ret_values, sum_num_new_pages)
|
410
|
-
|
411
403
|
if num_page_start_loc_self == 0:
|
412
404
|
last_loc = tl.load(last_loc_ptr + pid)
|
413
405
|
tl.store(out_indices + pid, last_loc + 1)
|
@@ -438,7 +430,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
438
430
|
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
439
431
|
self.num_pages = size // page_size
|
440
432
|
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
441
|
-
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
442
433
|
self.seen_max_num_extend_tokens_next_power_of_2 = 1
|
443
434
|
self.clear()
|
444
435
|
|
@@ -468,7 +459,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
468
459
|
def alloc_extend(
|
469
460
|
self,
|
470
461
|
prefix_lens: torch.Tensor,
|
462
|
+
prefix_lens_cpu: torch.Tensor,
|
471
463
|
seq_lens: torch.Tensor,
|
464
|
+
seq_lens_cpu: torch.Tensor,
|
472
465
|
last_loc: torch.Tensor,
|
473
466
|
extend_num_tokens: int,
|
474
467
|
):
|
@@ -497,7 +490,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
497
490
|
last_loc,
|
498
491
|
self.free_pages,
|
499
492
|
out_indices,
|
500
|
-
self.ret_values,
|
501
493
|
next_power_of_2(bs),
|
502
494
|
self.page_size,
|
503
495
|
self.seen_max_num_extend_tokens_next_power_of_2,
|
@@ -506,8 +498,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
506
498
|
if self.debug_mode:
|
507
499
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
508
500
|
|
509
|
-
|
510
|
-
|
501
|
+
num_new_pages = get_num_new_pages(
|
502
|
+
seq_lens=seq_lens_cpu,
|
503
|
+
page_size=self.page_size,
|
504
|
+
prefix_lens=prefix_lens_cpu,
|
505
|
+
)
|
511
506
|
if num_new_pages > len(self.free_pages):
|
512
507
|
return None
|
513
508
|
|
@@ -517,6 +512,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
517
512
|
def alloc_decode(
|
518
513
|
self,
|
519
514
|
seq_lens: torch.Tensor,
|
515
|
+
seq_lens_cpu: torch.Tensor,
|
520
516
|
last_loc: torch.Tensor,
|
521
517
|
):
|
522
518
|
if self.debug_mode:
|
@@ -534,7 +530,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
534
530
|
last_loc,
|
535
531
|
self.free_pages,
|
536
532
|
out_indices,
|
537
|
-
self.ret_values,
|
538
533
|
next_power_of_2(bs),
|
539
534
|
self.page_size,
|
540
535
|
)
|
@@ -542,7 +537,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
542
537
|
if self.debug_mode:
|
543
538
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
544
539
|
|
545
|
-
num_new_pages =
|
540
|
+
num_new_pages = get_num_new_pages(
|
541
|
+
seq_lens=seq_lens_cpu,
|
542
|
+
page_size=self.page_size,
|
543
|
+
decode=True,
|
544
|
+
)
|
546
545
|
if num_new_pages > len(self.free_pages):
|
547
546
|
return None
|
548
547
|
|
@@ -1,13 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import TYPE_CHECKING
|
4
|
-
|
5
3
|
import torch
|
6
4
|
|
7
5
|
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
|
8
|
-
|
9
|
-
if TYPE_CHECKING:
|
10
|
-
from sglang.srt.mem_cache.memory_pool import KVCache
|
6
|
+
from sglang.srt.utils import get_num_new_pages
|
11
7
|
|
12
8
|
|
13
9
|
def alloc_extend_kernel_ascend(
|
@@ -69,7 +65,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
69
65
|
def alloc_extend(
|
70
66
|
self,
|
71
67
|
prefix_lens: torch.Tensor,
|
68
|
+
prefix_lens_cpu: torch.Tensor,
|
72
69
|
seq_lens: torch.Tensor,
|
70
|
+
seq_lens_cpu: torch.Tensor,
|
73
71
|
last_loc: torch.Tensor,
|
74
72
|
extend_num_tokens: int,
|
75
73
|
):
|
@@ -79,42 +77,54 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
79
77
|
)
|
80
78
|
|
81
79
|
num_new_pages = (
|
82
|
-
(
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
.item()
|
88
|
-
)
|
89
|
-
if self.need_sort and num_new_pages > len(self.free_pages):
|
80
|
+
(seq_lens + self.page_size - 1) // self.page_size
|
81
|
+
- (prefix_lens + self.page_size - 1) // self.page_size
|
82
|
+
).sum()
|
83
|
+
num_new_pages_item = num_new_pages.item()
|
84
|
+
if self.need_sort and num_new_pages_item > len(self.free_pages):
|
90
85
|
self.merge_and_sort_free()
|
91
86
|
|
92
|
-
if
|
87
|
+
if num_new_pages_item > len(self.free_pages):
|
93
88
|
return None
|
94
89
|
|
95
90
|
out_indices = torch.empty(
|
96
|
-
(extend_num_tokens,), dtype=torch.
|
91
|
+
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
97
92
|
)
|
98
93
|
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
94
|
+
if num_new_pages_item < 200:
|
95
|
+
import sgl_kernel_npu
|
96
|
+
|
97
|
+
torch.ops.npu.alloc_extend(
|
98
|
+
prefix_lens,
|
99
|
+
seq_lens,
|
100
|
+
last_loc,
|
101
|
+
self.free_pages,
|
102
|
+
self.page_size,
|
103
|
+
out_indices,
|
104
|
+
num_new_pages,
|
105
|
+
)
|
106
|
+
|
107
|
+
else:
|
108
|
+
alloc_extend_kernel_ascend(
|
109
|
+
prefix_lens,
|
110
|
+
seq_lens,
|
111
|
+
last_loc,
|
112
|
+
self.free_pages,
|
113
|
+
out_indices,
|
114
|
+
self.page_size,
|
115
|
+
self.device,
|
116
|
+
)
|
108
117
|
|
109
118
|
if self.debug_mode:
|
110
119
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
111
120
|
|
112
|
-
self.free_pages = self.free_pages[
|
121
|
+
self.free_pages = self.free_pages[num_new_pages_item:]
|
113
122
|
return out_indices
|
114
123
|
|
115
124
|
def alloc_decode(
|
116
125
|
self,
|
117
126
|
seq_lens: torch.Tensor,
|
127
|
+
seq_lens_cpu: torch.Tensor,
|
118
128
|
last_loc: torch.Tensor,
|
119
129
|
):
|
120
130
|
if self.debug_mode:
|
@@ -122,8 +132,11 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
122
132
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
123
133
|
)
|
124
134
|
|
125
|
-
|
126
|
-
|
135
|
+
num_new_pages = get_num_new_pages(
|
136
|
+
seq_lens=seq_lens_cpu,
|
137
|
+
page_size=self.page_size,
|
138
|
+
decode=True,
|
139
|
+
)
|
127
140
|
|
128
141
|
if num_new_pages > len(self.free_pages):
|
129
142
|
self.merge_and_sort_free()
|
@@ -131,6 +144,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
131
144
|
if num_new_pages > len(self.free_pages):
|
132
145
|
return None
|
133
146
|
|
147
|
+
need_new_pages = (seq_lens % self.page_size == 1).int()
|
134
148
|
end_new_pages = torch.cumsum(need_new_pages, 0)
|
135
149
|
start_new_pages = end_new_pages - need_new_pages
|
136
150
|
if num_new_pages == 0:
|
@@ -28,6 +28,13 @@ class ChunkCache(BasePrefixCache):
|
|
28
28
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
29
29
|
self.page_size = page_size
|
30
30
|
|
31
|
+
# NOTE (csy): this is to determine if a cache has prefix matching feature.
|
32
|
+
# Chunk cache always return True to indicate no prefix matching.
|
33
|
+
# TODO (csy): Using a prefix cache trait to replace this
|
34
|
+
@property
|
35
|
+
def disable(self):
|
36
|
+
return True
|
37
|
+
|
31
38
|
def reset(self):
|
32
39
|
pass
|
33
40
|
|
@@ -38,7 +45,7 @@ class ChunkCache(BasePrefixCache):
|
|
38
45
|
last_host_node=None,
|
39
46
|
)
|
40
47
|
|
41
|
-
def cache_finished_req(self, req: Req):
|
48
|
+
def cache_finished_req(self, req: Req, insert: bool = True):
|
42
49
|
kv_indices = self.req_to_token_pool.req_to_token[
|
43
50
|
req.req_pool_idx,
|
44
51
|
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
|
@@ -53,7 +60,7 @@ class ChunkCache(BasePrefixCache):
|
|
53
60
|
]
|
54
61
|
|
55
62
|
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
56
|
-
req.prefix_indices = kv_indices
|
63
|
+
req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
57
64
|
|
58
65
|
def evict(self, num_tokens: int):
|
59
66
|
pass
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import TYPE_CHECKING, List, Tuple, Union
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from sglang.srt.mem_cache.radix_cache import TreeNode
|
8
|
+
|
9
|
+
|
10
|
+
class EvictionStrategy(ABC):
|
11
|
+
@abstractmethod
|
12
|
+
def get_priority(self, node: "TreeNode") -> Union[float, Tuple]:
|
13
|
+
pass
|
14
|
+
|
15
|
+
|
16
|
+
class LRUStrategy(EvictionStrategy):
|
17
|
+
def get_priority(self, node: "TreeNode") -> float:
|
18
|
+
return node.last_access_time
|
19
|
+
|
20
|
+
|
21
|
+
class LFUStrategy(EvictionStrategy):
|
22
|
+
def get_priority(self, node: "TreeNode") -> Tuple[int, float]:
|
23
|
+
return (node.hit_count, node.last_access_time)
|