sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/check_env.py +3 -3
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/kimi_vl.py +38 -0
  5. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  6. sglang/srt/configs/model_config.py +15 -0
  7. sglang/srt/conversation.py +122 -1
  8. sglang/srt/disaggregation/decode.py +8 -2
  9. sglang/srt/disaggregation/fake/__init__.py +1 -0
  10. sglang/srt/disaggregation/fake/conn.py +88 -0
  11. sglang/srt/disaggregation/prefill.py +12 -3
  12. sglang/srt/disaggregation/utils.py +16 -2
  13. sglang/srt/entrypoints/engine.py +52 -21
  14. sglang/srt/entrypoints/http_server.py +27 -2
  15. sglang/srt/function_call_parser.py +97 -0
  16. sglang/srt/hf_transformers_utils.py +2 -0
  17. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  18. sglang/srt/layers/attention/flashinfer_backend.py +107 -82
  19. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
  20. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  21. sglang/srt/layers/attention/utils.py +1 -1
  22. sglang/srt/layers/dp_attention.py +5 -2
  23. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
  41. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  42. sglang/srt/layers/quantization/__init__.py +2 -2
  43. sglang/srt/layers/quantization/deep_gemm.py +1 -1
  44. sglang/srt/layers/quantization/fp8.py +20 -22
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/utils.py +35 -0
  47. sglang/srt/lora/layers.py +35 -9
  48. sglang/srt/lora/lora_manager.py +84 -35
  49. sglang/srt/managers/data_parallel_controller.py +52 -34
  50. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  51. sglang/srt/managers/schedule_batch.py +34 -15
  52. sglang/srt/managers/scheduler.py +273 -67
  53. sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
  54. sglang/srt/managers/tp_worker.py +52 -17
  55. sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
  56. sglang/srt/mem_cache/memory_pool.py +70 -36
  57. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  58. sglang/srt/model_executor/forward_batch_info.py +31 -1
  59. sglang/srt/model_executor/model_runner.py +123 -58
  60. sglang/srt/models/deepseek_nextn.py +1 -257
  61. sglang/srt/models/deepseek_v2.py +78 -18
  62. sglang/srt/models/kimi_vl.py +308 -0
  63. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  64. sglang/srt/models/llama.py +92 -30
  65. sglang/srt/models/llama4.py +2 -1
  66. sglang/srt/models/llama_eagle.py +4 -1
  67. sglang/srt/models/llama_eagle3.py +4 -1
  68. sglang/srt/models/qwen2_moe.py +8 -3
  69. sglang/srt/models/qwen2_vl.py +0 -12
  70. sglang/srt/models/qwen3_moe.py +8 -3
  71. sglang/srt/openai_api/adapter.py +49 -8
  72. sglang/srt/openai_api/protocol.py +13 -1
  73. sglang/srt/reasoning_parser.py +25 -1
  74. sglang/srt/server_args.py +83 -24
  75. sglang/srt/speculative/eagle_worker.py +3 -2
  76. sglang/srt/utils.py +91 -9
  77. sglang/test/runners.py +4 -0
  78. sglang/test/send_one.py +84 -28
  79. sglang/test/test_utils.py +67 -0
  80. sglang/version.py +1 -1
  81. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
  82. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
  83. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
  84. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
  85. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import threading
3
4
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
4
5
 
5
6
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -11,6 +12,7 @@ if TYPE_CHECKING:
11
12
  EmbeddingBatchResult,
12
13
  GenerationBatchResult,
13
14
  ScheduleBatch,
15
+ Scheduler,
14
16
  )
15
17
 
16
18
 
