sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -22,13 +22,9 @@ import torch
|
|
22
22
|
|
23
23
|
from sglang.srt.configs.model_config import ModelConfig
|
24
24
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
25
|
-
from sglang.srt.hf_transformers_utils import (
|
26
|
-
get_processor,
|
27
|
-
get_tokenizer,
|
28
|
-
get_tokenizer_from_processor,
|
29
|
-
)
|
30
25
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
31
26
|
from sglang.srt.managers.io_struct import (
|
27
|
+
DestroyWeightsUpdateGroupReqInput,
|
32
28
|
GetWeightsByNameReqInput,
|
33
29
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
34
30
|
InitWeightsUpdateGroupReqInput,
|
@@ -42,11 +38,20 @@ from sglang.srt.managers.io_struct import (
|
|
42
38
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
43
39
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
44
40
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
45
|
-
from sglang.srt.model_executor.forward_batch_info import
|
41
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
42
|
+
ForwardBatch,
|
43
|
+
ForwardBatchOutput,
|
44
|
+
PPProxyTensors,
|
45
|
+
)
|
46
46
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
47
|
-
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
48
47
|
from sglang.srt.server_args import ServerArgs
|
49
48
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
49
|
+
from sglang.srt.utils.hf_transformers_utils import (
|
50
|
+
get_processor,
|
51
|
+
get_tokenizer,
|
52
|
+
get_tokenizer_from_processor,
|
53
|
+
)
|
54
|
+
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
|
50
55
|
|
51
56
|
if TYPE_CHECKING:
|
52
57
|
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
@@ -90,7 +95,6 @@ class TpModelWorker:
|
|
90
95
|
else server_args.speculative_draft_model_revision
|
91
96
|
),
|
92
97
|
is_draft_model=is_draft_worker,
|
93
|
-
tp_rank=tp_rank,
|
94
98
|
)
|
95
99
|
|
96
100
|
self.model_runner = ModelRunner(
|
@@ -149,8 +153,8 @@ class TpModelWorker:
|
|
149
153
|
assert self.max_running_requests > 0, "max_running_request is zero"
|
150
154
|
self.max_queued_requests = server_args.max_queued_requests
|
151
155
|
assert (
|
152
|
-
self.max_queued_requests
|
153
|
-
), "
|
156
|
+
self.max_queued_requests is None or self.max_queued_requests >= 1
|
157
|
+
), "If configured, max_queued_requests must be at least 1 for any work to be scheduled."
|
154
158
|
self.max_req_len = min(
|
155
159
|
self.model_config.context_len - 1,
|
156
160
|
self.max_total_num_tokens - 1,
|
@@ -233,10 +237,8 @@ class TpModelWorker:
|
|
233
237
|
self,
|
234
238
|
model_worker_batch: ModelWorkerBatch,
|
235
239
|
launch_done: Optional[threading.Event] = None,
|
236
|
-
|
237
|
-
) ->
|
238
|
-
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
|
239
|
-
]:
|
240
|
+
is_verify: bool = False,
|
241
|
+
) -> ForwardBatchOutput:
|
240
242
|
# update the consumer index of hicache to the running batch
|
241
243
|
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
242
244
|
|
@@ -257,29 +259,31 @@ class TpModelWorker:
|
|
257
259
|
if launch_done is not None:
|
258
260
|
launch_done.set()
|
259
261
|
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
self.model_runner.compute_logprobs_only(
|
269
|
-
logits_output, model_worker_batch
|
270
|
-
)
|
271
|
-
else:
|
272
|
-
next_token_ids = self.model_runner.sample(
|
262
|
+
skip_sample = is_verify or model_worker_batch.is_prefill_only
|
263
|
+
next_token_ids = None
|
264
|
+
|
265
|
+
if not skip_sample:
|
266
|
+
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
|
267
|
+
elif model_worker_batch.return_logprob and not is_verify:
|
268
|
+
# NOTE: Compute logprobs without full sampling
|
269
|
+
self.model_runner.compute_logprobs_only(
|
273
270
|
logits_output, model_worker_batch
|
274
271
|
)
|
275
272
|
|
276
|
-
return
|
273
|
+
return ForwardBatchOutput(
|
274
|
+
logits_output=logits_output,
|
275
|
+
next_token_ids=next_token_ids,
|
276
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
277
|
+
)
|
277
278
|
else:
|
278
279
|
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
|
279
280
|
forward_batch,
|
280
281
|
pp_proxy_tensors=pp_proxy_tensors,
|
281
282
|
)
|
282
|
-
return
|
283
|
+
return ForwardBatchOutput(
|
284
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
285
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
286
|
+
)
|
283
287
|
|
284
288
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
285
289
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
@@ -304,6 +308,12 @@ class TpModelWorker:
|
|
304
308
|
)
|
305
309
|
return success, message
|
306
310
|
|
311
|
+
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
312
|
+
success, message = self.model_runner.destroy_weights_update_group(
|
313
|
+
recv_req.group_name,
|
314
|
+
)
|
315
|
+
return success, message
|
316
|
+
|
307
317
|
def init_weights_send_group_for_remote_instance(
|
308
318
|
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
309
319
|
):
|
@@ -25,6 +25,7 @@ import psutil
|
|
25
25
|
import torch
|
26
26
|
|
27
27
|
from sglang.srt.managers.io_struct import (
|
28
|
+
DestroyWeightsUpdateGroupReqInput,
|
28
29
|
GetWeightsByNameReqInput,
|
29
30
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
30
31
|
InitWeightsUpdateGroupReqInput,
|
@@ -35,10 +36,12 @@ from sglang.srt.managers.io_struct import (
|
|
35
36
|
UpdateWeightsFromDistributedReqInput,
|
36
37
|
UpdateWeightsFromTensorReqInput,
|
37
38
|
)
|
39
|
+
from sglang.srt.managers.overlap_utils import FutureMap
|
38
40
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
39
41
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
42
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput
|
40
43
|
from sglang.srt.server_args import ServerArgs
|
41
|
-
from sglang.srt.utils import DynamicGradMode
|
44
|
+
from sglang.srt.utils import DynamicGradMode
|
42
45
|
from sglang.utils import get_exception_traceback
|
43
46
|
|
44
47
|
if TYPE_CHECKING:
|
@@ -47,15 +50,6 @@ if TYPE_CHECKING:
|
|
47
50
|
logger = logging.getLogger(__name__)
|
48
51
|
|
49
52
|
|
50
|
-
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
51
|
-
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
52
|
-
input_ids[:] = torch.where(
|
53
|
-
input_ids < 0,
|
54
|
-
future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
55
|
-
input_ids,
|
56
|
-
)
|
57
|
-
|
58
|
-
|
59
53
|
class TpModelWorkerClient:
|
60
54
|
"""A tensor parallel model worker."""
|
61
55
|
|
@@ -78,11 +72,7 @@ class TpModelWorkerClient:
|
|
78
72
|
self.gpu_id = gpu_id
|
79
73
|
|
80
74
|
# Init future mappings
|
81
|
-
self.
|
82
|
-
self.future_token_ids_limit = self.max_running_requests * 3
|
83
|
-
self.future_token_ids_map = torch.empty(
|
84
|
-
(self.max_running_requests * 5,), dtype=torch.int64, device=self.device
|
85
|
-
)
|
75
|
+
self.future_map = FutureMap(self.max_running_requests, self.device)
|
86
76
|
|
87
77
|
# Launch threads
|
88
78
|
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
|
@@ -152,7 +142,7 @@ class TpModelWorkerClient:
|
|
152
142
|
batch_lists: List = [None] * 2
|
153
143
|
|
154
144
|
while True:
|
155
|
-
model_worker_batch,
|
145
|
+
model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
|
156
146
|
if not model_worker_batch:
|
157
147
|
break
|
158
148
|
|
@@ -168,17 +158,18 @@ class TpModelWorkerClient:
|
|
168
158
|
copy_done = torch.get_device_module(self.device).Event()
|
169
159
|
|
170
160
|
# Resolve future tokens in the input
|
171
|
-
|
172
|
-
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
161
|
+
self.future_map.resolve_future(model_worker_batch)
|
173
162
|
|
174
163
|
# Run forward
|
164
|
+
forward_batch_output = self.worker.forward_batch_generation(
|
165
|
+
model_worker_batch,
|
166
|
+
model_worker_batch.launch_done,
|
167
|
+
)
|
168
|
+
|
175
169
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
# Skip sampling for prefill-only requests
|
180
|
-
skip_sample=model_worker_batch.is_prefill_only,
|
181
|
-
)
|
170
|
+
forward_batch_output.logits_output,
|
171
|
+
forward_batch_output.next_token_ids,
|
172
|
+
forward_batch_output.can_run_cuda_graph,
|
182
173
|
)
|
183
174
|
|
184
175
|
# Update the future token ids map
|
@@ -186,9 +177,9 @@ class TpModelWorkerClient:
|
|
186
177
|
if model_worker_batch.is_prefill_only:
|
187
178
|
# For prefill-only requests, create dummy token IDs on CPU
|
188
179
|
next_token_ids = torch.zeros(bs, dtype=torch.long)
|
189
|
-
|
190
|
-
|
191
|
-
|
180
|
+
|
181
|
+
# store the future indices into future map
|
182
|
+
self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
|
192
183
|
|
193
184
|
# Copy results to the CPU
|
194
185
|
if model_worker_batch.return_logprob:
|
@@ -239,7 +230,7 @@ class TpModelWorkerClient:
|
|
239
230
|
|
240
231
|
def forward_batch_generation(
|
241
232
|
self, model_worker_batch: ModelWorkerBatch
|
242
|
-
) ->
|
233
|
+
) -> ForwardBatchOutput:
|
243
234
|
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
244
235
|
sampling_info = model_worker_batch.sampling_info
|
245
236
|
sampling_info.update_penalties()
|
@@ -254,21 +245,18 @@ class TpModelWorkerClient:
|
|
254
245
|
sync_event.record(self.scheduler_stream)
|
255
246
|
|
256
247
|
# Push a new batch to the queue
|
257
|
-
self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
|
258
|
-
|
259
|
-
# Allocate output future objects
|
260
248
|
bs = len(model_worker_batch.seq_lens)
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
249
|
+
cur_future_map_ct = self.future_map.update_ct(bs)
|
250
|
+
self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
|
251
|
+
|
252
|
+
# get this forward batch's future token ids
|
253
|
+
future_next_token_ids = self.future_map.update_next_future(
|
254
|
+
cur_future_map_ct, bs
|
255
|
+
)
|
256
|
+
return ForwardBatchOutput(
|
257
|
+
next_token_ids=future_next_token_ids,
|
258
|
+
can_run_cuda_graph=False,
|
267
259
|
)
|
268
|
-
self.future_token_ids_ct = (
|
269
|
-
self.future_token_ids_ct + bs
|
270
|
-
) % self.future_token_ids_limit
|
271
|
-
return None, future_next_token_ids, False
|
272
260
|
|
273
261
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
274
262
|
success, message = self.worker.update_weights_from_disk(recv_req)
|
@@ -278,6 +266,10 @@ class TpModelWorkerClient:
|
|
278
266
|
success, message = self.worker.init_weights_update_group(recv_req)
|
279
267
|
return success, message
|
280
268
|
|
269
|
+
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
270
|
+
success, message = self.worker.destroy_weights_update_group(recv_req)
|
271
|
+
return success, message
|
272
|
+
|
281
273
|
def init_weights_send_group_for_remote_instance(
|
282
274
|
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
283
275
|
):
|
sglang/srt/managers/utils.py
CHANGED
@@ -2,11 +2,10 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
import multiprocessing as mp
|
5
|
-
from http import HTTPStatus
|
6
5
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
7
6
|
|
8
7
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
9
|
-
from sglang.srt.managers.schedule_batch import
|
8
|
+
from sglang.srt.managers.schedule_batch import Req
|
10
9
|
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
11
10
|
|
12
11
|
if TYPE_CHECKING:
|
@@ -97,46 +96,3 @@ def get_logprob_from_pp_outputs(
|
|
97
96
|
]
|
98
97
|
|
99
98
|
return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
|
100
|
-
|
101
|
-
|
102
|
-
class DPBalanceMeta:
|
103
|
-
"""
|
104
|
-
This class will be use in scheduler and dp controller
|
105
|
-
"""
|
106
|
-
|
107
|
-
def __init__(self, num_workers: int):
|
108
|
-
self.num_workers = num_workers
|
109
|
-
self._manager = mp.Manager()
|
110
|
-
self.mutex = self._manager.Lock()
|
111
|
-
|
112
|
-
init_local_tokens = [0] * self.num_workers
|
113
|
-
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
|
114
|
-
|
115
|
-
self.shared_state = self._manager.Namespace()
|
116
|
-
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
|
117
|
-
self.shared_state.onfly_info = self._manager.list(init_onfly_info)
|
118
|
-
|
119
|
-
def destructor(self):
|
120
|
-
# we must destructor this class manually
|
121
|
-
self._manager.shutdown()
|
122
|
-
|
123
|
-
def get_shared_onfly(self) -> List[Dict[int, int]]:
|
124
|
-
return [dict(d) for d in self.shared_state.onfly_info]
|
125
|
-
|
126
|
-
def set_shared_onfly_info(self, data: List[Dict[int, int]]):
|
127
|
-
self.shared_state.onfly_info = data
|
128
|
-
|
129
|
-
def get_shared_local_tokens(self) -> List[int]:
|
130
|
-
return list(self.shared_state.local_tokens)
|
131
|
-
|
132
|
-
def set_shared_local_tokens(self, data: List[int]):
|
133
|
-
self.shared_state.local_tokens = data
|
134
|
-
|
135
|
-
def __getstate__(self):
|
136
|
-
state = self.__dict__.copy()
|
137
|
-
del state["_manager"]
|
138
|
-
return state
|
139
|
-
|
140
|
-
def __setstate__(self, state):
|
141
|
-
self.__dict__.update(state)
|
142
|
-
self._manager = None
|
@@ -27,7 +27,7 @@ import triton
|
|
27
27
|
import triton.language as tl
|
28
28
|
|
29
29
|
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
30
|
-
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
30
|
+
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
|
31
31
|
|
32
32
|
if TYPE_CHECKING:
|
33
33
|
from sglang.srt.mem_cache.memory_pool import KVCache
|
@@ -294,7 +294,6 @@ def alloc_extend_kernel(
|
|
294
294
|
last_loc_ptr,
|
295
295
|
free_page_ptr,
|
296
296
|
out_indices,
|
297
|
-
ret_values,
|
298
297
|
bs_upper: tl.constexpr,
|
299
298
|
page_size: tl.constexpr,
|
300
299
|
max_num_extend_tokens: tl.constexpr,
|
@@ -323,13 +322,6 @@ def alloc_extend_kernel(
|
|
323
322
|
sum_num_new_pages = tl.sum(num_new_pages)
|
324
323
|
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
325
324
|
|
326
|
-
# Return value
|
327
|
-
if pid == tl.num_programs(0) - 1:
|
328
|
-
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
|
329
|
-
tl.int64
|
330
|
-
)
|
331
|
-
tl.store(ret_values, merged_value)
|
332
|
-
|
333
325
|
# Part 1: fill the old partial page
|
334
326
|
last_loc = tl.load(last_loc_ptr + pid)
|
335
327
|
num_part1 = (
|
@@ -381,7 +373,6 @@ def alloc_decode_kernel(
|
|
381
373
|
last_loc_ptr,
|
382
374
|
free_page_ptr,
|
383
375
|
out_indices,
|
384
|
-
ret_values,
|
385
376
|
bs_upper: tl.constexpr,
|
386
377
|
page_size: tl.constexpr,
|
387
378
|
):
|
@@ -404,10 +395,6 @@ def alloc_decode_kernel(
|
|
404
395
|
sum_num_new_pages = tl.sum(num_new_pages)
|
405
396
|
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
406
397
|
|
407
|
-
# Return value
|
408
|
-
if pid == tl.num_programs(0) - 1:
|
409
|
-
tl.store(ret_values, sum_num_new_pages)
|
410
|
-
|
411
398
|
if num_page_start_loc_self == 0:
|
412
399
|
last_loc = tl.load(last_loc_ptr + pid)
|
413
400
|
tl.store(out_indices + pid, last_loc + 1)
|
@@ -438,7 +425,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
438
425
|
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
439
426
|
self.num_pages = size // page_size
|
440
427
|
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
441
|
-
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
442
428
|
self.seen_max_num_extend_tokens_next_power_of_2 = 1
|
443
429
|
self.clear()
|
444
430
|
|
@@ -468,7 +454,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
468
454
|
def alloc_extend(
|
469
455
|
self,
|
470
456
|
prefix_lens: torch.Tensor,
|
457
|
+
prefix_lens_cpu: torch.Tensor,
|
471
458
|
seq_lens: torch.Tensor,
|
459
|
+
seq_lens_cpu: torch.Tensor,
|
472
460
|
last_loc: torch.Tensor,
|
473
461
|
extend_num_tokens: int,
|
474
462
|
):
|
@@ -497,7 +485,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
497
485
|
last_loc,
|
498
486
|
self.free_pages,
|
499
487
|
out_indices,
|
500
|
-
self.ret_values,
|
501
488
|
next_power_of_2(bs),
|
502
489
|
self.page_size,
|
503
490
|
self.seen_max_num_extend_tokens_next_power_of_2,
|
@@ -506,8 +493,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
506
493
|
if self.debug_mode:
|
507
494
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
508
495
|
|
509
|
-
|
510
|
-
|
496
|
+
num_new_pages = get_num_new_pages(
|
497
|
+
seq_lens=seq_lens_cpu,
|
498
|
+
page_size=self.page_size,
|
499
|
+
prefix_lens=prefix_lens_cpu,
|
500
|
+
)
|
511
501
|
if num_new_pages > len(self.free_pages):
|
512
502
|
return None
|
513
503
|
|
@@ -517,6 +507,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
517
507
|
def alloc_decode(
|
518
508
|
self,
|
519
509
|
seq_lens: torch.Tensor,
|
510
|
+
seq_lens_cpu: torch.Tensor,
|
520
511
|
last_loc: torch.Tensor,
|
521
512
|
):
|
522
513
|
if self.debug_mode:
|
@@ -534,7 +525,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
534
525
|
last_loc,
|
535
526
|
self.free_pages,
|
536
527
|
out_indices,
|
537
|
-
self.ret_values,
|
538
528
|
next_power_of_2(bs),
|
539
529
|
self.page_size,
|
540
530
|
)
|
@@ -542,7 +532,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
542
532
|
if self.debug_mode:
|
543
533
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
544
534
|
|
545
|
-
num_new_pages =
|
535
|
+
num_new_pages = get_num_new_pages(
|
536
|
+
seq_lens=seq_lens_cpu,
|
537
|
+
page_size=self.page_size,
|
538
|
+
decode=True,
|
539
|
+
)
|
546
540
|
if num_new_pages > len(self.free_pages):
|
547
541
|
return None
|
548
542
|
|
@@ -1,13 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import TYPE_CHECKING
|
4
|
-
|
5
3
|
import torch
|
6
4
|
|
7
5
|
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
|
8
|
-
|
9
|
-
if TYPE_CHECKING:
|
10
|
-
from sglang.srt.mem_cache.memory_pool import KVCache
|
6
|
+
from sglang.srt.utils import get_num_new_pages
|
11
7
|
|
12
8
|
|
13
9
|
def alloc_extend_kernel_ascend(
|
@@ -69,7 +65,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
69
65
|
def alloc_extend(
|
70
66
|
self,
|
71
67
|
prefix_lens: torch.Tensor,
|
68
|
+
prefix_lens_cpu: torch.Tensor,
|
72
69
|
seq_lens: torch.Tensor,
|
70
|
+
seq_lens_cpu: torch.Tensor,
|
73
71
|
last_loc: torch.Tensor,
|
74
72
|
extend_num_tokens: int,
|
75
73
|
):
|
@@ -79,42 +77,54 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
79
77
|
)
|
80
78
|
|
81
79
|
num_new_pages = (
|
82
|
-
(
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
.item()
|
88
|
-
)
|
89
|
-
if self.need_sort and num_new_pages > len(self.free_pages):
|
80
|
+
(seq_lens + self.page_size - 1) // self.page_size
|
81
|
+
- (prefix_lens + self.page_size - 1) // self.page_size
|
82
|
+
).sum()
|
83
|
+
num_new_pages_item = num_new_pages.item()
|
84
|
+
if self.need_sort and num_new_pages_item > len(self.free_pages):
|
90
85
|
self.merge_and_sort_free()
|
91
86
|
|
92
|
-
if
|
87
|
+
if num_new_pages_item > len(self.free_pages):
|
93
88
|
return None
|
94
89
|
|
95
90
|
out_indices = torch.empty(
|
96
|
-
(extend_num_tokens,), dtype=torch.
|
91
|
+
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
97
92
|
)
|
98
93
|
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
94
|
+
if num_new_pages_item < 200:
|
95
|
+
import sgl_kernel_npu
|
96
|
+
|
97
|
+
torch.ops.npu.alloc_extend(
|
98
|
+
prefix_lens,
|
99
|
+
seq_lens,
|
100
|
+
last_loc,
|
101
|
+
self.free_pages,
|
102
|
+
self.page_size,
|
103
|
+
out_indices,
|
104
|
+
num_new_pages,
|
105
|
+
)
|
106
|
+
|
107
|
+
else:
|
108
|
+
alloc_extend_kernel_ascend(
|
109
|
+
prefix_lens,
|
110
|
+
seq_lens,
|
111
|
+
last_loc,
|
112
|
+
self.free_pages,
|
113
|
+
out_indices,
|
114
|
+
self.page_size,
|
115
|
+
self.device,
|
116
|
+
)
|
108
117
|
|
109
118
|
if self.debug_mode:
|
110
119
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
111
120
|
|
112
|
-
self.free_pages = self.free_pages[
|
121
|
+
self.free_pages = self.free_pages[num_new_pages_item:]
|
113
122
|
return out_indices
|
114
123
|
|
115
124
|
def alloc_decode(
|
116
125
|
self,
|
117
126
|
seq_lens: torch.Tensor,
|
127
|
+
seq_lens_cpu: torch.Tensor,
|
118
128
|
last_loc: torch.Tensor,
|
119
129
|
):
|
120
130
|
if self.debug_mode:
|
@@ -122,8 +132,11 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
122
132
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
123
133
|
)
|
124
134
|
|
125
|
-
|
126
|
-
|
135
|
+
num_new_pages = get_num_new_pages(
|
136
|
+
seq_lens=seq_lens_cpu,
|
137
|
+
page_size=self.page_size,
|
138
|
+
decode=True,
|
139
|
+
)
|
127
140
|
|
128
141
|
if num_new_pages > len(self.free_pages):
|
129
142
|
self.merge_and_sort_free()
|
@@ -131,6 +144,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
131
144
|
if num_new_pages > len(self.free_pages):
|
132
145
|
return None
|
133
146
|
|
147
|
+
need_new_pages = (seq_lens % self.page_size == 1).int()
|
134
148
|
end_new_pages = torch.cumsum(need_new_pages, 0)
|
135
149
|
start_new_pages = end_new_pages - need_new_pages
|
136
150
|
if num_new_pages == 0:
|
@@ -28,6 +28,13 @@ class ChunkCache(BasePrefixCache):
|
|
28
28
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
29
29
|
self.page_size = page_size
|
30
30
|
|
31
|
+
# NOTE (csy): this is to determine if a cache has prefix matching feature.
|
32
|
+
# Chunk cache always return True to indicate no prefix matching.
|
33
|
+
# TODO (csy): Using a prefix cache trait to replace this
|
34
|
+
@property
|
35
|
+
def disable(self):
|
36
|
+
return True
|
37
|
+
|
31
38
|
def reset(self):
|
32
39
|
pass
|
33
40
|
|
@@ -38,7 +45,7 @@ class ChunkCache(BasePrefixCache):
|
|
38
45
|
last_host_node=None,
|
39
46
|
)
|
40
47
|
|
41
|
-
def cache_finished_req(self, req: Req):
|
48
|
+
def cache_finished_req(self, req: Req, insert: bool = True):
|
42
49
|
kv_indices = self.req_to_token_pool.req_to_token[
|
43
50
|
req.req_pool_idx,
|
44
51
|
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import TYPE_CHECKING, List, Tuple, Union
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from sglang.srt.mem_cache.radix_cache import TreeNode
|
8
|
+
|
9
|
+
|
10
|
+
class EvictionStrategy(ABC):
|
11
|
+
@abstractmethod
|
12
|
+
def get_priority(self, node: "TreeNode") -> Union[float, Tuple]:
|
13
|
+
pass
|
14
|
+
|
15
|
+
|
16
|
+
class LRUStrategy(EvictionStrategy):
|
17
|
+
def get_priority(self, node: "TreeNode") -> float:
|
18
|
+
return node.last_access_time
|
19
|
+
|
20
|
+
|
21
|
+
class LFUStrategy(EvictionStrategy):
|
22
|
+
def get_priority(self, node: "TreeNode") -> Tuple[int, float]:
|
23
|
+
return (node.hit_count, node.last_access_time)
|