sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -12,10 +12,11 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
"""A tensor parallel worker."""
|
15
|
+
from __future__ import annotations
|
15
16
|
|
16
17
|
import logging
|
17
18
|
import threading
|
18
|
-
from typing import Optional, Tuple, Union
|
19
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
19
20
|
|
20
21
|
import torch
|
21
22
|
|
@@ -45,6 +46,9 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
|
45
46
|
from sglang.srt.server_args import ServerArgs
|
46
47
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
47
48
|
|
49
|
+
if TYPE_CHECKING:
|
50
|
+
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
51
|
+
|
48
52
|
logger = logging.getLogger(__name__)
|
49
53
|
|
50
54
|
|
@@ -78,6 +82,11 @@ class TpModelWorker:
|
|
78
82
|
if not is_draft_worker
|
79
83
|
else server_args.speculative_draft_model_path
|
80
84
|
),
|
85
|
+
model_revision=(
|
86
|
+
server_args.revision
|
87
|
+
if not is_draft_worker
|
88
|
+
else server_args.speculative_draft_model_revision
|
89
|
+
),
|
81
90
|
is_draft_model=is_draft_worker,
|
82
91
|
)
|
83
92
|
|
@@ -137,7 +146,7 @@ class TpModelWorker:
|
|
137
146
|
assert self.max_running_requests > 0, "max_running_request is zero"
|
138
147
|
self.max_queued_requests = server_args.max_queued_requests
|
139
148
|
assert (
|
140
|
-
self.
|
149
|
+
self.max_queued_requests > 0
|
141
150
|
), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
|
142
151
|
self.max_req_len = min(
|
143
152
|
self.model_config.context_len - 1,
|
@@ -162,10 +171,10 @@ class TpModelWorker:
|
|
162
171
|
|
163
172
|
self.hicache_layer_transfer_counter = None
|
164
173
|
|
165
|
-
def register_hicache_layer_transfer_counter(self, counter):
|
174
|
+
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
166
175
|
self.hicache_layer_transfer_counter = counter
|
167
176
|
|
168
|
-
def set_hicache_consumer(self, consumer_index):
|
177
|
+
def set_hicache_consumer(self, consumer_index: int):
|
169
178
|
if self.hicache_layer_transfer_counter is not None:
|
170
179
|
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
171
180
|
|
@@ -225,6 +234,9 @@ class TpModelWorker:
|
|
225
234
|
) -> Tuple[
|
226
235
|
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
|
227
236
|
]:
|
237
|
+
# update the consumer index of hicache to the running batch
|
238
|
+
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
239
|
+
|
228
240
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
229
241
|
|
230
242
|
pp_proxy_tensors = None
|
@@ -12,13 +12,14 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
"""A tensor parallel worker."""
|
15
|
+
from __future__ import annotations
|
15
16
|
|
16
17
|
import dataclasses
|
17
18
|
import logging
|
18
19
|
import signal
|
19
20
|
import threading
|
20
21
|
from queue import Queue
|
21
|
-
from typing import Optional, Tuple
|
22
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
22
23
|
|
23
24
|
import psutil
|
24
25
|
import torch
|
@@ -38,6 +39,9 @@ from sglang.srt.server_args import ServerArgs
|
|
38
39
|
from sglang.srt.utils import DynamicGradMode, get_compiler_backend
|
39
40
|
from sglang.utils import get_exception_traceback
|
40
41
|
|
42
|
+
if TYPE_CHECKING:
|
43
|
+
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
44
|
+
|
41
45
|
logger = logging.getLogger(__name__)
|
42
46
|
|
43
47
|
|
@@ -79,7 +83,7 @@ class TpModelWorkerClient:
|
|
79
83
|
)
|
80
84
|
|
81
85
|
# Launch threads
|
82
|
-
self.input_queue = Queue()
|
86
|
+
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
|
83
87
|
self.output_queue = Queue()
|
84
88
|
self.forward_stream = torch.get_device_module(self.device).Stream()
|
85
89
|
self.forward_thread = threading.Thread(
|
@@ -93,13 +97,9 @@ class TpModelWorkerClient:
|
|
93
97
|
|
94
98
|
self.hicache_layer_transfer_counter = None
|
95
99
|
|
96
|
-
def register_hicache_layer_transfer_counter(self, counter):
|
100
|
+
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
97
101
|
self.hicache_layer_transfer_counter = counter
|
98
102
|
|
99
|
-
def set_hicache_consumer(self, consumer_index):
|
100
|
-
if self.hicache_layer_transfer_counter is not None:
|
101
|
-
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
102
|
-
|
103
103
|
def get_worker_info(self):
|
104
104
|
return self.worker.get_worker_info()
|
105
105
|
|
@@ -147,7 +147,7 @@ class TpModelWorkerClient:
|
|
147
147
|
@DynamicGradMode()
|
148
148
|
def forward_thread_func_(self):
|
149
149
|
batch_pt = 0
|
150
|
-
batch_lists = [None] * 2
|
150
|
+
batch_lists: List = [None] * 2
|
151
151
|
|
152
152
|
while True:
|
153
153
|
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
|
@@ -169,8 +169,6 @@ class TpModelWorkerClient:
|
|
169
169
|
input_ids = model_worker_batch.input_ids
|
170
170
|
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
171
171
|
|
172
|
-
# update the consumer index of hicache to the running batch
|
173
|
-
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
174
172
|
# Run forward
|
175
173
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
176
174
|
self.worker.forward_batch_generation(
|
@@ -283,7 +283,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
283
283
|
self.swa_attn_allocator.clear()
|
284
284
|
self.full_attn_allocator.clear()
|
285
285
|
self.full_to_swa_index_mapping.fill_(0)
|
286
|
-
self.
|
286
|
+
self.is_not_in_free_group = True
|
287
287
|
self.free_group = []
|
288
288
|
|
289
289
|
|
@@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache):
|
|
47
47
|
self.req_to_token_pool.free(req.req_pool_idx)
|
48
48
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
49
49
|
|
50
|
-
def cache_unfinished_req(self, req: Req):
|
50
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
51
51
|
kv_indices = self.req_to_token_pool.req_to_token[
|
52
52
|
req.req_pool_idx, : len(req.fill_ids)
|
53
53
|
]
|
@@ -27,6 +27,7 @@ class HiCacheStorageConfig:
|
|
27
27
|
tp_rank: int
|
28
28
|
tp_size: int
|
29
29
|
is_mla_model: bool
|
30
|
+
is_page_first_layout: bool
|
30
31
|
model_name: Optional[str]
|
31
32
|
extra_config: Optional[dict] = None
|
32
33
|
|
@@ -113,6 +114,9 @@ class HiCacheStorage(ABC):
|
|
113
114
|
return i
|
114
115
|
return len(keys)
|
115
116
|
|
117
|
+
def get_stats(self):
|
118
|
+
return None
|
119
|
+
|
116
120
|
|
117
121
|
class HiCacheFile(HiCacheStorage):
|
118
122
|
|
@@ -121,18 +125,24 @@ class HiCacheFile(HiCacheStorage):
|
|
121
125
|
):
|
122
126
|
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
123
127
|
|
124
|
-
tp_rank, tp_size,
|
128
|
+
tp_rank, tp_size, model_name, is_mla_model = (
|
125
129
|
storage_config.tp_rank,
|
126
130
|
storage_config.tp_size,
|
131
|
+
storage_config.model_name,
|
127
132
|
storage_config.is_mla_model,
|
128
133
|
)
|
129
|
-
|
134
|
+
model_name = "-".join(model_name.split("/")) if model_name else ""
|
135
|
+
if is_mla_model:
|
136
|
+
self.config_suffix = f"_{model_name}"
|
137
|
+
else:
|
138
|
+
self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
|
139
|
+
|
130
140
|
if not os.path.exists(self.file_path) and tp_rank == 0:
|
131
141
|
os.makedirs(self.file_path)
|
132
142
|
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
133
143
|
|
134
144
|
def _get_suffixed_key(self, key: str) -> str:
|
135
|
-
return key + self.
|
145
|
+
return key + self.config_suffix
|
136
146
|
|
137
147
|
def get(
|
138
148
|
self,
|
@@ -143,13 +153,11 @@ class HiCacheFile(HiCacheStorage):
|
|
143
153
|
key = self._get_suffixed_key(key)
|
144
154
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
145
155
|
try:
|
146
|
-
|
147
|
-
with open(tensor_path, "rb") as f:
|
148
|
-
target_location.
|
149
|
-
|
150
|
-
|
151
|
-
.untyped_storage()
|
152
|
-
)
|
156
|
+
expected = target_location.numel() * target_location.element_size()
|
157
|
+
with open(tensor_path, "rb", buffering=0) as f:
|
158
|
+
buf = memoryview(target_location.view(torch.uint8).contiguous().numpy())
|
159
|
+
if f.readinto(buf) != expected:
|
160
|
+
raise IOError(f"Short read for {key}")
|
153
161
|
return target_location
|
154
162
|
except FileNotFoundError:
|
155
163
|
logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
|
@@ -175,11 +183,12 @@ class HiCacheFile(HiCacheStorage):
|
|
175
183
|
target_location: Optional[Any] = None,
|
176
184
|
target_sizes: Optional[Any] = None,
|
177
185
|
) -> bool:
|
178
|
-
key = self._get_suffixed_key(key)
|
179
|
-
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
180
186
|
if self.exists(key):
|
181
187
|
logger.debug(f"Key {key} already exists. Skipped.")
|
182
188
|
return True
|
189
|
+
|
190
|
+
key = self._get_suffixed_key(key)
|
191
|
+
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
183
192
|
try:
|
184
193
|
value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
|
185
194
|
return True
|
@@ -204,21 +213,14 @@ class HiCacheFile(HiCacheStorage):
|
|
204
213
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
205
214
|
return os.path.exists(tensor_path)
|
206
215
|
|
207
|
-
def
|
208
|
-
key = self._get_suffixed_key(key)
|
209
|
-
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
210
|
-
try:
|
211
|
-
os.remove(tensor_path)
|
212
|
-
except FileNotFoundError:
|
213
|
-
logger.warning(f"Key {key} does not exist. Cannot delete.")
|
214
|
-
return
|
215
|
-
|
216
|
-
def clear(self) -> None:
|
216
|
+
def clear(self) -> bool:
|
217
217
|
try:
|
218
218
|
for filename in os.listdir(self.file_path):
|
219
219
|
file_path = os.path.join(self.file_path, filename)
|
220
220
|
if os.path.isfile(file_path):
|
221
221
|
os.remove(file_path)
|
222
222
|
logger.info("Cleared all entries in HiCacheFile storage.")
|
223
|
+
return True
|
223
224
|
except Exception as e:
|
224
225
|
logger.error(f"Failed to clear HiCacheFile storage: {e}")
|
226
|
+
return False
|