@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
21
23
  """
22
24
 
23
25
  def process_batch_result_prefill(
24
- self,
26
+ self: Scheduler,
25
27
  batch: ScheduleBatch,
26
28
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
29
+ launch_done: Optional[threading.Event] = None,
27
30
  ):
28
31
  skip_stream_req = None
29
32
 
@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
43
46
  )
44
47
 
45
48
  if self.enable_overlap:
46
- logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
49
+ logits_output, next_token_ids = (
50
+ self.tp_worker.resolve_last_batch_result(
51
+ launch_done,
52
+ )
53
+ )
47
54
  else:
48
55
  # Move next_token_ids and logprobs to cpu
49
56
  next_token_ids = next_token_ids.tolist()
@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
175
182
  self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
176
183
 
177
184
  def process_batch_result_decode(
178
- self,
185
+ self: Scheduler,
179
186
  batch: ScheduleBatch,
180
187
  result: GenerationBatchResult,
188
+ launch_done: Optional[threading.Event] = None,
181
189
  ):
182
190
  logits_output, next_token_ids, bid = (
183
191
  result.logits_output,
@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
187
195
  self.num_generated_tokens += len(batch.reqs)
188
196
 
189
197
  if self.enable_overlap:
190
- logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
198
+ logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result(
199
+ launch_done
200
+ )
191
201
  next_token_logprobs = logits_output.next_token_logprobs
192
202
  elif batch.spec_algorithm.is_none():
193
203
  # spec decoding handles output logprobs inside verify process.
@@ -268,10 +278,10 @@ class SchedulerOutputProcessorMixin:
268
278
  self.attn_tp_rank == 0
269
279
  and self.forward_ct_decode % self.server_args.decode_log_interval == 0
270
280
  ):
271
- self.log_decode_stats()
281
+ self.log_decode_stats(running_batch=batch)
272
282
 
273
283
  def add_input_logprob_return_values(
274
- self,
284
+ self: Scheduler,
275
285
  i: int,
276
286
  req: Req,
277
287
  output: LogitsProcessorOutput,
@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
405
415
  assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
406
416
 
407
417
  def add_logprob_return_values(
408
- self,
418
+ self: Scheduler,
409
419
  i: int,
410
420
  req: Req,
411
421
  pt: int,
@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
436
446
  return num_input_logprobs
437
447
 
438
448
  def stream_output(
439
- self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
449
+ self: Scheduler,
450
+ reqs: List[Req],
451
+ return_logprob: bool,
452
+ skip_req: Optional[Req] = None,
440
453
  ):
441
454
  """Stream the output to detokenizer."""
442
455
  if self.is_generation:
@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
445
458
  self.stream_output_embedding(reqs)
446
459
 
447
460
  def stream_output_generation(
448
- self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
461
+ self: Scheduler,
462
+ reqs: List[Req],
463
+ return_logprob: bool,
464
+ skip_req: Optional[Req] = None,
449
465
  ):
450
466
  rids = []
451
467
  finished_reasons: List[BaseFinishReason] = []
@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
593
609
  )
594
610
  )
595
611
 
596
- def stream_output_embedding(self, reqs: List[Req]):
612
+ def stream_output_embedding(self: Scheduler, reqs: List[Req]):
597
613
  rids = []
598
614
  finished_reasons: List[BaseFinishReason] = []
599
615
 
@@ -15,11 +15,12 @@
15
15
 
16
16
  import logging
17
17
  import threading
18
- from typing import Optional, Tuple
18
+ from typing import Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
 
22
22
  from sglang.srt.configs.model_config import ModelConfig
23
+ from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
23
24
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
24
25
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
25
26
  from sglang.srt.managers.io_struct import (
@@ -31,7 +32,7 @@ from sglang.srt.managers.io_struct import (
31
32
  )
32
33
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
33
34
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
34
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
35
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
35
36
  from sglang.srt.model_executor.model_runner import ModelRunner
36
37
  from sglang.srt.server_args import ServerArgs
37
38
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
@@ -47,6 +48,7 @@ class TpModelWorker:
47
48
  server_args: ServerArgs,
48
49
  gpu_id: int,
49
50
  tp_rank: int,
51
+ pp_rank: int,
50
52
  dp_rank: Optional[int],
51
53
  nccl_port: int,
52
54
  is_draft_worker: bool = False,
@@ -54,7 +56,9 @@ class TpModelWorker:
54
56
  token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
55
57
  ):
56
58
  # Parse args
59
+ self.tp_size = server_args.tp_size
57
60
  self.tp_rank = tp_rank
61
+ self.pp_rank = pp_rank
58
62
 
59
63
  # Init model and tokenizer
60
64
  self.model_config = ModelConfig(
@@ -71,13 +75,17 @@ class TpModelWorker:
71
75
  enable_multimodal=server_args.enable_multimodal,
72
76
  dtype=server_args.dtype,
73
77
  quantization=server_args.quantization,
78
+ is_draft_model=is_draft_worker,
74
79
  )
80
+
75
81
  self.model_runner = ModelRunner(
76
82
  model_config=self.model_config,
77
83
  mem_fraction_static=server_args.mem_fraction_static,
78
84
  gpu_id=gpu_id,
79
85
  tp_rank=tp_rank,
80
86
  tp_size=server_args.tp_size,
87
+ pp_rank=pp_rank,
88
+ pp_size=server_args.pp_size,
81
89
  nccl_port=nccl_port,
82
90
  server_args=server_args,
83
91
  is_draft_worker=is_draft_worker,
@@ -104,6 +112,10 @@ class TpModelWorker:
104
112
  )
105
113
  self.device = self.model_runner.device
106
114
 
115
+ # Init nccl groups
116
+ self.pp_group = get_pp_group()
117
+ self.world_group = get_world_group()
118
+
107
119
  # Profile number of tokens
108
120
  self.max_total_num_tokens = self.model_runner.max_total_num_tokens
109
121
  self.max_prefill_tokens = server_args.max_prefill_tokens
@@ -129,8 +141,9 @@ class TpModelWorker:
129
141
  # Sync random seed across TP workers
130
142
  self.random_seed = broadcast_pyobj(
131
143
  [server_args.random_seed],
132
- self.tp_rank,
133
- self.model_runner.tp_group.cpu_group,
144
+ self.tp_size * self.pp_rank + tp_rank,
145
+ self.world_group.cpu_group,
146
+ src=self.world_group.ranks[0],
134
147
  )[0]
135
148
  set_random_seed(self.random_seed)
136
149
 
@@ -155,11 +168,14 @@ class TpModelWorker:
155
168
  def get_pad_input_ids_func(self):
156
169
  return getattr(self.model_runner.model, "pad_input_ids", None)
157
170
 
158
- def get_tp_cpu_group(self):
159
- return self.model_runner.tp_group.cpu_group
171
+ def get_tp_group(self):
172
+ return self.model_runner.tp_group
173
+
174
+ def get_attention_tp_group(self):
175
+ return self.model_runner.attention_tp_group
160
176
 
161
177
  def get_attention_tp_cpu_group(self):
162
- return self.model_runner.attention_tp_group.cpu_group
178
+ return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
163
179
 
164
180
  def get_memory_pool(self):
165
181
  return (
@@ -170,20 +186,39 @@ class TpModelWorker:
170
186
  def forward_batch_generation(
171
187
  self,
172
188
  model_worker_batch: ModelWorkerBatch,
173
- launch_done: Optional[threading.Event] = None,
174
189
  skip_sample: bool = False,
175
- ) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
190
+ ) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]:
176
191
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
177
- logits_output = self.model_runner.forward(forward_batch)
178
- if launch_done:
179
- launch_done.set()
180
192
 
181
- if skip_sample:
182
- next_token_ids = None
183
- else:
184
- next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
193
+ pp_proxy_tensors = None
194
+ if not self.pp_group.is_first_rank:
195
+ pp_proxy_tensors = PPProxyTensors(
196
+ self.pp_group.recv_tensor_dict(
197
+ all_gather_group=self.get_attention_tp_group()
198
+ )
199
+ )
185
200
 
186
- return logits_output, next_token_ids
201
+ if self.pp_group.is_last_rank:
202
+ logits_output = self.model_runner.forward(
203
+ forward_batch, pp_proxy_tensors=pp_proxy_tensors
204
+ )
205
+ if model_worker_batch.launch_done is not None:
206
+ model_worker_batch.launch_done.set()
207
+
208
+ if skip_sample:
209
+ next_token_ids = None
210
+ else:
211
+ next_token_ids = self.model_runner.sample(
212
+ logits_output, model_worker_batch
213
+ )
214
+
215
+ return logits_output, next_token_ids
216
+ else:
217
+ pp_proxy_tensors = self.model_runner.forward(
218
+ forward_batch,
219
+ pp_proxy_tensors=pp_proxy_tensors,
220
+ )
221
+ return pp_proxy_tensors.tensors, None
187
222
 
188
223
  def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
189
224
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
@@ -56,11 +56,14 @@ class TpModelWorkerClient:
56
56
  server_args: ServerArgs,
57
57
  gpu_id: int,
58
58
  tp_rank: int,
59
+ pp_rank: int,
59
60
  dp_rank: Optional[int],
60
61
  nccl_port: int,
61
62
  ):
62
63
  # Load the model
63
- self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
64
+ self.worker = TpModelWorker(
65
+ server_args, gpu_id, tp_rank, pp_rank, dp_rank, nccl_port
66
+ )
64
67
  self.max_running_requests = self.worker.max_running_requests
65
68
  self.device = self.worker.device
66
69
  self.gpu_id = gpu_id
@@ -91,8 +94,11 @@ class TpModelWorkerClient:
91
94
  def get_pad_input_ids_func(self):
92
95
  return self.worker.get_pad_input_ids_func()
93
96
 
94
- def get_tp_cpu_group(self):
95
- return self.worker.get_tp_cpu_group()
97
+ def get_tp_group(self):
98
+ return self.worker.get_tp_group()
99
+
100
+ def get_attention_tp_group(self):
101
+ return self.worker.get_attention_tp_group()
96
102
 
97
103
  def get_attention_tp_cpu_group(self):
98
104
  return self.worker.get_attention_tp_cpu_group()
@@ -132,7 +138,6 @@ class TpModelWorkerClient:
132
138
  batch_pt += 1
133
139
 
134
140
  # Create event
135
- self.launch_done = threading.Event()
136
141
  copy_done = torch.get_device_module(self.device).Event()
137
142
 
138
143
  # Resolve future tokens in the input
@@ -141,7 +146,7 @@ class TpModelWorkerClient:
141
146
 
142
147
  # Run forward
143
148
  logits_output, next_token_ids = self.worker.forward_batch_generation(
144
- model_worker_batch, self.launch_done
149
+ model_worker_batch
145
150
  )
146
151
 
147
152
  # Update the future token ids map
@@ -168,10 +173,16 @@ class TpModelWorkerClient:
168
173
 
169
174
  self.output_queue.put((copy_done, logits_output, next_token_ids))
170
175
 
171
- def resolve_batch_result(self, bid: int):
176
+ def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
177
+ """
178
+ This function is called to resolve the last batch result and
179
+ wait for the current batch to be launched. Used in overlap mode.
180
+ """
172
181
  copy_done, logits_output, next_token_ids = self.output_queue.get()
182
+
183
+ if launch_done is not None:
184
+ launch_done.wait()
173
185
  copy_done.synchronize()
174
- self.launch_done.wait()
175
186
 
176
187
  if logits_output.next_token_logprobs is not None:
177
188
  logits_output.next_token_logprobs = (
@@ -214,6 +214,8 @@ class MHATokenToKVPool(KVCache):
214
214
  layer_num: int,
215
215
  device: str,
216
216
  enable_memory_saver: bool,
217
+ start_layer: Optional[int] = None,
218
+ end_layer: Optional[int] = None,
217
219
  ):
218
220
  self.size = size
219
221
  self.page_size = page_size
@@ -232,6 +234,8 @@ class MHATokenToKVPool(KVCache):
232
234
  self.head_dim = head_dim
233
235
  self.layer_num = layer_num
234
236
  self._create_buffers()
237
+ self.start_layer = start_layer or 0
238
+ self.end_layer = end_layer or layer_num - 1
235
239
 
236
240
  self.layer_transfer_counter = None
237
241
  self.capture_mode = False
@@ -281,6 +285,8 @@ class MHATokenToKVPool(KVCache):
281
285
 
282
286
  # for disagg
283
287
  def get_contiguous_buf_infos(self):
288
+ # layer_num x [seq_len, head_num, head_dim]
289
+ # layer_num x [page_num, page_size, head_num, head_dim]
284
290
  kv_data_ptrs = [
285
291
  self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
286
292
  ] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
@@ -320,24 +326,24 @@ class MHATokenToKVPool(KVCache):
320
326
  # transfer prepared data from host to device
321
327
  flat_data = flat_data.to(device=self.device, non_blocking=False)
322
328
  k_data, v_data = flat_data[0], flat_data[1]
323
- self.k_buffer[layer_id][indices] = k_data
324
- self.v_buffer[layer_id][indices] = v_data
329
+ self.k_buffer[layer_id - self.start_layer][indices] = k_data
330
+ self.v_buffer[layer_id - self.start_layer][indices] = v_data
325
331
 
326
332
  def get_key_buffer(self, layer_id: int):
327
333
  if self.layer_transfer_counter is not None:
328
- self.layer_transfer_counter.wait_until(layer_id)
334
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
329
335
 
330
336
  if self.store_dtype != self.dtype:
331
- return self.k_buffer[layer_id].view(self.dtype)
332
- return self.k_buffer[layer_id]
337
+ return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
338
+ return self.k_buffer[layer_id - self.start_layer]
333
339
 
334
340
  def get_value_buffer(self, layer_id: int):
335
341
  if self.layer_transfer_counter is not None:
336
- self.layer_transfer_counter.wait_until(layer_id)
342
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
337
343
 
338
344
  if self.store_dtype != self.dtype:
339
- return self.v_buffer[layer_id].view(self.dtype)
340
- return self.v_buffer[layer_id]
345
+ return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
346
+ return self.v_buffer[layer_id - self.start_layer]
341
347
 
342
348
  def get_kv_buffer(self, layer_id: int):
343
349
  return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
@@ -369,12 +375,12 @@ class MHATokenToKVPool(KVCache):
369
375
  current_stream = self.device_module.current_stream()
370
376
  self.alt_stream.wait_stream(current_stream)
371
377
  with self.device_module.stream(self.alt_stream):
372
- self.k_buffer[layer_id][loc] = cache_k
373
- self.v_buffer[layer_id][loc] = cache_v
378
+ self.k_buffer[layer_id - self.start_layer][loc] = cache_k
379
+ self.v_buffer[layer_id - self.start_layer][loc] = cache_v
374
380
  current_stream.wait_stream(self.alt_stream)
375
381
  else:
376
- self.k_buffer[layer_id][loc] = cache_k
377
- self.v_buffer[layer_id][loc] = cache_v
382
+ self.k_buffer[layer_id - self.start_layer][loc] = cache_k
383
+ self.v_buffer[layer_id - self.start_layer][loc] = cache_v
378
384
 
379
385
 
380
386
  @torch.compile
@@ -484,6 +490,8 @@ class MLATokenToKVPool(KVCache):
484
490
  layer_num: int,
485
491
  device: str,
486
492
  enable_memory_saver: bool,
493
+ start_layer: Optional[int] = None,
494
+ end_layer: Optional[int] = None,
487
495
  ):
488
496
  self.size = size
489
497
  self.page_size = page_size
@@ -497,6 +505,8 @@ class MLATokenToKVPool(KVCache):
497
505
  self.kv_lora_rank = kv_lora_rank
498
506
  self.qk_rope_head_dim = qk_rope_head_dim
499
507
  self.layer_num = layer_num
508
+ self.start_layer = start_layer or 0
509
+ self.end_layer = end_layer or layer_num - 1
500
510
 
501
511
  memory_saver_adapter = TorchMemorySaverAdapter.create(
502
512
  enable=enable_memory_saver
@@ -540,19 +550,21 @@ class MLATokenToKVPool(KVCache):
540
550
 
541
551
  def get_key_buffer(self, layer_id: int):
542
552
  if self.layer_transfer_counter is not None:
543
- self.layer_transfer_counter.wait_until(layer_id)
553
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
544
554
 
545
555
  if self.store_dtype != self.dtype:
546
- return self.kv_buffer[layer_id].view(self.dtype)
547
- return self.kv_buffer[layer_id]
556
+ return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
557
+ return self.kv_buffer[layer_id - self.start_layer]
548
558
 
549
559
  def get_value_buffer(self, layer_id: int):
550
560
  if self.layer_transfer_counter is not None:
551
- self.layer_transfer_counter.wait_until(layer_id)
561
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
552
562
 
553
563
  if self.store_dtype != self.dtype:
554
- return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
555
- return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
564
+ return self.kv_buffer[layer_id - self.start_layer][
565
+ ..., : self.kv_lora_rank
566
+ ].view(self.dtype)
567
+ return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank]
556
568
 
557
569
  def get_kv_buffer(self, layer_id: int):
558
570
  return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
@@ -568,9 +580,11 @@ class MLATokenToKVPool(KVCache):
568
580
  if cache_k.dtype != self.dtype:
569
581
  cache_k = cache_k.to(self.dtype)
570
582
  if self.store_dtype != self.dtype:
571
- self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
583
+ self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view(
584
+ self.store_dtype
585
+ )
572
586
  else:
573
- self.kv_buffer[layer_id][loc] = cache_k
587
+ self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
574
588
 
575
589
  def set_mla_kv_buffer(
576
590
  self,
@@ -605,7 +619,7 @@ class MLATokenToKVPool(KVCache):
605
619
  def transfer_per_layer(self, indices, flat_data, layer_id):
606
620
  # transfer prepared data from host to device
607
621
  flat_data = flat_data.to(device=self.device, non_blocking=False)
608
- self.kv_buffer[layer_id][indices] = flat_data
622
+ self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
609
623
 
610
624
 
611
625
  class DoubleSparseTokenToKVPool(KVCache):
@@ -620,6 +634,8 @@ class DoubleSparseTokenToKVPool(KVCache):
620
634
  device: str,
621
635
  heavy_channel_num: int,
622
636
  enable_memory_saver: bool,
637
+ start_layer: Optional[int] = None,
638
+ end_layer: Optional[int] = None,
623
639
  ):
624
640
  self.size = size
625
641
  self.page_size = page_size
@@ -657,17 +673,23 @@ class DoubleSparseTokenToKVPool(KVCache):
657
673
  for _ in range(layer_num)
658
674
  ]
659
675
 
676
+ self.start_layer = start_layer or 0
677
+ self.end_layer = end_layer or layer_num - 1
678
+
660
679
  def get_key_buffer(self, layer_id: int):
661
- return self.k_buffer[layer_id]
680
+ return self.k_buffer[layer_id - self.start_layer]
662
681
 
663
682
  def get_value_buffer(self, layer_id: int):
664
- return self.v_buffer[layer_id]
683
+ return self.v_buffer[layer_id - self.start_layer]
665
684
 
666
685
  def get_label_buffer(self, layer_id: int):
667
- return self.label_buffer[layer_id]
686
+ return self.label_buffer[layer_id - self.start_layer]
668
687
 
669
688
  def get_kv_buffer(self, layer_id: int):
670
- return self.k_buffer[layer_id], self.v_buffer[layer_id]
689
+ return (
690
+ self.k_buffer[layer_id - self.start_layer],
691
+ self.v_buffer[layer_id - self.start_layer],
692
+ )
671
693
 
672
694
  def set_kv_buffer(
673
695
  self,
@@ -679,9 +701,9 @@ class DoubleSparseTokenToKVPool(KVCache):
679
701
  ):
680
702
  # NOTE(Andy): ignore the dtype check
681
703
  layer_id = layer.layer_id
682
- self.k_buffer[layer_id][loc] = cache_k
683
- self.v_buffer[layer_id][loc] = cache_v
684
- self.label_buffer[layer_id][loc] = cache_label
704
+ self.k_buffer[layer_id - self.start_layer][loc] = cache_k
705
+ self.v_buffer[layer_id - self.start_layer][loc] = cache_v
706
+ self.label_buffer[layer_id - self.start_layer][loc] = cache_label
685
707
 
686
708
  def get_flat_data(self, indices):
687
709
  pass
@@ -930,7 +952,7 @@ class MHATokenToKVPoolHost(HostKVCache):
930
952
  return self.kv_buffer[:, :, indices]
931
953
 
932
954
  def get_flat_data_by_layer(self, indices, layer_id):
933
- return self.kv_buffer[:, layer_id, indices]
955
+ return self.kv_buffer[:, layer_id - self.start_layer, indices]
934
956
 
935
957
  def assign_flat_data(self, indices, flat_data):
936
958
  self.kv_buffer[:, :, indices] = flat_data
@@ -955,12 +977,20 @@ class MHATokenToKVPoolHost(HostKVCache):
955
977
  for i in range(len(device_indices_cpu)):
956
978
  h_index = host_indices[i * self.page_size]
957
979
  d_index = device_indices_cpu[i]
958
- device_pool.k_buffer[layer_id][d_index : d_index + self.page_size].copy_(
959
- self.kv_buffer[0, layer_id, h_index : h_index + self.page_size],
980
+ device_pool.k_buffer[layer_id - self.start_layer][
981
+ d_index : d_index + self.page_size
982
+ ].copy_(
983
+ self.kv_buffer[
984
+ 0, layer_id - self.start_layer, h_index : h_index + self.page_size
985
+ ],
960
986
  non_blocking=True,
961
987
  )
962
- device_pool.v_buffer[layer_id][d_index : d_index + self.page_size].copy_(
963
- self.kv_buffer[1, layer_id, h_index : h_index + self.page_size],
988
+ device_pool.v_buffer[layer_id - self.start_layer][
989
+ d_index : d_index + self.page_size
990
+ ].copy_(
991
+ self.kv_buffer[
992
+ 1, layer_id - self.start_layer, h_index : h_index + self.page_size
993
+ ],
964
994
  non_blocking=True,
965
995
  )
966
996
 
@@ -1015,7 +1045,7 @@ class MLATokenToKVPoolHost(HostKVCache):
1015
1045
  return self.kv_buffer[:, indices]
1016
1046
 
1017
1047
  def get_flat_data_by_layer(self, indices, layer_id):
1018
- return self.kv_buffer[layer_id, indices]
1048
+ return self.kv_buffer[layer_id - self.start_layer, indices]
1019
1049
 
1020
1050
  def assign_flat_data(self, indices, flat_data):
1021
1051
  self.kv_buffer[:, indices] = flat_data
@@ -1036,7 +1066,11 @@ class MLATokenToKVPoolHost(HostKVCache):
1036
1066
  for i in range(len(device_indices_cpu)):
1037
1067
  h_index = host_indices[i * self.page_size]
1038
1068
  d_index = device_indices_cpu[i]
1039
- device_pool.kv_buffer[layer_id][d_index : d_index + self.page_size].copy_(
1040
- self.kv_buffer[layer_id, h_index : h_index + self.page_size],
1069
+ device_pool.kv_buffer[layer_id - self.start_layer][
1070
+ d_index : d_index + self.page_size
1071
+ ].copy_(
1072
+ self.kv_buffer[
1073
+ layer_id - self.start_layer, h_index : h_index + self.page_size
1074
+ ],
1041
1075
  non_blocking=True,
1042
1076
  )