sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
from __future__ import annotations
|
17
|
+
|
16
18
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
17
19
|
|
18
20
|
"""
|
@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache.
|
|
27
29
|
import abc
|
28
30
|
import logging
|
29
31
|
from contextlib import nullcontext
|
30
|
-
from typing import Dict, List, Optional, Tuple, Union
|
32
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
31
33
|
|
32
34
|
import numpy as np
|
33
35
|
import torch
|
@@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
|
38
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
41
|
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
40
42
|
|
43
|
+
if TYPE_CHECKING:
|
44
|
+
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
45
|
+
|
41
46
|
logger = logging.getLogger(__name__)
|
42
47
|
|
43
48
|
GB = 1024 * 1024 * 1024
|
@@ -47,6 +52,10 @@ if _is_npu:
|
|
47
52
|
import torch_npu
|
48
53
|
|
49
54
|
|
55
|
+
def get_tensor_size_bytes(t: torch.Tensor):
|
56
|
+
return np.prod(t.shape) * t.dtype.itemsize
|
57
|
+
|
58
|
+
|
50
59
|
class ReqToTokenPool:
|
51
60
|
"""A memory pool that maps a request to its token locations."""
|
52
61
|
|
@@ -97,6 +106,211 @@ class ReqToTokenPool:
|
|
97
106
|
self.free_slots = list(range(self.size))
|
98
107
|
|
99
108
|
|
109
|
+
class MambaPool:
|
110
|
+
def __init__(
|
111
|
+
self,
|
112
|
+
size: int,
|
113
|
+
conv_dtype: torch.dtype,
|
114
|
+
ssm_dtype: torch.dtype,
|
115
|
+
num_mamba_layers: int,
|
116
|
+
conv_state_shape: Tuple[int, int],
|
117
|
+
temporal_state_shape: Tuple[int, int],
|
118
|
+
device: str,
|
119
|
+
speculative_num_draft_tokens: Optional[int] = None,
|
120
|
+
):
|
121
|
+
conv_state = torch.zeros(
|
122
|
+
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
123
|
+
dtype=conv_dtype,
|
124
|
+
device=device,
|
125
|
+
)
|
126
|
+
temporal_state = torch.zeros(
|
127
|
+
size=(num_mamba_layers, size + 1) + temporal_state_shape,
|
128
|
+
dtype=ssm_dtype,
|
129
|
+
device=device,
|
130
|
+
)
|
131
|
+
if speculative_num_draft_tokens is not None:
|
132
|
+
# Cache intermediate SSM states per draft token during target verify
|
133
|
+
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
134
|
+
intermediate_ssm_state_cache = torch.zeros(
|
135
|
+
size=(
|
136
|
+
num_mamba_layers,
|
137
|
+
size + 1,
|
138
|
+
speculative_num_draft_tokens,
|
139
|
+
temporal_state_shape[0],
|
140
|
+
temporal_state_shape[1],
|
141
|
+
temporal_state_shape[2],
|
142
|
+
),
|
143
|
+
dtype=ssm_dtype,
|
144
|
+
device="cuda",
|
145
|
+
)
|
146
|
+
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
147
|
+
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
148
|
+
intermediate_conv_window_cache = torch.zeros(
|
149
|
+
size=(
|
150
|
+
num_mamba_layers,
|
151
|
+
size + 1,
|
152
|
+
speculative_num_draft_tokens,
|
153
|
+
conv_state_shape[0],
|
154
|
+
conv_state_shape[1],
|
155
|
+
),
|
156
|
+
dtype=conv_dtype,
|
157
|
+
device="cuda",
|
158
|
+
)
|
159
|
+
self.mamba_cache = (
|
160
|
+
conv_state,
|
161
|
+
temporal_state,
|
162
|
+
intermediate_ssm_state_cache,
|
163
|
+
intermediate_conv_window_cache,
|
164
|
+
)
|
165
|
+
logger.info(
|
166
|
+
f"Mamba Cache is allocated. "
|
167
|
+
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
168
|
+
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
169
|
+
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
|
170
|
+
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
|
171
|
+
)
|
172
|
+
else:
|
173
|
+
self.mamba_cache = (conv_state, temporal_state)
|
174
|
+
logger.info(
|
175
|
+
f"Mamba Cache is allocated. "
|
176
|
+
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
177
|
+
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
178
|
+
)
|
179
|
+
self.size = size
|
180
|
+
self.free_slots = list(range(size))
|
181
|
+
self.mem_usage = self.get_mamba_size() / GB
|
182
|
+
|
183
|
+
def get_mamba_params_all_layers(self):
|
184
|
+
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
|
185
|
+
|
186
|
+
def get_mamba_params(self, layer_id: int):
|
187
|
+
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
|
188
|
+
|
189
|
+
def get_mamba_size(self):
|
190
|
+
return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
|
191
|
+
|
192
|
+
def available_size(self):
|
193
|
+
return len(self.free_slots)
|
194
|
+
|
195
|
+
def alloc(self, need_size: int) -> Optional[List[int]]:
|
196
|
+
if need_size > len(self.free_slots):
|
197
|
+
return None
|
198
|
+
|
199
|
+
select_index = self.free_slots[:need_size]
|
200
|
+
self.free_slots = self.free_slots[need_size:]
|
201
|
+
|
202
|
+
return select_index
|
203
|
+
|
204
|
+
def free(self, free_index: Union[int, List[int]]):
|
205
|
+
if isinstance(free_index, (int,)):
|
206
|
+
self.free_slots.append(free_index)
|
207
|
+
else:
|
208
|
+
self.free_slots.extend(free_index)
|
209
|
+
self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
|
210
|
+
|
211
|
+
def clear(self):
|
212
|
+
self.free_slots = list(range(self.size))
|
213
|
+
|
214
|
+
|
215
|
+
class HybridReqToTokenPool(ReqToTokenPool):
|
216
|
+
"""A memory pool that maps a request to its token locations."""
|
217
|
+
|
218
|
+
def __init__(
|
219
|
+
self,
|
220
|
+
size: int,
|
221
|
+
max_context_len: int,
|
222
|
+
device: str,
|
223
|
+
enable_memory_saver: bool,
|
224
|
+
conv_dtype: torch.dtype,
|
225
|
+
ssm_dtype: torch.dtype,
|
226
|
+
mamba_layers: List[int],
|
227
|
+
conv_state_shape: Tuple[int, int],
|
228
|
+
temporal_state_shape: Tuple[int, int],
|
229
|
+
speculative_num_draft_tokens: int,
|
230
|
+
):
|
231
|
+
super().__init__(
|
232
|
+
size=size,
|
233
|
+
max_context_len=max_context_len,
|
234
|
+
device=device,
|
235
|
+
enable_memory_saver=enable_memory_saver,
|
236
|
+
)
|
237
|
+
|
238
|
+
self.mamba_pool = MambaPool(
|
239
|
+
size,
|
240
|
+
conv_dtype,
|
241
|
+
ssm_dtype,
|
242
|
+
len(mamba_layers),
|
243
|
+
conv_state_shape,
|
244
|
+
temporal_state_shape,
|
245
|
+
device,
|
246
|
+
speculative_num_draft_tokens,
|
247
|
+
)
|
248
|
+
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
|
249
|
+
|
250
|
+
self.device = device
|
251
|
+
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
|
252
|
+
size, dtype=torch.int32, device=self.device
|
253
|
+
)
|
254
|
+
|
255
|
+
self.rid_to_mamba_index_mapping: Dict[str, int] = {}
|
256
|
+
self.mamba_index_to_rid_mapping: Dict[int, str] = {}
|
257
|
+
|
258
|
+
# For chunk prefill req, we do not need to allocate mamba cache,
|
259
|
+
# We could use allocated mamba cache instead.
|
260
|
+
def alloc(
|
261
|
+
self, need_size: int, reqs: Optional[List["Req"]] = None
|
262
|
+
) -> Optional[List[int]]:
|
263
|
+
select_index = super().alloc(need_size)
|
264
|
+
if select_index == None:
|
265
|
+
return None
|
266
|
+
|
267
|
+
mamba_index = []
|
268
|
+
for req in reqs:
|
269
|
+
rid = req.rid
|
270
|
+
if rid in self.rid_to_mamba_index_mapping:
|
271
|
+
mid = self.rid_to_mamba_index_mapping[rid]
|
272
|
+
elif (mid := self.mamba_pool.alloc(1)) is not None:
|
273
|
+
mid = mid[0]
|
274
|
+
self.rid_to_mamba_index_mapping[rid] = mid
|
275
|
+
self.mamba_index_to_rid_mapping[mid] = rid
|
276
|
+
mamba_index.append(mid)
|
277
|
+
assert len(select_index) == len(
|
278
|
+
mamba_index
|
279
|
+
), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
|
280
|
+
self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
|
281
|
+
mamba_index, dtype=torch.int32, device=self.device
|
282
|
+
)
|
283
|
+
return select_index
|
284
|
+
|
285
|
+
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
|
286
|
+
return self.req_index_to_mamba_index_mapping[req_indices]
|
287
|
+
|
288
|
+
def get_mamba_params(self, layer_id: int):
|
289
|
+
assert layer_id in self.mamba_map
|
290
|
+
return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
|
291
|
+
|
292
|
+
def get_mamba_params_all_layers(self):
|
293
|
+
return self.mamba_pool.get_mamba_params_all_layers()
|
294
|
+
|
295
|
+
# For chunk prefill, we can not free mamba cache, we need use it in the future
|
296
|
+
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
|
297
|
+
super().free(free_index)
|
298
|
+
if free_mamba_cache:
|
299
|
+
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
|
300
|
+
mamba_index_list = mamba_index.tolist()
|
301
|
+
if isinstance(mamba_index_list, int):
|
302
|
+
mamba_index_list = [mamba_index_list]
|
303
|
+
self.mamba_pool.free(mamba_index_list)
|
304
|
+
for mid in mamba_index_list:
|
305
|
+
rid = self.mamba_index_to_rid_mapping[mid]
|
306
|
+
self.mamba_index_to_rid_mapping.pop(mid)
|
307
|
+
self.rid_to_mamba_index_mapping.pop(rid)
|
308
|
+
|
309
|
+
def clear(self):
|
310
|
+
super().clear()
|
311
|
+
self.mamba_pool.clear()
|
312
|
+
|
313
|
+
|
100
314
|
class KVCache(abc.ABC):
|
101
315
|
@abc.abstractmethod
|
102
316
|
def __init__(
|
@@ -130,6 +344,29 @@ class KVCache(abc.ABC):
|
|
130
344
|
# used for chunked cpu-offloading
|
131
345
|
self.cpu_offloading_chunk_size = 8192
|
132
346
|
|
347
|
+
# default state for optional layer-wise transfer control
|
348
|
+
self.layer_transfer_counter = None
|
349
|
+
|
350
|
+
def _finalize_allocation_log(self, num_tokens: int):
|
351
|
+
"""Common logging and mem_usage computation for KV cache allocation.
|
352
|
+
Supports both tuple (K, V) size returns and single KV size returns.
|
353
|
+
"""
|
354
|
+
kv_size_bytes = self.get_kv_size_bytes()
|
355
|
+
if isinstance(kv_size_bytes, tuple):
|
356
|
+
k_size, v_size = kv_size_bytes
|
357
|
+
k_size_GB = k_size / GB
|
358
|
+
v_size_GB = v_size / GB
|
359
|
+
logger.info(
|
360
|
+
f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB"
|
361
|
+
)
|
362
|
+
self.mem_usage = k_size_GB + v_size_GB
|
363
|
+
else:
|
364
|
+
kv_size_GB = kv_size_bytes / GB
|
365
|
+
logger.info(
|
366
|
+
f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB"
|
367
|
+
)
|
368
|
+
self.mem_usage = kv_size_GB
|
369
|
+
|
133
370
|
@abc.abstractmethod
|
134
371
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
135
372
|
raise NotImplementedError()
|
@@ -152,7 +389,7 @@ class KVCache(abc.ABC):
|
|
152
389
|
) -> None:
|
153
390
|
raise NotImplementedError()
|
154
391
|
|
155
|
-
def register_layer_transfer_counter(self, layer_transfer_counter):
|
392
|
+
def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
|
156
393
|
self.layer_transfer_counter = layer_transfer_counter
|
157
394
|
|
158
395
|
def get_cpu_copy(self, indices):
|
@@ -205,15 +442,9 @@ class MHATokenToKVPool(KVCache):
|
|
205
442
|
|
206
443
|
self._create_buffers()
|
207
444
|
|
208
|
-
self.layer_transfer_counter = None
|
209
445
|
self.device_module = torch.get_device_module(self.device)
|
210
446
|
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
211
|
-
|
212
|
-
k_size, v_size = self.get_kv_size_bytes()
|
213
|
-
logger.info(
|
214
|
-
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
215
|
-
)
|
216
|
-
self.mem_usage = (k_size + v_size) / GB
|
447
|
+
self._finalize_allocation_log(size)
|
217
448
|
|
218
449
|
def _create_buffers(self):
|
219
450
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
@@ -269,10 +500,10 @@ class MHATokenToKVPool(KVCache):
|
|
269
500
|
assert hasattr(self, "v_buffer")
|
270
501
|
k_size_bytes = 0
|
271
502
|
for k_cache in self.k_buffer:
|
272
|
-
k_size_bytes +=
|
503
|
+
k_size_bytes += get_tensor_size_bytes(k_cache)
|
273
504
|
v_size_bytes = 0
|
274
505
|
for v_cache in self.v_buffer:
|
275
|
-
v_size_bytes +=
|
506
|
+
v_size_bytes += get_tensor_size_bytes(v_cache)
|
276
507
|
return k_size_bytes, v_size_bytes
|
277
508
|
|
278
509
|
# for disagg
|
@@ -352,7 +583,6 @@ class MHATokenToKVPool(KVCache):
|
|
352
583
|
# same applies to get_value_buffer and get_kv_buffer
|
353
584
|
if self.layer_transfer_counter is not None:
|
354
585
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
355
|
-
|
356
586
|
return self._get_key_buffer(layer_id)
|
357
587
|
|
358
588
|
def _get_value_buffer(self, layer_id: int):
|
@@ -420,41 +650,31 @@ class MHATokenToKVPool(KVCache):
|
|
420
650
|
)
|
421
651
|
|
422
652
|
|
423
|
-
class
|
424
|
-
"""KV cache with separate pools for full and
|
653
|
+
class HybridLinearKVPool(KVCache):
|
654
|
+
"""KV cache with separate pools for full and linear attention layers."""
|
425
655
|
|
426
656
|
def __init__(
|
427
657
|
self,
|
428
658
|
size: int,
|
429
|
-
size_swa: int,
|
430
659
|
dtype: torch.dtype,
|
660
|
+
page_size: int,
|
431
661
|
head_num: int,
|
432
662
|
head_dim: int,
|
433
|
-
swa_attention_layer_ids: List[int],
|
434
663
|
full_attention_layer_ids: List[int],
|
435
664
|
enable_kvcache_transpose: bool,
|
436
665
|
device: str,
|
437
666
|
):
|
438
667
|
self.size = size
|
439
|
-
self.size_swa = size_swa
|
440
668
|
self.dtype = dtype
|
441
669
|
self.device = device
|
442
|
-
self.swa_layer_nums = len(swa_attention_layer_ids)
|
443
670
|
self.full_layer_nums = len(full_attention_layer_ids)
|
444
|
-
self.page_size =
|
671
|
+
self.page_size = page_size
|
445
672
|
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
446
673
|
assert not enable_kvcache_transpose
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
dtype=dtype,
|
452
|
-
head_num=head_num,
|
453
|
-
head_dim=head_dim,
|
454
|
-
layer_num=self.swa_layer_nums,
|
455
|
-
device=device,
|
456
|
-
enable_memory_saver=False,
|
457
|
-
)
|
674
|
+
if _is_npu:
|
675
|
+
TokenToKVPoolClass = AscendTokenToKVPool
|
676
|
+
else:
|
677
|
+
TokenToKVPoolClass = MHATokenToKVPool
|
458
678
|
self.full_kv_pool = TokenToKVPoolClass(
|
459
679
|
size=size,
|
460
680
|
page_size=self.page_size,
|
@@ -465,6 +685,93 @@ class SWAKVPool(KVCache):
|
|
465
685
|
device=device,
|
466
686
|
enable_memory_saver=False,
|
467
687
|
)
|
688
|
+
self.full_attention_layer_id_mapping = {
|
689
|
+
id: i for i, id in enumerate(full_attention_layer_ids)
|
690
|
+
}
|
691
|
+
k_size, v_size = self.get_kv_size_bytes()
|
692
|
+
self.mem_usage = (k_size + v_size) / GB
|
693
|
+
|
694
|
+
def get_kv_size_bytes(self):
|
695
|
+
return self.full_kv_pool.get_kv_size_bytes()
|
696
|
+
|
697
|
+
def get_contiguous_buf_infos(self):
|
698
|
+
return self.full_kv_pool.get_contiguous_buf_infos()
|
699
|
+
|
700
|
+
def _transfer_full_attention_id(self, layer_id: int):
|
701
|
+
if layer_id not in self.full_attention_layer_id_mapping:
|
702
|
+
raise ValueError(
|
703
|
+
f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
|
704
|
+
)
|
705
|
+
return self.full_attention_layer_id_mapping[layer_id]
|
706
|
+
|
707
|
+
def get_key_buffer(self, layer_id: int):
|
708
|
+
layer_id = self._transfer_full_attention_id(layer_id)
|
709
|
+
return self.full_kv_pool.get_key_buffer(layer_id)
|
710
|
+
|
711
|
+
def get_value_buffer(self, layer_id: int):
|
712
|
+
layer_id = self._transfer_full_attention_id(layer_id)
|
713
|
+
return self.full_kv_pool.get_value_buffer(layer_id)
|
714
|
+
|
715
|
+
def get_kv_buffer(self, layer_id: int):
|
716
|
+
layer_id = self._transfer_full_attention_id(layer_id)
|
717
|
+
return self.full_kv_pool.get_kv_buffer(layer_id)
|
718
|
+
|
719
|
+
def set_kv_buffer(
|
720
|
+
self,
|
721
|
+
layer: RadixAttention,
|
722
|
+
loc: torch.Tensor,
|
723
|
+
cache_k: torch.Tensor,
|
724
|
+
cache_v: torch.Tensor,
|
725
|
+
k_scale: float = 1.0,
|
726
|
+
v_scale: float = 1.0,
|
727
|
+
):
|
728
|
+
layer_id = self._transfer_full_attention_id(layer.layer_id)
|
729
|
+
self.full_kv_pool.set_kv_buffer(
|
730
|
+
None,
|
731
|
+
loc,
|
732
|
+
cache_k,
|
733
|
+
cache_v,
|
734
|
+
k_scale,
|
735
|
+
v_scale,
|
736
|
+
layer_id_override=layer_id,
|
737
|
+
)
|
738
|
+
|
739
|
+
def get_v_head_dim(self):
|
740
|
+
return self.full_kv_pool.get_value_buffer(0).shape[-1]
|
741
|
+
|
742
|
+
|
743
|
+
class SWAKVPool(KVCache):
|
744
|
+
"""KV cache with separate pools for full and SWA attention layers."""
|
745
|
+
|
746
|
+
def __init__(
|
747
|
+
self,
|
748
|
+
size: int,
|
749
|
+
size_swa: int,
|
750
|
+
swa_attention_layer_ids: List[int],
|
751
|
+
full_attention_layer_ids: List[int],
|
752
|
+
enable_kvcache_transpose: bool,
|
753
|
+
token_to_kv_pool_class: KVCache = MHATokenToKVPool,
|
754
|
+
**kwargs,
|
755
|
+
):
|
756
|
+
self.size = size
|
757
|
+
self.size_swa = size_swa
|
758
|
+
self.swa_layer_nums = len(swa_attention_layer_ids)
|
759
|
+
self.full_layer_nums = len(full_attention_layer_ids)
|
760
|
+
kwargs["page_size"] = 1
|
761
|
+
kwargs["enable_memory_saver"] = False
|
762
|
+
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
763
|
+
assert not enable_kvcache_transpose
|
764
|
+
|
765
|
+
self.swa_kv_pool = token_to_kv_pool_class(
|
766
|
+
size=size_swa,
|
767
|
+
layer_num=self.swa_layer_nums,
|
768
|
+
**kwargs,
|
769
|
+
)
|
770
|
+
self.full_kv_pool = token_to_kv_pool_class(
|
771
|
+
size=size,
|
772
|
+
layer_num=self.full_layer_nums,
|
773
|
+
**kwargs,
|
774
|
+
)
|
468
775
|
self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
|
469
776
|
for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
|
470
777
|
self.layers_mapping[global_layer_id] = (full_attn_layer_id, False)
|
@@ -613,8 +920,12 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
|
613
920
|
cache_v: torch.Tensor,
|
614
921
|
k_scale: Optional[float] = None,
|
615
922
|
v_scale: Optional[float] = None,
|
923
|
+
layer_id_override: Optional[int] = None,
|
616
924
|
):
|
617
|
-
|
925
|
+
if layer_id_override is not None:
|
926
|
+
layer_id = layer_id_override
|
927
|
+
else:
|
928
|
+
layer_id = layer.layer_id
|
618
929
|
if cache_k.dtype != self.dtype:
|
619
930
|
if k_scale is not None:
|
620
931
|
cache_k.div_(k_scale)
|
@@ -768,19 +1079,13 @@ class MLATokenToKVPool(KVCache):
|
|
768
1079
|
dtype=torch.uint64,
|
769
1080
|
device=self.device,
|
770
1081
|
)
|
771
|
-
self.
|
772
|
-
|
773
|
-
kv_size = self.get_kv_size_bytes()
|
774
|
-
logger.info(
|
775
|
-
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
776
|
-
)
|
777
|
-
self.mem_usage = kv_size / GB
|
1082
|
+
self._finalize_allocation_log(size)
|
778
1083
|
|
779
1084
|
def get_kv_size_bytes(self):
|
780
1085
|
assert hasattr(self, "kv_buffer")
|
781
1086
|
kv_size_bytes = 0
|
782
1087
|
for kv_cache in self.kv_buffer:
|
783
|
-
kv_size_bytes +=
|
1088
|
+
kv_size_bytes += get_tensor_size_bytes(kv_cache)
|
784
1089
|
return kv_size_bytes
|
785
1090
|
|
786
1091
|
# for disagg
|
@@ -936,22 +1241,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
936
1241
|
device=self.device,
|
937
1242
|
)
|
938
1243
|
|
939
|
-
self.
|
940
|
-
|
941
|
-
kv_size = self.get_kv_size_bytes()
|
942
|
-
logger.info(
|
943
|
-
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
944
|
-
)
|
945
|
-
self.mem_usage = kv_size / GB
|
1244
|
+
self._finalize_allocation_log(size)
|
946
1245
|
|
947
1246
|
def get_kv_size_bytes(self):
|
948
1247
|
assert hasattr(self, "k_buffer")
|
949
1248
|
assert hasattr(self, "v_buffer")
|
950
1249
|
kv_size_bytes = 0
|
951
1250
|
for k_cache in self.k_buffer:
|
952
|
-
kv_size_bytes +=
|
1251
|
+
kv_size_bytes += get_tensor_size_bytes(k_cache)
|
953
1252
|
for v_cache in self.v_buffer:
|
954
|
-
kv_size_bytes +=
|
1253
|
+
kv_size_bytes += get_tensor_size_bytes(v_cache)
|
955
1254
|
return kv_size_bytes
|
956
1255
|
|
957
1256
|
def get_kv_buffer(self, layer_id: int):
|