sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,7 @@ import threading
|
|
38
38
|
from enum import Enum, auto
|
39
39
|
from http import HTTPStatus
|
40
40
|
from itertools import chain
|
41
|
-
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
41
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
42
42
|
|
43
43
|
import numpy as np
|
44
44
|
import torch
|
@@ -52,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
52
52
|
ScheduleBatchDisaggregationDecodeMixin,
|
53
53
|
)
|
54
54
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
55
|
-
from sglang.srt.layers.moe import is_tbo_enabled
|
56
55
|
from sglang.srt.mem_cache.allocator import (
|
57
56
|
BaseTokenToKVPoolAllocator,
|
58
57
|
SWATokenToKVPoolAllocator,
|
@@ -60,7 +59,7 @@ from sglang.srt.mem_cache.allocator import (
|
|
60
59
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
61
60
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
62
61
|
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
|
63
|
-
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
62
|
+
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
64
63
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
65
64
|
from sglang.srt.metrics.collector import TimeStats
|
66
65
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
@@ -99,6 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
99
98
|
"sampling_backend",
|
100
99
|
"speculative_accept_threshold_single",
|
101
100
|
"speculative_accept_threshold_acc",
|
101
|
+
"speculative_attention_mode",
|
102
102
|
"torchao_config",
|
103
103
|
"triton_attention_reduce_in_fp32",
|
104
104
|
"num_reserved_decode_tokens",
|
@@ -911,7 +911,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
911
911
|
is_prefill_only: bool = False
|
912
912
|
|
913
913
|
# hicache pointer for synchronizing data loading from CPU to GPU
|
914
|
-
hicache_consumer_index: int =
|
914
|
+
hicache_consumer_index: int = -1
|
915
915
|
|
916
916
|
@classmethod
|
917
917
|
def init_new(
|
@@ -962,8 +962,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
962
962
|
def is_empty(self):
|
963
963
|
return len(self.reqs) == 0
|
964
964
|
|
965
|
-
def alloc_req_slots(self, num_reqs: int):
|
966
|
-
|
965
|
+
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
|
966
|
+
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
|
967
|
+
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
|
968
|
+
else:
|
969
|
+
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
967
970
|
if req_pool_indices is None:
|
968
971
|
raise RuntimeError(
|
969
972
|
"alloc_req_slots runs out of memory. "
|
@@ -1138,7 +1141,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1138
1141
|
|
1139
1142
|
# Allocate req slots
|
1140
1143
|
bs = len(self.reqs)
|
1141
|
-
req_pool_indices = self.alloc_req_slots(bs)
|
1144
|
+
req_pool_indices = self.alloc_req_slots(bs, self.reqs)
|
1142
1145
|
|
1143
1146
|
# Init tensors
|
1144
1147
|
reqs = self.reqs
|
@@ -1372,21 +1375,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1372
1375
|
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
1373
1376
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
1374
1377
|
|
1375
|
-
def new_page_count_next_decode(self):
|
1378
|
+
def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
|
1376
1379
|
page_size = self.token_to_kv_pool_allocator.page_size
|
1380
|
+
requests = (
|
1381
|
+
self.reqs
|
1382
|
+
if selected_indices is None
|
1383
|
+
else [self.reqs[i] for i in selected_indices]
|
1384
|
+
)
|
1377
1385
|
if page_size == 1:
|
1378
|
-
return len(
|
1386
|
+
return len(requests)
|
1379
1387
|
# In the decoding phase, the length of a request's KV cache should be
|
1380
1388
|
# the total length of the request minus 1
|
1381
1389
|
return (
|
1382
|
-
sum(1 for req in
|
1390
|
+
sum(1 for req in requests if req.seqlen % page_size == 0)
|
1383
1391
|
if self.enable_overlap
|
1384
|
-
else sum(1 for req in
|
1392
|
+
else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
|
1385
1393
|
)
|
1386
1394
|
|
1387
|
-
def check_decode_mem(
|
1395
|
+
def check_decode_mem(
|
1396
|
+
self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
|
1397
|
+
):
|
1388
1398
|
num_tokens = (
|
1389
|
-
self.new_page_count_next_decode()
|
1399
|
+
self.new_page_count_next_decode(selected_indices)
|
1390
1400
|
* buf_multiplier
|
1391
1401
|
* self.token_to_kv_pool_allocator.page_size
|
1392
1402
|
)
|
@@ -1412,34 +1422,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1412
1422
|
reverse=True,
|
1413
1423
|
)
|
1414
1424
|
|
1415
|
-
def get_required_tokens(num_reqs: int):
|
1416
|
-
headroom_for_spec_decode = 0
|
1417
|
-
if server_args.speculative_algorithm:
|
1418
|
-
headroom_for_spec_decode += (
|
1419
|
-
num_reqs
|
1420
|
-
* server_args.speculative_eagle_topk
|
1421
|
-
* server_args.speculative_num_steps
|
1422
|
-
+ num_reqs * server_args.speculative_num_draft_tokens
|
1423
|
-
)
|
1424
|
-
return (
|
1425
|
-
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
1426
|
-
)
|
1427
|
-
|
1428
|
-
def _get_available_size():
|
1429
|
-
if self.is_hybrid:
|
1430
|
-
return min(
|
1431
|
-
self.token_to_kv_pool_allocator.full_available_size(),
|
1432
|
-
self.token_to_kv_pool_allocator.swa_available_size(),
|
1433
|
-
)
|
1434
|
-
else:
|
1435
|
-
return self.token_to_kv_pool_allocator.available_size()
|
1436
|
-
|
1437
1425
|
retracted_reqs = []
|
1438
1426
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
1439
1427
|
first_iter = True
|
1440
|
-
while (
|
1441
|
-
|
1442
|
-
or first_iter
|
1428
|
+
while first_iter or (
|
1429
|
+
not self.check_decode_mem(selected_indices=sorted_indices)
|
1443
1430
|
):
|
1444
1431
|
if len(sorted_indices) == 1:
|
1445
1432
|
# Corner case: only one request left
|
@@ -1493,10 +1480,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1493
1480
|
else:
|
1494
1481
|
self.tree_cache.dec_lock_ref(req.last_node)
|
1495
1482
|
|
1496
|
-
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
1497
|
-
num_tokens = len(sorted_indices) * global_config.retract_decode_steps
|
1498
|
-
self._evict_tree_cache_if_needed(num_tokens)
|
1499
|
-
|
1500
1483
|
req.reset_for_retract()
|
1501
1484
|
|
1502
1485
|
if len(retracted_reqs) == 0:
|
@@ -1540,7 +1523,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1540
1523
|
self.forward_mode = ForwardMode.DECODE
|
1541
1524
|
bs = len(self.reqs)
|
1542
1525
|
|
1543
|
-
if self.spec_algorithm.is_eagle():
|
1526
|
+
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
1544
1527
|
# if spec decoding is used, the decode batch is prepared inside
|
1545
1528
|
# `forward_batch_speculative_generation` after running draft models.
|
1546
1529
|
return
|
@@ -1917,7 +1900,7 @@ class ModelWorkerBatch:
|
|
1917
1900
|
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
1918
1901
|
# If set, the output of the batch contains the hidden states of the run.
|
1919
1902
|
capture_hidden_mode: CaptureHiddenMode = None
|
1920
|
-
hicache_consumer_index: int =
|
1903
|
+
hicache_consumer_index: int = -1
|
1921
1904
|
|
1922
1905
|
# Overlap event
|
1923
1906
|
launch_done: Optional[threading.Event] = None
|
@@ -380,8 +380,9 @@ class PrefillAdder:
|
|
380
380
|
self.log_input_tokens += extend_input_len
|
381
381
|
|
382
382
|
def add_chunked_req(self, req: Req):
|
383
|
-
|
384
|
-
|
383
|
+
_rem_tokens = min(self.rem_chunk_tokens, int(self.rem_total_tokens))
|
384
|
+
truncated = req.extend_input_len > _rem_tokens
|
385
|
+
req.extend_input_len = min(req.extend_input_len, _rem_tokens)
|
385
386
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
386
387
|
self.can_run_list.append(req)
|
387
388
|
self._update_prefill_budget(
|
@@ -549,7 +550,7 @@ class PrefillAdder:
|
|
549
550
|
)
|
550
551
|
else:
|
551
552
|
# Make sure at least one page is available
|
552
|
-
trunc_len = self.rem_chunk_tokens
|
553
|
+
trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size
|
553
554
|
if trunc_len <= 0:
|
554
555
|
return AddReqResult.OTHER
|
555
556
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -67,6 +67,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
67
67
|
from sglang.srt.layers.moe import initialize_moe_config
|
68
68
|
from sglang.srt.managers.io_struct import (
|
69
69
|
AbortReq,
|
70
|
+
BatchTokenizedEmbeddingReqInput,
|
71
|
+
BatchTokenizedGenerateReqInput,
|
72
|
+
ClearHiCacheReqInput,
|
73
|
+
ClearHiCacheReqOutput,
|
70
74
|
CloseSessionReqInput,
|
71
75
|
ExpertDistributionReq,
|
72
76
|
ExpertDistributionReqOutput,
|
@@ -80,6 +84,8 @@ from sglang.srt.managers.io_struct import (
|
|
80
84
|
InitWeightsUpdateGroupReqInput,
|
81
85
|
LoadLoRAAdapterReqInput,
|
82
86
|
LoadLoRAAdapterReqOutput,
|
87
|
+
MultiTokenizerRegisterReq,
|
88
|
+
MultiTokenizerWrapper,
|
83
89
|
OpenSessionReqInput,
|
84
90
|
OpenSessionReqOutput,
|
85
91
|
ProfileReq,
|
@@ -135,7 +141,7 @@ from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
|
|
135
141
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
136
142
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
137
143
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
138
|
-
from sglang.srt.reasoning_parser import ReasoningParser
|
144
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
139
145
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
140
146
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
141
147
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
@@ -152,6 +158,7 @@ from sglang.srt.utils import (
|
|
152
158
|
get_zmq_socket,
|
153
159
|
is_cpu,
|
154
160
|
kill_itself_when_parent_died,
|
161
|
+
numa_bind_to_node,
|
155
162
|
point_to_point_pyobj,
|
156
163
|
pyspy_dump_schedulers,
|
157
164
|
require_mlp_sync,
|
@@ -253,7 +260,6 @@ class Scheduler(
|
|
253
260
|
# Init inter-process communication
|
254
261
|
context = zmq.Context(2)
|
255
262
|
self.idle_sleeper = None
|
256
|
-
|
257
263
|
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
258
264
|
self.recv_from_tokenizer = get_zmq_socket(
|
259
265
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
@@ -343,6 +349,18 @@ class Scheduler(
|
|
343
349
|
target_worker=self.tp_worker,
|
344
350
|
dp_rank=dp_rank,
|
345
351
|
)
|
352
|
+
elif self.spec_algorithm.is_standalone():
|
353
|
+
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
354
|
+
|
355
|
+
self.draft_worker = StandaloneWorker(
|
356
|
+
gpu_id=gpu_id,
|
357
|
+
tp_rank=tp_rank,
|
358
|
+
moe_ep_rank=moe_ep_rank,
|
359
|
+
server_args=server_args,
|
360
|
+
nccl_port=port_args.nccl_port,
|
361
|
+
target_worker=self.tp_worker,
|
362
|
+
dp_rank=dp_rank,
|
363
|
+
)
|
346
364
|
else:
|
347
365
|
self.draft_worker = None
|
348
366
|
|
@@ -396,7 +414,7 @@ class Scheduler(
|
|
396
414
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
397
415
|
f"max_running_requests={self.max_running_requests}, "
|
398
416
|
f"context_len={self.model_config.context_len}, "
|
399
|
-
f"available_gpu_mem={avail_mem:.2f} GB"
|
417
|
+
f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
|
400
418
|
)
|
401
419
|
|
402
420
|
# Init memory pool and cache
|
@@ -483,7 +501,7 @@ class Scheduler(
|
|
483
501
|
enable=server_args.enable_memory_saver
|
484
502
|
)
|
485
503
|
self.offload_tags = set()
|
486
|
-
self.
|
504
|
+
self.init_profiler()
|
487
505
|
|
488
506
|
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
|
489
507
|
self.input_blocker = (
|
@@ -495,6 +513,7 @@ class Scheduler(
|
|
495
513
|
# Init metrics stats
|
496
514
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
497
515
|
self.init_kv_events(server_args.kv_events_config)
|
516
|
+
self.init_dp_balance(dp_balance_meta)
|
498
517
|
|
499
518
|
# Init disaggregation
|
500
519
|
self.disaggregation_mode = DisaggregationMode(
|
@@ -510,7 +529,10 @@ class Scheduler(
|
|
510
529
|
[
|
511
530
|
(TokenizedGenerateReqInput, self.handle_generate_request),
|
512
531
|
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
532
|
+
(BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
|
533
|
+
(BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
|
513
534
|
(FlushCacheReqInput, self.flush_cache_wrapped),
|
535
|
+
(ClearHiCacheReqInput, self.clear_hicache_storage_wrapped),
|
514
536
|
(AbortReq, self.abort_request),
|
515
537
|
(OpenSessionReqInput, self.open_session),
|
516
538
|
(CloseSessionReqInput, self.close_session),
|
@@ -533,18 +555,10 @@ class Scheduler(
|
|
533
555
|
(ExpertDistributionReq, self.expert_distribution_handle),
|
534
556
|
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
535
557
|
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
558
|
+
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
536
559
|
]
|
537
560
|
)
|
538
561
|
|
539
|
-
self.balance_meta = dp_balance_meta
|
540
|
-
if (
|
541
|
-
server_args.enable_dp_attention
|
542
|
-
and server_args.load_balance_method == "minimum_tokens"
|
543
|
-
):
|
544
|
-
assert dp_balance_meta is not None
|
545
|
-
|
546
|
-
self.recv_dp_balance_id_this_term = []
|
547
|
-
|
548
562
|
def init_tokenizer(self):
|
549
563
|
server_args = self.server_args
|
550
564
|
self.is_generation = self.model_config.is_generation
|
@@ -621,8 +635,11 @@ class Scheduler(
|
|
621
635
|
hicache_write_policy=server_args.hicache_write_policy,
|
622
636
|
hicache_io_backend=server_args.hicache_io_backend,
|
623
637
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
638
|
+
enable_metrics=self.enable_metrics,
|
624
639
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
625
640
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
641
|
+
model_name=server_args.served_model_name,
|
642
|
+
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
626
643
|
)
|
627
644
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
628
645
|
self.tree_cache.cache_controller.layer_done_counter
|
@@ -651,6 +668,21 @@ class Scheduler(
|
|
651
668
|
page_size=self.page_size,
|
652
669
|
disable=server_args.disable_radix_cache,
|
653
670
|
)
|
671
|
+
elif server_args.enable_lmcache:
|
672
|
+
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
673
|
+
LMCRadixCache,
|
674
|
+
)
|
675
|
+
|
676
|
+
self.tree_cache = LMCRadixCache(
|
677
|
+
req_to_token_pool=self.req_to_token_pool,
|
678
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
679
|
+
page_size=self.page_size,
|
680
|
+
disable=server_args.disable_radix_cache,
|
681
|
+
model_config=self.model_config,
|
682
|
+
tp_size=self.tp_size,
|
683
|
+
rank=self.tp_rank,
|
684
|
+
tp_group=self.tp_group,
|
685
|
+
)
|
654
686
|
else:
|
655
687
|
self.tree_cache = RadixCache(
|
656
688
|
req_to_token_pool=self.req_to_token_pool,
|
@@ -1018,14 +1050,26 @@ class Scheduler(
|
|
1018
1050
|
req
|
1019
1051
|
for req in recv_reqs
|
1020
1052
|
if isinstance(
|
1021
|
-
req,
|
1053
|
+
req,
|
1054
|
+
(
|
1055
|
+
TokenizedGenerateReqInput,
|
1056
|
+
TokenizedEmbeddingReqInput,
|
1057
|
+
BatchTokenizedGenerateReqInput,
|
1058
|
+
BatchTokenizedEmbeddingReqInput,
|
1059
|
+
),
|
1022
1060
|
)
|
1023
1061
|
]
|
1024
1062
|
control_reqs = [
|
1025
1063
|
req
|
1026
1064
|
for req in recv_reqs
|
1027
1065
|
if not isinstance(
|
1028
|
-
req,
|
1066
|
+
req,
|
1067
|
+
(
|
1068
|
+
TokenizedGenerateReqInput,
|
1069
|
+
TokenizedEmbeddingReqInput,
|
1070
|
+
BatchTokenizedGenerateReqInput,
|
1071
|
+
BatchTokenizedEmbeddingReqInput,
|
1072
|
+
),
|
1029
1073
|
)
|
1030
1074
|
]
|
1031
1075
|
else:
|
@@ -1080,6 +1124,17 @@ class Scheduler(
|
|
1080
1124
|
)
|
1081
1125
|
self.send_to_tokenizer.send_pyobj(abort_req)
|
1082
1126
|
continue
|
1127
|
+
|
1128
|
+
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
|
1129
|
+
if isinstance(recv_req, MultiTokenizerWrapper):
|
1130
|
+
worker_id = recv_req.worker_id
|
1131
|
+
recv_req = recv_req.obj
|
1132
|
+
output = self._request_dispatcher(recv_req)
|
1133
|
+
if output is not None:
|
1134
|
+
output = MultiTokenizerWrapper(worker_id, output)
|
1135
|
+
self.send_to_tokenizer.send_pyobj(output)
|
1136
|
+
continue
|
1137
|
+
|
1083
1138
|
output = self._request_dispatcher(recv_req)
|
1084
1139
|
if output is not None:
|
1085
1140
|
if isinstance(output, RpcReqOutput):
|
@@ -1092,11 +1147,7 @@ class Scheduler(
|
|
1092
1147
|
self,
|
1093
1148
|
recv_req: TokenizedGenerateReqInput,
|
1094
1149
|
):
|
1095
|
-
|
1096
|
-
self.server_args.enable_dp_attention
|
1097
|
-
and self.server_args.load_balance_method == "minimum_tokens"
|
1098
|
-
):
|
1099
|
-
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
1150
|
+
self.maybe_update_dp_balance_data(recv_req)
|
1100
1151
|
|
1101
1152
|
# Create a new request
|
1102
1153
|
if (
|
@@ -1253,6 +1304,17 @@ class Scheduler(
|
|
1253
1304
|
else:
|
1254
1305
|
self._add_request_to_queue(req)
|
1255
1306
|
|
1307
|
+
def handle_batch_generate_request(
|
1308
|
+
self,
|
1309
|
+
recv_req: BatchTokenizedGenerateReqInput,
|
1310
|
+
):
|
1311
|
+
"""Handle optimized batch generate request."""
|
1312
|
+
logger.debug(f"Processing batch generate request with {len(recv_req)} requests")
|
1313
|
+
|
1314
|
+
# Process each request in the batch
|
1315
|
+
for tokenized_req in recv_req:
|
1316
|
+
self.handle_generate_request(tokenized_req)
|
1317
|
+
|
1256
1318
|
def _add_request_to_queue(self, req: Req):
|
1257
1319
|
req.queue_time_start = time.perf_counter()
|
1258
1320
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
@@ -1269,10 +1331,11 @@ class Scheduler(
|
|
1269
1331
|
def _prefetch_kvcache(self, req: Req):
|
1270
1332
|
if self.enable_hicache_storage:
|
1271
1333
|
req.init_next_round_input(self.tree_cache)
|
1272
|
-
|
1273
|
-
|
1274
|
-
|
1275
|
-
|
1334
|
+
if req.last_node.backuped:
|
1335
|
+
# only to initiate the prefetch if the last node is backuped
|
1336
|
+
# otherwise, the allocated GPU memory must be locked for integrity
|
1337
|
+
last_hash = req.last_host_node.get_last_hash_value()
|
1338
|
+
matched_len = len(req.prefix_indices) + req.host_hit_length
|
1276
1339
|
new_input_tokens = req.fill_ids[matched_len:]
|
1277
1340
|
self.tree_cache.prefetch_from_storage(
|
1278
1341
|
req.rid, req.last_host_node, new_input_tokens, last_hash
|
@@ -1335,6 +1398,19 @@ class Scheduler(
|
|
1335
1398
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1336
1399
|
self._add_request_to_queue(req)
|
1337
1400
|
|
1401
|
+
def handle_batch_embedding_request(
|
1402
|
+
self,
|
1403
|
+
recv_req: BatchTokenizedEmbeddingReqInput,
|
1404
|
+
):
|
1405
|
+
"""Handle optimized batch embedding request."""
|
1406
|
+
logger.debug(
|
1407
|
+
f"Processing batch embedding request with {len(recv_req)} requests"
|
1408
|
+
)
|
1409
|
+
|
1410
|
+
# Process each request in the batch
|
1411
|
+
for tokenized_req in recv_req:
|
1412
|
+
self.handle_embedding_request(tokenized_req)
|
1413
|
+
|
1338
1414
|
def self_check_during_idle(self):
|
1339
1415
|
self.check_memory()
|
1340
1416
|
self.check_tree_cache()
|
@@ -1362,9 +1438,11 @@ class Scheduler(
|
|
1362
1438
|
_, _, available_size, evictable_size = self._get_token_info()
|
1363
1439
|
protected_size = self.tree_cache.protected_size()
|
1364
1440
|
memory_leak = (available_size + evictable_size) != (
|
1441
|
+
# self.max_total_num_tokens
|
1442
|
+
# if not self.enable_hierarchical_cache
|
1443
|
+
# else self.max_total_num_tokens - protected_size
|
1365
1444
|
self.max_total_num_tokens
|
1366
|
-
|
1367
|
-
else self.max_total_num_tokens - protected_size
|
1445
|
+
- protected_size
|
1368
1446
|
)
|
1369
1447
|
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
1370
1448
|
|
@@ -1460,9 +1538,14 @@ class Scheduler(
|
|
1460
1538
|
# Move the chunked request out of the batch so that we can merge
|
1461
1539
|
# only finished requests to running_batch.
|
1462
1540
|
chunked_req_to_exclude.add(self.chunked_req)
|
1463
|
-
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
1541
|
+
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
1464
1542
|
# chunked request keeps its rid but will get a new req_pool_idx
|
1465
|
-
self.
|
1543
|
+
if self.tp_worker.worker.model_runner.is_hybrid_gdn:
|
1544
|
+
self.req_to_token_pool.free(
|
1545
|
+
self.chunked_req.req_pool_idx, free_mamba_cache=False
|
1546
|
+
)
|
1547
|
+
else:
|
1548
|
+
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1466
1549
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
1467
1550
|
if self.last_batch.chunked_req is not None:
|
1468
1551
|
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
|
@@ -1509,11 +1592,7 @@ class Scheduler(
|
|
1509
1592
|
|
1510
1593
|
# Handle DP attention
|
1511
1594
|
if need_dp_attn_preparation:
|
1512
|
-
|
1513
|
-
self.server_args.load_balance_method == "minimum_tokens"
|
1514
|
-
and self.forward_ct % 40 == 0
|
1515
|
-
):
|
1516
|
-
self.handle_dp_balance_data(ret)
|
1595
|
+
self.maybe_handle_dp_balance_data()
|
1517
1596
|
ret = self.prepare_mlp_sync_batch(ret)
|
1518
1597
|
|
1519
1598
|
return ret
|
@@ -1733,10 +1812,6 @@ class Scheduler(
|
|
1733
1812
|
if self.spec_algorithm.is_none():
|
1734
1813
|
model_worker_batch = batch.get_model_worker_batch()
|
1735
1814
|
|
1736
|
-
# update the consumer index of hicache to the running batch
|
1737
|
-
self.tp_worker.set_hicache_consumer(
|
1738
|
-
model_worker_batch.hicache_consumer_index
|
1739
|
-
)
|
1740
1815
|
if self.pp_group.is_last_rank:
|
1741
1816
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
1742
1817
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
@@ -1838,86 +1913,6 @@ class Scheduler(
|
|
1838
1913
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
1839
1914
|
)
|
1840
1915
|
|
1841
|
-
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
|
1842
|
-
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
|
1843
|
-
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
1844
|
-
recv_list = self.recv_dp_balance_id_this_term
|
1845
|
-
assert len(recv_list) <= 511, (
|
1846
|
-
"The number of requests received this round is too large. "
|
1847
|
-
"Please increase gather_tensor_size and onfly_info_size."
|
1848
|
-
)
|
1849
|
-
# The maximum size of the tensor used for gathering data from all workers.
|
1850
|
-
gather_tensor_size = 512
|
1851
|
-
|
1852
|
-
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
1853
|
-
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1854
|
-
recv_tensor[0] = holding_tokens_list
|
1855
|
-
recv_tensor[1] = len(
|
1856
|
-
recv_list
|
1857
|
-
) # The first element is the length of the list.
|
1858
|
-
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
|
1859
|
-
recv_list, dtype=torch.int32
|
1860
|
-
)
|
1861
|
-
|
1862
|
-
if self.tp_rank == 0:
|
1863
|
-
gathered_list = [
|
1864
|
-
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1865
|
-
for _ in range(self.balance_meta.num_workers)
|
1866
|
-
]
|
1867
|
-
else:
|
1868
|
-
gathered_list = None
|
1869
|
-
|
1870
|
-
torch.distributed.gather(
|
1871
|
-
recv_tensor, gathered_list, group=self.tp_cpu_group
|
1872
|
-
)
|
1873
|
-
|
1874
|
-
gathered_id_list_per_worker = None
|
1875
|
-
if self.tp_rank == 0:
|
1876
|
-
gathered_id_list_per_worker = []
|
1877
|
-
holding_tokens_list = []
|
1878
|
-
for tensor in gathered_list:
|
1879
|
-
holding_tokens_list.append(tensor[0].item())
|
1880
|
-
list_length = tensor[1].item()
|
1881
|
-
gathered_id_list_per_worker.append(
|
1882
|
-
tensor[2 : list_length + 2].tolist()
|
1883
|
-
)
|
1884
|
-
|
1885
|
-
return gathered_id_list_per_worker, holding_tokens_list
|
1886
|
-
|
1887
|
-
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
|
1888
|
-
meta = self.balance_meta
|
1889
|
-
|
1890
|
-
with meta.mutex:
|
1891
|
-
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
1892
|
-
assert len(new_recv_rid_lists) == len(
|
1893
|
-
onfly_list
|
1894
|
-
), "num_worker not equal"
|
1895
|
-
# 1.Check if the rid received by each worker this round is present in onfly.
|
1896
|
-
# If it is, remove the corresponding onfly item.
|
1897
|
-
worker_id = 0
|
1898
|
-
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
1899
|
-
for new_recv_rid in new_recv_rids:
|
1900
|
-
assert (
|
1901
|
-
new_recv_rid in on_fly_reqs
|
1902
|
-
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
1903
|
-
del on_fly_reqs[new_recv_rid]
|
1904
|
-
worker_id += 1
|
1905
|
-
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
1906
|
-
meta.set_shared_onfly_info(onfly_list)
|
1907
|
-
meta.set_shared_local_tokens(local_tokens)
|
1908
|
-
|
1909
|
-
holding_tokens = self.get_load()
|
1910
|
-
|
1911
|
-
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
|
1912
|
-
holding_tokens
|
1913
|
-
)
|
1914
|
-
|
1915
|
-
self.recv_dp_balance_id_this_term.clear()
|
1916
|
-
if self.tp_rank == 0: # only first worker write info
|
1917
|
-
write_shared_dp_balance_info(
|
1918
|
-
new_recv_dp_balance_id_list, holding_token_list
|
1919
|
-
)
|
1920
|
-
|
1921
1916
|
@staticmethod
|
1922
1917
|
def prepare_mlp_sync_batch_raw(
|
1923
1918
|
local_batch: ScheduleBatch,
|
@@ -2164,6 +2159,16 @@ class Scheduler(
|
|
2164
2159
|
success = self.flush_cache()
|
2165
2160
|
return FlushCacheReqOutput(success=success)
|
2166
2161
|
|
2162
|
+
def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput):
|
2163
|
+
if self.enable_hierarchical_cache:
|
2164
|
+
self.tree_cache.clear_storage_backend()
|
2165
|
+
logger.info("Hierarchical cache cleared successfully!")
|
2166
|
+
if_success = True
|
2167
|
+
else:
|
2168
|
+
logging.warning("Hierarchical cache is not enabled.")
|
2169
|
+
if_success = False
|
2170
|
+
return ClearHiCacheReqOutput(success=if_success)
|
2171
|
+
|
2167
2172
|
def flush_cache(self):
|
2168
2173
|
"""Flush the memory pool and cache."""
|
2169
2174
|
if (
|
@@ -2248,10 +2253,9 @@ class Scheduler(
|
|
2248
2253
|
"token_capacity": int(self.max_total_num_tokens),
|
2249
2254
|
}
|
2250
2255
|
|
2251
|
-
|
2252
|
-
|
2253
|
-
|
2254
|
-
)
|
2256
|
+
ret["memory_usage"]["graph"] = round(
|
2257
|
+
self.tp_worker.worker.model_runner.graph_mem_usage, 2
|
2258
|
+
)
|
2255
2259
|
|
2256
2260
|
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
2257
2261
|
ret["avg_spec_accept_length"] = (
|
@@ -2334,7 +2338,14 @@ class Scheduler(
|
|
2334
2338
|
# This only works for requests that have not started anything.
|
2335
2339
|
# We still need to send something back to TokenizerManager to clean up the state.
|
2336
2340
|
req = self.waiting_queue.pop(i)
|
2341
|
+
if self.enable_hicache_storage:
|
2342
|
+
# to release prefetch events associated with the request
|
2343
|
+
self.tree_cache.release_aborted_request(req.rid)
|
2337
2344
|
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
2345
|
+
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
2346
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
2347
|
+
self.tree_cache.cache_finished_req(req)
|
2348
|
+
|
2338
2349
|
logger.debug(f"Abort queued request. {req.rid=}")
|
2339
2350
|
|
2340
2351
|
# Delete the requests in the grammar queue
|
@@ -2414,6 +2425,10 @@ class Scheduler(
|
|
2414
2425
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
2415
2426
|
return result
|
2416
2427
|
|
2428
|
+
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
|
2429
|
+
self.send_to_detokenizer.send_pyobj(recv_req)
|
2430
|
+
return recv_req
|
2431
|
+
|
2417
2432
|
def slow_down(self, recv_req: SlowDownReqInput):
|
2418
2433
|
t = recv_req.forward_sleep_time
|
2419
2434
|
if t is not None and t <= 0:
|
@@ -2513,7 +2528,15 @@ def is_health_check_generate_req(recv_req):
|
|
2513
2528
|
|
2514
2529
|
|
2515
2530
|
def is_work_request(recv_req):
|
2516
|
-
return isinstance(
|
2531
|
+
return isinstance(
|
2532
|
+
recv_req,
|
2533
|
+
(
|
2534
|
+
TokenizedGenerateReqInput,
|
2535
|
+
TokenizedEmbeddingReqInput,
|
2536
|
+
BatchTokenizedGenerateReqInput,
|
2537
|
+
BatchTokenizedEmbeddingReqInput,
|
2538
|
+
),
|
2539
|
+
)
|
2517
2540
|
|
2518
2541
|
|
2519
2542
|
def run_scheduler_process(
|
@@ -2527,6 +2550,9 @@ def run_scheduler_process(
|
|
2527
2550
|
pipe_writer,
|
2528
2551
|
balance_meta: Optional[DPBalanceMeta] = None,
|
2529
2552
|
):
|
2553
|
+
if (numa_node := server_args.numa_node) is not None:
|
2554
|
+
numa_bind_to_node(numa_node[gpu_id])
|
2555
|
+
|
2530
2556
|
# Generate the prefix
|
2531
2557
|
prefix = ""
|
2532
2558
|
if dp_rank is not None:
|