sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -12,10 +12,11 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
"""A tensor parallel worker."""
|
15
|
+
from __future__ import annotations
|
15
16
|
|
16
17
|
import logging
|
17
18
|
import threading
|
18
|
-
from typing import Optional, Tuple, Union
|
19
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
19
20
|
|
20
21
|
import torch
|
21
22
|
|
@@ -45,6 +46,9 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
|
45
46
|
from sglang.srt.server_args import ServerArgs
|
46
47
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
47
48
|
|
49
|
+
if TYPE_CHECKING:
|
50
|
+
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
51
|
+
|
48
52
|
logger = logging.getLogger(__name__)
|
49
53
|
|
50
54
|
|
@@ -78,6 +82,11 @@ class TpModelWorker:
|
|
78
82
|
if not is_draft_worker
|
79
83
|
else server_args.speculative_draft_model_path
|
80
84
|
),
|
85
|
+
model_revision=(
|
86
|
+
server_args.revision
|
87
|
+
if not is_draft_worker
|
88
|
+
else server_args.speculative_draft_model_revision
|
89
|
+
),
|
81
90
|
is_draft_model=is_draft_worker,
|
82
91
|
)
|
83
92
|
|
@@ -137,7 +146,7 @@ class TpModelWorker:
|
|
137
146
|
assert self.max_running_requests > 0, "max_running_request is zero"
|
138
147
|
self.max_queued_requests = server_args.max_queued_requests
|
139
148
|
assert (
|
140
|
-
self.
|
149
|
+
self.max_queued_requests > 0
|
141
150
|
), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
|
142
151
|
self.max_req_len = min(
|
143
152
|
self.model_config.context_len - 1,
|
@@ -162,10 +171,10 @@ class TpModelWorker:
|
|
162
171
|
|
163
172
|
self.hicache_layer_transfer_counter = None
|
164
173
|
|
165
|
-
def register_hicache_layer_transfer_counter(self, counter):
|
174
|
+
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
166
175
|
self.hicache_layer_transfer_counter = counter
|
167
176
|
|
168
|
-
def set_hicache_consumer(self, consumer_index):
|
177
|
+
def set_hicache_consumer(self, consumer_index: int):
|
169
178
|
if self.hicache_layer_transfer_counter is not None:
|
170
179
|
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
171
180
|
|
@@ -225,6 +234,9 @@ class TpModelWorker:
|
|
225
234
|
) -> Tuple[
|
226
235
|
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
|
227
236
|
]:
|
237
|
+
# update the consumer index of hicache to the running batch
|
238
|
+
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
239
|
+
|
228
240
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
229
241
|
|
230
242
|
pp_proxy_tensors = None
|
@@ -12,13 +12,14 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
"""A tensor parallel worker."""
|
15
|
+
from __future__ import annotations
|
15
16
|
|
16
17
|
import dataclasses
|
17
18
|
import logging
|
18
19
|
import signal
|
19
20
|
import threading
|
20
21
|
from queue import Queue
|
21
|
-
from typing import Optional, Tuple
|
22
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
22
23
|
|
23
24
|
import psutil
|
24
25
|
import torch
|
@@ -38,6 +39,9 @@ from sglang.srt.server_args import ServerArgs
|
|
38
39
|
from sglang.srt.utils import DynamicGradMode, get_compiler_backend
|
39
40
|
from sglang.utils import get_exception_traceback
|
40
41
|
|
42
|
+
if TYPE_CHECKING:
|
43
|
+
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
44
|
+
|
41
45
|
logger = logging.getLogger(__name__)
|
42
46
|
|
43
47
|
|
@@ -79,7 +83,7 @@ class TpModelWorkerClient:
|
|
79
83
|
)
|
80
84
|
|
81
85
|
# Launch threads
|
82
|
-
self.input_queue = Queue()
|
86
|
+
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
|
83
87
|
self.output_queue = Queue()
|
84
88
|
self.forward_stream = torch.get_device_module(self.device).Stream()
|
85
89
|
self.forward_thread = threading.Thread(
|
@@ -93,13 +97,9 @@ class TpModelWorkerClient:
|
|
93
97
|
|
94
98
|
self.hicache_layer_transfer_counter = None
|
95
99
|
|
96
|
-
def register_hicache_layer_transfer_counter(self, counter):
|
100
|
+
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
97
101
|
self.hicache_layer_transfer_counter = counter
|
98
102
|
|
99
|
-
def set_hicache_consumer(self, consumer_index):
|
100
|
-
if self.hicache_layer_transfer_counter is not None:
|
101
|
-
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
102
|
-
|
103
103
|
def get_worker_info(self):
|
104
104
|
return self.worker.get_worker_info()
|
105
105
|
|
@@ -147,7 +147,7 @@ class TpModelWorkerClient:
|
|
147
147
|
@DynamicGradMode()
|
148
148
|
def forward_thread_func_(self):
|
149
149
|
batch_pt = 0
|
150
|
-
batch_lists = [None] * 2
|
150
|
+
batch_lists: List = [None] * 2
|
151
151
|
|
152
152
|
while True:
|
153
153
|
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
|
@@ -169,8 +169,6 @@ class TpModelWorkerClient:
|
|
169
169
|
input_ids = model_worker_batch.input_ids
|
170
170
|
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
171
171
|
|
172
|
-
# update the consumer index of hicache to the running batch
|
173
|
-
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
174
172
|
# Run forward
|
175
173
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
176
174
|
self.worker.forward_batch_generation(
|
@@ -283,7 +283,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
283
283
|
self.swa_attn_allocator.clear()
|
284
284
|
self.full_attn_allocator.clear()
|
285
285
|
self.full_to_swa_index_mapping.fill_(0)
|
286
|
-
self.
|
286
|
+
self.is_not_in_free_group = True
|
287
287
|
self.free_group = []
|
288
288
|
|
289
289
|
|
@@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache):
|
|
47
47
|
self.req_to_token_pool.free(req.req_pool_idx)
|
48
48
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
49
49
|
|
50
|
-
def cache_unfinished_req(self, req: Req):
|
50
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
51
51
|
kv_indices = self.req_to_token_pool.req_to_token[
|
52
52
|
req.req_pool_idx, : len(req.fill_ids)
|
53
53
|
]
|
@@ -2,6 +2,7 @@ import hashlib
|
|
2
2
|
import logging
|
3
3
|
import os
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
+
from dataclasses import dataclass
|
5
6
|
from typing import Any, List, Optional
|
6
7
|
|
7
8
|
import torch
|
@@ -9,17 +10,6 @@ import torch
|
|
9
10
|
logger = logging.getLogger(__name__)
|
10
11
|
|
11
12
|
|
12
|
-
from sglang.srt.distributed import (
|
13
|
-
get_tensor_model_parallel_rank,
|
14
|
-
get_tensor_model_parallel_world_size,
|
15
|
-
)
|
16
|
-
from sglang.srt.layers.dp_attention import (
|
17
|
-
get_attention_tp_rank,
|
18
|
-
get_attention_tp_size,
|
19
|
-
is_dp_attention_enabled,
|
20
|
-
)
|
21
|
-
|
22
|
-
|
23
13
|
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
24
14
|
hasher = hashlib.sha256()
|
25
15
|
|
@@ -32,6 +22,16 @@ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
|
32
22
|
return hasher.hexdigest()
|
33
23
|
|
34
24
|
|
25
|
+
@dataclass
|
26
|
+
class HiCacheStorageConfig:
|
27
|
+
tp_rank: int
|
28
|
+
tp_size: int
|
29
|
+
is_mla_model: bool
|
30
|
+
is_page_first_layout: bool
|
31
|
+
model_name: Optional[str]
|
32
|
+
extra_config: Optional[dict] = None
|
33
|
+
|
34
|
+
|
35
35
|
class HiCacheStorage(ABC):
|
36
36
|
"""
|
37
37
|
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
|
@@ -60,7 +60,7 @@ class HiCacheStorage(ABC):
|
|
60
60
|
keys: List[str],
|
61
61
|
target_locations: Optional[Any] = None,
|
62
62
|
target_sizes: Optional[Any] = None,
|
63
|
-
) -> List[torch.Tensor | None]:
|
63
|
+
) -> List[torch.Tensor | None] | int:
|
64
64
|
"""
|
65
65
|
Retrieve values for multiple keys.
|
66
66
|
Returns a list of tensors or None for each key.
|
@@ -96,32 +96,53 @@ class HiCacheStorage(ABC):
|
|
96
96
|
pass
|
97
97
|
|
98
98
|
@abstractmethod
|
99
|
-
def exists(self, key: str) -> bool
|
99
|
+
def exists(self, key: str) -> bool:
|
100
100
|
"""
|
101
101
|
Check if the key exists in the storage.
|
102
102
|
Returns True if the key exists, False otherwise.
|
103
103
|
"""
|
104
104
|
pass
|
105
105
|
|
106
|
+
def batch_exists(self, keys: List[str]) -> int:
|
107
|
+
"""
|
108
|
+
Check if the keys exist in the storage.
|
109
|
+
return the number of consecutive existing keys from the start.
|
110
|
+
Can be overridden by subclasses for more efficient implementation.
|
111
|
+
"""
|
112
|
+
for i in range(len(keys)):
|
113
|
+
if not self.exists(keys[i]):
|
114
|
+
return i
|
115
|
+
return len(keys)
|
116
|
+
|
117
|
+
def get_stats(self):
|
118
|
+
return None
|
119
|
+
|
106
120
|
|
107
121
|
class HiCacheFile(HiCacheStorage):
|
108
122
|
|
109
|
-
def __init__(
|
123
|
+
def __init__(
|
124
|
+
self, storage_config: HiCacheStorageConfig, file_path: str = "/tmp/hicache"
|
125
|
+
):
|
110
126
|
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
111
|
-
|
112
|
-
|
113
|
-
|
127
|
+
|
128
|
+
tp_rank, tp_size, model_name, is_mla_model = (
|
129
|
+
storage_config.tp_rank,
|
130
|
+
storage_config.tp_size,
|
131
|
+
storage_config.model_name,
|
132
|
+
storage_config.is_mla_model,
|
133
|
+
)
|
134
|
+
model_name = "-".join(model_name.split("/")) if model_name else ""
|
135
|
+
if is_mla_model:
|
136
|
+
self.config_suffix = f"_{model_name}"
|
114
137
|
else:
|
115
|
-
|
116
|
-
tp_size = get_tensor_model_parallel_world_size()
|
138
|
+
self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
|
117
139
|
|
118
|
-
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
|
119
140
|
if not os.path.exists(self.file_path) and tp_rank == 0:
|
120
141
|
os.makedirs(self.file_path)
|
121
142
|
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
122
143
|
|
123
144
|
def _get_suffixed_key(self, key: str) -> str:
|
124
|
-
return key + self.
|
145
|
+
return key + self.config_suffix
|
125
146
|
|
126
147
|
def get(
|
127
148
|
self,
|
@@ -132,13 +153,11 @@ class HiCacheFile(HiCacheStorage):
|
|
132
153
|
key = self._get_suffixed_key(key)
|
133
154
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
134
155
|
try:
|
135
|
-
|
136
|
-
with open(tensor_path, "rb") as f:
|
137
|
-
target_location.
|
138
|
-
|
139
|
-
|
140
|
-
.untyped_storage()
|
141
|
-
)
|
156
|
+
expected = target_location.numel() * target_location.element_size()
|
157
|
+
with open(tensor_path, "rb", buffering=0) as f:
|
158
|
+
buf = memoryview(target_location.view(torch.uint8).contiguous().numpy())
|
159
|
+
if f.readinto(buf) != expected:
|
160
|
+
raise IOError(f"Short read for {key}")
|
142
161
|
return target_location
|
143
162
|
except FileNotFoundError:
|
144
163
|
logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
|
@@ -164,11 +183,12 @@ class HiCacheFile(HiCacheStorage):
|
|
164
183
|
target_location: Optional[Any] = None,
|
165
184
|
target_sizes: Optional[Any] = None,
|
166
185
|
) -> bool:
|
167
|
-
key = self._get_suffixed_key(key)
|
168
|
-
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
169
186
|
if self.exists(key):
|
170
187
|
logger.debug(f"Key {key} already exists. Skipped.")
|
171
188
|
return True
|
189
|
+
|
190
|
+
key = self._get_suffixed_key(key)
|
191
|
+
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
172
192
|
try:
|
173
193
|
value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
|
174
194
|
return True
|
@@ -193,21 +213,14 @@ class HiCacheFile(HiCacheStorage):
|
|
193
213
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
194
214
|
return os.path.exists(tensor_path)
|
195
215
|
|
196
|
-
def
|
197
|
-
key = self._get_suffixed_key(key)
|
198
|
-
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
199
|
-
try:
|
200
|
-
os.remove(tensor_path)
|
201
|
-
except FileNotFoundError:
|
202
|
-
logger.warning(f"Key {key} does not exist. Cannot delete.")
|
203
|
-
return
|
204
|
-
|
205
|
-
def clear(self) -> None:
|
216
|
+
def clear(self) -> bool:
|
206
217
|
try:
|
207
218
|
for filename in os.listdir(self.file_path):
|
208
219
|
file_path = os.path.join(self.file_path, filename)
|
209
220
|
if os.path.isfile(file_path):
|
210
221
|
os.remove(file_path)
|
211
222
|
logger.info("Cleared all entries in HiCacheFile storage.")
|
223
|
+
return True
|
212
224
|
except Exception as e:
|
213
225
|
logger.error(f"Failed to clear HiCacheFile storage: {e}")
|
226
|
+
return False
|