sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -12,10 +12,11 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
"""A tensor parallel worker."""
|
15
|
+
from __future__ import annotations
|
15
16
|
|
16
17
|
import logging
|
17
18
|
import threading
|
18
|
-
from typing import Optional, Tuple, Union
|
19
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
19
20
|
|
20
21
|
import torch
|
21
22
|
|
@@ -29,8 +30,10 @@ from sglang.srt.hf_transformers_utils import (
|
|
29
30
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
30
31
|
from sglang.srt.managers.io_struct import (
|
31
32
|
GetWeightsByNameReqInput,
|
33
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
32
34
|
InitWeightsUpdateGroupReqInput,
|
33
35
|
LoadLoRAAdapterReqInput,
|
36
|
+
SendWeightsToRemoteInstanceReqInput,
|
34
37
|
UnloadLoRAAdapterReqInput,
|
35
38
|
UpdateWeightFromDiskReqInput,
|
36
39
|
UpdateWeightsFromDistributedReqInput,
|
@@ -45,6 +48,9 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
|
45
48
|
from sglang.srt.server_args import ServerArgs
|
46
49
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
47
50
|
|
51
|
+
if TYPE_CHECKING:
|
52
|
+
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
53
|
+
|
48
54
|
logger = logging.getLogger(__name__)
|
49
55
|
|
50
56
|
|
@@ -78,7 +84,13 @@ class TpModelWorker:
|
|
78
84
|
if not is_draft_worker
|
79
85
|
else server_args.speculative_draft_model_path
|
80
86
|
),
|
87
|
+
model_revision=(
|
88
|
+
server_args.revision
|
89
|
+
if not is_draft_worker
|
90
|
+
else server_args.speculative_draft_model_revision
|
91
|
+
),
|
81
92
|
is_draft_model=is_draft_worker,
|
93
|
+
tp_rank=tp_rank,
|
82
94
|
)
|
83
95
|
|
84
96
|
self.model_runner = ModelRunner(
|
@@ -137,7 +149,7 @@ class TpModelWorker:
|
|
137
149
|
assert self.max_running_requests > 0, "max_running_request is zero"
|
138
150
|
self.max_queued_requests = server_args.max_queued_requests
|
139
151
|
assert (
|
140
|
-
self.
|
152
|
+
self.max_queued_requests > 0
|
141
153
|
), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
|
142
154
|
self.max_req_len = min(
|
143
155
|
self.model_config.context_len - 1,
|
@@ -162,10 +174,10 @@ class TpModelWorker:
|
|
162
174
|
|
163
175
|
self.hicache_layer_transfer_counter = None
|
164
176
|
|
165
|
-
def register_hicache_layer_transfer_counter(self, counter):
|
177
|
+
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
166
178
|
self.hicache_layer_transfer_counter = counter
|
167
179
|
|
168
|
-
def set_hicache_consumer(self, consumer_index):
|
180
|
+
def set_hicache_consumer(self, consumer_index: int):
|
169
181
|
if self.hicache_layer_transfer_counter is not None:
|
170
182
|
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
171
183
|
|
@@ -225,6 +237,9 @@ class TpModelWorker:
|
|
225
237
|
) -> Tuple[
|
226
238
|
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
|
227
239
|
]:
|
240
|
+
# update the consumer index of hicache to the running batch
|
241
|
+
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
242
|
+
|
228
243
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
229
244
|
|
230
245
|
pp_proxy_tensors = None
|
@@ -244,6 +259,15 @@ class TpModelWorker:
|
|
244
259
|
|
245
260
|
if skip_sample:
|
246
261
|
next_token_ids = None
|
262
|
+
# For prefill-only requests, we still need to compute logprobs even when sampling is skipped
|
263
|
+
if (
|
264
|
+
model_worker_batch.is_prefill_only
|
265
|
+
and model_worker_batch.return_logprob
|
266
|
+
):
|
267
|
+
# Compute logprobs without full sampling
|
268
|
+
self.model_runner.compute_logprobs_only(
|
269
|
+
logits_output, model_worker_batch
|
270
|
+
)
|
247
271
|
else:
|
248
272
|
next_token_ids = self.model_runner.sample(
|
249
273
|
logits_output, model_worker_batch
|
@@ -280,6 +304,31 @@ class TpModelWorker:
|
|
280
304
|
)
|
281
305
|
return success, message
|
282
306
|
|
307
|
+
def init_weights_send_group_for_remote_instance(
|
308
|
+
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
309
|
+
):
|
310
|
+
success, message = (
|
311
|
+
self.model_runner.init_weights_send_group_for_remote_instance(
|
312
|
+
recv_req.master_address,
|
313
|
+
recv_req.ports,
|
314
|
+
recv_req.group_rank,
|
315
|
+
recv_req.world_size,
|
316
|
+
recv_req.group_name,
|
317
|
+
recv_req.backend,
|
318
|
+
)
|
319
|
+
)
|
320
|
+
return success, message
|
321
|
+
|
322
|
+
def send_weights_to_remote_instance(
|
323
|
+
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
324
|
+
):
|
325
|
+
success, message = self.model_runner.send_weights_to_remote_instance(
|
326
|
+
recv_req.master_address,
|
327
|
+
recv_req.ports,
|
328
|
+
recv_req.group_name,
|
329
|
+
)
|
330
|
+
return success, message
|
331
|
+
|
283
332
|
def update_weights_from_distributed(
|
284
333
|
self, recv_req: UpdateWeightsFromDistributedReqInput
|
285
334
|
):
|
@@ -12,21 +12,24 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
"""A tensor parallel worker."""
|
15
|
+
from __future__ import annotations
|
15
16
|
|
16
17
|
import dataclasses
|
17
18
|
import logging
|
18
19
|
import signal
|
19
20
|
import threading
|
20
21
|
from queue import Queue
|
21
|
-
from typing import Optional, Tuple
|
22
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
22
23
|
|
23
24
|
import psutil
|
24
25
|
import torch
|
25
26
|
|
26
27
|
from sglang.srt.managers.io_struct import (
|
27
28
|
GetWeightsByNameReqInput,
|
29
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
28
30
|
InitWeightsUpdateGroupReqInput,
|
29
31
|
LoadLoRAAdapterReqInput,
|
32
|
+
SendWeightsToRemoteInstanceReqInput,
|
30
33
|
UnloadLoRAAdapterReqInput,
|
31
34
|
UpdateWeightFromDiskReqInput,
|
32
35
|
UpdateWeightsFromDistributedReqInput,
|
@@ -38,6 +41,9 @@ from sglang.srt.server_args import ServerArgs
|
|
38
41
|
from sglang.srt.utils import DynamicGradMode, get_compiler_backend
|
39
42
|
from sglang.utils import get_exception_traceback
|
40
43
|
|
44
|
+
if TYPE_CHECKING:
|
45
|
+
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
46
|
+
|
41
47
|
logger = logging.getLogger(__name__)
|
42
48
|
|
43
49
|
|
@@ -79,7 +85,7 @@ class TpModelWorkerClient:
|
|
79
85
|
)
|
80
86
|
|
81
87
|
# Launch threads
|
82
|
-
self.input_queue = Queue()
|
88
|
+
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
|
83
89
|
self.output_queue = Queue()
|
84
90
|
self.forward_stream = torch.get_device_module(self.device).Stream()
|
85
91
|
self.forward_thread = threading.Thread(
|
@@ -93,13 +99,9 @@ class TpModelWorkerClient:
|
|
93
99
|
|
94
100
|
self.hicache_layer_transfer_counter = None
|
95
101
|
|
96
|
-
def register_hicache_layer_transfer_counter(self, counter):
|
102
|
+
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
97
103
|
self.hicache_layer_transfer_counter = counter
|
98
104
|
|
99
|
-
def set_hicache_consumer(self, consumer_index):
|
100
|
-
if self.hicache_layer_transfer_counter is not None:
|
101
|
-
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
102
|
-
|
103
105
|
def get_worker_info(self):
|
104
106
|
return self.worker.get_worker_info()
|
105
107
|
|
@@ -147,7 +149,7 @@ class TpModelWorkerClient:
|
|
147
149
|
@DynamicGradMode()
|
148
150
|
def forward_thread_func_(self):
|
149
151
|
batch_pt = 0
|
150
|
-
batch_lists = [None] * 2
|
152
|
+
batch_lists: List = [None] * 2
|
151
153
|
|
152
154
|
while True:
|
153
155
|
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
|
@@ -169,26 +171,31 @@ class TpModelWorkerClient:
|
|
169
171
|
input_ids = model_worker_batch.input_ids
|
170
172
|
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
171
173
|
|
172
|
-
# update the consumer index of hicache to the running batch
|
173
|
-
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
174
174
|
# Run forward
|
175
175
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
176
176
|
self.worker.forward_batch_generation(
|
177
|
-
model_worker_batch,
|
177
|
+
model_worker_batch,
|
178
|
+
model_worker_batch.launch_done,
|
179
|
+
# Skip sampling for prefill-only requests
|
180
|
+
skip_sample=model_worker_batch.is_prefill_only,
|
178
181
|
)
|
179
182
|
)
|
180
183
|
|
181
184
|
# Update the future token ids map
|
182
185
|
bs = len(model_worker_batch.seq_lens)
|
186
|
+
if model_worker_batch.is_prefill_only:
|
187
|
+
# For prefill-only requests, create dummy token IDs on CPU
|
188
|
+
next_token_ids = torch.zeros(bs, dtype=torch.long)
|
183
189
|
self.future_token_ids_map[
|
184
190
|
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
|
185
191
|
] = next_token_ids
|
186
192
|
|
187
193
|
# Copy results to the CPU
|
188
194
|
if model_worker_batch.return_logprob:
|
189
|
-
logits_output.next_token_logprobs
|
190
|
-
logits_output.next_token_logprobs
|
191
|
-
|
195
|
+
if logits_output.next_token_logprobs is not None:
|
196
|
+
logits_output.next_token_logprobs = (
|
197
|
+
logits_output.next_token_logprobs.to("cpu", non_blocking=True)
|
198
|
+
)
|
192
199
|
if logits_output.input_token_logprobs is not None:
|
193
200
|
logits_output.input_token_logprobs = (
|
194
201
|
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
@@ -197,7 +204,9 @@ class TpModelWorkerClient:
|
|
197
204
|
logits_output.hidden_states = logits_output.hidden_states.to(
|
198
205
|
"cpu", non_blocking=True
|
199
206
|
)
|
200
|
-
|
207
|
+
# Only copy to CPU if not already on CPU
|
208
|
+
if next_token_ids.device.type != "cpu":
|
209
|
+
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
201
210
|
copy_done.record()
|
202
211
|
|
203
212
|
self.output_queue.put(
|
@@ -221,10 +230,10 @@ class TpModelWorkerClient:
|
|
221
230
|
logits_output.next_token_logprobs = (
|
222
231
|
logits_output.next_token_logprobs.tolist()
|
223
232
|
)
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
233
|
+
if logits_output.input_token_logprobs is not None:
|
234
|
+
logits_output.input_token_logprobs = tuple(
|
235
|
+
logits_output.input_token_logprobs.tolist()
|
236
|
+
)
|
228
237
|
next_token_ids = next_token_ids.tolist()
|
229
238
|
return logits_output, next_token_ids, can_run_cuda_graph
|
230
239
|
|
@@ -269,6 +278,20 @@ class TpModelWorkerClient:
|
|
269
278
|
success, message = self.worker.init_weights_update_group(recv_req)
|
270
279
|
return success, message
|
271
280
|
|
281
|
+
def init_weights_send_group_for_remote_instance(
|
282
|
+
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
283
|
+
):
|
284
|
+
success, message = self.worker.init_weights_send_group_for_remote_instance(
|
285
|
+
recv_req
|
286
|
+
)
|
287
|
+
return success, message
|
288
|
+
|
289
|
+
def send_weights_to_remote_instance(
|
290
|
+
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
291
|
+
):
|
292
|
+
success, message = self.worker.send_weights_to_remote_instance(recv_req)
|
293
|
+
return success, message
|
294
|
+
|
272
295
|
def update_weights_from_distributed(
|
273
296
|
self, recv_req: UpdateWeightsFromDistributedReqInput
|
274
297
|
):
|
@@ -103,20 +103,6 @@ class HiCacheStorage(ABC):
|
|
103
103
|
"""
|
104
104
|
pass
|
105
105
|
|
106
|
-
@abstractmethod
|
107
|
-
def delete(self, key: str) -> bool:
|
108
|
-
"""
|
109
|
-
Delete the entry associated with the given key.
|
110
|
-
"""
|
111
|
-
pass
|
112
|
-
|
113
|
-
@abstractmethod
|
114
|
-
def clear(self) -> bool:
|
115
|
-
"""
|
116
|
-
Clear all entries in the storage.
|
117
|
-
"""
|
118
|
-
pass
|
119
|
-
|
120
106
|
def batch_exists(self, keys: List[str]) -> int:
|
121
107
|
"""
|
122
108
|
Check if the keys exist in the storage.
|
@@ -128,6 +114,9 @@ class HiCacheStorage(ABC):
|
|
128
114
|
return i
|
129
115
|
return len(keys)
|
130
116
|
|
117
|
+
def get_stats(self):
|
118
|
+
return None
|
119
|
+
|
131
120
|
|
132
121
|
class HiCacheFile(HiCacheStorage):
|
133
122
|
|
@@ -224,15 +213,6 @@ class HiCacheFile(HiCacheStorage):
|
|
224
213
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
225
214
|
return os.path.exists(tensor_path)
|
226
215
|
|
227
|
-
def delete(self, key: str) -> None:
|
228
|
-
key = self._get_suffixed_key(key)
|
229
|
-
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
230
|
-
try:
|
231
|
-
os.remove(tensor_path)
|
232
|
-
except FileNotFoundError:
|
233
|
-
logger.warning(f"Key {key} does not exist. Cannot delete.")
|
234
|
-
return
|
235
|
-
|
236
216
|
def clear(self) -> bool:
|
237
217
|
try:
|
238
218
|
for filename in os.listdir(self.file_path):
|
@@ -20,6 +20,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
|
|
20
20
|
MLATokenToKVPoolHost,
|
21
21
|
)
|
22
22
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
23
|
+
from sglang.srt.metrics.collector import StorageMetricsCollector
|
23
24
|
|
24
25
|
logger = logging.getLogger(__name__)
|
25
26
|
|
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
|
|
37
38
|
hicache_write_policy: str,
|
38
39
|
hicache_io_backend: str,
|
39
40
|
hicache_mem_layout: str,
|
41
|
+
enable_metrics: bool,
|
40
42
|
hicache_storage_backend: Optional[str] = None,
|
41
43
|
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
42
44
|
model_name: Optional[str] = None,
|
@@ -73,6 +75,8 @@ class HiRadixCache(RadixCache):
|
|
73
75
|
self.tp_group = tp_cache_group
|
74
76
|
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
75
77
|
self.enable_storage = hicache_storage_backend is not None
|
78
|
+
self.enable_storage_metrics = self.enable_storage and enable_metrics
|
79
|
+
|
76
80
|
# todo: customizable storage prefetch threshold and timeout
|
77
81
|
self.prefetch_threshold = 256
|
78
82
|
self.prefetch_timeout = 3 # seconds
|
@@ -92,6 +96,14 @@ class HiRadixCache(RadixCache):
|
|
92
96
|
model_name=model_name,
|
93
97
|
storage_backend_extra_config=storage_backend_extra_config,
|
94
98
|
)
|
99
|
+
if self.enable_storage_metrics:
|
100
|
+
# TODO: support pp
|
101
|
+
labels = {
|
102
|
+
"storage_backend": hicache_storage_backend,
|
103
|
+
"tp_rank": self.cache_controller.tp_rank,
|
104
|
+
"dp_rank": self.cache_controller.dp_rank,
|
105
|
+
}
|
106
|
+
self.metrics_collector = StorageMetricsCollector(labels=labels)
|
95
107
|
|
96
108
|
# record the nodes with ongoing write through
|
97
109
|
self.ongoing_write_through = {}
|
@@ -122,11 +134,24 @@ class HiRadixCache(RadixCache):
|
|
122
134
|
height += 1
|
123
135
|
return height
|
124
136
|
|
125
|
-
def clear_storage_backend(self):
|
137
|
+
def clear_storage_backend(self) -> bool:
|
126
138
|
if self.enable_storage:
|
127
|
-
|
128
|
-
|
129
|
-
|
139
|
+
try:
|
140
|
+
# Check if the storage backend has a clear method (for nixl backends)
|
141
|
+
if hasattr(self.cache_controller.storage_backend, "clear"):
|
142
|
+
self.cache_controller.storage_backend.clear()
|
143
|
+
logger.info(
|
144
|
+
"Hierarchical cache storage backend cleared successfully!"
|
145
|
+
)
|
146
|
+
return True
|
147
|
+
else:
|
148
|
+
logger.warning(
|
149
|
+
f"Storage backend {type(self.cache_controller.storage_backend).__name__} does not support clear operation."
|
150
|
+
)
|
151
|
+
return False
|
152
|
+
except Exception as e:
|
153
|
+
logger.error(f"Failed to clear hierarchical cache storage backend: {e}")
|
154
|
+
return False
|
130
155
|
else:
|
131
156
|
logger.warning("Hierarchical cache storage backend is not enabled.")
|
132
157
|
return False
|
@@ -176,41 +201,57 @@ class HiRadixCache(RadixCache):
|
|
176
201
|
if write_back:
|
177
202
|
# blocking till all write back complete
|
178
203
|
while len(self.ongoing_write_through) > 0:
|
179
|
-
|
180
|
-
|
204
|
+
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
205
|
+
finish_event.synchronize()
|
206
|
+
for ack_id in ack_list:
|
207
|
+
del self.ongoing_write_through[ack_id]
|
208
|
+
self.cache_controller.ack_write_queue.clear()
|
209
|
+
assert len(self.ongoing_write_through) == 0
|
181
210
|
return
|
182
|
-
|
183
|
-
|
184
|
-
)
|
211
|
+
|
212
|
+
# NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
|
213
|
+
if len(self.ongoing_write_through) == 0:
|
214
|
+
return
|
215
|
+
|
216
|
+
finish_count = 0
|
217
|
+
for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
|
218
|
+
if not finish_event.query():
|
219
|
+
break
|
220
|
+
finish_count += 1
|
221
|
+
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
|
185
222
|
if self.tp_world_size > 1:
|
186
|
-
#
|
223
|
+
# synchronize TP workers to make the same update to radix cache
|
187
224
|
torch.distributed.all_reduce(
|
188
225
|
queue_size,
|
189
226
|
op=torch.distributed.ReduceOp.MIN,
|
190
227
|
group=self.tp_group,
|
191
228
|
)
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
self.
|
196
|
-
|
197
|
-
|
198
|
-
self.
|
229
|
+
|
230
|
+
finish_count = int(queue_size.item())
|
231
|
+
while finish_count > 0:
|
232
|
+
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
233
|
+
finish_event.synchronize()
|
234
|
+
for ack_id in ack_list:
|
235
|
+
backuped_node = self.ongoing_write_through.pop(ack_id)
|
236
|
+
self.dec_lock_ref(backuped_node)
|
237
|
+
if self.enable_storage:
|
238
|
+
self.write_backup_storage(backuped_node)
|
239
|
+
finish_count -= 1
|
199
240
|
|
200
241
|
def loading_check(self):
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
self.dec_lock_ref(end_node)
|
206
|
-
while end_node != start_node:
|
207
|
-
assert end_node.loading
|
208
|
-
end_node.loading = False
|
209
|
-
end_node = end_node.parent
|
210
|
-
# clear the reference
|
211
|
-
del self.ongoing_load_back[ack_id]
|
212
|
-
except Exception:
|
242
|
+
finish_count = 0
|
243
|
+
for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
|
244
|
+
if not finish_event.query():
|
245
|
+
# the KV cache loading is still ongoing
|
213
246
|
break
|
247
|
+
finish_count += 1
|
248
|
+
# no need to sync across TP workers as batch forwarding is synced
|
249
|
+
for ack_id in ack_list:
|
250
|
+
end_node = self.ongoing_load_back.pop(ack_id)
|
251
|
+
self.dec_lock_ref(end_node)
|
252
|
+
|
253
|
+
# ACK until all events are processed
|
254
|
+
del self.cache_controller.ack_load_queue[:finish_count]
|
214
255
|
|
215
256
|
def evictable_size(self):
|
216
257
|
return self.evictable_size_
|
@@ -335,12 +376,11 @@ class HiRadixCache(RadixCache):
|
|
335
376
|
# no sufficient GPU memory to load back KV caches
|
336
377
|
return None
|
337
378
|
|
338
|
-
self.ongoing_load_back[last_hit_node.id] =
|
379
|
+
self.ongoing_load_back[last_hit_node.id] = last_hit_node
|
339
380
|
offset = 0
|
340
381
|
for node in nodes_to_load:
|
341
382
|
node.value = device_indices[offset : offset + len(node.host_value)]
|
342
383
|
offset += len(node.host_value)
|
343
|
-
node.loading = True
|
344
384
|
self.evictable_size_ += len(device_indices)
|
345
385
|
self.inc_lock_ref(last_hit_node)
|
346
386
|
|
@@ -369,16 +409,22 @@ class HiRadixCache(RadixCache):
|
|
369
409
|
last_node,
|
370
410
|
)
|
371
411
|
|
372
|
-
def ready_to_load_host_cache(self):
|
373
|
-
|
374
|
-
|
375
|
-
|
412
|
+
def ready_to_load_host_cache(self) -> int:
|
413
|
+
"""
|
414
|
+
Notify the cache controller to start the KV cache loading.
|
415
|
+
Return the consumer index for the schedule batch manager to track.
|
416
|
+
"""
|
417
|
+
return self.cache_controller.start_loading()
|
376
418
|
|
377
419
|
def check_hicache_events(self):
|
378
420
|
self.writing_check()
|
379
421
|
self.loading_check()
|
380
422
|
if self.enable_storage:
|
381
423
|
self.drain_storage_control_queues()
|
424
|
+
if self.enable_storage_metrics:
|
425
|
+
self.metrics_collector.log_storage_metrics(
|
426
|
+
self.cache_controller.storage_backend.get_stats()
|
427
|
+
)
|
382
428
|
|
383
429
|
def drain_storage_control_queues(self):
|
384
430
|
"""
|
@@ -414,10 +460,13 @@ class HiRadixCache(RadixCache):
|
|
414
460
|
|
415
461
|
# process backup acks
|
416
462
|
for _ in range(n_backup):
|
417
|
-
|
463
|
+
operation = cc.ack_backup_queue.get()
|
464
|
+
ack_id = operation.id
|
418
465
|
entry = self.ongoing_backup.pop(ack_id, None)
|
419
466
|
if entry is not None:
|
420
467
|
entry.release_host()
|
468
|
+
if self.enable_storage_metrics:
|
469
|
+
self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
|
421
470
|
|
422
471
|
# release host memory
|
423
472
|
host_indices_list = []
|
@@ -450,15 +499,22 @@ class HiRadixCache(RadixCache):
|
|
450
499
|
# unknown prefetch stop policy, just return True
|
451
500
|
return True
|
452
501
|
|
502
|
+
operation_terminated = operation.is_terminated()
|
453
503
|
if self.tp_world_size > 1:
|
454
|
-
|
504
|
+
states = torch.tensor(
|
505
|
+
[1 - int(can_terminate), int(operation_terminated)],
|
506
|
+
dtype=torch.int,
|
507
|
+
)
|
455
508
|
torch.distributed.all_reduce(
|
456
|
-
|
457
|
-
op=torch.distributed.ReduceOp.
|
509
|
+
states,
|
510
|
+
op=torch.distributed.ReduceOp.MAX,
|
458
511
|
group=self.tp_group,
|
459
512
|
)
|
460
|
-
can_terminate =
|
461
|
-
|
513
|
+
can_terminate = states[0].item() == 0
|
514
|
+
operation_terminated = states[1].item() == 1
|
515
|
+
# the operation should be terminated if it is already terminated on any TP worker
|
516
|
+
# or it meets the termination condition on all TP workers
|
517
|
+
can_terminate = can_terminate or operation_terminated
|
462
518
|
return can_terminate
|
463
519
|
|
464
520
|
def check_prefetch_progress(self, req_id: str) -> bool:
|
@@ -485,7 +541,7 @@ class HiRadixCache(RadixCache):
|
|
485
541
|
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
486
542
|
|
487
543
|
min_completed_tokens = completed_tokens
|
488
|
-
if self.tp_world_size > 1
|
544
|
+
if self.tp_world_size > 1:
|
489
545
|
# synchrnoize TP workers to make the same update to hiradix cache
|
490
546
|
completed_tokens_tensor = torch.tensor(
|
491
547
|
min_completed_tokens, dtype=torch.int
|
@@ -515,6 +571,11 @@ class HiRadixCache(RadixCache):
|
|
515
571
|
del self.ongoing_prefetch[req_id]
|
516
572
|
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
517
573
|
|
574
|
+
if self.enable_storage_metrics:
|
575
|
+
self.metrics_collector.log_prefetched_tokens(
|
576
|
+
min_completed_tokens - matched_length
|
577
|
+
)
|
578
|
+
|
518
579
|
return True
|
519
580
|
|
520
581
|
def match_prefix(self, key: List[int], **kwargs):
|
@@ -658,7 +719,6 @@ class HiRadixCache(RadixCache):
|
|
658
719
|
new_node.parent = child.parent
|
659
720
|
new_node.lock_ref = child.lock_ref
|
660
721
|
new_node.key = child.key[:split_len]
|
661
|
-
new_node.loading = child.loading
|
662
722
|
new_node.hit_count = child.hit_count
|
663
723
|
|
664
724
|
# split value and host value if exists
|