sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -11
- sglang/bench_serving.py +149 -1
- sglang/check_env.py +3 -3
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +32 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +151 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +58 -24
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +22 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +129 -94
- sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +6 -1
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +81 -35
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +44 -16
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +291 -72
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +60 -28
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +159 -90
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +2 -277
- sglang/srt/models/deepseek_v2.py +132 -37
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +93 -31
- sglang/srt/models/llama4.py +54 -7
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +4 -16
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +58 -62
- sglang/srt/openai_api/protocol.py +38 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +93 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +123 -10
- sglang/test/runners.py +4 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +32 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -54,7 +54,11 @@ from sglang.srt.disaggregation.utils import (
|
|
54
54
|
TransferBackend,
|
55
55
|
get_kv_class,
|
56
56
|
)
|
57
|
-
from sglang.srt.hf_transformers_utils import
|
57
|
+
from sglang.srt.hf_transformers_utils import (
|
58
|
+
get_processor,
|
59
|
+
get_tokenizer,
|
60
|
+
get_tokenizer_from_processor,
|
61
|
+
)
|
58
62
|
from sglang.srt.managers.io_struct import (
|
59
63
|
AbortReq,
|
60
64
|
BatchEmbeddingOut,
|
@@ -86,6 +90,8 @@ from sglang.srt.managers.io_struct import (
|
|
86
90
|
ResumeMemoryOccupationReqInput,
|
87
91
|
ResumeMemoryOccupationReqOutput,
|
88
92
|
SessionParams,
|
93
|
+
SlowDownReqInput,
|
94
|
+
SlowDownReqOutput,
|
89
95
|
TokenizedEmbeddingReqInput,
|
90
96
|
TokenizedGenerateReqInput,
|
91
97
|
UpdateWeightFromDiskReqInput,
|
@@ -161,17 +167,7 @@ class TokenizerManager:
|
|
161
167
|
# Read model args
|
162
168
|
self.model_path = server_args.model_path
|
163
169
|
self.served_model_name = server_args.served_model_name
|
164
|
-
self.model_config = ModelConfig(
|
165
|
-
server_args.model_path,
|
166
|
-
trust_remote_code=server_args.trust_remote_code,
|
167
|
-
revision=server_args.revision,
|
168
|
-
context_length=server_args.context_length,
|
169
|
-
model_override_args=server_args.json_model_override_args,
|
170
|
-
is_embedding=server_args.is_embedding,
|
171
|
-
enable_multimodal=server_args.enable_multimodal,
|
172
|
-
dtype=server_args.dtype,
|
173
|
-
quantization=server_args.quantization,
|
174
|
-
)
|
170
|
+
self.model_config = ModelConfig.from_server_args(server_args)
|
175
171
|
|
176
172
|
self.is_generation = self.model_config.is_generation
|
177
173
|
self.is_image_gen = self.model_config.is_image_gen
|
@@ -199,7 +195,7 @@ class TokenizerManager:
|
|
199
195
|
self.tokenizer = self.processor = None
|
200
196
|
else:
|
201
197
|
self.processor = _processor
|
202
|
-
self.tokenizer = self.processor
|
198
|
+
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
203
199
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
204
200
|
else:
|
205
201
|
self.mm_processor = get_dummy_processor()
|
@@ -265,6 +261,9 @@ class TokenizerManager:
|
|
265
261
|
self.resume_memory_occupation_communicator = _Communicator(
|
266
262
|
self.send_to_scheduler, server_args.dp_size
|
267
263
|
)
|
264
|
+
self.slow_down_communicator = _Communicator(
|
265
|
+
self.send_to_scheduler, server_args.dp_size
|
266
|
+
)
|
268
267
|
self.flush_cache_communicator = _Communicator(
|
269
268
|
self.send_to_scheduler, server_args.dp_size
|
270
269
|
)
|
@@ -318,6 +317,10 @@ class TokenizerManager:
|
|
318
317
|
ResumeMemoryOccupationReqOutput,
|
319
318
|
self.resume_memory_occupation_communicator.handle_recv,
|
320
319
|
),
|
320
|
+
(
|
321
|
+
SlowDownReqOutput,
|
322
|
+
self.slow_down_communicator.handle_recv,
|
323
|
+
),
|
321
324
|
(
|
322
325
|
FlushCacheReqOutput,
|
323
326
|
self.flush_cache_communicator.handle_recv,
|
@@ -876,6 +879,14 @@ class TokenizerManager:
|
|
876
879
|
self.auto_create_handle_loop()
|
877
880
|
await self.resume_memory_occupation_communicator(obj)
|
878
881
|
|
882
|
+
async def slow_down(
|
883
|
+
self,
|
884
|
+
obj: SlowDownReqInput,
|
885
|
+
request: Optional[fastapi.Request] = None,
|
886
|
+
):
|
887
|
+
self.auto_create_handle_loop()
|
888
|
+
await self.slow_down_communicator(obj)
|
889
|
+
|
879
890
|
async def open_session(
|
880
891
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
881
892
|
):
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -15,12 +15,17 @@
|
|
15
15
|
|
16
16
|
import logging
|
17
17
|
import threading
|
18
|
-
from typing import Optional, Tuple
|
18
|
+
from typing import Optional, Tuple, Union
|
19
19
|
|
20
20
|
import torch
|
21
21
|
|
22
22
|
from sglang.srt.configs.model_config import ModelConfig
|
23
|
-
from sglang.srt.
|
23
|
+
from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
|
24
|
+
from sglang.srt.hf_transformers_utils import (
|
25
|
+
get_processor,
|
26
|
+
get_tokenizer,
|
27
|
+
get_tokenizer_from_processor,
|
28
|
+
)
|
24
29
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
25
30
|
from sglang.srt.managers.io_struct import (
|
26
31
|
GetWeightsByNameReqInput,
|
@@ -31,7 +36,7 @@ from sglang.srt.managers.io_struct import (
|
|
31
36
|
)
|
32
37
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
33
38
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
34
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
35
40
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
36
41
|
from sglang.srt.server_args import ServerArgs
|
37
42
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
@@ -47,6 +52,7 @@ class TpModelWorker:
|
|
47
52
|
server_args: ServerArgs,
|
48
53
|
gpu_id: int,
|
49
54
|
tp_rank: int,
|
55
|
+
pp_rank: int,
|
50
56
|
dp_rank: Optional[int],
|
51
57
|
nccl_port: int,
|
52
58
|
is_draft_worker: bool = False,
|
@@ -54,30 +60,29 @@ class TpModelWorker:
|
|
54
60
|
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
55
61
|
):
|
56
62
|
# Parse args
|
63
|
+
self.tp_size = server_args.tp_size
|
57
64
|
self.tp_rank = tp_rank
|
65
|
+
self.pp_rank = pp_rank
|
58
66
|
|
59
67
|
# Init model and tokenizer
|
60
|
-
self.model_config = ModelConfig(
|
61
|
-
|
68
|
+
self.model_config = ModelConfig.from_server_args(
|
69
|
+
server_args,
|
70
|
+
model_path=(
|
62
71
|
server_args.model_path
|
63
72
|
if not is_draft_worker
|
64
73
|
else server_args.speculative_draft_model_path
|
65
74
|
),
|
66
|
-
|
67
|
-
revision=server_args.revision,
|
68
|
-
context_length=server_args.context_length,
|
69
|
-
model_override_args=server_args.json_model_override_args,
|
70
|
-
is_embedding=server_args.is_embedding,
|
71
|
-
enable_multimodal=server_args.enable_multimodal,
|
72
|
-
dtype=server_args.dtype,
|
73
|
-
quantization=server_args.quantization,
|
75
|
+
is_draft_model=is_draft_worker,
|
74
76
|
)
|
77
|
+
|
75
78
|
self.model_runner = ModelRunner(
|
76
79
|
model_config=self.model_config,
|
77
80
|
mem_fraction_static=server_args.mem_fraction_static,
|
78
81
|
gpu_id=gpu_id,
|
79
82
|
tp_rank=tp_rank,
|
80
83
|
tp_size=server_args.tp_size,
|
84
|
+
pp_rank=pp_rank,
|
85
|
+
pp_size=server_args.pp_size,
|
81
86
|
nccl_port=nccl_port,
|
82
87
|
server_args=server_args,
|
83
88
|
is_draft_worker=is_draft_worker,
|
@@ -94,7 +99,7 @@ class TpModelWorker:
|
|
94
99
|
trust_remote_code=server_args.trust_remote_code,
|
95
100
|
revision=server_args.revision,
|
96
101
|
)
|
97
|
-
self.tokenizer = self.processor
|
102
|
+
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
98
103
|
else:
|
99
104
|
self.tokenizer = get_tokenizer(
|
100
105
|
server_args.tokenizer_path,
|
@@ -104,6 +109,10 @@ class TpModelWorker:
|
|
104
109
|
)
|
105
110
|
self.device = self.model_runner.device
|
106
111
|
|
112
|
+
# Init nccl groups
|
113
|
+
self.pp_group = get_pp_group()
|
114
|
+
self.world_group = get_world_group()
|
115
|
+
|
107
116
|
# Profile number of tokens
|
108
117
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
109
118
|
self.max_prefill_tokens = server_args.max_prefill_tokens
|
@@ -129,8 +138,9 @@ class TpModelWorker:
|
|
129
138
|
# Sync random seed across TP workers
|
130
139
|
self.random_seed = broadcast_pyobj(
|
131
140
|
[server_args.random_seed],
|
132
|
-
self.tp_rank,
|
133
|
-
self.
|
141
|
+
self.tp_size * self.pp_rank + tp_rank,
|
142
|
+
self.world_group.cpu_group,
|
143
|
+
src=self.world_group.ranks[0],
|
134
144
|
)[0]
|
135
145
|
set_random_seed(self.random_seed)
|
136
146
|
|
@@ -155,11 +165,14 @@ class TpModelWorker:
|
|
155
165
|
def get_pad_input_ids_func(self):
|
156
166
|
return getattr(self.model_runner.model, "pad_input_ids", None)
|
157
167
|
|
158
|
-
def
|
159
|
-
return self.model_runner.tp_group
|
168
|
+
def get_tp_group(self):
|
169
|
+
return self.model_runner.tp_group
|
170
|
+
|
171
|
+
def get_attention_tp_group(self):
|
172
|
+
return self.model_runner.attention_tp_group
|
160
173
|
|
161
174
|
def get_attention_tp_cpu_group(self):
|
162
|
-
return self.model_runner.attention_tp_group
|
175
|
+
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
|
163
176
|
|
164
177
|
def get_memory_pool(self):
|
165
178
|
return (
|
@@ -171,19 +184,38 @@ class TpModelWorker:
|
|
171
184
|
self,
|
172
185
|
model_worker_batch: ModelWorkerBatch,
|
173
186
|
skip_sample: bool = False,
|
174
|
-
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
|
187
|
+
) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]:
|
175
188
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
176
|
-
logits_output = self.model_runner.forward(forward_batch)
|
177
189
|
|
178
|
-
|
179
|
-
|
190
|
+
pp_proxy_tensors = None
|
191
|
+
if not self.pp_group.is_first_rank:
|
192
|
+
pp_proxy_tensors = PPProxyTensors(
|
193
|
+
self.pp_group.recv_tensor_dict(
|
194
|
+
all_gather_group=self.get_attention_tp_group()
|
195
|
+
)
|
196
|
+
)
|
180
197
|
|
181
|
-
if
|
182
|
-
|
183
|
-
|
184
|
-
|
198
|
+
if self.pp_group.is_last_rank:
|
199
|
+
logits_output = self.model_runner.forward(
|
200
|
+
forward_batch, pp_proxy_tensors=pp_proxy_tensors
|
201
|
+
)
|
202
|
+
if model_worker_batch.launch_done is not None:
|
203
|
+
model_worker_batch.launch_done.set()
|
204
|
+
|
205
|
+
if skip_sample:
|
206
|
+
next_token_ids = None
|
207
|
+
else:
|
208
|
+
next_token_ids = self.model_runner.sample(
|
209
|
+
logits_output, model_worker_batch
|
210
|
+
)
|
185
211
|
|
186
|
-
|
212
|
+
return logits_output, next_token_ids
|
213
|
+
else:
|
214
|
+
pp_proxy_tensors = self.model_runner.forward(
|
215
|
+
forward_batch,
|
216
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
217
|
+
)
|
218
|
+
return pp_proxy_tensors.tensors, None
|
187
219
|
|
188
220
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
189
221
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
@@ -56,11 +56,14 @@ class TpModelWorkerClient:
|
|
56
56
|
server_args: ServerArgs,
|
57
57
|
gpu_id: int,
|
58
58
|
tp_rank: int,
|
59
|
+
pp_rank: int,
|
59
60
|
dp_rank: Optional[int],
|
60
61
|
nccl_port: int,
|
61
62
|
):
|
62
63
|
# Load the model
|
63
|
-
self.worker = TpModelWorker(
|
64
|
+
self.worker = TpModelWorker(
|
65
|
+
server_args, gpu_id, tp_rank, pp_rank, dp_rank, nccl_port
|
66
|
+
)
|
64
67
|
self.max_running_requests = self.worker.max_running_requests
|
65
68
|
self.device = self.worker.device
|
66
69
|
self.gpu_id = gpu_id
|
@@ -91,8 +94,11 @@ class TpModelWorkerClient:
|
|
91
94
|
def get_pad_input_ids_func(self):
|
92
95
|
return self.worker.get_pad_input_ids_func()
|
93
96
|
|
94
|
-
def
|
95
|
-
return self.worker.
|
97
|
+
def get_tp_group(self):
|
98
|
+
return self.worker.get_tp_group()
|
99
|
+
|
100
|
+
def get_attention_tp_group(self):
|
101
|
+
return self.worker.get_attention_tp_group()
|
96
102
|
|
97
103
|
def get_attention_tp_cpu_group(self):
|
98
104
|
return self.worker.get_attention_tp_cpu_group()
|
@@ -24,9 +24,11 @@ class ChunkCache(BasePrefixCache):
|
|
24
24
|
self,
|
25
25
|
req_to_token_pool: ReqToTokenPool,
|
26
26
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
27
|
+
page_size: int,
|
27
28
|
):
|
28
29
|
self.req_to_token_pool = req_to_token_pool
|
29
30
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
31
|
+
self.page_size = page_size
|
30
32
|
|
31
33
|
def reset(self):
|
32
34
|
pass
|
@@ -214,6 +214,8 @@ class MHATokenToKVPool(KVCache):
|
|
214
214
|
layer_num: int,
|
215
215
|
device: str,
|
216
216
|
enable_memory_saver: bool,
|
217
|
+
start_layer: Optional[int] = None,
|
218
|
+
end_layer: Optional[int] = None,
|
217
219
|
):
|
218
220
|
self.size = size
|
219
221
|
self.page_size = page_size
|
@@ -232,6 +234,8 @@ class MHATokenToKVPool(KVCache):
|
|
232
234
|
self.head_dim = head_dim
|
233
235
|
self.layer_num = layer_num
|
234
236
|
self._create_buffers()
|
237
|
+
self.start_layer = start_layer or 0
|
238
|
+
self.end_layer = end_layer or layer_num - 1
|
235
239
|
|
236
240
|
self.layer_transfer_counter = None
|
237
241
|
self.capture_mode = False
|
@@ -281,6 +285,8 @@ class MHATokenToKVPool(KVCache):
|
|
281
285
|
|
282
286
|
# for disagg
|
283
287
|
def get_contiguous_buf_infos(self):
|
288
|
+
# layer_num x [seq_len, head_num, head_dim]
|
289
|
+
# layer_num x [page_num, page_size, head_num, head_dim]
|
284
290
|
kv_data_ptrs = [
|
285
291
|
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
|
286
292
|
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
|
@@ -320,24 +326,24 @@ class MHATokenToKVPool(KVCache):
|
|
320
326
|
# transfer prepared data from host to device
|
321
327
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
322
328
|
k_data, v_data = flat_data[0], flat_data[1]
|
323
|
-
self.k_buffer[layer_id][indices] = k_data
|
324
|
-
self.v_buffer[layer_id][indices] = v_data
|
329
|
+
self.k_buffer[layer_id - self.start_layer][indices] = k_data
|
330
|
+
self.v_buffer[layer_id - self.start_layer][indices] = v_data
|
325
331
|
|
326
332
|
def get_key_buffer(self, layer_id: int):
|
327
333
|
if self.layer_transfer_counter is not None:
|
328
|
-
self.layer_transfer_counter.wait_until(layer_id)
|
334
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
329
335
|
|
330
336
|
if self.store_dtype != self.dtype:
|
331
|
-
return self.k_buffer[layer_id].view(self.dtype)
|
332
|
-
return self.k_buffer[layer_id]
|
337
|
+
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
338
|
+
return self.k_buffer[layer_id - self.start_layer]
|
333
339
|
|
334
340
|
def get_value_buffer(self, layer_id: int):
|
335
341
|
if self.layer_transfer_counter is not None:
|
336
|
-
self.layer_transfer_counter.wait_until(layer_id)
|
342
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
337
343
|
|
338
344
|
if self.store_dtype != self.dtype:
|
339
|
-
return self.v_buffer[layer_id].view(self.dtype)
|
340
|
-
return self.v_buffer[layer_id]
|
345
|
+
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
346
|
+
return self.v_buffer[layer_id - self.start_layer]
|
341
347
|
|
342
348
|
def get_kv_buffer(self, layer_id: int):
|
343
349
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
@@ -368,13 +374,13 @@ class MHATokenToKVPool(KVCache):
|
|
368
374
|
# Overlap the copy of K and V cache for small batch size
|
369
375
|
current_stream = self.device_module.current_stream()
|
370
376
|
self.alt_stream.wait_stream(current_stream)
|
377
|
+
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
371
378
|
with self.device_module.stream(self.alt_stream):
|
372
|
-
self.
|
373
|
-
self.v_buffer[layer_id][loc] = cache_v
|
379
|
+
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
374
380
|
current_stream.wait_stream(self.alt_stream)
|
375
381
|
else:
|
376
|
-
self.k_buffer[layer_id][loc] = cache_k
|
377
|
-
self.v_buffer[layer_id][loc] = cache_v
|
382
|
+
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
383
|
+
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
378
384
|
|
379
385
|
|
380
386
|
@torch.compile
|
@@ -484,6 +490,8 @@ class MLATokenToKVPool(KVCache):
|
|
484
490
|
layer_num: int,
|
485
491
|
device: str,
|
486
492
|
enable_memory_saver: bool,
|
493
|
+
start_layer: Optional[int] = None,
|
494
|
+
end_layer: Optional[int] = None,
|
487
495
|
):
|
488
496
|
self.size = size
|
489
497
|
self.page_size = page_size
|
@@ -497,6 +505,8 @@ class MLATokenToKVPool(KVCache):
|
|
497
505
|
self.kv_lora_rank = kv_lora_rank
|
498
506
|
self.qk_rope_head_dim = qk_rope_head_dim
|
499
507
|
self.layer_num = layer_num
|
508
|
+
self.start_layer = start_layer or 0
|
509
|
+
self.end_layer = end_layer or layer_num - 1
|
500
510
|
|
501
511
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
502
512
|
enable=enable_memory_saver
|
@@ -540,19 +550,21 @@ class MLATokenToKVPool(KVCache):
|
|
540
550
|
|
541
551
|
def get_key_buffer(self, layer_id: int):
|
542
552
|
if self.layer_transfer_counter is not None:
|
543
|
-
self.layer_transfer_counter.wait_until(layer_id)
|
553
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
544
554
|
|
545
555
|
if self.store_dtype != self.dtype:
|
546
|
-
return self.kv_buffer[layer_id].view(self.dtype)
|
547
|
-
return self.kv_buffer[layer_id]
|
556
|
+
return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
|
557
|
+
return self.kv_buffer[layer_id - self.start_layer]
|
548
558
|
|
549
559
|
def get_value_buffer(self, layer_id: int):
|
550
560
|
if self.layer_transfer_counter is not None:
|
551
|
-
self.layer_transfer_counter.wait_until(layer_id)
|
561
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
552
562
|
|
553
563
|
if self.store_dtype != self.dtype:
|
554
|
-
return self.kv_buffer[layer_id
|
555
|
-
|
564
|
+
return self.kv_buffer[layer_id - self.start_layer][
|
565
|
+
..., : self.kv_lora_rank
|
566
|
+
].view(self.dtype)
|
567
|
+
return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank]
|
556
568
|
|
557
569
|
def get_kv_buffer(self, layer_id: int):
|
558
570
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
@@ -568,9 +580,11 @@ class MLATokenToKVPool(KVCache):
|
|
568
580
|
if cache_k.dtype != self.dtype:
|
569
581
|
cache_k = cache_k.to(self.dtype)
|
570
582
|
if self.store_dtype != self.dtype:
|
571
|
-
self.kv_buffer[layer_id][loc] = cache_k.view(
|
583
|
+
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view(
|
584
|
+
self.store_dtype
|
585
|
+
)
|
572
586
|
else:
|
573
|
-
self.kv_buffer[layer_id][loc] = cache_k
|
587
|
+
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
|
574
588
|
|
575
589
|
def set_mla_kv_buffer(
|
576
590
|
self,
|
@@ -605,7 +619,7 @@ class MLATokenToKVPool(KVCache):
|
|
605
619
|
def transfer_per_layer(self, indices, flat_data, layer_id):
|
606
620
|
# transfer prepared data from host to device
|
607
621
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
608
|
-
self.kv_buffer[layer_id][indices] = flat_data
|
622
|
+
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
|
609
623
|
|
610
624
|
|
611
625
|
class DoubleSparseTokenToKVPool(KVCache):
|
@@ -620,6 +634,8 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
620
634
|
device: str,
|
621
635
|
heavy_channel_num: int,
|
622
636
|
enable_memory_saver: bool,
|
637
|
+
start_layer: Optional[int] = None,
|
638
|
+
end_layer: Optional[int] = None,
|
623
639
|
):
|
624
640
|
self.size = size
|
625
641
|
self.page_size = page_size
|
@@ -657,17 +673,23 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
657
673
|
for _ in range(layer_num)
|
658
674
|
]
|
659
675
|
|
676
|
+
self.start_layer = start_layer or 0
|
677
|
+
self.end_layer = end_layer or layer_num - 1
|
678
|
+
|
660
679
|
def get_key_buffer(self, layer_id: int):
|
661
|
-
return self.k_buffer[layer_id]
|
680
|
+
return self.k_buffer[layer_id - self.start_layer]
|
662
681
|
|
663
682
|
def get_value_buffer(self, layer_id: int):
|
664
|
-
return self.v_buffer[layer_id]
|
683
|
+
return self.v_buffer[layer_id - self.start_layer]
|
665
684
|
|
666
685
|
def get_label_buffer(self, layer_id: int):
|
667
|
-
return self.label_buffer[layer_id]
|
686
|
+
return self.label_buffer[layer_id - self.start_layer]
|
668
687
|
|
669
688
|
def get_kv_buffer(self, layer_id: int):
|
670
|
-
return
|
689
|
+
return (
|
690
|
+
self.k_buffer[layer_id - self.start_layer],
|
691
|
+
self.v_buffer[layer_id - self.start_layer],
|
692
|
+
)
|
671
693
|
|
672
694
|
def set_kv_buffer(
|
673
695
|
self,
|
@@ -679,9 +701,9 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
679
701
|
):
|
680
702
|
# NOTE(Andy): ignore the dtype check
|
681
703
|
layer_id = layer.layer_id
|
682
|
-
self.k_buffer[layer_id][loc] = cache_k
|
683
|
-
self.v_buffer[layer_id][loc] = cache_v
|
684
|
-
self.label_buffer[layer_id][loc] = cache_label
|
704
|
+
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
705
|
+
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
706
|
+
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
685
707
|
|
686
708
|
def get_flat_data(self, indices):
|
687
709
|
pass
|
@@ -930,7 +952,7 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
930
952
|
return self.kv_buffer[:, :, indices]
|
931
953
|
|
932
954
|
def get_flat_data_by_layer(self, indices, layer_id):
|
933
|
-
return self.kv_buffer[:, layer_id, indices]
|
955
|
+
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
934
956
|
|
935
957
|
def assign_flat_data(self, indices, flat_data):
|
936
958
|
self.kv_buffer[:, :, indices] = flat_data
|
@@ -955,12 +977,20 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
955
977
|
for i in range(len(device_indices_cpu)):
|
956
978
|
h_index = host_indices[i * self.page_size]
|
957
979
|
d_index = device_indices_cpu[i]
|
958
|
-
device_pool.k_buffer[layer_id
|
959
|
-
|
980
|
+
device_pool.k_buffer[layer_id - self.start_layer][
|
981
|
+
d_index : d_index + self.page_size
|
982
|
+
].copy_(
|
983
|
+
self.kv_buffer[
|
984
|
+
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
985
|
+
],
|
960
986
|
non_blocking=True,
|
961
987
|
)
|
962
|
-
device_pool.v_buffer[layer_id
|
963
|
-
|
988
|
+
device_pool.v_buffer[layer_id - self.start_layer][
|
989
|
+
d_index : d_index + self.page_size
|
990
|
+
].copy_(
|
991
|
+
self.kv_buffer[
|
992
|
+
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
993
|
+
],
|
964
994
|
non_blocking=True,
|
965
995
|
)
|
966
996
|
|
@@ -1015,7 +1045,7 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
1015
1045
|
return self.kv_buffer[:, indices]
|
1016
1046
|
|
1017
1047
|
def get_flat_data_by_layer(self, indices, layer_id):
|
1018
|
-
return self.kv_buffer[layer_id, indices]
|
1048
|
+
return self.kv_buffer[layer_id - self.start_layer, indices]
|
1019
1049
|
|
1020
1050
|
def assign_flat_data(self, indices, flat_data):
|
1021
1051
|
self.kv_buffer[:, indices] = flat_data
|
@@ -1036,7 +1066,11 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
1036
1066
|
for i in range(len(device_indices_cpu)):
|
1037
1067
|
h_index = host_indices[i * self.page_size]
|
1038
1068
|
d_index = device_indices_cpu[i]
|
1039
|
-
device_pool.kv_buffer[layer_id
|
1040
|
-
|
1069
|
+
device_pool.kv_buffer[layer_id - self.start_layer][
|
1070
|
+
d_index : d_index + self.page_size
|
1071
|
+
].copy_(
|
1072
|
+
self.kv_buffer[
|
1073
|
+
layer_id - self.start_layer, h_index : h_index + self.page_size
|
1074
|
+
],
|
1041
1075
|
non_blocking=True,
|
1042
1076
|
)
|