sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/check_env.py +3 -3
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +15 -0
- sglang/srt/conversation.py +122 -1
- sglang/srt/disaggregation/decode.py +8 -2
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/prefill.py +12 -3
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +52 -21
- sglang/srt/entrypoints/http_server.py +27 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/flashinfer_backend.py +107 -82
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +1 -1
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +84 -35
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +34 -15
- sglang/srt/managers/scheduler.py +273 -67
- sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
- sglang/srt/managers/tp_worker.py +52 -17
- sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +123 -58
- sglang/srt/models/deepseek_nextn.py +1 -257
- sglang/srt/models/deepseek_v2.py +78 -18
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +92 -30
- sglang/srt/models/llama4.py +2 -1
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +0 -12
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/openai_api/adapter.py +49 -8
- sglang/srt/openai_api/protocol.py +13 -1
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +83 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +91 -9
- sglang/test/runners.py +4 -0
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +67 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import threading
|
3
4
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
4
5
|
|
5
6
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
@@ -11,6 +12,7 @@ if TYPE_CHECKING:
|
|
11
12
|
EmbeddingBatchResult,
|
12
13
|
GenerationBatchResult,
|
13
14
|
ScheduleBatch,
|
15
|
+
Scheduler,
|
14
16
|
)
|
15
17
|
|
16
18
|
|
@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
|
|
21
23
|
"""
|
22
24
|
|
23
25
|
def process_batch_result_prefill(
|
24
|
-
self,
|
26
|
+
self: Scheduler,
|
25
27
|
batch: ScheduleBatch,
|
26
28
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
29
|
+
launch_done: Optional[threading.Event] = None,
|
27
30
|
):
|
28
31
|
skip_stream_req = None
|
29
32
|
|
@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
|
|
43
46
|
)
|
44
47
|
|
45
48
|
if self.enable_overlap:
|
46
|
-
logits_output, next_token_ids =
|
49
|
+
logits_output, next_token_ids = (
|
50
|
+
self.tp_worker.resolve_last_batch_result(
|
51
|
+
launch_done,
|
52
|
+
)
|
53
|
+
)
|
47
54
|
else:
|
48
55
|
# Move next_token_ids and logprobs to cpu
|
49
56
|
next_token_ids = next_token_ids.tolist()
|
@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
|
|
175
182
|
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
176
183
|
|
177
184
|
def process_batch_result_decode(
|
178
|
-
self,
|
185
|
+
self: Scheduler,
|
179
186
|
batch: ScheduleBatch,
|
180
187
|
result: GenerationBatchResult,
|
188
|
+
launch_done: Optional[threading.Event] = None,
|
181
189
|
):
|
182
190
|
logits_output, next_token_ids, bid = (
|
183
191
|
result.logits_output,
|
@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
|
|
187
195
|
self.num_generated_tokens += len(batch.reqs)
|
188
196
|
|
189
197
|
if self.enable_overlap:
|
190
|
-
logits_output, next_token_ids = self.tp_worker.
|
198
|
+
logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result(
|
199
|
+
launch_done
|
200
|
+
)
|
191
201
|
next_token_logprobs = logits_output.next_token_logprobs
|
192
202
|
elif batch.spec_algorithm.is_none():
|
193
203
|
# spec decoding handles output logprobs inside verify process.
|
@@ -268,10 +278,10 @@ class SchedulerOutputProcessorMixin:
|
|
268
278
|
self.attn_tp_rank == 0
|
269
279
|
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
270
280
|
):
|
271
|
-
self.log_decode_stats()
|
281
|
+
self.log_decode_stats(running_batch=batch)
|
272
282
|
|
273
283
|
def add_input_logprob_return_values(
|
274
|
-
self,
|
284
|
+
self: Scheduler,
|
275
285
|
i: int,
|
276
286
|
req: Req,
|
277
287
|
output: LogitsProcessorOutput,
|
@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
|
|
405
415
|
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
|
406
416
|
|
407
417
|
def add_logprob_return_values(
|
408
|
-
self,
|
418
|
+
self: Scheduler,
|
409
419
|
i: int,
|
410
420
|
req: Req,
|
411
421
|
pt: int,
|
@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
|
|
436
446
|
return num_input_logprobs
|
437
447
|
|
438
448
|
def stream_output(
|
439
|
-
self
|
449
|
+
self: Scheduler,
|
450
|
+
reqs: List[Req],
|
451
|
+
return_logprob: bool,
|
452
|
+
skip_req: Optional[Req] = None,
|
440
453
|
):
|
441
454
|
"""Stream the output to detokenizer."""
|
442
455
|
if self.is_generation:
|
@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
|
|
445
458
|
self.stream_output_embedding(reqs)
|
446
459
|
|
447
460
|
def stream_output_generation(
|
448
|
-
self
|
461
|
+
self: Scheduler,
|
462
|
+
reqs: List[Req],
|
463
|
+
return_logprob: bool,
|
464
|
+
skip_req: Optional[Req] = None,
|
449
465
|
):
|
450
466
|
rids = []
|
451
467
|
finished_reasons: List[BaseFinishReason] = []
|
@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
|
|
593
609
|
)
|
594
610
|
)
|
595
611
|
|
596
|
-
def stream_output_embedding(self, reqs: List[Req]):
|
612
|
+
def stream_output_embedding(self: Scheduler, reqs: List[Req]):
|
597
613
|
rids = []
|
598
614
|
finished_reasons: List[BaseFinishReason] = []
|
599
615
|
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -15,11 +15,12 @@
|
|
15
15
|
|
16
16
|
import logging
|
17
17
|
import threading
|
18
|
-
from typing import Optional, Tuple
|
18
|
+
from typing import Optional, Tuple, Union
|
19
19
|
|
20
20
|
import torch
|
21
21
|
|
22
22
|
from sglang.srt.configs.model_config import ModelConfig
|
23
|
+
from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
|
23
24
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
24
25
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
25
26
|
from sglang.srt.managers.io_struct import (
|
@@ -31,7 +32,7 @@ from sglang.srt.managers.io_struct import (
|
|
31
32
|
)
|
32
33
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
33
34
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
34
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
35
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
35
36
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
36
37
|
from sglang.srt.server_args import ServerArgs
|
37
38
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
@@ -47,6 +48,7 @@ class TpModelWorker:
|
|
47
48
|
server_args: ServerArgs,
|
48
49
|
gpu_id: int,
|
49
50
|
tp_rank: int,
|
51
|
+
pp_rank: int,
|
50
52
|
dp_rank: Optional[int],
|
51
53
|
nccl_port: int,
|
52
54
|
is_draft_worker: bool = False,
|
@@ -54,7 +56,9 @@ class TpModelWorker:
|
|
54
56
|
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
55
57
|
):
|
56
58
|
# Parse args
|
59
|
+
self.tp_size = server_args.tp_size
|
57
60
|
self.tp_rank = tp_rank
|
61
|
+
self.pp_rank = pp_rank
|
58
62
|
|
59
63
|
# Init model and tokenizer
|
60
64
|
self.model_config = ModelConfig(
|
@@ -71,13 +75,17 @@ class TpModelWorker:
|
|
71
75
|
enable_multimodal=server_args.enable_multimodal,
|
72
76
|
dtype=server_args.dtype,
|
73
77
|
quantization=server_args.quantization,
|
78
|
+
is_draft_model=is_draft_worker,
|
74
79
|
)
|
80
|
+
|
75
81
|
self.model_runner = ModelRunner(
|
76
82
|
model_config=self.model_config,
|
77
83
|
mem_fraction_static=server_args.mem_fraction_static,
|
78
84
|
gpu_id=gpu_id,
|
79
85
|
tp_rank=tp_rank,
|
80
86
|
tp_size=server_args.tp_size,
|
87
|
+
pp_rank=pp_rank,
|
88
|
+
pp_size=server_args.pp_size,
|
81
89
|
nccl_port=nccl_port,
|
82
90
|
server_args=server_args,
|
83
91
|
is_draft_worker=is_draft_worker,
|
@@ -104,6 +112,10 @@ class TpModelWorker:
|
|
104
112
|
)
|
105
113
|
self.device = self.model_runner.device
|
106
114
|
|
115
|
+
# Init nccl groups
|
116
|
+
self.pp_group = get_pp_group()
|
117
|
+
self.world_group = get_world_group()
|
118
|
+
|
107
119
|
# Profile number of tokens
|
108
120
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
109
121
|
self.max_prefill_tokens = server_args.max_prefill_tokens
|
@@ -129,8 +141,9 @@ class TpModelWorker:
|
|
129
141
|
# Sync random seed across TP workers
|
130
142
|
self.random_seed = broadcast_pyobj(
|
131
143
|
[server_args.random_seed],
|
132
|
-
self.tp_rank,
|
133
|
-
self.
|
144
|
+
self.tp_size * self.pp_rank + tp_rank,
|
145
|
+
self.world_group.cpu_group,
|
146
|
+
src=self.world_group.ranks[0],
|
134
147
|
)[0]
|
135
148
|
set_random_seed(self.random_seed)
|
136
149
|
|
@@ -155,11 +168,14 @@ class TpModelWorker:
|
|
155
168
|
def get_pad_input_ids_func(self):
|
156
169
|
return getattr(self.model_runner.model, "pad_input_ids", None)
|
157
170
|
|
158
|
-
def
|
159
|
-
return self.model_runner.tp_group
|
171
|
+
def get_tp_group(self):
|
172
|
+
return self.model_runner.tp_group
|
173
|
+
|
174
|
+
def get_attention_tp_group(self):
|
175
|
+
return self.model_runner.attention_tp_group
|
160
176
|
|
161
177
|
def get_attention_tp_cpu_group(self):
|
162
|
-
return self.model_runner.attention_tp_group
|
178
|
+
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
|
163
179
|
|
164
180
|
def get_memory_pool(self):
|
165
181
|
return (
|
@@ -170,20 +186,39 @@ class TpModelWorker:
|
|
170
186
|
def forward_batch_generation(
|
171
187
|
self,
|
172
188
|
model_worker_batch: ModelWorkerBatch,
|
173
|
-
launch_done: Optional[threading.Event] = None,
|
174
189
|
skip_sample: bool = False,
|
175
|
-
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
|
190
|
+
) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]:
|
176
191
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
177
|
-
logits_output = self.model_runner.forward(forward_batch)
|
178
|
-
if launch_done:
|
179
|
-
launch_done.set()
|
180
192
|
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
193
|
+
pp_proxy_tensors = None
|
194
|
+
if not self.pp_group.is_first_rank:
|
195
|
+
pp_proxy_tensors = PPProxyTensors(
|
196
|
+
self.pp_group.recv_tensor_dict(
|
197
|
+
all_gather_group=self.get_attention_tp_group()
|
198
|
+
)
|
199
|
+
)
|
185
200
|
|
186
|
-
|
201
|
+
if self.pp_group.is_last_rank:
|
202
|
+
logits_output = self.model_runner.forward(
|
203
|
+
forward_batch, pp_proxy_tensors=pp_proxy_tensors
|
204
|
+
)
|
205
|
+
if model_worker_batch.launch_done is not None:
|
206
|
+
model_worker_batch.launch_done.set()
|
207
|
+
|
208
|
+
if skip_sample:
|
209
|
+
next_token_ids = None
|
210
|
+
else:
|
211
|
+
next_token_ids = self.model_runner.sample(
|
212
|
+
logits_output, model_worker_batch
|
213
|
+
)
|
214
|
+
|
215
|
+
return logits_output, next_token_ids
|
216
|
+
else:
|
217
|
+
pp_proxy_tensors = self.model_runner.forward(
|
218
|
+
forward_batch,
|
219
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
220
|
+
)
|
221
|
+
return pp_proxy_tensors.tensors, None
|
187
222
|
|
188
223
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
189
224
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
@@ -56,11 +56,14 @@ class TpModelWorkerClient:
|
|
56
56
|
server_args: ServerArgs,
|
57
57
|
gpu_id: int,
|
58
58
|
tp_rank: int,
|
59
|
+
pp_rank: int,
|
59
60
|
dp_rank: Optional[int],
|
60
61
|
nccl_port: int,
|
61
62
|
):
|
62
63
|
# Load the model
|
63
|
-
self.worker = TpModelWorker(
|
64
|
+
self.worker = TpModelWorker(
|
65
|
+
server_args, gpu_id, tp_rank, pp_rank, dp_rank, nccl_port
|
66
|
+
)
|
64
67
|
self.max_running_requests = self.worker.max_running_requests
|
65
68
|
self.device = self.worker.device
|
66
69
|
self.gpu_id = gpu_id
|
@@ -91,8 +94,11 @@ class TpModelWorkerClient:
|
|
91
94
|
def get_pad_input_ids_func(self):
|
92
95
|
return self.worker.get_pad_input_ids_func()
|
93
96
|
|
94
|
-
def
|
95
|
-
return self.worker.
|
97
|
+
def get_tp_group(self):
|
98
|
+
return self.worker.get_tp_group()
|
99
|
+
|
100
|
+
def get_attention_tp_group(self):
|
101
|
+
return self.worker.get_attention_tp_group()
|
96
102
|
|
97
103
|
def get_attention_tp_cpu_group(self):
|
98
104
|
return self.worker.get_attention_tp_cpu_group()
|
@@ -132,7 +138,6 @@ class TpModelWorkerClient:
|
|
132
138
|
batch_pt += 1
|
133
139
|
|
134
140
|
# Create event
|
135
|
-
self.launch_done = threading.Event()
|
136
141
|
copy_done = torch.get_device_module(self.device).Event()
|
137
142
|
|
138
143
|
# Resolve future tokens in the input
|
@@ -141,7 +146,7 @@ class TpModelWorkerClient:
|
|
141
146
|
|
142
147
|
# Run forward
|
143
148
|
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
144
|
-
model_worker_batch
|
149
|
+
model_worker_batch
|
145
150
|
)
|
146
151
|
|
147
152
|
# Update the future token ids map
|
@@ -168,10 +173,16 @@ class TpModelWorkerClient:
|
|
168
173
|
|
169
174
|
self.output_queue.put((copy_done, logits_output, next_token_ids))
|
170
175
|
|
171
|
-
def
|
176
|
+
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
|
177
|
+
"""
|
178
|
+
This function is called to resolve the last batch result and
|
179
|
+
wait for the current batch to be launched. Used in overlap mode.
|
180
|
+
"""
|
172
181
|
copy_done, logits_output, next_token_ids = self.output_queue.get()
|
182
|
+
|
183
|
+
if launch_done is not None:
|
184
|
+
launch_done.wait()
|
173
185
|
copy_done.synchronize()
|
174
|
-
self.launch_done.wait()
|
175
186
|
|
176
187
|
if logits_output.next_token_logprobs is not None:
|
177
188
|
logits_output.next_token_logprobs = (
|
@@ -214,6 +214,8 @@ class MHATokenToKVPool(KVCache):
|
|
214
214
|
layer_num: int,
|
215
215
|
device: str,
|
216
216
|
enable_memory_saver: bool,
|
217
|
+
start_layer: Optional[int] = None,
|
218
|
+
end_layer: Optional[int] = None,
|
217
219
|
):
|
218
220
|
self.size = size
|
219
221
|
self.page_size = page_size
|
@@ -232,6 +234,8 @@ class MHATokenToKVPool(KVCache):
|
|
232
234
|
self.head_dim = head_dim
|
233
235
|
self.layer_num = layer_num
|
234
236
|
self._create_buffers()
|
237
|
+
self.start_layer = start_layer or 0
|
238
|
+
self.end_layer = end_layer or layer_num - 1
|
235
239
|
|
236
240
|
self.layer_transfer_counter = None
|
237
241
|
self.capture_mode = False
|
@@ -281,6 +285,8 @@ class MHATokenToKVPool(KVCache):
|
|
281
285
|
|
282
286
|
# for disagg
|
283
287
|
def get_contiguous_buf_infos(self):
|
288
|
+
# layer_num x [seq_len, head_num, head_dim]
|
289
|
+
# layer_num x [page_num, page_size, head_num, head_dim]
|
284
290
|
kv_data_ptrs = [
|
285
291
|
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
|
286
292
|
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
|
@@ -320,24 +326,24 @@ class MHATokenToKVPool(KVCache):
|
|
320
326
|
# transfer prepared data from host to device
|
321
327
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
322
328
|
k_data, v_data = flat_data[0], flat_data[1]
|
323
|
-
self.k_buffer[layer_id][indices] = k_data
|
324
|
-
self.v_buffer[layer_id][indices] = v_data
|
329
|
+
self.k_buffer[layer_id - self.start_layer][indices] = k_data
|
330
|
+
self.v_buffer[layer_id - self.start_layer][indices] = v_data
|
325
331
|
|
326
332
|
def get_key_buffer(self, layer_id: int):
|
327
333
|
if self.layer_transfer_counter is not None:
|
328
|
-
self.layer_transfer_counter.wait_until(layer_id)
|
334
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
329
335
|
|
330
336
|
if self.store_dtype != self.dtype:
|
331
|
-
return self.k_buffer[layer_id].view(self.dtype)
|
332
|
-
return self.k_buffer[layer_id]
|
337
|
+
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
338
|
+
return self.k_buffer[layer_id - self.start_layer]
|
333
339
|
|
334
340
|
def get_value_buffer(self, layer_id: int):
|
335
341
|
if self.layer_transfer_counter is not None:
|
336
|
-
self.layer_transfer_counter.wait_until(layer_id)
|
342
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
337
343
|
|
338
344
|
if self.store_dtype != self.dtype:
|
339
|
-
return self.v_buffer[layer_id].view(self.dtype)
|
340
|
-
return self.v_buffer[layer_id]
|
345
|
+
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
346
|
+
return self.v_buffer[layer_id - self.start_layer]
|
341
347
|
|
342
348
|
def get_kv_buffer(self, layer_id: int):
|
343
349
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
@@ -369,12 +375,12 @@ class MHATokenToKVPool(KVCache):
|
|
369
375
|
current_stream = self.device_module.current_stream()
|
370
376
|
self.alt_stream.wait_stream(current_stream)
|
371
377
|
with self.device_module.stream(self.alt_stream):
|
372
|
-
self.k_buffer[layer_id][loc] = cache_k
|
373
|
-
self.v_buffer[layer_id][loc] = cache_v
|
378
|
+
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
379
|
+
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
374
380
|
current_stream.wait_stream(self.alt_stream)
|
375
381
|
else:
|
376
|
-
self.k_buffer[layer_id][loc] = cache_k
|
377
|
-
self.v_buffer[layer_id][loc] = cache_v
|
382
|
+
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
383
|
+
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
378
384
|
|
379
385
|
|
380
386
|
@torch.compile
|
@@ -484,6 +490,8 @@ class MLATokenToKVPool(KVCache):
|
|
484
490
|
layer_num: int,
|
485
491
|
device: str,
|
486
492
|
enable_memory_saver: bool,
|
493
|
+
start_layer: Optional[int] = None,
|
494
|
+
end_layer: Optional[int] = None,
|
487
495
|
):
|
488
496
|
self.size = size
|
489
497
|
self.page_size = page_size
|
@@ -497,6 +505,8 @@ class MLATokenToKVPool(KVCache):
|
|
497
505
|
self.kv_lora_rank = kv_lora_rank
|
498
506
|
self.qk_rope_head_dim = qk_rope_head_dim
|
499
507
|
self.layer_num = layer_num
|
508
|
+
self.start_layer = start_layer or 0
|
509
|
+
self.end_layer = end_layer or layer_num - 1
|
500
510
|
|
501
511
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
502
512
|
enable=enable_memory_saver
|
@@ -540,19 +550,21 @@ class MLATokenToKVPool(KVCache):
|
|
540
550
|
|
541
551
|
def get_key_buffer(self, layer_id: int):
|
542
552
|
if self.layer_transfer_counter is not None:
|
543
|
-
self.layer_transfer_counter.wait_until(layer_id)
|
553
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
544
554
|
|
545
555
|
if self.store_dtype != self.dtype:
|
546
|
-
return self.kv_buffer[layer_id].view(self.dtype)
|
547
|
-
return self.kv_buffer[layer_id]
|
556
|
+
return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
|
557
|
+
return self.kv_buffer[layer_id - self.start_layer]
|
548
558
|
|
549
559
|
def get_value_buffer(self, layer_id: int):
|
550
560
|
if self.layer_transfer_counter is not None:
|
551
|
-
self.layer_transfer_counter.wait_until(layer_id)
|
561
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
552
562
|
|
553
563
|
if self.store_dtype != self.dtype:
|
554
|
-
return self.kv_buffer[layer_id
|
555
|
-
|
564
|
+
return self.kv_buffer[layer_id - self.start_layer][
|
565
|
+
..., : self.kv_lora_rank
|
566
|
+
].view(self.dtype)
|
567
|
+
return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank]
|
556
568
|
|
557
569
|
def get_kv_buffer(self, layer_id: int):
|
558
570
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
@@ -568,9 +580,11 @@ class MLATokenToKVPool(KVCache):
|
|
568
580
|
if cache_k.dtype != self.dtype:
|
569
581
|
cache_k = cache_k.to(self.dtype)
|
570
582
|
if self.store_dtype != self.dtype:
|
571
|
-
self.kv_buffer[layer_id][loc] = cache_k.view(
|
583
|
+
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view(
|
584
|
+
self.store_dtype
|
585
|
+
)
|
572
586
|
else:
|
573
|
-
self.kv_buffer[layer_id][loc] = cache_k
|
587
|
+
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
|
574
588
|
|
575
589
|
def set_mla_kv_buffer(
|
576
590
|
self,
|
@@ -605,7 +619,7 @@ class MLATokenToKVPool(KVCache):
|
|
605
619
|
def transfer_per_layer(self, indices, flat_data, layer_id):
|
606
620
|
# transfer prepared data from host to device
|
607
621
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
608
|
-
self.kv_buffer[layer_id][indices] = flat_data
|
622
|
+
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
|
609
623
|
|
610
624
|
|
611
625
|
class DoubleSparseTokenToKVPool(KVCache):
|
@@ -620,6 +634,8 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
620
634
|
device: str,
|
621
635
|
heavy_channel_num: int,
|
622
636
|
enable_memory_saver: bool,
|
637
|
+
start_layer: Optional[int] = None,
|
638
|
+
end_layer: Optional[int] = None,
|
623
639
|
):
|
624
640
|
self.size = size
|
625
641
|
self.page_size = page_size
|
@@ -657,17 +673,23 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
657
673
|
for _ in range(layer_num)
|
658
674
|
]
|
659
675
|
|
676
|
+
self.start_layer = start_layer or 0
|
677
|
+
self.end_layer = end_layer or layer_num - 1
|
678
|
+
|
660
679
|
def get_key_buffer(self, layer_id: int):
|
661
|
-
return self.k_buffer[layer_id]
|
680
|
+
return self.k_buffer[layer_id - self.start_layer]
|
662
681
|
|
663
682
|
def get_value_buffer(self, layer_id: int):
|
664
|
-
return self.v_buffer[layer_id]
|
683
|
+
return self.v_buffer[layer_id - self.start_layer]
|
665
684
|
|
666
685
|
def get_label_buffer(self, layer_id: int):
|
667
|
-
return self.label_buffer[layer_id]
|
686
|
+
return self.label_buffer[layer_id - self.start_layer]
|
668
687
|
|
669
688
|
def get_kv_buffer(self, layer_id: int):
|
670
|
-
return
|
689
|
+
return (
|
690
|
+
self.k_buffer[layer_id - self.start_layer],
|
691
|
+
self.v_buffer[layer_id - self.start_layer],
|
692
|
+
)
|
671
693
|
|
672
694
|
def set_kv_buffer(
|
673
695
|
self,
|
@@ -679,9 +701,9 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
679
701
|
):
|
680
702
|
# NOTE(Andy): ignore the dtype check
|
681
703
|
layer_id = layer.layer_id
|
682
|
-
self.k_buffer[layer_id][loc] = cache_k
|
683
|
-
self.v_buffer[layer_id][loc] = cache_v
|
684
|
-
self.label_buffer[layer_id][loc] = cache_label
|
704
|
+
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
705
|
+
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
706
|
+
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
685
707
|
|
686
708
|
def get_flat_data(self, indices):
|
687
709
|
pass
|
@@ -930,7 +952,7 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
930
952
|
return self.kv_buffer[:, :, indices]
|
931
953
|
|
932
954
|
def get_flat_data_by_layer(self, indices, layer_id):
|
933
|
-
return self.kv_buffer[:, layer_id, indices]
|
955
|
+
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
934
956
|
|
935
957
|
def assign_flat_data(self, indices, flat_data):
|
936
958
|
self.kv_buffer[:, :, indices] = flat_data
|
@@ -955,12 +977,20 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
955
977
|
for i in range(len(device_indices_cpu)):
|
956
978
|
h_index = host_indices[i * self.page_size]
|
957
979
|
d_index = device_indices_cpu[i]
|
958
|
-
device_pool.k_buffer[layer_id
|
959
|
-
|
980
|
+
device_pool.k_buffer[layer_id - self.start_layer][
|
981
|
+
d_index : d_index + self.page_size
|
982
|
+
].copy_(
|
983
|
+
self.kv_buffer[
|
984
|
+
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
985
|
+
],
|
960
986
|
non_blocking=True,
|
961
987
|
)
|
962
|
-
device_pool.v_buffer[layer_id
|
963
|
-
|
988
|
+
device_pool.v_buffer[layer_id - self.start_layer][
|
989
|
+
d_index : d_index + self.page_size
|
990
|
+
].copy_(
|
991
|
+
self.kv_buffer[
|
992
|
+
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
993
|
+
],
|
964
994
|
non_blocking=True,
|
965
995
|
)
|
966
996
|
|
@@ -1015,7 +1045,7 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
1015
1045
|
return self.kv_buffer[:, indices]
|
1016
1046
|
|
1017
1047
|
def get_flat_data_by_layer(self, indices, layer_id):
|
1018
|
-
return self.kv_buffer[layer_id, indices]
|
1048
|
+
return self.kv_buffer[layer_id - self.start_layer, indices]
|
1019
1049
|
|
1020
1050
|
def assign_flat_data(self, indices, flat_data):
|
1021
1051
|
self.kv_buffer[:, indices] = flat_data
|
@@ -1036,7 +1066,11 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
1036
1066
|
for i in range(len(device_indices_cpu)):
|
1037
1067
|
h_index = host_indices[i * self.page_size]
|
1038
1068
|
d_index = device_indices_cpu[i]
|
1039
|
-
device_pool.kv_buffer[layer_id
|
1040
|
-
|
1069
|
+
device_pool.kv_buffer[layer_id - self.start_layer][
|
1070
|
+
d_index : d_index + self.page_size
|
1071
|
+
].copy_(
|
1072
|
+
self.kv_buffer[
|
1073
|
+
layer_id - self.start_layer, h_index : h_index + self.page_size
|
1074
|
+
],
|
1041
1075
|
non_blocking=True,
|
1042
1076
|
)
|