sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
|
|
21
21
|
from __future__ import annotations
|
22
22
|
|
23
23
|
import logging
|
24
|
+
import time
|
24
25
|
from collections import deque
|
25
26
|
from dataclasses import dataclass
|
26
27
|
from http import HTTPStatus
|
@@ -45,7 +46,7 @@ from sglang.srt.disaggregation.utils import (
|
|
45
46
|
prepare_abort,
|
46
47
|
)
|
47
48
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
48
|
-
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
49
|
+
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
|
49
50
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
50
51
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
51
52
|
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
@@ -253,6 +254,7 @@ class DecodePreallocQueue:
|
|
253
254
|
prefill_dp_rank=req.data_parallel_rank,
|
254
255
|
)
|
255
256
|
|
257
|
+
req.add_latency(RequestStage.DECODE_PREPARE)
|
256
258
|
self.queue.append(
|
257
259
|
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
|
258
260
|
)
|
@@ -421,8 +423,13 @@ class DecodePreallocQueue:
|
|
421
423
|
kv_indices, self.token_to_kv_pool_allocator.page_size
|
422
424
|
)
|
423
425
|
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
426
|
+
|
424
427
|
preallocated_reqs.append(decode_req)
|
425
428
|
indices_to_remove.add(i)
|
429
|
+
decode_req.req.time_stats.decode_transfer_queue_entry_time = (
|
430
|
+
time.perf_counter()
|
431
|
+
)
|
432
|
+
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
|
426
433
|
|
427
434
|
self.queue = [
|
428
435
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
@@ -516,11 +523,19 @@ class DecodePreallocQueue:
|
|
516
523
|
dtype=torch.int64,
|
517
524
|
device=self.token_to_kv_pool_allocator.device,
|
518
525
|
),
|
526
|
+
prefix_lens_cpu=torch.tensor(
|
527
|
+
[0],
|
528
|
+
dtype=torch.int64,
|
529
|
+
),
|
519
530
|
seq_lens=torch.tensor(
|
520
531
|
[num_tokens],
|
521
532
|
dtype=torch.int64,
|
522
533
|
device=self.token_to_kv_pool_allocator.device,
|
523
534
|
),
|
535
|
+
seq_lens_cpu=torch.tensor(
|
536
|
+
[num_tokens],
|
537
|
+
dtype=torch.int64,
|
538
|
+
),
|
524
539
|
last_loc=torch.tensor(
|
525
540
|
[-1],
|
526
541
|
dtype=torch.int64,
|
@@ -607,16 +622,23 @@ class DecodeTransferQueue:
|
|
607
622
|
idx = decode_req.metadata_buffer_index
|
608
623
|
(
|
609
624
|
output_id,
|
625
|
+
cached_tokens,
|
610
626
|
output_token_logprobs_val,
|
611
627
|
output_token_logprobs_idx,
|
612
628
|
output_top_logprobs_val,
|
613
629
|
output_top_logprobs_idx,
|
630
|
+
output_topk_p,
|
631
|
+
output_topk_index,
|
614
632
|
output_hidden_states,
|
615
633
|
) = self.metadata_buffers.get_buf(idx)
|
616
634
|
|
617
635
|
decode_req.req.output_ids.append(output_id[0].item())
|
636
|
+
decode_req.req.cached_tokens = cached_tokens[0].item()
|
618
637
|
if not self.spec_algorithm.is_none():
|
638
|
+
decode_req.req.output_topk_p = output_topk_p
|
639
|
+
decode_req.req.output_topk_index = output_topk_index
|
619
640
|
decode_req.req.hidden_states_tensor = output_hidden_states
|
641
|
+
|
620
642
|
if decode_req.req.return_logprob:
|
621
643
|
decode_req.req.output_token_logprobs_val.append(
|
622
644
|
output_token_logprobs_val[0].item()
|
@@ -637,10 +659,17 @@ class DecodeTransferQueue:
|
|
637
659
|
|
638
660
|
if hasattr(decode_req.kv_receiver, "clear"):
|
639
661
|
decode_req.kv_receiver.clear()
|
662
|
+
decode_req.kv_receiver = None
|
663
|
+
|
664
|
+
indices_to_remove.add(i)
|
665
|
+
decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
|
640
666
|
|
641
667
|
# special handling for sampling_params.max_new_tokens == 1
|
642
668
|
if decode_req.req.sampling_params.max_new_tokens == 1:
|
643
669
|
# finish immediately
|
670
|
+
decode_req.req.time_stats.forward_entry_time = (
|
671
|
+
decode_req.req.time_stats.completion_time
|
672
|
+
) = time.perf_counter()
|
644
673
|
decode_req.req.check_finished()
|
645
674
|
self.scheduler.stream_output(
|
646
675
|
[decode_req.req], decode_req.req.return_logprob
|
@@ -648,8 +677,6 @@ class DecodeTransferQueue:
|
|
648
677
|
self.tree_cache.cache_finished_req(decode_req.req)
|
649
678
|
else:
|
650
679
|
transferred_reqs.append(decode_req.req)
|
651
|
-
|
652
|
-
indices_to_remove.add(i)
|
653
680
|
elif poll in [
|
654
681
|
KVPoll.Bootstrapping,
|
655
682
|
KVPoll.WaitingForInput,
|
@@ -662,6 +689,7 @@ class DecodeTransferQueue:
|
|
662
689
|
for i in indices_to_remove:
|
663
690
|
idx = self.queue[i].metadata_buffer_index
|
664
691
|
assert idx != -1
|
692
|
+
self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED)
|
665
693
|
self.req_to_metadata_buffer_idx_allocator.free(idx)
|
666
694
|
|
667
695
|
self.queue = [
|
@@ -704,12 +732,15 @@ class SchedulerDisaggregationDecodeMixin:
|
|
704
732
|
elif prepare_mlp_sync_flag:
|
705
733
|
batch, _ = self._prepare_idle_batch_and_run(None)
|
706
734
|
|
707
|
-
|
735
|
+
queue_size = (
|
708
736
|
len(self.waiting_queue)
|
709
737
|
+ len(self.disagg_decode_transfer_queue.queue)
|
710
738
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
711
|
-
|
712
|
-
|
739
|
+
)
|
740
|
+
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
741
|
+
queue_size += len(self.decode_offload_manager.ongoing_offload)
|
742
|
+
|
743
|
+
if batch is None and queue_size == 0:
|
713
744
|
self.self_check_during_idle()
|
714
745
|
|
715
746
|
self.last_batch = batch
|
@@ -778,12 +809,15 @@ class SchedulerDisaggregationDecodeMixin:
|
|
778
809
|
)
|
779
810
|
self.process_batch_result(tmp_batch, tmp_result)
|
780
811
|
|
781
|
-
|
812
|
+
queue_size = (
|
782
813
|
len(self.waiting_queue)
|
783
814
|
+ len(self.disagg_decode_transfer_queue.queue)
|
784
815
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
785
|
-
|
786
|
-
|
816
|
+
)
|
817
|
+
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
818
|
+
queue_size += len(self.decode_offload_manager.ongoing_offload)
|
819
|
+
|
820
|
+
if batch is None and queue_size == 0:
|
787
821
|
self.self_check_during_idle()
|
788
822
|
|
789
823
|
self.last_batch = batch
|
@@ -853,6 +887,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
853
887
|
# we can only add at least `num_not_used_batch` new batch to the running queue
|
854
888
|
if i < num_not_used_batch:
|
855
889
|
can_run_list.append(req)
|
890
|
+
req.add_latency(RequestStage.DECODE_WAITING)
|
856
891
|
req.init_next_round_input(self.tree_cache)
|
857
892
|
else:
|
858
893
|
waiting_queue.append(req)
|
@@ -861,6 +896,9 @@ class SchedulerDisaggregationDecodeMixin:
|
|
861
896
|
if len(can_run_list) == 0:
|
862
897
|
return None
|
863
898
|
|
899
|
+
for req in can_run_list:
|
900
|
+
req.time_stats.forward_entry_time = time.perf_counter()
|
901
|
+
|
864
902
|
# construct a schedule batch with those requests and mark as decode
|
865
903
|
new_batch = ScheduleBatch.init_new(
|
866
904
|
can_run_list,
|
@@ -901,3 +939,6 @@ class SchedulerDisaggregationDecodeMixin:
|
|
901
939
|
self.disagg_decode_transfer_queue.pop_transferred()
|
902
940
|
) # the requests which kv has arrived
|
903
941
|
self.waiting_queue.extend(alloc_reqs)
|
942
|
+
|
943
|
+
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
944
|
+
self.decode_offload_manager.check_offload_progress()
|
@@ -0,0 +1,185 @@
|
|
1
|
+
import logging
|
2
|
+
import threading
|
3
|
+
import time
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from sglang import ServerArgs
|
8
|
+
from sglang.srt.managers.cache_controller import HiCacheController
|
9
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
10
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
11
|
+
from sglang.srt.mem_cache.memory_pool import (
|
12
|
+
MHATokenToKVPool,
|
13
|
+
MLATokenToKVPool,
|
14
|
+
ReqToTokenPool,
|
15
|
+
)
|
16
|
+
from sglang.srt.mem_cache.memory_pool_host import (
|
17
|
+
MHATokenToKVPoolHost,
|
18
|
+
MLATokenToKVPoolHost,
|
19
|
+
)
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class DecodeKVCacheOffloadManager:
|
25
|
+
"""Manage decode-side KV cache offloading lifecycle and operations."""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
req_to_token_pool: ReqToTokenPool,
|
30
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
31
|
+
tp_group: torch.distributed.ProcessGroup,
|
32
|
+
tree_cache: BasePrefixCache,
|
33
|
+
server_args: ServerArgs,
|
34
|
+
) -> None:
|
35
|
+
self.req_to_token_pool = req_to_token_pool
|
36
|
+
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
37
|
+
self.page_size = server_args.page_size
|
38
|
+
self.server_args = server_args
|
39
|
+
self.request_counter = 0
|
40
|
+
self.tree_cache = tree_cache
|
41
|
+
kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
|
42
|
+
if isinstance(kv_cache, MHATokenToKVPool):
|
43
|
+
self.decode_host_mem_pool = MHATokenToKVPoolHost(
|
44
|
+
kv_cache,
|
45
|
+
server_args.hicache_ratio,
|
46
|
+
server_args.hicache_size,
|
47
|
+
self.page_size,
|
48
|
+
server_args.hicache_mem_layout,
|
49
|
+
)
|
50
|
+
elif isinstance(kv_cache, MLATokenToKVPool):
|
51
|
+
self.decode_host_mem_pool = MLATokenToKVPoolHost(
|
52
|
+
kv_cache,
|
53
|
+
server_args.hicache_ratio,
|
54
|
+
server_args.hicache_size,
|
55
|
+
self.page_size,
|
56
|
+
server_args.hicache_mem_layout,
|
57
|
+
)
|
58
|
+
else:
|
59
|
+
raise ValueError("Unsupported KV cache type for decode offload")
|
60
|
+
|
61
|
+
self.tp_group = tp_group
|
62
|
+
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
63
|
+
self.cache_controller = HiCacheController(
|
64
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
65
|
+
mem_pool_host=self.decode_host_mem_pool,
|
66
|
+
page_size=self.page_size,
|
67
|
+
tp_group=tp_group,
|
68
|
+
io_backend=server_args.hicache_io_backend,
|
69
|
+
load_cache_event=threading.Event(),
|
70
|
+
storage_backend=server_args.hicache_storage_backend,
|
71
|
+
model_name=server_args.served_model_name,
|
72
|
+
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
73
|
+
)
|
74
|
+
|
75
|
+
self.ongoing_offload = {}
|
76
|
+
self.ongoing_backup = {}
|
77
|
+
logger.info("Enable offload kv cache for decode side")
|
78
|
+
|
79
|
+
def offload_kv_cache(self, req) -> bool:
|
80
|
+
"""Offload a finished request's KV cache to storage."""
|
81
|
+
|
82
|
+
if self.cache_controller is None or self.decode_host_mem_pool is None:
|
83
|
+
return False
|
84
|
+
|
85
|
+
if req.req_pool_idx == -1:
|
86
|
+
return False
|
87
|
+
|
88
|
+
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
|
89
|
+
if token_indices.dim() == 0 or token_indices.numel() == 0:
|
90
|
+
logger.debug(
|
91
|
+
f"Request {req.rid} has invalid token_indices: {token_indices}"
|
92
|
+
)
|
93
|
+
return False
|
94
|
+
|
95
|
+
tokens = req.origin_input_ids + req.output_ids
|
96
|
+
aligned_len = (len(tokens) // self.page_size) * self.page_size
|
97
|
+
if aligned_len == 0:
|
98
|
+
return False
|
99
|
+
|
100
|
+
token_indices = token_indices[:aligned_len]
|
101
|
+
tokens = tokens[:aligned_len]
|
102
|
+
|
103
|
+
# Asynchronously offload KV cache from device to host by cache controller
|
104
|
+
self.request_counter += 1
|
105
|
+
ack_id = self.request_counter
|
106
|
+
host_indices = self.cache_controller.write(
|
107
|
+
device_indices=token_indices.long(),
|
108
|
+
node_id=ack_id,
|
109
|
+
)
|
110
|
+
if host_indices is None:
|
111
|
+
logger.error(f"Not enough host memory for request {req.rid}")
|
112
|
+
return False
|
113
|
+
|
114
|
+
self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
|
115
|
+
return True
|
116
|
+
|
117
|
+
def check_offload_progress(self):
|
118
|
+
"""Check the progress of offload from device to host and backup from host to storage."""
|
119
|
+
cc = self.cache_controller
|
120
|
+
|
121
|
+
qsizes = torch.tensor(
|
122
|
+
[
|
123
|
+
len(cc.ack_write_queue),
|
124
|
+
cc.ack_backup_queue.qsize(),
|
125
|
+
],
|
126
|
+
dtype=torch.int,
|
127
|
+
)
|
128
|
+
if self.tp_world_size > 1:
|
129
|
+
torch.distributed.all_reduce(
|
130
|
+
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
131
|
+
)
|
132
|
+
|
133
|
+
n_write, n_backup = map(int, qsizes.tolist())
|
134
|
+
self._check_offload_progress(n_write)
|
135
|
+
self._check_backup_progress(n_backup)
|
136
|
+
|
137
|
+
def _check_offload_progress(self, finish_count):
|
138
|
+
"""Check the progress of offload from device to host."""
|
139
|
+
while finish_count > 0:
|
140
|
+
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
141
|
+
finish_event.synchronize()
|
142
|
+
for ack_id in ack_list:
|
143
|
+
req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
|
144
|
+
|
145
|
+
# Release device
|
146
|
+
self.tree_cache.cache_finished_req(req)
|
147
|
+
|
148
|
+
# Trigger async backup from host to storage by cache controller
|
149
|
+
self._trigger_backup(req.rid, host_indices, tokens, start_time)
|
150
|
+
finish_count -= 1
|
151
|
+
|
152
|
+
def _check_backup_progress(self, finish_count):
|
153
|
+
"""Check the progress of backup from host to storage."""
|
154
|
+
for _ in range(finish_count):
|
155
|
+
storage_operation = self.cache_controller.ack_backup_queue.get()
|
156
|
+
ack_id = storage_operation.id
|
157
|
+
req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
|
158
|
+
|
159
|
+
# Release host memory
|
160
|
+
self.decode_host_mem_pool.free(host_indices)
|
161
|
+
|
162
|
+
logger.debug(
|
163
|
+
f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
|
164
|
+
)
|
165
|
+
|
166
|
+
def _trigger_backup(self, req_id, host_indices, tokens, start_time):
|
167
|
+
"""Trigger async backup from host to storage by cache controller."""
|
168
|
+
|
169
|
+
# Generate page hashes and write to storage
|
170
|
+
page_hashes = self._compute_prefix_hash(tokens)
|
171
|
+
ack_id = self.cache_controller.write_storage(
|
172
|
+
host_indices,
|
173
|
+
tokens,
|
174
|
+
hash_value=page_hashes,
|
175
|
+
)
|
176
|
+
self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
|
177
|
+
|
178
|
+
def _compute_prefix_hash(self, tokens):
|
179
|
+
last_hash = ""
|
180
|
+
page_hashes = []
|
181
|
+
for offset in range(0, len(tokens), self.page_size):
|
182
|
+
page_tokens = tokens[offset : offset + self.page_size]
|
183
|
+
last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
|
184
|
+
page_hashes.append(last_hash)
|
185
|
+
return page_hashes
|
@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
76
76
|
req_pool_indices, dtype=torch.int64, device=self.device
|
77
77
|
)
|
78
78
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
79
|
+
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
79
80
|
self.orig_seq_lens = torch.tensor(
|
80
81
|
seq_lens, dtype=torch.int32, device=self.device
|
81
82
|
)
|
@@ -125,31 +126,39 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
125
126
|
req.grammar.finished = req.finished()
|
126
127
|
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
127
128
|
|
128
|
-
# Simulate the eagle run.
|
129
|
-
|
130
|
-
# of 0.
|
131
|
-
if not self.spec_algorithm.is_none():
|
129
|
+
# Simulate the eagle run.
|
130
|
+
if self.spec_algorithm.is_eagle():
|
132
131
|
|
133
132
|
b = len(self.reqs)
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
133
|
+
topk = server_args.speculative_eagle_topk
|
134
|
+
topk_p = torch.stack(
|
135
|
+
[
|
136
|
+
torch.as_tensor(
|
137
|
+
req.output_topk_p[:topk],
|
138
|
+
device=self.device,
|
139
|
+
dtype=torch.float32,
|
140
|
+
)
|
141
|
+
for req in self.reqs
|
142
|
+
],
|
143
|
+
dim=0,
|
140
144
|
)
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
+
topk_index = torch.stack(
|
146
|
+
[
|
147
|
+
torch.as_tensor(
|
148
|
+
req.output_topk_index[:topk],
|
149
|
+
device=self.device,
|
150
|
+
dtype=torch.int64,
|
151
|
+
)
|
152
|
+
for req in self.reqs
|
153
|
+
],
|
154
|
+
dim=0,
|
145
155
|
)
|
146
|
-
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
147
156
|
|
148
157
|
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
|
149
158
|
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
150
159
|
|
151
160
|
# local import to avoid circular import
|
152
|
-
from sglang.srt.speculative.
|
161
|
+
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
153
162
|
|
154
163
|
spec_info = EagleDraftInput(
|
155
164
|
topk_p=topk_p,
|